diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..4a273c46 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,9 @@ +blank_issues_enabled: true +contact_links: + - name: Ask on Discord + url: https://discord.gg/7xgjhVAg6g + about: You can ask for help here! + - name: Want to contribute to step certificates? + url: https://github.com/smallstep/certificates/blob/master/docs/CONTRIBUTING.md + about: Be sure to read contributing guidelines! + diff --git a/.github/ISSUE_TEMPLATE/enhancement.md b/.github/ISSUE_TEMPLATE/enhancement.md index 28eec406..3a6ffc94 100644 --- a/.github/ISSUE_TEMPLATE/enhancement.md +++ b/.github/ISSUE_TEMPLATE/enhancement.md @@ -1,5 +1,5 @@ --- -name: Certificates Enhancement +name: Enhancement about: Suggest an enhancement to step certificates title: '' labels: enhancement, needs triage diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index f0011406..bf5b4bc0 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -1,14 +1,12 @@ -name: labeler +name: Pull Request Labeler on: - pull_request: - branches: - - master + pull_request_target jobs: label: runs-on: ubuntu-latest steps: - - uses: actions/labeler@v3 + - uses: actions/labeler@v3.0.2 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" - configuration-path: .github/needs-triage-labeler.yml + diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 819a470e..6da2aa27 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: [ '1.15', '1.16' ] + go: [ '1.15', '1.16', '1.17' ] outputs: is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} steps: @@ -62,8 +62,15 @@ jobs: needs: test runs-on: ubuntu-20.04 outputs: + debversion: ${{ steps.extract-tag.outputs.DEB_VERSION }} is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} steps: + - + name: Extract Tag Names + id: extract-tag + run: | + DEB_VERSION=$(echo ${GITHUB_REF#refs/tags/v} | sed 's/-/./') + echo "::set-output name=DEB_VERSION::${DEB_VERSION}" - name: Is Pre-release id: is_prerelease @@ -99,62 +106,71 @@ jobs: name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.16 - - - name: Run GoReleaser - uses: goreleaser/goreleaser-action@56f5b77f7fa4a8fe068bf22b732ec036cc9bc13f # v2.4.1 - with: - version: latest - args: release --rm-dist - env: - GITHUB_TOKEN: ${{ secrets.PAT }} - - release_deb: - name: Build & Upload Debian Package To Github - runs-on: ubuntu-20.04 - needs: create_release - steps: - - - name: Checkout - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Set up Go - uses: actions/setup-go@v2 - with: - go-version: '1.16' + go-version: 1.17 - name: APT Install id: aptInstall run: sudo apt-get -y install build-essential debhelper fakeroot - name: Build Debian package - id: build + id: make_debian run: | PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin make debian + # need to restore the git state otherwise goreleaser fails due to dirty state + git restore debian/changelog + git clean -fd - - name: Upload Debian Package - id: upload_deb + name: Install cosign + uses: sigstore/cosign-installer@v1.1.0 + with: + cosign-release: 'v1.1.0' + - + name: Write cosign key to disk + id: write_key + run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key" + - + name: Get Release Date + id: release_date run: | - tag_name="${GITHUB_REF##*/}" - hub release edit $(find ./.releases -type f -printf "-a %p ") -m "" "$tag_name" + RELEASE_DATE=$(date +"%y-%m-%d") + echo "::set-output name=RELEASE_DATE::${RELEASE_DATE}" + - + name: Run GoReleaser + uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0 + with: + version: latest + args: release --rm-dist env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_TOKEN: ${{ secrets.PAT }} + COSIGN_PWD: ${{ secrets.COSIGN_PWD }} + DEB_VERSION: ${{ needs.create_release.outputs.debversion }} + RELEASE_DATE: ${{ steps.release_date.outputs.RELEASE_DATE }} build_upload_docker: name: Build & Upload Docker Images runs-on: ubuntu-20.04 needs: test steps: - - name: Checkout + - + name: Checkout uses: actions/checkout@v2 - - name: Setup Go + - + name: Setup Go uses: actions/setup-go@v2 with: - go-version: '1.16' - - name: Build + go-version: '1.17' + - + name: Install cosign + uses: sigstore/cosign-installer@v1.1.0 + with: + cosign-release: 'v1.1.0' + - + name: Write cosign key to disk + id: write_key + run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key" + - + name: Build id: build run: | PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin @@ -162,3 +178,4 @@ jobs: env: DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} + COSIGN_PWD: ${{ secrets.COSIGN_PWD }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9c73cfbd..96655664 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: [ '1.15', '1.16' ] + go: [ '1.15', '1.16', '1.17' ] steps: - name: Checkout diff --git a/.gitignore b/.gitignore index 7cba0d08..d87786b0 100644 --- a/.gitignore +++ b/.gitignore @@ -14,8 +14,8 @@ # Others *.swp -.travis-releases +.releases coverage.txt -vendor output +vendor .idea diff --git a/.golangci.yml b/.golangci.yml index 1bab3ba3..cf389517 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -8,7 +8,7 @@ linters-settings: - (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf - (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf - (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf - golint: + revive: min-confidence: 0 gocyclo: min-complexity: 10 @@ -36,22 +36,30 @@ linters-settings: - performance - style - experimental + - diagnostic disabled-checks: - - wrapperFunc - - dupImport # https://github.com/go-critic/go-critic/issues/845 + - commentFormatting + - commentedOutCode + - evalOrder + - hugeParam + - octalLiteral + - rangeValCopy + - tooManyResultsChecker + - unnamedResult linters: disable-all: true enable: - - gofmt - - golint - - govet - - misspell - - ineffassign - deadcode + - gocritic + - gofmt + - gosimple + - govet + - ineffassign + - misspell + - revive - staticcheck - unused - - gosimple run: skip-dirs: diff --git a/.goreleaser.yml b/.goreleaser.yml index 7a7e20d3..207c75bd 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -1,34 +1,27 @@ # This is an example .goreleaser.yml file with some sane defaults. # Make sure to check the documentation at http://goreleaser.com project_name: step-ca + before: hooks: # You may remove this if you don't use go modules. - go mod download + builds: - id: step-ca env: - CGO_ENABLED=0 - goos: - - linux - - darwin - - windows - goarch: - - amd64 - - arm - - arm64 - - 386 - goarm: - - 6 - - 7 - ignore: - - goos: windows - goarch: 386 - - goos: windows - goarm: 6 - - goos: windows - goarm: 7 + targets: + - darwin_amd64 + - darwin_arm64 + - freebsd_amd64 + - linux_386 + - linux_amd64 + - linux_arm64 + - linux_arm_6 + - linux_arm_7 + - windows_amd64 flags: - -trimpath main: ./cmd/step-ca/main.go @@ -39,25 +32,16 @@ builds: id: step-cloudkms-init env: - CGO_ENABLED=0 - goos: - - linux - - darwin - - windows - goarch: - - amd64 - - arm - - arm64 - - 386 - goarm: - - 6 - - 7 - ignore: - - goos: windows - goarch: 386 - - goos: windows - goarm: 6 - - goos: windows - goarm: 7 + targets: + - darwin_amd64 + - darwin_arm64 + - freebsd_amd64 + - linux_386 + - linux_amd64 + - linux_arm64 + - linux_arm_6 + - linux_arm_7 + - windows_amd64 flags: - -trimpath main: ./cmd/step-cloudkms-init/main.go @@ -68,31 +52,23 @@ builds: id: step-awskms-init env: - CGO_ENABLED=0 - goos: - - linux - - darwin - - windows - goarch: - - amd64 - - arm - - arm64 - - 386 - goarm: - - 6 - - 7 - ignore: - - goos: windows - goarch: 386 - - goos: windows - goarm: 6 - - goos: windows - goarm: 7 + targets: + - darwin_amd64 + - darwin_arm64 + - freebsd_amd64 + - linux_386 + - linux_amd64 + - linux_arm64 + - linux_arm_6 + - linux_arm_7 + - windows_amd64 flags: - -trimpath main: ./cmd/step-awskms-init/main.go binary: bin/step-awskms-init ldflags: - -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}} + archives: - # Can be used to change the archive formats for specific GOOSs. @@ -106,13 +82,25 @@ archives: files: - README.md - LICENSE + source: enabled: true name_template: '{{ .ProjectName }}_{{ .Version }}' + checksum: name_template: 'checksums.txt' + extra_files: + - glob: ./.releases/* + +signs: +- cmd: cosign + stdin: '{{ .Env.COSIGN_PWD }}' + args: ["sign-blob", "-key=/tmp/cosign.key", "-output=${signature}", "${artifact}"] + artifacts: all + snapshot: name_template: "{{ .Tag }}-next" + release: # Repo in which the release will be created. # Default is extracted from the origin remote URL or empty if its private hosted. @@ -139,7 +127,55 @@ release: # You can change the name of the release. # Default is `{{.Tag}}` - #name_template: "{{.ProjectName}}-v{{.Version}} {{.Env.USER}}" + name_template: "Step CA {{ .Tag }} ({{ .Env.RELEASE_DATE }})" + + # Header template for the release body. + # Defaults to empty. + header: | + ## Official Release Artifacts + + #### Linux + + - πŸ“¦ [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz) + - πŸ“¦ [step-ca_{{ .Env.DEB_VERSION }}_amd64.deb](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ .Env.DEB_VERSION }}_amd64.deb) + + #### OSX Darwin + + - πŸ“¦ [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz) + - πŸ“¦ [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz) + + #### Windows + + - πŸ“¦ [step-ca_windows_{{ .Version }}_arm64.zip](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip) + + For more builds across platforms and architectures, see the `Assets` section below. + And for packaged versions (Docker, k8s, Homebrew), see our [installation docs](https://smallstep.com/docs/step-ca/installation). + + Don't see the artifact you need? Open an issue [here](https://github.com/smallstep/certificates/issues/new/choose). + + ## Signatures and Checksums + + `step-ca` uses [sigstore/cosign](https://github.com/sigstore/cosign) for signing and verifying release artifacts. + + Below is an example using `cosign` to verify a release artifact: + + ``` + cosign verify-blob \ + -key https://raw.githubusercontent.com/smallstep/certificates/master/cosign.pub \ + -signature ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig + ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz + ``` + + The `checksums.txt` file (in the `Assets` section below) contains a checksum for every artifact in the release. + + # Footer template for the release body. + # Defaults to empty. + footer: | + ## Thanks! + + Those were the changes on {{ .Tag }}! + + Come join us on [Discord](https://discord.gg/X2RKGwEbV9) to ask questions, chat about PKI, or get a sneak peak at the freshest PKI memes. # You can disable this pipe in order to not upload any artifacts. # Defaults to false. @@ -149,6 +185,8 @@ release: # The filename on the release will be the last part of the path (base). If # another file with the same name exists, the latest one found will be used. # Defaults to empty. + extra_files: + - glob: ./.releases/* #extra_files: # - glob: ./path/to/file.txt # - glob: ./glob/**/to/**/file/**/* diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a2b3e25..ca792f55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,10 +4,61 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). -## [Unreleased - 0.0.1] - DATE +## [Unreleased - 0.17.7] - DATE ### Added ### Changed ### Deprecated ### Removed ### Fixed ### Security + +## [0.17.6] - 2021-10-20 +### Notes +- 0.17.5 failed in CI/CD + +## [0.17.5] - 2021-10-20 +### Added +- Support for Azure Key Vault as a KMS. +- Adapt `pki` package to support key managers. +- gocritic linter +### Fixed +- gocritic warnings + +## [0.17.4] - 2021-09-28 +### Fixed +- Support host-only or user-only SSH CA. + +## [0.17.3] - 2021-09-24 +### Added +- go 1.17 to github action test matrix +- Support for CloudKMS RSA-PSS signers without using templates. +- Add flags to support individual passwords for the intermediate and SSH keys. +- Global support for group admins in the OIDC provisioner. +### Changed +- Using go 1.17 for binaries +### Fixed +- Upgrade go-jose.v2 to fix a bug in the JWK fingerprint of Ed25519 keys. +### Security +- Use cosign to sign and upload signatures for multi-arch Docker container. +- Add debian checksum + +## [0.17.2] - 2021-08-30 +### Added +- Additional way to distinguish Azure IID and Azure OIDC tokens. +### Security +- Sign over all goreleaser github artifacts using cosign + +## [0.17.1] - 2021-08-26 + +## [0.17.0] - 2021-08-25 +### Added +- Add support for Linked CAs using protocol buffers and gRPC +- `step-ca init` adds support for + - configuring a StepCAS RA + - configuring a Linked CA + - congifuring a `step-ca` using Helm +### Changed +- Update badger driver to use v2 by default +- Update TLS cipher suites to include 1.3 +### Security +- Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512. diff --git a/Makefile b/Makefile index 882a0122..09e342df 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,7 @@ PREFIX?= SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*") GOOS_OVERRIDE ?= OUTPUT_ROOT=output/ +RELEASE=./.releases all: lint test build @@ -28,7 +29,7 @@ ci: testcgo build bootstra%: # Using a released version of golangci-lint to take into account custom replacements in their go.mod - $Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.39.0 + $Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v1.42.0 .PHONY: bootstra% @@ -67,7 +68,7 @@ PUSHTYPE := branch endif VERSION := $(shell echo $(VERSION) | sed 's/^v//') -DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/~/g') +DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/./g') ifdef V $(info TRAVIS_TAG is $(TRAVIS_TAG)) @@ -153,7 +154,7 @@ fmt: $Q gofmt -l -w $(SRC) lint: - $Q $(GOFLAGS) LOG_LEVEL=error golangci-lint run --timeout=30m + $Q golangci-lint run --timeout=30m lintcgo: $Q LOG_LEVEL=error golangci-lint run --timeout=30m diff --git a/README.md b/README.md index f0649175..65116b38 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,14 @@ You can use it to: Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [safe, sane defaults](https://smallstep.com/docs/step-ca/certificate-authority-server-production#sane-cryptographic-defaults). -**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions).** +--- + +**Don't want to run your own CA?** +To get up and running quickly, or as an alternative to running your own `step-ca` server, consider creating a [free hosted smallstep Certificate Manager authority](https://info.smallstep.com/certificate-manager-early-access-mvp/). + +--- + +**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions) or [Join our Discord](https://u.step.sm/discord).** [Website](https://smallstep.com/certificates) | [Documentation](https://smallstep.com/docs) | @@ -27,7 +34,6 @@ Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [ [Contributor's Guide](./docs/CONTRIBUTING.md) [![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest) -[![CA Image](https://images.microbadger.com/badges/image/smallstep/step-ca.svg)](https://microbadger.com/images/smallstep/step-ca) [![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates) [![Build Status](https://travis-ci.com/smallstep/certificates.svg?branch=master)](https://travis-ci.com/smallstep/certificates) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) @@ -58,10 +64,11 @@ You can issue certificates in exchange for: - ID tokens from Okta, GSuite, Azure AD, Auth0. - ID tokens from an OAuth OIDC service that you host, like [Keycloak](https://www.keycloak.org/) or [Dex](https://github.com/dexidp/dex) - [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure -- [Single-use, short-lived JWK tokens]() issued by your CD tool β€” Puppet, Chef, Ansible, Terraform, etc. +- [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool β€” Puppet, Chef, Ansible, Terraform, etc. - A trusted X.509 certificate (X5C provisioner) -- Expiring SSH host certificates needing rotation (the SSHPOP provisioner) -- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/configuration#jwk) +- A SCEP challenge (SCEP provisioner) +- An SSH host certificates needing renewal (the SSHPOP provisioner) +- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners) ### πŸ” Your own private ACME server @@ -74,16 +81,17 @@ ACME is the protocol used by Let's Encrypt to automate the issuance of HTTPS cer - For `tls-alpn-01`, respond to the challenge at the TLS layer ([as Caddy does](https://caddy.community/t/caddy-supports-the-acme-tls-alpn-challenge/4860)) to prove that you control the web server - Works with any ACME client. We've written examples for: - - [certbot](https://smallstep.com/blog/private-acme-server/#certbotuploadsacme-certbotpng-certbot-example) - - [acme.sh](https://smallstep.com/blog/private-acme-server/#acmeshuploadsacme-acme-shpng-acmesh-example) - - [Caddy](https://smallstep.com/blog/private-acme-server/#caddyuploadsacme-caddypng-caddy-example) - - [Traefik](https://smallstep.com/blog/private-acme-server/#traefikuploadsacme-traefikpng-traefik-example) - - [Apache](https://smallstep.com/blog/private-acme-server/#apacheuploadsacme-apachepng-apache-example) - - [nginx](https://smallstep.com/blog/private-acme-server/#nginxuploadsacme-nginxpng-nginx-example) + - [certbot](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#certbot) + - [acme.sh](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#acmesh) + - [win-acme](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#win-acme) + - [Caddy](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#caddy-v2) + - [Traefik](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#traefik) + - [Apache](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#apache) + - [nginx](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#nginx) - Get certificates programmatically using ACME, using these libraries: - - [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/blog/private-acme-server/#golanguploadsacme-golangpng-go-example)) - - certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/blog/private-acme-server/#pythonuploadsacme-pythonpng-python-example)) - - [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/blog/private-acme-server/#nodejsuploadsacme-node-jspng-nodejs-example)) + - [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#golang)) + - certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#python)) + - [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#node)) - Our own [`step` CLI tool](https://github.com/smallstep/cli) is also an ACME client! - See our [ACME tutorial](https://smallstep.com/docs/tutorials/acme-challenge) for more diff --git a/acme/api/account.go b/acme/api/account.go index b733c679..259cb2a2 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -19,7 +19,7 @@ type NewAccountRequest struct { func validateContacts(cs []string) error { for _, c := range cs { - if len(c) == 0 { + if c == "" { return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string") } } diff --git a/acme/api/account_test.go b/acme/api/account_test.go index c4d7a812..a45751a0 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -178,7 +178,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) + u := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) oids := []string{"foo", "bar"} oidURLs := []string{ @@ -255,7 +255,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetOrdersByAccountID(w, req) diff --git a/acme/api/handler.go b/acme/api/handler.go index 2a6d3a02..b05bd0c4 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -64,8 +64,14 @@ type HandlerOptions struct { // NewHandler returns a new ACME API handler. func NewHandler(ops HandlerOptions) api.RouterHandler { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } client := http.Client{ - Timeout: 30 * time.Second, + Timeout: 30 * time.Second, + Transport: transport, } dialer := &net.Dialer{ Timeout: 30 * time.Second, diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 5501479d..8112ad4c 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -148,7 +148,7 @@ func TestHandler_GetAuthorization(t *testing.T) { // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("authzID", az.ID) - url := fmt.Sprintf("%s/acme/%s/authz/%s", + u := fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, az.ID) type test struct { @@ -280,7 +280,7 @@ func TestHandler_GetAuthorization(t *testing.T) { expB, err := json.Marshal(az) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Location"], []string{url}) + assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) @@ -314,7 +314,7 @@ func TestHandler_GetCertificate(t *testing.T) { // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("certID", certID) - url := fmt.Sprintf("%s/acme/%s/certificate/%s", + u := fmt.Sprintf("%s/acme/%s/certificate/%s", baseURL.String(), provName, certID) type test struct { @@ -396,7 +396,7 @@ func TestHandler_GetCertificate(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{db: tc.db} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetCertificate(w, req) @@ -434,7 +434,7 @@ func TestHandler_GetChallenge(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s", + u := fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL.String(), provName, "authzID", "chID") type test struct { @@ -574,13 +574,13 @@ func TestHandler_GetChallenge(t *testing.T) { assert.Equals(t, azID, "authzID") return &acme.Challenge{ Status: acme.StatusPending, - Type: "http-01", + Type: acme.HTTP01, AccountID: "accID", }, nil }, MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.Status, acme.StatusPending) - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) assert.Equals(t, ch.AccountID, "accID") assert.Equals(t, ch.AuthorizationID, "authzID") assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) @@ -616,13 +616,13 @@ func TestHandler_GetChallenge(t *testing.T) { return &acme.Challenge{ ID: "chID", Status: acme.StatusPending, - Type: "http-01", + Type: acme.HTTP01, AccountID: "accID", }, nil }, MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.Status, acme.StatusPending) - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) assert.Equals(t, ch.AccountID, "accID") assert.Equals(t, ch.AuthorizationID, "authzID") assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) @@ -633,9 +633,9 @@ func TestHandler_GetChallenge(t *testing.T) { ID: "chID", Status: acme.StatusPending, AuthorizationID: "authzID", - Type: "http-01", + Type: acme.HTTP01, AccountID: "accID", - URL: url, + URL: u, Error: acme.NewError(acme.ErrorConnectionType, "force"), }, vco: &acme.ValidateChallengeOptions{ @@ -652,7 +652,7 @@ func TestHandler_GetChallenge(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetChallenge(w, req) @@ -678,7 +678,7 @@ func TestHandler_GetChallenge(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")}) - assert.Equals(t, res.Header["Location"], []string{url}) + assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 50f7146f..bc67dbc6 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -223,7 +223,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } - if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 { + if hdr.JSONWebKey == nil && hdr.KeyID == "" { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } @@ -288,13 +288,13 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - name := chi.URLParam(r, "provisionerID") - provID, err := url.PathUnescape(name) + nameEscaped := chi.URLParam(r, "provisionerID") + name, err := url.PathUnescape(nameEscaped) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner id '%s'", name)) + api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } - p, err := h.ca.LoadProvisionerByID("acme/" + provID) + p, err := h.ca.LoadProvisionerByName(name) if err != nil { api.WriteError(w, err) return @@ -367,7 +367,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { api.WriteError(w, err) return } - if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { + if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 40090e83..e8d22d53 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -108,7 +108,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) { } func TestHandler_addNonce(t *testing.T) { - url := "https://ca.smallstep.com/acme/new-nonce" + u := "https://ca.smallstep.com/acme/new-nonce" type test struct { db acme.DB err *acme.Error @@ -141,7 +141,7 @@ func TestHandler_addNonce(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{db: tc.db} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) w := httptest.NewRecorder() h.addNonce(testNext)(w, req) res := w.Result() @@ -230,7 +230,7 @@ func TestHandler_verifyContentType(t *testing.T) { prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) + u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) type test struct { h Handler ctx context.Context @@ -245,7 +245,7 @@ func TestHandler_verifyContentType(t *testing.T) { h: Handler{ linker: NewLinker("dns", "acme"), }, - url: url, + url: u, ctx: context.Background(), contentType: "foo", statusCode: 500, @@ -257,7 +257,7 @@ func TestHandler_verifyContentType(t *testing.T) { h: Handler{ linker: NewLinker("dns", "acme"), }, - url: url, + url: u, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", statusCode: 400, @@ -319,11 +319,11 @@ func TestHandler_verifyContentType(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - _url := url + _u := u if tc.url != "" { - _url = tc.url + _u = tc.url } - req := httptest.NewRequest("GET", _url, nil) + req := httptest.NewRequest("GET", _u, nil) req = req.WithContext(tc.ctx) req.Header.Add("Content-Type", tc.contentType) w := httptest.NewRecorder() @@ -353,7 +353,7 @@ func TestHandler_verifyContentType(t *testing.T) { } func TestHandler_isPostAsGet(t *testing.T) { - url := "https://ca.smallstep.com/acme/new-account" + u := "https://ca.smallstep.com/acme/new-account" type test struct { ctx context.Context err *acme.Error @@ -392,7 +392,7 @@ func TestHandler_isPostAsGet(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.isPostAsGet(testNext)(w, req) @@ -430,7 +430,7 @@ func (errReader) Close() error { } func TestHandler_parseJWS(t *testing.T) { - url := "https://ca.smallstep.com/acme/new-account" + u := "https://ca.smallstep.com/acme/new-account" type test struct { next nextHTTP body io.Reader @@ -483,7 +483,7 @@ func TestHandler_parseJWS(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{} - req := httptest.NewRequest("GET", url, tc.body) + req := httptest.NewRequest("GET", u, tc.body) w := httptest.NewRecorder() h.parseJWS(tc.next)(w, req) res := w.Result() @@ -528,7 +528,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) - url := "https://ca.smallstep.com/acme/account/1234" + u := "https://ca.smallstep.com/acme/account/1234" type test struct { ctx context.Context next func(http.ResponseWriter, *http.Request) @@ -681,7 +681,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.verifyAndExtractJWSPayload(tc.next)(w, req) @@ -713,7 +713,7 @@ func TestHandler_lookupJWK(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/account/1234", + u := fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provName) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) @@ -883,7 +883,7 @@ func TestHandler_lookupJWK(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{db: tc.db, linker: tc.linker} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.lookupJWK(tc.next)(w, req) @@ -934,7 +934,7 @@ func TestHandler_extractJWK(t *testing.T) { assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) - url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", + u := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provName) type test struct { db acme.DB @@ -1079,7 +1079,7 @@ func TestHandler_extractJWK(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{db: tc.db} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.extractJWK(tc.next)(w, req) @@ -1108,7 +1108,7 @@ func TestHandler_extractJWK(t *testing.T) { } func TestHandler_validateJWS(t *testing.T) { - url := "https://ca.smallstep.com/acme/account/1234" + u := "https://ca.smallstep.com/acme/account/1234" type test struct { db acme.DB ctx context.Context @@ -1198,7 +1198,7 @@ func TestHandler_validateJWS(t *testing.T) { Algorithm: jose.RS256, JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": url, + "url": u, }, }, }, @@ -1226,7 +1226,7 @@ func TestHandler_validateJWS(t *testing.T) { Algorithm: jose.RS256, JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": url, + "url": u, }, }, }, @@ -1298,7 +1298,7 @@ func TestHandler_validateJWS(t *testing.T) { }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", url), + err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", u), } }, "fail/both-jwk-kid": func(t *testing.T) test { @@ -1313,7 +1313,7 @@ func TestHandler_validateJWS(t *testing.T) { KeyID: "bar", JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": url, + "url": u, }, }, }, @@ -1337,7 +1337,7 @@ func TestHandler_validateJWS(t *testing.T) { Protected: jose.Header{ Algorithm: jose.ES256, ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": url, + "url": u, }, }, }, @@ -1362,7 +1362,7 @@ func TestHandler_validateJWS(t *testing.T) { Algorithm: jose.ES256, KeyID: "bar", ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": url, + "url": u, }, }, }, @@ -1392,7 +1392,7 @@ func TestHandler_validateJWS(t *testing.T) { Algorithm: jose.ES256, JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": url, + "url": u, }, }, }, @@ -1422,7 +1422,7 @@ func TestHandler_validateJWS(t *testing.T) { Algorithm: jose.RS256, JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": url, + "url": u, }, }, }, @@ -1446,7 +1446,7 @@ func TestHandler_validateJWS(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{db: tc.db} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.validateJWS(tc.next)(w, req) diff --git a/acme/api/order.go b/acme/api/order.go index 9d410173..9cf2c1eb 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" + "net" "net/http" "strings" "time" @@ -28,9 +29,12 @@ func (n *NewOrderRequest) Validate() error { return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty") } for _, id := range n.Identifiers { - if id.Type != "dns" { + if !(id.Type == acme.DNS || id.Type == acme.IP) { return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type) } + if id.Type == acme.IP && net.ParseIP(id.Value) == nil { + return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value) + } } return nil } @@ -85,6 +89,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { "failed to unmarshal new-order request payload")) return } + if err := nor.Validate(); err != nil { api.WriteError(w, err) return @@ -149,15 +154,9 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) } } - var ( - err error - chTypes = []string{"dns-01"} - ) - // HTTP and TLS challenges can only be used for identifiers without wildcards. - if !az.Wildcard { - chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...) - } + chTypes := challengeTypes(az) + var err error az.Token, err = randutil.Alphanumeric(32) if err != nil { return acme.WrapErrorISE(err, "error generating random alphanumeric ID") @@ -275,3 +274,24 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) api.JSON(w, o) } + +// challengeTypes determines the types of challenges that should be used +// for the ACME authorization request. +func challengeTypes(az *acme.Authorization) []acme.ChallengeType { + var chTypes []acme.ChallengeType + + switch az.Identifier.Type { + case acme.IP: + chTypes = []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01} + case acme.DNS: + chTypes = []acme.ChallengeType{acme.DNS01} + // HTTP and TLS challenges can only be used for identifiers without wildcards. + if !az.Wildcard { + chTypes = append(chTypes, []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}...) + } + default: + chTypes = []acme.ChallengeType{} + } + + return chTypes +} diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 300aa61b..3c6d768f 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -10,6 +10,7 @@ import ( "io/ioutil" "net/http/httptest" "net/url" + "reflect" "testing" "time" @@ -44,6 +45,22 @@ func TestNewOrderRequest_Validate(t *testing.T) { err: acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: foo"), } }, + "fail/bad-ip": func(t *testing.T) test { + nbf := time.Now().UTC().Add(time.Minute) + naf := time.Now().UTC().Add(5 * time.Minute) + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "ip", Value: "192.168.42.1000"}, + }, + NotAfter: naf, + NotBefore: nbf, + }, + nbf: nbf, + naf: naf, + err: acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", "192.168.42.1000"), + } + }, "ok": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) @@ -60,6 +77,68 @@ func TestNewOrderRequest_Validate(t *testing.T) { naf: naf, } }, + "ok/ipv4": func(t *testing.T) test { + nbf := time.Now().UTC().Add(time.Minute) + naf := time.Now().UTC().Add(5 * time.Minute) + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "ip", Value: "192.168.42.42"}, + }, + NotAfter: naf, + NotBefore: nbf, + }, + nbf: nbf, + naf: naf, + } + }, + "ok/ipv6": func(t *testing.T) test { + nbf := time.Now().UTC().Add(time.Minute) + naf := time.Now().UTC().Add(5 * time.Minute) + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "ip", Value: "2001:db8::1"}, + }, + NotAfter: naf, + NotBefore: nbf, + }, + nbf: nbf, + naf: naf, + } + }, + "ok/mixed-dns-and-ipv4": func(t *testing.T) test { + nbf := time.Now().UTC().Add(time.Minute) + naf := time.Now().UTC().Add(5 * time.Minute) + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "example.com"}, + {Type: "ip", Value: "192.168.42.42"}, + }, + NotAfter: naf, + NotBefore: nbf, + }, + nbf: nbf, + naf: naf, + } + }, + "ok/mixed-ipv4-and-ipv6": func(t *testing.T) test { + nbf := time.Now().UTC().Add(time.Minute) + naf := time.Now().UTC().Add(5 * time.Minute) + return test{ + nor: &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "ip", Value: "192.168.42.42"}, + {Type: "ip", Value: "2001:db8::1"}, + }, + NotAfter: naf, + NotBefore: nbf, + }, + nbf: nbf, + naf: naf, + } + }, } for name, run := range tests { tc := run(t) @@ -185,7 +264,7 @@ func TestHandler_GetOrder(t *testing.T) { // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", o.ID) - url := fmt.Sprintf("%s/acme/%s/order/%s", + u := fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), escProvName, o.ID) type test struct { @@ -343,7 +422,7 @@ func TestHandler_GetOrder(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetOrder(w, req) @@ -369,7 +448,7 @@ func TestHandler_GetOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Location"], []string{url}) + assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) @@ -395,7 +474,7 @@ func TestHandler_newAuthorization(t *testing.T) { db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.AccountID, az.AccountID) - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) assert.Equals(t, ch.Token, az.Token) assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, az.Identifier.Value) @@ -424,15 +503,15 @@ func TestHandler_newAuthorization(t *testing.T) { switch count { case 0: ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" - assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) @@ -478,15 +557,15 @@ func TestHandler_newAuthorization(t *testing.T) { switch count { case 0: ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" - assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) @@ -528,7 +607,7 @@ func TestHandler_newAuthorization(t *testing.T) { db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) assert.Equals(t, ch.AccountID, az.AccountID) assert.Equals(t, ch.Token, az.Token) assert.Equals(t, ch.Status, acme.StatusPending) @@ -584,7 +663,7 @@ func TestHandler_NewOrder(t *testing.T) { prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/order/ordID", + u := fmt.Sprintf("%s/acme/%s/order/ordID", baseURL.String(), escProvName) type test struct { @@ -695,7 +774,7 @@ func TestHandler_NewOrder(t *testing.T) { db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.AccountID, "accID") - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") @@ -730,15 +809,15 @@ func TestHandler_NewOrder(t *testing.T) { switch count { case 0: ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" - assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) @@ -802,22 +881,22 @@ func TestHandler_NewOrder(t *testing.T) { switch chCount { case 0: ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) assert.Equals(t, ch.Value, "zap.internal") ch1 = &ch case 1: ch.ID = "http" - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) assert.Equals(t, ch.Value, "zap.internal") ch2 = &ch case 2: ch.ID = "tls" - assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Type, acme.TLSALPN01) assert.Equals(t, ch.Value, "zap.internal") ch3 = &ch case 3: ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) assert.Equals(t, ch.Value, "zar.internal") ch4 = &ch default: @@ -842,7 +921,7 @@ func TestHandler_NewOrder(t *testing.T) { az.ID = "az2ID" az2ID = &az.ID assert.Equals(t, az.Identifier, acme.Identifier{ - Type: "dns", + Type: acme.DNS, Value: "zar.internal", }) assert.Equals(t, az.Wildcard, true) @@ -917,15 +996,15 @@ func TestHandler_NewOrder(t *testing.T) { switch count { case 0: ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" - assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) @@ -1009,15 +1088,15 @@ func TestHandler_NewOrder(t *testing.T) { switch count { case 0: ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" - assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) @@ -1100,15 +1179,15 @@ func TestHandler_NewOrder(t *testing.T) { switch count { case 0: ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" - assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) @@ -1192,15 +1271,15 @@ func TestHandler_NewOrder(t *testing.T) { switch count { case 0: ch.ID = "dns" - assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" - assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" - assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) @@ -1256,7 +1335,7 @@ func TestHandler_NewOrder(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.NewOrder(w, req) @@ -1284,7 +1363,7 @@ func TestHandler_NewOrder(t *testing.T) { tc.vr(t, ro) } - assert.Equals(t, res.Header["Location"], []string{url}) + assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) @@ -1327,7 +1406,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", o.ID) - url := fmt.Sprintf("%s/acme/%s/order/%s", + u := fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), escProvName, o.ID) _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") @@ -1546,7 +1625,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.FinalizeOrder(w, req) @@ -1575,9 +1654,58 @@ func TestHandler_FinalizeOrder(t *testing.T) { assert.FatalError(t, json.Unmarshal(body, ro)) assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Location"], []string{url}) + assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } + +func TestHandler_challengeTypes(t *testing.T) { + type args struct { + az *acme.Authorization + } + tests := []struct { + name string + args args + want []acme.ChallengeType + }{ + { + name: "ok/dns", + args: args{ + az: &acme.Authorization{ + Identifier: acme.Identifier{Type: "dns", Value: "example.com"}, + Wildcard: false, + }, + }, + want: []acme.ChallengeType{acme.DNS01, acme.HTTP01, acme.TLSALPN01}, + }, + { + name: "ok/wildcard", + args: args{ + az: &acme.Authorization{ + Identifier: acme.Identifier{Type: "dns", Value: "*.example.com"}, + Wildcard: true, + }, + }, + want: []acme.ChallengeType{acme.DNS01}, + }, + { + name: "ok/ip", + args: args{ + az: &acme.Authorization{ + Identifier: acme.Identifier{Type: "ip", Value: "192.168.42.42"}, + Wildcard: false, + }, + }, + want: []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := challengeTypes(tt.args.az); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Handler.challengeTypes() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/acme/challenge.go b/acme/challenge.go index 1059e437..b880708c 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -10,29 +10,39 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "fmt" "io/ioutil" "net" "net/http" "net/url" + "reflect" "strings" "time" "go.step.sm/crypto/jose" ) +type ChallengeType string + +const ( + HTTP01 ChallengeType = "http-01" + DNS01 ChallengeType = "dns-01" + TLSALPN01 ChallengeType = "tls-alpn-01" +) + // Challenge represents an ACME response Challenge type. type Challenge struct { - ID string `json:"-"` - AccountID string `json:"-"` - AuthorizationID string `json:"-"` - Value string `json:"-"` - Type string `json:"type"` - Status Status `json:"status"` - Token string `json:"token"` - ValidatedAt string `json:"validated,omitempty"` - URL string `json:"url"` - Error *Error `json:"error,omitempty"` + ID string `json:"-"` + AccountID string `json:"-"` + AuthorizationID string `json:"-"` + Value string `json:"-"` + Type ChallengeType `json:"type"` + Status Status `json:"status"` + Token string `json:"token"` + ValidatedAt string `json:"validated,omitempty"` + URL string `json:"url"` + Error *Error `json:"error,omitempty"` } // ToLog enables response logging. @@ -54,11 +64,11 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, return nil } switch ch.Type { - case "http-01": + case HTTP01: return http01Validate(ctx, ch, db, jwk, vo) - case "dns-01": + case DNS01: return dns01Validate(ctx, ch, db, jwk, vo) - case "tls-alpn-01": + case TLSALPN01: return tlsalpn01Validate(ctx, ch, db, jwk, vo) default: return NewErrorISE("unexpected challenge type '%s'", ch.Type) @@ -66,23 +76,23 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, } func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { - url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} + u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} - resp, err := vo.HTTPGet(url.String()) + resp, err := vo.HTTPGet(u.String()) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, - "error doing http GET for url %s", url)) + "error doing http GET for url %s", u)) } defer resp.Body.Close() if resp.StatusCode >= 400 { return storeError(ctx, db, ch, false, NewError(ErrorConnectionType, - "error doing http GET for url %s with status code %d", url, resp.StatusCode)) + "error doing http GET for url %s with status code %d", u, resp.StatusCode)) } body, err := ioutil.ReadAll(resp.Body) if err != nil { return WrapErrorISE(err, "error reading "+ - "response body for url %s", url) + "response body for url %s", u) } keyAuth := strings.TrimSpace(string(body)) @@ -106,6 +116,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb return nil } +func tlsAlert(err error) uint8 { + var opErr *net.OpError + if errors.As(err, &opErr) { + v := reflect.ValueOf(opErr.Err) + if v.Kind() == reflect.Uint8 { + return uint8(v.Uint()) + } + } + return 0 +} + func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, @@ -113,7 +134,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON // ACME servers that implement "acme-tls/1" MUST only negotiate TLS 1.2 // [RFC5246] or higher when connecting to clients for validation. MinVersion: tls.VersionTLS12, - ServerName: ch.Value, + ServerName: serverName(ch), InsecureSkipVerify: true, // we expect a self-signed challenge certificate } @@ -121,6 +142,14 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON conn, err := vo.TLSDial("tcp", hostPort, config) if err != nil { + // With Go 1.17+ tls.Dial fails if there's no overlap between configured + // client and server protocols. When this happens the connection is + // closed with the error no_application_protocol(120) as required by + // RFC7301. See https://golang.org/doc/go1.17#ALPN + if tlsAlert(err) == 120 { + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) + } return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, "error doing TLS dial for %s", hostPort)) } @@ -141,9 +170,17 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON leafCert := certs[0] - if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) { - return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, - "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)) + // if no DNS names present, look for IP address and verify that exactly one exists + if len(leafCert.DNSNames) == 0 { + if len(leafCert.IPAddresses) != 1 || !leafCert.IPAddresses[0].Equal(net.ParseIP(ch.Value)) { + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value)) + } + } else { + if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) { + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value)) + } } idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} @@ -244,6 +281,65 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK return nil } +// serverName determines the SNI HostName to set based on an acme.Challenge +// for TLS-ALPN-01 challenges RFC8738 states that, if HostName is an IP, it +// should be the ARPA address https://datatracker.ietf.org/doc/html/rfc8738#section-6. +// It also references TLS Extensions [RFC6066]. +func serverName(ch *Challenge) string { + var serverName string + ip := net.ParseIP(ch.Value) + if ip != nil { + serverName = reverseAddr(ip) + } else { + serverName = ch.Value + } + return serverName +} + +// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP +// address addr suitable for rDNS (PTR) record lookup or an error if it fails +// to parse the IP address. +// Implementation taken and adapted from https://golang.org/src/net/dnsclient.go?s=780:834#L20 +func reverseAddr(ip net.IP) (arpa string) { + if ip.To4() != nil { + return uitoa(uint(ip[15])) + "." + uitoa(uint(ip[14])) + "." + uitoa(uint(ip[13])) + "." + uitoa(uint(ip[12])) + ".in-addr.arpa." + } + // Must be IPv6 + buf := make([]byte, 0, len(ip)*4+len("ip6.arpa.")) + // Add it, in reverse, to the buffer + for i := len(ip) - 1; i >= 0; i-- { + v := ip[i] + buf = append(buf, hexit[v&0xF], + '.', + hexit[v>>4], + '.') + } + // Append "ip6.arpa." and return (buf already has the final .) + buf = append(buf, "ip6.arpa."...) + return string(buf) +} + +// Convert unsigned integer to decimal string. +// Implementation taken from https://golang.org/src/net/parse.go +func uitoa(val uint) string { + if val == 0 { // avoid string allocation + return "0" + } + var buf [20]byte // big enough for 64bit value base 10 + i := len(buf) - 1 + for val >= 10 { + q := val / 10 + buf[i] = byte('0' + val - q*10) + i-- + val = q + } + // val < 10 + buf[i] = byte('0' + val) + return string(buf[i:]) +} + +const hexit = "0123456789abcdef" + // KeyAuthorization creates the ACME key authorization value from a token // and a jwk. func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 14287945..a522790f 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -1276,7 +1276,7 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} } - keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:]) + keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash) certTemplate.ExtraExtensions = []pkix.Extension{ { @@ -1395,7 +1395,7 @@ func TestTLSALPN01Validate(t *testing.T) { assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) - err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.Value) + err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443:", ch.Value) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) assert.Equals(t, updch.Error.Type, err.Type) @@ -1544,7 +1544,7 @@ func TestTLSALPN01Validate(t *testing.T) { err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/no-names-error": func(t *testing.T) test { + "ok/no-names-nor-ips-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1573,7 +1573,7 @@ func TestTLSALPN01Validate(t *testing.T) { assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) - err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) assert.Equals(t, updch.Error.Type, err.Type) @@ -1616,7 +1616,7 @@ func TestTLSALPN01Validate(t *testing.T) { assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) - err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) assert.Equals(t, updch.Error.Type, err.Type) @@ -1660,7 +1660,7 @@ func TestTLSALPN01Validate(t *testing.T) { assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) - err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) assert.Equals(t, updch.Error.Type, err.Type) @@ -1703,7 +1703,7 @@ func TestTLSALPN01Validate(t *testing.T) { assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Value, ch.Value) - err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) assert.Equals(t, updch.Error.Type, err.Type) @@ -2187,6 +2187,43 @@ func TestTLSALPN01Validate(t *testing.T) { srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Error, nil) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, + "ok/ip": func(t *testing.T) test { + ch := makeTLSCh() + ch.Value = "127.0.0.1" + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + return test{ ch: ch, vo: &ValidateChallengeOptions{ @@ -2235,3 +2272,82 @@ func TestTLSALPN01Validate(t *testing.T) { }) } } + +func Test_reverseAddr(t *testing.T) { + type args struct { + ip net.IP + } + tests := []struct { + name string + args args + wantArpa string + }{ + { + name: "ok/ipv4", + args: args{ + ip: net.ParseIP("127.0.0.1"), + }, + wantArpa: "1.0.0.127.in-addr.arpa.", + }, + { + name: "ok/ipv6", + args: args{ + ip: net.ParseIP("2001:db8::567:89ab"), + }, + wantArpa: "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotArpa := reverseAddr(tt.args.ip); gotArpa != tt.wantArpa { + t.Errorf("reverseAddr() = %v, want %v", gotArpa, tt.wantArpa) + } + }) + } +} + +func Test_serverName(t *testing.T) { + type args struct { + ch *Challenge + } + tests := []struct { + name string + args args + want string + }{ + { + name: "ok/dns", + args: args{ + ch: &Challenge{ + Value: "example.com", + }, + }, + want: "example.com", + }, + { + name: "ok/ipv4", + args: args{ + ch: &Challenge{ + Value: "127.0.0.1", + }, + }, + want: "1.0.0.127.in-addr.arpa.", + }, + { + name: "ok/ipv6", + args: args{ + ch: &Challenge{ + Value: "2001:db8::567:89ab", + }, + }, + want: "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := serverName(tt.args.ch); got != tt.want { + t.Errorf("serverName() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/acme/common.go b/acme/common.go index 26552c61..f18907fe 100644 --- a/acme/common.go +++ b/acme/common.go @@ -11,7 +11,7 @@ import ( // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - LoadProvisionerByID(string) (provisioner.Interface, error) + LoadProvisionerByName(string) (provisioner.Interface, error) } // Clock that returns time in UTC rounded to seconds. diff --git a/acme/db/nosql/account_test.go b/acme/db/nosql/account_test.go index 5ba99a73..a02e93dc 100644 --- a/acme/db/nosql/account_test.go +++ b/acme/db/nosql/account_test.go @@ -93,8 +93,8 @@ func TestDB_getDBAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil { + d := DB{db: tc.db} + if dbacc, err := d.getDBAccount(context.Background(), accID); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -109,15 +109,13 @@ func TestDB_getDBAccount(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, dbacc.ID, tc.dbacc.ID) - assert.Equals(t, dbacc.Status, tc.dbacc.Status) - assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt) - assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt) - assert.Equals(t, dbacc.Contact, tc.dbacc.Contact) - assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, dbacc.ID, tc.dbacc.ID) + assert.Equals(t, dbacc.Status, tc.dbacc.Status) + assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt) + assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt) + assert.Equals(t, dbacc.Contact, tc.dbacc.Contact) + assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID) } }) } @@ -174,8 +172,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil { + d := DB{db: tc.db} + if retAccID, err := d.getAccountIDByKeyID(context.Background(), kid); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -190,10 +188,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, retAccID, accID) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, retAccID, accID) } }) } @@ -250,8 +246,8 @@ func TestDB_GetAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if acc, err := db.GetAccount(context.Background(), accID); err != nil { + d := DB{db: tc.db} + if acc, err := d.GetAccount(context.Background(), accID); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -266,13 +262,11 @@ func TestDB_GetAccount(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acc.ID, tc.dbacc.ID) - assert.Equals(t, acc.Status, tc.dbacc.Status) - assert.Equals(t, acc.Contact, tc.dbacc.Contact) - assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, tc.dbacc.ID) + assert.Equals(t, acc.Status, tc.dbacc.Status) + assert.Equals(t, acc.Contact, tc.dbacc.Contact) + assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) } }) } @@ -358,8 +352,8 @@ func TestDB_GetAccountByKeyID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil { + d := DB{db: tc.db} + if acc, err := d.GetAccountByKeyID(context.Background(), kid); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -374,13 +368,11 @@ func TestDB_GetAccountByKeyID(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acc.ID, tc.dbacc.ID) - assert.Equals(t, acc.Status, tc.dbacc.Status) - assert.Equals(t, acc.Contact, tc.dbacc.Contact) - assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, tc.dbacc.ID) + assert.Equals(t, acc.Status, tc.dbacc.Status) + assert.Equals(t, acc.Contact, tc.dbacc.Contact) + assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) } }) } @@ -527,8 +519,8 @@ func TestDB_CreateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.CreateAccount(context.Background(), tc.acc); err != nil { + d := DB{db: tc.db} + if err := d.CreateAccount(context.Background(), tc.acc); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -688,8 +680,8 @@ func TestDB_UpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.UpdateAccount(context.Background(), tc.acc); err != nil { + d := DB{db: tc.db} + if err := d.UpdateAccount(context.Background(), tc.acc); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/acme/db/nosql/authz_test.go b/acme/db/nosql/authz_test.go index 0c2cec50..01c255dc 100644 --- a/acme/db/nosql/authz_test.go +++ b/acme/db/nosql/authz_test.go @@ -97,8 +97,8 @@ func TestDB_getDBAuthz(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil { + d := DB{db: tc.db} + if dbaz, err := d.getDBAuthz(context.Background(), azID); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -113,18 +113,16 @@ func TestDB_getDBAuthz(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, dbaz.ID, tc.dbaz.ID) - assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID) - assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier) - assert.Equals(t, dbaz.Status, tc.dbaz.Status) - assert.Equals(t, dbaz.Token, tc.dbaz.Token) - assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt) - assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt) - assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error()) - assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, dbaz.ID, tc.dbaz.ID) + assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID) + assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier) + assert.Equals(t, dbaz.Status, tc.dbaz.Status) + assert.Equals(t, dbaz.Token, tc.dbaz.Token) + assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt) + assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt) + assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error()) + assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard) } }) } @@ -293,8 +291,8 @@ func TestDB_GetAuthorization(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if az, err := db.GetAuthorization(context.Background(), azID); err != nil { + d := DB{db: tc.db} + if az, err := d.GetAuthorization(context.Background(), azID); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -309,21 +307,19 @@ func TestDB_GetAuthorization(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, az.ID, tc.dbaz.ID) - assert.Equals(t, az.AccountID, tc.dbaz.AccountID) - assert.Equals(t, az.Identifier, tc.dbaz.Identifier) - assert.Equals(t, az.Status, tc.dbaz.Status) - assert.Equals(t, az.Token, tc.dbaz.Token) - assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard) - assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt) - assert.Equals(t, az.Challenges, []*acme.Challenge{ - {ID: "foo"}, - {ID: "bar"}, - }) - assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error()) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, az.ID, tc.dbaz.ID) + assert.Equals(t, az.AccountID, tc.dbaz.AccountID) + assert.Equals(t, az.Identifier, tc.dbaz.Identifier) + assert.Equals(t, az.Status, tc.dbaz.Status) + assert.Equals(t, az.Token, tc.dbaz.Token) + assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard) + assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt) + assert.Equals(t, az.Challenges, []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }) + assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error()) } }) } @@ -445,8 +441,8 @@ func TestDB_CreateAuthorization(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.CreateAuthorization(context.Background(), tc.az); err != nil { + d := DB{db: tc.db} + if err := d.CreateAuthorization(context.Background(), tc.az); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -594,8 +590,8 @@ func TestDB_UpdateAuthorization(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil { + d := DB{db: tc.db} + if err := d.UpdateAuthorization(context.Background(), tc.az); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/acme/db/nosql/certificate_test.go b/acme/db/nosql/certificate_test.go index 4ec4589e..37a61352 100644 --- a/acme/db/nosql/certificate_test.go +++ b/acme/db/nosql/certificate_test.go @@ -98,8 +98,8 @@ func TestDB_CreateCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.CreateCertificate(context.Background(), tc.cert); err != nil { + d := DB{db: tc.db} + if err := d.CreateCertificate(context.Background(), tc.cert); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -228,8 +228,8 @@ func TestDB_GetCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - cert, err := db.GetCertificate(context.Background(), certID) + d := DB{db: tc.db} + cert, err := d.GetCertificate(context.Background(), certID) if err != nil { switch k := err.(type) { case *acme.Error: @@ -245,14 +245,12 @@ func TestDB_GetCertificate(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, cert.ID, certID) - assert.Equals(t, cert.AccountID, "accountID") - assert.Equals(t, cert.OrderID, "orderID") - assert.Equals(t, cert.Leaf, leaf) - assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, cert.ID, certID) + assert.Equals(t, cert.AccountID, "accountID") + assert.Equals(t, cert.OrderID, "orderID") + assert.Equals(t, cert.Leaf, leaf) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) } }) } diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go index f3a3cfca..f84a6f4e 100644 --- a/acme/db/nosql/challenge.go +++ b/acme/db/nosql/challenge.go @@ -11,15 +11,15 @@ import ( ) type dbChallenge struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Type string `json:"type"` - Status acme.Status `json:"status"` - Token string `json:"token"` - Value string `json:"value"` - ValidatedAt string `json:"validatedAt"` - CreatedAt time.Time `json:"createdAt"` - Error *acme.Error `json:"error"` + ID string `json:"id"` + AccountID string `json:"accountID"` + Type acme.ChallengeType `json:"type"` + Status acme.Status `json:"status"` + Token string `json:"token"` + Value string `json:"value"` + ValidatedAt string `json:"validatedAt"` + CreatedAt time.Time `json:"createdAt"` + Error *acme.Error `json:"error"` } func (dbc *dbChallenge) clone() *dbChallenge { diff --git a/acme/db/nosql/challenge_test.go b/acme/db/nosql/challenge_test.go index b39395e8..4da5679b 100644 --- a/acme/db/nosql/challenge_test.go +++ b/acme/db/nosql/challenge_test.go @@ -92,8 +92,8 @@ func TestDB_getDBChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if ch, err := db.getDBChallenge(context.Background(), chID); err != nil { + d := DB{db: tc.db} + if ch, err := d.getDBChallenge(context.Background(), chID); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -108,17 +108,15 @@ func TestDB_getDBChallenge(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.ID, tc.dbc.ID) - assert.Equals(t, ch.AccountID, tc.dbc.AccountID) - assert.Equals(t, ch.Type, tc.dbc.Type) - assert.Equals(t, ch.Status, tc.dbc.Status) - assert.Equals(t, ch.Token, tc.dbc.Token) - assert.Equals(t, ch.Value, tc.dbc.Value) - assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) - assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, ch.ID, tc.dbc.ID) + assert.Equals(t, ch.AccountID, tc.dbc.AccountID) + assert.Equals(t, ch.Type, tc.dbc.Type) + assert.Equals(t, ch.Status, tc.dbc.Status) + assert.Equals(t, ch.Token, tc.dbc.Token) + assert.Equals(t, ch.Value, tc.dbc.Value) + assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) + assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) } }) } @@ -206,8 +204,8 @@ func TestDB_CreateChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.CreateChallenge(context.Background(), tc.ch); err != nil { + d := DB{db: tc.db} + if err := d.CreateChallenge(context.Background(), tc.ch); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -286,8 +284,8 @@ func TestDB_GetChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil { + d := DB{db: tc.db} + if ch, err := d.GetChallenge(context.Background(), chID, azID); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -302,17 +300,15 @@ func TestDB_GetChallenge(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.ID, tc.dbc.ID) - assert.Equals(t, ch.AccountID, tc.dbc.AccountID) - assert.Equals(t, ch.Type, tc.dbc.Type) - assert.Equals(t, ch.Status, tc.dbc.Status) - assert.Equals(t, ch.Token, tc.dbc.Token) - assert.Equals(t, ch.Value, tc.dbc.Value) - assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) - assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, ch.ID, tc.dbc.ID) + assert.Equals(t, ch.AccountID, tc.dbc.AccountID) + assert.Equals(t, ch.Type, tc.dbc.Type) + assert.Equals(t, ch.Status, tc.dbc.Status) + assert.Equals(t, ch.Token, tc.dbc.Token) + assert.Equals(t, ch.Value, tc.dbc.Value) + assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) + assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) } }) } @@ -442,8 +438,8 @@ func TestDB_UpdateChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil { + d := DB{db: tc.db} + if err := d.UpdateChallenge(context.Background(), tc.ch); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go index 9badae87..e438c9ed 100644 --- a/acme/db/nosql/nonce.go +++ b/acme/db/nosql/nonce.go @@ -31,7 +31,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) { ID: id, CreatedAt: clock.Now(), } - if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil { + if err := db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil { return "", err } return acme.Nonce(id), nil diff --git a/acme/db/nosql/nonce_test.go b/acme/db/nosql/nonce_test.go index 05d73d52..7dc5cc91 100644 --- a/acme/db/nosql/nonce_test.go +++ b/acme/db/nosql/nonce_test.go @@ -67,8 +67,8 @@ func TestDB_CreateNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if n, err := db.CreateNonce(context.Background()); err != nil { + d := DB{db: tc.db} + if n, err := d.CreateNonce(context.Background()); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -144,8 +144,8 @@ func TestDB_DeleteNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil { + d := DB{db: tc.db} + if err := d.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go index 052f5729..b1547373 100644 --- a/acme/db/nosql/nosql.go +++ b/acme/db/nosql/nosql.go @@ -41,7 +41,7 @@ func New(db nosqlDB.DB) (*DB, error) { // save writes the new data to the database, overwriting the old data if it // existed. -func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { +func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error { var ( err error newB []byte diff --git a/acme/db/nosql/nosql_test.go b/acme/db/nosql/nosql_test.go index 4396acc8..d9c0b484 100644 --- a/acme/db/nosql/nosql_test.go +++ b/acme/db/nosql/nosql_test.go @@ -126,8 +126,8 @@ func TestDB_save(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - db := &DB{db: tc.db} - if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil { + d := &DB{db: tc.db} + if err := d.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index ba3934af..0c6bf795 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -124,10 +124,8 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st ordersByAccountMux.Lock() defer ordersByAccountMux.Unlock() + var oldOids []string b, err := db.db.Get(ordersByAccountIDTable, []byte(accID)) - var ( - oldOids []string - ) if err != nil { if !nosql.IsErrNotFound(err) { return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) diff --git a/acme/db/nosql/order_test.go b/acme/db/nosql/order_test.go index 7248700f..e92eb684 100644 --- a/acme/db/nosql/order_test.go +++ b/acme/db/nosql/order_test.go @@ -12,7 +12,7 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" - nosqldb "github.com/smallstep/nosql/database" + "github.com/smallstep/nosql/database" ) func TestDB_getDBOrder(t *testing.T) { @@ -31,7 +31,7 @@ func TestDB_getDBOrder(t *testing.T) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) - return nil, nosqldb.ErrNotFound + return nil, database.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), @@ -100,8 +100,8 @@ func TestDB_getDBOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if dbo, err := db.getDBOrder(context.Background(), orderID); err != nil { + d := DB{db: tc.db} + if dbo, err := d.getDBOrder(context.Background(), orderID); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -116,20 +116,18 @@ func TestDB_getDBOrder(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, dbo.ID, tc.dbo.ID) - assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID) - assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID) - assert.Equals(t, dbo.Status, tc.dbo.Status) - assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt) - assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt) - assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore) - assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter) - assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers) - assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs) - assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error()) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, dbo.ID, tc.dbo.ID) + assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID) + assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID) + assert.Equals(t, dbo.Status, tc.dbo.Status) + assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt) + assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt) + assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore) + assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter) + assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers) + assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs) + assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error()) } }) } @@ -164,7 +162,7 @@ func TestDB_GetOrder(t *testing.T) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) - return nil, nosqldb.ErrNotFound + return nil, database.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), @@ -206,8 +204,8 @@ func TestDB_GetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if o, err := db.GetOrder(context.Background(), orderID); err != nil { + d := DB{db: tc.db} + if o, err := d.GetOrder(context.Background(), orderID); err != nil { switch k := err.(type) { case *acme.Error: if assert.NotNil(t, tc.acmeErr) { @@ -222,20 +220,18 @@ func TestDB_GetOrder(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, o.ID, tc.dbo.ID) - assert.Equals(t, o.AccountID, tc.dbo.AccountID) - assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID) - assert.Equals(t, o.CertificateID, tc.dbo.CertificateID) - assert.Equals(t, o.Status, tc.dbo.Status) - assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt) - assert.Equals(t, o.NotBefore, tc.dbo.NotBefore) - assert.Equals(t, o.NotAfter, tc.dbo.NotAfter) - assert.Equals(t, o.Identifiers, tc.dbo.Identifiers) - assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs) - assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error()) - } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, o.ID, tc.dbo.ID) + assert.Equals(t, o.AccountID, tc.dbo.AccountID) + assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID) + assert.Equals(t, o.CertificateID, tc.dbo.CertificateID) + assert.Equals(t, o.Status, tc.dbo.Status) + assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt) + assert.Equals(t, o.NotBefore, tc.dbo.NotBefore) + assert.Equals(t, o.NotAfter, tc.dbo.NotAfter) + assert.Equals(t, o.Identifiers, tc.dbo.Identifiers) + assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs) + assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error()) } }) } @@ -366,8 +362,8 @@ func TestDB_UpdateOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.UpdateOrder(context.Background(), tc.o); err != nil { + d := DB{db: tc.db} + if err := d.UpdateOrder(context.Background(), tc.o); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -511,7 +507,7 @@ func TestDB_CreateOrder(t *testing.T) { MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(ordersByAccountIDTable)) assert.Equals(t, string(key), o.AccountID) - return nil, nosqldb.ErrNotFound + return nil, database.ErrNotFound }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { @@ -557,8 +553,8 @@ func TestDB_CreateOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} - if err := db.CreateOrder(context.Background(), tc.o); err != nil { + d := DB{db: tc.db} + if err := d.CreateOrder(context.Background(), tc.o); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -680,7 +676,7 @@ func TestDB_updateAddOrderIDs(t *testing.T) { MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(accID)) - return nil, nosqldb.ErrNotFound + return nil, database.ErrNotFound }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, ordersByAccountIDTable) @@ -710,6 +706,34 @@ func TestDB_updateAddOrderIDs(t *testing.T) { err: errors.Errorf("error saving orderIDs index for account %s", accID), } }, + "ok/no-old": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + return nil, database.ErrNotFound + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, nil) + assert.Equals(t, nu, nil) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + res: []string{}, + } + }, "ok/all-old-not-pending": func(t *testing.T) test { oldOids := []string{"foo", "bar"} bOldOids, err := json.Marshal(oldOids) @@ -967,15 +991,15 @@ func TestDB_updateAddOrderIDs(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - db := DB{db: tc.db} + d := DB{db: tc.db} var ( res []string err error ) if tc.addOids == nil { - res, err = db.updateAddOrderIDs(context.Background(), accID) + res, err = d.updateAddOrderIDs(context.Background(), accID) } else { - res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...) + res, err = d.updateAddOrderIDs(context.Background(), accID, tc.addOids...) } if err != nil { @@ -993,10 +1017,8 @@ func TestDB_updateAddOrderIDs(t *testing.T) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } - } else { - if assert.Nil(t, tc.err) { - assert.True(t, reflect.DeepEqual(res, tc.res)) - } + } else if assert.Nil(t, tc.err) { + assert.True(t, reflect.DeepEqual(res, tc.res)) } }) } diff --git a/acme/order.go b/acme/order.go index a003fe9a..237c6979 100644 --- a/acme/order.go +++ b/acme/order.go @@ -1,9 +1,11 @@ package acme import ( + "bytes" "context" "crypto/x509" "encoding/json" + "net" "sort" "strings" "time" @@ -12,10 +14,17 @@ import ( "go.step.sm/crypto/x509util" ) +type IdentifierType string + +const ( + IP IdentifierType = "ip" + DNS IdentifierType = "dns" +) + // Identifier encodes the type that an order pertains to. type Identifier struct { - Type string `json:"type"` - Value string `json:"value"` + Type IdentifierType `json:"type"` + Value string `json:"value"` } // Order contains order metadata for the ACME protocol order type. @@ -131,41 +140,13 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID) } - // RFC8555: The CSR MUST indicate the exact same set of requested - // identifiers as the initial newOrder request. Identifiers of type "dns" - // MUST appear either in the commonName portion of the requested subject - // name or in an extensionRequest attribute [RFC2985] requesting a - // subjectAltName extension, or both. - if csr.Subject.CommonName != "" { - csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) - } - csr.DNSNames = uniqueSortedLowerNames(csr.DNSNames) - orderNames := make([]string, len(o.Identifiers)) - for i, n := range o.Identifiers { - orderNames[i] = n.Value - } - orderNames = uniqueSortedLowerNames(orderNames) + // canonicalize the CSR to allow for comparison + csr = canonicalize(csr) - // Validate identifier names against CSR alternative names. - // - // Note that with certificate templates we are not going to check for the - // absence of other SANs as they will only be set if the templates allows - // them. - if len(csr.DNSNames) != len(orderNames) { - return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ - "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) - } - - sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames)) - for i := range csr.DNSNames { - if csr.DNSNames[i] != orderNames[i] { - return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ - "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) - } - sans[i] = x509util.SubjectAlternativeName{ - Type: x509util.DNSType, - Value: csr.DNSNames[i], - } + // retrieve the requested SANs for the Order + sans, err := o.sans(csr) + if err != nil { + return err } // Get authorizations from the ACME provisioner. @@ -213,6 +194,123 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques return nil } +func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativeName, error) { + + var sans []x509util.SubjectAlternativeName + + // order the DNS names and IP addresses, so that they can be compared against the canonicalized CSR + orderNames := make([]string, numberOfIdentifierType(DNS, o.Identifiers)) + orderIPs := make([]net.IP, numberOfIdentifierType(IP, o.Identifiers)) + indexDNS, indexIP := 0, 0 + for _, n := range o.Identifiers { + switch n.Type { + case DNS: + orderNames[indexDNS] = n.Value + indexDNS++ + case IP: + orderIPs[indexIP] = net.ParseIP(n.Value) // NOTE: this assumes are all valid IPs at this time; or will result in nil entries + indexIP++ + default: + return sans, NewErrorISE("unsupported identifier type in order: %s", n.Type) + } + } + orderNames = uniqueSortedLowerNames(orderNames) + orderIPs = uniqueSortedIPs(orderIPs) + + totalNumberOfSANs := len(csr.DNSNames) + len(csr.IPAddresses) + sans = make([]x509util.SubjectAlternativeName, totalNumberOfSANs) + index := 0 + + // Validate identifier names against CSR alternative names. + // + // Note that with certificate templates we are not going to check for the + // absence of other SANs as they will only be set if the template allows + // them. + if len(csr.DNSNames) != len(orderNames) { + return sans, NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) + } + + for i := range csr.DNSNames { + if csr.DNSNames[i] != orderNames[i] { + return sans, NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) + } + sans[index] = x509util.SubjectAlternativeName{ + Type: x509util.DNSType, + Value: csr.DNSNames[i], + } + index++ + } + + if len(csr.IPAddresses) != len(orderIPs) { + return sans, NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+ + "CSR IPs = %v, Order IPs = %v", csr.IPAddresses, orderIPs) + } + + for i := range csr.IPAddresses { + if !ipsAreEqual(csr.IPAddresses[i], orderIPs[i]) { + return sans, NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+ + "CSR IPs = %v, Order IPs = %v", csr.IPAddresses, orderIPs) + } + sans[index] = x509util.SubjectAlternativeName{ + Type: x509util.IPType, + Value: csr.IPAddresses[i].String(), + } + index++ + } + + return sans, nil +} + +// numberOfIdentifierType returns the number of Identifiers that +// are of type typ. +func numberOfIdentifierType(typ IdentifierType, ids []Identifier) int { + c := 0 + for _, id := range ids { + if id.Type == typ { + c++ + } + } + return c +} + +// canonicalize canonicalizes a CSR so that it can be compared against an Order +// NOTE: this effectively changes the order of SANs in the CSR, which may be OK, +// but may not be expected. +func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.CertificateRequest) { + + // for clarity only; we're operating on the same object by pointer + canonicalized = csr + + // RFC8555: The CSR MUST indicate the exact same set of requested + // identifiers as the initial newOrder request. Identifiers of type "dns" + // MUST appear either in the commonName portion of the requested subject + // name or in an extensionRequest attribute [RFC2985] requesting a + // subjectAltName extension, or both. + if csr.Subject.CommonName != "" { + // nolint:gocritic + canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) + } + canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames) + canonicalized.IPAddresses = uniqueSortedIPs(csr.IPAddresses) + + return canonicalized +} + +// ipsAreEqual compares IPs to be equal. Nil values (i.e. invalid IPs) are +// not considered equal. IPv6 representations of IPv4 addresses are +// considered equal to the IPv4 address in this implementation, which is +// standard Go behavior. An example is "::ffff:192.168.42.42", which +// is equal to "192.168.42.42". This is considered a known issue within +// step and is tracked here too: https://github.com/golang/go/issues/37921. +func ipsAreEqual(x, y net.IP) bool { + if x == nil || y == nil { + return false + } + return x.Equal(y) +} + // uniqueSortedLowerNames returns the set of all unique names in the input after all // of them are lowercased. The returned names will be in their lowercased form // and sorted alphabetically. @@ -228,3 +326,23 @@ func uniqueSortedLowerNames(names []string) (unique []string) { sort.Strings(unique) return } + +// uniqueSortedIPs returns the set of all unique net.IPs in the input. They +// are sorted by their bytes (octet) representation. +func uniqueSortedIPs(ips []net.IP) (unique []net.IP) { + type entry struct { + ip net.IP + } + ipEntryMap := make(map[string]entry, len(ips)) + for _, ip := range ips { + ipEntryMap[ip.String()] = entry{ip: ip} + } + unique = make([]net.IP, 0, len(ipEntryMap)) + for _, entry := range ipEntryMap { + unique = append(unique, entry.ip) + } + sort.Slice(unique, func(i, j int) bool { + return bytes.Compare(unique[i], unique[j]) < 0 + }) + return +} diff --git a/acme/order_test.go b/acme/order_test.go index 993a92f2..83488c8c 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -5,12 +5,15 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/json" + "net" + "reflect" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/crypto/x509util" ) func TestOrder_UpdateStatus(t *testing.T) { @@ -261,10 +264,10 @@ func TestOrder_UpdateStatus(t *testing.T) { } type mockSignAuth struct { - sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - loadProvisionerByID func(string) (provisioner.Interface, error) - ret1, ret2 interface{} - err error + sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + loadProvisionerByName func(string) (provisioner.Interface, error) + ret1, ret2 interface{} + err error } func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { @@ -276,9 +279,9 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } -func (m *mockSignAuth) LoadProvisionerByID(id string) (provisioner.Interface, error) { - if m.loadProvisionerByID != nil { - return m.loadProvisionerByID(id) +func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) { + if m.loadProvisionerByName != nil { + return m.loadProvisionerByName(name) } return m.ret1.(provisioner.Interface), m.err } @@ -364,61 +367,6 @@ func TestOrder_Finalize(t *testing.T) { err: NewErrorISE("unrecognized order status: %s", o.Status), } }, - "fail/error-names-length-mismatch": func(t *testing.T) test { - now := clock.Now() - o := &Order{ - ID: "oID", - AccountID: "accID", - Status: StatusReady, - ExpiresAt: now.Add(5 * time.Minute), - AuthorizationIDs: []string{"a", "b"}, - Identifiers: []Identifier{ - {Type: "dns", Value: "foo.internal"}, - {Type: "dns", Value: "bar.internal"}, - }, - } - orderNames := []string{"bar.internal", "foo.internal"} - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "foo.internal", - }, - } - - return test{ - o: o, - csr: csr, - err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ - "CSR names = %v, Order names = %v", []string{"foo.internal"}, orderNames), - } - }, - "fail/error-names-mismatch": func(t *testing.T) test { - now := clock.Now() - o := &Order{ - ID: "oID", - AccountID: "accID", - Status: StatusReady, - ExpiresAt: now.Add(5 * time.Minute), - AuthorizationIDs: []string{"a", "b"}, - Identifiers: []Identifier{ - {Type: "dns", Value: "foo.internal"}, - {Type: "dns", Value: "bar.internal"}, - }, - } - orderNames := []string{"bar.internal", "foo.internal"} - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "foo.internal", - }, - DNSNames: []string{"zap.internal"}, - } - - return test{ - o: o, - csr: csr, - err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ - "CSR names = %v, Order names = %v", []string{"foo.internal", "zap.internal"}, orderNames), - } - }, "fail/error-provisioner-auth": func(t *testing.T) test { now := clock.Now() o := &Order{ @@ -650,7 +598,7 @@ func TestOrder_Finalize(t *testing.T) { err: NewErrorISE("error updating order oID: force"), } }, - "ok/new-cert": func(t *testing.T) test { + "ok/new-cert-dns": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", @@ -674,6 +622,131 @@ func TestOrder_Finalize(t *testing.T) { bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} + return test{ + o: o, + csr: csr, + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil + }, + MgetOptions: func() *provisioner.Options { + return nil + }, + }, + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return []*x509.Certificate{foo, bar, baz}, nil + }, + }, + db: &MockDB{ + MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { + cert.ID = "certID" + assert.Equals(t, cert.AccountID, o.AccountID) + assert.Equals(t, cert.OrderID, o.ID) + assert.Equals(t, cert.Leaf, foo) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) + return nil + }, + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.CertificateID, "certID") + assert.Equals(t, updo.Status, StatusValid) + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, updo.Identifiers, o.Identifiers) + return nil + }, + }, + } + }, + "ok/new-cert-ip": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "ip", Value: "192.168.42.42"}, + {Type: "ip", Value: "192.168.43.42"}, + }, + } + csr := &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, // in case of IPs, no Common Name + } + + foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} + bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} + baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} + + return test{ + o: o, + csr: csr, + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil + }, + MgetOptions: func() *provisioner.Options { + return nil + }, + }, + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return []*x509.Certificate{foo, bar, baz}, nil + }, + }, + db: &MockDB{ + MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { + cert.ID = "certID" + assert.Equals(t, cert.AccountID, o.AccountID) + assert.Equals(t, cert.OrderID, o.ID) + assert.Equals(t, cert.Leaf, foo) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) + return nil + }, + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.CertificateID, "certID") + assert.Equals(t, updo.Status, StatusValid) + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, updo.Identifiers, o.Identifiers) + return nil + }, + }, + } + }, + "ok/new-cert-dns-and-ip": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "ip", Value: "192.168.42.42"}, + }, + } + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", + }, + IPAddresses: []net.IP{net.ParseIP("192.168.42.42")}, + } + + foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} + bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} + baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} + return test{ o: o, csr: csr, @@ -737,3 +810,592 @@ func TestOrder_Finalize(t *testing.T) { }) } } + +func Test_uniqueSortedIPs(t *testing.T) { + type args struct { + ips []net.IP + } + tests := []struct { + name string + args args + wantUnique []net.IP + }{ + { + name: "ok/empty", + args: args{ + ips: []net.IP{}, + }, + wantUnique: []net.IP{}, + }, + { + name: "ok/single-ipv4", + args: args{ + ips: []net.IP{net.ParseIP("192.168.42.42")}, + }, + wantUnique: []net.IP{net.ParseIP("192.168.42.42")}, + }, + { + name: "ok/multiple-ipv4", + args: args{ + ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.1")}, + }, + wantUnique: []net.IP{net.ParseIP("192.168.42.1"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.42")}, + }, + { + name: "ok/unique-ipv4", + args: args{ + ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42")}, + }, + wantUnique: []net.IP{net.ParseIP("192.168.42.42")}, + }, + { + name: "ok/single-ipv6", + args: args{ + ips: []net.IP{net.ParseIP("2001:db8::30")}, + }, + wantUnique: []net.IP{net.ParseIP("2001:db8::30")}, + }, + { + name: "ok/multiple-ipv6", + args: args{ + ips: []net.IP{net.ParseIP("2001:db8::30"), net.ParseIP("2001:db8::20"), net.ParseIP("2001:db8::10")}, + }, + wantUnique: []net.IP{net.ParseIP("2001:db8::10"), net.ParseIP("2001:db8::20"), net.ParseIP("2001:db8::30")}, + }, + { + name: "ok/unique-ipv6", + args: args{ + ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1")}, + }, + wantUnique: []net.IP{net.ParseIP("2001:db8::1")}, + }, + { + name: "ok/mixed-ipv4-and-ipv6", + args: args{ + ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1"), net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42")}, + }, + wantUnique: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::1")}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotUnique := uniqueSortedIPs(tt.args.ips); !reflect.DeepEqual(gotUnique, tt.wantUnique) { + t.Errorf("uniqueSortedIPs() = %v, want %v", gotUnique, tt.wantUnique) + } + }) + } +} + +func Test_numberOfIdentifierType(t *testing.T) { + type args struct { + typ IdentifierType + ids []Identifier + } + tests := []struct { + name string + args args + want int + }{ + { + name: "ok/no-identifiers", + args: args{ + typ: DNS, + ids: []Identifier{}, + }, + want: 0, + }, + { + name: "ok/no-dns", + args: args{ + typ: DNS, + ids: []Identifier{ + { + Type: IP, + Value: "192.168.42.42", + }, + }, + }, + want: 0, + }, + { + name: "ok/no-ips", + args: args{ + typ: IP, + ids: []Identifier{ + { + Type: DNS, + Value: "example.com", + }, + }, + }, + want: 0, + }, + { + name: "ok/one-dns", + args: args{ + typ: DNS, + ids: []Identifier{ + { + Type: DNS, + Value: "example.com", + }, + { + Type: IP, + Value: "192.168.42.42", + }, + }, + }, + want: 1, + }, + { + name: "ok/one-ip", + args: args{ + typ: IP, + ids: []Identifier{ + { + Type: DNS, + Value: "example.com", + }, + { + Type: IP, + Value: "192.168.42.42", + }, + }, + }, + want: 1, + }, + { + name: "ok/more-dns", + args: args{ + typ: DNS, + ids: []Identifier{ + { + Type: DNS, + Value: "example.com", + }, + { + Type: DNS, + Value: "*.example.com", + }, + { + Type: IP, + Value: "192.168.42.42", + }, + }, + }, + want: 2, + }, + { + name: "ok/more-ips", + args: args{ + typ: IP, + ids: []Identifier{ + { + Type: DNS, + Value: "example.com", + }, + { + Type: IP, + Value: "192.168.42.42", + }, + { + Type: IP, + Value: "192.168.42.43", + }, + }, + }, + want: 2, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := numberOfIdentifierType(tt.args.typ, tt.args.ids); got != tt.want { + t.Errorf("numberOfIdentifierType() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_ipsAreEqual(t *testing.T) { + type args struct { + x net.IP + y net.IP + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "ok/ipv4", + args: args{ + x: net.ParseIP("192.168.42.42"), + y: net.ParseIP("192.168.42.42"), + }, + want: true, + }, + { + name: "fail/ipv4", + args: args{ + x: net.ParseIP("192.168.42.42"), + y: net.ParseIP("192.168.42.43"), + }, + want: false, + }, + { + name: "ok/ipv6", + args: args{ + x: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + }, + want: true, + }, + { + name: "fail/ipv6", + args: args{ + x: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7335"), + }, + want: false, + }, + { + name: "fail/ipv4-and-ipv6", + args: args{ + x: net.ParseIP("192.168.42.42"), + y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + }, + want: false, + }, + { + name: "ok/ipv4-mapped-to-ipv6", + args: args{ + x: net.ParseIP("192.168.42.42"), + y: net.ParseIP("::ffff:192.168.42.42"), // parsed to the same IPv4 by Go + }, + want: true, // we expect this to happen; a known issue in which ipv4 mapped ipv6 addresses are considered the same as their ipv4 counterpart + }, + { + name: "fail/invalid-ipv4-and-valid-ipv6", + args: args{ + x: net.ParseIP("192.168.42.1000"), + y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + }, + want: false, + }, + { + name: "fail/valid-ipv4-and-invalid-ipv6", + args: args{ + x: net.ParseIP("192.168.42.42"), + y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:733400"), + }, + want: false, + }, + { + name: "fail/invalid-ipv4-and-invalid-ipv6", + args: args{ + x: net.ParseIP("192.168.42.1000"), + y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:1000000"), + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ipsAreEqual(tt.args.x, tt.args.y); got != tt.want { + t.Errorf("ipsAreEqual() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_canonicalize(t *testing.T) { + type args struct { + csr *x509.CertificateRequest + } + tests := []struct { + name string + args args + wantCanonicalized *x509.CertificateRequest + }{ + { + name: "ok/dns", + args: args{ + csr: &x509.CertificateRequest{ + DNSNames: []string{"www.example.com", "example.com"}, + }, + }, + wantCanonicalized: &x509.CertificateRequest{ + DNSNames: []string{"example.com", "www.example.com"}, + IPAddresses: []net.IP{}, + }, + }, + { + name: "ok/common-name", + args: args{ + csr: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "example.com", + }, + DNSNames: []string{"www.example.com"}, + }, + }, + wantCanonicalized: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "example.com", + }, + DNSNames: []string{"example.com", "www.example.com"}, + IPAddresses: []net.IP{}, + }, + }, + { + name: "ok/ipv4", + args: args{ + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, + }, + }, + wantCanonicalized: &x509.CertificateRequest{ + DNSNames: []string{}, + IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, + }, + }, + { + name: "ok/mixed", + args: args{ + csr: &x509.CertificateRequest{ + DNSNames: []string{"www.example.com", "example.com"}, + IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, + }, + }, + wantCanonicalized: &x509.CertificateRequest{ + DNSNames: []string{"example.com", "www.example.com"}, + IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, + }, + }, + { + name: "ok/mixed-common-name", + args: args{ + csr: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "example.com", + }, + DNSNames: []string{"www.example.com"}, + IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, + }, + }, + wantCanonicalized: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "example.com", + }, + DNSNames: []string{"example.com", "www.example.com"}, + IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotCanonicalized := canonicalize(tt.args.csr); !reflect.DeepEqual(gotCanonicalized, tt.wantCanonicalized) { + t.Errorf("canonicalize() = %v, want %v", gotCanonicalized, tt.wantCanonicalized) + } + }) + } +} + +func TestOrder_sans(t *testing.T) { + type fields struct { + Identifiers []Identifier + } + tests := []struct { + name string + fields fields + csr *x509.CertificateRequest + want []x509util.SubjectAlternativeName + err *Error + }{ + { + name: "ok/dns", + fields: fields{ + Identifiers: []Identifier{ + {Type: "dns", Value: "example.com"}, + }, + }, + csr: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "example.com", + }, + }, + want: []x509util.SubjectAlternativeName{ + {Type: "dns", Value: "example.com"}, + }, + err: nil, + }, + { + name: "fail/error-names-length-mismatch", + fields: fields{ + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + }, + csr: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", + }, + }, + want: []x509util.SubjectAlternativeName{}, + err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", []string{"foo.internal"}, []string{"bar.internal", "foo.internal"}), + }, + { + name: "fail/error-names-mismatch", + fields: fields{ + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + }, + csr: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", + }, + DNSNames: []string{"zap.internal"}, + }, + want: []x509util.SubjectAlternativeName{}, + err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", []string{"foo.internal", "zap.internal"}, []string{"bar.internal", "foo.internal"}), + }, + { + name: "ok/ipv4", + fields: fields{ + Identifiers: []Identifier{ + {Type: "ip", Value: "192.168.43.42"}, + {Type: "ip", Value: "192.168.42.42"}, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, + }, + want: []x509util.SubjectAlternativeName{ + {Type: "ip", Value: "192.168.42.42"}, + {Type: "ip", Value: "192.168.43.42"}, + }, + err: nil, + }, + { + name: "ok/ipv6", + fields: fields{ + Identifiers: []Identifier{ + {Type: "ip", Value: "2001:0db8:85a3::8a2e:0370:7335"}, + {Type: "ip", Value: "2001:0db8:85a3::8a2e:0370:7334"}, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7335"), net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, + }, + want: []x509util.SubjectAlternativeName{ + {Type: "ip", Value: "2001:db8:85a3::8a2e:370:7334"}, + {Type: "ip", Value: "2001:db8:85a3::8a2e:370:7335"}, + }, + err: nil, + }, + { + name: "fail/error-ips-length-mismatch", + fields: fields{ + Identifiers: []Identifier{ + {Type: "ip", Value: "192.168.42.42"}, + {Type: "ip", Value: "192.168.43.42"}, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("192.168.42.42")}, + }, + want: []x509util.SubjectAlternativeName{}, + err: NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+ + "CSR IPs = %v, Order IPs = %v", []net.IP{net.ParseIP("192.168.42.42")}, []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}), + }, + { + name: "fail/error-ips-mismatch", + fields: fields{ + Identifiers: []Identifier{ + {Type: "ip", Value: "192.168.42.42"}, + {Type: "ip", Value: "192.168.43.42"}, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.32")}, + }, + want: []x509util.SubjectAlternativeName{}, + err: NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+ + "CSR IPs = %v, Order IPs = %v", []net.IP{net.ParseIP("192.168.42.32"), net.ParseIP("192.168.42.42")}, []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}), + }, + { + name: "ok/mixed", + fields: fields{ + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + {Type: "ip", Value: "192.168.43.42"}, + {Type: "ip", Value: "192.168.42.42"}, + {Type: "ip", Value: "2001:0db8:85a3:0000:0000:8a2e:0370:7334"}, + }, + }, + csr: &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "bar.internal", + }, + DNSNames: []string{"foo.internal"}, + IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42"), net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, + }, + want: []x509util.SubjectAlternativeName{ + {Type: "dns", Value: "bar.internal"}, + {Type: "dns", Value: "foo.internal"}, + {Type: "ip", Value: "192.168.42.42"}, + {Type: "ip", Value: "192.168.43.42"}, + {Type: "ip", Value: "2001:db8:85a3::8a2e:370:7334"}, + }, + err: nil, + }, + { + name: "fail/unsupported-identifier-type", + fields: fields{ + Identifiers: []Identifier{ + {Type: "ipv4", Value: "192.168.42.42"}, + }, + }, + csr: &x509.CertificateRequest{ + IPAddresses: []net.IP{net.ParseIP("192.168.42.42")}, + }, + want: []x509util.SubjectAlternativeName{}, + err: NewError(ErrorServerInternalType, "unsupported identifier type in order: ipv4"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &Order{ + Identifiers: tt.fields.Identifiers, + } + canonicalizedCSR := canonicalize(tt.csr) + got, err := o.sans(canonicalizedCSR) + if tt.err != nil { + if err == nil { + t.Errorf("Order.sans() = %v, want error; got none", got) + return + } + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tt.err.Type) + assert.Equals(t, k.Detail, tt.err.Detail) + assert.Equals(t, k.Status, tt.err.Status) + assert.Equals(t, k.Err.Error(), tt.err.Err.Error()) + assert.Equals(t, k.Detail, tt.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Order.sans() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/api/api.go b/api/api.go index 6a0a7e8f..30ba03f9 100644 --- a/api/api.go +++ b/api/api.go @@ -21,6 +21,7 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" @@ -32,13 +33,13 @@ type Authority interface { // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error) - GetTLSOptions() *authority.TLSOptions + GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error) - LoadProvisionerByID(string) (provisioner.Interface, error) + LoadProvisionerByName(string) (provisioner.Interface, error) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) Revoke(context.Context, *authority.RevokeOptions) error GetEncryptedKey(kid string) (string, error) @@ -239,9 +240,9 @@ type caHandler struct { } // New creates a new RouterHandler with the CA endpoints. -func New(authority Authority) RouterHandler { +func New(auth Authority) RouterHandler { return &caHandler{ - Authority: authority, + Authority: auth, } } @@ -294,7 +295,7 @@ func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { // certificate for the given SHA256. func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { sha := chi.URLParam(r, "sha") - sum := strings.ToLower(strings.Replace(sha, "-", "", -1)) + sum := strings.ToLower(strings.ReplaceAll(sha, "-", "")) // Load root certificate with the cert, err := h.Authority.Root(sum) if err != nil { @@ -315,7 +316,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { // Provisioners returns the list of provisioners configured in the authority. func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { - cursor, limit, err := parseCursor(r) + cursor, limit, err := ParseCursor(r) if err != nil { WriteError(w, errs.BadRequestErr(err)) return @@ -399,7 +400,7 @@ func logOtt(w http.ResponseWriter, token string) { func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) { if rl, ok := w.(logging.ResponseLogger); ok { m := map[string]interface{}{ - "serial": cert.SerialNumber, + "serial": cert.SerialNumber.String(), "subject": cert.Subject.CommonName, "issuer": cert.Issuer.CommonName, "valid-from": cert.NotBefore.Format(time.RFC3339), @@ -408,25 +409,27 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) { "certificate": base64.StdEncoding.EncodeToString(cert.Raw), } for _, ext := range cert.Extensions { - if ext.Id.Equal(oidStepProvisioner) { - val := &stepProvisioner{} - rest, err := asn1.Unmarshal(ext.Value, val) - if err != nil || len(rest) > 0 { - break - } - if len(val.CredentialID) > 0 { - m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID) - } else { - m["provisioner"] = fmt.Sprintf("%s", val.Name) - } + if !ext.Id.Equal(oidStepProvisioner) { + continue + } + val := &stepProvisioner{} + rest, err := asn1.Unmarshal(ext.Value, val) + if err != nil || len(rest) > 0 { break } + if len(val.CredentialID) > 0 { + m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID) + } else { + m["provisioner"] = string(val.Name) + } + break } rl.WithFields(m) } } -func parseCursor(r *http.Request) (cursor string, limit int, err error) { +// ParseCursor parses the cursor and limit from the request query params. +func ParseCursor(r *http.Request) (cursor string, limit int, err error) { q := r.URL.Query() cursor = q.Get("cursor") if v := q.Get("limit"); len(v) > 0 { diff --git a/api/api_test.go b/api/api_test.go index 62ef7740..89596165 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -186,8 +186,8 @@ func TestCertificate_MarshalJSON(t *testing.T) { }{ {"nil", fields{Certificate: nil}, []byte("null"), false}, {"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false}, - {"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"`), false}, - {"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`), false}, + {"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"`), false}, + {"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`), false}, } for _, tt := range tests { @@ -219,11 +219,11 @@ func TestCertificate_UnmarshalJSON(t *testing.T) { {"invalid string", []byte(`"foobar"`), false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true}, - {"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false, true}, + {"invalid type", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), false, true}, {"empty string", []byte(`""`), false, false}, {"json null", []byte(`null`), false, false}, - {"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true, false}, - {"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), true, false}, + {"valid root", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), true, false}, + {"valid cert", []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"`), true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -251,7 +251,7 @@ func TestCertificate_UnmarshalJSON_json(t *testing.T) { {"empty crt (null)", `{"crt":null}`, false, false}, {"empty crt (string)", `{"crt":""}`, false, false}, {"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true}, - {"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, true, false}, + {"valid crt", `{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"}`, true, false}, } type request struct { @@ -297,7 +297,7 @@ func TestCertificateRequest_MarshalJSON(t *testing.T) { }{ {"nil", fields{CertificateRequest: nil}, []byte("null"), false}, {"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false}, - {"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `\n"`), false}, + {"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `\n"`), false}, } for _, tt := range tests { @@ -329,10 +329,10 @@ func TestCertificateRequest_UnmarshalJSON(t *testing.T) { {"invalid string", []byte(`"foobar"`), false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true}, - {"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false, true}, + {"invalid type", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), false, true}, {"empty string", []byte(`""`), false, false}, {"json null", []byte(`null`), false, false}, - {"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true, false}, + {"valid csr", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -360,7 +360,7 @@ func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) { {"empty csr (null)", `{"csr":null}`, false, false}, {"empty csr (string)", `{"csr":""}`, false, false}, {"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true}, - {"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, true, false}, + {"valid csr", `{"csr":"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"}`, true, false}, } type request struct { @@ -430,6 +430,7 @@ type mockProvisioner struct { ret1, ret2, ret3 interface{} err error getID func() string + getIDForToken func() string getTokenID func(string) (string, error) getName func() string getType func() provisioner.Type @@ -452,6 +453,13 @@ func (m *mockProvisioner) GetID() string { return m.ret1.(string) } +func (m *mockProvisioner) GetIDForToken() string { + if m.getIDForToken != nil { + return m.getIDForToken() + } + return m.ret1.(string) +} + func (m *mockProvisioner) GetTokenID(token string) (string, error) { if m.getTokenID != nil { return m.getTokenID(token) @@ -553,7 +561,7 @@ type mockAuthority struct { renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) - loadProvisionerByID func(provID string) (provisioner.Interface, error) + loadProvisionerByName func(name string) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) revoke func(context.Context, *authority.RevokeOptions) error getEncryptedKey func(kid string) (string, error) @@ -633,9 +641,9 @@ func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (pr return m.ret1.(provisioner.Interface), m.err } -func (m *mockAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) { - if m.loadProvisionerByID != nil { - return m.loadProvisionerByID(provID) +func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + if m.loadProvisionerByName != nil { + return m.loadProvisionerByName(name) } return m.ret1.(provisioner.Interface), m.err } @@ -731,7 +739,7 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin return m.ret1.(bool), m.err } -func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) { +func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { if m.getSSHBastion != nil { return m.getSSHBastion(ctx, user, hostname) } @@ -808,7 +816,7 @@ func Test_caHandler_Root(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) - expected := []byte(`{"ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`) + expected := []byte(`{"ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -852,8 +860,8 @@ func Test_caHandler_Sign(t *testing.T) { t.Fatal(err) } - expected1 := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) - expected2 := []byte(`{"crt":"` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) + expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) tests := []struct { name string @@ -926,7 +934,7 @@ func Test_caHandler_Renew(t *testing.T) { {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, } - expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -987,7 +995,7 @@ func Test_caHandler_Rekey(t *testing.T) { {"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest}, } - expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1202,7 +1210,7 @@ func Test_caHandler_Roots(t *testing.T) { {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } - expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1248,7 +1256,7 @@ func Test_caHandler_Federation(t *testing.T) { {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } - expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/api/errors.go b/api/errors.go index 438b873d..bff46b55 100644 --- a/api/errors.go +++ b/api/errors.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/scep" @@ -19,6 +20,9 @@ func WriteError(w http.ResponseWriter, err error) { case *acme.Error: acme.WriteError(w, k) return + case *admin.Error: + admin.WriteError(w, k) + return case *scep.Error: w.Header().Set("Content-Type", "text/plain") default: @@ -46,12 +50,10 @@ func WriteError(w http.ResponseWriter, err error) { rl.WithFields(map[string]interface{}{ "stack-trace": fmt.Sprintf("%+v", e), }) - } else { - if e, ok := cause.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } + } else if e, ok := cause.(errs.StackTracer); ok { + rl.WithFields(map[string]interface{}{ + "stack-trace": fmt.Sprintf("%+v", e), + }) } } } diff --git a/api/sign.go b/api/sign.go index 69e9a1a5..d6fd2bc6 100644 --- a/api/sign.go +++ b/api/sign.go @@ -5,7 +5,7 @@ import ( "encoding/json" "net/http" - "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -37,11 +37,11 @@ func (s *SignRequest) Validate() error { // SignResponse is the response object of the certificate signature request. type SignResponse struct { - ServerPEM Certificate `json:"crt"` - CaPEM Certificate `json:"ca"` - CertChainPEM []Certificate `json:"certChain"` - TLSOptions *authority.TLSOptions `json:"tlsOptions,omitempty"` - TLS *tls.ConnectionState `json:"-"` + ServerPEM Certificate `json:"crt"` + CaPEM Certificate `json:"ca"` + CertChainPEM []Certificate `json:"certChain"` + TLSOptions *config.TLSOptions `json:"tlsOptions,omitempty"` + TLS *tls.ConnectionState `json:"-"` } // Sign is an HTTP handler that reads a certificate request and an diff --git a/api/ssh.go b/api/ssh.go index 9962ad4f..7c7a5acd 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/templates" @@ -22,12 +23,12 @@ type SSHAuthority interface { RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) - GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) - GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) + GetSSHRoots(ctx context.Context) (*config.SSHKeys, error) + GetSSHFederation(ctx context.Context) (*config.SSHKeys, error) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) - GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) - GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) + GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) + GetSSHBastion(ctx context.Context, user string, hostname string) (*config.Bastion, error) } // SSHSignRequest is the request body of an SSH certificate request. @@ -51,7 +52,7 @@ func (s *SSHSignRequest) Validate() error { return errors.Errorf("unknown certType %s", s.CertType) case len(s.PublicKey) == 0: return errors.New("missing or empty publicKey") - case len(s.OTT) == 0: + case s.OTT == "": return errors.New("missing or empty ott") default: // Validate identity signature if provided @@ -86,7 +87,7 @@ type SSHCertificate struct { // SSHGetHostsResponse is the response object that returns the list of valid // hosts for SSH. type SSHGetHostsResponse struct { - Hosts []authority.Host `json:"hosts"` + Hosts []config.Host `json:"hosts"` } // MarshalJSON implements the json.Marshaler interface. Returns a quoted, @@ -239,8 +240,8 @@ func (r *SSHBastionRequest) Validate() error { // SSHBastionResponse is the response body used to return the bastion for a // given host. type SSHBastionResponse struct { - Hostname string `json:"hostname"` - Bastion *authority.Bastion `json:"bastion,omitempty"` + Hostname string `json:"hostname"` + Bastion *config.Bastion `json:"bastion,omitempty"` } // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token @@ -407,18 +408,18 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { return } - var config SSHConfigResponse + var cfg SSHConfigResponse switch body.Type { case provisioner.SSHUserCert: - config.UserTemplates = ts + cfg.UserTemplates = ts case provisioner.SSHHostCert: - config.HostTemplates = ts + cfg.HostTemplates = ts default: WriteError(w, errs.InternalServer("it should hot get here")) return } - JSON(w, config) + JSON(w, cfg) } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. diff --git a/api/sshRekey.go b/api/sshRekey.go index 285422f9..9d9e17cf 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -2,6 +2,7 @@ package api import ( "net/http" + "time" "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" @@ -18,7 +19,7 @@ type SSHRekeyRequest struct { // Validate validates the SSHSignRekey. func (s *SSHRekeyRequest) Validate() error { switch { - case len(s.OTT) == 0: + case s.OTT == "": return errors.New("missing or empty ott") case len(s.PublicKey) == 0: return errors.New("missing or empty public key") @@ -72,7 +73,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { return } - identity, err := h.renewIdentityCertificate(r) + // Match identity cert with the SSH cert + notBefore := time.Unix(int64(oldCert.ValidAfter), 0) + notAfter := time.Unix(int64(oldCert.ValidBefore), 0) + + identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { WriteError(w, errs.ForbiddenErr(err)) return diff --git a/api/sshRenew.go b/api/sshRenew.go index 048c83a3..d0633ecf 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -1,7 +1,9 @@ package api import ( + "crypto/x509" "net/http" + "time" "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" @@ -16,7 +18,7 @@ type SSHRenewRequest struct { // Validate validates the SSHSignRequest. func (s *SSHRenewRequest) Validate() error { switch { - case len(s.OTT) == 0: + case s.OTT == "": return errors.New("missing or empty ott") default: return nil @@ -62,7 +64,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { return } - identity, err := h.renewIdentityCertificate(r) + // Match identity cert with the SSH cert + notBefore := time.Unix(int64(oldCert.ValidAfter), 0) + notAfter := time.Unix(int64(oldCert.ValidBefore), 0) + + identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { WriteError(w, errs.ForbiddenErr(err)) return @@ -74,13 +80,28 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { }, http.StatusCreated) } -// renewIdentityCertificate request the client TLS certificate if present. -func (h *caHandler) renewIdentityCertificate(r *http.Request) ([]Certificate, error) { +// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the +func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { return nil, nil } - certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) + // Clone the certificate as we can modify it. + cert, err := x509.ParseCertificate(r.TLS.PeerCertificates[0].Raw) + if err != nil { + return nil, errors.Wrap(err, "error parsing client certificate") + } + + // Enforce the cert to match another certificate, for example an ssh + // certificate. + if !notBefore.IsZero() { + cert.NotBefore = notBefore + } + if !notAfter.IsZero() { + cert.NotAfter = notAfter + } + + certChain, err := h.Authority.Renew(cert) if err != nil { return nil, err } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index 5a1c858c..c6ebe99d 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -36,7 +36,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { if !r.Passive { return errs.NotImplemented("non-passive revocation not implemented") } - if len(r.OTT) == 0 { + if r.OTT == "" { return errs.BadRequest("missing ott") } return diff --git a/api/ssh_test.go b/api/ssh_test.go index 1873a96d..a2e8748f 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -284,7 +284,7 @@ func Test_caHandler_SSHSign(t *testing.T) { identityCerts := []*x509.Certificate{ parseCertificate(certPEM), } - identityCertsPEM := []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`) + identityCertsPEM := []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`) tests := []struct { name string diff --git a/api/utils.go b/api/utils.go index 0d87a065..bf45db53 100644 --- a/api/utils.go +++ b/api/utils.go @@ -3,11 +3,14 @@ package api import ( "encoding/json" "io" + "io/ioutil" "log" "net/http" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" ) // EnableLogger is an interface that enables response logging for an object. @@ -64,6 +67,29 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) { LogEnabledResponse(w, v) } +// ProtoJSON writes the passed value into the http.ResponseWriter. +func ProtoJSON(w http.ResponseWriter, m proto.Message) { + ProtoJSONStatus(w, m, http.StatusOK) +} + +// ProtoJSONStatus writes the given value into the http.ResponseWriter and the +// given status is written as the status code of the response. +func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + + b, err := protojson.Marshal(m) + if err != nil { + LogError(w, err) + return + } + if _, err := w.Write(b); err != nil { + LogError(w, err) + return + } + //LogEnabledResponse(w, v) +} + // ReadJSON reads JSON from the request body and stores it in the value // pointed by v. func ReadJSON(r io.Reader, v interface{}) error { @@ -72,3 +98,13 @@ func ReadJSON(r io.Reader, v interface{}) error { } return nil } + +// ReadProtoJSON reads JSON from the request body and stores it in the value +// pointed by v. +func ReadProtoJSON(r io.Reader, m proto.Message) error { + data, err := ioutil.ReadAll(r) + if err != nil { + return errs.Wrap(http.StatusBadRequest, err, "error reading request body") + } + return protojson.Unmarshal(data, m) +} diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go new file mode 100644 index 00000000..bf79ebcf --- /dev/null +++ b/authority/admin/api/admin.go @@ -0,0 +1,160 @@ +package api + +import ( + "net/http" + + "github.com/go-chi/chi" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/admin" + "go.step.sm/linkedca" +) + +// CreateAdminRequest represents the body for a CreateAdmin request. +type CreateAdminRequest struct { + Subject string `json:"subject"` + Provisioner string `json:"provisioner"` + Type linkedca.Admin_Type `json:"type"` +} + +// Validate validates a new-admin request body. +func (car *CreateAdminRequest) Validate() error { + if car.Subject == "" { + return admin.NewError(admin.ErrorBadRequestType, "subject cannot be empty") + } + if car.Provisioner == "" { + return admin.NewError(admin.ErrorBadRequestType, "provisioner cannot be empty") + } + switch car.Type { + case linkedca.Admin_SUPER_ADMIN, linkedca.Admin_ADMIN: + default: + return admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type") + } + return nil +} + +// GetAdminsResponse for returning a list of admins. +type GetAdminsResponse struct { + Admins []*linkedca.Admin `json:"admins"` + NextCursor string `json:"nextCursor"` +} + +// UpdateAdminRequest represents the body for a UpdateAdmin request. +type UpdateAdminRequest struct { + Type linkedca.Admin_Type `json:"type"` +} + +// Validate validates a new-admin request body. +func (uar *UpdateAdminRequest) Validate() error { + switch uar.Type { + case linkedca.Admin_SUPER_ADMIN, linkedca.Admin_ADMIN: + default: + return admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type") + } + return nil +} + +// DeleteResponse is the resource for successful DELETE responses. +type DeleteResponse struct { + Status string `json:"status"` +} + +// GetAdmin returns the requested admin, or an error. +func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + + adm, ok := h.auth.LoadAdminByID(id) + if !ok { + api.WriteError(w, admin.NewError(admin.ErrorNotFoundType, + "admin %s not found", id)) + return + } + api.ProtoJSON(w, adm) +} + +// GetAdmins returns a segment of admins associated with the authority. +func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { + cursor, limit, err := api.ParseCursor(r) + if err != nil { + api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + "error parsing cursor and limit from query params")) + return + } + + admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) + return + } + api.JSON(w, &GetAdminsResponse{ + Admins: admins, + NextCursor: nextCursor, + }) +} + +// CreateAdmin creates a new admin. +func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { + var body CreateAdminRequest + if err := api.ReadJSON(r.Body, &body); err != nil { + api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + return + } + + if err := body.Validate(); err != nil { + api.WriteError(w, err) + return + } + + p, err := h.auth.LoadProvisionerByName(body.Provisioner) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) + return + } + adm := &linkedca.Admin{ + ProvisionerId: p.GetID(), + Subject: body.Subject, + Type: body.Type, + } + // Store to authority collection. + if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error storing admin")) + return + } + + api.ProtoJSONStatus(w, adm, http.StatusCreated) +} + +// DeleteAdmin deletes admin. +func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + + if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) + return + } + + api.JSON(w, &DeleteResponse{Status: "ok"}) +} + +// UpdateAdmin updates an existing admin. +func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { + var body UpdateAdminRequest + if err := api.ReadJSON(r.Body, &body); err != nil { + api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + return + } + + if err := body.Validate(); err != nil { + api.WriteError(w, err) + return + } + + id := chi.URLParam(r, "id") + + adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id)) + return + } + + api.ProtoJSON(w, adm) +} diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go new file mode 100644 index 00000000..d88edfa1 --- /dev/null +++ b/authority/admin/api/handler.go @@ -0,0 +1,41 @@ +package api + +import ( + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" +) + +// Handler is the ACME API request handler. +type Handler struct { + db admin.DB + auth *authority.Authority +} + +// NewHandler returns a new Authority Config Handler. +func NewHandler(auth *authority.Authority) api.RouterHandler { + h := &Handler{db: auth.GetAdminDatabase(), auth: auth} + + return h +} + +// Route traffic and implement the Router interface. +func (h *Handler) Route(r api.Router) { + authnz := func(next nextHTTP) nextHTTP { + return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) + } + + // Provisioners + r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) + r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) + r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner)) + r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner)) + r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner)) + + // Admins + r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin)) + r.MethodFunc("GET", "/admins", authnz(h.GetAdmins)) + r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin)) + r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin)) + r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin)) +} diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go new file mode 100644 index 00000000..19025a9d --- /dev/null +++ b/authority/admin/api/middleware.go @@ -0,0 +1,54 @@ +package api + +import ( + "context" + "net/http" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/admin" +) + +type nextHTTP = func(http.ResponseWriter, *http.Request) + +// requireAPIEnabled is a middleware that ensures the Administration API +// is enabled before servicing requests. +func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + if !h.auth.IsAdminAPIEnabled() { + api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, + "administration API not enabled")) + return + } + next(w, r) + } +} + +// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. +func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + tok := r.Header.Get("Authorization") + if tok == "" { + api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, + "missing authorization header token")) + return + } + + adm, err := h.auth.AuthorizeAdminToken(r, tok) + if err != nil { + api.WriteError(w, err) + return + } + + ctx := context.WithValue(r.Context(), adminContextKey, adm) + next(w, r.WithContext(ctx)) + } +} + +// ContextKey is the key type for storing and searching for ACME request +// essentials in the context of a request. +type ContextKey string + +const ( + // adminContextKey account key + adminContextKey = ContextKey("admin") +) diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go new file mode 100644 index 00000000..fd1a02d5 --- /dev/null +++ b/authority/admin/api/provisioner.go @@ -0,0 +1,175 @@ +package api + +import ( + "net/http" + + "github.com/go-chi/chi" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" + "go.step.sm/linkedca" +) + +// GetProvisionersResponse is the type for GET /admin/provisioners responses. +type GetProvisionersResponse struct { + Provisioners provisioner.List `json:"provisioners"` + NextCursor string `json:"nextCursor"` +} + +// GetProvisioner returns the requested provisioner, or an error. +func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + + var ( + p provisioner.Interface + err error + ) + if len(id) > 0 { + if p, err = h.auth.LoadProvisionerByID(id); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + return + } + } else { + if p, err = h.auth.LoadProvisionerByName(name); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + return + } + } + + prov, err := h.db.GetProvisioner(ctx, p.GetID()) + if err != nil { + api.WriteError(w, err) + return + } + api.ProtoJSON(w, prov) +} + +// GetProvisioners returns the given segment of provisioners associated with the authority. +func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { + cursor, limit, err := api.ParseCursor(r) + if err != nil { + api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + "error parsing cursor & limit query params")) + return + } + + p, next, err := h.auth.GetProvisioners(cursor, limit) + if err != nil { + api.WriteError(w, errs.InternalServerErr(err)) + return + } + api.JSON(w, &GetProvisionersResponse{ + Provisioners: p, + NextCursor: next, + }) +} + +// CreateProvisioner creates a new prov. +func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { + var prov = new(linkedca.Provisioner) + if err := api.ReadProtoJSON(r.Body, prov); err != nil { + api.WriteError(w, err) + return + } + + // TODO: Validate inputs + if err := authority.ValidateClaims(prov.Claims); err != nil { + api.WriteError(w, err) + return + } + + if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) + return + } + api.ProtoJSONStatus(w, prov, http.StatusCreated) +} + +// DeleteProvisioner deletes a provisioner. +func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + + var ( + p provisioner.Interface + err error + ) + if len(id) > 0 { + if p, err = h.auth.LoadProvisionerByID(id); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + return + } + } else { + if p, err = h.auth.LoadProvisionerByName(name); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + return + } + } + + if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) + return + } + + api.JSON(w, &DeleteResponse{Status: "ok"}) +} + +// UpdateProvisioner updates an existing prov. +func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { + var nu = new(linkedca.Provisioner) + if err := api.ReadProtoJSON(r.Body, nu); err != nil { + api.WriteError(w, err) + return + } + + name := chi.URLParam(r, "name") + _old, err := h.auth.LoadProvisionerByName(name) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) + return + } + + old, err := h.db.GetProvisioner(r.Context(), _old.GetID()) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) + return + } + + if nu.Id != old.Id { + api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID")) + return + } + if nu.Type != old.Type { + api.WriteError(w, admin.NewErrorISE("cannot change provisioner type")) + return + } + if nu.AuthorityId != old.AuthorityId { + api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID")) + return + } + if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { + api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt")) + return + } + if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { + api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt")) + return + } + + // TODO: Validate inputs + if err := authority.ValidateClaims(nu.Claims); err != nil { + api.WriteError(w, err) + return + } + + if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { + api.WriteError(w, err) + return + } + api.ProtoJSON(w, nu) +} diff --git a/authority/admin/db.go b/authority/admin/db.go new file mode 100644 index 00000000..8a6339d9 --- /dev/null +++ b/authority/admin/db.go @@ -0,0 +1,179 @@ +package admin + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/pkg/errors" + "go.step.sm/linkedca" +) + +const ( + // DefaultAuthorityID is the default AuthorityID. This will be the ID + // of the first Authority created, as well as the default AuthorityID + // if one is not specified in the configuration. + DefaultAuthorityID = "00000000-0000-0000-0000-000000000000" +) + +// ErrNotFound is an error that should be used by the authority.DB interface to +// indicate that an entity does not exist. +var ErrNotFound = errors.New("not found") + +// UnmarshalProvisionerDetails unmarshals details type to the specific provisioner details. +func UnmarshalProvisionerDetails(typ linkedca.Provisioner_Type, data []byte) (*linkedca.ProvisionerDetails, error) { + var v linkedca.ProvisionerDetails + switch typ { + case linkedca.Provisioner_JWK: + v.Data = new(linkedca.ProvisionerDetails_JWK) + case linkedca.Provisioner_OIDC: + v.Data = new(linkedca.ProvisionerDetails_OIDC) + case linkedca.Provisioner_GCP: + v.Data = new(linkedca.ProvisionerDetails_GCP) + case linkedca.Provisioner_AWS: + v.Data = new(linkedca.ProvisionerDetails_AWS) + case linkedca.Provisioner_AZURE: + v.Data = new(linkedca.ProvisionerDetails_Azure) + case linkedca.Provisioner_ACME: + v.Data = new(linkedca.ProvisionerDetails_ACME) + case linkedca.Provisioner_X5C: + v.Data = new(linkedca.ProvisionerDetails_X5C) + case linkedca.Provisioner_K8SSA: + v.Data = new(linkedca.ProvisionerDetails_K8SSA) + case linkedca.Provisioner_SSHPOP: + v.Data = new(linkedca.ProvisionerDetails_SSHPOP) + case linkedca.Provisioner_SCEP: + v.Data = new(linkedca.ProvisionerDetails_SCEP) + default: + return nil, fmt.Errorf("unsupported provisioner type %s", typ) + } + + if err := json.Unmarshal(data, v.Data); err != nil { + return nil, err + } + return &linkedca.ProvisionerDetails{Data: v.Data}, nil +} + +// DB is the DB interface expected by the step-ca Admin API. +type DB interface { + CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error + GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) + GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) + UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error + DeleteProvisioner(ctx context.Context, id string) error + + CreateAdmin(ctx context.Context, admin *linkedca.Admin) error + GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) + GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) + UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error + DeleteAdmin(ctx context.Context, id string) error +} + +// MockDB is an implementation of the DB interface that should only be used as +// a mock in tests. +type MockDB struct { + MockCreateProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error + MockGetProvisioner func(ctx context.Context, id string) (*linkedca.Provisioner, error) + MockGetProvisioners func(ctx context.Context) ([]*linkedca.Provisioner, error) + MockUpdateProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error + MockDeleteProvisioner func(ctx context.Context, id string) error + + MockCreateAdmin func(ctx context.Context, adm *linkedca.Admin) error + MockGetAdmin func(ctx context.Context, id string) (*linkedca.Admin, error) + MockGetAdmins func(ctx context.Context) ([]*linkedca.Admin, error) + MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error + MockDeleteAdmin func(ctx context.Context, id string) error + + MockError error + MockRet1 interface{} +} + +// CreateProvisioner mock. +func (m *MockDB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + if m.MockCreateProvisioner != nil { + return m.MockCreateProvisioner(ctx, prov) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetProvisioner mock. +func (m *MockDB) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) { + if m.MockGetProvisioner != nil { + return m.MockGetProvisioner(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*linkedca.Provisioner), m.MockError +} + +// GetProvisioners mock +func (m *MockDB) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) { + if m.MockGetProvisioners != nil { + return m.MockGetProvisioners(ctx) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.([]*linkedca.Provisioner), m.MockError +} + +// UpdateProvisioner mock +func (m *MockDB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + if m.MockUpdateProvisioner != nil { + return m.MockUpdateProvisioner(ctx, prov) + } + return m.MockError +} + +// DeleteProvisioner mock +func (m *MockDB) DeleteProvisioner(ctx context.Context, id string) error { + if m.MockDeleteProvisioner != nil { + return m.MockDeleteProvisioner(ctx, id) + } + return m.MockError +} + +// CreateAdmin mock +func (m *MockDB) CreateAdmin(ctx context.Context, admin *linkedca.Admin) error { + if m.MockCreateAdmin != nil { + return m.MockCreateAdmin(ctx, admin) + } + return m.MockError +} + +// GetAdmin mock. +func (m *MockDB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) { + if m.MockGetAdmin != nil { + return m.MockGetAdmin(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*linkedca.Admin), m.MockError +} + +// GetAdmins mock +func (m *MockDB) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) { + if m.MockGetAdmins != nil { + return m.MockGetAdmins(ctx) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.([]*linkedca.Admin), m.MockError +} + +// UpdateAdmin mock +func (m *MockDB) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error { + if m.MockUpdateAdmin != nil { + return m.MockUpdateAdmin(ctx, adm) + } + return m.MockError +} + +// DeleteAdmin mock +func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error { + if m.MockDeleteAdmin != nil { + return m.MockDeleteAdmin(ctx, id) + } + return m.MockError +} diff --git a/authority/admin/db/nosql/admin.go b/authority/admin/db/nosql/admin.go new file mode 100644 index 00000000..6bb6bdd1 --- /dev/null +++ b/authority/admin/db/nosql/admin.go @@ -0,0 +1,178 @@ +package nosql + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/nosql" + "go.step.sm/linkedca" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// dbAdmin is the database representation of the Admin type. +type dbAdmin struct { + ID string `json:"id"` + AuthorityID string `json:"authorityID"` + ProvisionerID string `json:"provisionerID"` + Subject string `json:"subject"` + Type linkedca.Admin_Type `json:"type"` + CreatedAt time.Time `json:"createdAt"` + DeletedAt time.Time `json:"deletedAt"` +} + +func (dba *dbAdmin) convert() *linkedca.Admin { + return &linkedca.Admin{ + Id: dba.ID, + AuthorityId: dba.AuthorityID, + ProvisionerId: dba.ProvisionerID, + Subject: dba.Subject, + Type: dba.Type, + CreatedAt: timestamppb.New(dba.CreatedAt), + DeletedAt: timestamppb.New(dba.DeletedAt), + } +} + +func (dba *dbAdmin) clone() *dbAdmin { + u := *dba + return &u +} + +func (db *DB) getDBAdminBytes(ctx context.Context, id string) ([]byte, error) { + data, err := db.db.Get(adminsTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id) + } else if err != nil { + return nil, errors.Wrapf(err, "error loading admin %s", id) + } + return data, nil +} + +func (db *DB) unmarshalDBAdmin(data []byte, id string) (*dbAdmin, error) { + var dba = new(dbAdmin) + if err := json.Unmarshal(data, dba); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling admin %s into dbAdmin", id) + } + if !dba.DeletedAt.IsZero() { + return nil, admin.NewError(admin.ErrorDeletedType, "admin %s is deleted", id) + } + if dba.AuthorityID != db.authorityID { + return nil, admin.NewError(admin.ErrorAuthorityMismatchType, + "admin %s is not owned by authority %s", dba.ID, db.authorityID) + } + return dba, nil +} + +func (db *DB) getDBAdmin(ctx context.Context, id string) (*dbAdmin, error) { + data, err := db.getDBAdminBytes(ctx, id) + if err != nil { + return nil, err + } + dba, err := db.unmarshalDBAdmin(data, id) + if err != nil { + return nil, err + } + return dba, nil +} + +func (db *DB) unmarshalAdmin(data []byte, id string) (*linkedca.Admin, error) { + dba, err := db.unmarshalDBAdmin(data, id) + if err != nil { + return nil, err + } + return dba.convert(), nil +} + +// GetAdmin retrieves and unmarshals a admin from the database. +func (db *DB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) { + data, err := db.getDBAdminBytes(ctx, id) + if err != nil { + return nil, err + } + adm, err := db.unmarshalAdmin(data, id) + if err != nil { + return nil, err + } + + return adm, nil +} + +// GetAdmins retrieves and unmarshals all active (not deleted) admins +// from the database. +// TODO should we be paginating? +func (db *DB) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) { + dbEntries, err := db.db.List(adminsTable) + if err != nil { + return nil, errors.Wrap(err, "error loading admins") + } + var admins = []*linkedca.Admin{} + for _, entry := range dbEntries { + adm, err := db.unmarshalAdmin(entry.Value, string(entry.Key)) + if err != nil { + switch k := err.(type) { + case *admin.Error: + if k.IsType(admin.ErrorDeletedType) || k.IsType(admin.ErrorAuthorityMismatchType) { + continue + } else { + return nil, err + } + default: + return nil, err + } + } + if adm.AuthorityId != db.authorityID { + continue + } + admins = append(admins, adm) + } + return admins, nil +} + +// CreateAdmin stores a new admin to the database. +func (db *DB) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error { + var err error + adm.Id, err = randID() + if err != nil { + return admin.WrapErrorISE(err, "error generating random id for admin") + } + adm.AuthorityId = db.authorityID + + dba := &dbAdmin{ + ID: adm.Id, + AuthorityID: db.authorityID, + ProvisionerID: adm.ProvisionerId, + Subject: adm.Subject, + Type: adm.Type, + CreatedAt: clock.Now(), + } + + return db.save(ctx, dba.ID, dba, nil, "admin", adminsTable) +} + +// UpdateAdmin saves an updated admin to the database. +func (db *DB) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error { + old, err := db.getDBAdmin(ctx, adm.Id) + if err != nil { + return err + } + + nu := old.clone() + nu.Type = adm.Type + + return db.save(ctx, old.ID, nu, old, "admin", adminsTable) +} + +// DeleteAdmin saves an updated admin to the database. +func (db *DB) DeleteAdmin(ctx context.Context, id string) error { + old, err := db.getDBAdmin(ctx, id) + if err != nil { + return err + } + + nu := old.clone() + nu.DeletedAt = clock.Now() + + return db.save(ctx, old.ID, nu, old, "admin", adminsTable) +} diff --git a/authority/admin/db/nosql/admin_test.go b/authority/admin/db/nosql/admin_test.go new file mode 100644 index 00000000..4234d526 --- /dev/null +++ b/authority/admin/db/nosql/admin_test.go @@ -0,0 +1,1108 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" + "go.step.sm/linkedca" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestDB_getDBAdminBytes(t *testing.T) { + adminID := "adminID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading admin adminID: force"), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return []byte("foo"), nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if b, err := d.getDBAdminBytes(context.Background(), adminID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, string(b), "foo") + } + }) + } +} + +func TestDB_getDBAdmin(t *testing.T) { + adminID := "adminID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + dba *dbAdmin + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading admin adminID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling admin adminID into dbAdmin"), + } + }, + "fail/deleted": func(t *testing.T) test { + now := clock.Now() + dba := &dbAdmin{ + ID: adminID, + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: now, + DeletedAt: now, + } + b, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return b, nil + }, + }, + adminErr: admin.NewError(admin.ErrorDeletedType, "admin adminID is deleted"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dba := &dbAdmin{ + ID: adminID, + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: now, + } + b, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return b, nil + }, + }, + dba: dba, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if dba, err := d.getDBAdmin(context.Background(), adminID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, dba.ID, adminID) + assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID) + assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID) + assert.Equals(t, dba.Subject, tc.dba.Subject) + assert.Equals(t, dba.Type, tc.dba.Type) + assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt) + assert.Fatal(t, dba.DeletedAt.IsZero()) + } + }) + } +} + +func TestDB_unmarshalDBAdmin(t *testing.T) { + adminID := "adminID" + type test struct { + in []byte + err error + adminErr *admin.Error + dba *dbAdmin + } + var tests = map[string]func(t *testing.T) test{ + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + in: []byte("foo"), + err: errors.New("error unmarshaling admin adminID into dbAdmin"), + } + }, + "fail/deleted-error": func(t *testing.T) test { + dba := &dbAdmin{ + DeletedAt: time.Now(), + } + data, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + in: data, + adminErr: admin.NewError(admin.ErrorDeletedType, "admin adminID is deleted"), + } + }, + "fail/authority-mismatch-error": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + AuthorityID: "foo", + } + data, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + in: data, + adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, + "admin %s is not owned by authority %s", adminID, admin.DefaultAuthorityID), + } + }, + "ok": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + Subject: "max@smallstep.com", + ProvisionerID: "provID", + AuthorityID: admin.DefaultAuthorityID, + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: clock.Now(), + } + data, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + in: data, + dba: dba, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{authorityID: admin.DefaultAuthorityID} + if dba, err := d.unmarshalDBAdmin(tc.in, adminID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, dba.ID, adminID) + assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID) + assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID) + assert.Equals(t, dba.Subject, tc.dba.Subject) + assert.Equals(t, dba.Type, tc.dba.Type) + assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt) + assert.Fatal(t, dba.DeletedAt.IsZero()) + } + }) + } +} + +func TestDB_unmarshalAdmin(t *testing.T) { + adminID := "adminID" + type test struct { + in []byte + err error + adminErr *admin.Error + dba *dbAdmin + } + var tests = map[string]func(t *testing.T) test{ + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + in: []byte("foo"), + err: errors.New("error unmarshaling admin adminID into dbAdmin"), + } + }, + "fail/deleted-error": func(t *testing.T) test { + dba := &dbAdmin{ + DeletedAt: time.Now(), + } + data, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + in: data, + adminErr: admin.NewError(admin.ErrorDeletedType, "admin adminID is deleted"), + } + }, + "ok": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + Subject: "max@smallstep.com", + ProvisionerID: "provID", + AuthorityID: admin.DefaultAuthorityID, + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: clock.Now(), + } + data, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + in: data, + dba: dba, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{authorityID: admin.DefaultAuthorityID} + if adm, err := d.unmarshalAdmin(tc.in, adminID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, adm.Id, adminID) + assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID) + assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID) + assert.Equals(t, adm.Subject, tc.dba.Subject) + assert.Equals(t, adm.Type, tc.dba.Type) + assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt)) + assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt)) + } + }) + } +} + +func TestDB_GetAdmin(t *testing.T) { + adminID := "adminID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + dba *dbAdmin + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading admin adminID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling admin adminID into dbAdmin"), + } + }, + "fail/deleted": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: clock.Now(), + DeletedAt: clock.Now(), + } + b, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return b, nil + }, + }, + dba: dba, + adminErr: admin.NewError(admin.ErrorDeletedType, "admin adminID is deleted"), + } + }, + "fail/authorityID-mismatch": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + AuthorityID: "foo", + ProvisionerID: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: clock.Now(), + } + b, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return b, nil + }, + }, + dba: dba, + adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, + "admin %s is not owned by authority %s", dba.ID, admin.DefaultAuthorityID), + } + }, + "ok": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: clock.Now(), + } + b, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return b, nil + }, + }, + dba: dba, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if adm, err := d.GetAdmin(context.Background(), adminID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, adm.Id, adminID) + assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID) + assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID) + assert.Equals(t, adm.Subject, tc.dba.Subject) + assert.Equals(t, adm.Type, tc.dba.Type) + assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt)) + assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt)) + } + }) + } +} + +func TestDB_DeleteAdmin(t *testing.T) { + adminID := "adminID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading admin adminID: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: clock.Now(), + } + data, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return data, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + assert.Equals(t, string(old), string(data)) + + var _dba = new(dbAdmin) + assert.FatalError(t, json.Unmarshal(nu, _dba)) + + assert.Equals(t, _dba.ID, dba.ID) + assert.Equals(t, _dba.AuthorityID, dba.AuthorityID) + assert.Equals(t, _dba.ProvisionerID, dba.ProvisionerID) + assert.Equals(t, _dba.Subject, dba.Subject) + assert.Equals(t, _dba.Type, dba.Type) + assert.Equals(t, _dba.CreatedAt, dba.CreatedAt) + + assert.True(t, _dba.DeletedAt.Before(time.Now())) + assert.True(t, _dba.DeletedAt.After(time.Now().Add(-time.Minute))) + + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving authority admin: force"), + } + }, + "ok": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: clock.Now(), + } + data, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return data, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + assert.Equals(t, string(old), string(data)) + + var _dba = new(dbAdmin) + assert.FatalError(t, json.Unmarshal(nu, _dba)) + + assert.Equals(t, _dba.ID, dba.ID) + assert.Equals(t, _dba.AuthorityID, dba.AuthorityID) + assert.Equals(t, _dba.ProvisionerID, dba.ProvisionerID) + assert.Equals(t, _dba.Subject, dba.Subject) + assert.Equals(t, _dba.Type, dba.Type) + assert.Equals(t, _dba.CreatedAt, dba.CreatedAt) + + assert.True(t, _dba.DeletedAt.Before(time.Now())) + assert.True(t, _dba.DeletedAt.After(time.Now().Add(-time.Minute))) + + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if err := d.DeleteAdmin(context.Background(), adminID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } + }) + } +} + +func TestDB_UpdateAdmin(t *testing.T) { + adminID := "adminID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + adm *linkedca.Admin + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + adm: &linkedca.Admin{Id: adminID}, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + adm: &linkedca.Admin{Id: adminID}, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading admin adminID: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: clock.Now(), + } + + upd := dba.convert() + upd.Type = linkedca.Admin_ADMIN + + data, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + adm: upd, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return data, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + assert.Equals(t, string(old), string(data)) + + var _dba = new(dbAdmin) + assert.FatalError(t, json.Unmarshal(nu, _dba)) + + assert.Equals(t, _dba.ID, dba.ID) + assert.Equals(t, _dba.AuthorityID, dba.AuthorityID) + assert.Equals(t, _dba.ProvisionerID, dba.ProvisionerID) + assert.Equals(t, _dba.Subject, dba.Subject) + assert.Equals(t, _dba.Type, linkedca.Admin_ADMIN) + assert.Equals(t, _dba.CreatedAt, dba.CreatedAt) + + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving authority admin: force"), + } + }, + "ok": func(t *testing.T) test { + dba := &dbAdmin{ + ID: adminID, + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: clock.Now(), + } + + upd := dba.convert() + upd.Type = linkedca.Admin_ADMIN + + data, err := json.Marshal(dba) + assert.FatalError(t, err) + return test{ + adm: upd, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + + return data, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, string(key), adminID) + assert.Equals(t, string(old), string(data)) + + var _dba = new(dbAdmin) + assert.FatalError(t, json.Unmarshal(nu, _dba)) + + assert.Equals(t, _dba.ID, dba.ID) + assert.Equals(t, _dba.AuthorityID, dba.AuthorityID) + assert.Equals(t, _dba.ProvisionerID, dba.ProvisionerID) + assert.Equals(t, _dba.Subject, dba.Subject) + assert.Equals(t, _dba.Type, linkedca.Admin_ADMIN) + assert.Equals(t, _dba.CreatedAt, dba.CreatedAt) + + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if err := d.UpdateAdmin(context.Background(), tc.adm); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } + }) + } +} + +func TestDB_CreateAdmin(t *testing.T) { + type test struct { + db nosql.DB + err error + adminErr *admin.Error + adm *linkedca.Admin + } + var tests = map[string]func(t *testing.T) test{ + "fail/save-error": func(t *testing.T) test { + adm := &linkedca.Admin{ + AuthorityId: admin.DefaultAuthorityID, + ProvisionerId: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_ADMIN, + } + + return test{ + adm: adm, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, old, nil) + + var _dba = new(dbAdmin) + assert.FatalError(t, json.Unmarshal(nu, _dba)) + + assert.True(t, len(_dba.ID) > 0 && _dba.ID == string(key)) + assert.Equals(t, _dba.AuthorityID, adm.AuthorityId) + assert.Equals(t, _dba.ProvisionerID, adm.ProvisionerId) + assert.Equals(t, _dba.Subject, adm.Subject) + assert.Equals(t, _dba.Type, linkedca.Admin_ADMIN) + + assert.True(t, _dba.CreatedAt.Before(time.Now())) + assert.True(t, _dba.CreatedAt.After(time.Now().Add(-time.Minute))) + + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving authority admin: force"), + } + }, + "ok": func(t *testing.T) test { + adm := &linkedca.Admin{ + AuthorityId: admin.DefaultAuthorityID, + ProvisionerId: "provID", + Subject: "max@smallstep.com", + Type: linkedca.Admin_ADMIN, + } + + return test{ + adm: adm, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, adminsTable) + assert.Equals(t, old, nil) + + var _dba = new(dbAdmin) + assert.FatalError(t, json.Unmarshal(nu, _dba)) + + assert.True(t, len(_dba.ID) > 0 && _dba.ID == string(key)) + assert.Equals(t, _dba.AuthorityID, adm.AuthorityId) + assert.Equals(t, _dba.ProvisionerID, adm.ProvisionerId) + assert.Equals(t, _dba.Subject, adm.Subject) + assert.Equals(t, _dba.Type, linkedca.Admin_ADMIN) + + assert.True(t, _dba.CreatedAt.Before(time.Now())) + assert.True(t, _dba.CreatedAt.After(time.Now().Add(-time.Minute))) + + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if err := d.CreateAdmin(context.Background(), tc.adm); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } + }) + } +} + +func TestDB_GetAdmins(t *testing.T) { + now := clock.Now() + fooAdmin := &dbAdmin{ + ID: "foo", + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "foo@smallstep.com", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: now, + } + foob, err := json.Marshal(fooAdmin) + assert.FatalError(t, err) + + barAdmin := &dbAdmin{ + ID: "bar", + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "bar@smallstep.com", + Type: linkedca.Admin_ADMIN, + CreatedAt: now, + DeletedAt: now, + } + barb, err := json.Marshal(barAdmin) + assert.FatalError(t, err) + + bazAdmin := &dbAdmin{ + ID: "baz", + AuthorityID: "bazzer", + ProvisionerID: "provID", + Subject: "baz@smallstep.com", + Type: linkedca.Admin_ADMIN, + CreatedAt: now, + } + bazb, err := json.Marshal(bazAdmin) + assert.FatalError(t, err) + + zapAdmin := &dbAdmin{ + ID: "zap", + AuthorityID: admin.DefaultAuthorityID, + ProvisionerID: "provID", + Subject: "zap@smallstep.com", + Type: linkedca.Admin_ADMIN, + CreatedAt: now, + } + zapb, err := json.Marshal(zapAdmin) + assert.FatalError(t, err) + type test struct { + db nosql.DB + err error + adminErr *admin.Error + verify func(*testing.T, []*linkedca.Admin) + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.List-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, adminsTable) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading admins: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + ret := []*database.Entry{ + {Bucket: adminsTable, Key: []byte("foo"), Value: foob}, + {Bucket: adminsTable, Key: []byte("bar"), Value: barb}, + {Bucket: adminsTable, Key: []byte("zap"), Value: []byte("zap")}, + } + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, adminsTable) + + return ret, nil + }, + }, + err: errors.New("error unmarshaling admin zap into dbAdmin"), + } + }, + "ok/none": func(t *testing.T) test { + ret := []*database.Entry{} + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, adminsTable) + + return ret, nil + }, + }, + verify: func(t *testing.T, admins []*linkedca.Admin) { + assert.Equals(t, len(admins), 0) + }, + } + }, + "ok/only-invalid": func(t *testing.T) test { + ret := []*database.Entry{ + {Bucket: adminsTable, Key: []byte("bar"), Value: barb}, + {Bucket: adminsTable, Key: []byte("baz"), Value: bazb}, + } + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, adminsTable) + + return ret, nil + }, + }, + verify: func(t *testing.T, admins []*linkedca.Admin) { + assert.Equals(t, len(admins), 0) + }, + } + }, + "ok": func(t *testing.T) test { + ret := []*database.Entry{ + {Bucket: adminsTable, Key: []byte("foo"), Value: foob}, + {Bucket: adminsTable, Key: []byte("bar"), Value: barb}, + {Bucket: adminsTable, Key: []byte("baz"), Value: bazb}, + {Bucket: adminsTable, Key: []byte("zap"), Value: zapb}, + } + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, adminsTable) + + return ret, nil + }, + }, + verify: func(t *testing.T, admins []*linkedca.Admin) { + assert.Equals(t, len(admins), 2) + + assert.Equals(t, admins[0].Id, fooAdmin.ID) + assert.Equals(t, admins[0].AuthorityId, fooAdmin.AuthorityID) + assert.Equals(t, admins[0].ProvisionerId, fooAdmin.ProvisionerID) + assert.Equals(t, admins[0].Subject, fooAdmin.Subject) + assert.Equals(t, admins[0].Type, fooAdmin.Type) + assert.Equals(t, admins[0].CreatedAt, timestamppb.New(fooAdmin.CreatedAt)) + assert.Equals(t, admins[0].DeletedAt, timestamppb.New(fooAdmin.DeletedAt)) + + assert.Equals(t, admins[1].Id, zapAdmin.ID) + assert.Equals(t, admins[1].AuthorityId, zapAdmin.AuthorityID) + assert.Equals(t, admins[1].ProvisionerId, zapAdmin.ProvisionerID) + assert.Equals(t, admins[1].Subject, zapAdmin.Subject) + assert.Equals(t, admins[1].Type, zapAdmin.Type) + assert.Equals(t, admins[1].CreatedAt, timestamppb.New(zapAdmin.CreatedAt)) + assert.Equals(t, admins[1].DeletedAt, timestamppb.New(zapAdmin.DeletedAt)) + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if admins, err := d.GetAdmins(context.Background()); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + tc.verify(t, admins) + } + }) + } +} diff --git a/authority/admin/db/nosql/nosql.go b/authority/admin/db/nosql/nosql.go new file mode 100644 index 00000000..22b049f5 --- /dev/null +++ b/authority/admin/db/nosql/nosql.go @@ -0,0 +1,88 @@ +package nosql + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + nosqlDB "github.com/smallstep/nosql/database" + "go.step.sm/crypto/randutil" +) + +var ( + adminsTable = []byte("admins") + provisionersTable = []byte("provisioners") +) + +// DB is a struct that implements the AdminDB interface. +type DB struct { + db nosqlDB.DB + authorityID string +} + +// New configures and returns a new Authority DB backend implemented using a nosql DB. +func New(db nosqlDB.DB, authorityID string) (*DB, error) { + tables := [][]byte{adminsTable, provisionersTable} + for _, b := range tables { + if err := db.CreateTable(b); err != nil { + return nil, errors.Wrapf(err, "error creating table %s", + string(b)) + } + } + return &DB{db, authorityID}, nil +} + +// save writes the new data to the database, overwriting the old data if it +// existed. +func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error { + var ( + err error + newB []byte + ) + if nu == nil { + newB = nil + } else { + newB, err = json.Marshal(nu) + if err != nil { + return errors.Wrapf(err, "error marshaling authority type: %s, value: %v", typ, nu) + } + } + var oldB []byte + if old == nil { + oldB = nil + } else { + oldB, err = json.Marshal(old) + if err != nil { + return errors.Wrapf(err, "error marshaling admin type: %s, value: %v", typ, old) + } + } + + _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB) + switch { + case err != nil: + return errors.Wrapf(err, "error saving authority %s", typ) + case !swapped: + return errors.Errorf("error saving authority %s; changed since last read", typ) + default: + return nil + } +} + +func randID() (val string, err error) { + val, err = randutil.UUIDv4() + if err != nil { + return "", errors.Wrap(err, "error generating random alphanumeric ID") + } + return val, nil +} + +// Clock that returns time in UTC rounded to seconds. +type Clock struct{} + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Truncate(time.Second) +} + +var clock = new(Clock) diff --git a/authority/admin/db/nosql/provisioner.go b/authority/admin/db/nosql/provisioner.go new file mode 100644 index 00000000..71d9c8d6 --- /dev/null +++ b/authority/admin/db/nosql/provisioner.go @@ -0,0 +1,211 @@ +package nosql + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/nosql" + "go.step.sm/linkedca" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// dbProvisioner is the database representation of a Provisioner type. +type dbProvisioner struct { + ID string `json:"id"` + AuthorityID string `json:"authorityID"` + Type linkedca.Provisioner_Type `json:"type"` + Name string `json:"name"` + Claims *linkedca.Claims `json:"claims"` + Details []byte `json:"details"` + X509Template *linkedca.Template `json:"x509Template"` + SSHTemplate *linkedca.Template `json:"sshTemplate"` + CreatedAt time.Time `json:"createdAt"` + DeletedAt time.Time `json:"deletedAt"` +} + +func (dbp *dbProvisioner) clone() *dbProvisioner { + u := *dbp + return &u +} + +func (dbp *dbProvisioner) convert2linkedca() (*linkedca.Provisioner, error) { + details, err := admin.UnmarshalProvisionerDetails(dbp.Type, dbp.Details) + if err != nil { + return nil, err + } + + return &linkedca.Provisioner{ + Id: dbp.ID, + AuthorityId: dbp.AuthorityID, + Type: dbp.Type, + Name: dbp.Name, + Claims: dbp.Claims, + Details: details, + X509Template: dbp.X509Template, + SshTemplate: dbp.SSHTemplate, + CreatedAt: timestamppb.New(dbp.CreatedAt), + DeletedAt: timestamppb.New(dbp.DeletedAt), + }, nil +} + +func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) { + data, err := db.db.Get(provisionersTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id) + } else if err != nil { + return nil, errors.Wrapf(err, "error loading provisioner %s", id) + } + return data, nil +} + +func (db *DB) unmarshalDBProvisioner(data []byte, id string) (*dbProvisioner, error) { + var dbp = new(dbProvisioner) + if err := json.Unmarshal(data, dbp); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling provisioner %s into dbProvisioner", id) + } + if !dbp.DeletedAt.IsZero() { + return nil, admin.NewError(admin.ErrorDeletedType, "provisioner %s is deleted", id) + } + if dbp.AuthorityID != db.authorityID { + return nil, admin.NewError(admin.ErrorAuthorityMismatchType, + "provisioner %s is not owned by authority %s", id, db.authorityID) + } + return dbp, nil +} + +func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner, error) { + data, err := db.getDBProvisionerBytes(ctx, id) + if err != nil { + return nil, err + } + dbp, err := db.unmarshalDBProvisioner(data, id) + if err != nil { + return nil, err + } + return dbp, nil +} + +func (db *DB) unmarshalProvisioner(data []byte, id string) (*linkedca.Provisioner, error) { + dbp, err := db.unmarshalDBProvisioner(data, id) + if err != nil { + return nil, err + } + + return dbp.convert2linkedca() +} + +// GetProvisioner retrieves and unmarshals a provisioner from the database. +func (db *DB) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) { + data, err := db.getDBProvisionerBytes(ctx, id) + if err != nil { + return nil, err + } + + prov, err := db.unmarshalProvisioner(data, id) + if err != nil { + return nil, err + } + return prov, nil +} + +// GetProvisioners retrieves and unmarshals all active (not deleted) provisioners +// from the database. +func (db *DB) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) { + dbEntries, err := db.db.List(provisionersTable) + if err != nil { + return nil, errors.Wrap(err, "error loading provisioners") + } + var provs []*linkedca.Provisioner + for _, entry := range dbEntries { + prov, err := db.unmarshalProvisioner(entry.Value, string(entry.Key)) + if err != nil { + switch k := err.(type) { + case *admin.Error: + if k.IsType(admin.ErrorDeletedType) || k.IsType(admin.ErrorAuthorityMismatchType) { + continue + } else { + return nil, err + } + default: + return nil, err + } + } + if prov.AuthorityId != db.authorityID { + continue + } + provs = append(provs, prov) + } + return provs, nil +} + +// CreateProvisioner stores a new provisioner to the database. +func (db *DB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + var err error + prov.Id, err = randID() + if err != nil { + return admin.WrapErrorISE(err, "error generating random id for provisioner") + } + + details, err := json.Marshal(prov.Details.GetData()) + if err != nil { + return admin.WrapErrorISE(err, "error marshaling details when creating provisioner %s", prov.Name) + } + + dbp := &dbProvisioner{ + ID: prov.Id, + AuthorityID: db.authorityID, + Type: prov.Type, + Name: prov.Name, + Claims: prov.Claims, + Details: details, + X509Template: prov.X509Template, + SSHTemplate: prov.SshTemplate, + CreatedAt: clock.Now(), + } + + if err := db.save(ctx, prov.Id, dbp, nil, "provisioner", provisionersTable); err != nil { + return admin.WrapErrorISE(err, "error creating provisioner %s", prov.Name) + } + + return nil +} + +// UpdateProvisioner saves an updated provisioner to the database. +func (db *DB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + old, err := db.getDBProvisioner(ctx, prov.Id) + if err != nil { + return err + } + + nu := old.clone() + + if old.Type != prov.Type { + return admin.NewError(admin.ErrorBadRequestType, "cannot update provisioner type") + } + nu.Name = prov.Name + nu.Claims = prov.Claims + nu.Details, err = json.Marshal(prov.Details.GetData()) + if err != nil { + return admin.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name) + } + nu.X509Template = prov.X509Template + nu.SSHTemplate = prov.SshTemplate + + return db.save(ctx, prov.Id, nu, old, "provisioner", provisionersTable) +} + +// DeleteProvisioner saves an updated admin to the database. +func (db *DB) DeleteProvisioner(ctx context.Context, id string) error { + old, err := db.getDBProvisioner(ctx, id) + if err != nil { + return err + } + + nu := old.clone() + nu.DeletedAt = clock.Now() + + return db.save(ctx, old.ID, nu, old, "provisioner", provisionersTable) +} diff --git a/authority/admin/db/nosql/provisioner_test.go b/authority/admin/db/nosql/provisioner_test.go new file mode 100644 index 00000000..e599ea04 --- /dev/null +++ b/authority/admin/db/nosql/provisioner_test.go @@ -0,0 +1,1208 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" + "go.step.sm/linkedca" +) + +func TestDB_getDBProvisionerBytes(t *testing.T) { + provID := "provID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading provisioner provID: force"), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return []byte("foo"), nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if b, err := d.getDBProvisionerBytes(context.Background(), provID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, string(b), "foo") + } + }) + } +} + +func TestDB_getDBProvisioner(t *testing.T) { + provID := "provID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + dbp *dbProvisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading provisioner provID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling provisioner provID into dbProvisioner"), + } + }, + "fail/deleted": func(t *testing.T) test { + now := clock.Now() + dbp := &dbProvisioner{ + ID: provID, + AuthorityID: admin.DefaultAuthorityID, + Type: linkedca.Provisioner_JWK, + Name: "provName", + CreatedAt: now, + DeletedAt: now, + } + b, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return b, nil + }, + }, + adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner provID is deleted"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbp := &dbProvisioner{ + ID: provID, + AuthorityID: admin.DefaultAuthorityID, + Type: linkedca.Provisioner_JWK, + Name: "provName", + CreatedAt: now, + } + b, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return b, nil + }, + }, + dbp: dbp, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if dbp, err := d.getDBProvisioner(context.Background(), provID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, dbp.ID, provID) + assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID) + assert.Equals(t, dbp.Type, tc.dbp.Type) + assert.Equals(t, dbp.Name, tc.dbp.Name) + assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt) + assert.Fatal(t, dbp.DeletedAt.IsZero()) + } + }) + } +} + +func TestDB_unmarshalDBProvisioner(t *testing.T) { + provID := "provID" + type test struct { + in []byte + err error + adminErr *admin.Error + dbp *dbProvisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + in: []byte("foo"), + err: errors.New("error unmarshaling provisioner provID into dbProvisioner"), + } + }, + "fail/deleted-error": func(t *testing.T) test { + dbp := &dbProvisioner{ + DeletedAt: clock.Now(), + } + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + in: data, + adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner %s is deleted", provID), + } + }, + "fail/authority-mismatch-error": func(t *testing.T) test { + dbp := &dbProvisioner{ + ID: provID, + AuthorityID: "foo", + } + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + in: data, + adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, + "provisioner %s is not owned by authority %s", provID, admin.DefaultAuthorityID), + } + }, + "ok": func(t *testing.T) test { + dbp := &dbProvisioner{ + ID: provID, + AuthorityID: admin.DefaultAuthorityID, + Type: linkedca.Provisioner_JWK, + Name: "provName", + CreatedAt: clock.Now(), + } + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + in: data, + dbp: dbp, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{authorityID: admin.DefaultAuthorityID} + if dbp, err := d.unmarshalDBProvisioner(tc.in, provID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, dbp.ID, provID) + assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID) + assert.Equals(t, dbp.Type, tc.dbp.Type) + assert.Equals(t, dbp.Name, tc.dbp.Name) + assert.Equals(t, dbp.Details, tc.dbp.Details) + assert.Equals(t, dbp.Claims, tc.dbp.Claims) + assert.Equals(t, dbp.X509Template, tc.dbp.X509Template) + assert.Equals(t, dbp.SSHTemplate, tc.dbp.SSHTemplate) + assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt) + assert.Fatal(t, dbp.DeletedAt.IsZero()) + } + }) + } +} + +func defaultDBP(t *testing.T) *dbProvisioner { + details := &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + ForceCn: true, + }, + } + detailBytes, err := json.Marshal(details) + assert.FatalError(t, err) + + return &dbProvisioner{ + ID: "provID", + AuthorityID: admin.DefaultAuthorityID, + Type: linkedca.Provisioner_ACME, + Name: "provName", + Details: detailBytes, + Claims: &linkedca.Claims{ + DisableRenewal: true, + X509: &linkedca.X509Claims{ + Enabled: true, + Durations: &linkedca.Durations{ + Min: "5m", + Max: "12h", + Default: "6h", + }, + }, + Ssh: &linkedca.SSHClaims{ + Enabled: true, + UserDurations: &linkedca.Durations{ + Min: "5m", + Max: "12h", + Default: "6h", + }, + HostDurations: &linkedca.Durations{ + Min: "5m", + Max: "12h", + Default: "6h", + }, + }, + }, + X509Template: &linkedca.Template{ + Template: []byte("foo"), + Data: []byte("bar"), + }, + SSHTemplate: &linkedca.Template{ + Template: []byte("baz"), + Data: []byte("zap"), + }, + CreatedAt: clock.Now(), + } +} + +func TestDB_unmarshalProvisioner(t *testing.T) { + provID := "provID" + type test struct { + in []byte + err error + adminErr *admin.Error + dbp *dbProvisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + in: []byte("foo"), + err: errors.New("error unmarshaling provisioner provID into dbProvisioner"), + } + }, + "fail/deleted-error": func(t *testing.T) test { + dbp := &dbProvisioner{ + DeletedAt: time.Now(), + } + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + in: data, + adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner provID is deleted"), + } + }, + "ok": func(t *testing.T) test { + dbp := defaultDBP(t) + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + in: data, + dbp: dbp, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{authorityID: admin.DefaultAuthorityID} + if prov, err := d.unmarshalProvisioner(tc.in, provID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, prov.Id, provID) + assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID) + assert.Equals(t, prov.Type, tc.dbp.Type) + assert.Equals(t, prov.Name, tc.dbp.Name) + assert.Equals(t, prov.Claims, tc.dbp.Claims) + assert.Equals(t, prov.X509Template, tc.dbp.X509Template) + assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate) + + retDetailsBytes, err := json.Marshal(prov.Details.GetData()) + assert.FatalError(t, err) + assert.Equals(t, retDetailsBytes, tc.dbp.Details) + } + }) + } +} + +func TestDB_GetProvisioner(t *testing.T) { + provID := "provID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + dbp *dbProvisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading provisioner provID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling provisioner provID into dbProvisioner"), + } + }, + "fail/deleted": func(t *testing.T) test { + dbp := defaultDBP(t) + dbp.DeletedAt = clock.Now() + b, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return b, nil + }, + }, + dbp: dbp, + adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner provID is deleted"), + } + }, + "fail/authorityID-mismatch": func(t *testing.T) test { + dbp := defaultDBP(t) + dbp.AuthorityID = "foo" + b, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return b, nil + }, + }, + dbp: dbp, + adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, + "provisioner %s is not owned by authority %s", dbp.ID, admin.DefaultAuthorityID), + } + }, + "ok": func(t *testing.T) test { + dbp := defaultDBP(t) + b, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return b, nil + }, + }, + dbp: dbp, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if prov, err := d.GetProvisioner(context.Background(), provID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, prov.Id, provID) + assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID) + assert.Equals(t, prov.Type, tc.dbp.Type) + assert.Equals(t, prov.Name, tc.dbp.Name) + assert.Equals(t, prov.Claims, tc.dbp.Claims) + assert.Equals(t, prov.X509Template, tc.dbp.X509Template) + assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate) + + retDetailsBytes, err := json.Marshal(prov.Details.GetData()) + assert.FatalError(t, err) + assert.Equals(t, retDetailsBytes, tc.dbp.Details) + } + }) + } +} + +func TestDB_DeleteProvisioner(t *testing.T) { + provID := "provID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading provisioner provID: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + dbp := defaultDBP(t) + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return data, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + assert.Equals(t, string(old), string(data)) + + var _dbp = new(dbProvisioner) + assert.FatalError(t, json.Unmarshal(nu, _dbp)) + + assert.Equals(t, _dbp.ID, provID) + assert.Equals(t, _dbp.AuthorityID, dbp.AuthorityID) + assert.Equals(t, _dbp.Type, dbp.Type) + assert.Equals(t, _dbp.Name, dbp.Name) + assert.Equals(t, _dbp.Claims, dbp.Claims) + assert.Equals(t, _dbp.X509Template, dbp.X509Template) + assert.Equals(t, _dbp.SSHTemplate, dbp.SSHTemplate) + assert.Equals(t, _dbp.CreatedAt, dbp.CreatedAt) + assert.Equals(t, _dbp.Details, dbp.Details) + + assert.True(t, _dbp.DeletedAt.Before(time.Now())) + assert.True(t, _dbp.DeletedAt.After(time.Now().Add(-time.Minute))) + + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving authority provisioner: force"), + } + }, + "ok": func(t *testing.T) test { + dbp := defaultDBP(t) + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return data, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + assert.Equals(t, string(old), string(data)) + + var _dbp = new(dbProvisioner) + assert.FatalError(t, json.Unmarshal(nu, _dbp)) + + assert.Equals(t, _dbp.ID, provID) + assert.Equals(t, _dbp.AuthorityID, dbp.AuthorityID) + assert.Equals(t, _dbp.Type, dbp.Type) + assert.Equals(t, _dbp.Name, dbp.Name) + assert.Equals(t, _dbp.Claims, dbp.Claims) + assert.Equals(t, _dbp.X509Template, dbp.X509Template) + assert.Equals(t, _dbp.SSHTemplate, dbp.SSHTemplate) + assert.Equals(t, _dbp.CreatedAt, dbp.CreatedAt) + assert.Equals(t, _dbp.Details, dbp.Details) + + assert.True(t, _dbp.DeletedAt.Before(time.Now())) + assert.True(t, _dbp.DeletedAt.After(time.Now().Add(-time.Minute))) + + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if err := d.DeleteProvisioner(context.Background(), provID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } + }) + } +} + +func TestDB_GetProvisioners(t *testing.T) { + fooProv := defaultDBP(t) + fooProv.Name = "foo" + foob, err := json.Marshal(fooProv) + assert.FatalError(t, err) + + barProv := defaultDBP(t) + barProv.Name = "bar" + barProv.DeletedAt = clock.Now() + barb, err := json.Marshal(barProv) + assert.FatalError(t, err) + + bazProv := defaultDBP(t) + bazProv.Name = "baz" + bazProv.AuthorityID = "baz" + bazb, err := json.Marshal(bazProv) + assert.FatalError(t, err) + + zapProv := defaultDBP(t) + zapProv.Name = "zap" + zapb, err := json.Marshal(zapProv) + assert.FatalError(t, err) + + type test struct { + db nosql.DB + err error + adminErr *admin.Error + verify func(*testing.T, []*linkedca.Provisioner) + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.List-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, provisionersTable) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading provisioners"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + ret := []*database.Entry{ + {Bucket: provisionersTable, Key: []byte("foo"), Value: foob}, + {Bucket: provisionersTable, Key: []byte("bar"), Value: barb}, + {Bucket: provisionersTable, Key: []byte("zap"), Value: []byte("zap")}, + } + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, provisionersTable) + + return ret, nil + }, + }, + err: errors.New("error unmarshaling provisioner zap into dbProvisioner"), + } + }, + "ok/none": func(t *testing.T) test { + ret := []*database.Entry{} + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, provisionersTable) + + return ret, nil + }, + }, + verify: func(t *testing.T, provs []*linkedca.Provisioner) { + assert.Equals(t, len(provs), 0) + }, + } + }, + "ok/only-invalid": func(t *testing.T) test { + ret := []*database.Entry{ + {Bucket: provisionersTable, Key: []byte("bar"), Value: barb}, + {Bucket: provisionersTable, Key: []byte("baz"), Value: bazb}, + } + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, provisionersTable) + + return ret, nil + }, + }, + verify: func(t *testing.T, provs []*linkedca.Provisioner) { + assert.Equals(t, len(provs), 0) + }, + } + }, + "ok": func(t *testing.T) test { + ret := []*database.Entry{ + {Bucket: provisionersTable, Key: []byte("foo"), Value: foob}, + {Bucket: provisionersTable, Key: []byte("bar"), Value: barb}, + {Bucket: provisionersTable, Key: []byte("baz"), Value: bazb}, + {Bucket: provisionersTable, Key: []byte("zap"), Value: zapb}, + } + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*database.Entry, error) { + assert.Equals(t, bucket, provisionersTable) + + return ret, nil + }, + }, + verify: func(t *testing.T, provs []*linkedca.Provisioner) { + assert.Equals(t, len(provs), 2) + + assert.Equals(t, provs[0].Id, fooProv.ID) + assert.Equals(t, provs[0].AuthorityId, fooProv.AuthorityID) + assert.Equals(t, provs[0].Type, fooProv.Type) + assert.Equals(t, provs[0].Name, fooProv.Name) + assert.Equals(t, provs[0].Claims, fooProv.Claims) + assert.Equals(t, provs[0].X509Template, fooProv.X509Template) + assert.Equals(t, provs[0].SshTemplate, fooProv.SSHTemplate) + + retDetailsBytes, err := json.Marshal(provs[0].Details.GetData()) + assert.FatalError(t, err) + assert.Equals(t, retDetailsBytes, fooProv.Details) + + assert.Equals(t, provs[1].Id, zapProv.ID) + assert.Equals(t, provs[1].AuthorityId, zapProv.AuthorityID) + assert.Equals(t, provs[1].Type, zapProv.Type) + assert.Equals(t, provs[1].Name, zapProv.Name) + assert.Equals(t, provs[1].Claims, zapProv.Claims) + assert.Equals(t, provs[1].X509Template, zapProv.X509Template) + assert.Equals(t, provs[1].SshTemplate, zapProv.SSHTemplate) + + retDetailsBytes, err = json.Marshal(provs[1].Details.GetData()) + assert.FatalError(t, err) + assert.Equals(t, retDetailsBytes, zapProv.Details) + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if provs, err := d.GetProvisioners(context.Background()); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + tc.verify(t, provs) + } + }) + } +} + +func TestDB_CreateProvisioner(t *testing.T) { + type test struct { + db nosql.DB + err error + adminErr *admin.Error + prov *linkedca.Provisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/save-error": func(t *testing.T) test { + dbp := defaultDBP(t) + prov, err := dbp.convert2linkedca() + assert.FatalError(t, err) + + return test{ + prov: prov, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, old, nil) + + var _dbp = new(dbProvisioner) + assert.FatalError(t, json.Unmarshal(nu, _dbp)) + + assert.True(t, len(_dbp.ID) > 0 && _dbp.ID == string(key)) + assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId) + assert.Equals(t, _dbp.Type, prov.Type) + assert.Equals(t, _dbp.Name, prov.Name) + assert.Equals(t, _dbp.Claims, prov.Claims) + assert.Equals(t, _dbp.X509Template, prov.X509Template) + assert.Equals(t, _dbp.SSHTemplate, prov.SshTemplate) + + retDetailsBytes, err := json.Marshal(prov.Details.GetData()) + assert.FatalError(t, err) + assert.Equals(t, retDetailsBytes, _dbp.Details) + + assert.True(t, _dbp.DeletedAt.IsZero()) + assert.True(t, _dbp.CreatedAt.Before(time.Now())) + assert.True(t, _dbp.CreatedAt.After(time.Now().Add(-time.Minute))) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error creating provisioner provName: error saving authority provisioner: force"), + } + }, + "ok": func(t *testing.T) test { + dbp := defaultDBP(t) + prov, err := dbp.convert2linkedca() + assert.FatalError(t, err) + + return test{ + prov: prov, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, old, nil) + + var _dbp = new(dbProvisioner) + assert.FatalError(t, json.Unmarshal(nu, _dbp)) + + assert.True(t, len(_dbp.ID) > 0 && _dbp.ID == string(key)) + assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId) + assert.Equals(t, _dbp.Type, prov.Type) + assert.Equals(t, _dbp.Name, prov.Name) + assert.Equals(t, _dbp.Claims, prov.Claims) + assert.Equals(t, _dbp.X509Template, prov.X509Template) + assert.Equals(t, _dbp.SSHTemplate, prov.SshTemplate) + + retDetailsBytes, err := json.Marshal(prov.Details.GetData()) + assert.FatalError(t, err) + assert.Equals(t, retDetailsBytes, _dbp.Details) + + assert.True(t, _dbp.DeletedAt.IsZero()) + assert.True(t, _dbp.CreatedAt.Before(time.Now())) + assert.True(t, _dbp.CreatedAt.After(time.Now().Add(-time.Minute))) + + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if err := d.CreateProvisioner(context.Background(), tc.prov); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } + }) + } +} + +func TestDB_UpdateProvisioner(t *testing.T) { + provID := "provID" + type test struct { + db nosql.DB + err error + adminErr *admin.Error + prov *linkedca.Provisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + prov: &linkedca.Provisioner{Id: provID}, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, database.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + prov: &linkedca.Provisioner{Id: provID}, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading provisioner provID: force"), + } + }, + "fail/update-deleted": func(t *testing.T) test { + dbp := defaultDBP(t) + dbp.DeletedAt = clock.Now() + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + prov: &linkedca.Provisioner{Id: provID}, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return data, nil + }, + }, + adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner %s is deleted", provID), + } + }, + "fail/update-type-error": func(t *testing.T) test { + dbp := defaultDBP(t) + + upd, err := dbp.convert2linkedca() + assert.FatalError(t, err) + upd.Type = linkedca.Provisioner_JWK + + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + prov: upd, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return data, nil + }, + }, + adminErr: admin.NewError(admin.ErrorBadRequestType, "cannot update provisioner type"), + } + }, + "fail/save-error": func(t *testing.T) test { + dbp := defaultDBP(t) + + prov, err := dbp.convert2linkedca() + assert.FatalError(t, err) + + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + prov: prov, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return data, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + assert.Equals(t, string(old), string(data)) + + var _dbp = new(dbProvisioner) + assert.FatalError(t, json.Unmarshal(nu, _dbp)) + + assert.True(t, len(_dbp.ID) > 0 && _dbp.ID == string(key)) + assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId) + assert.Equals(t, _dbp.Type, prov.Type) + assert.Equals(t, _dbp.Name, prov.Name) + assert.Equals(t, _dbp.Claims, prov.Claims) + assert.Equals(t, _dbp.X509Template, prov.X509Template) + assert.Equals(t, _dbp.SSHTemplate, prov.SshTemplate) + + retDetailsBytes, err := json.Marshal(prov.Details.GetData()) + assert.FatalError(t, err) + assert.Equals(t, retDetailsBytes, _dbp.Details) + + assert.True(t, _dbp.DeletedAt.IsZero()) + assert.True(t, _dbp.CreatedAt.Before(time.Now())) + assert.True(t, _dbp.CreatedAt.After(time.Now().Add(-time.Minute))) + + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving authority provisioner: force"), + } + }, + "ok": func(t *testing.T) test { + dbp := defaultDBP(t) + + prov, err := dbp.convert2linkedca() + assert.FatalError(t, err) + + prov.Name = "new-name" + prov.Claims = &linkedca.Claims{ + DisableRenewal: true, + X509: &linkedca.X509Claims{ + Enabled: true, + Durations: &linkedca.Durations{ + Min: "10m", + Max: "8h", + Default: "4h", + }, + }, + Ssh: &linkedca.SSHClaims{ + Enabled: true, + UserDurations: &linkedca.Durations{ + Min: "7m", + Max: "11h", + Default: "5h", + }, + HostDurations: &linkedca.Durations{ + Min: "4m", + Max: "24h", + Default: "24h", + }, + }, + } + prov.X509Template = &linkedca.Template{ + Template: []byte("x"), + Data: []byte("y"), + } + prov.SshTemplate = &linkedca.Template{ + Template: []byte("z"), + Data: []byte("w"), + } + prov.Details = &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + ForceCn: false, + }, + }, + } + + data, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + prov: prov, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + + return data, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, provisionersTable) + assert.Equals(t, string(key), provID) + assert.Equals(t, string(old), string(data)) + + var _dbp = new(dbProvisioner) + assert.FatalError(t, json.Unmarshal(nu, _dbp)) + + assert.True(t, len(_dbp.ID) > 0 && _dbp.ID == string(key)) + assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId) + assert.Equals(t, _dbp.Type, prov.Type) + assert.Equals(t, _dbp.Name, prov.Name) + assert.Equals(t, _dbp.Claims, prov.Claims) + assert.Equals(t, _dbp.X509Template, prov.X509Template) + assert.Equals(t, _dbp.SSHTemplate, prov.SshTemplate) + + retDetailsBytes, err := json.Marshal(prov.Details.GetData()) + assert.FatalError(t, err) + assert.Equals(t, retDetailsBytes, _dbp.Details) + + assert.True(t, _dbp.DeletedAt.IsZero()) + assert.True(t, _dbp.CreatedAt.Before(time.Now())) + assert.True(t, _dbp.CreatedAt.After(time.Now().Add(-time.Minute))) + + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if err := d.UpdateProvisioner(context.Background(), tc.prov); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } + }) + } +} diff --git a/authority/admin/errors.go b/authority/admin/errors.go new file mode 100644 index 00000000..607093b0 --- /dev/null +++ b/authority/admin/errors.go @@ -0,0 +1,223 @@ +package admin + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/logging" +) + +// ProblemType is the type of the Admin problem. +type ProblemType int + +const ( + // ErrorNotFoundType resource not found. + ErrorNotFoundType ProblemType = iota + // ErrorAuthorityMismatchType resource Authority ID does not match the + // context Authority ID. + ErrorAuthorityMismatchType + // ErrorDeletedType resource has been deleted. + ErrorDeletedType + // ErrorBadRequestType bad request. + ErrorBadRequestType + // ErrorNotImplementedType not implemented. + ErrorNotImplementedType + // ErrorUnauthorizedType internal server error. + ErrorUnauthorizedType + // ErrorServerInternalType internal server error. + ErrorServerInternalType +) + +// String returns the string representation of the admin problem type, +// fulfilling the Stringer interface. +func (ap ProblemType) String() string { + switch ap { + case ErrorNotFoundType: + return "notFound" + case ErrorAuthorityMismatchType: + return "authorityMismatch" + case ErrorDeletedType: + return "deleted" + case ErrorBadRequestType: + return "badRequest" + case ErrorNotImplementedType: + return "notImplemented" + case ErrorUnauthorizedType: + return "unauthorized" + case ErrorServerInternalType: + return "internalServerError" + default: + return fmt.Sprintf("unsupported error type '%d'", int(ap)) + } +} + +type errorMetadata struct { + details string + status int + typ string + String string +} + +var ( + errorServerInternalMetadata = errorMetadata{ + typ: ErrorServerInternalType.String(), + details: "the server experienced an internal error", + status: 500, + } + errorMap = map[ProblemType]errorMetadata{ + ErrorNotFoundType: { + typ: ErrorNotFoundType.String(), + details: "resource not found", + status: http.StatusNotFound, + }, + ErrorAuthorityMismatchType: { + typ: ErrorAuthorityMismatchType.String(), + details: "resource not owned by authority", + status: http.StatusUnauthorized, + }, + ErrorDeletedType: { + typ: ErrorDeletedType.String(), + details: "resource is deleted", + status: http.StatusNotFound, + }, + ErrorNotImplementedType: { + typ: ErrorNotImplementedType.String(), + details: "not implemented", + status: http.StatusNotImplemented, + }, + ErrorBadRequestType: { + typ: ErrorBadRequestType.String(), + details: "bad request", + status: http.StatusBadRequest, + }, + ErrorUnauthorizedType: { + typ: ErrorUnauthorizedType.String(), + details: "unauthorized", + status: http.StatusUnauthorized, + }, + ErrorServerInternalType: errorServerInternalMetadata, + } +) + +// Error represents an Admin +type Error struct { + Type string `json:"type"` + Detail string `json:"detail"` + Message string `json:"message"` + Err error `json:"-"` + Status int `json:"-"` +} + +// IsType returns true if the error type matches the input type. +func (e *Error) IsType(pt ProblemType) bool { + return pt.String() == e.Type +} + +// NewError creates a new Error type. +func NewError(pt ProblemType, msg string, args ...interface{}) *Error { + return newError(pt, errors.Errorf(msg, args...)) +} + +func newError(pt ProblemType, err error) *Error { + meta, ok := errorMap[pt] + if !ok { + meta = errorServerInternalMetadata + return &Error{ + Type: meta.typ, + Detail: meta.details, + Status: meta.status, + Err: err, + } + } + + return &Error{ + Type: meta.typ, + Detail: meta.details, + Status: meta.status, + Err: err, + } +} + +// NewErrorISE creates a new ErrorServerInternalType Error. +func NewErrorISE(msg string, args ...interface{}) *Error { + return NewError(ErrorServerInternalType, msg, args...) +} + +// WrapError attempts to wrap the internal error. +func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error { + switch e := err.(type) { + case nil: + return nil + case *Error: + if e.Err == nil { + e.Err = errors.Errorf(msg+"; "+e.Detail, args...) + } else { + e.Err = errors.Wrapf(e.Err, msg, args...) + } + return e + default: + return newError(typ, errors.Wrapf(err, msg, args...)) + } +} + +// WrapErrorISE shortcut to wrap an internal server error type. +func WrapErrorISE(err error, msg string, args ...interface{}) *Error { + return WrapError(ErrorServerInternalType, err, msg, args...) +} + +// StatusCode returns the status code and implements the StatusCoder interface. +func (e *Error) StatusCode() int { + return e.Status +} + +// Error allows AError to implement the error interface. +func (e *Error) Error() string { + return e.Err.Error() +} + +// Cause returns the internal error and implements the Causer interface. +func (e *Error) Cause() error { + if e.Err == nil { + return errors.New(e.Detail) + } + return e.Err +} + +// ToLog implements the EnableLogger interface. +func (e *Error) ToLog() (interface{}, error) { + b, err := json.Marshal(e) + if err != nil { + return nil, WrapErrorISE(err, "error marshaling authority.Error for logging") + } + return string(b), nil +} + +// WriteError writes to w a JSON representation of the given error. +func WriteError(w http.ResponseWriter, err *Error) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(err.StatusCode()) + + err.Message = err.Err.Error() + // Write errors in the response writer + if rl, ok := w.(logging.ResponseLogger); ok { + rl.WithFields(map[string]interface{}{ + "error": err.Err, + }) + if os.Getenv("STEPDEBUG") == "1" { + if e, ok := err.Err.(errs.StackTracer); ok { + rl.WithFields(map[string]interface{}{ + "stack-trace": fmt.Sprintf("%+v", e), + }) + } + } + } + + if err := json.NewEncoder(w).Encode(err); err != nil { + log.Println(err) + } +} diff --git a/authority/administrator/collection.go b/authority/administrator/collection.go new file mode 100644 index 00000000..88d7bb2c --- /dev/null +++ b/authority/administrator/collection.go @@ -0,0 +1,243 @@ +package administrator + +import ( + "sort" + "sync" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" +) + +// DefaultAdminLimit is the default limit for listing provisioners. +const DefaultAdminLimit = 20 + +// DefaultAdminMax is the maximum limit for listing provisioners. +const DefaultAdminMax = 100 + +type adminSlice []*linkedca.Admin + +func (p adminSlice) Len() int { return len(p) } +func (p adminSlice) Less(i, j int) bool { return p[i].Id < p[j].Id } +func (p adminSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// Collection is a memory map of admins. +type Collection struct { + byID *sync.Map + bySubProv *sync.Map + byProv *sync.Map + sorted adminSlice + provisioners *provisioner.Collection + superCount int + superCountByProvisioner map[string]int +} + +// NewCollection initializes a collection of provisioners. The given list of +// audiences are the audiences used by the JWT provisioner. +func NewCollection(provisioners *provisioner.Collection) *Collection { + return &Collection{ + byID: new(sync.Map), + byProv: new(sync.Map), + bySubProv: new(sync.Map), + superCountByProvisioner: map[string]int{}, + provisioners: provisioners, + } +} + +// LoadByID a admin by the ID. +func (c *Collection) LoadByID(id string) (*linkedca.Admin, bool) { + return loadAdmin(c.byID, id) +} + +type subProv struct { + subject string + provisioner string +} + +func newSubProv(subject, prov string) subProv { + return subProv{subject, prov} +} + +// LoadBySubProv a admin by the subject and provisioner name. +func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) { + return loadAdmin(c.bySubProv, newSubProv(sub, provName)) +} + +// LoadByProvisioner a admin by the subject and provisioner name. +func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) { + val, ok := c.byProv.Load(provName) + if !ok { + return nil, false + } + admins, ok := val.([]*linkedca.Admin) + if !ok { + return nil, false + } + return admins, true +} + +// Store adds an admin to the collection and enforces the uniqueness of +// admin IDs and amdin subject <-> provisioner name combos. +func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error { + // Input validation. + if adm.ProvisionerId != prov.GetID() { + return admin.NewErrorISE("admin.provisionerId does not match provisioner argument") + } + + // Store admin always in byID. ID must be unique. + if _, loaded := c.byID.LoadOrStore(adm.Id, adm); loaded { + return errors.New("cannot add multiple admins with the same id") + } + + provName := prov.GetName() + // Store admin always in bySubProv. Subject <-> ProvisionerName must be unique. + if _, loaded := c.bySubProv.LoadOrStore(newSubProv(adm.Subject, provName), adm); loaded { + c.byID.Delete(adm.Id) + return errors.New("cannot add multiple admins with the same subject and provisioner") + } + + var isSuper = (adm.Type == linkedca.Admin_SUPER_ADMIN) + if admins, ok := c.LoadByProvisioner(provName); ok { + c.byProv.Store(provName, append(admins, adm)) + if isSuper { + c.superCountByProvisioner[provName]++ + } + } else { + c.byProv.Store(provName, []*linkedca.Admin{adm}) + if isSuper { + c.superCountByProvisioner[provName] = 1 + } + } + if isSuper { + c.superCount++ + } + + c.sorted = append(c.sorted, adm) + sort.Sort(c.sorted) + + return nil +} + +// Remove deletes an admin from all associated collections and lists. +func (c *Collection) Remove(id string) error { + adm, ok := c.LoadByID(id) + if !ok { + return admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id) + } + if adm.Type == linkedca.Admin_SUPER_ADMIN && c.SuperCount() == 1 { + return admin.NewError(admin.ErrorBadRequestType, "cannot remove the last super admin") + } + prov, ok := c.provisioners.Load(adm.ProvisionerId) + if !ok { + return admin.NewError(admin.ErrorNotFoundType, + "provisioner %s for admin %s not found", adm.ProvisionerId, id) + } + provName := prov.GetName() + adminsByProv, ok := c.LoadByProvisioner(provName) + if !ok { + return admin.NewError(admin.ErrorNotFoundType, + "admins not found for provisioner %s", provName) + } + + // Find index in sorted list. + sortedIndex := sort.Search(c.sorted.Len(), func(i int) bool { return c.sorted[i].Id >= adm.Id }) + if c.sorted[sortedIndex].Id != adm.Id { + return admin.NewError(admin.ErrorNotFoundType, + "admin %s not found in sorted list", adm.Id) + } + + var found bool + for i, a := range adminsByProv { + if a.Id == adm.Id { + // Remove admin from list. https://stackoverflow.com/questions/37334119/how-to-delete-an-element-from-a-slice-in-golang + // Order does not matter. + adminsByProv[i] = adminsByProv[len(adminsByProv)-1] + c.byProv.Store(provName, adminsByProv[:len(adminsByProv)-1]) + found = true + } + } + if !found { + return admin.NewError(admin.ErrorNotFoundType, + "admin %s not found in adminsByProvisioner list", adm.Id) + } + + // Remove index in sorted list + copy(c.sorted[sortedIndex:], c.sorted[sortedIndex+1:]) // Shift a[i+1:] left one index. + c.sorted[len(c.sorted)-1] = nil // Erase last element (write zero value). + c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice. + + c.byID.Delete(adm.Id) + c.bySubProv.Delete(newSubProv(adm.Subject, provName)) + + if adm.Type == linkedca.Admin_SUPER_ADMIN { + c.superCount-- + c.superCountByProvisioner[provName]-- + } + return nil +} + +// Update updates the given admin in all related lists and collections. +func (c *Collection) Update(id string, nu *linkedca.Admin) (*linkedca.Admin, error) { + adm, ok := c.LoadByID(id) + if !ok { + return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", adm.Id) + } + if adm.Type == nu.Type { + return adm, nil + } + if adm.Type == linkedca.Admin_SUPER_ADMIN && c.SuperCount() == 1 { + return nil, admin.NewError(admin.ErrorBadRequestType, "cannot change role of last super admin") + } + + adm.Type = nu.Type + return adm, nil +} + +// SuperCount returns the total number of admins. +func (c *Collection) SuperCount() int { + return c.superCount +} + +// SuperCountByProvisioner returns the total number of admins. +func (c *Collection) SuperCountByProvisioner(provName string) int { + if cnt, ok := c.superCountByProvisioner[provName]; ok { + return cnt + } + return 0 +} + +// Find implements pagination on a list of sorted admins. +func (c *Collection) Find(cursor string, limit int) ([]*linkedca.Admin, string) { + switch { + case limit <= 0: + limit = DefaultAdminLimit + case limit > DefaultAdminMax: + limit = DefaultAdminMax + } + + n := c.sorted.Len() + i := sort.Search(n, func(i int) bool { return c.sorted[i].Id >= cursor }) + + slice := []*linkedca.Admin{} + for ; i < n && len(slice) < limit; i++ { + slice = append(slice, c.sorted[i]) + } + + if i < n { + return slice, c.sorted[i].Id + } + return slice, "" +} + +func loadAdmin(m *sync.Map, key interface{}) (*linkedca.Admin, bool) { + val, ok := m.Load(key) + if !ok { + return nil, false + } + adm, ok := val.(*linkedca.Admin) + if !ok { + return nil, false + } + return adm, true +} diff --git a/authority/admins.go b/authority/admins.go new file mode 100644 index 00000000..b975297a --- /dev/null +++ b/authority/admins.go @@ -0,0 +1,97 @@ +package authority + +import ( + "context" + + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" +) + +// LoadAdminByID returns an *linkedca.Admin with the given ID. +func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) { + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() + return a.admins.LoadByID(id) +} + +// LoadAdminBySubProv returns an *linkedca.Admin with the given ID. +func (a *Authority) LoadAdminBySubProv(subject, prov string) (*linkedca.Admin, bool) { + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() + return a.admins.LoadBySubProv(subject, prov) +} + +// GetAdmins returns a map listing each provisioner and the JWK Key Set +// with their public keys. +func (a *Authority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) { + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() + admins, nextCursor := a.admins.Find(cursor, limit) + return admins, nextCursor, nil +} + +// StoreAdmin stores an *linkedca.Admin to the authority. +func (a *Authority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + if adm.ProvisionerId != prov.GetID() { + return admin.NewErrorISE("admin.provisionerId does not match provisioner argument") + } + + if _, ok := a.admins.LoadBySubProv(adm.Subject, prov.GetName()); ok { + return admin.NewError(admin.ErrorBadRequestType, + "admin with subject %s and provisioner %s already exists", adm.Subject, prov.GetName()) + } + // Store to database -- this will set the ID. + if err := a.adminDB.CreateAdmin(ctx, adm); err != nil { + return admin.WrapErrorISE(err, "error creating admin") + } + if err := a.admins.Store(adm, prov); err != nil { + if err := a.reloadAdminResources(ctx); err != nil { + return admin.WrapErrorISE(err, "error reloading admin resources on failed admin store") + } + return admin.WrapErrorISE(err, "error storing admin in authority cache") + } + return nil +} + +// UpdateAdmin stores an *linkedca.Admin to the authority. +func (a *Authority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + adm, err := a.admins.Update(id, nu) + if err != nil { + return nil, admin.WrapErrorISE(err, "error updating cached admin %s", id) + } + if err := a.adminDB.UpdateAdmin(ctx, adm); err != nil { + if err := a.reloadAdminResources(ctx); err != nil { + return nil, admin.WrapErrorISE(err, "error reloading admin resources on failed admin update") + } + return nil, admin.WrapErrorISE(err, "error updating admin %s", id) + } + return adm, nil +} + +// RemoveAdmin removes an *linkedca.Admin from the authority. +func (a *Authority) RemoveAdmin(ctx context.Context, id string) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + return a.removeAdmin(ctx, id) +} + +// removeAdmin helper that assumes lock. +func (a *Authority) removeAdmin(ctx context.Context, id string) error { + if err := a.admins.Remove(id); err != nil { + return admin.WrapErrorISE(err, "error removing admin %s from authority cache", id) + } + if err := a.adminDB.DeleteAdmin(ctx, id); err != nil { + if err := a.reloadAdminResources(ctx); err != nil { + return admin.WrapErrorISE(err, "error reloading admin resources on failed admin remove") + } + return admin.WrapErrorISE(err, "error deleting admin %s", id) + } + return nil +} diff --git a/authority/authority.go b/authority/authority.go index a2dca6e0..aa8698d7 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -7,39 +7,47 @@ import ( "crypto/x509" "encoding/hex" "log" + "strings" "sync" "time" - "github.com/smallstep/certificates/cas" - "github.com/smallstep/certificates/scep" - "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" + adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql" + "github.com/smallstep/certificates/authority/administrator" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/kms" kmsapi "github.com/smallstep/certificates/kms/apiv1" "github.com/smallstep/certificates/kms/sshagentkms" + "github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/templates" + "github.com/smallstep/nosql" "go.step.sm/crypto/pemutil" + "go.step.sm/linkedca" "golang.org/x/crypto/ssh" ) -const ( - legacyAuthority = "step-certificate-authority" -) - // Authority implements the Certificate Authority internal interface. type Authority struct { - config *Config - keyManager kms.KeyManager - provisioners *provisioner.Collection - db db.AuthDB - templates *templates.Templates + config *config.Config + keyManager kms.KeyManager + provisioners *provisioner.Collection + admins *administrator.Collection + db db.AuthDB + adminDB admin.DB + templates *templates.Templates + linkedCAToken string // X509 CA + password []byte + issuerPassword []byte x509CAService cas.CertificateAuthorityService rootX509Certs []*x509.Certificate + rootX509CertPool *x509.CertPool federatedX509Certs []*x509.Certificate certificates *sync.Map @@ -47,6 +55,8 @@ type Authority struct { scepService *scep.Service // SSH CA + sshHostPassword []byte + sshUserPassword []byte sshCAUserCertSignKey ssh.Signer sshCAHostCertSignKey ssh.Signer sshCAUserCerts []ssh.PublicKey @@ -59,21 +69,23 @@ type Authority struct { startTime time.Time // Custom functions - sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error) + sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) - sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]Host, error) + sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) getIdentityFunc provisioner.GetIdentityFunc + + adminMutex sync.RWMutex } // New creates and initiates a new Authority type. -func New(config *Config, opts ...Option) (*Authority, error) { - err := config.Validate() +func New(cfg *config.Config, opts ...Option) (*Authority, error) { + err := cfg.Validate() if err != nil { return nil, err } var a = &Authority{ - config: config, + config: cfg, certificates: new(sync.Map), } @@ -96,7 +108,7 @@ func New(config *Config, opts ...Option) (*Authority, error) { // project without the limitations of the config. func NewEmbedded(opts ...Option) (*Authority, error) { a := &Authority{ - config: &Config{}, + config: &config.Config{}, certificates: new(sync.Map), } @@ -120,7 +132,7 @@ func NewEmbedded(opts ...Option) (*Authority, error) { } // Initialize config required fields. - a.config.init() + a.config.Init() // Initialize authority from options or configuration. if err := a.init(); err != nil { @@ -130,6 +142,65 @@ func NewEmbedded(opts ...Option) (*Authority, error) { return a, nil } +// reloadAdminResources reloads admins and provisioners from the DB. +func (a *Authority) reloadAdminResources(ctx context.Context) error { + var ( + provList provisioner.List + adminList []*linkedca.Admin + ) + if a.config.AuthorityConfig.EnableAdmin { + provs, err := a.adminDB.GetProvisioners(ctx) + if err != nil { + return admin.WrapErrorISE(err, "error getting provisioners to initialize authority") + } + provList, err = provisionerListToCertificates(provs) + if err != nil { + return admin.WrapErrorISE(err, "error converting provisioner list to certificates") + } + adminList, err = a.adminDB.GetAdmins(ctx) + if err != nil { + return admin.WrapErrorISE(err, "error getting admins to initialize authority") + } + } else { + provList = a.config.AuthorityConfig.Provisioners + adminList = a.config.AuthorityConfig.Admins + } + + provisionerConfig, err := a.generateProvisionerConfig(ctx) + if err != nil { + return admin.WrapErrorISE(err, "error generating provisioner config") + } + + // Create provisioner collection. + provClxn := provisioner.NewCollection(provisionerConfig.Audiences) + for _, p := range provList { + if err := p.Init(*provisionerConfig); err != nil { + return err + } + if err := provClxn.Store(p); err != nil { + return err + } + } + // Create admin collection. + adminClxn := administrator.NewCollection(provClxn) + for _, adm := range adminList { + p, ok := provClxn.Load(adm.ProvisionerId) + if !ok { + return admin.NewErrorISE("provisioner %s not found when loading admin %s", + adm.ProvisionerId, adm.Id) + } + if err := adminClxn.Store(adm, p); err != nil { + return err + } + } + + a.config.AuthorityConfig.Provisioners = provList + a.provisioners = provClxn + a.config.AuthorityConfig.Admins = adminList + a.admins = adminClxn + return nil +} + // init performs validation and initializes the fields of an Authority struct. func (a *Authority) init() error { // Check if handler has already been validated/initialized. @@ -139,6 +210,26 @@ func (a *Authority) init() error { var err error + // Set password if they are not set. + var configPassword []byte + if a.config.Password != "" { + configPassword = []byte(a.config.Password) + } + if configPassword != nil && a.password == nil { + a.password = configPassword + } + if a.sshHostPassword == nil { + a.sshHostPassword = a.password + } + if a.sshUserPassword == nil { + a.sshUserPassword = a.password + } + + // Automatically enable admin for all linked cas. + if a.linkedCAToken != "" { + a.config.AuthorityConfig.EnableAdmin = true + } + // Initialize step-ca Database if it's not already initialized with WithDB. // If a.config.DB is nil then a simple, barebones in memory DB will be used. if a.db == nil { @@ -166,6 +257,11 @@ func (a *Authority) init() error { options = *a.config.AuthorityConfig.Options } + // Set the issuer password if passed in the flags. + if options.CertificateIssuer != nil && a.issuerPassword != nil { + options.CertificateIssuer.Password = string(a.issuerPassword) + } + // Read intermediate and create X509 signer for default CAS. if options.Is(casapi.SoftCAS) { options.CertificateChain, err = pemutil.ReadCertificateBundle(a.config.IntermediateCert) @@ -174,7 +270,7 @@ func (a *Authority) init() error { } options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKey: a.config.IntermediateKey, - Password: []byte(a.config.Password), + Password: []byte(a.password), }) if err != nil { return err @@ -216,6 +312,11 @@ func (a *Authority) init() error { a.certificates.Store(hex.EncodeToString(sum[:]), crt) } + a.rootX509CertPool = x509.NewCertPool() + for _, cert := range a.rootX509Certs { + a.rootX509CertPool.AddCert(cert) + } + // Read federated certificates and store them in the certificates map. if len(a.federatedX509Certs) == 0 { a.federatedX509Certs = make([]*x509.Certificate, len(a.config.FederatedRoots)) @@ -238,7 +339,7 @@ func (a *Authority) init() error { if a.config.SSH.HostKey != "" { signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKey: a.config.SSH.HostKey, - Password: []byte(a.config.Password), + Password: []byte(a.sshHostPassword), }) if err != nil { return err @@ -264,7 +365,7 @@ func (a *Authority) init() error { if a.config.SSH.UserKey != "" { signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKey: a.config.SSH.UserKey, - Password: []byte(a.config.Password), + Password: []byte(a.sshUserPassword), }) if err != nil { return err @@ -288,59 +389,52 @@ func (a *Authority) init() error { a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, a.sshCAUserCertSignKey.PublicKey()) } - // Append other public keys + // Append other public keys and add them to the template variables. for _, key := range a.config.SSH.Keys { + publicKey := key.PublicKey() switch key.Type { case provisioner.SSHHostCert: if key.Federated { - a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, key.PublicKey()) + a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, publicKey) } else { - a.sshCAHostCerts = append(a.sshCAHostCerts, key.PublicKey()) + a.sshCAHostCerts = append(a.sshCAHostCerts, publicKey) } case provisioner.SSHUserCert: if key.Federated { - a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, key.PublicKey()) + a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, publicKey) } else { - a.sshCAUserCerts = append(a.sshCAUserCerts, key.PublicKey()) + a.sshCAUserCerts = append(a.sshCAUserCerts, publicKey) } default: return errors.Errorf("unsupported type %s", key.Type) } } + } - // Configure template variables. + // Configure template variables. On the template variables HostFederatedKeys + // and UserFederatedKeys we will skip the actual CA that will be available + // in HostKey and UserKey. + // + // We cannot do it in the previous blocks because this configuration can be + // injected using options. + if a.sshCAHostCertSignKey != nil { tmplVars.SSH.HostKey = a.sshCAHostCertSignKey.PublicKey() - tmplVars.SSH.UserKey = a.sshCAUserCertSignKey.PublicKey() - // On the templates we skip the first one because there's a distinction - // between the main key and federated keys. tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts[1:]...) + } else { + tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts...) + } + if a.sshCAUserCertSignKey != nil { + tmplVars.SSH.UserKey = a.sshCAUserCertSignKey.PublicKey() tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts[1:]...) + } else { + tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts...) } - // Merge global and configuration claims - claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, globalProvisionerClaims) - if err != nil { - return err - } - // TODO: should we also be combining the ssh federated roots here? - // If we rotate ssh roots keys, sshpop provisioner will lose ability to - // validate old SSH certificates, unless they are added as federated certs. - sshKeys, err := a.GetSSHRoots(context.Background()) - if err != nil { - return err - } - // Initialize provisioners - audiences := a.config.getAudiences() - a.provisioners = provisioner.NewCollection(audiences) - config := provisioner.Config{ - Claims: claimer.Claims(), - Audiences: audiences, - DB: a.db, - SSHKeys: &provisioner.SSHKeys{ - UserKeys: sshKeys.UserKeys, - HostKeys: sshKeys.HostKeys, - }, - GetIdentityFunc: a.getIdentityFunc, + // Check if a KMS with decryption capability is required and available + if a.requiresDecrypter() { + if _, ok := a.keyManager.(kmsapi.Decrypter); !ok { + return errors.New("keymanager doesn't provide crypto.Decrypter") + } } // Check if a KMS with decryption capability is required and available @@ -362,7 +456,7 @@ func (a *Authority) init() error { } options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKey: a.config.IntermediateKey, - Password: []byte(a.config.Password), + Password: []byte(a.password), }) if err != nil { return err @@ -371,7 +465,7 @@ func (a *Authority) init() error { if km, ok := a.keyManager.(kmsapi.Decrypter); ok { options.Decrypter, err = km.CreateDecrypter(&kmsapi.CreateDecrypterRequest{ DecryptionKey: a.config.IntermediateKey, - Password: []byte(a.config.Password), + Password: []byte(a.password), }) if err != nil { return err @@ -386,14 +480,56 @@ func (a *Authority) init() error { // TODO: mimick the x509CAService GetCertificateAuthority here too? } - // Store all the provisioners - for _, p := range a.config.AuthorityConfig.Provisioners { - if err := p.Init(config); err != nil { - return err + if a.config.AuthorityConfig.EnableAdmin { + // Initialize step-ca Admin Database if it's not already initialized using + // WithAdminDB. + if a.adminDB == nil { + if a.linkedCAToken == "" { + // Check if AuthConfig already exists + a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID) + if err != nil { + return err + } + } else { + // Use the linkedca client as the admindb. + client, err := newLinkedCAClient(a.linkedCAToken) + if err != nil { + return err + } + // If authorityId is configured make sure it matches the one in the token + if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) { + return errors.New("error initializing linkedca: token authority and configured authority do not match") + } + client.Run() + a.adminDB = client + } } - if err := a.provisioners.Store(p); err != nil { - return err + + provs, err := a.adminDB.GetProvisioners(context.Background()) + if err != nil { + return admin.WrapErrorISE(err, "error loading provisioners to initialize authority") } + if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") { + // Create First Provisioner + prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password)) + if err != nil { + return admin.WrapErrorISE(err, "error creating first provisioner") + } + + // Create first admin + if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{ + ProvisionerId: prov.Id, + Subject: "step", + Type: linkedca.Admin_SUPER_ADMIN, + }); err != nil { + return admin.WrapErrorISE(err, "error creating first admin") + } + } + } + + // Load Provisioners and Admins + if err := a.reloadAdminResources(context.Background()); err != nil { + return err } // Configure templates, currently only ssh templates are supported. @@ -423,6 +559,17 @@ func (a *Authority) GetDatabase() db.AuthDB { return a.db } +// GetAdminDatabase returns the admin database, if one exists. +func (a *Authority) GetAdminDatabase() admin.DB { + return a.adminDB +} + +// IsAdminAPIEnabled returns a boolean indicating whether the Admin API has +// been enabled. +func (a *Authority) IsAdminAPIEnabled() bool { + return a.config.AuthorityConfig.EnableAdmin +} + // Shutdown safely shuts down any clients, databases, etc. held by the Authority. func (a *Authority) Shutdown() error { if err := a.keyManager.Close(); err != nil { @@ -436,6 +583,9 @@ func (a *Authority) CloseForReload() { if err := a.keyManager.Close(); err != nil { log.Printf("error closing the key manager: %v", err) } + if client, ok := a.adminDB.(*linkedCaClient); ok { + client.Stop() + } } // requiresDecrypter returns whether the Authority diff --git a/authority/authority_test.go b/authority/authority_test.go index 618e7939..1e18a24f 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -11,6 +11,7 @@ import ( "net" "reflect" "testing" + "time" "github.com/pkg/errors" "github.com/smallstep/assert" @@ -82,6 +83,10 @@ func testAuthority(t *testing.T, opts ...Option) *Authority { } a, err := New(c, opts...) assert.FatalError(t, err) + // Avoid errors when test tokens are created before the test authority. This + // happens in some tests where we re-create the same authority to test + // special cases without re-creating the token. + a.startTime = a.startTime.Add(-1 * time.Minute) return a } @@ -454,8 +459,6 @@ func TestAuthority_GetSCEPService(t *testing.T) { // getIdentityFunc: tt.fields.getIdentityFunc, // } a, err := New(tt.fields.config) - fmt.Println(err) - fmt.Println(a) if (err != nil) != tt.wantErr { t.Errorf("Authority.New(), error = %v, wantErr %v", err, tt.wantErr) return diff --git a/authority/authorize.go b/authority/authorize.go index f84bd6f5..a4e7e591 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -6,11 +6,15 @@ import ( "crypto/x509" "encoding/hex" "net/http" + "strconv" "strings" + "time" + "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + "go.step.sm/linkedca" "golang.org/x/crypto/ssh" ) @@ -50,7 +54,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. var claims Claims - if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil { + if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken") } @@ -73,25 +77,124 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // Store the token to protect against reuse unless it's skipped. // If we cannot get a token id from the provisioner, just hash the token. if !SkipTokenReuseFromContext(ctx) { - if reuseKey, err := p.GetTokenID(token); err == nil { - if reuseKey == "" { - sum := sha256.Sum256([]byte(token)) - reuseKey = strings.ToLower(hex.EncodeToString(sum[:])) - } - ok, err := a.db.UseToken(reuseKey, token) - if err != nil { - return nil, errs.Wrap(http.StatusInternalServerError, err, - "authority.authorizeToken: failed when attempting to store token") - } - if !ok { - return nil, errs.Unauthorized("authority.authorizeToken: token already used") - } + if err := a.UseToken(token, p); err != nil { + return nil, err } } return p, nil } +// AuthorizeAdminToken authorize an Admin token. +func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) { + jwt, err := jose.ParseSigned(token) + if err != nil { + return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c token") + } + + verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{ + Roots: a.rootX509CertPool, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + if err != nil { + return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, + "adminHandler.authorizeToken; error verifying x5c certificate chain in token") + } + leaf := verifiedChains[0][0] + + if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 { + return nil, admin.NewError(admin.ErrorUnauthorizedType, "adminHandler.authorizeToken; certificate used to sign x5c token cannot be used for digital signature") + } + + // Using the leaf certificates key to validate the claims accomplishes two + // things: + // 1. Asserts that the private key used to sign the token corresponds + // to the public certificate in the `x5c` header of the token. + // 2. Asserts that the claims are valid - have not been tampered with. + var claims jose.Claims + if err := jwt.Claims(leaf.PublicKey, &claims); err != nil { + return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims") + } + + prov, err := a.LoadProvisionerByCertificate(leaf) + if err != nil { + return nil, err + } + + // Check that the token has not been used. + if err := a.UseToken(token, prov); err != nil { + return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token") + } + + // According to "rfc7519 JSON Web Token" acceptable skew should be no + // more than a few minutes. + if err := claims.ValidateWithLeeway(jose.Expected{ + Issuer: prov.GetName(), + Time: time.Now().UTC(), + }, time.Minute); err != nil { + return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "x5c.authorizeToken; invalid x5c claims") + } + + // validate audience: path matches the current path + if r.URL.Path != claims.Audience[0] { + return nil, admin.NewError(admin.ErrorUnauthorizedType, + "x5c.authorizeToken; x5c token has invalid audience "+ + "claim (aud); expected %s, but got %s", r.URL.Path, claims.Audience) + } + + if claims.Subject == "" { + return nil, admin.NewError(admin.ErrorUnauthorizedType, + "x5c.authorizeToken; x5c token subject cannot be empty") + } + + var ( + ok bool + adm *linkedca.Admin + ) + adminFound := false + adminSANs := append([]string{leaf.Subject.CommonName}, leaf.DNSNames...) + adminSANs = append(adminSANs, leaf.EmailAddresses...) + for _, san := range adminSANs { + if adm, ok = a.LoadAdminBySubProv(san, claims.Issuer); ok { + adminFound = true + break + } + } + if !adminFound { + return nil, admin.NewError(admin.ErrorUnauthorizedType, + "adminHandler.authorizeToken; unable to load admin with subject(s) %s and provisioner '%s'", + adminSANs, claims.Issuer) + } + + if strings.HasPrefix(r.URL.Path, "/admin/admins") && (r.Method != "GET") && adm.Type != linkedca.Admin_SUPER_ADMIN { + return nil, admin.NewError(admin.ErrorUnauthorizedType, "must have super admin access to make this request") + } + + return adm, nil +} + +// UseToken stores the token to protect against reuse. +// +// This method currently ignores any error coming from the GetTokenID, but it +// should specifically ignore the error provisioner.ErrAllowTokenReuse. +func (a *Authority) UseToken(token string, prov provisioner.Interface) error { + if reuseKey, err := prov.GetTokenID(token); err == nil { + if reuseKey == "" { + sum := sha256.Sum256([]byte(token)) + reuseKey = strings.ToLower(hex.EncodeToString(sum[:])) + } + ok, err := a.db.UseToken(reuseKey, token) + if err != nil { + return errs.Wrap(http.StatusInternalServerError, err, + "authority.authorizeToken: failed when attempting to store token") + } + if !ok { + return errs.Unauthorized("authority.authorizeToken: token already used") + } + } + return nil +} + // Authorize grabs the method from the context and authorizes the request by // validating the one-time-token. func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) { @@ -159,7 +262,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") } - if err = p.AuthorizeRevoke(ctx, token); err != nil { + if err := p.AuthorizeRevoke(ctx, token); err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") } return nil @@ -171,10 +274,19 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { // // TODO(mariano): should we authorize by default? func (a *Authority) authorizeRenew(cert *x509.Certificate) error { + var err error + var isRevoked bool var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())} // Check the passive revocation table. - isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String()) + serial := cert.SerialNumber.String() + if lca, ok := a.adminDB.(interface { + IsRevoked(string) (bool, error) + }); ok { + isRevoked, err = lca.IsRevoked(serial) + } else { + isRevoked, err = a.db.IsRevoked(serial) + } if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) } @@ -192,6 +304,28 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { return nil } +// authorizeSSHCertificate returns an error if the given certificate is revoked. +func (a *Authority) authorizeSSHCertificate(ctx context.Context, cert *ssh.Certificate) error { + var err error + var isRevoked bool + + serial := strconv.FormatUint(cert.Serial, 10) + if lca, ok := a.adminDB.(interface { + IsSSHRevoked(string) (bool, error) + }); ok { + isRevoked, err = lca.IsSSHRevoked(serial) + } else { + isRevoked, err = a.db.IsSSHRevoked(serial) + } + if err != nil { + return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHCertificate", errs.WithKeyVal("serialNumber", serial)) + } + if isRevoked { + return errs.Unauthorized("authority.authorizeSSHCertificate: certificate has been revoked", errs.WithKeyVal("serialNumber", serial)) + } + return nil +} + // authorizeSSHSign loads the provisioner from the token, checks that it has not // been used again and calls the provisioner AuthorizeSSHSign method. Returns a // list of methods to apply to the signing flow. diff --git a/authority/authorize_test.go b/authority/authorize_test.go index f20e2976..6d524a25 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -822,7 +822,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { return &authorizeTest{ auth: a, cert: renewDisabledCrt, - err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), + err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"), code: http.StatusUnauthorized, } }, @@ -917,7 +917,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } - if err = cert.SignCert(rand.Reader, signer); err != nil { + if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } return cert, jwk, nil diff --git a/authority/config.go b/authority/config.go index 6baec855..744ca5e7 100644 --- a/authority/config.go +++ b/authority/config.go @@ -1,298 +1,46 @@ package authority -import ( - "encoding/json" - "fmt" - "net" - "os" - "time" +import "github.com/smallstep/certificates/authority/config" - "github.com/pkg/errors" - "github.com/smallstep/certificates/authority/provisioner" - cas "github.com/smallstep/certificates/cas/apiv1" - "github.com/smallstep/certificates/db" - kms "github.com/smallstep/certificates/kms/apiv1" - "github.com/smallstep/certificates/templates" -) +// Config is an alias to support older APIs. +type Config = config.Config -var ( - // DefaultTLSOptions represents the default TLS version as well as the cipher - // suites used in the TLS certificates. - DefaultTLSOptions = TLSOptions{ - CipherSuites: CipherSuites{ - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - }, - MinVersion: 1.2, - MaxVersion: 1.2, - Renegotiation: false, - } - defaultBackdate = time.Minute - defaultDisableRenewal = false - defaultEnableSSHCA = false - globalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, - MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &defaultEnableSSHCA, - } -) +// LoadConfiguration is an alias to support older APIs. +var LoadConfiguration = config.LoadConfiguration -// Config represents the CA configuration and it's mapped to a JSON object. -type Config struct { - Root multiString `json:"root"` - FederatedRoots []string `json:"federatedRoots"` - IntermediateCert string `json:"crt"` - IntermediateKey string `json:"key"` - Address string `json:"address"` - InsecureAddress string `json:"insecureAddress"` - DNSNames []string `json:"dnsNames"` - KMS *kms.Options `json:"kms,omitempty"` - SSH *SSHConfig `json:"ssh,omitempty"` - Logger json.RawMessage `json:"logger,omitempty"` - DB *db.Config `json:"db,omitempty"` - Monitoring json.RawMessage `json:"monitoring,omitempty"` - AuthorityConfig *AuthConfig `json:"authority,omitempty"` - TLS *TLSOptions `json:"tls,omitempty"` - Password string `json:"password,omitempty"` - Templates *templates.Templates `json:"templates,omitempty"` -} +// AuthConfig is an alias to support older APIs. +type AuthConfig = config.AuthConfig -// ASN1DN contains ASN1.DN attributes that are used in Subject and Issuer -// x509 Certificate blocks. -type ASN1DN struct { - Country string `json:"country,omitempty" step:"country"` - Organization string `json:"organization,omitempty" step:"organization"` - OrganizationalUnit string `json:"organizationalUnit,omitempty" step:"organizationalUnit"` - Locality string `json:"locality,omitempty" step:"locality"` - Province string `json:"province,omitempty" step:"province"` - StreetAddress string `json:"streetAddress,omitempty" step:"streetAddress"` - CommonName string `json:"commonName,omitempty" step:"commonName"` -} +// TLS -// AuthConfig represents the configuration options for the authority. An -// underlaying registration authority can also be configured using the -// cas.Options. -type AuthConfig struct { - *cas.Options - Provisioners provisioner.List `json:"provisioners"` - Template *ASN1DN `json:"template,omitempty"` - Claims *provisioner.Claims `json:"claims,omitempty"` - DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` - Backdate *provisioner.Duration `json:"backdate,omitempty"` -} +// ASN1DN is an alias to support older APIs. +type ASN1DN = config.ASN1DN -// init initializes the required fields in the AuthConfig if they are not -// provided. -func (c *AuthConfig) init() { - if c.Provisioners == nil { - c.Provisioners = provisioner.List{} - } - if c.Template == nil { - c.Template = &ASN1DN{} - } - if c.Backdate == nil { - c.Backdate = &provisioner.Duration{ - Duration: defaultBackdate, - } - } -} +// DefaultTLSOptions is an alias to support older APIs. +var DefaultTLSOptions = config.DefaultTLSOptions -// Validate validates the authority configuration. -func (c *AuthConfig) Validate(audiences provisioner.Audiences) error { - if c == nil { - return errors.New("authority cannot be undefined") - } +// TLSOptions is an alias to support older APIs. +type TLSOptions = config.TLSOptions - // Initialize required fields. - c.init() +// CipherSuites is an alias to support older APIs. +type CipherSuites = config.CipherSuites - // Check that only one K8sSA is enabled - var k8sCount int - for _, p := range c.Provisioners { - if p.GetType() == provisioner.TypeK8sSA { - k8sCount++ - } - } - if k8sCount > 1 { - return errors.New("cannot have more than one kubernetes service account provisioner") - } +// SSH - if c.Backdate.Duration < 0 { - return errors.New("authority.backdate cannot be less than 0") - } +// SSHConfig is an alias to support older APIs. +type SSHConfig = config.SSHConfig - return nil -} +// Bastion is an alias to support older APIs. +type Bastion = config.Bastion -// LoadConfiguration parses the given filename in JSON format and returns the -// configuration struct. -func LoadConfiguration(filename string) (*Config, error) { - f, err := os.Open(filename) - if err != nil { - return nil, errors.Wrapf(err, "error opening %s", filename) - } - defer f.Close() +// HostTag is an alias to support older APIs. +type HostTag = config.HostTag - var c Config - if err := json.NewDecoder(f).Decode(&c); err != nil { - return nil, errors.Wrapf(err, "error parsing %s", filename) - } +// Host is an alias to support older APIs. +type Host = config.Host - c.init() +// SSHPublicKey is an alias to support older APIs. +type SSHPublicKey = config.SSHPublicKey - return &c, nil -} - -// initializes the minimal configuration required to create an authority. This -// is mainly used on embedded authorities. -func (c *Config) init() { - if c.DNSNames == nil { - c.DNSNames = []string{"localhost", "127.0.0.1", "::1"} - } - if c.TLS == nil { - c.TLS = &DefaultTLSOptions - } - if c.AuthorityConfig == nil { - c.AuthorityConfig = &AuthConfig{} - } - c.AuthorityConfig.init() -} - -// Save saves the configuration to the given filename. -func (c *Config) Save(filename string) error { - f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return errors.Wrapf(err, "error opening %s", filename) - } - defer f.Close() - - enc := json.NewEncoder(f) - enc.SetIndent("", "\t") - return errors.Wrapf(enc.Encode(c), "error writing %s", filename) -} - -// Validate validates the configuration. -func (c *Config) Validate() error { - switch { - case c.Address == "": - return errors.New("address cannot be empty") - - case len(c.DNSNames) == 0: - return errors.New("dnsNames cannot be empty") - } - - // Options holds the RA/CAS configuration. - ra := c.AuthorityConfig.Options - // The default RA/CAS requires root, crt and key. - if ra.Is(cas.SoftCAS) { - switch { - case c.Root.HasEmpties(): - return errors.New("root cannot be empty") - case c.IntermediateCert == "": - return errors.New("crt cannot be empty") - case c.IntermediateKey == "": - return errors.New("key cannot be empty") - } - } - - // Validate address (a port is required) - if _, _, err := net.SplitHostPort(c.Address); err != nil { - return errors.Errorf("invalid address %s", c.Address) - } - - // Validate insecure address if it is configured - if c.InsecureAddress != "" { - if _, _, err := net.SplitHostPort(c.InsecureAddress); err != nil { - return errors.Errorf("invalid address %s", c.InsecureAddress) - } - } - - if c.TLS == nil { - c.TLS = &DefaultTLSOptions - } else { - if len(c.TLS.CipherSuites) == 0 { - c.TLS.CipherSuites = DefaultTLSOptions.CipherSuites - } - if c.TLS.MaxVersion == 0 { - c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion - } - if c.TLS.MinVersion == 0 { - c.TLS.MinVersion = c.TLS.MaxVersion - } - if c.TLS.MinVersion > c.TLS.MaxVersion { - return errors.New("tls minVersion cannot exceed tls maxVersion") - } - c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation - } - - // Validate KMS options, nil is ok. - if err := c.KMS.Validate(); err != nil { - return err - } - - // Validate RA/CAS options, nil is ok. - if err := ra.Validate(); err != nil { - return err - } - - // Validate ssh: nil is ok - if err := c.SSH.Validate(); err != nil { - return err - } - - // Validate templates: nil is ok - if err := c.Templates.Validate(); err != nil { - return err - } - - return c.AuthorityConfig.Validate(c.getAudiences()) -} - -// getAudiences returns the legacy and possible urls without the ports that will -// be used as the default provisioner audiences. The CA might have proxies in -// front so we cannot rely on the port. -func (c *Config) getAudiences() provisioner.Audiences { - audiences := provisioner.Audiences{ - Sign: []string{legacyAuthority}, - Revoke: []string{legacyAuthority}, - SSHSign: []string{}, - SSHRevoke: []string{}, - SSHRenew: []string{}, - } - - for _, name := range c.DNSNames { - audiences.Sign = append(audiences.Sign, - fmt.Sprintf("https://%s/1.0/sign", name), - fmt.Sprintf("https://%s/sign", name), - fmt.Sprintf("https://%s/1.0/ssh/sign", name), - fmt.Sprintf("https://%s/ssh/sign", name)) - audiences.Revoke = append(audiences.Revoke, - fmt.Sprintf("https://%s/1.0/revoke", name), - fmt.Sprintf("https://%s/revoke", name)) - audiences.SSHSign = append(audiences.SSHSign, - fmt.Sprintf("https://%s/1.0/ssh/sign", name), - fmt.Sprintf("https://%s/ssh/sign", name), - fmt.Sprintf("https://%s/1.0/sign", name), - fmt.Sprintf("https://%s/sign", name)) - audiences.SSHRevoke = append(audiences.SSHRevoke, - fmt.Sprintf("https://%s/1.0/ssh/revoke", name), - fmt.Sprintf("https://%s/ssh/revoke", name)) - audiences.SSHRenew = append(audiences.SSHRenew, - fmt.Sprintf("https://%s/1.0/ssh/renew", name), - fmt.Sprintf("https://%s/ssh/renew", name)) - audiences.SSHRekey = append(audiences.SSHRekey, - fmt.Sprintf("https://%s/1.0/ssh/rekey", name), - fmt.Sprintf("https://%s/ssh/rekey", name)) - } - - return audiences -} +// SSHKeys is an alias to support older APIs. +type SSHKeys = config.SSHKeys diff --git a/authority/config/config.go b/authority/config/config.go new file mode 100644 index 00000000..75c32994 --- /dev/null +++ b/authority/config/config.go @@ -0,0 +1,297 @@ +package config + +import ( + "encoding/json" + "fmt" + "net" + "os" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + cas "github.com/smallstep/certificates/cas/apiv1" + "github.com/smallstep/certificates/db" + kms "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/certificates/templates" + "go.step.sm/linkedca" +) + +const ( + legacyAuthority = "step-certificate-authority" +) + +var ( + // DefaultBackdate length of time to backdate certificates to avoid + // clock skew validation issues. + DefaultBackdate = time.Minute + // DefaultDisableRenewal disables renewals per provisioner. + DefaultDisableRenewal = false + // DefaultEnableSSHCA enable SSH CA features per provisioner or globally + // for all provisioners. + DefaultEnableSSHCA = false + // GlobalProvisionerClaims default claims for the Authority. Can be overridden + // by provisioner specific claims. + GlobalProvisionerClaims = provisioner.Claims{ + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DisableRenewal: &DefaultDisableRenewal, + MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &DefaultEnableSSHCA, + } +) + +// Config represents the CA configuration and it's mapped to a JSON object. +type Config struct { + Root multiString `json:"root"` + FederatedRoots []string `json:"federatedRoots"` + IntermediateCert string `json:"crt"` + IntermediateKey string `json:"key"` + Address string `json:"address"` + InsecureAddress string `json:"insecureAddress"` + DNSNames []string `json:"dnsNames"` + KMS *kms.Options `json:"kms,omitempty"` + SSH *SSHConfig `json:"ssh,omitempty"` + Logger json.RawMessage `json:"logger,omitempty"` + DB *db.Config `json:"db,omitempty"` + Monitoring json.RawMessage `json:"monitoring,omitempty"` + AuthorityConfig *AuthConfig `json:"authority,omitempty"` + TLS *TLSOptions `json:"tls,omitempty"` + Password string `json:"password,omitempty"` + Templates *templates.Templates `json:"templates,omitempty"` +} + +// ASN1DN contains ASN1.DN attributes that are used in Subject and Issuer +// x509 Certificate blocks. +type ASN1DN struct { + Country string `json:"country,omitempty"` + Organization string `json:"organization,omitempty"` + OrganizationalUnit string `json:"organizationalUnit,omitempty"` + Locality string `json:"locality,omitempty"` + Province string `json:"province,omitempty"` + StreetAddress string `json:"streetAddress,omitempty"` + SerialNumber string `json:"serialNumber,omitempty"` + CommonName string `json:"commonName,omitempty"` +} + +// AuthConfig represents the configuration options for the authority. An +// underlaying registration authority can also be configured using the +// cas.Options. +type AuthConfig struct { + *cas.Options + AuthorityID string `json:"authorityId,omitempty"` + DeploymentType string `json:"deploymentType,omitempty"` + Provisioners provisioner.List `json:"provisioners,omitempty"` + Admins []*linkedca.Admin `json:"-"` + Template *ASN1DN `json:"template,omitempty"` + Claims *provisioner.Claims `json:"claims,omitempty"` + DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` + Backdate *provisioner.Duration `json:"backdate,omitempty"` + EnableAdmin bool `json:"enableAdmin,omitempty"` +} + +// init initializes the required fields in the AuthConfig if they are not +// provided. +func (c *AuthConfig) init() { + if c.Provisioners == nil { + c.Provisioners = provisioner.List{} + } + if c.Template == nil { + c.Template = &ASN1DN{} + } + if c.Backdate == nil { + c.Backdate = &provisioner.Duration{ + Duration: DefaultBackdate, + } + } +} + +// Validate validates the authority configuration. +func (c *AuthConfig) Validate(audiences provisioner.Audiences) error { + if c == nil { + return errors.New("authority cannot be undefined") + } + + // Initialize required fields. + c.init() + + // Check that only one K8sSA is enabled + var k8sCount int + for _, p := range c.Provisioners { + if p.GetType() == provisioner.TypeK8sSA { + k8sCount++ + } + } + if k8sCount > 1 { + return errors.New("cannot have more than one kubernetes service account provisioner") + } + + if c.Backdate.Duration < 0 { + return errors.New("authority.backdate cannot be less than 0") + } + + return nil +} + +// LoadConfiguration parses the given filename in JSON format and returns the +// configuration struct. +func LoadConfiguration(filename string) (*Config, error) { + f, err := os.Open(filename) + if err != nil { + return nil, errors.Wrapf(err, "error opening %s", filename) + } + defer f.Close() + + var c Config + if err := json.NewDecoder(f).Decode(&c); err != nil { + return nil, errors.Wrapf(err, "error parsing %s", filename) + } + + c.Init() + + return &c, nil +} + +// Init initializes the minimal configuration required to create an authority. This +// is mainly used on embedded authorities. +func (c *Config) Init() { + if c.DNSNames == nil { + c.DNSNames = []string{"localhost", "127.0.0.1", "::1"} + } + if c.TLS == nil { + c.TLS = &DefaultTLSOptions + } + if c.AuthorityConfig == nil { + c.AuthorityConfig = &AuthConfig{} + } + c.AuthorityConfig.init() +} + +// Save saves the configuration to the given filename. +func (c *Config) Save(filename string) error { + f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return errors.Wrapf(err, "error opening %s", filename) + } + defer f.Close() + + enc := json.NewEncoder(f) + enc.SetIndent("", "\t") + return errors.Wrapf(enc.Encode(c), "error writing %s", filename) +} + +// Validate validates the configuration. +func (c *Config) Validate() error { + switch { + case c.Address == "": + return errors.New("address cannot be empty") + case len(c.DNSNames) == 0: + return errors.New("dnsNames cannot be empty") + case c.AuthorityConfig == nil: + return errors.New("authority cannot be nil") + } + + // Options holds the RA/CAS configuration. + ra := c.AuthorityConfig.Options + // The default RA/CAS requires root, crt and key. + if ra.Is(cas.SoftCAS) { + switch { + case c.Root.HasEmpties(): + return errors.New("root cannot be empty") + case c.IntermediateCert == "": + return errors.New("crt cannot be empty") + case c.IntermediateKey == "": + return errors.New("key cannot be empty") + } + } + + // Validate address (a port is required) + if _, _, err := net.SplitHostPort(c.Address); err != nil { + return errors.Errorf("invalid address %s", c.Address) + } + + if c.TLS == nil { + c.TLS = &DefaultTLSOptions + } else { + if len(c.TLS.CipherSuites) == 0 { + c.TLS.CipherSuites = DefaultTLSOptions.CipherSuites + } + if c.TLS.MaxVersion == 0 { + c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion + } + if c.TLS.MinVersion == 0 { + c.TLS.MinVersion = DefaultTLSOptions.MinVersion + } + if c.TLS.MinVersion > c.TLS.MaxVersion { + return errors.New("tls minVersion cannot exceed tls maxVersion") + } + c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation + } + + // Validate KMS options, nil is ok. + if err := c.KMS.Validate(); err != nil { + return err + } + + // Validate RA/CAS options, nil is ok. + if err := ra.Validate(); err != nil { + return err + } + + // Validate ssh: nil is ok + if err := c.SSH.Validate(); err != nil { + return err + } + + // Validate templates: nil is ok + if err := c.Templates.Validate(); err != nil { + return err + } + + return c.AuthorityConfig.Validate(c.GetAudiences()) +} + +// GetAudiences returns the legacy and possible urls without the ports that will +// be used as the default provisioner audiences. The CA might have proxies in +// front so we cannot rely on the port. +func (c *Config) GetAudiences() provisioner.Audiences { + audiences := provisioner.Audiences{ + Sign: []string{legacyAuthority}, + Revoke: []string{legacyAuthority}, + SSHSign: []string{}, + SSHRevoke: []string{}, + SSHRenew: []string{}, + } + + for _, name := range c.DNSNames { + audiences.Sign = append(audiences.Sign, + fmt.Sprintf("https://%s/1.0/sign", name), + fmt.Sprintf("https://%s/sign", name), + fmt.Sprintf("https://%s/1.0/ssh/sign", name), + fmt.Sprintf("https://%s/ssh/sign", name)) + audiences.Revoke = append(audiences.Revoke, + fmt.Sprintf("https://%s/1.0/revoke", name), + fmt.Sprintf("https://%s/revoke", name)) + audiences.SSHSign = append(audiences.SSHSign, + fmt.Sprintf("https://%s/1.0/ssh/sign", name), + fmt.Sprintf("https://%s/ssh/sign", name), + fmt.Sprintf("https://%s/1.0/sign", name), + fmt.Sprintf("https://%s/sign", name)) + audiences.SSHRevoke = append(audiences.SSHRevoke, + fmt.Sprintf("https://%s/1.0/ssh/revoke", name), + fmt.Sprintf("https://%s/ssh/revoke", name)) + audiences.SSHRenew = append(audiences.SSHRenew, + fmt.Sprintf("https://%s/1.0/ssh/renew", name), + fmt.Sprintf("https://%s/ssh/renew", name)) + audiences.SSHRekey = append(audiences.SSHRekey, + fmt.Sprintf("https://%s/1.0/ssh/rekey", name), + fmt.Sprintf("https://%s/ssh/rekey", name)) + } + + return audiences +} diff --git a/authority/config_test.go b/authority/config/config_test.go similarity index 74% rename from authority/config_test.go rename to authority/config/config_test.go index 87cd3fba..a5b60513 100644 --- a/authority/config_test.go +++ b/authority/config/config_test.go @@ -1,4 +1,4 @@ -package authority +package config import ( "fmt" @@ -8,12 +8,14 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" + + _ "github.com/smallstep/certificates/cas" ) func TestConfigValidate(t *testing.T) { - maxjwk, err := jose.ReadKey("testdata/secrets/max_pub.jwk") + maxjwk, err := jose.ReadKey("../testdata/secrets/max_pub.jwk") assert.FatalError(t, err) - clijwk, err := jose.ReadKey("testdata/secrets/step_cli_key_pub.jwk") + clijwk, err := jose.ReadKey("../testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) ac := &AuthConfig{ Provisioners: provisioner.List{ @@ -39,9 +41,9 @@ func TestConfigValidate(t *testing.T) { "empty-address": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ - Root: []string{"testdata/secrets/root_ca.crt"}, - IntermediateCert: "testdata/secrets/intermediate_ca.crt", - IntermediateKey: "testdata/secrets/intermediate_ca_key", + Root: []string{"../testdata/secrets/root_ca.crt"}, + IntermediateCert: "../testdata/secrets/intermediate_ca.crt", + IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, @@ -53,9 +55,9 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1", - Root: []string{"testdata/secrets/root_ca.crt"}, - IntermediateCert: "testdata/secrets/intermediate_ca.crt", - IntermediateKey: "testdata/secrets/intermediate_ca_key", + Root: []string{"../testdata/secrets/root_ca.crt"}, + IntermediateCert: "../testdata/secrets/intermediate_ca.crt", + IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, @@ -67,8 +69,8 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - IntermediateCert: "testdata/secrets/intermediate_ca.crt", - IntermediateKey: "testdata/secrets/intermediate_ca_key", + IntermediateCert: "../testdata/secrets/intermediate_ca.crt", + IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, @@ -80,8 +82,8 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: []string{"testdata/secrets/root_ca.crt"}, - IntermediateKey: "testdata/secrets/intermediate_ca_key", + Root: []string{"../testdata/secrets/root_ca.crt"}, + IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, @@ -93,8 +95,8 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: []string{"testdata/secrets/root_ca.crt"}, - IntermediateCert: "testdata/secrets/intermediate_ca.crt", + Root: []string{"../testdata/secrets/root_ca.crt"}, + IntermediateCert: "../testdata/secrets/intermediate_ca.crt", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, @@ -106,9 +108,9 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: []string{"testdata/secrets/root_ca.crt"}, - IntermediateCert: "testdata/secrets/intermediate_ca.crt", - IntermediateKey: "testdata/secrets/intermediate_ca_key", + Root: []string{"../testdata/secrets/root_ca.crt"}, + IntermediateCert: "../testdata/secrets/intermediate_ca.crt", + IntermediateKey: "../testdata/secrets/intermediate_ca_key", Password: "pass", AuthorityConfig: ac, }, @@ -119,9 +121,9 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: []string{"testdata/secrets/root_ca.crt"}, - IntermediateCert: "testdata/secrets/intermediate_ca.crt", - IntermediateKey: "testdata/secrets/intermediate_ca_key", + Root: []string{"../testdata/secrets/root_ca.crt"}, + IntermediateCert: "../testdata/secrets/intermediate_ca.crt", + IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, @@ -133,9 +135,9 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: []string{"testdata/secrets/root_ca.crt"}, - IntermediateCert: "testdata/secrets/intermediate_ca.crt", - IntermediateKey: "testdata/secrets/intermediate_ca_key", + Root: []string{"../testdata/secrets/root_ca.crt"}, + IntermediateCert: "../testdata/secrets/intermediate_ca.crt", + IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, @@ -148,9 +150,9 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: []string{"testdata/secrets/root_ca.crt"}, - IntermediateCert: "testdata/secrets/intermediate_ca.crt", - IntermediateKey: "testdata/secrets/intermediate_ca_key", + Root: []string{"../testdata/secrets/root_ca.crt"}, + IntermediateCert: "../testdata/secrets/intermediate_ca.crt", + IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, @@ -177,9 +179,9 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: []string{"testdata/secrets/root_ca.crt"}, - IntermediateCert: "testdata/secrets/intermediate_ca.crt", - IntermediateKey: "testdata/secrets/intermediate_ca_key", + Root: []string{"../testdata/secrets/root_ca.crt"}, + IntermediateCert: "../testdata/secrets/intermediate_ca.crt", + IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, @@ -207,6 +209,8 @@ func TestConfigValidate(t *testing.T) { } } else { if assert.Nil(t, tc.err) { + fmt.Printf("tc.tls = %+v\n", tc.tls) + fmt.Printf("*tc.config.TLS = %+v\n", *tc.config.TLS) assert.Equals(t, *tc.config.TLS, tc.tls) } } @@ -224,9 +228,9 @@ func TestAuthConfigValidate(t *testing.T) { CommonName: "test", } - maxjwk, err := jose.ReadKey("testdata/secrets/max_pub.jwk") + maxjwk, err := jose.ReadKey("../testdata/secrets/max_pub.jwk") assert.FatalError(t, err) - clijwk, err := jose.ReadKey("testdata/secrets/step_cli_key_pub.jwk") + clijwk, err := jose.ReadKey("../testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) p := provisioner.List{ &provisioner.JWK{ diff --git a/authority/config/ssh.go b/authority/config/ssh.go new file mode 100644 index 00000000..4ba1bb38 --- /dev/null +++ b/authority/config/ssh.go @@ -0,0 +1,94 @@ +package config + +import ( + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/crypto/jose" + "golang.org/x/crypto/ssh" +) + +// SSHConfig contains the user and host keys. +type SSHConfig struct { + HostKey string `json:"hostKey"` + UserKey string `json:"userKey"` + Keys []*SSHPublicKey `json:"keys,omitempty"` + AddUserPrincipal string `json:"addUserPrincipal,omitempty"` + AddUserCommand string `json:"addUserCommand,omitempty"` + Bastion *Bastion `json:"bastion,omitempty"` +} + +// Bastion contains the custom properties used on bastion. +type Bastion struct { + Hostname string `json:"hostname"` + User string `json:"user,omitempty"` + Port string `json:"port,omitempty"` + Command string `json:"cmd,omitempty"` + Flags string `json:"flags,omitempty"` +} + +// HostTag are tagged with k,v pairs. These tags are how a user is ultimately +// associated with a host. +type HostTag struct { + ID string + Name string + Value string +} + +// Host defines expected attributes for an ssh host. +type Host struct { + HostID string `json:"hid"` + HostTags []HostTag `json:"host_tags"` + Hostname string `json:"hostname"` +} + +// Validate checks the fields in SSHConfig. +func (c *SSHConfig) Validate() error { + if c == nil { + return nil + } + for _, k := range c.Keys { + if err := k.Validate(); err != nil { + return err + } + } + return nil +} + +// SSHPublicKey contains a public key used by federated CAs to keep old signing +// keys for this ca. +type SSHPublicKey struct { + Type string `json:"type"` + Federated bool `json:"federated"` + Key jose.JSONWebKey `json:"key"` + publicKey ssh.PublicKey +} + +// Validate checks the fields in SSHPublicKey. +func (k *SSHPublicKey) Validate() error { + switch { + case k.Type == "": + return errors.New("type cannot be empty") + case k.Type != provisioner.SSHHostCert && k.Type != provisioner.SSHUserCert: + return errors.Errorf("invalid type %s, it must be user or host", k.Type) + case !k.Key.IsPublic(): + return errors.New("invalid key type, it must be a public key") + } + + key, err := ssh.NewPublicKey(k.Key.Key) + if err != nil { + return errors.Wrap(err, "error creating ssh key") + } + k.publicKey = key + return nil +} + +// PublicKey returns the ssh public key. +func (k *SSHPublicKey) PublicKey() ssh.PublicKey { + return k.publicKey +} + +// SSHKeys represents the SSH User and Host public keys. +type SSHKeys struct { + UserKeys []ssh.PublicKey + HostKeys []ssh.PublicKey +} diff --git a/authority/config/ssh_test.go b/authority/config/ssh_test.go new file mode 100644 index 00000000..2c4c8eac --- /dev/null +++ b/authority/config/ssh_test.go @@ -0,0 +1,73 @@ +package config + +import ( + "reflect" + "testing" + + "github.com/smallstep/assert" + "go.step.sm/crypto/jose" + "golang.org/x/crypto/ssh" +) + +func TestSSHPublicKey_Validate(t *testing.T) { + key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) + assert.FatalError(t, err) + + type fields struct { + Type string + Federated bool + Key jose.JSONWebKey + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"user", fields{"user", true, key.Public()}, false}, + {"host", fields{"host", false, key.Public()}, false}, + {"empty", fields{"", true, key.Public()}, true}, + {"badType", fields{"bad", false, key.Public()}, true}, + {"badKey", fields{"user", false, *key}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &SSHPublicKey{ + Type: tt.fields.Type, + Federated: tt.fields.Federated, + Key: tt.fields.Key, + } + if err := k.Validate(); (err != nil) != tt.wantErr { + t.Errorf("SSHPublicKey.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSSHPublicKey_PublicKey(t *testing.T) { + key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) + assert.FatalError(t, err) + pub, err := ssh.NewPublicKey(key.Public().Key) + assert.FatalError(t, err) + + type fields struct { + publicKey ssh.PublicKey + } + tests := []struct { + name string + fields fields + want ssh.PublicKey + }{ + {"ok", fields{pub}, pub}, + {"nil", fields{nil}, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &SSHPublicKey{ + publicKey: tt.fields.publicKey, + } + if got := k.PublicKey(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("SSHPublicKey.PublicKey() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/tls_options.go b/authority/config/tls_options.go similarity index 56% rename from authority/tls_options.go rename to authority/config/tls_options.go index 3edde605..0db202e5 100644 --- a/authority/tls_options.go +++ b/authority/config/tls_options.go @@ -1,4 +1,4 @@ -package authority +package config import ( "crypto/tls" @@ -15,8 +15,9 @@ var ( // DefaultTLSRenegotiation default TLS connection renegotiation policy. DefaultTLSRenegotiation = false // Never regnegotiate. // DefaultTLSCipherSuites specifies default step ciphersuite(s). + // These are TLS 1.0 - 1.2 cipher suites. DefaultTLSCipherSuites = CipherSuites{ - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", } // ApprovedTLSCipherSuites smallstep approved ciphersuites. @@ -26,13 +27,21 @@ var ( "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", + } + // DefaultTLSOptions represents the default TLS version as well as the cipher + // suites used in the TLS certificates. + DefaultTLSOptions = TLSOptions{ + CipherSuites: DefaultTLSCipherSuites, + MinVersion: DefaultTLSMinVersion, + MaxVersion: DefaultTLSMaxVersion, + Renegotiation: DefaultTLSRenegotiation, } ) @@ -107,27 +116,38 @@ func (c CipherSuites) Value() []uint16 { // cipherSuites has the list of supported cipher suites. var cipherSuites = map[string]uint16{ - "TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA, - "TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, - "TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, - "TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, - "TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256, - "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, - "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + // TLS 1.0 - 1.2 cipher suites. + "TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA, + "TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + "TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, + "TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, + "TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256, + "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + "TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + + // TLS 1.3 cipher sutes. + "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, + "TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384, + "TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256, + + // Legacy names. + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, } // TLSOptions represents the TLS options that can be specified on *tls.Config diff --git a/authority/tls_options_test.go b/authority/config/tls_options_test.go similarity index 99% rename from authority/tls_options_test.go rename to authority/config/tls_options_test.go index 96c58c5d..d7ccb20b 100644 --- a/authority/tls_options_test.go +++ b/authority/config/tls_options_test.go @@ -1,4 +1,4 @@ -package authority +package config import ( "crypto/tls" diff --git a/authority/types.go b/authority/config/types.go similarity index 97% rename from authority/types.go rename to authority/config/types.go index 0d0f2a90..5ca3b15f 100644 --- a/authority/types.go +++ b/authority/config/types.go @@ -1,4 +1,4 @@ -package authority +package config import ( "encoding/json" @@ -25,7 +25,7 @@ func (s multiString) HasEmpties() bool { return true } for _, ss := range s { - if len(ss) == 0 { + if ss == "" { return true } } diff --git a/authority/types_test.go b/authority/config/types_test.go similarity index 99% rename from authority/types_test.go rename to authority/config/types_test.go index 352c253f..b1a874d6 100644 --- a/authority/types_test.go +++ b/authority/config/types_test.go @@ -1,4 +1,4 @@ -package authority +package config import ( "reflect" diff --git a/authority/export.go b/authority/export.go new file mode 100644 index 00000000..8a5a257f --- /dev/null +++ b/authority/export.go @@ -0,0 +1,284 @@ +package authority + +import ( + "encoding/json" + "io/ioutil" + "net/url" + "path/filepath" + "strings" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/cli-utils/config" + "go.step.sm/linkedca" + "google.golang.org/protobuf/types/known/structpb" +) + +// Export creates a linkedca configuration form the current ca.json and loaded +// authorities. +// +// Note that export will not export neither the pki password nor the certificate +// issuer password. +func (a *Authority) Export() (c *linkedca.Configuration, err error) { + // Recover from panics + defer func() { + if r := recover(); r != nil { + err = r.(error) + } + }() + + files := make(map[string][]byte) + + // The exported configuration should not include the password in it. + c = &linkedca.Configuration{ + Version: "1.0", + Root: mustReadFilesOrURIs(a.config.Root, files), + FederatedRoots: mustReadFilesOrURIs(a.config.FederatedRoots, files), + Intermediate: mustReadFileOrURI(a.config.IntermediateCert, files), + IntermediateKey: mustReadFileOrURI(a.config.IntermediateKey, files), + Address: a.config.Address, + InsecureAddress: a.config.InsecureAddress, + DnsNames: a.config.DNSNames, + Db: mustMarshalToStruct(a.config.DB), + Logger: mustMarshalToStruct(a.config.Logger), + Monitoring: mustMarshalToStruct(a.config.Monitoring), + Authority: &linkedca.Authority{ + Id: a.config.AuthorityConfig.AuthorityID, + EnableAdmin: a.config.AuthorityConfig.EnableAdmin, + DisableIssuedAtCheck: a.config.AuthorityConfig.DisableIssuedAtCheck, + Backdate: mustDuration(a.config.AuthorityConfig.Backdate), + DeploymentType: a.config.AuthorityConfig.DeploymentType, + }, + Files: files, + } + + // SSH + if v := a.config.SSH; v != nil { + c.Ssh = &linkedca.SSH{ + HostKey: mustReadFileOrURI(v.HostKey, files), + UserKey: mustReadFileOrURI(v.UserKey, files), + AddUserPrincipal: v.AddUserPrincipal, + AddUserCommand: v.AddUserCommand, + } + for _, k := range v.Keys { + typ, ok := linkedca.SSHPublicKey_Type_value[strings.ToUpper(k.Type)] + if !ok { + return nil, errors.Errorf("unsupported ssh key type %s", k.Type) + } + c.Ssh.Keys = append(c.Ssh.Keys, &linkedca.SSHPublicKey{ + Type: linkedca.SSHPublicKey_Type(typ), + Federated: k.Federated, + Key: mustMarshalToStruct(k), + }) + } + if b := v.Bastion; b != nil { + c.Ssh.Bastion = &linkedca.Bastion{ + Hostname: b.Hostname, + User: b.User, + Port: b.Port, + Command: b.Command, + Flags: b.Flags, + } + } + } + + // KMS + if v := a.config.KMS; v != nil { + var typ int32 + var ok bool + if v.Type == "" { + typ = int32(linkedca.KMS_SOFTKMS) + } else { + typ, ok = linkedca.KMS_Type_value[strings.ToUpper(v.Type)] + if !ok { + return nil, errors.Errorf("unsupported kms type %s", v.Type) + } + } + c.Kms = &linkedca.KMS{ + Type: linkedca.KMS_Type(typ), + CredentialsFile: v.CredentialsFile, + Uri: v.URI, + Pin: v.Pin, + ManagementKey: v.ManagementKey, + Region: v.Region, + Profile: v.Profile, + } + } + + // Authority + // cas options + if v := a.config.AuthorityConfig.Options; v != nil { + c.Authority.Type = 0 + c.Authority.CertificateAuthority = v.CertificateAuthority + c.Authority.CertificateAuthorityFingerprint = v.CertificateAuthorityFingerprint + c.Authority.CredentialsFile = v.CredentialsFile + if iss := v.CertificateIssuer; iss != nil { + typ, ok := linkedca.CertificateIssuer_Type_value[strings.ToUpper(iss.Type)] + if !ok { + return nil, errors.Errorf("unknown certificate issuer type %s", iss.Type) + } + // The exported certificate issuer should not include the password. + c.Authority.CertificateIssuer = &linkedca.CertificateIssuer{ + Type: linkedca.CertificateIssuer_Type(typ), + Provisioner: iss.Provisioner, + Certificate: mustReadFileOrURI(iss.Certificate, files), + Key: mustReadFileOrURI(iss.Key, files), + } + } + } + // admins + for { + list, cursor := a.admins.Find("", 100) + c.Authority.Admins = append(c.Authority.Admins, list...) + if cursor == "" { + break + } + } + // provisioners + for { + list, cursor := a.provisioners.Find("", 100) + for _, p := range list { + lp, err := ProvisionerToLinkedca(p) + if err != nil { + return nil, err + } + c.Authority.Provisioners = append(c.Authority.Provisioners, lp) + } + if cursor == "" { + break + } + } + // global claims + c.Authority.Claims = claimsToLinkedca(a.config.AuthorityConfig.Claims) + // Distinguished names template + if v := a.config.AuthorityConfig.Template; v != nil { + c.Authority.Template = &linkedca.DistinguishedName{ + Country: v.Country, + Organization: v.Organization, + OrganizationalUnit: v.OrganizationalUnit, + Locality: v.Locality, + Province: v.Province, + StreetAddress: v.StreetAddress, + SerialNumber: v.SerialNumber, + CommonName: v.CommonName, + } + } + + // TLS + if v := a.config.TLS; v != nil { + c.Tls = &linkedca.TLS{ + MinVersion: v.MinVersion.String(), + MaxVersion: v.MaxVersion.String(), + Renegotiation: v.Renegotiation, + } + for _, cs := range v.CipherSuites.Value() { + c.Tls.CipherSuites = append(c.Tls.CipherSuites, linkedca.TLS_CiperSuite(cs)) + } + } + + // Templates + if v := a.config.Templates; v != nil { + c.Templates = &linkedca.ConfigTemplates{ + Ssh: &linkedca.SSHConfigTemplate{}, + Data: mustMarshalToStruct(v.Data), + } + // Remove automatically loaded vars + if c.Templates.Data != nil && c.Templates.Data.Fields != nil { + delete(c.Templates.Data.Fields, "Step") + } + for _, t := range v.SSH.Host { + typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))] + if !ok { + return nil, errors.Errorf("unsupported template type %s", t.Type) + } + c.Templates.Ssh.Hosts = append(c.Templates.Ssh.Hosts, &linkedca.ConfigTemplate{ + Type: linkedca.ConfigTemplate_Type(typ), + Name: t.Name, + Template: mustReadFileOrURI(t.TemplatePath, files), + Path: t.Path, + Comment: t.Comment, + Requires: t.RequiredData, + Content: t.Content, + }) + } + for _, t := range v.SSH.User { + typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))] + if !ok { + return nil, errors.Errorf("unsupported template type %s", t.Type) + } + c.Templates.Ssh.Users = append(c.Templates.Ssh.Users, &linkedca.ConfigTemplate{ + Type: linkedca.ConfigTemplate_Type(typ), + Name: t.Name, + Template: mustReadFileOrURI(t.TemplatePath, files), + Path: t.Path, + Comment: t.Comment, + Requires: t.RequiredData, + Content: t.Content, + }) + } + } + + return c, nil +} + +func mustDuration(d *provisioner.Duration) string { + if d == nil || d.Duration == 0 { + return "" + } + return d.String() +} + +func mustMarshalToStruct(v interface{}) *structpb.Struct { + b, err := json.Marshal(v) + if err != nil { + panic(errors.Wrapf(err, "error marshaling %T", v)) + } + var r *structpb.Struct + if err := json.Unmarshal(b, &r); err != nil { + panic(errors.Wrapf(err, "error unmarshaling %T", v)) + } + return r +} + +func mustReadFileOrURI(fn string, m map[string][]byte) string { + if fn == "" { + return "" + } + + stepPath := filepath.ToSlash(config.StepPath()) + if !strings.HasSuffix(stepPath, "/") { + stepPath += "/" + } + + fn = strings.TrimPrefix(filepath.ToSlash(fn), stepPath) + + ok, err := isFilename(fn) + if err != nil { + panic(err) + } + if ok { + b, err := ioutil.ReadFile(config.StepAbs(fn)) + if err != nil { + panic(errors.Wrapf(err, "error reading %s", fn)) + } + m[fn] = b + return fn + } + return fn +} + +func mustReadFilesOrURIs(fns []string, m map[string][]byte) []string { + var result []string + for _, fn := range fns { + result = append(result, mustReadFileOrURI(fn, m)) + } + return result +} + +func isFilename(fn string) (bool, error) { + u, err := url.Parse(fn) + if err != nil { + return false, errors.Wrapf(err, "error parsing %s", fn) + } + return u.Scheme == "" || u.Scheme == "file", nil +} diff --git a/authority/linkedca.go b/authority/linkedca.go new file mode 100644 index 00000000..b568dcbb --- /dev/null +++ b/authority/linkedca.go @@ -0,0 +1,490 @@ +package authority + +import ( + "context" + "crypto" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "encoding/pem" + "fmt" + "net/url" + "regexp" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/db" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/tlsutil" + "go.step.sm/crypto/x509util" + "go.step.sm/linkedca" + "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$" + +type linkedCaClient struct { + renewer *tlsutil.Renewer + client linkedca.MajordomoClient + authorityID string +} + +type linkedCAClaims struct { + jose.Claims + SANs []string `json:"sans"` + SHA string `json:"sha"` +} + +func newLinkedCAClient(token string) (*linkedCaClient, error) { + tok, err := jose.ParseSigned(token) + if err != nil { + return nil, errors.Wrap(err, "error parsing token") + } + + var claims linkedCAClaims + if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, errors.Wrap(err, "error parsing token") + } + // Validate claims + if len(claims.Audience) != 1 { + return nil, errors.New("error parsing token: invalid aud claim") + } + if claims.SHA == "" { + return nil, errors.New("error parsing token: invalid sha claim") + } + // Get linkedCA endpoint from audience. + u, err := url.Parse(claims.Audience[0]) + if err != nil { + return nil, errors.New("error parsing token: invalid aud claim") + } + // Get authority from SANs + authority, err := getAuthority(claims.SANs) + if err != nil { + return nil, err + } + + // Create csr to login with + signer, err := keyutil.GenerateDefaultSigner() + if err != nil { + return nil, err + } + csr, err := x509util.CreateCertificateRequest(claims.Subject, claims.SANs, signer) + if err != nil { + return nil, err + } + + // Get and verify root certificate + root, err := getRootCertificate(u.Host, claims.SHA) + if err != nil { + return nil, err + } + + pool := x509.NewCertPool() + pool.AddCert(root) + + // Login with majordomo and get certificates + cert, tlsConfig, err := login(authority, token, csr, signer, u.Host, pool) + if err != nil { + return nil, err + } + + // Start TLS renewer and set the GetClientCertificate callback to it. + renewer, err := tlsutil.NewRenewer(cert, tlsConfig, func() (*tls.Certificate, *tls.Config, error) { + return login(authority, token, csr, signer, u.Host, pool) + }) + if err != nil { + return nil, err + } + tlsConfig.GetClientCertificate = renewer.GetClientCertificate + + // Start mTLS client + conn, err := grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + if err != nil { + return nil, errors.Wrapf(err, "error connecting %s", u.Host) + } + + return &linkedCaClient{ + renewer: renewer, + client: linkedca.NewMajordomoClient(conn), + authorityID: authority, + }, nil +} + +func (c *linkedCaClient) Run() { + c.renewer.Run() +} + +func (c *linkedCaClient) Stop() { + c.renewer.Stop() +} + +func (c *linkedCaClient) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + resp, err := c.client.CreateProvisioner(ctx, &linkedca.CreateProvisionerRequest{ + Type: prov.Type, + Name: prov.Name, + Details: prov.Details, + Claims: prov.Claims, + X509Template: prov.X509Template, + SshTemplate: prov.SshTemplate, + }) + if err != nil { + return errors.Wrap(err, "error creating provisioner") + } + prov.Id = resp.Id + prov.AuthorityId = resp.AuthorityId + return nil +} + +func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) { + resp, err := c.client.GetProvisioner(ctx, &linkedca.GetProvisionerRequest{ + Id: id, + }) + if err != nil { + return nil, errors.Wrap(err, "error getting provisioners") + } + return resp, nil +} + +func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) { + resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ + AuthorityId: c.authorityID, + }) + if err != nil { + return nil, errors.Wrap(err, "error getting provisioners") + } + return resp.Provisioners, nil +} + +func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + _, err := c.client.UpdateProvisioner(ctx, &linkedca.UpdateProvisionerRequest{ + Id: prov.Id, + Name: prov.Name, + Details: prov.Details, + Claims: prov.Claims, + X509Template: prov.X509Template, + SshTemplate: prov.SshTemplate, + }) + return errors.Wrap(err, "error updating provisioner") +} + +func (c *linkedCaClient) DeleteProvisioner(ctx context.Context, id string) error { + _, err := c.client.DeleteProvisioner(ctx, &linkedca.DeleteProvisionerRequest{ + Id: id, + }) + return errors.Wrap(err, "error deleting provisioner") +} + +func (c *linkedCaClient) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error { + resp, err := c.client.CreateAdmin(ctx, &linkedca.CreateAdminRequest{ + Subject: adm.Subject, + ProvisionerId: adm.ProvisionerId, + Type: adm.Type, + }) + if err != nil { + return errors.Wrap(err, "error creating admin") + } + adm.Id = resp.Id + adm.AuthorityId = resp.AuthorityId + return nil +} + +func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) { + resp, err := c.client.GetAdmin(ctx, &linkedca.GetAdminRequest{ + Id: id, + }) + if err != nil { + return nil, errors.Wrap(err, "error getting admins") + } + return resp, nil +} + +func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) { + resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ + AuthorityId: c.authorityID, + }) + if err != nil { + return nil, errors.Wrap(err, "error getting admins") + } + return resp.Admins, nil +} + +func (c *linkedCaClient) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error { + _, err := c.client.UpdateAdmin(ctx, &linkedca.UpdateAdminRequest{ + Id: adm.Id, + Type: adm.Type, + }) + return errors.Wrap(err, "error updating admin") +} + +func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error { + _, err := c.client.DeleteAdmin(ctx, &linkedca.DeleteAdminRequest{ + Id: id, + }) + return errors.Wrap(err, "error deleting admin") +} + +func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ + PemCertificate: serializeCertificateChain(fullchain[0]), + PemCertificateChain: serializeCertificateChain(fullchain[1:]...), + }) + return errors.Wrap(err, "error posting certificate") +} + +func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullchain ...*x509.Certificate) error { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ + PemCertificate: serializeCertificateChain(fullchain[0]), + PemCertificateChain: serializeCertificateChain(fullchain[1:]...), + PemParentCertificate: serializeCertificateChain(parent), + }) + return errors.Wrap(err, "error posting certificate") +} + +func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ + Certificate: string(ssh.MarshalAuthorizedKey(crt)), + }) + return errors.Wrap(err, "error posting ssh certificate") +} + +func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + _, err := c.client.RevokeCertificate(ctx, &linkedca.RevokeCertificateRequest{ + Serial: rci.Serial, + PemCertificate: serializeCertificate(crt), + Reason: rci.Reason, + ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode), + Passive: true, + }) + + return errors.Wrap(err, "error revoking certificate") +} + +func (c *linkedCaClient) RevokeSSH(cert *ssh.Certificate, rci *db.RevokedCertificateInfo) error { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + _, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{ + Serial: rci.Serial, + Certificate: serializeSSHCertificate(cert), + Reason: rci.Reason, + ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode), + Passive: true, + }) + + return errors.Wrap(err, "error revoking ssh certificate") +} + +func (c *linkedCaClient) IsRevoked(serial string) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + resp, err := c.client.GetCertificateStatus(ctx, &linkedca.GetCertificateStatusRequest{ + Serial: serial, + }) + if err != nil { + return false, errors.Wrap(err, "error getting certificate status") + } + return resp.Status != linkedca.RevocationStatus_ACTIVE, nil +} + +func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + resp, err := c.client.GetSSHCertificateStatus(ctx, &linkedca.GetSSHCertificateStatusRequest{ + Serial: serial, + }) + if err != nil { + return false, errors.Wrap(err, "error getting certificate status") + } + return resp.Status != linkedca.RevocationStatus_ACTIVE, nil +} + +func serializeCertificate(crt *x509.Certificate) string { + if crt == nil { + return "" + } + return string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: crt.Raw, + })) +} + +func serializeCertificateChain(fullchain ...*x509.Certificate) string { + var chain string + for _, crt := range fullchain { + chain += string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: crt.Raw, + })) + } + return chain +} + +func serializeSSHCertificate(crt *ssh.Certificate) string { + if crt == nil { + return "" + } + return string(ssh.MarshalAuthorizedKey(crt)) +} + +func getAuthority(sans []string) (string, error) { + for _, s := range sans { + if strings.HasPrefix(s, "urn:smallstep:authority:") { + if regexp.MustCompile(uuidPattern).MatchString(s[24:]) { + return s[24:], nil + } + } + } + return "", fmt.Errorf("error parsing token: invalid sans claim") +} + +// getRootCertificate creates an insecure majordomo client and returns the +// verified root certificate. +func getRootCertificate(endpoint, fingerprint string) (*x509.Certificate, error) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + InsecureSkipVerify: true, + }))) + if err != nil { + return nil, errors.Wrapf(err, "error connecting %s", endpoint) + } + + ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + client := linkedca.NewMajordomoClient(conn) + resp, err := client.GetRootCertificate(ctx, &linkedca.GetRootCertificateRequest{ + Fingerprint: fingerprint, + }) + if err != nil { + return nil, fmt.Errorf("error getting root certificate: %w", err) + } + + var block *pem.Block + b := []byte(resp.PemCertificate) + for len(b) > 0 { + block, b = pem.Decode(b) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %w", err) + } + + // verify the sha256 + sum := sha256.Sum256(cert.Raw) + if !strings.EqualFold(fingerprint, hex.EncodeToString(sum[:])) { + return nil, fmt.Errorf("error verifying certificate: SHA256 fingerprint does not match") + } + + return cert, nil + } + + return nil, fmt.Errorf("error getting root certificate: certificate not found") +} + +// login creates a new majordomo client with just the root ca pool and returns +// the signed certificate and tls configuration. +func login(authority, token string, csr *x509.CertificateRequest, signer crypto.PrivateKey, endpoint string, rootCAs *x509.CertPool) (*tls.Certificate, *tls.Config, error) { + // Connect to majordomo + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + RootCAs: rootCAs, + }))) + if err != nil { + return nil, nil, errors.Wrapf(err, "error connecting %s", endpoint) + } + + // Login to get the signed certificate + ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + client := linkedca.NewMajordomoClient(conn) + resp, err := client.Login(ctx, &linkedca.LoginRequest{ + AuthorityId: authority, + Token: token, + PemCertificateRequest: string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csr.Raw, + })), + }) + if err != nil { + return nil, nil, errors.Wrapf(err, "error logging in %s", endpoint) + } + + // Parse login response + var block *pem.Block + var bundle []*x509.Certificate + rest := []byte(resp.PemCertificateChain) + for { + block, rest = pem.Decode(rest) + if block == nil { + break + } + if block.Type != "CERTIFICATE" { + return nil, nil, errors.New("error decoding login response: pemCertificateChain is not a certificate bundle") + } + crt, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, errors.Wrap(err, "error parsing login response") + } + bundle = append(bundle, crt) + } + if len(bundle) == 0 { + return nil, nil, errors.New("error decoding login response: pemCertificateChain should not be empty") + } + + // Build tls.Certificate with PemCertificate and intermediates in the + // PemCertificateChain + cert := &tls.Certificate{ + PrivateKey: signer, + } + rest = []byte(resp.PemCertificate) + for { + block, rest = pem.Decode(rest) + if block == nil { + break + } + if block.Type == "CERTIFICATE" { + leaf, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, errors.Wrap(err, "error parsing pemCertificate") + } + cert.Certificate = append(cert.Certificate, block.Bytes) + cert.Leaf = leaf + } + } + + // Add intermediates to the tls.Certificate + last := len(bundle) - 1 + for i := 0; i < last; i++ { + cert.Certificate = append(cert.Certificate, bundle[i].Raw) + } + + // Add root to the pool if it's not there yet + rootCAs.AddCert(bundle[last]) + + return cert, &tls.Config{ + RootCAs: rootCAs, + }, nil +} diff --git a/authority/options.go b/authority/options.go index 9594f989..0f80cbbf 100644 --- a/authority/options.go +++ b/authority/options.go @@ -7,6 +7,8 @@ import ( "encoding/pem" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" @@ -20,9 +22,9 @@ type Option func(*Authority) error // WithConfig replaces the current config with the given one. No validation is // performed in the given value. -func WithConfig(config *Config) Option { +func WithConfig(cfg *config.Config) Option { return func(a *Authority) error { - a.config = config + a.config = cfg return nil } } @@ -31,16 +33,52 @@ func WithConfig(config *Config) Option { // the current one. No validation is performed in the given configuration. func WithConfigFile(filename string) Option { return func(a *Authority) (err error) { - a.config, err = LoadConfiguration(filename) + a.config, err = config.LoadConfiguration(filename) + return + } +} + +// WithPassword set the password to decrypt the intermediate key as well as the +// ssh host and user keys if they are not overridden by other options. +func WithPassword(password []byte) Option { + return func(a *Authority) (err error) { + a.password = password + return + } +} + +// WithSSHHostPassword set the password to decrypt the key used to sign SSH host +// certificates. +func WithSSHHostPassword(password []byte) Option { + return func(a *Authority) (err error) { + a.sshHostPassword = password + return + } +} + +// WithSSHUserPassword set the password to decrypt the key used to sign SSH user +// certificates. +func WithSSHUserPassword(password []byte) Option { + return func(a *Authority) (err error) { + a.sshUserPassword = password + return + } +} + +// WithIssuerPassword set the password to decrypt the certificate issuer private +// key used in RA mode. +func WithIssuerPassword(password []byte) Option { + return func(a *Authority) (err error) { + a.issuerPassword = password return } } // WithDatabase sets an already initialized authority database to a new // authority. This option is intended to be use on graceful reloads. -func WithDatabase(db db.AuthDB) Option { +func WithDatabase(d db.AuthDB) Option { return func(a *Authority) error { - a.db = db + a.db = d return nil } } @@ -56,7 +94,7 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e // WithSSHBastionFunc sets a custom function to get the bastion for a // given user-host pair. -func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*Bastion, error)) Option { +func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option { return func(a *Authority) error { a.sshBastionFunc = fn return nil @@ -65,7 +103,7 @@ func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*Bastio // WithSSHGetHosts sets a custom function to get the bastion for a // given user-host pair. -func WithSSHGetHosts(fn func(ctx context.Context, cert *x509.Certificate) ([]Host, error)) Option { +func WithSSHGetHosts(fn func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)) Option { return func(a *Authority) error { a.sshGetHostsFunc = fn return nil @@ -186,6 +224,23 @@ func WithX509FederatedBundle(pemCerts []byte) Option { } } +// WithAdminDB is an option to set the database backing the admin APIs. +func WithAdminDB(d admin.DB) Option { + return func(a *Authority) error { + a.adminDB = d + return nil + } +} + +// WithLinkedCAToken is an option to set the authentication token used to enable +// linked ca. +func WithLinkedCAToken(token string) Option { + return func(a *Authority) error { + a.linkedCAToken = token + return nil + } +} + func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) { var block *pem.Block var certs []*x509.Certificate diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index a36c496d..d81b0231 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -13,6 +13,7 @@ import ( // provisioning flow. type ACME struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` ForceCN bool `json:"forceCN,omitempty"` @@ -23,6 +24,15 @@ type ACME struct { // GetID returns the provisioner unique identifier. func (p ACME) GetID() string { + if p.ID != "" { + return p.ID + } + return p.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (p *ACME) GetIDForToken() string { return "acme/" + p.Name } @@ -95,7 +105,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // certificate was configured to allow renewals. func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()) + return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()) } return nil } diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index 7b669d8d..bd173f87 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -61,7 +61,7 @@ func TestACME_Init(t *testing.T) { "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{Name: "foo", Type: "bar", Claims: &Claims{DefaultTLSDur: &Duration{0}}}, - err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"), + err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { @@ -110,7 +110,7 @@ func TestACME_AuthorizeRenew(t *testing.T) { p: p, cert: &x509.Certificate{}, code: http.StatusUnauthorized, - err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()), + err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index 75115154..cd129b7b 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -252,6 +252,7 @@ type awsInstanceIdentityDocument struct { // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html type AWS struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Accounts []string `json:"accounts"` @@ -269,6 +270,15 @@ type AWS struct { // GetID returns the provisioner unique identifier. func (p *AWS) GetID() string { + if p.ID != "" { + return p.ID + } + return p.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (p *AWS) GetIDForToken() string { return "aws/" + p.Name } @@ -286,7 +296,7 @@ func (p *AWS) GetTokenID(token string) (string, error) { } // Use provisioner + instance-id as the identifier. - unique := fmt.Sprintf("%s.%s", p.GetID(), payload.document.InstanceID) + unique := fmt.Sprintf("%s.%s", p.GetIDForToken(), payload.document.InstanceID) sum := sha256.Sum256([]byte(unique)) return strings.ToLower(hex.EncodeToString(sum[:])), nil } @@ -302,7 +312,7 @@ func (p *AWS) GetType() Type { } // GetEncryptedKey is not available in an AWS provisioner. -func (p *AWS) GetEncryptedKey() (kid string, key string, ok bool) { +func (p *AWS) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } @@ -334,7 +344,7 @@ func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) { return "", err } - audience, err := generateSignAudience(caURL, p.GetID()) + audience, err := generateSignAudience(caURL, p.GetIDForToken()) if err != nil { return "", err } @@ -342,7 +352,7 @@ func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) { // Create unique ID for Trust On First Use (TOFU). Only the first instance // per provisioner is allowed as we don't have a way to trust the given // sans. - unique := fmt.Sprintf("%s.%s", p.GetID(), idoc.InstanceID) + unique := fmt.Sprintf("%s.%s", p.GetIDForToken(), idoc.InstanceID) sum := sha256.Sum256([]byte(unique)) // Create a JWT from the identity document @@ -397,7 +407,7 @@ func (p *AWS) Init(config Config) (err error) { if p.config, err = newAWSConfig(p.IIDRoots); err != nil { return err } - p.audiences = config.Audiences.WithFragment(p.GetID()) + p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) // validate IMDS versions if len(p.IMDSVersions) == 0 { @@ -439,13 +449,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // There's no way to trust them other than TOFU. var so []SignOption if p.DisableCustomSANs { - dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) - so = append(so, dnsNamesValidator([]string{dnsName})) - so = append(so, ipAddressesValidator([]net.IP{ - net.ParseIP(doc.PrivateIP), - })) - so = append(so, emailAddressesValidator(nil)) - so = append(so, urisValidator(nil)) + dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region) + so = append(so, + dnsNamesValidator([]string{dnsName}), + ipAddressesValidator([]net.IP{ + net.ParseIP(doc.PrivateIP), + }), + emailAddressesValidator(nil), + urisValidator(nil), + ) // Template options data.SetSANs([]string{dnsName, doc.PrivateIP}) @@ -474,7 +486,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // certificate was configured to allow renewals. func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID()) + return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName()) } return nil } @@ -505,6 +517,11 @@ func (p *AWS) readURL(url string) ([]byte, error) { var resp *http.Response var err error + // Initialize IMDS versions when this is called from the cli. + if len(p.IMDSVersions) == 0 { + p.IMDSVersions = []string{"v2", "v1"} + } + for _, v := range p.IMDSVersions { switch v { case "v1": @@ -654,7 +671,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { if p.DisableCustomSANs { if payload.Subject != doc.InstanceID && payload.Subject != doc.PrivateIP && - payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) { + payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region) { return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)") } } @@ -687,7 +704,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID()) + return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) if err != nil { @@ -705,7 +722,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validated principals. principals := []string{ doc.PrivateIP, - fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region), + fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region), } // Only enforce known principals if disable custom sans is true. diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index dadf1f17..0d2786db 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -141,6 +141,12 @@ func TestAWS_GetIdentityToken(t *testing.T) { p7.config.signatureURL = p1.config.signatureURL p7.config.tokenURL = p1.config.tokenURL + p8, err := generateAWS() + assert.FatalError(t, err) + p8.IMDSVersions = nil + p8.Accounts = p1.Accounts + p8.config = p1.config + caURL := "https://ca.smallstep.com" u, err := url.Parse(caURL) assert.FatalError(t, err) @@ -156,6 +162,7 @@ func TestAWS_GetIdentityToken(t *testing.T) { wantErr bool }{ {"ok", p1, args{"foo.local", caURL}, false}, + {"ok no imds", p8, args{"foo.local", caURL}, false}, {"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true}, {"fail identityURL", p2, args{"foo.local", caURL}, true}, {"fail signatureURL", p3, args{"foo.local", caURL}, true}, @@ -656,15 +663,15 @@ func TestAWS_AuthorizeSign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) - got, err := tt.aws.AuthorizeSign(ctx, tt.args.token) - if (err != nil) != tt.wantErr { + switch got, err := tt.aws.AuthorizeSign(ctx, tt.args.token); { + case (err != nil) != tt.wantErr: t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return - } else if err != nil { + case err != nil: sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.code) - } else { + default: assert.Len(t, tt.wantLen, got) for _, o := range got { switch v := o.(type) { diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index ea8b08ec..a90d1728 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -84,6 +84,7 @@ type azurePayload struct { // and https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service type Azure struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` TenantID string `json:"tenantID"` @@ -101,6 +102,15 @@ type Azure struct { // GetID returns the provisioner unique identifier. func (p *Azure) GetID() string { + if p.ID != "" { + return p.ID + } + return p.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (p *Azure) GetIDForToken() string { return p.TenantID } @@ -121,9 +131,10 @@ func (p *Azure) GetTokenID(token string) (string, error) { return "", errors.Wrap(err, "error verifying claims") } - // If TOFU is disabled create return the token kid + // If TOFU is disabled then allow token re-use. Azure caches the token for + // 24h and without allowing the re-use we cannot use it twice. if p.DisableTrustOnFirstUse { - return claims.ID, nil + return "", ErrAllowTokenReuse } sum := sha256.Sum256([]byte(claims.XMSMirID)) @@ -141,7 +152,7 @@ func (p *Azure) GetType() Type { } // GetEncryptedKey is not available in an Azure provisioner. -func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) { +func (p *Azure) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } @@ -292,11 +303,13 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, var so []SignOption if p.DisableCustomSANs { // name will work only inside the virtual network - so = append(so, commonNameValidator(name)) - so = append(so, dnsNamesValidator([]string{name})) - so = append(so, ipAddressesValidator(nil)) - so = append(so, emailAddressesValidator(nil)) - so = append(so, urisValidator(nil)) + so = append(so, + commonNameValidator(name), + dnsNamesValidator([]string{name}), + ipAddressesValidator(nil), + emailAddressesValidator(nil), + urisValidator(nil), + ) // Enforce SANs in the template. data.SetSANs([]string{name}) @@ -324,7 +337,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // certificate was configured to allow renewals. func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID()) + return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName()) } return nil } @@ -332,7 +345,7 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID()) + return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) } _, name, _, err := p.authorizeToken(token) diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index f21a5676..b7c321a6 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -72,7 +72,7 @@ func TestAzure_GetTokenID(t *testing.T) { wantErr bool }{ {"ok", p1, args{t1}, w1, false}, - {"ok no TOFU", p2, args{t2}, "the-jti", false}, + {"ok no TOFU", p2, args{t2}, "", true}, {"fail token", p1, args{"bad-token"}, "", true}, {"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true}, } @@ -446,15 +446,15 @@ func TestAzure_AuthorizeSign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) - got, err := tt.azure.AuthorizeSign(ctx, tt.args.token) - if (err != nil) != tt.wantErr { + switch got, err := tt.azure.AuthorizeSign(ctx, tt.args.token); { + case (err != nil) != tt.wantErr: t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return - } else if err != nil { + case err != nil: sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.code) - } else { + default: assert.Len(t, tt.wantLen, got) for _, o := range got { switch v := o.(type) { diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index 997d9ba3..629a313c 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -71,6 +71,9 @@ func (c *Claimer) DefaultTLSCertDuration() time.Duration { // minimum from the authority configuration will be used. func (c *Claimer) MinTLSCertDuration() time.Duration { if c.claims == nil || c.claims.MinTLSDur == nil { + if c.claims != nil && c.claims.DefaultTLSDur != nil && c.claims.DefaultTLSDur.Duration < c.global.MinTLSDur.Duration { + return c.claims.DefaultTLSDur.Duration + } return c.global.MinTLSDur.Duration } return c.claims.MinTLSDur.Duration @@ -81,6 +84,9 @@ func (c *Claimer) MinTLSCertDuration() time.Duration { // maximum from the authority configuration will be used. func (c *Claimer) MaxTLSCertDuration() time.Duration { if c.claims == nil || c.claims.MaxTLSDur == nil { + if c.claims != nil && c.claims.DefaultTLSDur != nil && c.claims.DefaultTLSDur.Duration > c.global.MaxTLSDur.Duration { + return c.claims.DefaultTLSDur.Duration + } return c.global.MaxTLSDur.Duration } return c.claims.MaxTLSDur.Duration @@ -126,6 +132,9 @@ func (c *Claimer) DefaultUserSSHCertDuration() time.Duration { // global minimum from the authority configuration will be used. func (c *Claimer) MinUserSSHCertDuration() time.Duration { if c.claims == nil || c.claims.MinUserSSHDur == nil { + if c.claims != nil && c.claims.DefaultUserSSHDur != nil && c.claims.DefaultUserSSHDur.Duration < c.global.MinUserSSHDur.Duration { + return c.claims.DefaultUserSSHDur.Duration + } return c.global.MinUserSSHDur.Duration } return c.claims.MinUserSSHDur.Duration @@ -136,6 +145,9 @@ func (c *Claimer) MinUserSSHCertDuration() time.Duration { // global maximum from the authority configuration will be used. func (c *Claimer) MaxUserSSHCertDuration() time.Duration { if c.claims == nil || c.claims.MaxUserSSHDur == nil { + if c.claims != nil && c.claims.DefaultUserSSHDur != nil && c.claims.DefaultUserSSHDur.Duration > c.global.MaxUserSSHDur.Duration { + return c.claims.DefaultUserSSHDur.Duration + } return c.global.MaxUserSSHDur.Duration } return c.claims.MaxUserSSHDur.Duration @@ -156,6 +168,9 @@ func (c *Claimer) DefaultHostSSHCertDuration() time.Duration { // global minimum from the authority configuration will be used. func (c *Claimer) MinHostSSHCertDuration() time.Duration { if c.claims == nil || c.claims.MinHostSSHDur == nil { + if c.claims != nil && c.claims.DefaultHostSSHDur != nil && c.claims.DefaultHostSSHDur.Duration < c.global.MinHostSSHDur.Duration { + return c.claims.DefaultHostSSHDur.Duration + } return c.global.MinHostSSHDur.Duration } return c.claims.MinHostSSHDur.Duration @@ -166,6 +181,9 @@ func (c *Claimer) MinHostSSHCertDuration() time.Duration { // global maximum from the authority configuration will be used. func (c *Claimer) MaxHostSSHCertDuration() time.Duration { if c.claims == nil || c.claims.MaxHostSSHDur == nil { + if c.claims != nil && c.claims.DefaultHostSSHDur != nil && c.claims.DefaultHostSSHDur.Duration > c.global.MaxHostSSHDur.Duration { + return c.claims.DefaultHostSSHDur.Duration + } return c.global.MaxHostSSHDur.Duration } return c.claims.MaxHostSSHDur.Duration diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index 30f950a5..1bec8689 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -12,7 +12,7 @@ import ( "strings" "sync" - "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" "go.step.sm/crypto/jose" ) @@ -37,14 +37,17 @@ func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // provisioner. type loadByTokenPayload struct { jose.Claims - AuthorizedParty string `json:"azp"` // OIDC client id - TenantID string `json:"tid"` // Microsoft Azure tenant id + Email string `json:"email"` // OIDC email + AuthorizedParty string `json:"azp"` // OIDC client id + TenantID string `json:"tid"` // Microsoft Azure tenant id } // Collection is a memory map of provisioners. type Collection struct { byID *sync.Map byKey *sync.Map + byName *sync.Map + byTokenID *sync.Map sorted provisionerSlice audiences Audiences } @@ -55,6 +58,8 @@ func NewCollection(audiences Audiences) *Collection { return &Collection{ byID: new(sync.Map), byKey: new(sync.Map), + byName: new(sync.Map), + byTokenID: new(sync.Map), audiences: audiences, } } @@ -64,6 +69,18 @@ func (c *Collection) Load(id string) (Interface, bool) { return loadProvisioner(c.byID, id) } +// LoadByName a provisioner by name. +func (c *Collection) LoadByName(name string) (Interface, bool) { + return loadProvisioner(c.byName, name) +} + +// LoadByTokenID a provisioner by identifier found in token. +// For different provisioner types this identifier may be found in in different +// attributes of the token. +func (c *Collection) LoadByTokenID(tokenProvisionerID string) (Interface, bool) { + return loadProvisioner(c.byTokenID, tokenProvisionerID) +} + // LoadByToken parses the token claims and loads the provisioner associated. func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) { var audiences []string @@ -79,11 +96,12 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) if matchesAudience(claims.Audience, audiences) { // Use fragment to get provisioner name (GCP, AWS, SSHPOP) if fragment != "" { - return c.Load(fragment) + return c.LoadByTokenID(fragment) } // If matches with stored audiences it will be a JWT token (default), and // the id would be :. - return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID) + // TODO: is this ok? + return c.LoadByTokenID(claims.Issuer + ":" + token.Headers[0].KeyID) } // The ID will be just the clientID stored in azp, aud or tid. @@ -94,7 +112,7 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) // Kubernetes Service Account tokens. if payload.Issuer == k8sSAIssuer { - if p, ok := c.Load(K8sSAID); ok { + if p, ok := c.LoadByTokenID(K8sSAID); ok { return p, ok } // Kubernetes service account provisioner not found @@ -108,18 +126,26 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) // Try with azp (OIDC) if len(payload.AuthorizedParty) > 0 { - if p, ok := c.Load(payload.AuthorizedParty); ok { + if p, ok := c.LoadByTokenID(payload.AuthorizedParty); ok { return p, ok } } - // Try with tid (Azure) + // Try with tid (Azure, Azure OIDC) if payload.TenantID != "" { - if p, ok := c.Load(payload.TenantID); ok { + // Try to load an OIDC provisioner first. + if payload.Email != "" { + if p, ok := c.LoadByTokenID(payload.Audience[0]); ok { + return p, ok + } + } + // Try to load an Azure provisioner. + if p, ok := c.LoadByTokenID(payload.TenantID); ok { return p, ok } } + // Fallback to aud - return c.Load(payload.Audience[0]) + return c.LoadByTokenID(payload.Audience[0]) } // LoadByCertificate looks for the provisioner extension and extracts the @@ -131,24 +157,7 @@ func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { return nil, false } - switch Type(provisioner.Type) { - case TypeJWK: - return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID)) - case TypeAWS: - return c.Load("aws/" + string(provisioner.Name)) - case TypeGCP: - return c.Load("gcp/" + string(provisioner.Name)) - case TypeACME: - return c.Load("acme/" + string(provisioner.Name)) - case TypeSCEP: - return c.Load("scep/" + string(provisioner.Name)) - case TypeX5C: - return c.Load("x5c/" + string(provisioner.Name)) - case TypeK8sSA: - return c.Load(K8sSAID) - default: - return c.Load(string(provisioner.CredentialID)) - } + return c.LoadByName(string(provisioner.Name)) } } @@ -173,7 +182,21 @@ func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) { func (c *Collection) Store(p Interface) error { // Store provisioner always in byID. ID must be unique. if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded { - return errors.New("cannot add multiple provisioners with the same id") + return admin.NewError(admin.ErrorBadRequestType, + "cannot add multiple provisioners with the same id") + } + // Store provisioner always by name. + if _, loaded := c.byName.LoadOrStore(p.GetName(), p); loaded { + c.byID.Delete(p.GetID()) + return admin.NewError(admin.ErrorBadRequestType, + "cannot add multiple provisioners with the same name") + } + // Store provisioner always by ID presented in token. + if _, loaded := c.byTokenID.LoadOrStore(p.GetIDForToken(), p); loaded { + c.byID.Delete(p.GetID()) + c.byName.Delete(p.GetName()) + return admin.NewError(admin.ErrorBadRequestType, + "cannot add multiple provisioners with the same token identifier") } // Store provisioner in byKey if EncryptedKey is defined. @@ -197,6 +220,66 @@ func (c *Collection) Store(p Interface) error { return nil } +// Remove deletes an provisioner from all associated collections and lists. +func (c *Collection) Remove(id string) error { + prov, ok := c.Load(id) + if !ok { + return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id) + } + + var found bool + for i, elem := range c.sorted { + if elem.provisioner.GetID() != id { + continue + } + // Remove index in sorted list + copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index. + c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value). + c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice. + found = true + break + } + if !found { + return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName()) + } + + c.byID.Delete(id) + c.byName.Delete(prov.GetName()) + c.byTokenID.Delete(prov.GetIDForToken()) + if kid, _, ok := prov.GetEncryptedKey(); ok { + c.byKey.Delete(kid) + } + + return nil +} + +// Update updates the given provisioner in all related lists and collections. +func (c *Collection) Update(nu Interface) error { + old, ok := c.Load(nu.GetID()) + if !ok { + return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", nu.GetID()) + } + + if old.GetName() != nu.GetName() { + if _, ok := c.LoadByName(nu.GetName()); ok { + return admin.NewError(admin.ErrorBadRequestType, + "provisioner with name %s already exists", nu.GetName()) + } + } + if old.GetIDForToken() != nu.GetIDForToken() { + if _, ok := c.LoadByTokenID(nu.GetIDForToken()); ok { + return admin.NewError(admin.ErrorBadRequestType, + "provisioner with Token ID %s already exists", nu.GetIDForToken()) + } + } + + if err := c.Remove(old.GetID()); err != nil { + return err + } + + return c.Store(nu) +} + // Find implements pagination on a list of sorted provisioners. func (c *Collection) Find(cursor string, limit int) (List, string) { switch { diff --git a/authority/provisioner/collection_test.go b/authority/provisioner/collection_test.go index a0a79e92..348b797c 100644 --- a/authority/provisioner/collection_test.go +++ b/authority/provisioner/collection_test.go @@ -132,6 +132,7 @@ func TestCollection_LoadByToken(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := &Collection{ byID: tt.fields.byID, + byTokenID: tt.fields.byID, audiences: tt.fields.audiences, } got, got1 := c.LoadByToken(tt.args.token, tt.args.claims) @@ -153,10 +154,10 @@ func TestCollection_LoadByCertificate(t *testing.T) { p3, err := generateACME() assert.FatalError(t, err) - byID := new(sync.Map) - byID.Store(p1.GetID(), p1) - byID.Store(p2.GetID(), p2) - byID.Store(p3.GetID(), p3) + byName := new(sync.Map) + byName.Store(p1.GetName(), p1) + byName.Store(p2.GetName(), p2) + byName.Store(p3.GetName(), p3) ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID) assert.FatalError(t, err) @@ -186,7 +187,7 @@ func TestCollection_LoadByCertificate(t *testing.T) { } type fields struct { - byID *sync.Map + byName *sync.Map audiences Audiences } type args struct { @@ -199,17 +200,17 @@ func TestCollection_LoadByCertificate(t *testing.T) { want Interface want1 bool }{ - {"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true}, - {"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true}, - {"ok3", fields{byID, testAudiences}, args{ok3Cert}, p3, true}, - {"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true}, - {"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false}, - {"badCert", fields{byID, testAudiences}, args{badCert}, nil, false}, + {"ok1", fields{byName, testAudiences}, args{ok1Cert}, p1, true}, + {"ok2", fields{byName, testAudiences}, args{ok2Cert}, p2, true}, + {"ok3", fields{byName, testAudiences}, args{ok3Cert}, p3, true}, + {"noExtension", fields{byName, testAudiences}, args{&x509.Certificate{}}, &noop{}, true}, + {"notFound", fields{byName, testAudiences}, args{notFoundCert}, nil, false}, + {"badCert", fields{byName, testAudiences}, args{badCert}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Collection{ - byID: tt.fields.byID, + byName: tt.fields.byName, audiences: tt.fields.audiences, } got, got1 := c.LoadByCertificate(tt.args.cert) diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index 830e7965..98d776d1 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -78,6 +78,7 @@ func newGCPConfig() *gcpConfig { // https://cloud.google.com/compute/docs/instances/verifying-instance-identity type GCP struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` ServiceAccounts []string `json:"serviceAccounts"` @@ -96,6 +97,16 @@ type GCP struct { // GetID returns the provisioner unique identifier. The name should uniquely // identify any GCP provisioner. func (p *GCP) GetID() string { + if p.ID != "" { + return p.ID + } + return p.GetIDForToken() + +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (p *GCP) GetIDForToken() string { return "gcp/" + p.Name } @@ -123,7 +134,7 @@ func (p *GCP) GetTokenID(token string) (string, error) { // Create unique ID for Trust On First Use (TOFU). Only the first instance // per provisioner is allowed as we don't have a way to trust the given // sans. - unique := fmt.Sprintf("%s.%s", p.GetID(), claims.Google.ComputeEngine.InstanceID) + unique := fmt.Sprintf("%s.%s", p.GetIDForToken(), claims.Google.ComputeEngine.InstanceID) sum := sha256.Sum256([]byte(unique)) return strings.ToLower(hex.EncodeToString(sum[:])), nil } @@ -139,7 +150,7 @@ func (p *GCP) GetType() Type { } // GetEncryptedKey is not available in a GCP provisioner. -func (p *GCP) GetEncryptedKey() (kid string, key string, ok bool) { +func (p *GCP) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } @@ -157,7 +168,7 @@ func (p *GCP) GetIdentityURL(audience string) string { // GetIdentityToken does an HTTP request to the identity url. func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) { - audience, err := generateSignAudience(caURL, p.GetID()) + audience, err := generateSignAudience(caURL, p.GetIDForToken()) if err != nil { return "", err } @@ -205,7 +216,7 @@ func (p *GCP) Init(config Config) error { return err } - p.audiences = config.Audiences.WithFragment(p.GetID()) + p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) return nil } @@ -233,15 +244,17 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er if p.DisableCustomSANs { dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID) dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID) - so = append(so, commonNameSliceValidator([]string{ - ce.InstanceName, ce.InstanceID, dnsName1, dnsName2, - })) - so = append(so, dnsNamesValidator([]string{ - dnsName1, dnsName2, - })) - so = append(so, ipAddressesValidator(nil)) - so = append(so, emailAddressesValidator(nil)) - so = append(so, urisValidator(nil)) + so = append(so, + commonNameSliceValidator([]string{ + ce.InstanceName, ce.InstanceID, dnsName1, dnsName2, + }), + dnsNamesValidator([]string{ + dnsName1, dnsName2, + }), + ipAddressesValidator(nil), + emailAddressesValidator(nil), + urisValidator(nil), + ) // Template SANs data.SetSANs([]string{dnsName1, dnsName2}) @@ -266,7 +279,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // AuthorizeRenew returns an error if the renewal is disabled. func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID()) + return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName()) } return nil } @@ -371,7 +384,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID()) + return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) if err != nil { diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index d6c4054c..5f6f9bc7 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -535,15 +535,15 @@ func TestGCP_AuthorizeSign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) - got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token) - if (err != nil) != tt.wantErr { + switch got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token); { + case (err != nil) != tt.wantErr: t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return - } else if err != nil { + case err != nil: sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.code) - } else { + default: assert.Len(t, tt.wantLen, got) for _, o := range got { switch v := o.(type) { diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index d6a97e2b..56768fb7 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -28,6 +28,7 @@ type stepPayload struct { // signature requests. type JWK struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Key *jose.JSONWebKey `json:"key"` @@ -41,6 +42,15 @@ type JWK struct { // GetID returns the provisioner unique identifier. The name and credential id // should uniquely identify any JWK provisioner. func (p *JWK) GetID() string { + if p.ID != "" { + return p.ID + } + return p.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (p *JWK) GetIDForToken() string { return p.Name + ":" + p.Key.KeyID } @@ -184,7 +194,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID()) + return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName()) } return nil } @@ -192,7 +202,7 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID()) + return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token, p.audiences.SSHSign) if err != nil { diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 9198ff69..deae8f7a 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -77,7 +77,7 @@ func TestJWK_Init(t *testing.T) { "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, - err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"), + err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index 209a7dd4..d260f5ec 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -42,6 +42,7 @@ type k8sSAPayload struct { // entity trusted to make signature requests. type K8sSA struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` PubKeys []byte `json:"publicKeys,omitempty"` @@ -56,6 +57,15 @@ type K8sSA struct { // GetID returns the provisioner unique identifier. The name and credential id // should uniquely identify any K8sSA provisioner. func (p *K8sSA) GetID() string { + if p.ID != "" { + return p.ID + } + return p.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (p *K8sSA) GetIDForToken() string { return K8sSAID } @@ -101,12 +111,12 @@ func (p *K8sSA) Init(config Config) (err error) { } key, err := pemutil.ParseKey(pem.EncodeToMemory(block)) if err != nil { - return errors.Wrapf(err, "error parsing public key in provisioner %s", p.GetID()) + return errors.Wrapf(err, "error parsing public key in provisioner '%s'", p.GetName()) } switch q := key.(type) { case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: default: - return errors.Errorf("Unexpected public key type %T in provisioner %s", q, p.GetID()) + return errors.Errorf("Unexpected public key type %T in provisioner '%s'", q, p.GetName()) } p.pubKeys = append(p.pubKeys, key) } @@ -240,7 +250,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // AuthorizeRenew returns an error if the renewal is disabled. func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()) + return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()) } return nil } @@ -248,7 +258,7 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro // AuthorizeSSHSign validates an request for an SSH certificate. func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()) + return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token, p.audiences.SSHSign) if err != nil { diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 03ae7eff..176cdfd3 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -198,7 +198,7 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { p: p, cert: &x509.Certificate{}, code: http.StatusUnauthorized, - err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()), + err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -319,7 +319,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { p: p, token: "foo", code: http.StatusUnauthorized, - err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()), + err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { diff --git a/authority/provisioner/keystore.go b/authority/provisioner/keystore.go index f775e150..d1811fab 100644 --- a/authority/provisioner/keystore.go +++ b/authority/provisioner/keystore.go @@ -18,7 +18,7 @@ const ( defaultCacheJitter = 1 * time.Hour ) -var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)") +var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`) type keyStore struct { sync.RWMutex diff --git a/authority/provisioner/noop.go b/authority/provisioner/noop.go index ccdeccf4..1709fbca 100644 --- a/authority/provisioner/noop.go +++ b/authority/provisioner/noop.go @@ -14,6 +14,10 @@ func (p *noop) GetID() string { return "noop" } +func (p *noop) GetIDForToken() string { + return "noop" +} + func (p *noop) GetTokenID(token string) (string, error) { return "", nil } @@ -25,7 +29,7 @@ func (p *noop) GetType() Type { return noopType } -func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) { +func (p *noop) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 46e1c623..ac1f2a25 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -49,11 +49,35 @@ type openIDPayload struct { Groups []string `json:"groups"` } +func (o *openIDPayload) IsAdmin(admins []string) bool { + if o.Email != "" { + email := sanitizeEmail(o.Email) + for _, e := range admins { + if email == sanitizeEmail(e) { + return true + } + } + } + + // The groups and emails can be in the same array for now, but consider + // making a specialized option later. + for _, name := range o.Groups { + for _, admin := range admins { + if name == admin { + return true + } + } + } + + return false +} + // OIDC represents an OAuth 2.0 OpenID Connect provider. // // ClientSecret is mandatory, but it can be an empty string. type OIDC struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` ClientID string `json:"clientID"` @@ -72,35 +96,6 @@ type OIDC struct { getIdentityFunc GetIdentityFunc } -// IsAdmin returns true if the given email is in the Admins allowlist, false -// otherwise. -func (o *OIDC) IsAdmin(email string) bool { - if email != "" { - email = sanitizeEmail(email) - for _, e := range o.Admins { - if email == sanitizeEmail(e) { - return true - } - } - } - return false -} - -// IsAdminGroup returns true if the one group in the given list is in the Admins -// allowlist, false otherwise. -func (o *OIDC) IsAdminGroup(groups []string) bool { - for _, g := range groups { - // The groups and emails can be in the same array for now, but consider - // making a specialized option later. - for _, gadmin := range o.Admins { - if g == gadmin { - return true - } - } - } - return false -} - func sanitizeEmail(email string) string { if i := strings.LastIndex(email, "@"); i >= 0 { email = email[:i] + strings.ToLower(email[i:]) @@ -111,6 +106,15 @@ func sanitizeEmail(email string) string { // GetID returns the provisioner unique identifier, the OIDC provisioner the // uses the clientID for this. func (o *OIDC) GetID() string { + if o.ID != "" { + return o.ID + } + return o.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (o *OIDC) GetIDForToken() string { return o.ClientID } @@ -144,7 +148,7 @@ func (o *OIDC) GetType() Type { } // GetEncryptedKey is not available in an OIDC provisioner. -func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) { +func (o *OIDC) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } @@ -189,7 +193,7 @@ func (o *OIDC) Init(config Config) (err error) { } // Replace {tenantid} with the configured one if o.TenantID != "" { - o.configuration.Issuer = strings.Replace(o.configuration.Issuer, "{tenantid}", o.TenantID, -1) + o.configuration.Issuer = strings.ReplaceAll(o.configuration.Issuer, "{tenantid}", o.TenantID) } // Get JWK key set o.keyStore, err = newKeyStore(o.configuration.JWKSetURI) @@ -224,7 +228,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error { } // Validate domains (case-insensitive) - if p.Email != "" && len(o.Domains) > 0 && !o.IsAdmin(p.Email) { + if p.Email != "" && len(o.Domains) > 0 && !p.IsAdmin(o.Admins) { email := sanitizeEmail(p.Email) var found bool for _, d := range o.Domains { @@ -303,9 +307,10 @@ func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error { } // Only admins can revoke certificates. - if o.IsAdmin(claims.Email) { + if claims.IsAdmin(o.Admins) { return nil } + return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token") } @@ -341,7 +346,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // Use the default template unless no-templates are configured and email is // an admin, in that case we will use the CR template. defaultTemplate := x509util.DefaultLeafTemplate - if !o.Options.GetX509Options().HasTemplate() && o.IsAdmin(claims.Email) { + if !o.Options.GetX509Options().HasTemplate() && claims.IsAdmin(o.Admins) { defaultTemplate = x509util.DefaultAdminLeafTemplate } @@ -367,7 +372,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // certificate was configured to allow renewals. func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if o.claimer.IsDisableRenewal() { - return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID()) + return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName()) } return nil } @@ -375,7 +380,7 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !o.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID()) + return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName()) } claims, err := o.authorizeToken(token) if err != nil { @@ -410,10 +415,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Use the default template unless no-templates are configured and email is // an admin, in that case we will use the parameters in the request. - isAdmin := o.IsAdmin(claims.Email) - if !isAdmin && len(claims.Groups) > 0 { - isAdmin = o.IsAdminGroup(claims.Groups) - } + isAdmin := claims.IsAdmin(o.Admins) defaultTemplate := sshutil.DefaultTemplate if isAdmin && !o.Options.GetSSHOptions().HasTemplate() { defaultTemplate = sshutil.DefaultAdminTemplate @@ -461,10 +463,11 @@ func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error { } // Only admins can revoke certificates. - if !o.IsAdmin(claims.Email) { - return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token") + if claims.IsAdmin(o.Admins) { + return nil } - return nil + + return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token") } func getAndDecode(uri string, v interface{}) error { diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 48f879a8..7bf6ad7a 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -321,32 +321,26 @@ func TestOIDC_AuthorizeSign(t *testing.T) { assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) - } else { - if assert.NotNil(t, got) { - if tt.name == "admin" { - assert.Len(t, 5, got) - } else { - assert.Len(t, 5, got) - } - for _, o := range got { - switch v := o.(type) { - case certificateOptionsFunc: - case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeOIDC)) - assert.Equals(t, v.Name, tt.prov.GetName()) - assert.Equals(t, v.CredentialID, tt.prov.ClientID) - assert.Len(t, 0, v.KeyValuePairs) - case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) - case defaultPublicKeyValidator: - case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) - case emailOnlyIdentity: - assert.Equals(t, string(v), "name@smallstep.com") - default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) - } + } else if assert.NotNil(t, got) { + assert.Len(t, 5, got) + for _, o := range got { + switch v := o.(type) { + case certificateOptionsFunc: + case *provisionerExtensionOption: + assert.Equals(t, v.Type, int(TypeOIDC)) + assert.Equals(t, v.Name, tt.prov.GetName()) + assert.Equals(t, v.CredentialID, tt.prov.ClientID) + assert.Len(t, 0, v.KeyValuePairs) + case profileDefaultDuration: + assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + case defaultPublicKeyValidator: + case *validityValidator: + assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + case emailOnlyIdentity: + assert.Equals(t, string(v), "name@smallstep.com") + default: + assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) } } } @@ -698,3 +692,39 @@ func Test_sanitizeEmail(t *testing.T) { }) } } + +func Test_openIDPayload_IsAdmin(t *testing.T) { + type fields struct { + Email string + Groups []string + } + type args struct { + admins []string + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + {"ok email", fields{"admin@smallstep.com", nil}, args{[]string{"admin@smallstep.com"}}, true}, + {"ok email multiple", fields{"admin@smallstep.com", []string{"admin", "eng"}}, args{[]string{"eng@smallstep.com", "admin@smallstep.com"}}, true}, + {"ok email sanitized", fields{"admin@Smallstep.com", nil}, args{[]string{"admin@smallStep.com"}}, true}, + {"ok group", fields{"", []string{"admin"}}, args{[]string{"admin"}}, true}, + {"ok group multiple", fields{"admin@smallstep.com", []string{"engineering", "admin"}}, args{[]string{"admin"}}, true}, + {"fail missing", fields{"eng@smallstep.com", []string{"admin"}}, args{[]string{"admin@smallstep.com"}}, false}, + {"fail email letter case", fields{"Admin@smallstep.com", []string{}}, args{[]string{"admin@smallstep.com"}}, false}, + {"fail group letter case", fields{"", []string{"Admin"}}, args{[]string{"admin"}}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &openIDPayload{ + Email: tt.fields.Email, + Groups: tt.fields.Groups, + } + if got := o.IsAdmin(tt.args.admins); got != tt.want { + t.Errorf("openIDPayload.IsAdmin() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/options.go b/authority/provisioner/options.go index 100aa588..f86c4863 100644 --- a/authority/provisioner/options.go +++ b/authority/provisioner/options.go @@ -138,7 +138,7 @@ func unsafeParseSigned(s string) (map[string]interface{}, error) { return nil, err } claims := make(map[string]interface{}) - if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { + if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, err } return claims, nil diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 1ac0bbf0..5d6b2f80 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -4,6 +4,7 @@ import ( "context" "crypto/x509" "encoding/json" + stderrors "errors" "net/url" "regexp" "strings" @@ -17,6 +18,7 @@ import ( // Interface is the interface that all provisioner types must implement. type Interface interface { GetID() string + GetIDForToken() string GetTokenID(token string) (string, error) GetName() string GetType() Type @@ -31,6 +33,17 @@ type Interface interface { AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) } +// ErrAllowTokenReuse is an error that is returned by provisioners that allows +// the reuse of tokens. +// +// This is, for example, returned by the Azure provisioner when +// DisableTrustOnFirstUse is set to true. Azure caches tokens for up to 24hr and +// has no mechanism for getting a different token - this can be an issue when +// rebooting a VM. In contrast, AWS and GCP have facilities for requesting a new +// token. Therefore, for the Azure provisioner we are enabling token reuse, with +// the understanding that we are not following security best practices +var ErrAllowTokenReuse = stderrors.New("allow token reuse") + // Audiences stores all supported audiences by request type. type Audiences struct { Sign []string @@ -110,7 +123,7 @@ func (a Audiences) WithFragment(fragment string) Audiences { // generateSignAudience generates a sign audience with the format // https:///1.0/sign#provisionerID -func generateSignAudience(caURL string, provisionerID string) (string, error) { +func generateSignAudience(caURL, provisionerID string) (string, error) { u, err := url.Parse(caURL) if err != nil { return "", errors.Wrapf(err, "error parsing %s", caURL) @@ -394,6 +407,7 @@ type MockProvisioner struct { Mret1, Mret2, Mret3 interface{} Merr error MgetID func() string + MgetIDForToken func() string MgetTokenID func(string) (string, error) MgetName func() string MgetType func() Type @@ -416,6 +430,14 @@ func (m *MockProvisioner) GetID() string { return m.Mret1.(string) } +// GetIDForToken mock +func (m *MockProvisioner) GetIDForToken() string { + if m.MgetIDForToken != nil { + return m.MgetIDForToken() + } + return m.Mret1.(string) +} + // GetTokenID mock func (m *MockProvisioner) GetTokenID(token string) (string, error) { if m.MgetTokenID != nil { diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 7673ecc2..145a1920 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -11,6 +11,7 @@ import ( // SCEP provisioning flow type SCEP struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` @@ -27,7 +28,16 @@ type SCEP struct { } // GetID returns the provisioner unique identifier. -func (s SCEP) GetID() string { +func (s *SCEP) GetID() string { + if s.ID != "" { + return s.ID + } + return s.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (s *SCEP) GetIDForToken() string { return "scep/" + s.Name } diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index a872513e..158470d1 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "math/big" + "strings" "time" "github.com/pkg/errors" @@ -455,10 +456,10 @@ func containsAllMembers(group, subgroup []string) bool { } visit := make(map[string]struct{}, lg) for i := 0; i < lg; i++ { - visit[group[i]] = struct{}{} + visit[strings.ToLower(group[i])] = struct{}{} } for i := 0; i < lsg; i++ { - if _, ok := visit[subgroup[i]]; !ok { + if _, ok := visit[strings.ToLower(subgroup[i])]; !ok { return false } } diff --git a/authority/provisioner/sign_ssh_options_test.go b/authority/provisioner/sign_ssh_options_test.go index 693690f6..3a1ff324 100644 --- a/authority/provisioner/sign_ssh_options_test.go +++ b/authority/provisioner/sign_ssh_options_test.go @@ -44,7 +44,7 @@ func TestSSHOptions_Modify(t *testing.T) { valid func(*ssh.Certificate) err error } - tests := map[string](func() test){ + tests := map[string]func() test{ "fail/unexpected-cert-type": func() test { return test{ so: SignSSHOptions{CertType: "foo"}, @@ -117,7 +117,7 @@ func TestSSHOptions_Match(t *testing.T) { cmp SignSSHOptions err error } - tests := map[string](func() test){ + tests := map[string]func() test{ "fail/cert-type": func() test { return test{ so: SignSSHOptions{CertType: "foo"}, @@ -208,7 +208,7 @@ func Test_sshCertPrincipalsModifier_Modify(t *testing.T) { cert *ssh.Certificate expected []string } - tests := map[string](func() test){ + tests := map[string]func() test{ "ok": func() test { a := []string{"foo", "bar"} return test{ @@ -234,7 +234,7 @@ func Test_sshCertKeyIDModifier_Modify(t *testing.T) { cert *ssh.Certificate expected string } - tests := map[string](func() test){ + tests := map[string]func() test{ "ok": func() test { a := "foo" return test{ @@ -260,7 +260,7 @@ func Test_sshCertTypeModifier_Modify(t *testing.T) { cert *ssh.Certificate expected uint32 } - tests := map[string](func() test){ + tests := map[string]func() test{ "ok/user": func() test { return test{ modifier: sshCertTypeModifier("user"), @@ -299,7 +299,7 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) { cert *ssh.Certificate expected uint64 } - tests := map[string](func() test){ + tests := map[string]func() test{ "ok": func() test { return test{ modifier: sshCertValidAfterModifier(15), @@ -324,7 +324,7 @@ func Test_sshCertDefaultsModifier_Modify(t *testing.T) { cert *ssh.Certificate valid func(*ssh.Certificate) } - tests := map[string](func() test){ + tests := map[string]func() test{ "ok/changes": func() test { n := time.Now() va := NewTimeDuration(n.Add(1 * time.Minute)) @@ -388,7 +388,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) { valid func(*ssh.Certificate) err error } - tests := map[string](func() test){ + tests := map[string]func() test{ "fail/unexpected-cert-type": func() test { cert := &ssh.Certificate{CertType: 3} return test{ diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 223f0b9e..99974ff1 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -8,7 +8,6 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" @@ -26,10 +25,10 @@ type sshPOPPayload struct { // signature requests. type SSHPOP struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Claims *Claims `json:"claims,omitempty"` - db db.AuthDB claimer *Claimer audiences Audiences sshPubKeys *SSHKeys @@ -38,6 +37,15 @@ type SSHPOP struct { // GetID returns the provisioner unique identifier. The name and credential id // should uniquely identify any SSH-POP provisioner. func (p *SSHPOP) GetID() string { + if p.ID != "" { + return p.ID + } + return p.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (p *SSHPOP) GetIDForToken() string { return "sshpop/" + p.Name } @@ -91,8 +99,7 @@ func (p *SSHPOP) Init(config Config) error { return err } - p.audiences = config.Audiences.WithFragment(p.GetID()) - p.db = config.DB + p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.sshPubKeys = config.SSHKeys return nil } @@ -100,6 +107,8 @@ func (p *SSHPOP) Init(config Config) error { // authorizeToken performs common jwt authorization actions and returns the // claims for case specific downstream parsing. // e.g. a Sign request will auth/validate different fields than a Revoke request. +// +// Checking for certificate revocation has been moved to the authority package. func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { sshCert, jwt, err := ExtractSSHPOPCert(token) if err != nil { @@ -107,14 +116,6 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa "sshpop.authorizeToken; error extracting sshpop header from token") } - // Check for revocation. - if isRevoked, err := p.db.IsSSHRevoked(strconv.FormatUint(sshCert.Serial, 10)); err != nil { - return nil, errs.Wrap(http.StatusInternalServerError, err, - "sshpop.authorizeToken; error checking checking sshpop cert revocation") - } else if isRevoked { - return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked") - } - // Check validity period of the certificate. n := time.Now() if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 5d51b90e..3d343967 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -11,7 +11,6 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" @@ -47,7 +46,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } - if err = cert.SignCert(rand.Reader, signer); err != nil { + if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } return cert, jwk, nil @@ -83,52 +82,9 @@ func TestSSHPOP_authorizeToken(t *testing.T) { err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), } }, - "fail/error-revoked-db-check": func(t *testing.T) test { - p, err := generateSSHPOP() - assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, errors.New("force") - }, - } - cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) - assert.FatalError(t, err) - tok, err := generateSSHPOPToken(p, cert, jwk) - assert.FatalError(t, err) - return test{ - p: p, - token: tok, - code: http.StatusInternalServerError, - err: errors.New("sshpop.authorizeToken; error checking checking sshpop cert revocation: force"), - } - }, - "fail/cert-already-revoked": func(t *testing.T) test { - p, err := generateSSHPOP() - assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return true, nil - }, - } - cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) - assert.FatalError(t, err) - tok, err := generateSSHPOPToken(p, cert, jwk) - assert.FatalError(t, err) - return test{ - p: p, - token: tok, - code: http.StatusUnauthorized, - err: errors.New("sshpop.authorizeToken; sshpop certificate is revoked"), - } - }, "fail/cert-not-yet-valid": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{ CertType: ssh.UserCert, ValidAfter: uint64(time.Now().Add(time.Minute).Unix()), @@ -146,11 +102,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) { "fail/cert-past-validity": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{ CertType: ssh.UserCert, ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()), @@ -168,11 +119,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) { "fail/no-signer-found": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) assert.FatalError(t, err) tok, err := generateSSHPOPToken(p, cert, jwk) @@ -187,11 +133,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) { "fail/error-parsing-claims-bad-sig": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -208,11 +149,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) { "fail/invalid-claims-issuer": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "", @@ -228,11 +164,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) { "fail/invalid-audience": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), "invalid-aud", "", @@ -248,11 +179,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) { "fail/empty-subject": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "", @@ -268,11 +194,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) { "ok": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateSSHPOPToken(p, cert, jwk) @@ -293,10 +214,8 @@ func TestSSHPOP_authorizeToken(t *testing.T) { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } - } else { - if assert.Nil(t, tc.err) { - assert.NotNil(t, claims) - } + } else if assert.Nil(t, tc.err) { + assert.NotNil(t, claims) } }) } @@ -330,11 +249,6 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { "fail/subject-not-equal-serial": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "", @@ -350,11 +264,6 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { "ok": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "", @@ -419,11 +328,6 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { "fail/not-host-cert": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "", @@ -439,11 +343,6 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { "ok": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) assert.FatalError(t, err) tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "", @@ -511,11 +410,6 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { "fail/not-host-cert": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "", @@ -531,11 +425,6 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { "ok": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) - p.db = &db.MockAuthDB{ - MIsSSHRevoked: func(sn string) (bool, error) { - return false, nil - }, - } cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) assert.FatalError(t, err) tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "", diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 534e83cf..e39efbcf 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -732,7 +732,7 @@ func withSSHPOPFile(cert *ssh.Certificate) tokOption { } } -func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { +func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", jwk.KeyID) @@ -773,7 +773,7 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T return jose.Signed(sig).Claims(claims).CompactSerialize() } -func generateOIDCToken(sub, iss, aud string, email string, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { +func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", jwk.KeyID) diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 2b05f4c8..a05f39c7 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -26,6 +26,7 @@ type x5cPayload struct { // signature requests. type X5C struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Roots []byte `json:"roots"` @@ -39,6 +40,15 @@ type X5C struct { // GetID returns the provisioner unique identifier. The name and credential id // should uniquely identify any X5C provisioner. func (p *X5C) GetID() string { + if p.ID != "" { + return p.ID + } + return p.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner +// from a token. +func (p *X5C) GetIDForToken() string { return "x5c/" + p.Name } @@ -106,7 +116,7 @@ func (p *X5C) Init(config Config) error { // Verify that at least one root was found. if len(p.rootPool.Subjects()) == 0 { - return errors.Errorf("no x509 certificates found in roots attribute for provisioner %s", p.GetName()) + return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) } // Update claims with global ones @@ -115,7 +125,7 @@ func (p *X5C) Init(config Config) error { return err } - p.audiences = config.Audiences.WithFragment(p.GetID()) + p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) return nil } @@ -129,7 +139,8 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err } verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{ - Roots: p.rootPool, + Roots: p.rootPool, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, }) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, @@ -224,7 +235,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // AuthorizeRenew returns an error if the renewal is disabled. func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID()) + return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()) } return nil } @@ -232,7 +243,7 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !p.claimer.IsSSHCAEnabled() { - return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID()) + return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token, p.audiences.SSHSign) diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 5d288de5..2959f8c6 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -70,7 +70,7 @@ func TestX5C_Init(t *testing.T) { "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences}, - err: errors.Errorf("no x509 certificates found in roots attribute for provisioner foo"), + err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), } }, "fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest { @@ -79,7 +79,7 @@ func TestX5C_Init(t *testing.T) { p.Claims = &Claims{DefaultTLSDur: &Duration{0}} return ProvisionerValidateTest{ p: p, - err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"), + err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { @@ -568,7 +568,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) { return test{ p: p, code: http.StatusUnauthorized, - err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner %s", p.GetID()), + err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -624,7 +624,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { p: p, token: "foo", code: http.StatusUnauthorized, - err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner %s", p.GetID()), + err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { diff --git a/authority/provisioners.go b/authority/provisioners.go index 99a85d46..7e02126f 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -1,14 +1,29 @@ package authority import ( + "context" "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io/ioutil" + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" + step "go.step.sm/cli-utils/config" + "go.step.sm/cli-utils/ui" + "go.step.sm/crypto/jose" + "go.step.sm/linkedca" + "gopkg.in/square/go-jose.v2/jwt" ) // GetEncryptedKey returns the JWE key corresponding to the given kid argument. func (a *Authority) GetEncryptedKey(kid string) (string, error) { + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() key, ok := a.provisioners.LoadEncryptedKey(kid) if !ok { return "", errs.NotFound("encrypted key with kid %s was not found", kid) @@ -19,6 +34,8 @@ func (a *Authority) GetEncryptedKey(kid string) (string, error) { // GetProvisioners returns a map listing each provisioner and the JWK Key Set // with their public keys. func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) { + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() provisioners, nextCursor := a.provisioners.Find(cursor, limit) return provisioners, nextCursor, nil } @@ -26,18 +43,909 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, // LoadProvisionerByCertificate returns an interface to the provisioner that // provisioned the certificate. func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) { + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() p, ok := a.provisioners.LoadByCertificate(crt) if !ok { - return nil, errs.NotFound("provisioner not found") + return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate") + } + return p, nil +} + +// LoadProvisionerByToken returns an interface to the provisioner that +// provisioned the token. +func (a *Authority) LoadProvisionerByToken(token *jwt.JSONWebToken, claims *jwt.Claims) (provisioner.Interface, error) { + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() + p, ok := a.provisioners.LoadByToken(token, claims) + if !ok { + return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from token") } return p, nil } // LoadProvisionerByID returns an interface to the provisioner with the given ID. func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() p, ok := a.provisioners.Load(id) if !ok { - return nil, errs.NotFound("provisioner not found") + return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id) } return p, nil } + +// LoadProvisionerByName returns an interface to the provisioner with the given Name. +func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() + p, ok := a.provisioners.LoadByName(name) + if !ok { + return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", name) + } + return p, nil +} + +func (a *Authority) generateProvisionerConfig(ctx context.Context) (*provisioner.Config, error) { + // Merge global and configuration claims + claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims) + if err != nil { + return nil, err + } + // TODO: should we also be combining the ssh federated roots here? + // If we rotate ssh roots keys, sshpop provisioner will lose ability to + // validate old SSH certificates, unless they are added as federated certs. + sshKeys, err := a.GetSSHRoots(ctx) + if err != nil { + return nil, err + } + return &provisioner.Config{ + Claims: claimer.Claims(), + Audiences: a.config.GetAudiences(), + DB: a.db, + SSHKeys: &provisioner.SSHKeys{ + UserKeys: sshKeys.UserKeys, + HostKeys: sshKeys.HostKeys, + }, + GetIdentityFunc: a.getIdentityFunc, + }, nil + +} + +// StoreProvisioner stores an provisioner.Interface to the authority. +func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + certProv, err := ProvisionerToCertificates(prov) + if err != nil { + return admin.WrapErrorISE(err, + "error converting to certificates provisioner from linkedca provisioner") + } + + if _, ok := a.provisioners.LoadByName(prov.GetName()); ok { + return admin.NewError(admin.ErrorBadRequestType, + "provisioner with name %s already exists", prov.GetName()) + } + if _, ok := a.provisioners.LoadByTokenID(certProv.GetIDForToken()); ok { + return admin.NewError(admin.ErrorBadRequestType, + "provisioner with token ID %s already exists", certProv.GetIDForToken()) + } + + // Store to database -- this will set the ID. + if err := a.adminDB.CreateProvisioner(ctx, prov); err != nil { + return admin.WrapErrorISE(err, "error creating admin") + } + + // We need a new conversion that has the newly set ID. + certProv, err = ProvisionerToCertificates(prov) + if err != nil { + return admin.WrapErrorISE(err, + "error converting to certificates provisioner from linkedca provisioner") + } + + provisionerConfig, err := a.generateProvisionerConfig(ctx) + if err != nil { + return admin.WrapErrorISE(err, "error generating provisioner config") + } + + if err := certProv.Init(*provisionerConfig); err != nil { + return admin.WrapErrorISE(err, "error initializing provisioner %s", prov.Name) + } + + if err := a.provisioners.Store(certProv); err != nil { + if err := a.reloadAdminResources(ctx); err != nil { + return admin.WrapErrorISE(err, "error reloading admin resources on failed provisioner store") + } + return admin.WrapErrorISE(err, "error storing provisioner in authority cache") + } + return nil +} + +// UpdateProvisioner stores an provisioner.Interface to the authority. +func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + certProv, err := ProvisionerToCertificates(nu) + if err != nil { + return admin.WrapErrorISE(err, + "error converting to certificates provisioner from linkedca provisioner") + } + + provisionerConfig, err := a.generateProvisionerConfig(ctx) + if err != nil { + return admin.WrapErrorISE(err, "error generating provisioner config") + } + + if err := certProv.Init(*provisionerConfig); err != nil { + return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name) + } + + if err := a.provisioners.Update(certProv); err != nil { + return admin.WrapErrorISE(err, "error updating provisioner '%s' in authority cache", nu.Name) + } + if err := a.adminDB.UpdateProvisioner(ctx, nu); err != nil { + if err := a.reloadAdminResources(ctx); err != nil { + return admin.WrapErrorISE(err, "error reloading admin resources on failed provisioner update") + } + return admin.WrapErrorISE(err, "error updating provisioner '%s'", nu.Name) + } + return nil +} + +// RemoveProvisioner removes an provisioner.Interface from the authority. +func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error { + a.adminMutex.Lock() + defer a.adminMutex.Unlock() + + p, ok := a.provisioners.Load(id) + if !ok { + return admin.NewError(admin.ErrorBadRequestType, + "provisioner %s not found", id) + } + + provName, provID := p.GetName(), p.GetID() + // Validate + // - Check that there will be SUPER_ADMINs that remain after we + // remove this provisioner. + if a.admins.SuperCount() == a.admins.SuperCountByProvisioner(provName) { + return admin.NewError(admin.ErrorBadRequestType, + "cannot remove provisioner %s because no super admins will remain", provName) + } + + // Delete all admins associated with the provisioner. + admins, ok := a.admins.LoadByProvisioner(provName) + if ok { + for _, adm := range admins { + if err := a.removeAdmin(ctx, adm.Id); err != nil { + return admin.WrapErrorISE(err, "error deleting admin %s, as part of provisioner %s deletion", adm.Subject, provName) + } + } + } + + // Remove provisioner from authority caches. + if err := a.provisioners.Remove(provID); err != nil { + return admin.WrapErrorISE(err, "error removing admin from authority cache") + } + // Remove provisioner from database. + if err := a.adminDB.DeleteProvisioner(ctx, provID); err != nil { + if err := a.reloadAdminResources(ctx); err != nil { + return admin.WrapErrorISE(err, "error reloading admin resources on failed provisioner remove") + } + return admin.WrapErrorISE(err, "error deleting provisioner %s", provName) + } + return nil +} + +func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) { + if password == "" { + pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one") + if err != nil { + return nil, err + } + password = string(pass) + } + + jwk, jwe, err := jose.GenerateDefaultKeyPair([]byte(password)) + if err != nil { + return nil, admin.WrapErrorISE(err, "error generating JWK key pair") + } + + jwkPubBytes, err := jwk.MarshalJSON() + if err != nil { + return nil, admin.WrapErrorISE(err, "error marshaling JWK") + } + jwePrivStr, err := jwe.CompactSerialize() + if err != nil { + return nil, admin.WrapErrorISE(err, "error serializing JWE") + } + + p := &linkedca.Provisioner{ + Name: "Admin JWK", + Type: linkedca.Provisioner_JWK, + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_JWK{ + JWK: &linkedca.JWKProvisioner{ + PublicKey: jwkPubBytes, + EncryptedPrivateKey: []byte(jwePrivStr), + }, + }, + }, + Claims: &linkedca.Claims{ + X509: &linkedca.X509Claims{ + Enabled: true, + Durations: &linkedca.Durations{ + Default: "5m", + }, + }, + }, + } + if err := db.CreateProvisioner(ctx, p); err != nil { + return nil, admin.WrapErrorISE(err, "error creating provisioner") + } + return p, nil +} + +func ValidateClaims(c *linkedca.Claims) error { + if c == nil { + return nil + } + if c.X509 != nil { + if c.X509.Durations != nil { + if err := ValidateDurations(c.X509.Durations); err != nil { + return err + } + } + } + if c.Ssh != nil { + if c.Ssh.UserDurations != nil { + if err := ValidateDurations(c.Ssh.UserDurations); err != nil { + return err + } + } + if c.Ssh.HostDurations != nil { + if err := ValidateDurations(c.Ssh.HostDurations); err != nil { + return err + } + } + } + return nil +} + +func ValidateDurations(d *linkedca.Durations) error { + var ( + err error + min, max, def *provisioner.Duration + ) + + if d.Min != "" { + min, err = provisioner.NewDuration(d.Min) + if err != nil { + return admin.WrapError(admin.ErrorBadRequestType, err, "min duration '%s' is invalid", d.Min) + } + if min.Value() < 0 { + return admin.WrapError(admin.ErrorBadRequestType, err, "min duration '%s' cannot be less than 0", d.Min) + } + } + if d.Max != "" { + max, err = provisioner.NewDuration(d.Max) + if err != nil { + return admin.WrapError(admin.ErrorBadRequestType, err, "max duration '%s' is invalid", d.Max) + } + if max.Value() < 0 { + return admin.WrapError(admin.ErrorBadRequestType, err, "max duration '%s' cannot be less than 0", d.Max) + } + } + if d.Default != "" { + def, err = provisioner.NewDuration(d.Default) + if err != nil { + return admin.WrapError(admin.ErrorBadRequestType, err, "default duration '%s' is invalid", d.Default) + } + if def.Value() < 0 { + return admin.WrapError(admin.ErrorBadRequestType, err, "default duration '%s' cannot be less than 0", d.Default) + } + } + if d.Min != "" && d.Max != "" && min.Value() > max.Value() { + return admin.NewError(admin.ErrorBadRequestType, + "min duration '%s' cannot be greater than max duration '%s'", d.Min, d.Max) + } + if d.Min != "" && d.Default != "" && min.Value() > def.Value() { + return admin.NewError(admin.ErrorBadRequestType, + "min duration '%s' cannot be greater than default duration '%s'", d.Min, d.Default) + } + if d.Default != "" && d.Max != "" && min.Value() > def.Value() { + return admin.NewError(admin.ErrorBadRequestType, + "default duration '%s' cannot be greater than max duration '%s'", d.Default, d.Max) + } + return nil +} + +func provisionerListToCertificates(l []*linkedca.Provisioner) (provisioner.List, error) { + var nu provisioner.List + for _, p := range l { + certProv, err := ProvisionerToCertificates(p) + if err != nil { + return nil, err + } + nu = append(nu, certProv) + } + return nu, nil +} + +func optionsToCertificates(p *linkedca.Provisioner) *provisioner.Options { + ops := &provisioner.Options{ + X509: &provisioner.X509Options{}, + SSH: &provisioner.SSHOptions{}, + } + if p.X509Template != nil { + ops.X509.Template = string(p.X509Template.Template) + ops.X509.TemplateData = p.X509Template.Data + } + if p.SshTemplate != nil { + ops.SSH.Template = string(p.SshTemplate.Template) + ops.SSH.TemplateData = p.SshTemplate.Data + } + return ops +} + +func durationsToCertificates(d *linkedca.Durations) (min, max, def *provisioner.Duration, err error) { + if len(d.Min) > 0 { + min, err = provisioner.NewDuration(d.Min) + if err != nil { + return nil, nil, nil, admin.WrapErrorISE(err, "error parsing minimum duration '%s'", d.Min) + } + } + if len(d.Max) > 0 { + max, err = provisioner.NewDuration(d.Max) + if err != nil { + return nil, nil, nil, admin.WrapErrorISE(err, "error parsing maximum duration '%s'", d.Max) + } + } + if len(d.Default) > 0 { + def, err = provisioner.NewDuration(d.Default) + if err != nil { + return nil, nil, nil, admin.WrapErrorISE(err, "error parsing default duration '%s'", d.Default) + } + } + return +} + +func durationsToLinkedca(d *provisioner.Duration) string { + if d == nil { + return "" + } + return d.Duration.String() +} + +// claimsToCertificates converts the linkedca provisioner claims type to the +// certifictes claims type. +func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) { + if c == nil { + return nil, nil + } + + pc := &provisioner.Claims{ + DisableRenewal: &c.DisableRenewal, + } + + var err error + + if xc := c.X509; xc != nil { + if d := xc.Durations; d != nil { + pc.MinTLSDur, pc.MaxTLSDur, pc.DefaultTLSDur, err = durationsToCertificates(d) + if err != nil { + return nil, err + } + } + } + if sc := c.Ssh; sc != nil { + pc.EnableSSHCA = &sc.Enabled + if d := sc.UserDurations; d != nil { + pc.MinUserSSHDur, pc.MaxUserSSHDur, pc.DefaultUserSSHDur, err = durationsToCertificates(d) + if err != nil { + return nil, err + } + } + if d := sc.HostDurations; d != nil { + pc.MinHostSSHDur, pc.MaxHostSSHDur, pc.DefaultHostSSHDur, err = durationsToCertificates(d) + if err != nil { + return nil, err + } + } + } + + return pc, nil +} + +func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims { + if c == nil { + return nil + } + + disableRenewal := config.DefaultDisableRenewal + if c.DisableRenewal != nil { + disableRenewal = *c.DisableRenewal + } + + lc := &linkedca.Claims{ + DisableRenewal: disableRenewal, + } + + if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil { + lc.X509 = &linkedca.X509Claims{ + Enabled: true, + Durations: &linkedca.Durations{ + Default: durationsToLinkedca(c.DefaultTLSDur), + Min: durationsToLinkedca(c.MinTLSDur), + Max: durationsToLinkedca(c.MaxTLSDur), + }, + } + } + + if c.EnableSSHCA != nil && *c.EnableSSHCA { + lc.Ssh = &linkedca.SSHClaims{ + Enabled: true, + } + if c.DefaultUserSSHDur != nil || c.MinUserSSHDur != nil || c.MaxUserSSHDur != nil { + lc.Ssh.UserDurations = &linkedca.Durations{ + Default: durationsToLinkedca(c.DefaultUserSSHDur), + Min: durationsToLinkedca(c.MinUserSSHDur), + Max: durationsToLinkedca(c.MaxUserSSHDur), + } + } + if c.DefaultHostSSHDur != nil || c.MinHostSSHDur != nil || c.MaxHostSSHDur != nil { + lc.Ssh.HostDurations = &linkedca.Durations{ + Default: durationsToLinkedca(c.DefaultHostSSHDur), + Min: durationsToLinkedca(c.MinHostSSHDur), + Max: durationsToLinkedca(c.MaxHostSSHDur), + } + } + } + + return lc +} + +func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *linkedca.Template, error) { + var err error + var x509Template, sshTemplate *linkedca.Template + + if p == nil { + return nil, nil, nil + } + + if p.X509 != nil && p.X509.HasTemplate() { + x509Template = &linkedca.Template{ + Template: nil, + Data: nil, + } + + if p.X509.Template != "" { + x509Template.Template = []byte(p.SSH.Template) + } else if p.X509.TemplateFile != "" { + filename := step.StepAbs(p.X509.TemplateFile) + if x509Template.Template, err = ioutil.ReadFile(filename); err != nil { + return nil, nil, errors.Wrap(err, "error reading x509 template") + } + } + } + + if p.SSH != nil && p.SSH.HasTemplate() { + sshTemplate = &linkedca.Template{ + Template: nil, + Data: nil, + } + + if p.SSH.Template != "" { + sshTemplate.Template = []byte(p.SSH.Template) + } else if p.SSH.TemplateFile != "" { + filename := step.StepAbs(p.SSH.TemplateFile) + if sshTemplate.Template, err = ioutil.ReadFile(filename); err != nil { + return nil, nil, errors.Wrap(err, "error reading ssh template") + } + } + } + + return x509Template, sshTemplate, nil +} + +func provisionerPEMToLinkedca(b []byte) [][]byte { + var roots [][]byte + var block *pem.Block + for { + if block, b = pem.Decode(b); block == nil { + break + } + roots = append(roots, pem.EncodeToMemory(block)) + } + return roots +} + +// ProvisionerToCertificates converts the linkedca provisioner type to the certificates provisioner +// interface. +func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, error) { + claims, err := claimsToCertificates(p.Claims) + if err != nil { + return nil, err + } + + details := p.Details.GetData() + if details == nil { + return nil, errors.New("provisioner does not have any details") + } + + options := optionsToCertificates(p) + + switch d := details.(type) { + case *linkedca.ProvisionerDetails_JWK: + jwk := new(jose.JSONWebKey) + if err := json.Unmarshal(d.JWK.PublicKey, &jwk); err != nil { + return nil, errors.Wrap(err, "error unmarshaling public key") + } + return &provisioner.JWK{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + Key: jwk, + EncryptedKey: string(d.JWK.EncryptedPrivateKey), + Claims: claims, + Options: options, + }, nil + case *linkedca.ProvisionerDetails_X5C: + var roots []byte + for i, root := range d.X5C.GetRoots() { + if i > 0 { + roots = append(roots, '\n') + } + roots = append(roots, root...) + } + return &provisioner.X5C{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + Roots: roots, + Claims: claims, + Options: options, + }, nil + case *linkedca.ProvisionerDetails_K8SSA: + var publicKeys []byte + for i, k := range d.K8SSA.GetPublicKeys() { + if i > 0 { + publicKeys = append(publicKeys, '\n') + } + publicKeys = append(publicKeys, k...) + } + return &provisioner.K8sSA{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + PubKeys: publicKeys, + Claims: claims, + Options: options, + }, nil + case *linkedca.ProvisionerDetails_SSHPOP: + return &provisioner.SSHPOP{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + Claims: claims, + }, nil + case *linkedca.ProvisionerDetails_ACME: + cfg := d.ACME + return &provisioner.ACME{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + ForceCN: cfg.ForceCn, + Claims: claims, + Options: options, + }, nil + case *linkedca.ProvisionerDetails_OIDC: + cfg := d.OIDC + return &provisioner.OIDC{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + TenantID: cfg.TenantId, + ClientID: cfg.ClientId, + ClientSecret: cfg.ClientSecret, + ConfigurationEndpoint: cfg.ConfigurationEndpoint, + Admins: cfg.Admins, + Domains: cfg.Domains, + Groups: cfg.Groups, + ListenAddress: cfg.ListenAddress, + Claims: claims, + Options: options, + }, nil + case *linkedca.ProvisionerDetails_AWS: + cfg := d.AWS + instanceAge, err := parseInstanceAge(cfg.InstanceAge) + if err != nil { + return nil, err + } + return &provisioner.AWS{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + Accounts: cfg.Accounts, + DisableCustomSANs: cfg.DisableCustomSans, + DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse, + InstanceAge: instanceAge, + Claims: claims, + Options: options, + }, nil + case *linkedca.ProvisionerDetails_GCP: + cfg := d.GCP + instanceAge, err := parseInstanceAge(cfg.InstanceAge) + if err != nil { + return nil, err + } + return &provisioner.GCP{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + ServiceAccounts: cfg.ServiceAccounts, + ProjectIDs: cfg.ProjectIds, + DisableCustomSANs: cfg.DisableCustomSans, + DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse, + InstanceAge: instanceAge, + Claims: claims, + Options: options, + }, nil + case *linkedca.ProvisionerDetails_Azure: + cfg := d.Azure + return &provisioner.Azure{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + TenantID: cfg.TenantId, + ResourceGroups: cfg.ResourceGroups, + Audience: cfg.Audience, + DisableCustomSANs: cfg.DisableCustomSans, + DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse, + Claims: claims, + Options: options, + }, nil + default: + return nil, fmt.Errorf("provisioner %s not implemented", p.Type) + } +} + +// ProvisionerToLinkedca converts a provisioner.Interface to a +// linkedca.Provisioner type. +func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, error) { + switch p := p.(type) { + case *provisioner.JWK: + x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + publicKey, err := json.Marshal(p.Key) + if err != nil { + return nil, errors.Wrap(err, "error marshaling key") + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_JWK, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_JWK{ + JWK: &linkedca.JWKProvisioner{ + PublicKey: publicKey, + EncryptedPrivateKey: []byte(p.EncryptedKey), + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + }, nil + case *provisioner.OIDC: + x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_OIDC, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_OIDC{ + OIDC: &linkedca.OIDCProvisioner{ + ClientId: p.ClientID, + ClientSecret: p.ClientSecret, + ConfigurationEndpoint: p.ConfigurationEndpoint, + Admins: p.Admins, + Domains: p.Domains, + Groups: p.Groups, + ListenAddress: p.ListenAddress, + TenantId: p.TenantID, + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + }, nil + case *provisioner.GCP: + x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_GCP, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_GCP{ + GCP: &linkedca.GCPProvisioner{ + ServiceAccounts: p.ServiceAccounts, + ProjectIds: p.ProjectIDs, + DisableCustomSans: p.DisableCustomSANs, + DisableTrustOnFirstUse: p.DisableTrustOnFirstUse, + InstanceAge: p.InstanceAge.String(), + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + }, nil + case *provisioner.AWS: + x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_AWS, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_AWS{ + AWS: &linkedca.AWSProvisioner{ + Accounts: p.Accounts, + DisableCustomSans: p.DisableCustomSANs, + DisableTrustOnFirstUse: p.DisableTrustOnFirstUse, + InstanceAge: p.InstanceAge.String(), + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + }, nil + case *provisioner.Azure: + x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_AZURE, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_Azure{ + Azure: &linkedca.AzureProvisioner{ + TenantId: p.TenantID, + ResourceGroups: p.ResourceGroups, + Audience: p.Audience, + DisableCustomSans: p.DisableCustomSANs, + DisableTrustOnFirstUse: p.DisableTrustOnFirstUse, + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + }, nil + case *provisioner.ACME: + x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_ACME, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + ForceCn: p.ForceCN, + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + }, nil + case *provisioner.X5C: + x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_X5C, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_X5C{ + X5C: &linkedca.X5CProvisioner{ + Roots: provisionerPEMToLinkedca(p.Roots), + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + }, nil + case *provisioner.K8sSA: + x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_K8SSA, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_K8SSA{ + K8SSA: &linkedca.K8SSAProvisioner{ + PublicKeys: provisionerPEMToLinkedca(p.PubKeys), + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + }, nil + case *provisioner.SSHPOP: + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_SSHPOP, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_SSHPOP{ + SSHPOP: &linkedca.SSHPOPProvisioner{}, + }, + }, + Claims: claimsToLinkedca(p.Claims), + }, nil + case *provisioner.SCEP: + x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_SCEP, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_SCEP{ + SCEP: &linkedca.SCEPProvisioner{ + ForceCn: p.ForceCN, + Challenge: p.GetChallengePassword(), + Capabilities: p.Capabilities, + MinimumPublicKeyLength: int32(p.MinimumPublicKeyLength), + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + }, nil + default: + return nil, fmt.Errorf("provisioner %s not implemented", p.GetType()) + } +} + +func parseInstanceAge(age string) (provisioner.Duration, error) { + var instanceAge provisioner.Duration + if age != "" { + iap, err := provisioner.NewDuration(age) + if err != nil { + return instanceAge, err + } + instanceAge = *iap + } + return instanceAge, nil +} diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 94b2d715..3975031b 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -56,7 +56,7 @@ func TestGetEncryptedKey(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - val, ok := tc.a.provisioners.Load("max:" + tc.kid) + val, ok := tc.a.provisioners.Load("mike:" + tc.kid) assert.Fatal(t, ok) p, ok := val.(*provisioner.JWK) assert.Fatal(t, ok) diff --git a/authority/ssh.go b/authority/ssh.go index bb0ff562..762319ae 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -10,11 +10,11 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/templates" - "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" "go.step.sm/crypto/sshutil" "golang.org/x/crypto/ssh" @@ -32,103 +32,17 @@ const ( SSHAddUserCommand = "sudo useradd -m ; nc -q0 localhost 22" ) -// SSHConfig contains the user and host keys. -type SSHConfig struct { - HostKey string `json:"hostKey"` - UserKey string `json:"userKey"` - Keys []*SSHPublicKey `json:"keys,omitempty"` - AddUserPrincipal string `json:"addUserPrincipal,omitempty"` - AddUserCommand string `json:"addUserCommand,omitempty"` - Bastion *Bastion `json:"bastion,omitempty"` -} - -// Bastion contains the custom properties used on bastion. -type Bastion struct { - Hostname string `json:"hostname"` - User string `json:"user,omitempty"` - Port string `json:"port,omitempty"` - Command string `json:"cmd,omitempty"` - Flags string `json:"flags,omitempty"` -} - -// HostTag are tagged with k,v pairs. These tags are how a user is ultimately -// associated with a host. -type HostTag struct { - ID string - Name string - Value string -} - -// Host defines expected attributes for an ssh host. -type Host struct { - HostID string `json:"hid"` - HostTags []HostTag `json:"host_tags"` - Hostname string `json:"hostname"` -} - -// Validate checks the fields in SSHConfig. -func (c *SSHConfig) Validate() error { - if c == nil { - return nil - } - for _, k := range c.Keys { - if err := k.Validate(); err != nil { - return err - } - } - return nil -} - -// SSHPublicKey contains a public key used by federated CAs to keep old signing -// keys for this ca. -type SSHPublicKey struct { - Type string `json:"type"` - Federated bool `json:"federated"` - Key jose.JSONWebKey `json:"key"` - publicKey ssh.PublicKey -} - -// Validate checks the fields in SSHPublicKey. -func (k *SSHPublicKey) Validate() error { - switch { - case k.Type == "": - return errors.New("type cannot be empty") - case k.Type != provisioner.SSHHostCert && k.Type != provisioner.SSHUserCert: - return errors.Errorf("invalid type %s, it must be user or host", k.Type) - case !k.Key.IsPublic(): - return errors.New("invalid key type, it must be a public key") - } - - key, err := ssh.NewPublicKey(k.Key.Key) - if err != nil { - return errors.Wrap(err, "error creating ssh key") - } - k.publicKey = key - return nil -} - -// PublicKey returns the ssh public key. -func (k *SSHPublicKey) PublicKey() ssh.PublicKey { - return k.publicKey -} - -// SSHKeys represents the SSH User and Host public keys. -type SSHKeys struct { - UserKeys []ssh.PublicKey - HostKeys []ssh.PublicKey -} - // GetSSHRoots returns the SSH User and Host public keys. -func (a *Authority) GetSSHRoots(context.Context) (*SSHKeys, error) { - return &SSHKeys{ +func (a *Authority) GetSSHRoots(context.Context) (*config.SSHKeys, error) { + return &config.SSHKeys{ HostKeys: a.sshCAHostCerts, UserKeys: a.sshCAUserCerts, }, nil } // GetSSHFederation returns the public keys for federated SSH signers. -func (a *Authority) GetSSHFederation(context.Context) (*SSHKeys, error) { - return &SSHKeys{ +func (a *Authority) GetSSHFederation(context.Context) (*config.SSHKeys, error) { + return &config.SSHKeys{ HostKeys: a.sshCAHostFederatedCerts, UserKeys: a.sshCAUserFederatedCerts, }, nil @@ -194,7 +108,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin // GetSSHBastion returns the bastion configuration, for the given pair user, // hostname. -func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname string) (*Bastion, error) { +func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*config.Bastion, error) { if a.sshBastionFunc != nil { bs, err := a.sshBastionFunc(ctx, user, hostname) return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion") @@ -325,7 +239,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi } } - if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db") } @@ -335,7 +249,11 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period") + return nil, errs.BadRequest("renewSSH: cannot renew certificate without validity period") + } + + if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { + return nil, err } backdate := a.config.AuthorityConfig.Backdate.Duration @@ -380,7 +298,7 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") } - if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db") } @@ -405,6 +323,10 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period") } + if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { + return nil, err + } + backdate := a.config.AuthorityConfig.Backdate.Duration duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second now := time.Now() @@ -455,13 +377,23 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } } - if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db") } return cert, nil } +func (a *Authority) storeSSHCertificate(cert *ssh.Certificate) error { + type sshCertificateStorer interface { + StoreSSHCertificate(crt *ssh.Certificate) error + } + if s, ok := a.adminDB.(sshCertificateStorer); ok { + return s.StoreSSHCertificate(cert) + } + return a.db.StoreSSHCertificate(cert) +} + // IsValidForAddUser checks if a user provisioner certificate can be issued to // the given certificate. func IsValidForAddUser(cert *ssh.Certificate) error { @@ -537,7 +469,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje } cert.Signature = sig - if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db") } @@ -545,7 +477,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje } // CheckSSHHost checks the given principal has been registered before. -func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) { +func (a *Authority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { if a.sshCheckHostFunc != nil { exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates()) if err != nil { @@ -568,7 +500,7 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st } // GetSSHHosts returns a list of valid host principals. -func (a *Authority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]Host, error) { +func (a *Authority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) { if a.sshGetHostsFunc != nil { hosts, err := a.sshGetHostsFunc(ctx, cert) return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts") @@ -578,9 +510,9 @@ func (a *Authority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([] return nil, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts") } - hosts := make([]Host, len(hostnames)) + hosts := make([]config.Host, len(hostnames)) for i, hn := range hostnames { - hosts[i] = Host{Hostname: hn} + hosts[i] = config.Host{Hostname: hn} } return hosts, nil } @@ -599,5 +531,5 @@ func (a *Authority) getAddUserCommand(principal string) string { } else { cmd = a.config.SSH.AddUserCommand } - return strings.Replace(cmd, "", principal, -1) + return strings.ReplaceAll(cmd, "", principal) } diff --git a/authority/ssh_test.go b/authority/ssh_test.go index 1662260c..41df8576 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -87,6 +87,52 @@ func (m sshTestOptionsModifier) Modify(cert *ssh.Certificate, opts provisioner.S return fmt.Errorf(string(m)) } +func TestAuthority_initHostOnly(t *testing.T) { + auth := testAuthority(t, func(a *Authority) error { + a.config.SSH.UserKey = "" + return nil + }) + + // Check keys + keys, err := auth.GetSSHRoots(context.Background()) + assert.NoError(t, err) + assert.Len(t, 1, keys.HostKeys) + assert.Len(t, 0, keys.UserKeys) + + // Check templates, user templates should work fine. + _, err = auth.GetSSHConfig(context.Background(), "user", nil) + assert.NoError(t, err) + + _, err = auth.GetSSHConfig(context.Background(), "host", map[string]string{ + "Certificate": "ssh_host_ecdsa_key-cert.pub", + "Key": "ssh_host_ecdsa_key", + }) + assert.Error(t, err) +} + +func TestAuthority_initUserOnly(t *testing.T) { + auth := testAuthority(t, func(a *Authority) error { + a.config.SSH.HostKey = "" + return nil + }) + + // Check keys + keys, err := auth.GetSSHRoots(context.Background()) + assert.NoError(t, err) + assert.Len(t, 0, keys.HostKeys) + assert.Len(t, 1, keys.UserKeys) + + // Check templates, host templates should work fine. + _, err = auth.GetSSHConfig(context.Background(), "host", map[string]string{ + "Certificate": "ssh_host_ecdsa_key-cert.pub", + "Key": "ssh_host_ecdsa_key", + }) + assert.NoError(t, err) + + _, err = auth.GetSSHConfig(context.Background(), "user", nil) + assert.Error(t, err) +} + func TestAuthority_SignSSH(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) @@ -153,6 +199,8 @@ func TestAuthority_SignSSH(t *testing.T) { }{ {"ok-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, {"ok-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, + {"ok-user-only", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, + {"ok-host-only", fields{nil, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, {"ok-opts-type-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false}, {"ok-opts-type-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false}, {"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, @@ -590,69 +638,6 @@ func TestSSHConfig_Validate(t *testing.T) { } } -func TestSSHPublicKey_Validate(t *testing.T) { - key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) - assert.FatalError(t, err) - - type fields struct { - Type string - Federated bool - Key jose.JSONWebKey - } - tests := []struct { - name string - fields fields - wantErr bool - }{ - {"user", fields{"user", true, key.Public()}, false}, - {"host", fields{"host", false, key.Public()}, false}, - {"empty", fields{"", true, key.Public()}, true}, - {"badType", fields{"bad", false, key.Public()}, true}, - {"badKey", fields{"user", false, *key}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - k := &SSHPublicKey{ - Type: tt.fields.Type, - Federated: tt.fields.Federated, - Key: tt.fields.Key, - } - if err := k.Validate(); (err != nil) != tt.wantErr { - t.Errorf("SSHPublicKey.Validate() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestSSHPublicKey_PublicKey(t *testing.T) { - key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) - assert.FatalError(t, err) - pub, err := ssh.NewPublicKey(key.Public().Key) - assert.FatalError(t, err) - - type fields struct { - publicKey ssh.PublicKey - } - tests := []struct { - name string - fields fields - want ssh.PublicKey - }{ - {"ok", fields{pub}, pub}, - {"nil", fields{nil}, nil}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - k := &SSHPublicKey{ - publicKey: tt.fields.publicKey, - } - if got := k.PublicKey(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("SSHPublicKey.PublicKey() = %v, want %v", got, tt.want) - } - }) - } -} - func TestAuthority_GetSSHBastion(t *testing.T) { bastion := &Bastion{ Hostname: "bastion.local", @@ -813,6 +798,11 @@ func TestAuthority_RekeySSH(t *testing.T) { now := time.Now().UTC() a := testAuthority(t) + a.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, + } type test struct { auth *Authority @@ -826,6 +816,56 @@ func TestAuthority_RekeySSH(t *testing.T) { code int } tests := map[string]func(t *testing.T) *test{ + "fail/is-revoked": func(t *testing.T) *test { + auth := testAuthority(t) + auth.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return true, nil + }, + } + return &test{ + auth: auth, + userSigner: signer, + hostSigner: signer, + cert: &ssh.Certificate{ + Serial: 1234567890, + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + CertType: ssh.UserCert, + ValidPrincipals: []string{"foo", "bar"}, + KeyId: "foo", + }, + key: pub, + signOpts: []provisioner.SignOption{}, + err: errors.New("authority.authorizeSSHCertificate: certificate has been revoked"), + code: http.StatusUnauthorized, + } + }, + "fail/is-revoked-error": func(t *testing.T) *test { + auth := testAuthority(t) + auth.db = &db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, errors.New("an error") + }, + } + return &test{ + auth: auth, + userSigner: signer, + hostSigner: signer, + cert: &ssh.Certificate{ + Serial: 1234567890, + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + CertType: ssh.UserCert, + ValidPrincipals: []string{"foo", "bar"}, + KeyId: "foo", + }, + key: pub, + signOpts: []provisioner.SignOption{}, + err: errors.New("authority.authorizeSSHCertificate: an error"), + code: http.StatusInternalServerError, + } + }, "fail/opts-type": func(t *testing.T) *test { return &test{ userSigner: signer, @@ -894,6 +934,9 @@ func TestAuthority_RekeySSH(t *testing.T) { "fail/db-store": func(t *testing.T) *test { return &test{ auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ + MIsSSHRevoked: func(sn string) (bool, error) { + return false, nil + }, MStoreSSHCertificate: func(cert *ssh.Certificate) error { return errors.New("force") }, diff --git a/authority/status/status.go b/authority/status/status.go new file mode 100644 index 00000000..49e4c0bb --- /dev/null +++ b/authority/status/status.go @@ -0,0 +1,11 @@ +package status + +// Type is the type for status. +type Type string + +var ( + // Active active + Active = Type("active") + // Deleted deleted + Deleted = Type("deleted") +) diff --git a/authority/tls.go b/authority/tls.go index b7b2f936..839866a2 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -12,6 +12,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" @@ -20,22 +21,22 @@ import ( "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ssh" ) // GetTLSOptions returns the tls options configured. -func (a *Authority) GetTLSOptions() *TLSOptions { +func (a *Authority) GetTLSOptions() *config.TLSOptions { return a.config.TLS } var oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35} var oidSubjectKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 14} -func withDefaultASN1DN(def *ASN1DN) provisioner.CertificateModifierFunc { +func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc { return func(crt *x509.Certificate, opts provisioner.SignOptions) error { if def == nil { return errors.New("default ASN1DN template cannot be nil") } - if len(crt.Subject.Country) == 0 && def.Country != "" { crt.Subject.Country = append(crt.Subject.Country, def.Country) } @@ -54,7 +55,12 @@ func withDefaultASN1DN(def *ASN1DN) provisioner.CertificateModifierFunc { if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" { crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress) } - + if crt.Subject.SerialNumber == "" && def.SerialNumber != "" { + crt.Subject.SerialNumber = def.SerialNumber + } + if crt.Subject.CommonName == "" && def.CommonName != "" { + crt.Subject.CommonName = def.CommonName + } return nil } } @@ -279,9 +285,15 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 // `StoreCertificate(...*x509.Certificate) error` instead of just // `StoreCertificate(*x509.Certificate) error`. func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { - if s, ok := a.db.(interface { + type certificateChainStorer interface { StoreCertificateChain(...*x509.Certificate) error - }); ok { + } + // Store certificate in linkedca + if s, ok := a.adminDB.(certificateChainStorer); ok { + return s.StoreCertificateChain(fullchain...) + } + // Store certificate in local db + if s, ok := a.db.(certificateChainStorer); ok { return s.StoreCertificateChain(fullchain...) } return a.db.StoreCertificate(fullchain[0]) @@ -292,9 +304,15 @@ func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { // // TODO: at some point we should implement this in the standard implementation. func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error { - if s, ok := a.db.(interface { + type renewedCertificateChainStorer interface { StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error - }); ok { + } + // Store certificate in linkedca + if s, ok := a.adminDB.(renewedCertificateChainStorer); ok { + return s.StoreRenewedCertificate(oldCert, fullchain...) + } + // Store certificate in local db + if s, ok := a.db.(renewedCertificateChainStorer); ok { return s.StoreRenewedCertificate(oldCert, fullchain...) } return a.db.StoreCertificate(fullchain[0]) @@ -359,29 +377,28 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error } // This method will also validate the audiences for JWK provisioners. - var ok bool - p, ok = a.provisioners.LoadByToken(token, &claims.Claims) - if !ok { - return errs.InternalServer("authority.Revoke; provisioner not found", opts...) + p, err = a.LoadProvisionerByToken(token, &claims.Claims) + if err != nil { + return err } rci.ProvisionerID = p.GetID() rci.TokenID, err = p.GetTokenID(revokeOpts.OTT) - if err != nil { + if err != nil && !errors.Is(err, provisioner.ErrAllowTokenReuse) { return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke; could not get ID for token") } - opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID)) - opts = append(opts, errs.WithKeyVal("tokenID", rci.TokenID)) - } else { + opts = append(opts, + errs.WithKeyVal("provisionerID", rci.ProvisionerID), + errs.WithKeyVal("tokenID", rci.TokenID), + ) + } else if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil { // Load the Certificate provisioner if one exists. - if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil { - rci.ProvisionerID = p.GetID() - opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID)) - } + rci.ProvisionerID = p.GetID() + opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID)) } if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod { - err = a.db.RevokeSSH(rci) + err = a.revokeSSH(nil, rci) } else { // Revoke an X.509 certificate using CAS. If the certificate is not // provided we will try to read it from the db. If the read fails we @@ -408,7 +425,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error } // Save as revoked in the Db. - err = a.db.Revoke(rci) + err = a.revoke(revokedCert, rci) } switch err { case nil: @@ -423,6 +440,24 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error } } +func (a *Authority) revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error { + if lca, ok := a.adminDB.(interface { + Revoke(*x509.Certificate, *db.RevokedCertificateInfo) error + }); ok { + return lca.Revoke(crt, rci) + } + return a.db.Revoke(rci) +} + +func (a *Authority) revokeSSH(crt *ssh.Certificate, rci *db.RevokedCertificateInfo) error { + if lca, ok := a.adminDB.(interface { + RevokeSSH(*ssh.Certificate, *db.RevokedCertificateInfo) error + }); ok { + return lca.RevokeSSH(crt, rci) + } + return a.db.Revoke(rci) +} + // GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server. func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) { fatal := func(err error) (*tls.Certificate, error) { diff --git a/authority/tls_test.go b/authority/tls_test.go index 4c936f0c..f1d1748d 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -426,6 +426,7 @@ ZYtQ9Ot36qc= {Id: stepOIDProvisioner, Value: []byte("foo")}, {Id: []int{1, 1, 1}, Value: []byte("bar")}})) now := time.Now().UTC() + // nolint:gocritic enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{ NotBefore: now, NotAfter: now.Add(365 * 24 * time.Hour), @@ -656,7 +657,7 @@ func TestAuthority_Renew(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), + err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, @@ -856,7 +857,7 @@ func TestAuthority_Rekey(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner dev:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"), + err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, diff --git a/ca/acmeClient.go b/ca/acmeClient.go index 5633dac5..d1f40f32 100644 --- a/ca/acmeClient.go +++ b/ca/acmeClient.go @@ -345,7 +345,7 @@ func readACMEError(r io.ReadCloser) error { ae := new(acme.Error) err = json.Unmarshal(b, &ae) // If we successfully marshaled to an ACMEError then return the ACMEError. - if err != nil || len(ae.Error()) == 0 { + if err != nil || ae.Error() == "" { fmt.Printf("b = %s\n", b) // Throw up our hands. return errors.Errorf("%s", b) diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index f5963de4..656a82cf 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -1247,6 +1247,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { Type: "Certificate", Bytes: leaf.Raw, }) + // nolint:gocritic certBytes := append(leafb, leafb...) certBytes = append(certBytes, leafb...) ac := &ACMEClient{ diff --git a/ca/adminClient.go b/ca/adminClient.go new file mode 100644 index 00000000..6022f677 --- /dev/null +++ b/ca/adminClient.go @@ -0,0 +1,570 @@ +package ca + +import ( + "bytes" + "crypto/x509" + "encoding/json" + "io" + "net/http" + "net/url" + "path" + "strconv" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" + adminAPI "github.com/smallstep/certificates/authority/admin/api" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/errs" + "go.step.sm/cli-utils/token" + "go.step.sm/cli-utils/token/provision" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/randutil" + "go.step.sm/linkedca" + "google.golang.org/protobuf/encoding/protojson" +) + +var adminURLPrefix = "admin" + +// AdminClient implements an HTTP client for the CA server. +type AdminClient struct { + client *uaClient + endpoint *url.URL + retryFunc RetryFunc + opts []ClientOption + x5cJWK *jose.JSONWebKey + x5cCertFile string + x5cCertStrs []string + x5cCert *x509.Certificate + x5cIssuer string + x5cSubject string +} + +// NewAdminClient creates a new AdminClient with the given endpoint and options. +func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error) { + u, err := parseEndpoint(endpoint) + if err != nil { + return nil, err + } + // Retrieve transport from options. + o := new(clientOptions) + if err := o.apply(opts); err != nil { + return nil, err + } + tr, err := o.getTransport(endpoint) + if err != nil { + return nil, err + } + + return &AdminClient{ + client: newClient(tr), + endpoint: u, + retryFunc: o.retryFunc, + opts: opts, + x5cJWK: o.x5cJWK, + x5cCertFile: o.x5cCertFile, + x5cCertStrs: o.x5cCertStrs, + x5cCert: o.x5cCert, + x5cIssuer: o.x5cIssuer, + x5cSubject: o.x5cSubject, + }, nil +} + +func (c *AdminClient) generateAdminToken(urlPath string) (string, error) { + // A random jwt id will be used to identify duplicated tokens + jwtID, err := randutil.Hex(64) // 256 bits + if err != nil { + return "", err + } + + now := time.Now() + tokOptions := []token.Options{ + token.WithJWTID(jwtID), + token.WithKid(c.x5cJWK.KeyID), + token.WithIssuer(c.x5cIssuer), + token.WithAudience(urlPath), + token.WithValidity(now, now.Add(token.DefaultValidity)), + token.WithX5CCerts(c.x5cCertStrs), + } + + tok, err := provision.New(c.x5cSubject, tokOptions...) + if err != nil { + return "", err + } + + return tok.SignedString(c.x5cJWK.Algorithm, c.x5cJWK.Key) + +} + +func (c *AdminClient) retryOnError(r *http.Response) bool { + if c.retryFunc != nil { + if c.retryFunc(r.StatusCode) { + o := new(clientOptions) + if err := o.apply(c.opts); err != nil { + return false + } + tr, err := o.getTransport(c.endpoint.String()) + if err != nil { + return false + } + r.Body.Close() + c.client.SetTransport(tr) + return true + } + } + return false +} + +// GetAdmin performs the GET /admin/admin/{id} request to the CA. +func (c *AdminClient) GetAdmin(id string) (*linkedca.Admin, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)}) +retry: + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var adm = new(linkedca.Admin) + if err := readProtoJSON(resp.Body, adm); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return adm, nil +} + +// AdminOption is the type of options passed to the Admin method. +type AdminOption func(o *adminOptions) error + +type adminOptions struct { + cursor string + limit int +} + +func (o *adminOptions) apply(opts []AdminOption) (err error) { + for _, fn := range opts { + if err = fn(o); err != nil { + return + } + } + return +} + +func (o *adminOptions) rawQuery() string { + v := url.Values{} + if len(o.cursor) > 0 { + v.Set("cursor", o.cursor) + } + if o.limit > 0 { + v.Set("limit", strconv.Itoa(o.limit)) + } + return v.Encode() +} + +// WithAdminCursor will request the admins starting with the given cursor. +func WithAdminCursor(cursor string) AdminOption { + return func(o *adminOptions) error { + o.cursor = cursor + return nil + } +} + +// WithAdminLimit will request the given number of admins. +func WithAdminLimit(limit int) AdminOption { + return func(o *adminOptions) error { + o.limit = limit + return nil + } +} + +// GetAdminsPaginate returns a page from the the GET /admin/admins request to the CA. +func (c *AdminClient) GetAdminsPaginate(opts ...AdminOption) (*adminAPI.GetAdminsResponse, error) { + var retried bool + o := new(adminOptions) + if err := o.apply(opts); err != nil { + return nil, err + } + u := c.endpoint.ResolveReference(&url.URL{ + Path: "/admin/admins", + RawQuery: o.rawQuery(), + }) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return nil, errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return nil, errors.Wrapf(err, "create GET %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var body = new(adminAPI.GetAdminsResponse) + if err := readJSON(resp.Body, body); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return body, nil +} + +// GetAdmins returns all admins from the GET /admin/admins request to the CA. +func (c *AdminClient) GetAdmins(opts ...AdminOption) ([]*linkedca.Admin, error) { + var ( + cursor = "" + admins = []*linkedca.Admin{} + ) + for { + resp, err := c.GetAdminsPaginate(WithAdminCursor(cursor), WithAdminLimit(100)) + if err != nil { + return nil, err + } + admins = append(admins, resp.Admins...) + if resp.NextCursor == "" { + return admins, nil + } + cursor = resp.NextCursor + } +} + +// CreateAdmin performs the POST /admin/admins request to the CA. +func (c *AdminClient) CreateAdmin(createAdminRequest *adminAPI.CreateAdminRequest) (*linkedca.Admin, error) { + var retried bool + body, err := json.Marshal(createAdminRequest) + if err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: "/admin/admins"}) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return nil, errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("POST", u.String(), bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "create GET %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "client POST %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var adm = new(linkedca.Admin) + if err := readProtoJSON(resp.Body, adm); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return adm, nil +} + +// RemoveAdmin performs the DELETE /admin/admins/{id} request to the CA. +func (c *AdminClient) RemoveAdmin(id string) error { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)}) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("DELETE", u.String(), nil) + if err != nil { + return errors.Wrapf(err, "create DELETE %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return errors.Wrapf(err, "client DELETE %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + +// UpdateAdmin performs the PUT /admin/admins/{id} request to the CA. +func (c *AdminClient) UpdateAdmin(id string, uar *adminAPI.UpdateAdminRequest) (*linkedca.Admin, error) { + var retried bool + body, err := json.Marshal(uar) + if err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)}) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return nil, errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("PATCH", u.String(), bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "create PUT %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "client PUT %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var adm = new(linkedca.Admin) + if err := readProtoJSON(resp.Body, adm); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return adm, nil +} + +// GetProvisioner performs the GET /admin/provisioners/{name} request to the CA. +func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provisioner, error) { + var retried bool + o := new(provisionerOptions) + if err := o.apply(opts); err != nil { + return nil, err + } + var u *url.URL + switch { + case len(o.id) > 0: + u = c.endpoint.ResolveReference(&url.URL{ + Path: "/admin/provisioners/id", + RawQuery: o.rawQuery(), + }) + case len(o.name) > 0: + u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)}) + default: + return nil, errors.New("must set either name or id in method options") + } + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return nil, errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return nil, errors.Wrapf(err, "create PUT %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var prov = new(linkedca.Provisioner) + if err := readProtoJSON(resp.Body, prov); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return prov, nil +} + +// GetProvisionersPaginate performs the GET /admin/provisioners request to the CA. +func (c *AdminClient) GetProvisionersPaginate(opts ...ProvisionerOption) (*adminAPI.GetProvisionersResponse, error) { + var retried bool + o := new(provisionerOptions) + if err := o.apply(opts); err != nil { + return nil, err + } + u := c.endpoint.ResolveReference(&url.URL{ + Path: "/admin/provisioners", + RawQuery: o.rawQuery(), + }) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return nil, errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return nil, errors.Wrapf(err, "create PUT %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var body = new(adminAPI.GetProvisionersResponse) + if err := readJSON(resp.Body, body); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return body, nil +} + +// GetProvisioners returns all admins from the GET /admin/admins request to the CA. +func (c *AdminClient) GetProvisioners(opts ...AdminOption) (provisioner.List, error) { + var ( + cursor = "" + provs = provisioner.List{} + ) + for { + resp, err := c.GetProvisionersPaginate(WithProvisionerCursor(cursor), WithProvisionerLimit(100)) + if err != nil { + return nil, err + } + provs = append(provs, resp.Provisioners...) + if resp.NextCursor == "" { + return provs, nil + } + cursor = resp.NextCursor + } +} + +// RemoveProvisioner performs the DELETE /admin/provisioners/{name} request to the CA. +func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error { + var ( + u *url.URL + retried bool + ) + + o := new(provisionerOptions) + if err := o.apply(opts); err != nil { + return err + } + + switch { + case len(o.id) > 0: + u = c.endpoint.ResolveReference(&url.URL{ + Path: path.Join(adminURLPrefix, "provisioners/id"), + RawQuery: o.rawQuery(), + }) + case len(o.name) > 0: + u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)}) + default: + return errors.New("must set either name or id in method options") + } + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("DELETE", u.String(), nil) + if err != nil { + return errors.Wrapf(err, "create DELETE %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return errors.Wrapf(err, "client DELETE %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + +// CreateProvisioner performs the POST /admin/provisioners request to the CA. +func (c *AdminClient) CreateProvisioner(prov *linkedca.Provisioner) (*linkedca.Provisioner, error) { + var retried bool + body, err := protojson.Marshal(prov) + if err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners")}) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return nil, errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("POST", u.String(), bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "create POST %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "client POST %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var nuProv = new(linkedca.Provisioner) + if err := readProtoJSON(resp.Body, nuProv); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return nuProv, nil +} + +// UpdateProvisioner performs the PUT /admin/provisioners/{name} request to the CA. +func (c *AdminClient) UpdateProvisioner(name string, prov *linkedca.Provisioner) error { + var retried bool + body, err := protojson.Marshal(prov) + if err != nil { + return errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", name)}) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("PUT", u.String(), bytes.NewReader(body)) + if err != nil { + return errors.Wrapf(err, "create PUT %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return errors.Wrapf(err, "client PUT %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + +func readAdminError(r io.ReadCloser) error { + defer r.Close() + adminErr := new(admin.Error) + if err := json.NewDecoder(r).Decode(adminErr); err != nil { + return err + } + return errors.New(adminErr.Message) +} diff --git a/ca/bootstrap.go b/ca/bootstrap.go index c9e859bf..42087985 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -30,7 +30,7 @@ func Bootstrap(token string) (*Client, error) { // Validate bootstrap token switch { - case len(claims.SHA) == 0: + case claims.SHA == "": return nil, errors.New("invalid bootstrap token: sha claim is not present") case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"): return nil, errors.New("invalid bootstrap token: aud claim is not a url") @@ -39,6 +39,53 @@ func Bootstrap(token string) (*Client, error) { return NewClient(claims.Audience[0], WithRootSHA256(claims.SHA)) } +// BootstrapClient is a helper function that using the given bootstrap token +// return an http.Client configured with a Transport prepared to do TLS +// connections using the client certificate returned by the certificate +// authority. By default the server will kick off a routine that will renew the +// certificate after 2/3rd of the certificate's lifetime has expired. +// +// Usage: +// // Default example with certificate rotation. +// client, err := ca.BootstrapClient(ctx.Background(), token) +// +// // Example canceling automatic certificate rotation. +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// client, err := ca.BootstrapClient(ctx, token) +// if err != nil { +// return err +// } +// resp, err := client.Get("https://internal.smallstep.com") +func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) { + client, err := Bootstrap(token) + if err != nil { + return nil, err + } + + req, pk, err := CreateSignRequest(token) + if err != nil { + return nil, err + } + + sign, err := client.Sign(req) + if err != nil { + return nil, err + } + + // Make sure the tlsConfig have all supported roots on RootCAs + options = append(options, AddRootsToRootCAs()) + + transport, err := client.Transport(ctx, sign, pk, options...) + if err != nil { + return nil, err + } + + return &http.Client{ + Transport: transport, + }, nil +} + // BootstrapServer is a helper function that using the given token returns the // given http.Server configured with a TLS certificate signed by the Certificate // Authority. By default the server will kick off a routine that will renew the @@ -100,53 +147,6 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio return base, nil } -// BootstrapClient is a helper function that using the given bootstrap token -// return an http.Client configured with a Transport prepared to do TLS -// connections using the client certificate returned by the certificate -// authority. By default the server will kick off a routine that will renew the -// certificate after 2/3rd of the certificate's lifetime has expired. -// -// Usage: -// // Default example with certificate rotation. -// client, err := ca.BootstrapClient(ctx.Background(), token) -// -// // Example canceling automatic certificate rotation. -// ctx, cancel := context.WithCancel(context.Background()) -// defer cancel() -// client, err := ca.BootstrapClient(ctx, token) -// if err != nil { -// return err -// } -// resp, err := client.Get("https://internal.smallstep.com") -func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) { - client, err := Bootstrap(token) - if err != nil { - return nil, err - } - - req, pk, err := CreateSignRequest(token) - if err != nil { - return nil, err - } - - sign, err := client.Sign(req) - if err != nil { - return nil, err - } - - // Make sure the tlsConfig have all supported roots on RootCAs - options = append(options, AddRootsToRootCAs()) - - transport, err := client.Transport(ctx, sign, pk, options...) - if err != nil { - return nil, err - } - - return &http.Client{ - Transport: transport, - }, nil -} - // BootstrapListener is a helper function that using the given token returns a // TLS listener which accepts connections from an inner listener and wraps each // connection with Server. diff --git a/ca/ca.go b/ca/ca.go index 56f4c1f8..c76e8c0a 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -17,6 +17,8 @@ import ( acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + adminAPI "github.com/smallstep/certificates/authority/admin/api" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/monitoring" @@ -27,10 +29,13 @@ import ( ) type options struct { - configFile string - password []byte - issuerPassword []byte - database db.AuthDB + configFile string + linkedCAToken string + password []byte + issuerPassword []byte + sshHostPassword []byte + sshUserPassword []byte + database db.AuthDB } func (o *options) apply(opts []Option) { @@ -58,6 +63,22 @@ func WithPassword(password []byte) Option { } } +// WithSSHHostPassword sets the given password to decrypt the key used to sign +// ssh host certificates. +func WithSSHHostPassword(password []byte) Option { + return func(o *options) { + o.sshHostPassword = password + } +} + +// WithSSHUserPassword sets the given password to decrypt the key used to sign +// ssh user certificates. +func WithSSHUserPassword(password []byte) Option { + return func(o *options) { + o.sshUserPassword = password + } +} + // WithIssuerPassword sets the given password as the configured certificate // issuer password in the CA options. func WithIssuerPassword(password []byte) Option { @@ -67,9 +88,16 @@ func WithIssuerPassword(password []byte) Option { } // WithDatabase sets the given authority database to the CA options. -func WithDatabase(db db.AuthDB) Option { +func WithDatabase(d db.AuthDB) Option { return func(o *options) { - o.database = db + o.database = d + } +} + +// WithLinkedCAToken sets the token used to authenticate with the linkedca. +func WithLinkedCAToken(token string) Option { + return func(o *options) { + o.linkedCAToken = token } } @@ -77,7 +105,7 @@ func WithDatabase(db db.AuthDB) Option { // the HTTP server, set ups the middlewares and the HTTP handlers. type CA struct { auth *authority.Authority - config *authority.Config + config *config.Config srv *server.Server insecureSrv *server.Server opts *options @@ -85,35 +113,34 @@ type CA struct { } // New creates and initializes the CA with the given configuration and options. -func New(config *authority.Config, opts ...Option) (*CA, error) { +func New(cfg *config.Config, opts ...Option) (*CA, error) { ca := &CA{ - config: config, + config: cfg, opts: new(options), } ca.opts.apply(opts) - return ca.Init(config) + return ca.Init(cfg) } // Init initializes the CA with the given configuration. -func (ca *CA) Init(config *authority.Config) (*CA, error) { - // Intermediate Password. - if len(ca.opts.password) > 0 { - ca.config.Password = string(ca.opts.password) +func (ca *CA) Init(cfg *config.Config) (*CA, error) { + // Set password, it's ok to set nil password, the ca will prompt for them if + // they are required. + opts := []authority.Option{ + authority.WithPassword(ca.opts.password), + authority.WithSSHHostPassword(ca.opts.sshHostPassword), + authority.WithSSHUserPassword(ca.opts.sshUserPassword), + authority.WithIssuerPassword(ca.opts.issuerPassword), + } + if ca.opts.linkedCAToken != "" { + opts = append(opts, authority.WithLinkedCAToken(ca.opts.linkedCAToken)) } - // Certificate issuer password for RA mode. - if len(ca.opts.issuerPassword) > 0 { - if ca.config.AuthorityConfig != nil && ca.config.AuthorityConfig.CertificateIssuer != nil { - ca.config.AuthorityConfig.CertificateIssuer.Password = string(ca.opts.issuerPassword) - } - } - - var opts []authority.Option if ca.opts.database != nil { opts = append(opts, authority.WithDatabase(ca.opts.database)) } - auth, err := authority.New(config, opts...) + auth, err := authority.New(cfg, opts...) if err != nil { return nil, err } @@ -139,8 +166,8 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { }) //Add ACME api endpoints in /acme and /1.0/acme - dns := config.DNSNames[0] - u, err := url.Parse("https://" + config.Address) + dns := cfg.DNSNames[0] + u, err := url.Parse("https://" + cfg.Address) if err != nil { return nil, err } @@ -149,9 +176,10 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { dns = fmt.Sprintf("%s:%s", dns, port) } + // ACME Router prefix := "acme" var acmeDB acme.DB - if config.DB == nil { + if cfg.DB == nil { acmeDB = nil } else { acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) @@ -160,7 +188,7 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { } } acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ - Backdate: *config.AuthorityConfig.Backdate, + Backdate: *cfg.AuthorityConfig.Backdate, DB: acmeDB, DNS: dns, Prefix: prefix, @@ -175,6 +203,17 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { acmeHandler.Route(r) }) + // Admin API Router + if cfg.AuthorityConfig.EnableAdmin { + adminDB := auth.GetAdminDatabase() + if adminDB != nil { + adminHandler := adminAPI.NewHandler(auth) + mux.Route("/admin", func(r chi.Router) { + adminHandler.Route(r) + }) + } + } + if ca.shouldServeSCEPEndpoints() { scepPrefix := "scep" scepAuthority, err := scep.New(auth, scep.AuthorityOptions{ @@ -209,8 +248,8 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { //dumpRoutes(mux) // Add monitoring if configured - if len(config.Monitoring) > 0 { - m, err := monitoring.New(config.Monitoring) + if len(cfg.Monitoring) > 0 { + m, err := monitoring.New(cfg.Monitoring) if err != nil { return nil, err } @@ -219,8 +258,8 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { } // Add logger if configured - if len(config.Logger) > 0 { - logger, err := logging.New("ca", config.Logger) + if len(cfg.Logger) > 0 { + logger, err := logging.New("ca", cfg.Logger) if err != nil { return nil, err } @@ -228,16 +267,16 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { insecureHandler = logger.Middleware(insecureHandler) } - ca.srv = server.New(config.Address, handler, tlsConfig) + ca.srv = server.New(cfg.Address, handler, tlsConfig) // only start the insecure server if the insecure address is configured // and, currently, also only when it should serve SCEP endpoints. - if ca.shouldServeSCEPEndpoints() && config.InsecureAddress != "" { + if ca.shouldServeSCEPEndpoints() && cfg.InsecureAddress != "" { // TODO: instead opt for having a single server.Server but two // http.Servers handling the HTTP and HTTPS handler? The latter // will probably introduce more complexity in terms of graceful // reload. - ca.insecureSrv = server.New(config.InsecureAddress, insecureHandler, nil) + ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil) } return ca, nil @@ -245,26 +284,25 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { // Run starts the CA calling to the server ListenAndServe method. func (ca *CA) Run() error { - var wg sync.WaitGroup - errors := make(chan error, 1) + errs := make(chan error, 1) if ca.insecureSrv != nil { wg.Add(1) go func() { defer wg.Done() - errors <- ca.insecureSrv.ListenAndServe() + errs <- ca.insecureSrv.ListenAndServe() }() } wg.Add(1) go func() { defer wg.Done() - errors <- ca.srv.ListenAndServe() + errs <- ca.srv.ListenAndServe() }() // wait till error occurs; ensures the servers keep listening - err := <-errors + err := <-errs wg.Wait() @@ -293,7 +331,7 @@ func (ca *CA) Stop() error { // Reload reloads the configuration of the CA and calls to the server Reload // method. func (ca *CA) Reload() error { - config, err := authority.LoadConfiguration(ca.opts.configFile) + cfg, err := config.LoadConfiguration(ca.opts.configFile) if err != nil { return errors.Wrap(err, "error reloading ca configuration") } @@ -305,14 +343,17 @@ func (ca *CA) Reload() error { } // Do not allow reload if the database configuration has changed. - if !reflect.DeepEqual(ca.config.DB, config.DB) { + if !reflect.DeepEqual(ca.config.DB, cfg.DB) { logContinue("Reload failed because the database configuration has changed.") return errors.New("error reloading ca: database configuration cannot change") } - newCA, err := New(config, + newCA, err := New(cfg, WithPassword(ca.opts.password), + WithSSHHostPassword(ca.opts.sshHostPassword), + WithSSHUserPassword(ca.opts.sshUserPassword), WithIssuerPassword(ca.opts.issuerPassword), + WithLinkedCAToken(ca.opts.linkedCAToken), WithConfigFile(ca.opts.configFile), WithDatabase(ca.auth.GetDatabase()), ) diff --git a/ca/ca_test.go b/ca/ca_test.go index 6e297733..ff264db7 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -322,7 +322,7 @@ ZEp7knvU2psWRw== assert.Equals(t, intermediate, realIntermediate) } else { err := readError(body) - if len(tc.errMsg) == 0 { + if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) @@ -375,7 +375,7 @@ func TestCAProvisioners(t *testing.T) { assert.Equals(t, a, b) } else { err := readError(body) - if len(tc.errMsg) == 0 { + if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) @@ -436,7 +436,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) { assert.Equals(t, ek.Key, tc.expectedKey) } else { err := readError(body) - if len(tc.errMsg) == 0 { + if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) @@ -497,7 +497,7 @@ func TestCARoot(t *testing.T) { assert.Equals(t, root.RootPEM.Certificate, rootCrt) } else { err := readError(body) - if len(tc.errMsg) == 0 { + if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) @@ -665,7 +665,7 @@ func TestCARenew(t *testing.T) { assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions) } else { err := readError(body) - if len(tc.errMsg) == 0 { + if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) diff --git a/ca/client.go b/ca/client.go index 2292c41e..cfeddba0 100644 --- a/ca/client.go +++ b/ca/client.go @@ -10,8 +10,10 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/asn1" "encoding/hex" "encoding/json" + "encoding/pem" "io" "io/ioutil" "net/http" @@ -28,10 +30,13 @@ import ( "github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/errs" "go.step.sm/cli-utils/config" + "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "golang.org/x/net/http2" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" "gopkg.in/square/go-jose.v2/jwt" ) @@ -69,17 +74,17 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) { c.Client.Transport = tr } -func (c *uaClient) Get(url string) (*http.Response, error) { - req, err := http.NewRequest("GET", url, nil) +func (c *uaClient) Get(u string) (*http.Response, error) { + req, err := http.NewRequest("GET", u, nil) if err != nil { - return nil, errors.Wrapf(err, "new request GET %s failed", url) + return nil, errors.Wrapf(err, "new request GET %s failed", u) } req.Header.Set("User-Agent", UserAgent) return c.Client.Do(req) } -func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response, error) { - req, err := http.NewRequest("POST", url, body) +func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest("POST", u, body) if err != nil { return nil, err } @@ -88,6 +93,11 @@ func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response return c.Client.Do(req) } +func (c *uaClient) Do(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", UserAgent) + return c.Client.Do(req) +} + // RetryFunc defines the method used to retry a request. If it returns true, the // request will be retried once. type RetryFunc func(code int) bool @@ -103,6 +113,12 @@ type clientOptions struct { certificate tls.Certificate getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error) retryFunc RetryFunc + x5cJWK *jose.JSONWebKey + x5cCertFile string + x5cCertStrs []string + x5cCert *x509.Certificate + x5cIssuer string + x5cSubject string } func (o *clientOptions) apply(opts []ClientOption) (err error) { @@ -261,9 +277,66 @@ func WithCABundle(bundle []byte) ClientOption { // WithCertificate will set the given certificate as the TLS client certificate // in the client. -func WithCertificate(crt tls.Certificate) ClientOption { +func WithCertificate(cert tls.Certificate) ClientOption { return func(o *clientOptions) error { - o.certificate = crt + o.certificate = cert + return nil + } +} + +var ( + stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} + stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) +) + +type stepProvisionerASN1 struct { + Type int + Name []byte + CredentialID []byte + KeyValuePairs []string `asn1:"optional,omitempty"` +} + +// WithAdminX5C will set the given file as the X5C certificate for use +// by the client. +func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile string) ClientOption { + return func(o *clientOptions) error { + // Get private key from given key file + var ( + err error + opts []jose.Option + ) + if passwordFile != "" { + opts = append(opts, jose.WithPasswordFile(passwordFile)) + } + blk, err := pemutil.Serialize(key) + if err != nil { + return errors.Wrap(err, "error serializing private key") + } + o.x5cJWK, err = jose.ParseKey(pem.EncodeToMemory(blk), opts...) + if err != nil { + return err + } + o.x5cCertStrs, err = jose.ValidateX5C(certs, o.x5cJWK.Key) + if err != nil { + return errors.Wrap(err, "error validating x5c certificate chain and key for use in x5c header") + } + + o.x5cCert = certs[0] + o.x5cSubject = o.x5cCert.Subject.CommonName + + for _, e := range o.x5cCert.Extensions { + if e.Id.Equal(stepOIDProvisioner) { + var prov stepProvisionerASN1 + if _, err := asn1.Unmarshal(e.Value, &prov); err != nil { + return errors.Wrap(err, "error unmarshaling provisioner OID from certificate") + } + o.x5cIssuer = string(prov.Name) + } + } + if o.x5cIssuer == "" { + return errors.New("provisioner extension not found in certificate") + } + return nil } } @@ -367,6 +440,8 @@ type ProvisionerOption func(o *provisionerOptions) error type provisionerOptions struct { cursor string limit int + id string + name string } func (o *provisionerOptions) apply(opts []ProvisionerOption) (err error) { @@ -386,6 +461,12 @@ func (o *provisionerOptions) rawQuery() string { if o.limit > 0 { v.Set("limit", strconv.Itoa(o.limit)) } + if len(o.id) > 0 { + v.Set("id", o.id) + } + if len(o.name) > 0 { + v.Set("name", o.name) + } return v.Encode() } @@ -405,6 +486,22 @@ func WithProvisionerLimit(limit int) ProvisionerOption { } } +// WithProvisionerID will request the given provisioner. +func WithProvisionerID(id string) ProvisionerOption { + return func(o *provisionerOptions) error { + o.id = id + return nil + } +} + +// WithProvisionerName will request the given provisioner. +func WithProvisionerName(name string) ProvisionerOption { + return func(o *provisionerOptions) error { + o.name = name + return nil + } +} + // Client implements an HTTP client for the CA server. type Client struct { client *uaClient @@ -534,7 +631,7 @@ retry: // do not match. func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { var retried bool - sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1)) + sha256Sum = strings.ToLower(strings.ReplaceAll(sha256Sum, "-", "")) u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum}) retry: resp, err := newInsecureClient().Get(u.String()) @@ -554,7 +651,7 @@ retry: } // verify the sha256 sum := sha256.Sum256(root.RootPEM.Raw) - if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) { + if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) { return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match") } return &root, nil @@ -969,16 +1066,16 @@ retry: } return nil, readError(resp.Body) } - var config api.SSHConfigResponse - if err := readJSON(resp.Body, &config); err != nil { + var cfg api.SSHConfigResponse + if err := readJSON(resp.Body, &cfg); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } - return &config, nil + return &cfg, nil } // SSHCheckHost performs the POST /ssh/check-host request to the CA with the // given principal. -func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrincipalResponse, error) { +func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalResponse, error) { var retried bool body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ Type: provisioner.SSHHostCert, @@ -1206,6 +1303,15 @@ func readJSON(r io.ReadCloser, v interface{}) error { return json.NewDecoder(r).Decode(v) } +func readProtoJSON(r io.ReadCloser, m proto.Message) error { + defer r.Close() + data, err := ioutil.ReadAll(r) + if err != nil { + return err + } + return protojson.Unmarshal(data, m) +} + func readError(r io.ReadCloser) error { defer r.Close() apiErr := new(errs.Error) diff --git a/ca/client_test.go b/ca/client_test.go index 30669e6e..187066f0 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -135,7 +135,7 @@ func parseCertificateRequest(data string) *x509.CertificateRequest { return csr } -func equalJSON(t *testing.T, a interface{}, b interface{}) bool { +func equalJSON(t *testing.T, a, b interface{}) bool { if reflect.DeepEqual(a, b) { return true } diff --git a/ca/identity/client_test.go b/ca/identity/client_test.go index c792a6dc..402ec7b8 100644 --- a/ca/identity/client_test.go +++ b/ca/identity/client_test.go @@ -187,11 +187,12 @@ func TestLoadClient(t *testing.T) { } else { gotTransport := got.Client.Transport.(*http.Transport) wantTransport := tt.want.Client.Transport.(*http.Transport) - if gotTransport.TLSClientConfig.GetClientCertificate == nil { + switch { + case gotTransport.TLSClientConfig.GetClientCertificate == nil: t.Error("LoadClient() transport does not define GetClientCertificate") - } else if !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()) { + case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()): t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) - } else { + default: crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) if err != nil { t.Errorf("LoadClient() GetClientCertificate error = %v", err) diff --git a/ca/testdata/ca.json b/ca/testdata/ca.json index b094c02e..d40325e8 100644 --- a/ca/testdata/ca.json +++ b/ca/testdata/ca.json @@ -9,12 +9,11 @@ "logger": {"format": "text"}, "tls": { "minVersion": 1.2, - "maxVersion": 1.2, + "maxVersion": 1.3, "renegotiation": false, "cipherSuites": [ - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" ] }, "authority": { @@ -34,7 +33,7 @@ "y": "ZhYcFQBqtErdC_pA7sOXrO7AboCEPIKP9Ik4CHJqANk" } }, { - "name": "max", + "name": "mike", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlZsWnl0dUxrWTR5enlqZXJybnN0aGcifQ.QP15wQYjZ12BLgl-XTq2Vb12G3OHAfic.X35QqAaXwnlmeCUU._2qIUp0TI8yDI7c2e9upIRdrnmB5OvtLfrYN-Su2NLBpaoYtr9O55Wo0Iryc0W2pYqnVDPvgPPes4P4nQAnzw5WhFYc1Xf1ZEetfdNhwi1x2FNwPbACBAgxm5AW40O5AAlbLcWushYASfeMBZocTGXuSGUzwFqoWD-5EDJ80TWQ7cAj3ttHrJ_3QV9hi4O9KJUCiXngN-Yz2zXrhBL4NOH2fmRbaf5c0rF8xUJIIW-TcyYJeX_Fbx1IzzKKPd9USUwkDhxD4tLa51I345xVqjuwG1PEn6nF8JKqLRVUKEKFin-ShXrfE61KceyAvm4YhWKrbJWIm3bH5Hxaphy4.TexIrIhsRxJStpE3EJ925Q", "key": { @@ -76,7 +75,7 @@ "minTLSCertDuration": "1s" } }, { - "name": "mariano", + "name": "maxey", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6Ik5SLTk5ZkVMSm1CLW1FZGllUlFFc3cifQ.Fr314BEUGTda4ICJl2uxFdjpEUGGqJEV.gBbu_DZE1ONDu14r.X-7MKMyokZIF1HTCVqqL0tTWgaC1ZGZBLLltd11ZUhQTswo_8kvgiTv3cFShj7ATF0tAY8HStyJmzLO8mKPVOPDXSwjdNsPriZclI6JWGi9iOu8pEiN9pZM6-itxan1JMcDUNg2U-P1BmKppHRbDKsOTivymfRyeUk51dBIlS54p5xNK1HFLc1YtWC1Rc_ngYVqOgqlhIrCHArAEBe3jrfUaH2ym-8fkVdwVqtxmte3XXK9g8FchsygRNnOKtRcr0TyzTUV-7bPi8_t02Zi-EHLFaSawVXWV_Qk1GeLYJR22Rp74beo-b5-lCNVp10btO0xdGySUWmCJ4v4_QZw.c8unwWycwtfdJMM_0b0fuA", "key": { diff --git a/ca/tls.go b/ca/tls.go index e4f585fe..0738d0e0 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -103,10 +103,9 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, return nil, nil, err } - // Update renew function with transport tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew - tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck + tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) @@ -155,7 +154,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, // Update renew function with transport tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew - tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck + tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) @@ -196,7 +195,7 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net } // buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport. -// nolint:unused +// nolint:unused,gocritic func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) { d := getDefaultDialer() @@ -254,6 +253,8 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific return nil, err } + // nolint:gocritic + // using a new variable for clarity chain := append(certPEM, caPEM...) cert, err := tls.X509KeyPair(chain, keyPEM) if err != nil { @@ -278,9 +279,9 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { // getDefaultDialer returns a new dialer with the default configuration. func getDefaultDialer() *net.Dialer { + // With the KeepAlive parameter set to 0, it will be use Golang's default. return &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, + Timeout: 30 * time.Second, } } diff --git a/cas/apiv1/options.go b/cas/apiv1/options.go index 4810d1f3..badad7fc 100644 --- a/cas/apiv1/options.go +++ b/cas/apiv1/options.go @@ -38,17 +38,29 @@ type Options struct { CertificateChain []*x509.Certificate `json:"-"` Signer crypto.Signer `json:"-"` - // IsCreator is set to true when we're creating a certificate authority. Is - // used to skip some validations when initializing a CertificateAuthority. + // IsCreator is set to true when we're creating a certificate authority. It + // is used to skip some validations when initializing a + // CertificateAuthority. This option is used on SoftCAS and CloudCAS. IsCreator bool `json:"-"` + // IsCAGetter is set to true when we're just using the + // CertificateAuthorityGetter interface to retrieve the root certificate. It + // is used to skip some validations when initializing a + // CertificateAuthority. This option is used on StepCAS. + IsCAGetter bool `json:"-"` + // KeyManager is the KMS used to generate keys in SoftCAS. KeyManager kms.KeyManager `json:"-"` - // Project and Location are parameters used in CloudCAS to create a new - // certificate authority. - Project string `json:"-"` - Location string `json:"-"` + // Project, Location, CaPool and GCSBucket are parameters used in CloudCAS + // to create a new certificate authority. If a CaPool does not exist it will + // be created. GCSBucket is optional, if not provided GCloud will create a + // managed bucket. + Project string `json:"-"` + Location string `json:"-"` + CaPool string `json:"-"` + CaPoolTier string `json:"-"` + GCSBucket string `json:"-"` } // CertificateIssuer contains the properties used to use the StepCAS certificate diff --git a/cas/apiv1/requests.go b/cas/apiv1/requests.go index b47a9c13..bf745c17 100644 --- a/cas/apiv1/requests.go +++ b/cas/apiv1/requests.go @@ -108,6 +108,9 @@ type GetCertificateAuthorityResponse struct { RootCertificate *x509.Certificate } +// CreateKeyRequest is the request used to generate a new key using a KMS. +type CreateKeyRequest = apiv1.CreateKeyRequest + // CreateCertificateAuthorityRequest is the request used to generate a root or // intermediate certificate. type CreateCertificateAuthorityRequest struct { @@ -126,7 +129,7 @@ type CreateCertificateAuthorityRequest struct { // CreateKey defines the KMS CreateKeyRequest to use when creating a new // CertificateAuthority. If CreateKey is nil, a default algorithm will be // used. - CreateKey *apiv1.CreateKeyRequest + CreateKey *CreateKeyRequest } // CreateCertificateAuthorityResponse is the response for @@ -136,6 +139,7 @@ type CreateCertificateAuthorityResponse struct { Name string Certificate *x509.Certificate CertificateChain []*x509.Certificate + KeyName string PublicKey crypto.PublicKey PrivateKey crypto.PrivateKey Signer crypto.Signer diff --git a/cas/apiv1/services.go b/cas/apiv1/services.go index d4dd3c8c..cf9a5470 100644 --- a/cas/apiv1/services.go +++ b/cas/apiv1/services.go @@ -1,6 +1,7 @@ package apiv1 import ( + "crypto/x509" "net/http" "strings" ) @@ -26,6 +27,12 @@ type CertificateAuthorityCreator interface { CreateCertificateAuthority(req *CreateCertificateAuthorityRequest) (*CreateCertificateAuthorityResponse, error) } +// SignatureAlgorithmGetter is an optional implementation in a crypto.Signer +// that returns the SignatureAlgorithm to use. +type SignatureAlgorithmGetter interface { + SignatureAlgorithm() x509.SignatureAlgorithm +} + // Type represents the CAS type used. type Type string diff --git a/cas/cloudcas/certificate.go b/cas/cloudcas/certificate.go index d7789992..6f229702 100644 --- a/cas/cloudcas/certificate.go +++ b/cas/cloudcas/certificate.go @@ -12,8 +12,7 @@ import ( "github.com/pkg/errors" kmsapi "github.com/smallstep/certificates/kms/apiv1" - pb "google.golang.org/genproto/googleapis/cloud/security/privateca/v1beta1" - wrapperspb "google.golang.org/protobuf/types/known/wrapperspb" + pb "google.golang.org/genproto/googleapis/cloud/security/privateca/v1" ) var ( @@ -67,11 +66,10 @@ func createCertificateConfig(tpl *x509.Certificate) (*pb.Certificate_Config, err config := &pb.CertificateConfig{ SubjectConfig: &pb.CertificateConfig_SubjectConfig{ Subject: createSubject(tpl), - CommonName: tpl.Subject.CommonName, SubjectAltName: createSubjectAlternativeNames(tpl), }, - ReusableConfig: createReusableConfig(tpl), - PublicKey: pk, + X509Config: createX509Parameters(tpl), + PublicKey: pk, } return &pb.Certificate_Config{ Config: config, @@ -86,7 +84,7 @@ func createPublicKey(key crypto.PublicKey) (*pb.PublicKey, error) { return nil, errors.Wrap(err, "error marshaling public key") } return &pb.PublicKey{ - Type: pb.PublicKey_PEM_EC_KEY, + Format: pb.PublicKey_PEM, Key: pem.EncodeToMemory(&pem.Block{ Type: "PUBLIC KEY", Bytes: asn1Bytes, @@ -94,7 +92,7 @@ func createPublicKey(key crypto.PublicKey) (*pb.PublicKey, error) { }, nil case *rsa.PublicKey: return &pb.PublicKey{ - Type: pb.PublicKey_PEM_RSA_KEY, + Format: pb.PublicKey_PEM, Key: pem.EncodeToMemory(&pem.Block{ Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(key), @@ -107,7 +105,9 @@ func createPublicKey(key crypto.PublicKey) (*pb.PublicKey, error) { func createSubject(cert *x509.Certificate) *pb.Subject { sub := cert.Subject - ret := new(pb.Subject) + ret := &pb.Subject{ + CommonName: sub.CommonName, + } if len(sub.Country) > 0 { ret.CountryCode = sub.Country[0] } @@ -196,7 +196,7 @@ func createSubjectAlternativeNames(cert *x509.Certificate) *pb.SubjectAltNames { return ret } -func createReusableConfig(cert *x509.Certificate) *pb.ReusableConfigWrapper { +func createX509Parameters(cert *x509.Certificate) *pb.X509Parameters { var unknownEKUs []*pb.ObjectId var ekuOptions = &pb.KeyUsage_ExtendedKeyUsageOptions{} for _, eku := range cert.ExtKeyUsage { @@ -241,22 +241,19 @@ func createReusableConfig(cert *x509.Certificate) *pb.ReusableConfigWrapper { policyIDs = append(policyIDs, createObjectID(oid)) } - var caOptions *pb.ReusableConfigValues_CaOptions + var caOptions *pb.X509Parameters_CaOptions if cert.BasicConstraintsValid { - var maxPathLength *wrapperspb.Int32Value + caOptions = new(pb.X509Parameters_CaOptions) + var maxPathLength int32 switch { case cert.MaxPathLenZero: - maxPathLength = wrapperspb.Int32(0) + maxPathLength = 0 + caOptions.MaxIssuerPathLength = &maxPathLength case cert.MaxPathLen > 0: - maxPathLength = wrapperspb.Int32(int32(cert.MaxPathLen)) - default: - maxPathLength = nil - } - - caOptions = &pb.ReusableConfigValues_CaOptions{ - IsCa: wrapperspb.Bool(cert.IsCA), - MaxIssuerPathLength: maxPathLength, + maxPathLength = int32(cert.MaxPathLen) + caOptions.MaxIssuerPathLength = &maxPathLength } + caOptions.IsCa = &cert.IsCA } var extraExtensions []*pb.X509Extension @@ -270,7 +267,7 @@ func createReusableConfig(cert *x509.Certificate) *pb.ReusableConfigWrapper { } } - values := &pb.ReusableConfigValues{ + return &pb.X509Parameters{ KeyUsage: &pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ DigitalSignature: cert.KeyUsage&x509.KeyUsageDigitalSignature > 0, @@ -291,12 +288,6 @@ func createReusableConfig(cert *x509.Certificate) *pb.ReusableConfigWrapper { AiaOcspServers: cert.OCSPServer, AdditionalExtensions: extraExtensions, } - - return &pb.ReusableConfigWrapper{ - ConfigValues: &pb.ReusableConfigWrapper_ReusableConfigValues{ - ReusableConfigValues: values, - }, - } } // isExtraExtension returns true if the extension oid is not managed in a diff --git a/cas/cloudcas/certificate_test.go b/cas/cloudcas/certificate_test.go index 0822e4c1..8bf67fb6 100644 --- a/cas/cloudcas/certificate_test.go +++ b/cas/cloudcas/certificate_test.go @@ -15,8 +15,7 @@ import ( "testing" kmsapi "github.com/smallstep/certificates/kms/apiv1" - pb "google.golang.org/genproto/googleapis/cloud/security/privateca/v1beta1" - wrapperspb "google.golang.org/protobuf/types/known/wrapperspb" + pb "google.golang.org/genproto/googleapis/cloud/security/privateca/v1" ) var ( @@ -67,30 +66,27 @@ func Test_createCertificateConfig(t *testing.T) { {"ok", args{cert}, &pb.Certificate_Config{ Config: &pb.CertificateConfig{ SubjectConfig: &pb.CertificateConfig_SubjectConfig{ - Subject: &pb.Subject{}, - CommonName: "test.smallstep.com", + Subject: &pb.Subject{ + CommonName: "test.smallstep.com", + }, SubjectAltName: &pb.SubjectAltNames{ DnsNames: []string{"test.smallstep.com"}, }, }, - ReusableConfig: &pb.ReusableConfigWrapper{ - ConfigValues: &pb.ReusableConfigWrapper_ReusableConfigValues{ - ReusableConfigValues: &pb.ReusableConfigValues{ - KeyUsage: &pb.KeyUsage{ - BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ - DigitalSignature: true, - }, - ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{ - ClientAuth: true, - ServerAuth: true, - }, - }, + X509Config: &pb.X509Parameters{ + KeyUsage: &pb.KeyUsage{ + BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ + DigitalSignature: true, + }, + ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{ + ClientAuth: true, + ServerAuth: true, }, }, }, PublicKey: &pb.PublicKey{ - Type: pb.PublicKey_PEM_EC_KEY, - Key: []byte(testLeafPublicKey), + Key: []byte(testLeafPublicKey), + Format: pb.PublicKey_PEM, }, }, }, false}, @@ -104,7 +100,7 @@ func Test_createCertificateConfig(t *testing.T) { return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("createCertificateConfig() = %v, want %v", got.Config.ReusableConfig, tt.want.Config.ReusableConfig) + t.Errorf("createCertificateConfig() = %v, want %v", got.Config, tt.want.Config) } }) } @@ -127,12 +123,12 @@ func Test_createPublicKey(t *testing.T) { wantErr bool }{ {"ok ec", args{ecCert.PublicKey}, &pb.PublicKey{ - Type: pb.PublicKey_PEM_EC_KEY, - Key: []byte(testLeafPublicKey), + Format: pb.PublicKey_PEM, + Key: []byte(testLeafPublicKey), }, false}, {"ok rsa", args{rsaCert.PublicKey}, &pb.PublicKey{ - Type: pb.PublicKey_PEM_RSA_KEY, - Key: []byte(testRSAPublicKey), + Format: pb.PublicKey_PEM, + Key: []byte(testRSAPublicKey), }, false}, {"fail ed25519", args{edpub}, nil, true}, {"fail ec marshal", args{&ecdsa.PublicKey{ @@ -185,6 +181,7 @@ func Test_createSubject(t *testing.T) { Province: "California", StreetAddress: "1 A St.", PostalCode: "12345", + CommonName: "test.smallstep.com", }}, } for _, tt := range tests { @@ -289,62 +286,55 @@ func Test_createSubjectAlternativeNames(t *testing.T) { } } -func Test_createReusableConfig(t *testing.T) { - withKU := func(ku *pb.KeyUsage) *pb.ReusableConfigWrapper { +func Test_createX509Parameters(t *testing.T) { + withKU := func(ku *pb.KeyUsage) *pb.X509Parameters { if ku.BaseKeyUsage == nil { ku.BaseKeyUsage = &pb.KeyUsage_KeyUsageOptions{} } if ku.ExtendedKeyUsage == nil { ku.ExtendedKeyUsage = &pb.KeyUsage_ExtendedKeyUsageOptions{} } - return &pb.ReusableConfigWrapper{ - ConfigValues: &pb.ReusableConfigWrapper_ReusableConfigValues{ - ReusableConfigValues: &pb.ReusableConfigValues{ - KeyUsage: ku, - }, - }, + return &pb.X509Parameters{ + KeyUsage: ku, } } - withRCV := func(rcv *pb.ReusableConfigValues) *pb.ReusableConfigWrapper { + withRCV := func(rcv *pb.X509Parameters) *pb.X509Parameters { if rcv.KeyUsage == nil { rcv.KeyUsage = &pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{}, ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{}, } } - return &pb.ReusableConfigWrapper{ - ConfigValues: &pb.ReusableConfigWrapper_ReusableConfigValues{ - ReusableConfigValues: rcv, - }, - } + return rcv } + vTrue := true + vFalse := false + vZero := int32(0) + vOne := int32(1) + type args struct { cert *x509.Certificate } tests := []struct { name string args args - want *pb.ReusableConfigWrapper + want *pb.X509Parameters }{ {"keyUsageDigitalSignature", args{&x509.Certificate{ KeyUsage: x509.KeyUsageDigitalSignature, - }}, &pb.ReusableConfigWrapper{ - ConfigValues: &pb.ReusableConfigWrapper_ReusableConfigValues{ - ReusableConfigValues: &pb.ReusableConfigValues{ - KeyUsage: &pb.KeyUsage{ - BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ - DigitalSignature: true, - }, - ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{}, - UnknownExtendedKeyUsages: nil, - }, - CaOptions: nil, - PolicyIds: nil, - AiaOcspServers: nil, - AdditionalExtensions: nil, + }}, &pb.X509Parameters{ + KeyUsage: &pb.KeyUsage{ + BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ + DigitalSignature: true, }, + ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{}, + UnknownExtendedKeyUsages: nil, }, + CaOptions: nil, + PolicyIds: nil, + AiaOcspServers: nil, + AdditionalExtensions: nil, }}, // KeyUsage {"KeyUsageDigitalSignature", args{&x509.Certificate{KeyUsage: x509.KeyUsageDigitalSignature}}, withKU(&pb.KeyUsage{ @@ -455,48 +445,48 @@ func Test_createReusableConfig(t *testing.T) { }, })}, // BasicCre - {"BasicConstraintsCAMax0", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, MaxPathLenZero: true}}, withRCV(&pb.ReusableConfigValues{ - CaOptions: &pb.ReusableConfigValues_CaOptions{ - IsCa: wrapperspb.Bool(true), - MaxIssuerPathLength: wrapperspb.Int32(0), + {"BasicConstraintsCAMax0", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, MaxPathLenZero: true}}, withRCV(&pb.X509Parameters{ + CaOptions: &pb.X509Parameters_CaOptions{ + IsCa: &vTrue, + MaxIssuerPathLength: &vZero, }, })}, - {"BasicConstraintsCAMax1", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: 1, MaxPathLenZero: false}}, withRCV(&pb.ReusableConfigValues{ - CaOptions: &pb.ReusableConfigValues_CaOptions{ - IsCa: wrapperspb.Bool(true), - MaxIssuerPathLength: wrapperspb.Int32(1), + {"BasicConstraintsCAMax1", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: 1, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ + CaOptions: &pb.X509Parameters_CaOptions{ + IsCa: &vTrue, + MaxIssuerPathLength: &vOne, }, })}, - {"BasicConstraintsCANoMax", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: -1, MaxPathLenZero: false}}, withRCV(&pb.ReusableConfigValues{ - CaOptions: &pb.ReusableConfigValues_CaOptions{ - IsCa: wrapperspb.Bool(true), + {"BasicConstraintsCANoMax", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: -1, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ + CaOptions: &pb.X509Parameters_CaOptions{ + IsCa: &vTrue, MaxIssuerPathLength: nil, }, })}, - {"BasicConstraintsCANoMax0", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, MaxPathLenZero: false}}, withRCV(&pb.ReusableConfigValues{ - CaOptions: &pb.ReusableConfigValues_CaOptions{ - IsCa: wrapperspb.Bool(true), + {"BasicConstraintsCANoMax0", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ + CaOptions: &pb.X509Parameters_CaOptions{ + IsCa: &vTrue, MaxIssuerPathLength: nil, }, })}, - {"BasicConstraintsNoCA", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: false, MaxPathLen: 0, MaxPathLenZero: false}}, withRCV(&pb.ReusableConfigValues{ - CaOptions: &pb.ReusableConfigValues_CaOptions{ - IsCa: wrapperspb.Bool(false), + {"BasicConstraintsNoCA", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: false, MaxPathLen: 0, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ + CaOptions: &pb.X509Parameters_CaOptions{ + IsCa: &vFalse, MaxIssuerPathLength: nil, }, })}, - {"BasicConstraintsNoValid", args{&x509.Certificate{BasicConstraintsValid: false, IsCA: false, MaxPathLen: 0, MaxPathLenZero: false}}, withRCV(&pb.ReusableConfigValues{ + {"BasicConstraintsNoValid", args{&x509.Certificate{BasicConstraintsValid: false, IsCA: false, MaxPathLen: 0, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ CaOptions: nil, })}, // PolicyIdentifiers - {"PolicyIdentifiers", args{&x509.Certificate{PolicyIdentifiers: []asn1.ObjectIdentifier{{1, 2, 3, 4}, {4, 3, 2, 1}}}}, withRCV(&pb.ReusableConfigValues{ + {"PolicyIdentifiers", args{&x509.Certificate{PolicyIdentifiers: []asn1.ObjectIdentifier{{1, 2, 3, 4}, {4, 3, 2, 1}}}}, withRCV(&pb.X509Parameters{ PolicyIds: []*pb.ObjectId{ {ObjectIdPath: []int32{1, 2, 3, 4}}, {ObjectIdPath: []int32{4, 3, 2, 1}}, }, })}, // OCSPServer - {"OCPServers", args{&x509.Certificate{OCSPServer: []string{"https://oscp.doe.com", "https://doe.com/ocsp"}}}, withRCV(&pb.ReusableConfigValues{ + {"OCPServers", args{&x509.Certificate{OCSPServer: []string{"https://oscp.doe.com", "https://doe.com/ocsp"}}}, withRCV(&pb.X509Parameters{ AiaOcspServers: []string{"https://oscp.doe.com", "https://doe.com/ocsp"}, })}, // Extensions @@ -505,7 +495,7 @@ func Test_createReusableConfig(t *testing.T) { {Id: []int{2, 5, 29, 17}, Critical: true, Value: []byte("SANs")}, // {Id: []int{4, 3, 2, 1}, Critical: false, Value: []byte("zoobar")}, {Id: []int{2, 5, 29, 31}, Critical: false, Value: []byte("CRL Distribution points")}, - }}}, withRCV(&pb.ReusableConfigValues{ + }}}, withRCV(&pb.X509Parameters{ AdditionalExtensions: []*pb.X509Extension{ {ObjectId: &pb.ObjectId{ObjectIdPath: []int32{1, 2, 3, 4}}, Critical: true, Value: []byte("foobar")}, {ObjectId: &pb.ObjectId{ObjectIdPath: []int32{4, 3, 2, 1}}, Critical: false, Value: []byte("zoobar")}, @@ -514,8 +504,8 @@ func Test_createReusableConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := createReusableConfig(tt.args.cert); !reflect.DeepEqual(got, tt.want) { - t.Errorf("createReusableConfig() = %v, want %v", got, tt.want) + if got := createX509Parameters(tt.args.cert); !reflect.DeepEqual(got, tt.want) { + t.Errorf("createX509Parameters() = %v, want %v", got, tt.want) } }) } diff --git a/cas/cloudcas/cloudcas.go b/cas/cloudcas/cloudcas.go index 695258c9..e3e956a9 100644 --- a/cas/cloudcas/cloudcas.go +++ b/cas/cloudcas/cloudcas.go @@ -10,14 +10,16 @@ import ( "strings" "time" - privateca "cloud.google.com/go/security/privateca/apiv1beta1" + privateca "cloud.google.com/go/security/privateca/apiv1" "github.com/google/uuid" gax "github.com/googleapis/gax-go/v2" "github.com/pkg/errors" "github.com/smallstep/certificates/cas/apiv1" "go.step.sm/crypto/x509util" "google.golang.org/api/option" - pb "google.golang.org/genproto/googleapis/cloud/security/privateca/v1beta1" + pb "google.golang.org/genproto/googleapis/cloud/security/privateca/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" durationpb "google.golang.org/protobuf/types/known/durationpb" ) @@ -27,14 +29,12 @@ func init() { }) } -var now = func() time.Time { - return time.Now() -} +var now = time.Now // The actual regular expression that matches a certificate authority is: -// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/certificateAuthorities/[a-zA-Z0-9-_]+$ +// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$ // But we will allow a more flexible one to fail if this changes. -var caRegexp = regexp.MustCompile("^projects/[^/]+/locations/[^/]+/certificateAuthorities/[^/]+$") +var caRegexp = regexp.MustCompile("^projects/[^/]+/locations/[^/]+/caPools/[^/]+/certificateAuthorities/[^/]+$") // CertificateAuthorityClient is the interface implemented by the Google CAS // client. @@ -45,6 +45,9 @@ type CertificateAuthorityClient interface { CreateCertificateAuthority(ctx context.Context, req *pb.CreateCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.CreateCertificateAuthorityOperation, error) FetchCertificateAuthorityCsr(ctx context.Context, req *pb.FetchCertificateAuthorityCsrRequest, opts ...gax.CallOption) (*pb.FetchCertificateAuthorityCsrResponse, error) ActivateCertificateAuthority(ctx context.Context, req *pb.ActivateCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.ActivateCertificateAuthorityOperation, error) + EnableCertificateAuthority(ctx context.Context, req *pb.EnableCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.EnableCertificateAuthorityOperation, error) + GetCaPool(ctx context.Context, req *pb.GetCaPoolRequest, opts ...gax.CallOption) (*pb.CaPool, error) + CreateCaPool(ctx context.Context, req *pb.CreateCaPoolRequest, opts ...gax.CallOption) (*privateca.CreateCaPoolOperation, error) } // recocationCodeMap maps revocation reason codes from RFC 5280, to Google CAS @@ -62,12 +65,22 @@ var revocationCodeMap = map[int]pb.RevocationReason{ 10: pb.RevocationReason_ATTRIBUTE_AUTHORITY_COMPROMISE, } +// caPoolTierMap contains the map between apiv1.Options.Tier and the pb type. +var caPoolTierMap = map[string]pb.CaPool_Tier{ + "": pb.CaPool_DEVOPS, + "ENTERPRISE": pb.CaPool_ENTERPRISE, + "DEVOPS": pb.CaPool_DEVOPS, +} + // CloudCAS implements a Certificate Authority Service using Google Cloud CAS. type CloudCAS struct { client CertificateAuthorityClient certificateAuthority string project string location string + caPool string + caPoolTier pb.CaPool_Tier + gcsBucket string } // newCertificateAuthorityClient creates the certificate authority client. This @@ -87,12 +100,19 @@ var newCertificateAuthorityClient = func(ctx context.Context, credentialsFile st // New creates a new CertificateAuthorityService implementation using Google // Cloud CAS. func New(ctx context.Context, opts apiv1.Options) (*CloudCAS, error) { - if opts.IsCreator { + var caPoolTier pb.CaPool_Tier + if opts.IsCreator && opts.CertificateAuthority == "" { switch { case opts.Project == "": return nil, errors.New("cloudCAS 'project' cannot be empty") case opts.Location == "": return nil, errors.New("cloudCAS 'location' cannot be empty") + case opts.CaPool == "": + return nil, errors.New("cloudCAS 'caPool' cannot be empty") + } + var ok bool + if caPoolTier, ok = caPoolTierMap[strings.ToUpper(opts.CaPoolTier)]; !ok { + return nil, errors.New("cloudCAS 'caPoolTier' is not a valid tier") } } else { if opts.CertificateAuthority == "" { @@ -102,13 +122,16 @@ func New(ctx context.Context, opts apiv1.Options) (*CloudCAS, error) { return nil, errors.New("cloudCAS 'certificateAuthority' is not valid certificate authority resource") } // Extract project and location from CertificateAuthority - if parts := strings.Split(opts.CertificateAuthority, "/"); len(parts) == 6 { + if parts := strings.Split(opts.CertificateAuthority, "/"); len(parts) == 8 { if opts.Project == "" { opts.Project = parts[1] } if opts.Location == "" { opts.Location = parts[3] } + if opts.CaPool == "" { + opts.CaPool = parts[5] + } } } @@ -117,11 +140,15 @@ func New(ctx context.Context, opts apiv1.Options) (*CloudCAS, error) { return nil, err } + // GCSBucket is the the bucket name or empty for a managed bucket. return &CloudCAS{ client: client, certificateAuthority: opts.CertificateAuthority, project: opts.Project, location: opts.Location, + caPool: opts.CaPool, + gcsBucket: opts.GCSBucket, + caPoolTier: caPoolTier, }, nil } @@ -251,6 +278,10 @@ func (c *CloudCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthor return nil, errors.New("cloudCAS `project` cannot be empty") case c.location == "": return nil, errors.New("cloudCAS `location` cannot be empty") + case c.caPool == "": + return nil, errors.New("cloudCAS `caPool` cannot be empty") + case c.caPoolTier == 0: + return nil, errors.New("cloudCAS `caPoolTier` cannot be empty") case req.Template == nil: return nil, errors.New("createCertificateAuthorityRequest `template` cannot be nil") case req.Lifetime == 0: @@ -301,28 +332,30 @@ func (c *CloudCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthor } req.Template.ExtraExtensions = append(req.Template.ExtraExtensions, casExtension) + // Create the caPool if necessary + parent, err := c.createCaPoolIfNecessary() + if err != nil { + return nil, err + } + // Prepare CreateCertificateAuthorityRequest pbReq := &pb.CreateCertificateAuthorityRequest{ - Parent: "projects/" + c.project + "/locations/" + c.location, + Parent: parent, CertificateAuthorityId: caID, RequestId: req.RequestID, CertificateAuthority: &pb.CertificateAuthority{ Type: caType, - Tier: pb.CertificateAuthority_ENTERPRISE, Config: &pb.CertificateConfig{ SubjectConfig: &pb.CertificateConfig_SubjectConfig{ - Subject: createSubject(req.Template), - CommonName: req.Template.Subject.CommonName, + Subject: createSubject(req.Template), + SubjectAltName: createSubjectAlternativeNames(req.Template), }, - ReusableConfig: createReusableConfig(req.Template), + X509Config: createX509Parameters(req.Template), }, - Lifetime: durationpb.New(req.Lifetime), - KeySpec: keySpec, - IssuingOptions: &pb.CertificateAuthority_IssuingOptions{ - IncludeCaCertUrl: true, - IncludeCrlAccessUrl: true, - }, - Labels: map[string]string{}, + Lifetime: durationpb.New(req.Lifetime), + KeySpec: keySpec, + GcsBucket: c.gcsBucket, + Labels: map[string]string{}, }, } @@ -346,12 +379,18 @@ func (c *CloudCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthor // Sign Intermediate CAs with the parent. if req.Type == apiv1.IntermediateCA { - ca, err = c.signIntermediateCA(ca.Name, req) + ca, err = c.signIntermediateCA(parent, ca.Name, req) if err != nil { return nil, err } } + // Enable Certificate Authority. + ca, err = c.enableCertificateAuthority(ca) + if err != nil { + return nil, err + } + if len(ca.PemCaCertificates) == 0 { return nil, errors.New("cloudCAS CreateCertificateAuthority failed: PemCaCertificates is empty") } @@ -378,6 +417,83 @@ func (c *CloudCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthor }, nil } +func (c *CloudCAS) createCaPoolIfNecessary() (string, error) { + ctx, cancel := defaultContext() + defer cancel() + + pool, err := c.client.GetCaPool(ctx, &pb.GetCaPoolRequest{ + Name: "projects/" + c.project + "/locations/" + c.location + "/caPools/" + c.caPool, + }) + if err == nil { + return pool.Name, nil + } + + if status.Code(err) != codes.NotFound { + return "", errors.Wrap(err, "cloudCAS GetCaPool failed") + } + + // PublishCrl is only supported by the enterprise tier + var publishCrl bool + if c.caPoolTier == pb.CaPool_ENTERPRISE { + publishCrl = true + } + + ctx, cancel = defaultContext() + defer cancel() + + op, err := c.client.CreateCaPool(ctx, &pb.CreateCaPoolRequest{ + Parent: "projects/" + c.project + "/locations/" + c.location, + CaPoolId: c.caPool, + CaPool: &pb.CaPool{ + Tier: c.caPoolTier, + IssuancePolicy: nil, + PublishingOptions: &pb.CaPool_PublishingOptions{ + PublishCaCert: true, + PublishCrl: publishCrl, + }, + }, + }) + if err != nil { + return "", errors.Wrap(err, "cloudCAS CreateCaPool failed") + } + + ctx, cancel = defaultInitiatorContext() + defer cancel() + + pool, err = op.Wait(ctx) + if err != nil { + return "", errors.Wrap(err, "cloudCAS CreateCaPool failed") + } + + return pool.Name, nil +} + +func (c *CloudCAS) enableCertificateAuthority(ca *pb.CertificateAuthority) (*pb.CertificateAuthority, error) { + if ca.State == pb.CertificateAuthority_ENABLED { + return ca, nil + } + + ctx, cancel := defaultContext() + defer cancel() + + resp, err := c.client.EnableCertificateAuthority(ctx, &pb.EnableCertificateAuthorityRequest{ + Name: ca.Name, + }) + if err != nil { + return nil, errors.Wrap(err, "cloudCAS EnableCertificateAuthority failed") + } + + ctx, cancel = defaultInitiatorContext() + defer cancel() + + ca, err = resp.Wait(ctx) + if err != nil { + return nil, errors.Wrap(err, "cloudCAS EnableCertificateAuthority failed") + } + + return ca, nil +} + func (c *CloudCAS) createCertificate(tpl *x509.Certificate, lifetime time.Duration, requestID string) (*x509.Certificate, []*x509.Certificate, error) { // Removes the CAS extension if it exists. apiv1.RemoveCertificateAuthorityExtension(tpl) @@ -403,14 +519,15 @@ func (c *CloudCAS) createCertificate(tpl *x509.Certificate, lifetime time.Durati defer cancel() cert, err := c.client.CreateCertificate(ctx, &pb.CreateCertificateRequest{ - Parent: c.certificateAuthority, + Parent: "projects/" + c.project + "/locations/" + c.location + "/caPools/" + c.caPool, CertificateId: id, Certificate: &pb.Certificate{ CertificateConfig: certConfig, Lifetime: durationpb.New(lifetime), Labels: map[string]string{}, }, - RequestId: requestID, + IssuingCertificateAuthorityId: getResourceName(c.certificateAuthority), + RequestId: requestID, }) if err != nil { return nil, nil, errors.Wrap(err, "cloudCAS CreateCertificate failed") @@ -420,7 +537,7 @@ func (c *CloudCAS) createCertificate(tpl *x509.Certificate, lifetime time.Durati return getCertificateAndChain(cert) } -func (c *CloudCAS) signIntermediateCA(name string, req *apiv1.CreateCertificateAuthorityRequest) (*pb.CertificateAuthority, error) { +func (c *CloudCAS) signIntermediateCA(parent, name string, req *apiv1.CreateCertificateAuthorityRequest) (*pb.CertificateAuthority, error) { id, err := createCertificateID() if err != nil { return nil, err @@ -477,7 +594,7 @@ func (c *CloudCAS) signIntermediateCA(name string, req *apiv1.CreateCertificateA defer cancel() cert, err = c.client.CreateCertificate(ctx, &pb.CreateCertificateRequest{ - Parent: req.Parent.Name, + Parent: parent, CertificateId: id, Certificate: &pb.Certificate{ CertificateConfig: &pb.Certificate_PemCsr{ @@ -486,7 +603,8 @@ func (c *CloudCAS) signIntermediateCA(name string, req *apiv1.CreateCertificateA Lifetime: durationpb.New(req.Lifetime), Labels: map[string]string{}, }, - RequestId: req.RequestID, + IssuingCertificateAuthorityId: getResourceName(req.Parent.Name), + RequestId: req.RequestID, }) if err != nil { return nil, errors.Wrap(err, "cloudCAS CreateCertificate failed") @@ -587,7 +705,12 @@ func getCertificateAndChain(certpb *pb.Certificate) (*x509.Certificate, []*x509. } return cert, chain, nil +} +// getResourceName returns the last part of a resource. +func getResourceName(name string) string { + parts := strings.Split(name, "/") + return parts[len(parts)-1] } // Normalize a certificate authority name to comply with [a-zA-Z0-9-_]. diff --git a/cas/cloudcas/cloudcas_test.go b/cas/cloudcas/cloudcas_test.go index eb682e28..7f996c15 100644 --- a/cas/cloudcas/cloudcas_test.go +++ b/cas/cloudcas/cloudcas_test.go @@ -12,7 +12,6 @@ import ( "encoding/pem" "fmt" "io" - "log" "net" "os" "reflect" @@ -20,7 +19,7 @@ import ( "time" lroauto "cloud.google.com/go/longrunning/autogen" - privateca "cloud.google.com/go/security/privateca/apiv1beta1" + privateca "cloud.google.com/go/security/privateca/apiv1" gomock "github.com/golang/mock/gomock" "github.com/google/uuid" gax "github.com/googleapis/gax-go/v2" @@ -28,19 +27,23 @@ import ( "github.com/smallstep/certificates/cas/apiv1" kmsapi "github.com/smallstep/certificates/kms/apiv1" "google.golang.org/api/option" - pb "google.golang.org/genproto/googleapis/cloud/security/privateca/v1beta1" + pb "google.golang.org/genproto/googleapis/cloud/security/privateca/v1" longrunningpb "google.golang.org/genproto/googleapis/longrunning" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/types/known/anypb" ) var ( errTest = errors.New("test error") - testAuthorityName = "projects/test-project/locations/us-west1/certificateAuthorities/test-ca" - testCertificateName = "projects/test-project/locations/us-west1/certificateAuthorities/test-ca/certificates/test-certificate" + testCaPoolName = "projects/test-project/locations/us-west1/caPools/test-capool" + testAuthorityName = "projects/test-project/locations/us-west1/caPools/test-capool/certificateAuthorities/test-ca" + testCertificateName = "projects/test-project/locations/us-west1/caPools/test-capool/certificateAuthorities/test-ca/certificates/test-certificate" testProject = "test-project" testLocation = "us-west1" + testCaPool = "test-capool" testRootCertificate = `-----BEGIN CERTIFICATE----- MIIBeDCCAR+gAwIBAgIQcXWWjtSZ/PAyH8D1Ou4L9jAKBggqhkjOPQQDAjAbMRkw FwYDVQQDExBDbG91ZENBUyBSb290IENBMB4XDTIwMTAyNzIyNTM1NFoXDTMwMTAy @@ -99,7 +102,7 @@ MHcCAQEEIN51Rgg6YcQVLeCRzumdw4pjM3VWqFIdCbnsV3Up1e/goAoGCCqGSM49 AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg== -----END EC PRIVATE KEY-----` - // nolint:unused,deadcode + // nolint:unused,deadcode,gocritic testIntermediateKey = `-----BEGIN EC PRIVATE KEY----- MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49 AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6 @@ -186,7 +189,7 @@ func (b *badSigner) Public() crypto.PublicKey { return b.pub } -func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { +func (b *badSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { return nil, fmt.Errorf("πŸ’₯") } @@ -214,6 +217,18 @@ func (c *testClient) ActivateCertificateAuthority(ctx context.Context, req *pb.A return nil, errors.New("use NewMockCertificateAuthorityClient") } +func (c *testClient) EnableCertificateAuthority(ctx context.Context, req *pb.EnableCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.EnableCertificateAuthorityOperation, error) { + return nil, errors.New("use NewMockCertificateAuthorityClient") +} + +func (c *testClient) GetCaPool(ctx context.Context, req *pb.GetCaPoolRequest, opts ...gax.CallOption) (*pb.CaPool, error) { + return nil, errors.New("use NewMockCertificateAuthorityClient") +} + +func (c *testClient) CreateCaPool(ctx context.Context, req *pb.CreateCaPoolRequest, opts ...gax.CallOption) (*privateca.CreateCaPoolOperation, error) { + return nil, errors.New("use NewMockCertificateAuthorityClient") +} + func mustParseCertificate(t *testing.T, pemCert string) *x509.Certificate { t.Helper() crt, err := parseCertificate(pemCert) @@ -262,6 +277,18 @@ func TestNew(t *testing.T) { certificateAuthority: testAuthorityName, project: testProject, location: testLocation, + caPool: testCaPool, + caPoolTier: 0, + }, false}, + {"ok authority and creator", args{context.Background(), apiv1.Options{ + CertificateAuthority: testAuthorityName, IsCreator: true, + }}, &CloudCAS{ + client: &testClient{}, + certificateAuthority: testAuthorityName, + project: testProject, + location: testLocation, + caPool: testCaPool, + caPoolTier: 0, }, false}, {"ok with credentials", args{context.Background(), apiv1.Options{ CertificateAuthority: testAuthorityName, CredentialsFile: "testdata/credentials.json", @@ -270,16 +297,38 @@ func TestNew(t *testing.T) { certificateAuthority: testAuthorityName, project: testProject, location: testLocation, + caPool: testCaPool, + caPoolTier: 0, }, false}, {"ok creator", args{context.Background(), apiv1.Options{ - IsCreator: true, Project: testProject, Location: testLocation, + IsCreator: true, Project: testProject, Location: testLocation, CaPool: testCaPool, }}, &CloudCAS{ - client: &testClient{}, - project: testProject, - location: testLocation, + client: &testClient{}, + project: testProject, + location: testLocation, + caPool: testCaPool, + caPoolTier: pb.CaPool_DEVOPS, + }, false}, + {"ok creator devops", args{context.Background(), apiv1.Options{ + IsCreator: true, Project: testProject, Location: testLocation, CaPool: testCaPool, CaPoolTier: "DevOps", + }}, &CloudCAS{ + client: &testClient{}, + project: testProject, + location: testLocation, + caPool: testCaPool, + caPoolTier: pb.CaPool_DEVOPS, + }, false}, + {"ok creator enterprise", args{context.Background(), apiv1.Options{ + IsCreator: true, Project: testProject, Location: testLocation, CaPool: testCaPool, CaPoolTier: "ENTERPRISE", + }}, &CloudCAS{ + client: &testClient{}, + project: testProject, + location: testLocation, + caPool: testCaPool, + caPoolTier: pb.CaPool_ENTERPRISE, }, false}, {"fail certificate authority", args{context.Background(), apiv1.Options{ - CertificateAuthority: "projects/ok1234/locations/ok1234/certificateAuthorities/ok1234/bad", + CertificateAuthority: "projects/ok1234/locations/ok1234/caPools/ok1234/certificateAuthorities/ok1234/bad", }}, nil, true}, {"fail certificate authority regex", args{context.Background(), apiv1.Options{}}, nil, true}, {"fail with credentials", args{context.Background(), apiv1.Options{ @@ -291,6 +340,9 @@ func TestNew(t *testing.T) { {"fail creator location", args{context.Background(), apiv1.Options{ IsCreator: true, Project: testProject, Location: "", }}, nil, true}, + {"fail caPool", args{context.Background(), apiv1.Options{ + IsCreator: true, Project: testProject, Location: testLocation, CaPool: "", + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -320,6 +372,7 @@ func TestNew_register(t *testing.T) { certificateAuthority: testAuthorityName, project: testProject, location: testLocation, + caPool: testCaPool, } newFn, ok := apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.CloudCAS) @@ -338,7 +391,6 @@ func TestNew_register(t *testing.T) { if !reflect.DeepEqual(got, want) { t.Errorf("New() = %v, want %v", got, want) } - } func TestNew_real(t *testing.T) { @@ -677,7 +729,7 @@ func TestCloudCAS_RevokeCertificate(t *testing.T) { func Test_createCertificateID(t *testing.T) { buf := new(bytes.Buffer) setTeeReader(t, buf) - uuid, err := uuid.NewRandomFromReader(rand.Reader) + id, err := uuid.NewRandomFromReader(rand.Reader) if err != nil { t.Fatal(err) } @@ -688,7 +740,7 @@ func Test_createCertificateID(t *testing.T) { want string wantErr bool }{ - {"ok", uuid.String(), false}, + {"ok", id.String(), false}, {"fail", "", true}, } for _, tt := range tests { @@ -805,21 +857,34 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { return lis.Dial() })) if err != nil { - log.Fatal(err) + t.Fatal(err) } client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn)) if err != nil { t.Fatal(err) } - fake := &privateca.CertificateAuthorityClient{ - LROClient: client, + fake, err := privateca.NewCertificateAuthorityClient(context.Background(), option.WithGRPCConn(conn)) + if err != nil { + t.Fatal(err) } + fake.LROClient = client // Configure mocks any := gomock.Any() // ok root + m.EXPECT().GetCaPool(any, any).Return(nil, status.Error(codes.NotFound, "not found")) + m.EXPECT().CreateCaPool(any, any).Return(fake.CreateCaPoolOperation("CreateCaPool"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "CreateCaPool", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CaPool{ + Name: testCaPoolName, + })).(*anypb.Any), + }, + }, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", @@ -831,33 +896,20 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { })).(*anypb.Any), }, }, nil) - - // ok intermediate - m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) + m.EXPECT().EnableCertificateAuthority(any, any).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ - Name: "CreateCertificateAuthority", - Done: true, - Result: &longrunningpb.Operation_Response{ - Response: must(anypb.New(&pb.CertificateAuthority{ - Name: testAuthorityName, - })).(*anypb.Any), - }, - }, nil) - m.EXPECT().FetchCertificateAuthorityCsr(any, any).Return(&pb.FetchCertificateAuthorityCsrResponse{ - PemCsr: testIntermediateCsr, - }, nil) - m.EXPECT().ActivateCertificateAuthority(any, any).Return(fake.ActivateCertificateAuthorityOperation("ActivateCertificateAuthority"), nil) - mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ - Name: "ActivateCertificateAuthority", + Name: "EnableCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, - PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, + PemCaCertificates: []string{testRootCertificate}, })).(*anypb.Any), }, }, nil) - // ok intermediate local signer + + // ok intermediate + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", @@ -886,8 +938,58 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { })).(*anypb.Any), }, }, nil) + m.EXPECT().EnableCertificateAuthority(any, any).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "EnableCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, + })).(*anypb.Any), + }, + }, nil) + + // ok intermediate local signer + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) + m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "CreateCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + })).(*anypb.Any), + }, + }, nil) + m.EXPECT().FetchCertificateAuthorityCsr(any, any).Return(&pb.FetchCertificateAuthorityCsrResponse{ + PemCsr: testIntermediateCsr, + }, nil) + m.EXPECT().ActivateCertificateAuthority(any, any).Return(fake.ActivateCertificateAuthorityOperation("ActivateCertificateAuthority"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "ActivateCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, + })).(*anypb.Any), + }, + }, nil) + m.EXPECT().EnableCertificateAuthority(any, any).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "EnableCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, + })).(*anypb.Any), + }, + }, nil) // ok create key + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", @@ -899,15 +1001,137 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { })).(*anypb.Any), }, }, nil) + m.EXPECT().EnableCertificateAuthority(any, any).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "EnableCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + PemCaCertificates: []string{testRootCertificate}, + })).(*anypb.Any), + }, + }, nil) + + // fail GetCaPool + m.EXPECT().GetCaPool(any, any).Return(nil, errTest) + + // fail CreateCaPool + m.EXPECT().GetCaPool(any, any).Return(nil, status.Error(codes.NotFound, "not found")) + m.EXPECT().CreateCaPool(any, any).Return(nil, errTest) + + // fail CreateCaPool.Wait + m.EXPECT().GetCaPool(any, any).Return(nil, status.Error(codes.NotFound, "not found")) + m.EXPECT().CreateCaPool(any, any).Return(fake.CreateCaPoolOperation("CreateCaPool"), nil) + mos.EXPECT().GetOperation(any, any).Return(nil, errTest) // fail CreateCertificateAuthority + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(nil, errTest) // fail CreateCertificateAuthority.Wait + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(nil, errTest) + // fail EnableCertificateAuthority + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) + m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "CreateCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + PemCaCertificates: []string{testRootCertificate}, + })).(*anypb.Any), + }, + }, nil) + m.EXPECT().EnableCertificateAuthority(any, any).Return(nil, errTest) + + // fail EnableCertificateAuthority.Wait + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) + m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "CreateCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + PemCaCertificates: []string{testRootCertificate}, + })).(*anypb.Any), + }, + }, nil) + m.EXPECT().EnableCertificateAuthority(any, any).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) + mos.EXPECT().GetOperation(any, any).Return(nil, errTest) + + // fail EnableCertificateAuthority intermediate + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) + m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "CreateCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + })).(*anypb.Any), + }, + }, nil) + m.EXPECT().FetchCertificateAuthorityCsr(any, any).Return(&pb.FetchCertificateAuthorityCsrResponse{ + PemCsr: testIntermediateCsr, + }, nil) + m.EXPECT().CreateCertificate(any, any).Return(&pb.Certificate{ + PemCertificate: testIntermediateCertificate, + PemCertificateChain: []string{testRootCertificate}, + }, nil) + m.EXPECT().ActivateCertificateAuthority(any, any).Return(fake.ActivateCertificateAuthorityOperation("ActivateCertificateAuthority"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "ActivateCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, + })).(*anypb.Any), + }, + }, nil) + m.EXPECT().EnableCertificateAuthority(any, any).Return(nil, errTest) + + // fail EnableCertificateAuthority.Wait intermediate + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) + m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "CreateCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + })).(*anypb.Any), + }, + }, nil) + m.EXPECT().FetchCertificateAuthorityCsr(any, any).Return(&pb.FetchCertificateAuthorityCsrResponse{ + PemCsr: testIntermediateCsr, + }, nil) + m.EXPECT().CreateCertificate(any, any).Return(&pb.Certificate{ + PemCertificate: testIntermediateCertificate, + PemCertificateChain: []string{testRootCertificate}, + }, nil) + m.EXPECT().ActivateCertificateAuthority(any, any).Return(fake.ActivateCertificateAuthorityOperation("ActivateCertificateAuthority"), nil) + mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ + Name: "ActivateCertificateAuthority", + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: must(anypb.New(&pb.CertificateAuthority{ + Name: testAuthorityName, + PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, + })).(*anypb.Any), + }, + }, nil) + m.EXPECT().EnableCertificateAuthority(any, any).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) + mos.EXPECT().GetOperation(any, any).Return(nil, errTest) + // fail FetchCertificateAuthorityCsr + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", @@ -921,6 +1145,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { m.EXPECT().FetchCertificateAuthorityCsr(any, any).Return(nil, errTest) // fail CreateCertificate + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", @@ -937,6 +1162,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { m.EXPECT().CreateCertificate(any, any).Return(nil, errTest) // fail ActivateCertificateAuthority + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", @@ -957,6 +1183,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { m.EXPECT().ActivateCertificateAuthority(any, any).Return(nil, errTest) // fail ActivateCertificateAuthority.Wait + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", @@ -978,6 +1205,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { mos.EXPECT().GetOperation(any, any).Return(nil, errTest) // fail x509util.CreateCertificate + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", @@ -993,6 +1221,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { }, nil) // fail parseCertificateRequest + m.EXPECT().GetCaPool(any, any).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(any, any).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(any, any).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", @@ -1015,6 +1244,8 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { certificateAuthority string project string location string + caPool string + caPoolTier pb.CaPool_Tier } type args struct { req *apiv1.CreateCertificateAuthorityRequest @@ -1026,7 +1257,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { want *apiv1.CreateCertificateAuthorityResponse wantErr bool }{ - {"ok root", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"ok root", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_ENTERPRISE}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, @@ -1034,7 +1265,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { Name: testAuthorityName, Certificate: rootCrt, }, false}, - {"ok intermediate", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"ok intermediate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, @@ -1047,7 +1278,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { Certificate: intCrt, CertificateChain: []*x509.Certificate{rootCrt}, }, false}, - {"ok intermediate local signer", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"ok intermediate local signer", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_ENTERPRISE}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, @@ -1060,7 +1291,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { Certificate: intCrt, CertificateChain: []*x509.Certificate{rootCrt}, }, false}, - {"ok create key", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"ok create key", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, @@ -1071,41 +1302,46 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { Name: testAuthorityName, Certificate: rootCrt, }, false}, - {"fail project", fields{m, "", "", testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail project", fields{m, "", "", testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, - {"fail location", fields{m, "", testProject, ""}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail location", fields{m, "", testProject, "", testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, - {"fail template", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail caPool", fields{m, "", testProject, testLocation, "", pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: mustParseCertificate(t, testRootCertificate), + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail template", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Lifetime: 24 * time.Hour, }}, nil, true}, - {"fail lifetime", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail lifetime", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), }}, nil, true}, - {"fail parent", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail parent", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, - {"fail parent name", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail parent name", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{}, }}, nil, true}, - {"fail type", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail type", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: 0, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, - {"fail create key", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail create key", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, @@ -1113,17 +1349,43 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { SignatureAlgorithm: kmsapi.PureEd25519, }, }}, nil, true}, - {"fail CreateCertificateAuthority", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail GetCaPool", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, - {"fail CreateCertificateAuthority.Wait", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail CreateCaPool", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, - {"fail FetchCertificateAuthorityCsr", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail CreateCaPool.Wait", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: mustParseCertificate(t, testRootCertificate), + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail CreateCertificateAuthority", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: mustParseCertificate(t, testRootCertificate), + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail CreateCertificateAuthority.Wait", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: mustParseCertificate(t, testRootCertificate), + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail EnableCertificateAuthority", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: mustParseCertificate(t, testRootCertificate), + Lifetime: 24 * time.Hour, + }}, nil, true}, + {"fail EnableCertificateAuthority.Wait", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: mustParseCertificate(t, testRootCertificate), + Lifetime: 24 * time.Hour, + }}, nil, true}, + + {"fail EnableCertificateAuthority intermediate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, @@ -1132,7 +1394,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { Certificate: rootCrt, }, }}, nil, true}, - {"fail CreateCertificate", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail EnableCertificateAuthority.Wait intermediate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, @@ -1141,7 +1403,8 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { Certificate: rootCrt, }, }}, nil, true}, - {"fail ActivateCertificateAuthority", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + + {"fail FetchCertificateAuthorityCsr", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, @@ -1150,7 +1413,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { Certificate: rootCrt, }, }}, nil, true}, - {"fail ActivateCertificateAuthority.Wait", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail CreateCertificate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, @@ -1159,7 +1422,25 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { Certificate: rootCrt, }, }}, nil, true}, - {"fail x509util.CreateCertificate", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail ActivateCertificateAuthority", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: mustParseCertificate(t, testIntermediateCertificate), + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Name: testAuthorityName, + Certificate: rootCrt, + }, + }}, nil, true}, + {"fail ActivateCertificateAuthority.Wait", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: mustParseCertificate(t, testIntermediateCertificate), + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Name: testAuthorityName, + Certificate: rootCrt, + }, + }}, nil, true}, + {"fail x509util.CreateCertificate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, @@ -1168,7 +1449,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { Signer: createBadSigner(t), }, }}, nil, true}, - {"fail parseCertificateRequest", fields{m, "", testProject, testLocation}, args{&apiv1.CreateCertificateAuthorityRequest{ + {"fail parseCertificateRequest", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, @@ -1185,6 +1466,8 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { certificateAuthority: tt.fields.certificateAuthority, project: tt.fields.project, location: tt.fields.location, + caPool: tt.fields.caPool, + caPoolTier: tt.fields.caPoolTier, } got, err := c.CreateCertificateAuthority(tt.args.req) if (err != nil) != tt.wantErr { diff --git a/cas/cloudcas/mock_client_test.go b/cas/cloudcas/mock_client_test.go index b81d3135..de5c2acb 100644 --- a/cas/cloudcas/mock_client_test.go +++ b/cas/cloudcas/mock_client_test.go @@ -1,15 +1,15 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ./cas/cloudcas/cloudcas.go +// Source: github.com/smallstep/certificates/cas/cloudcas (interfaces: CertificateAuthorityClient) // Package cloudcas is a generated GoMock package. package cloudcas import ( - privateca "cloud.google.com/go/security/privateca/apiv1beta1" + privateca "cloud.google.com/go/security/privateca/apiv1" context "context" gomock "github.com/golang/mock/gomock" gax "github.com/googleapis/gax-go/v2" - privateca0 "google.golang.org/genproto/googleapis/cloud/security/privateca/v1beta1" + privateca0 "google.golang.org/genproto/googleapis/cloud/security/privateca/v1" reflect "reflect" ) @@ -36,111 +36,11 @@ func (m *MockCertificateAuthorityClient) EXPECT() *MockCertificateAuthorityClien return m.recorder } -// CreateCertificate mocks base method -func (m *MockCertificateAuthorityClient) CreateCertificate(ctx context.Context, req *privateca0.CreateCertificateRequest, opts ...gax.CallOption) (*privateca0.Certificate, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, req} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "CreateCertificate", varargs...) - ret0, _ := ret[0].(*privateca0.Certificate) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateCertificate indicates an expected call of CreateCertificate -func (mr *MockCertificateAuthorityClientMockRecorder) CreateCertificate(ctx, req interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, req}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCertificate", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).CreateCertificate), varargs...) -} - -// RevokeCertificate mocks base method -func (m *MockCertificateAuthorityClient) RevokeCertificate(ctx context.Context, req *privateca0.RevokeCertificateRequest, opts ...gax.CallOption) (*privateca0.Certificate, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, req} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "RevokeCertificate", varargs...) - ret0, _ := ret[0].(*privateca0.Certificate) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// RevokeCertificate indicates an expected call of RevokeCertificate -func (mr *MockCertificateAuthorityClientMockRecorder) RevokeCertificate(ctx, req interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, req}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeCertificate", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).RevokeCertificate), varargs...) -} - -// GetCertificateAuthority mocks base method -func (m *MockCertificateAuthorityClient) GetCertificateAuthority(ctx context.Context, req *privateca0.GetCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca0.CertificateAuthority, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, req} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetCertificateAuthority", varargs...) - ret0, _ := ret[0].(*privateca0.CertificateAuthority) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetCertificateAuthority indicates an expected call of GetCertificateAuthority -func (mr *MockCertificateAuthorityClientMockRecorder) GetCertificateAuthority(ctx, req interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, req}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).GetCertificateAuthority), varargs...) -} - -// CreateCertificateAuthority mocks base method -func (m *MockCertificateAuthorityClient) CreateCertificateAuthority(ctx context.Context, req *privateca0.CreateCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.CreateCertificateAuthorityOperation, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, req} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "CreateCertificateAuthority", varargs...) - ret0, _ := ret[0].(*privateca.CreateCertificateAuthorityOperation) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateCertificateAuthority indicates an expected call of CreateCertificateAuthority -func (mr *MockCertificateAuthorityClientMockRecorder) CreateCertificateAuthority(ctx, req interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, req}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).CreateCertificateAuthority), varargs...) -} - -// FetchCertificateAuthorityCsr mocks base method -func (m *MockCertificateAuthorityClient) FetchCertificateAuthorityCsr(ctx context.Context, req *privateca0.FetchCertificateAuthorityCsrRequest, opts ...gax.CallOption) (*privateca0.FetchCertificateAuthorityCsrResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, req} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "FetchCertificateAuthorityCsr", varargs...) - ret0, _ := ret[0].(*privateca0.FetchCertificateAuthorityCsrResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// FetchCertificateAuthorityCsr indicates an expected call of FetchCertificateAuthorityCsr -func (mr *MockCertificateAuthorityClientMockRecorder) FetchCertificateAuthorityCsr(ctx, req interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, req}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchCertificateAuthorityCsr", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).FetchCertificateAuthorityCsr), varargs...) -} - // ActivateCertificateAuthority mocks base method -func (m *MockCertificateAuthorityClient) ActivateCertificateAuthority(ctx context.Context, req *privateca0.ActivateCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.ActivateCertificateAuthorityOperation, error) { +func (m *MockCertificateAuthorityClient) ActivateCertificateAuthority(arg0 context.Context, arg1 *privateca0.ActivateCertificateAuthorityRequest, arg2 ...gax.CallOption) (*privateca.ActivateCertificateAuthorityOperation, error) { m.ctrl.T.Helper() - varargs := []interface{}{ctx, req} - for _, a := range opts { + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ActivateCertificateAuthority", varargs...) @@ -150,8 +50,168 @@ func (m *MockCertificateAuthorityClient) ActivateCertificateAuthority(ctx contex } // ActivateCertificateAuthority indicates an expected call of ActivateCertificateAuthority -func (mr *MockCertificateAuthorityClientMockRecorder) ActivateCertificateAuthority(ctx, req interface{}, opts ...interface{}) *gomock.Call { +func (mr *MockCertificateAuthorityClientMockRecorder) ActivateCertificateAuthority(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, req}, opts...) + varargs := append([]interface{}{arg0, arg1}, arg2...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivateCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).ActivateCertificateAuthority), varargs...) } + +// CreateCaPool mocks base method +func (m *MockCertificateAuthorityClient) CreateCaPool(arg0 context.Context, arg1 *privateca0.CreateCaPoolRequest, arg2 ...gax.CallOption) (*privateca.CreateCaPoolOperation, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateCaPool", varargs...) + ret0, _ := ret[0].(*privateca.CreateCaPoolOperation) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateCaPool indicates an expected call of CreateCaPool +func (mr *MockCertificateAuthorityClientMockRecorder) CreateCaPool(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCaPool", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).CreateCaPool), varargs...) +} + +// CreateCertificate mocks base method +func (m *MockCertificateAuthorityClient) CreateCertificate(arg0 context.Context, arg1 *privateca0.CreateCertificateRequest, arg2 ...gax.CallOption) (*privateca0.Certificate, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateCertificate", varargs...) + ret0, _ := ret[0].(*privateca0.Certificate) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateCertificate indicates an expected call of CreateCertificate +func (mr *MockCertificateAuthorityClientMockRecorder) CreateCertificate(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCertificate", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).CreateCertificate), varargs...) +} + +// CreateCertificateAuthority mocks base method +func (m *MockCertificateAuthorityClient) CreateCertificateAuthority(arg0 context.Context, arg1 *privateca0.CreateCertificateAuthorityRequest, arg2 ...gax.CallOption) (*privateca.CreateCertificateAuthorityOperation, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateCertificateAuthority", varargs...) + ret0, _ := ret[0].(*privateca.CreateCertificateAuthorityOperation) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateCertificateAuthority indicates an expected call of CreateCertificateAuthority +func (mr *MockCertificateAuthorityClientMockRecorder) CreateCertificateAuthority(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).CreateCertificateAuthority), varargs...) +} + +// EnableCertificateAuthority mocks base method +func (m *MockCertificateAuthorityClient) EnableCertificateAuthority(arg0 context.Context, arg1 *privateca0.EnableCertificateAuthorityRequest, arg2 ...gax.CallOption) (*privateca.EnableCertificateAuthorityOperation, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "EnableCertificateAuthority", varargs...) + ret0, _ := ret[0].(*privateca.EnableCertificateAuthorityOperation) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EnableCertificateAuthority indicates an expected call of EnableCertificateAuthority +func (mr *MockCertificateAuthorityClientMockRecorder) EnableCertificateAuthority(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).EnableCertificateAuthority), varargs...) +} + +// FetchCertificateAuthorityCsr mocks base method +func (m *MockCertificateAuthorityClient) FetchCertificateAuthorityCsr(arg0 context.Context, arg1 *privateca0.FetchCertificateAuthorityCsrRequest, arg2 ...gax.CallOption) (*privateca0.FetchCertificateAuthorityCsrResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "FetchCertificateAuthorityCsr", varargs...) + ret0, _ := ret[0].(*privateca0.FetchCertificateAuthorityCsrResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchCertificateAuthorityCsr indicates an expected call of FetchCertificateAuthorityCsr +func (mr *MockCertificateAuthorityClientMockRecorder) FetchCertificateAuthorityCsr(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchCertificateAuthorityCsr", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).FetchCertificateAuthorityCsr), varargs...) +} + +// GetCaPool mocks base method +func (m *MockCertificateAuthorityClient) GetCaPool(arg0 context.Context, arg1 *privateca0.GetCaPoolRequest, arg2 ...gax.CallOption) (*privateca0.CaPool, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetCaPool", varargs...) + ret0, _ := ret[0].(*privateca0.CaPool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCaPool indicates an expected call of GetCaPool +func (mr *MockCertificateAuthorityClientMockRecorder) GetCaPool(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCaPool", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).GetCaPool), varargs...) +} + +// GetCertificateAuthority mocks base method +func (m *MockCertificateAuthorityClient) GetCertificateAuthority(arg0 context.Context, arg1 *privateca0.GetCertificateAuthorityRequest, arg2 ...gax.CallOption) (*privateca0.CertificateAuthority, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetCertificateAuthority", varargs...) + ret0, _ := ret[0].(*privateca0.CertificateAuthority) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCertificateAuthority indicates an expected call of GetCertificateAuthority +func (mr *MockCertificateAuthorityClientMockRecorder) GetCertificateAuthority(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).GetCertificateAuthority), varargs...) +} + +// RevokeCertificate mocks base method +func (m *MockCertificateAuthorityClient) RevokeCertificate(arg0 context.Context, arg1 *privateca0.RevokeCertificateRequest, arg2 ...gax.CallOption) (*privateca0.Certificate, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "RevokeCertificate", varargs...) + ret0, _ := ret[0].(*privateca0.Certificate) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RevokeCertificate indicates an expected call of RevokeCertificate +func (mr *MockCertificateAuthorityClientMockRecorder) RevokeCertificate(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeCertificate", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).RevokeCertificate), varargs...) +} diff --git a/cas/cloudcas/mock_operation_server_test.go b/cas/cloudcas/mock_operation_server_test.go index 48564cd1..ee2743d4 100644 --- a/cas/cloudcas/mock_operation_server_test.go +++ b/cas/cloudcas/mock_operation_server_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: /Users/mariano/go/pkg/mod/google.golang.org/genproto@v0.0.0-20200904004341-0bd0a958aa1d/googleapis/longrunning/operations.pb.go +// Source: google.golang.org/genproto/googleapis/longrunning (interfaces: OperationsServer) // Package cloudcas is a generated GoMock package. package cloudcas @@ -8,169 +8,10 @@ import ( context "context" gomock "github.com/golang/mock/gomock" longrunning "google.golang.org/genproto/googleapis/longrunning" - grpc "google.golang.org/grpc" emptypb "google.golang.org/protobuf/types/known/emptypb" reflect "reflect" ) -// MockisOperation_Result is a mock of isOperation_Result interface -type MockisOperation_Result struct { - ctrl *gomock.Controller - recorder *MockisOperation_ResultMockRecorder -} - -// MockisOperation_ResultMockRecorder is the mock recorder for MockisOperation_Result -type MockisOperation_ResultMockRecorder struct { - mock *MockisOperation_Result -} - -// NewMockisOperation_Result creates a new mock instance -func NewMockisOperation_Result(ctrl *gomock.Controller) *MockisOperation_Result { - mock := &MockisOperation_Result{ctrl: ctrl} - mock.recorder = &MockisOperation_ResultMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockisOperation_Result) EXPECT() *MockisOperation_ResultMockRecorder { - return m.recorder -} - -// isOperation_Result mocks base method -func (m *MockisOperation_Result) isOperation_Result() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "isOperation_Result") -} - -// isOperation_Result indicates an expected call of isOperation_Result -func (mr *MockisOperation_ResultMockRecorder) isOperation_Result() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "isOperation_Result", reflect.TypeOf((*MockisOperation_Result)(nil).isOperation_Result)) -} - -// MockOperationsClient is a mock of OperationsClient interface -type MockOperationsClient struct { - ctrl *gomock.Controller - recorder *MockOperationsClientMockRecorder -} - -// MockOperationsClientMockRecorder is the mock recorder for MockOperationsClient -type MockOperationsClientMockRecorder struct { - mock *MockOperationsClient -} - -// NewMockOperationsClient creates a new mock instance -func NewMockOperationsClient(ctrl *gomock.Controller) *MockOperationsClient { - mock := &MockOperationsClient{ctrl: ctrl} - mock.recorder = &MockOperationsClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockOperationsClient) EXPECT() *MockOperationsClientMockRecorder { - return m.recorder -} - -// ListOperations mocks base method -func (m *MockOperationsClient) ListOperations(ctx context.Context, in *longrunning.ListOperationsRequest, opts ...grpc.CallOption) (*longrunning.ListOperationsResponse, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, in} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "ListOperations", varargs...) - ret0, _ := ret[0].(*longrunning.ListOperationsResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListOperations indicates an expected call of ListOperations -func (mr *MockOperationsClientMockRecorder) ListOperations(ctx, in interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, in}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListOperations", reflect.TypeOf((*MockOperationsClient)(nil).ListOperations), varargs...) -} - -// GetOperation mocks base method -func (m *MockOperationsClient) GetOperation(ctx context.Context, in *longrunning.GetOperationRequest, opts ...grpc.CallOption) (*longrunning.Operation, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, in} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetOperation", varargs...) - ret0, _ := ret[0].(*longrunning.Operation) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOperation indicates an expected call of GetOperation -func (mr *MockOperationsClientMockRecorder) GetOperation(ctx, in interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, in}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperation", reflect.TypeOf((*MockOperationsClient)(nil).GetOperation), varargs...) -} - -// DeleteOperation mocks base method -func (m *MockOperationsClient) DeleteOperation(ctx context.Context, in *longrunning.DeleteOperationRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, in} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "DeleteOperation", varargs...) - ret0, _ := ret[0].(*emptypb.Empty) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DeleteOperation indicates an expected call of DeleteOperation -func (mr *MockOperationsClientMockRecorder) DeleteOperation(ctx, in interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, in}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOperation", reflect.TypeOf((*MockOperationsClient)(nil).DeleteOperation), varargs...) -} - -// CancelOperation mocks base method -func (m *MockOperationsClient) CancelOperation(ctx context.Context, in *longrunning.CancelOperationRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, in} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "CancelOperation", varargs...) - ret0, _ := ret[0].(*emptypb.Empty) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CancelOperation indicates an expected call of CancelOperation -func (mr *MockOperationsClientMockRecorder) CancelOperation(ctx, in interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, in}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelOperation", reflect.TypeOf((*MockOperationsClient)(nil).CancelOperation), varargs...) -} - -// WaitOperation mocks base method -func (m *MockOperationsClient) WaitOperation(ctx context.Context, in *longrunning.WaitOperationRequest, opts ...grpc.CallOption) (*longrunning.Operation, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, in} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "WaitOperation", varargs...) - ret0, _ := ret[0].(*longrunning.Operation) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// WaitOperation indicates an expected call of WaitOperation -func (mr *MockOperationsClientMockRecorder) WaitOperation(ctx, in interface{}, opts ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, in}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitOperation", reflect.TypeOf((*MockOperationsClient)(nil).WaitOperation), varargs...) -} - // MockOperationsServer is a mock of OperationsServer interface type MockOperationsServer struct { ctrl *gomock.Controller @@ -194,34 +35,19 @@ func (m *MockOperationsServer) EXPECT() *MockOperationsServerMockRecorder { return m.recorder } -// ListOperations mocks base method -func (m *MockOperationsServer) ListOperations(arg0 context.Context, arg1 *longrunning.ListOperationsRequest) (*longrunning.ListOperationsResponse, error) { +// CancelOperation mocks base method +func (m *MockOperationsServer) CancelOperation(arg0 context.Context, arg1 *longrunning.CancelOperationRequest) (*emptypb.Empty, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListOperations", arg0, arg1) - ret0, _ := ret[0].(*longrunning.ListOperationsResponse) + ret := m.ctrl.Call(m, "CancelOperation", arg0, arg1) + ret0, _ := ret[0].(*emptypb.Empty) ret1, _ := ret[1].(error) return ret0, ret1 } -// ListOperations indicates an expected call of ListOperations -func (mr *MockOperationsServerMockRecorder) ListOperations(arg0, arg1 interface{}) *gomock.Call { +// CancelOperation indicates an expected call of CancelOperation +func (mr *MockOperationsServerMockRecorder) CancelOperation(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListOperations", reflect.TypeOf((*MockOperationsServer)(nil).ListOperations), arg0, arg1) -} - -// GetOperation mocks base method -func (m *MockOperationsServer) GetOperation(arg0 context.Context, arg1 *longrunning.GetOperationRequest) (*longrunning.Operation, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOperation", arg0, arg1) - ret0, _ := ret[0].(*longrunning.Operation) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetOperation indicates an expected call of GetOperation -func (mr *MockOperationsServerMockRecorder) GetOperation(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperation", reflect.TypeOf((*MockOperationsServer)(nil).GetOperation), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelOperation", reflect.TypeOf((*MockOperationsServer)(nil).CancelOperation), arg0, arg1) } // DeleteOperation mocks base method @@ -239,19 +65,34 @@ func (mr *MockOperationsServerMockRecorder) DeleteOperation(arg0, arg1 interface return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOperation", reflect.TypeOf((*MockOperationsServer)(nil).DeleteOperation), arg0, arg1) } -// CancelOperation mocks base method -func (m *MockOperationsServer) CancelOperation(arg0 context.Context, arg1 *longrunning.CancelOperationRequest) (*emptypb.Empty, error) { +// GetOperation mocks base method +func (m *MockOperationsServer) GetOperation(arg0 context.Context, arg1 *longrunning.GetOperationRequest) (*longrunning.Operation, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CancelOperation", arg0, arg1) - ret0, _ := ret[0].(*emptypb.Empty) + ret := m.ctrl.Call(m, "GetOperation", arg0, arg1) + ret0, _ := ret[0].(*longrunning.Operation) ret1, _ := ret[1].(error) return ret0, ret1 } -// CancelOperation indicates an expected call of CancelOperation -func (mr *MockOperationsServerMockRecorder) CancelOperation(arg0, arg1 interface{}) *gomock.Call { +// GetOperation indicates an expected call of GetOperation +func (mr *MockOperationsServerMockRecorder) GetOperation(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelOperation", reflect.TypeOf((*MockOperationsServer)(nil).CancelOperation), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperation", reflect.TypeOf((*MockOperationsServer)(nil).GetOperation), arg0, arg1) +} + +// ListOperations mocks base method +func (m *MockOperationsServer) ListOperations(arg0 context.Context, arg1 *longrunning.ListOperationsRequest) (*longrunning.ListOperationsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListOperations", arg0, arg1) + ret0, _ := ret[0].(*longrunning.ListOperationsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListOperations indicates an expected call of ListOperations +func (mr *MockOperationsServerMockRecorder) ListOperations(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListOperations", reflect.TypeOf((*MockOperationsServer)(nil).ListOperations), arg0, arg1) } // WaitOperation mocks base method diff --git a/cas/softcas/softcas.go b/cas/softcas/softcas.go index 21760490..8e67d016 100644 --- a/cas/softcas/softcas.go +++ b/cas/softcas/softcas.go @@ -19,9 +19,7 @@ func init() { }) } -var now = func() time.Time { - return time.Now() -} +var now = time.Now // SoftCAS implements a Certificate Authority Service using Golang or KMS // crypto. This is the default CAS used in step-ca. @@ -68,7 +66,7 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 } req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := x509util.CreateCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) if err != nil { return nil, err } @@ -93,7 +91,7 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R req.Template.NotAfter = t.Add(req.Lifetime) req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := x509util.CreateCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) if err != nil { return nil, err } @@ -150,12 +148,12 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori var cert *x509.Certificate switch req.Type { case apiv1.RootCA: - cert, err = x509util.CreateCertificate(req.Template, req.Template, signer.Public(), signer) + cert, err = createCertificate(req.Template, req.Template, signer.Public(), signer) if err != nil { return nil, err } case apiv1.IntermediateCA: - cert, err = x509util.CreateCertificate(req.Template, req.Parent.Certificate, signer.Public(), req.Parent.Signer) + cert, err = createCertificate(req.Template, req.Parent.Certificate, signer.Public(), req.Parent.Signer) if err != nil { return nil, err } @@ -174,6 +172,7 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori Name: cert.Subject.CommonName, Certificate: cert, CertificateChain: chain, + KeyName: key.Name, PublicKey: key.PublicKey, PrivateKey: key.PrivateKey, Signer: signer, @@ -210,3 +209,16 @@ func (c *SoftCAS) createSigner(req *kmsapi.CreateSignerRequest) (crypto.Signer, } return c.KeyManager.CreateSigner(req) } + +// createCertificate sets the SignatureAlgorithm of the template if necessary +// and calls x509util.CreateCertificate. +func createCertificate(template, parent *x509.Certificate, pub crypto.PublicKey, signer crypto.Signer) (*x509.Certificate, error) { + // Signers can specify the signature algorithm. This is especially important + // when x509.CreateCertificate attempts to validate a RSAPSS signature. + if template.SignatureAlgorithm == 0 { + if sa, ok := signer.(apiv1.SignatureAlgorithmGetter); ok { + template.SignatureAlgorithm = sa.SignatureAlgorithm() + } + } + return x509util.CreateCertificate(template, parent, pub, signer) +} diff --git a/cas/softcas/softcas_test.go b/cas/softcas/softcas_test.go index 092a0337..7d3add4f 100644 --- a/cas/softcas/softcas_test.go +++ b/cas/softcas/softcas_test.go @@ -75,6 +75,15 @@ var ( testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) ) +type signatureAlgorithmSigner struct { + crypto.Signer + algorithm x509.SignatureAlgorithm +} + +func (s *signatureAlgorithmSigner) SignatureAlgorithm() x509.SignatureAlgorithm { + return s.algorithm +} + type mockKeyManager struct { signer crypto.Signer errGetPublicKey error @@ -97,6 +106,7 @@ func (m *mockKeyManager) CreateKey(req *kmsapi.CreateKeyRequest) (*kmsapi.Create signer = m.signer } return &kmsapi.CreateKeyResponse{ + Name: req.Name, PrivateKey: signer, PublicKey: signer.Public(), }, m.errCreateKey @@ -124,7 +134,7 @@ func (b *badSigner) Public() crypto.PublicKey { return testSigner.Public() } -func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { +func (b *badSigner) Sign(_ io.Reader, _ []byte, _ crypto.SignerOpts) ([]byte, error) { return nil, fmt.Errorf("πŸ’₯") } @@ -247,6 +257,13 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { tmplNoSerial := *testTemplate tmplNoSerial.SerialNumber = nil + saTemplate := *testSignedTemplate + saTemplate.SignatureAlgorithm = 0 + saSigner := &signatureAlgorithmSigner{ + Signer: testSigner, + algorithm: x509.PureEd25519, + } + type fields struct { Issuer *x509.Certificate Signer crypto.Signer @@ -267,6 +284,12 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, + {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.CreateCertificateRequest{ + Template: &saTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.CreateCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, {"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNotBefore, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ @@ -316,6 +339,11 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { tmplNoSerial := *testTemplate tmplNoSerial.SerialNumber = nil + saSigner := &signatureAlgorithmSigner{ + Signer: testSigner, + algorithm: x509.PureEd25519, + } + type fields struct { Issuer *x509.Certificate Signer crypto.Signer @@ -336,6 +364,12 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, + {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.RenewCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.RenewCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, {"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ @@ -425,6 +459,11 @@ func Test_now(t *testing.T) { func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { mockNow(t) + saSigner := &signatureAlgorithmSigner{ + Signer: testSigner, + algorithm: x509.PureEd25519, + } + type fields struct { Issuer *x509.Certificate Signer crypto.Signer @@ -467,6 +506,33 @@ func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { PrivateKey: testSigner, Signer: testSigner, }, false}, + {"ok signature algorithm", fields{nil, nil, &mockKeyManager{signer: saSigner}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: testRootTemplate, + Lifetime: 24 * time.Hour, + }}, &apiv1.CreateCertificateAuthorityResponse{ + Name: "Test Root CA", + Certificate: testSignedRootTemplate, + PublicKey: testSignedRootTemplate.PublicKey, + PrivateKey: saSigner, + Signer: saSigner, + }, false}, + {"ok createKey", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: testRootTemplate, + Lifetime: 24 * time.Hour, + CreateKey: &kmsapi.CreateKeyRequest{ + Name: "root_ca.crt", + SignatureAlgorithm: kmsapi.ECDSAWithSHA256, + }, + }}, &apiv1.CreateCertificateAuthorityResponse{ + Name: "Test Root CA", + Certificate: testSignedRootTemplate, + PublicKey: testSignedRootTemplate.PublicKey, + KeyName: "root_ca.crt", + PrivateKey: testSigner, + Signer: testSigner, + }, false}, {"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Lifetime: 24 * time.Hour, diff --git a/cas/stepcas/stepcas.go b/cas/stepcas/stepcas.go index 49a99963..9fcbd36c 100644 --- a/cas/stepcas/stepcas.go +++ b/cas/stepcas/stepcas.go @@ -47,10 +47,13 @@ func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) { return nil, err } - // Create configured issuer - iss, err := newStepIssuer(caURL, client, opts.CertificateIssuer) - if err != nil { - return nil, err + var iss stepIssuer + // Create configured issuer unless we only want to use GetCertificateAuthority. + // This avoid the request for the password if not provided. + if !opts.IsCAGetter { + if iss, err = newStepIssuer(caURL, client, opts.CertificateIssuer); err != nil { + return nil, err + } } return &StepCAS{ @@ -87,9 +90,9 @@ func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R return nil, apiv1.ErrNotImplemented{Message: "stepCAS does not support mTLS renewals"} } +// RevokeCertificate revokes a certificate. func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { - switch { - case req.SerialNumber == "" && req.Certificate == nil: + if req.SerialNumber == "" && req.Certificate == nil { return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required") } diff --git a/cas/stepcas/stepcas_test.go b/cas/stepcas/stepcas_test.go index fb8259f5..f430a1dd 100644 --- a/cas/stepcas/stepcas_test.go +++ b/cas/stepcas/stepcas_test.go @@ -411,6 +411,19 @@ func TestNew(t *testing.T) { client: client, fingerprint: testRootFingerprint, }, false}, + {"ok ca getter", args{context.TODO(), apiv1.Options{ + IsCAGetter: true, + CertificateAuthority: caURL.String(), + CertificateAuthorityFingerprint: testRootFingerprint, + CertificateIssuer: &apiv1.CertificateIssuer{ + Type: "jwk", + Provisioner: "ra@doe.org", + }, + }}, &StepCAS{ + iss: nil, + client: client, + fingerprint: testRootFingerprint, + }, false}, {"fail authority", args{context.TODO(), apiv1.Options{ CertificateAuthority: "", CertificateAuthorityFingerprint: testRootFingerprint, diff --git a/cas/stepcas/x5c_issuer.go b/cas/stepcas/x5c_issuer.go index da4aa27e..76ed9c3c 100644 --- a/cas/stepcas/x5c_issuer.go +++ b/cas/stepcas/x5c_issuer.go @@ -19,9 +19,7 @@ const defaultValidity = 5 * time.Minute // timeNow returns the current time. // This method is used for unit testing purposes. -var timeNow = func() time.Time { - return time.Now() -} +var timeNow = time.Now type x5cIssuer struct { caURL *url.URL @@ -143,7 +141,11 @@ func newX5CSigner(certFile, keyFile, password string) (jose.Signer, error) { if err != nil { return nil, err } - certs, err := jose.ValidateX5C(certFile, signer) + certs, err := pemutil.ReadCertificateBundle(certFile) + if err != nil { + return nil, errors.Wrap(err, "error reading x5c certificate chain") + } + certStrs, err := jose.ValidateX5C(certs, signer) if err != nil { return nil, errors.Wrap(err, "error validating x5c certificate chain and key") } @@ -151,7 +153,7 @@ func newX5CSigner(certFile, keyFile, password string) (jose.Signer, error) { so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", kid) - so.WithHeader("x5c", certs) + so.WithHeader("x5c", certStrs) return newJoseSigner(signer, so) } diff --git a/cas/stepcas/x5c_issuer_test.go b/cas/stepcas/x5c_issuer_test.go index a3190255..b1bc653d 100644 --- a/cas/stepcas/x5c_issuer_test.go +++ b/cas/stepcas/x5c_issuer_test.go @@ -22,7 +22,7 @@ func (b noneSigner) Public() crypto.PublicKey { return []byte(b) } -func (b noneSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { +func (b noneSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { return digest, nil } diff --git a/cmd/step-awskms-init/main.go b/cmd/step-awskms-init/main.go index 0d686239..8e30745f 100644 --- a/cmd/step-awskms-init/main.go +++ b/cmd/step-awskms-init/main.go @@ -24,13 +24,16 @@ import ( func main() { var credentialsFile, region string - var ssh bool + var enableSSH bool flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the AWS KMS credentials.") flag.StringVar(®ion, "region", "", "AWS KMS region name.") - flag.BoolVar(&ssh, "ssh", false, "Create SSH keys.") + flag.BoolVar(&enableSSH, "ssh", false, "Create SSH keys.") flag.Usage = usage flag.Parse() + // Initialize windows terminal + ui.Init() + c, err := awskms.New(context.Background(), apiv1.Options{ Type: string(apiv1.AmazonKMS), Region: region, @@ -44,16 +47,20 @@ func main() { fatal(err) } - if ssh { + if enableSSH { ui.Println() if err := createSSH(c); err != nil { fatal(err) } } + + // Reset windows terminal + ui.Reset() } func fatal(err error) { fmt.Fprintln(os.Stderr, err) + ui.Reset() os.Exit(1) } @@ -113,7 +120,7 @@ func createX509(c *awskms.KMS) error { return err } - if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ + if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: b, }), 0600); err != nil { @@ -156,7 +163,7 @@ func createX509(c *awskms.KMS) error { return err } - if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ + if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: b, }), 0600); err != nil { @@ -186,7 +193,7 @@ func createSSH(c *awskms.KMS) error { return err } - if err = fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { + if err := fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { return err } @@ -207,7 +214,7 @@ func createSSH(c *awskms.KMS) error { return err } - if err = fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { + if err := fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { return err } diff --git a/cmd/step-ca/main.go b/cmd/step-ca/main.go index 4396e028..01d800d8 100644 --- a/cmd/step-ca/main.go +++ b/cmd/step-ca/main.go @@ -22,10 +22,12 @@ import ( "go.step.sm/cli-utils/command" "go.step.sm/cli-utils/command/version" "go.step.sm/cli-utils/config" + "go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/usage" // Enabled kms interfaces. _ "github.com/smallstep/certificates/kms/awskms" + _ "github.com/smallstep/certificates/kms/azurekms" _ "github.com/smallstep/certificates/kms/cloudkms" _ "github.com/smallstep/certificates/kms/softkms" _ "github.com/smallstep/certificates/kms/sshagentkms" @@ -52,6 +54,11 @@ func init() { rand.Seed(time.Now().UnixNano()) } +func exit(code int) { + ui.Reset() + os.Exit(code) +} + // appHelpTemplate contains the modified template for the main app var appHelpTemplate = `## NAME **{{.HelpName}}** -- {{.Usage}} @@ -90,6 +97,9 @@ Please send us a sentence or two, good or bad: **feedback@smallstep.com** or htt ` func main() { + // Initialize windows terminal + ui.Init() + // Override global framework components cli.VersionPrinter = func(c *cli.Context) { version.Command(c) @@ -107,7 +117,9 @@ func main() { app.HelpName = "step-ca" app.Version = config.Version() app.Usage = "an online certificate authority for secure automated certificate management" - app.UsageText = `**step-ca** [**--password-file**=] [**--issuer-password-file**=] [**--resolver**=] [**--help**] [**--version**]` + app.UsageText = `**step-ca** [**--password-file**=] +[**--ssh-host-password-file**=] [**--ssh-user-password-file**=] +[**--issuer-password-file**=] [**--resolver**=] [**--help**] [**--version**]` app.Description = `**step-ca** runs the Step Online Certificate Authority (Step CA) using the given configuration. See the README.md for more detailed configuration documentation. @@ -162,8 +174,10 @@ $ step-ca $STEPPATH/config/ca.json --password-file ./password.txt } else { fmt.Fprintln(os.Stderr, err) } - os.Exit(1) + exit(1) } + + exit(0) } func flagValue(f cli.Flag) reflect.Value { @@ -178,8 +192,8 @@ var placeholderString = regexp.MustCompile(`<.*?>`) func stringifyFlag(f cli.Flag) string { fv := flagValue(f) - usage := fv.FieldByName("Usage").String() - placeholder := placeholderString.FindString(usage) + usg := fv.FieldByName("Usage").String() + placeholder := placeholderString.FindString(usg) if placeholder == "" { switch f.(type) { case cli.BoolFlag, cli.BoolTFlag: @@ -187,5 +201,5 @@ func stringifyFlag(f cli.Flag) string { placeholder = "" } } - return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage + return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usg } diff --git a/cmd/step-cloudkms-init/main.go b/cmd/step-cloudkms-init/main.go index 69573c5d..27dc82ad 100644 --- a/cmd/step-cloudkms-init/main.go +++ b/cmd/step-cloudkms-init/main.go @@ -27,13 +27,13 @@ func main() { var credentialsFile string var project, location, ring string var protectionLevelName string - var ssh bool + var enableSSH bool flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the Google's Cloud KMS credentials.") flag.StringVar(&project, "project", "", "Google Cloud Project ID.") flag.StringVar(&location, "location", "global", "Cloud KMS location name.") flag.StringVar(&ring, "ring", "pki", "Cloud KMS ring name.") flag.StringVar(&protectionLevelName, "protection-level", "SOFTWARE", "Protection level to use, SOFTWARE or HSM.") - flag.BoolVar(&ssh, "ssh", false, "Create SSH keys.") + flag.BoolVar(&enableSSH, "ssh", false, "Create SSH keys.") flag.Usage = usage flag.Parse() @@ -62,6 +62,9 @@ func main() { os.Exit(1) } + // Initialize windows terminal + ui.Init() + c, err := cloudkms.New(context.Background(), apiv1.Options{ Type: string(apiv1.CloudKMS), CredentialsFile: credentialsFile, @@ -74,16 +77,20 @@ func main() { fatal(err) } - if ssh { + if enableSSH { ui.Println() if err := createSSH(c, project, location, ring, protectionLevel); err != nil { fatal(err) } } + + // Reset windows terminal + ui.Reset() } func fatal(err error) { fmt.Fprintln(os.Stderr, err) + ui.Reset() os.Exit(1) } @@ -146,7 +153,7 @@ func createPKI(c *cloudkms.CloudKMS, project, location, keyRing string, protecti return err } - if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ + if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: b, }), 0600); err != nil { @@ -190,7 +197,7 @@ func createPKI(c *cloudkms.CloudKMS, project, location, keyRing string, protecti return err } - if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ + if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: b, }), 0600); err != nil { @@ -223,7 +230,7 @@ func createSSH(c *cloudkms.CloudKMS, project, location, keyRing string, protecti return err } - if err = fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { + if err := fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { return err } @@ -245,7 +252,7 @@ func createSSH(c *cloudkms.CloudKMS, project, location, keyRing string, protecti return err } - if err = fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { + if err := fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { return err } diff --git a/cmd/step-pkcs11-init/main.go b/cmd/step-pkcs11-init/main.go index c7ac9b0f..0db1e4d9 100644 --- a/cmd/step-pkcs11-init/main.go +++ b/cmd/step-pkcs11-init/main.go @@ -131,6 +131,9 @@ func main() { fatal(err) } + // Initialize windows terminal + ui.Init() + if u.Get("pin-value") == "" && u.Get("pin-source") == "" && c.Pin == "" { pin, err := ui.PromptPassword("What is the PKCS#11 PIN?") if err != nil { @@ -203,6 +206,9 @@ func main() { if err := createPKI(k, c); err != nil { fatalClose(err, k) } + + // Reset windows terminal + ui.Reset() } func fatal(err error) { @@ -211,6 +217,7 @@ func fatal(err error) { } else { fmt.Fprintln(os.Stderr, err) } + ui.Reset() os.Exit(1) } @@ -325,7 +332,7 @@ func createPKI(k kms.KeyManager, c Config) error { } if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts { - if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{ + if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{ Name: c.RootObject, Certificate: root, Extractable: c.Extractable, @@ -334,7 +341,7 @@ func createPKI(k kms.KeyManager, c Config) error { } } - if err = fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{ + if err := fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: b, }), 0600); err != nil { @@ -404,7 +411,7 @@ func createPKI(k kms.KeyManager, c Config) error { } if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts { - if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{ + if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{ Name: c.CrtObject, Certificate: intermediate, Extractable: c.Extractable, @@ -413,7 +420,7 @@ func createPKI(k kms.KeyManager, c Config) error { } } - if err = fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{ + if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: b, }), 0600); err != nil { diff --git a/cmd/step-yubikey-init/main.go b/cmd/step-yubikey-init/main.go index df7b9ea8..8b0ffab5 100644 --- a/cmd/step-yubikey-init/main.go +++ b/cmd/step-yubikey-init/main.go @@ -87,6 +87,9 @@ func main() { fatal(err) } + // Initialize windows terminal + ui.Init() + pin, err := ui.PromptPassword("What is the YubiKey PIN?") if err != nil { fatal(err) @@ -119,6 +122,9 @@ func main() { defer func() { _ = k.Close() }() + + // Reset windows terminal + ui.Reset() } func fatal(err error) { @@ -127,6 +133,7 @@ func fatal(err error) { } else { fmt.Fprintln(os.Stderr, err) } + ui.Reset() os.Exit(1) } @@ -221,7 +228,7 @@ func createPKI(k kms.KeyManager, c Config) error { } if cm, ok := k.(kms.CertificateManager); ok { - if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{ + if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{ Name: c.RootSlot, Certificate: root, }); err != nil { @@ -229,7 +236,7 @@ func createPKI(k kms.KeyManager, c Config) error { } } - if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ + if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: b, }), 0600); err != nil { @@ -298,7 +305,7 @@ func createPKI(k kms.KeyManager, c Config) error { } if cm, ok := k.(kms.CertificateManager); ok { - if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{ + if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{ Name: c.CrtSlot, Certificate: intermediate, }); err != nil { @@ -306,7 +313,7 @@ func createPKI(k kms.KeyManager, c Config) error { } } - if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ + if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: b, }), 0600); err != nil { diff --git a/commands/app.go b/commands/app.go index aff9d473..84232a6c 100644 --- a/commands/app.go +++ b/commands/app.go @@ -8,11 +8,13 @@ import ( "net" "net/http" "os" + "strings" "unicode" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/ca" + "github.com/smallstep/certificates/pki" "github.com/urfave/cli" "go.step.sm/cli-utils/errs" ) @@ -21,13 +23,26 @@ import ( var AppCommand = cli.Command{ Name: "start", Action: appAction, - UsageText: `**step-ca** -[**--password-file**=] [**--issuer-password-file**=] [**--resolver**=]`, + UsageText: `**step-ca** [**--password-file**=] +[**--ssh-host-password-file**=] [**--ssh-user-password-file**=] +[**--issuer-password-file**=] [**--resolver**=]`, Flags: []cli.Flag{ cli.StringFlag{ Name: "password-file", Usage: `path to the containing the password to decrypt the intermediate private key.`, + }, + cli.StringFlag{ + Name: "ssh-host-password-file", + Usage: `path to the containing the password to decrypt the +private key used to sign SSH host certificates. If the flag is not passed it +will default to --password-file.`, + }, + cli.StringFlag{ + Name: "ssh-user-password-file", + Usage: `path to the containing the password to decrypt the +private key used to sign SSH user certificates. If the flag is not passed it +will default to --password-file.`, }, cli.StringFlag{ Name: "issuer-password-file", @@ -38,14 +53,22 @@ certificate issuer private key used in the RA mode.`, Name: "resolver", Usage: "address of a DNS resolver to be used instead of the default.", }, + cli.StringFlag{ + Name: "token", + Usage: "token used to enable the linked ca.", + EnvVar: "STEP_CA_TOKEN", + }, }, } // AppAction is the action used when the top command runs. func appAction(ctx *cli.Context) error { passFile := ctx.String("password-file") + sshHostPassFile := ctx.String("ssh-host-password-file") + sshUserPassFile := ctx.String("ssh-user-password-file") issuerPassFile := ctx.String("issuer-password-file") resolver := ctx.String("resolver") + token := ctx.String("token") // If zero cmd line args show help, if >1 cmd line args show error. if ctx.NArg() == 0 { @@ -56,11 +79,23 @@ func appAction(ctx *cli.Context) error { } configFile := ctx.Args().Get(0) - config, err := authority.LoadConfiguration(configFile) + cfg, err := config.LoadConfiguration(configFile) if err != nil { fatal(err) } + if cfg.AuthorityConfig != nil { + if token == "" && strings.EqualFold(cfg.AuthorityConfig.DeploymentType, pki.LinkedDeployment.String()) { + return errors.New(`'step-ca' requires the '--token' flag for linked deploy type. + +To get a linked authority token: + 1. Log in or create a Certificate Manager account at ` + "\033[1mhttps://u.step.sm/linked\033[0m" + ` + 2. Add a new authority and select "Link a step-ca instance" + 3. Follow instructions in browser to start 'step-ca' using the '--token' flag +`) + } + } + var password []byte if passFile != "" { if password, err = ioutil.ReadFile(passFile); err != nil { @@ -69,6 +104,22 @@ func appAction(ctx *cli.Context) error { password = bytes.TrimRightFunc(password, unicode.IsSpace) } + var sshHostPassword []byte + if sshHostPassFile != "" { + if sshHostPassword, err = ioutil.ReadFile(sshHostPassFile); err != nil { + fatal(errors.Wrapf(err, "error reading %s", sshHostPassFile)) + } + sshHostPassword = bytes.TrimRightFunc(sshHostPassword, unicode.IsSpace) + } + + var sshUserPassword []byte + if sshUserPassFile != "" { + if sshUserPassword, err = ioutil.ReadFile(sshUserPassFile); err != nil { + fatal(errors.Wrapf(err, "error reading %s", sshUserPassFile)) + } + sshUserPassword = bytes.TrimRightFunc(sshUserPassword, unicode.IsSpace) + } + var issuerPassword []byte if issuerPassFile != "" { if issuerPassword, err = ioutil.ReadFile(issuerPassFile); err != nil { @@ -85,10 +136,13 @@ func appAction(ctx *cli.Context) error { } } - srv, err := ca.New(config, + srv, err := ca.New(cfg, ca.WithConfigFile(configFile), ca.WithPassword(password), - ca.WithIssuerPassword(issuerPassword)) + ca.WithSSHHostPassword(sshHostPassword), + ca.WithSSHUserPassword(sshUserPassword), + ca.WithIssuerPassword(issuerPassword), + ca.WithLinkedCAToken(token)) if err != nil { fatal(err) } diff --git a/commands/export.go b/commands/export.go new file mode 100644 index 00000000..5586f576 --- /dev/null +++ b/commands/export.go @@ -0,0 +1,113 @@ +package commands + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "unicode" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/config" + "github.com/urfave/cli" + "google.golang.org/protobuf/encoding/protojson" + + "go.step.sm/cli-utils/command" + "go.step.sm/cli-utils/errs" +) + +func init() { + command.Register(cli.Command{ + Name: "export", + Usage: "export the current configuration of step-ca", + UsageText: "**step-ca export** ", + Action: exportAction, + Description: `**step-ca export** exports the current configuration of step-ca. + +Note that neither the PKI password nor the certificate issuer password will be +included in the export file. + +## POSITIONAL ARGUMENTS + + +: The ca.json that contains the step-ca configuration. + +## EXAMPLES + +Export the current configuration: +''' +$ step-ca export $(step path)/config/ca.json +'''`, + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "password-file", + Usage: `path to the containing the password to decrypt the +intermediate private key.`, + }, + cli.StringFlag{ + Name: "issuer-password-file", + Usage: `path to the containing the password to decrypt the +certificate issuer private key used in the RA mode.`, + }, + }, + }) +} + +func exportAction(ctx *cli.Context) error { + if err := errs.NumberOfArguments(ctx, 1); err != nil { + return err + } + + configFile := ctx.Args().Get(0) + passwordFile := ctx.String("password-file") + issuerPasswordFile := ctx.String("issuer-password-file") + + cfg, err := config.LoadConfiguration(configFile) + if err != nil { + return err + } + if err := cfg.Validate(); err != nil { + return err + } + + if passwordFile != "" { + b, err := ioutil.ReadFile(passwordFile) + if err != nil { + return errors.Wrapf(err, "error reading %s", passwordFile) + } + cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace)) + } + if issuerPasswordFile != "" { + b, err := ioutil.ReadFile(issuerPasswordFile) + if err != nil { + return errors.Wrapf(err, "error reading %s", issuerPasswordFile) + } + if cfg.AuthorityConfig.CertificateIssuer != nil { + cfg.AuthorityConfig.CertificateIssuer.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace)) + } + } + + auth, err := authority.New(cfg) + if err != nil { + return err + } + + export, err := auth.Export() + if err != nil { + return err + } + + b, err := protojson.Marshal(export) + if err != nil { + return errors.Wrap(err, "error marshaling export") + } + + var buf bytes.Buffer + if err := json.Indent(&buf, b, "", "\t"); err != nil { + return errors.Wrap(err, "error indenting export") + } + + fmt.Println(buf.String()) + return nil +} diff --git a/commands/onboard.go b/commands/onboard.go index 13c32304..ebd468f5 100644 --- a/commands/onboard.go +++ b/commands/onboard.go @@ -9,7 +9,7 @@ import ( "os" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/pki" @@ -103,8 +103,8 @@ func onboardAction(ctx *cli.Context) error { return errors.Wrap(msg, "error receiving onboarding guide") } - var config onboardingConfiguration - if err := readJSON(res.Body, &config); err != nil { + var cfg onboardingConfiguration + if err := readJSON(res.Body, &cfg); err != nil { return errors.Wrap(err, "error unmarshaling response") } @@ -112,16 +112,16 @@ func onboardAction(ctx *cli.Context) error { if err != nil { return err } - config.password = []byte(password) + cfg.password = []byte(password) ui.Println("Initializing step-ca with the following configuration:") - ui.PrintSelected("Name", config.Name) - ui.PrintSelected("DNS", config.DNS) - ui.PrintSelected("Address", config.Address) + ui.PrintSelected("Name", cfg.Name) + ui.PrintSelected("DNS", cfg.DNS) + ui.PrintSelected("Address", cfg.Address) ui.PrintSelected("Password", password) ui.Println() - caConfig, fp, err := onboardPKI(config) + caConfig, fp, err := onboardPKI(cfg) if err != nil { return err } @@ -149,47 +149,55 @@ func onboardAction(ctx *cli.Context) error { ui.Println("Initialized!") ui.Println("Step CA is starting. Please return to the onboarding guide in your browser to continue.") - srv, err := ca.New(caConfig, ca.WithPassword(config.password)) + srv, err := ca.New(caConfig, ca.WithPassword(cfg.password)) if err != nil { fatal(err) } go ca.StopReloaderHandler(srv) - if err = srv.Run(); err != nil && err != http.ErrServerClosed { + if err := srv.Run(); err != nil && err != http.ErrServerClosed { fatal(err) } return nil } -func onboardPKI(config onboardingConfiguration) (*authority.Config, string, error) { +func onboardPKI(cfg onboardingConfiguration) (*config.Config, string, error) { + var opts = []pki.Option{ + pki.WithAddress(cfg.Address), + pki.WithDNSNames([]string{cfg.DNS}), + pki.WithProvisioner("admin"), + } + p, err := pki.New(apiv1.Options{ Type: apiv1.SoftCAS, IsCreator: true, - }) + }, opts...) if err != nil { return nil, "", err } - p.SetAddress(config.Address) - p.SetDNSNames([]string{config.DNS}) - + // Generate pki ui.Println("Generating root certificate...") - root, err := p.GenerateRootCertificate(config.Name, config.Name, config.Name, config.password) + root, err := p.GenerateRootCertificate(cfg.Name, cfg.Name, cfg.Name, cfg.password) if err != nil { return nil, "", err } ui.Println("Generating intermediate certificate...") - err = p.GenerateIntermediateCertificate(config.Name, config.Name, config.Name, root, config.password) + err = p.GenerateIntermediateCertificate(cfg.Name, cfg.Name, cfg.Name, root, cfg.password) if err != nil { return nil, "", err } + // Write files to disk + if err := p.WriteFiles(); err != nil { + return nil, "", err + } + // Generate provisioner - p.SetProvisioner("admin") ui.Println("Generating admin provisioner...") - if err = p.GenerateKeyPairs(config.password); err != nil { + if err := p.GenerateKeyPairs(cfg.password); err != nil { return nil, "", err } @@ -203,7 +211,7 @@ func onboardPKI(config onboardingConfiguration) (*authority.Config, string, erro if err != nil { return nil, "", errors.Wrapf(err, "error marshaling %s", p.GetCAConfigPath()) } - if err = fileutil.WriteFile(p.GetCAConfigPath(), b, 0666); err != nil { + if err := fileutil.WriteFile(p.GetCAConfigPath(), b, 0666); err != nil { return nil, "", errs.FileError(err, p.GetCAConfigPath()) } diff --git a/cosign.pub b/cosign.pub new file mode 100644 index 00000000..9a0b42be --- /dev/null +++ b/cosign.pub @@ -0,0 +1,4 @@ +-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs+6THbAiXx4bja5ARQFNZmPwZjlD +GRvt5H+9ZFDhrcFPR1E7eB2rt1B/DhobANdHGKjvEBZEf0v4X/7S+SHrIw== +-----END PUBLIC KEY----- diff --git a/db/db_test.go b/db/db_test.go index 7efc623e..40f59215 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -144,15 +144,15 @@ func TestUseToken(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - ok, err := tc.db.UseToken(tc.id, tc.tok) - if err != nil { + switch ok, err := tc.db.UseToken(tc.id, tc.tok); { + case err != nil: if assert.NotNil(t, tc.want.err) { assert.HasPrefix(t, err.Error(), tc.want.err.Error()) } assert.False(t, ok) - } else if ok { + case ok: assert.True(t, tc.want.ok) - } else { + default: assert.False(t, tc.want.ok) } }) diff --git a/docker/Dockerfile.step-ca b/docker/Dockerfile.step-ca index 4a1908d6..9363b6ae 100644 --- a/docker/Dockerfile.step-ca +++ b/docker/Dockerfile.step-ca @@ -24,4 +24,7 @@ VOLUME ["/home/step"] STOPSIGNAL SIGTERM HEALTHCHECK CMD step ca health 2>/dev/null | grep "^ok" >/dev/null +COPY docker/entrypoint.sh /entrypoint.sh + +ENTRYPOINT ["/bin/bash", "/entrypoint.sh"] CMD exec /usr/local/bin/step-ca --password-file $PWDPATH $CONFIGPATH diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100644 index 00000000..1f48c028 --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,60 @@ +#!/bin/bash +set -eo pipefail + +# Paraphrased from: +# https://github.com/influxdata/influxdata-docker/blob/0d341f18067c4652dfa8df7dcb24d69bf707363d/influxdb/2.0/entrypoint.sh +# (a repo with no LICENSE.md) + +export STEPPATH=$(step path) + +# List of env vars required for step ca init +declare -ra REQUIRED_INIT_VARS=(DOCKER_STEPCA_INIT_NAME DOCKER_STEPCA_INIT_DNS_NAMES) + +# Ensure all env vars required to run step ca init are set. +function init_if_possible () { + local missing_vars=0 + for var in "${REQUIRED_INIT_VARS[@]}"; do + if [ -z "${!var}" ]; then + missing_vars=1 + fi + done + if [ ${missing_vars} = 1 ]; then + >&2 echo "there is no ca.json config file; please run step ca init, or provide config parameters via DOCKER_STEPCA_INIT_ vars" + else + step_ca_init "${@}" + fi +} + +function generate_password () { + set +o pipefail + < /dev/urandom tr -dc A-Za-z0-9 | head -c40 + echo + set -o pipefail +} + +# Initialize a CA if not already initialized +function step_ca_init () { + local -a setup_args=( + --name "${DOCKER_STEPCA_INIT_NAME}" + --dns "${DOCKER_STEPCA_INIT_DNS_NAMES}" + --provisioner "${DOCKER_STEPCA_INIT_PROVISIONER_NAME:-admin}" + --password-file "${STEPPATH}/password" + --address ":9000" + ) + if [ -n "${DOCKER_STEPCA_INIT_PASSWORD}" ]; then + echo "${DOCKER_STEPCA_INIT_PASSWORD}" > "${STEPPATH}/password" + else + generate_password > "${STEPPATH}/password" + fi + if [ -n "${DOCKER_STEPCA_INIT_SSH}" ]; then + setup_args=("${setup_args[@]}" --ssh) + fi + step ca init "${setup_args[@]}" + mv $STEPPATH/password $PWDPATH +} + +if [ ! -f "${STEPPATH}/config/ca.json" ]; then + init_if_possible +fi + +exec "${@}" diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 93749026..35f75159 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -7,12 +7,20 @@ to manage issues, etc. ## Table of Contents -* [Building From Source](#building-from-source) -* [Asking Support Questions](#asking-support-questions) -* [Reporting Issues](#reporting-issues) -* [Submitting Patches](#submitting-patches) - * [Code Contribution Guidelines](#code-contribution-guidelines) - * [Git Commit Message Guidelines](#git-commit-message-guidelines) +- [Contributing to `step certificates`](#contributing-to-step-certificates) + - [Table of Contents](#table-of-contents) + - [Building From Source](#building-from-source) + - [Build a standard `step-ca`](#build-a-standard-step-ca) + - [Build `step-ca` using CGO](#build-step-ca-using-cgo) + - [The CGO build enables PKCS #11 and YubiKey PIV support](#the-cgo-build-enables-pkcs-11-and-yubikey-piv-support) + - [1. Install PCSC support](#1-install-pcsc-support) + - [2. Build `step-ca`](#2-build-step-ca) + - [Asking Support Questions](#asking-support-questions) + - [Reporting Issues](#reporting-issues) + - [Code Contribution](#code-contribution) + - [Submitting Patches](#submitting-patches) + - [Code Contribution Guidelines](#code-contribution-guidelines) + - [Git Commit Message Guidelines](#git-commit-message-guidelines) ## Building From Source @@ -73,7 +81,7 @@ When the build is complete, you will find binaries in `bin/`. ## Asking Support Questions -Feel free to post a question on our [GitHub Discussions](https://github.com/smallstep/certificates/discussions) page, or find us on [Gitter](https://gitter.im/smallstep/community). +Feel free to post a question on our [GitHub Discussions](https://github.com/smallstep/certificates/discussions) page, or find us on [Discord](https://bit.ly/step-discord). ## Reporting Issues diff --git a/docs/provisioners.md b/docs/provisioners.md index 7ee9af50..18010f88 100644 --- a/docs/provisioners.md +++ b/docs/provisioners.md @@ -1,7 +1,7 @@ # Provisioners > Note: The canonical documentation for `step-ca` provisioners now lives at -> https://smallstep.com/docs/step-ca/configuration#provisioners. Documentation +> https://smallstep.com/docs/step-ca/provisioners. Documentation > found on this page may be out of date. Provisioners are people or code that are registered with the CA and authorized diff --git a/docs/revocation.md b/docs/revocation.md index e994940d..4f3a7d5e 100644 --- a/docs/revocation.md +++ b/docs/revocation.md @@ -202,7 +202,8 @@ through an example. [Use TLS Everywhere](https://smallstep.com/blog/use-tls.html) and let us know what you think of our tools. Get in touch over [Twitter](twitter.com/smallsteplabs) or through our -[GitHub Discussions](https://github.com/smallstep/certificates/discussions) to chat with us in real time. +[GitHub Discussions](https://github.com/smallstep/certificates/discussions) to find answers to frequently asked questions. +[Discord](https://bit.ly/step-discord) to chat with us in real time. ## Further Reading diff --git a/examples/basic-client/client.go b/examples/basic-client/client.go index db6092bf..42358ac8 100644 --- a/examples/basic-client/client.go +++ b/examples/basic-client/client.go @@ -116,7 +116,6 @@ func main() { Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, DualStack: true, }).DialContext, MaxIdleConns: 100, diff --git a/go.mod b/go.mod index 74aa56e9..48299d6a 100644 --- a/go.mod +++ b/go.mod @@ -1,41 +1,54 @@ module github.com/smallstep/certificates -go 1.14 +go 1.15 require ( - cloud.google.com/go v0.70.0 - github.com/Masterminds/sprig/v3 v3.1.0 + cloud.google.com/go v0.83.0 + github.com/Azure/azure-sdk-for-go v58.0.0+incompatible + github.com/Azure/go-autorest/autorest v0.11.17 + github.com/Azure/go-autorest/autorest/azure/auth v0.5.8 + github.com/Azure/go-autorest/autorest/date v0.3.0 + github.com/Azure/go-autorest/autorest/to v0.4.0 // indirect + github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect + github.com/Masterminds/sprig/v3 v3.2.2 github.com/ThalesIgnite/crypto11 v1.2.4 github.com/aws/aws-sdk-go v1.30.29 + github.com/dgraph-io/ristretto v0.0.4-0.20200906165740-41ebdbffecfd // indirect github.com/go-chi/chi v4.0.2+incompatible github.com/go-kit/kit v0.10.0 // indirect github.com/go-piv/piv-go v1.7.0 - github.com/golang/mock v1.4.4 - github.com/google/uuid v1.1.2 + github.com/golang/mock v1.6.0 + github.com/google/uuid v1.3.0 github.com/googleapis/gax-go/v2 v2.0.5 github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect - github.com/micromdm/scep/v2 v2.0.0 - github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f + github.com/mattn/go-colorable v0.1.8 // indirect + github.com/mattn/go-isatty v0.0.13 // indirect + github.com/micromdm/scep/v2 v2.1.0 github.com/newrelic/go-agent v2.15.0+incompatible github.com/pkg/errors v0.9.1 github.com/rs/xid v1.2.1 github.com/sirupsen/logrus v1.4.2 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 - github.com/smallstep/nosql v0.3.6 + github.com/smallstep/nosql v0.3.8 github.com/urfave/cli v1.22.4 - go.mozilla.org/pkcs7 v0.0.0-20200128120323-432b2356ecb1 - go.step.sm/cli-utils v0.2.0 - go.step.sm/crypto v0.8.3 - golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a - golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 - google.golang.org/api v0.33.0 - google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154 - google.golang.org/grpc v1.32.0 - google.golang.org/protobuf v1.25.0 - gopkg.in/square/go-jose.v2 v2.5.1 + go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 + go.step.sm/cli-utils v0.6.1 + go.step.sm/crypto v0.13.0 + go.step.sm/linkedca v0.7.0 + golang.org/x/crypto v0.0.0-20210915214749-c084706c2272 + golang.org/x/net v0.0.0-20210913180222-943fd674d43e + google.golang.org/api v0.47.0 + google.golang.org/genproto v0.0.0-20210719143636-1d5a45f8e492 + google.golang.org/grpc v1.39.0 + google.golang.org/protobuf v1.27.1 + gopkg.in/square/go-jose.v2 v2.6.0 ) +// avoid license conflict from juju/ansiterm until https://github.com/manifoldco/promptui/pull/181 +// is merged or other dependency in path currently in violation fixes compliance +replace github.com/manifoldco/promptui => github.com/nguyer/promptui v0.8.1-0.20210517132806-70ccd4709797 + // replace github.com/smallstep/nosql => ../nosql // replace go.step.sm/crypto => ../crypto - -replace go.mozilla.org/pkcs7 v0.0.0-20200128120323-432b2356ecb1 => github.com/omorsi/pkcs7 v0.0.0-20210217142924-a7b80a2a8568 +// replace go.step.sm/cli-utils => ../cli-utils +// replace go.step.sm/linkedca => ../linkedca diff --git a/go.sum b/go.sum index 60c37a32..252832ea 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,13 @@ cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKV cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= -cloud.google.com/go v0.70.0 h1:ujhG1RejZYi+HYfJNlgBh3j/bVKD8DewM7AkJ5UPyBc= -cloud.google.com/go v0.70.0/go.mod h1:/UTKYRQTWjVnSe7nGvoSzxEFUELzSI/yAYd0JQT6cRo= +cloud.google.com/go v0.72.0/go.mod h1:M+5Vjvlc2wnp6tjzE102Dw08nGShTscUx2nZMufOKPI= +cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmWk= +cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg= +cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8= +cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0= +cloud.google.com/go v0.83.0 h1:bAMqZidYkmIsUqe6PtkEPT7Q+vfizScn+jfNA6jwK9c= +cloud.google.com/go v0.83.0/go.mod h1:Z7MJUsANfY0pYPdw0lbnivPx4/vhy/e2FEkSkF7vAVY= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= @@ -35,36 +40,79 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M= github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= +github.com/Azure/azure-sdk-for-go v58.0.0+incompatible h1:Cw16jiP4dI+CK761aq44ol4RV5dUiIIXky1+EKpoiVM= +github.com/Azure/azure-sdk-for-go v58.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= +github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= +github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= +github.com/Azure/go-autorest/autorest v0.11.17 h1:2zCdHwNgRH+St1J+ZMf66xI8aLr/5KMy+wWLH97zwYM= +github.com/Azure/go-autorest/autorest v0.11.17/go.mod h1:eipySxLmqSyC5s5k1CLupqet0PSENBEDP93LQ9a8QYw= +github.com/Azure/go-autorest/autorest/adal v0.9.5/go.mod h1:B7KF7jKIeC9Mct5spmyCB/A8CG/sEz1vwIRGv/bbw7A= +github.com/Azure/go-autorest/autorest/adal v0.9.11 h1:L4/pmq7poLdsy41Bj1FayKvBhayuWRYkx9HU5i4Ybl0= +github.com/Azure/go-autorest/autorest/adal v0.9.11/go.mod h1:nBKAnTomx8gDtl+3ZCJv2v0KACFHWTB2drffI1B68Pk= +github.com/Azure/go-autorest/autorest/azure/auth v0.5.8 h1:TzPg6B6fTZ0G1zBf3T54aI7p3cAT6u//TOXGPmFMOXg= +github.com/Azure/go-autorest/autorest/azure/auth v0.5.8/go.mod h1:kxyKZTSfKh8OVFWPAgOgQ/frrJgeYQJPyR5fLFmXko4= +github.com/Azure/go-autorest/autorest/azure/cli v0.4.2 h1:dMOmEJfkLKW/7JsokJqkyoYSgmR08hi9KrhjZb+JALY= +github.com/Azure/go-autorest/autorest/azure/cli v0.4.2/go.mod h1:7qkJkT+j6b+hIpzMOwPChJhTqS8VbsqqgULzMNRugoM= +github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw= +github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74= +github.com/Azure/go-autorest/autorest/mocks v0.4.1 h1:K0laFcLE6VLTOwNgSxaGbUcLPuGXlNkbVvq4cW4nIHk= +github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= +github.com/Azure/go-autorest/autorest/to v0.4.0 h1:oXVqrxakqqV1UZdSazDOPOLvOIz+XA683u8EctwboHk= +github.com/Azure/go-autorest/autorest/to v0.4.0/go.mod h1:fE8iZBn7LQR7zH/9XU2NcPR4o9jEImooCeWJcYV/zLE= +github.com/Azure/go-autorest/autorest/validation v0.3.1 h1:AgyqjAd94fwNAoTjl/WQXg4VvFeRFpO+UhNyRXqF1ac= +github.com/Azure/go-autorest/autorest/validation v0.3.1/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E= +github.com/Azure/go-autorest/logger v0.2.0 h1:e4RVHVZKC5p6UANLJHkM4OfR1UKZPj8Wt8Pcx+3oqrE= +github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= +github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= +github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= -github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= -github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= -github.com/Masterminds/goutils v1.1.0 h1:zukEsf/1JZwCMgHiK3GZftabmxiCw4apj3a28RPBiVg= github.com/Masterminds/goutils v1.1.0/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= -github.com/Masterminds/semver/v3 v3.1.0 h1:Y2lUDsFKVRSYGojLJ1yLxSXdMmMYTYls0rCvoqmMUQk= +github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= +github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= github.com/Masterminds/semver/v3 v3.1.0/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/Masterminds/sprig/v3 v3.1.0 h1:j7GpgZ7PdFqNsmncycTHsLmVPf5/3wJtlgW9TNDYD9Y= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Masterminds/sprig/v3 v3.1.0/go.mod h1:ONGMf7UfYGAbMXCZmQLy8x3lCDIPrEZE/rU8pmrbihA= +github.com/Masterminds/sprig/v3 v3.2.2 h1:17jRggJu518dr3QaafizSXOjKYp94wKfABxUmyxvxX8= +github.com/Masterminds/sprig/v3 v3.2.2/go.mod h1:UoaO7Yp8KlPnJIYWTFkMaqPUYKTfGFPhxNuwnnxkKlk= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= +github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= +github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/ThalesIgnite/crypto11 v1.2.4 h1:3MebRK/U0mA2SmSthXAIZAdUA9w8+ZuKem2O6HuR1f8= +github.com/ThalesIgnite/crypto11 v1.2.4 h1:3MebRK/U0mA2SmSthXAIZAdUA9w8+ZuKem2O6HuR1f8= +github.com/ThalesIgnite/crypto11 v1.2.4/go.mod h1:ILDKtnCKiQ7zRoNxcp36Y1ZR8LBPmR2E23+wTQe/MlE= github.com/ThalesIgnite/crypto11 v1.2.4/go.mod h1:ILDKtnCKiQ7zRoNxcp36Y1ZR8LBPmR2E23+wTQe/MlE= github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= +github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= +github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= +github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= @@ -76,6 +124,7 @@ github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+Ce github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= +github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -91,6 +140,9 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= @@ -109,15 +161,19 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/badger v1.6.2 h1:mNw0qs90GVgGGWylh0umH5iag1j6n/PeJtNvL6KY/x8= github.com/dgraph-io/badger v1.6.2/go.mod h1:JW2yswe3V058sS0kZ2h/AXeDSqFjxnZcRrVH//y2UQE= -github.com/dgraph-io/badger/v2 v2.0.1-rc1.0.20201003150343-5d1bab4fc658 h1:/WBjuutuivOA02gpDtrvrWKw01ugkyt3QnimB7enbtI= -github.com/dgraph-io/badger/v2 v2.0.1-rc1.0.20201003150343-5d1bab4fc658/go.mod h1:2uGEvGm+JSDLd5UAaKIFSbXDcYyeH0fWJP4N2HMMYMI= +github.com/dgraph-io/badger/v2 v2.2007.4 h1:TRWBQg8UrlUhaFdco01nO2uXwzKS7zd+HVdwV/GHc4o= +github.com/dgraph-io/badger/v2 v2.2007.4/go.mod h1:vSw/ax2qojzbN6eXHIx6KPKtCSHJN/Uz0X0VPruTIhk= github.com/dgraph-io/ristretto v0.0.2/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= +github.com/dgraph-io/ristretto v0.0.3-0.20200630154024-f66de99634de/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= github.com/dgraph-io/ristretto v0.0.4-0.20200906165740-41ebdbffecfd h1:KoJOtZf+6wpQaDTuOWGuo61GxcPBIfhwRxRTaTWGCTc= github.com/dgraph-io/ristretto v0.0.4-0.20200906165740-41ebdbffecfd/go.mod h1:YylP9MpCYGVZQrly/j/diqcdUetCRRePeBB0c2VGXsA= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dimchansky/utfbom v1.1.0/go.mod h1:rO41eb7gLfo8SF1jd9F8HplJm1Fewwi4mQvIirEdv+8= +github.com/dimchansky/utfbom v1.1.1 h1:vV6w1AhK4VMnhBno/TPVCoK9U/LP0PkLCS9tbxHdi/U= +github.com/dimchansky/utfbom v1.1.1/go.mod h1:SxdoEBH5qIqFocHMyGOXVAybYJdr71b1Q/j0mACtrfE= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -129,8 +185,14 @@ github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4s github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= +github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= +github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= @@ -173,8 +235,10 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -188,12 +252,14 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.1/go.mod h1:DopwsBzvsk0Fs44TXzsVbJyPhcCPeIwnvohx4u74HPM= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw= -github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -203,11 +269,17 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= +github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= +github.com/google/martian/v3 v3.2.1/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -215,12 +287,17 @@ github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20201009210932-67992a1a5a35/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= @@ -235,6 +312,7 @@ github.com/groob/finalizer v0.0.0-20170707115354-4c2ed49aabda/go.mod h1:MyndkAZd github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -257,13 +335,16 @@ github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0m github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/huandu/xstrings v1.3.1 h1:4jgBlKK6tLKFvO8u5pmYjG91cqytmDCDvGh7ECVFfFs= github.com/huandu/xstrings v1.3.1/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/huandu/xstrings v1.3.2 h1:L18LIDzqlW6xN2rEkpdV8+oL/IXWJ1APd+vsdYy4Wdw= +github.com/huandu/xstrings v1.3.2/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/imdario/mergo v0.3.8 h1:CGgOkSJeqMRmt0D9XLWExdT4m4F1vd3FV3VPt+0VxkQ= github.com/imdario/mergo v0.3.8/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= +github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= +github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= +github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= @@ -276,11 +357,11 @@ github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= -github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU= -github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.12.3 h1:G5AfA94pHPysR56qqrkO2pxEexdDzrpFJ6yt/VqWxVU= +github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -294,39 +375,43 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lunixbochs/vtclean v0.0.0-20180621232353-2d01aacdc34a/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= -github.com/lunixbochs/vtclean v1.0.0 h1:xu2sLAri4lGiovBDQKxl5mrXyESr3gUr5m5SM5+LVb8= -github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/manifoldco/promptui v0.8.0 h1:R95mMF+McvXZQ7j1g8ucVZE1gLP3Sv6j9vlF9kyRqQo= -github.com/manifoldco/promptui v0.8.0/go.mod h1:n4zTdgP0vr0S3w7/O/g98U+e0gwLScEXGwov2nIKuGQ= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= -github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= +github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.13 h1:qdl+GuBjcsKKDco5BsxPJlId98mSWNKqYA+Co0SC1yA= +github.com/mattn/go-isatty v0.0.13/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/micromdm/scep/v2 v2.0.0 h1:cRzcY0S5QX+0+J+7YC4P2uZSnfMup8S8zJu/bLFgOkA= -github.com/micromdm/scep/v2 v2.0.0/go.mod h1:ouaDs5tcjOjdHD/h8BGaQsWE87MUnQ/wMTMgfMMIpPc= +github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= +github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/micromdm/scep/v2 v2.1.0 h1:2fS9Rla7qRR266hvUoEauBJ7J6FhgssEiq2OkSKXmaU= +github.com/micromdm/scep/v2 v2.1.0/go.mod h1:BkF7TkPPhmgJAMtHfP+sFTKXmgzNJgLQlvvGoOExBcc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f h1:eVB9ELsoq5ouItQBr5Tj334bhPJG/MX+m7rTchmzVUQ= github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= -github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY= github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -341,11 +426,11 @@ github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxzi github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/newrelic/go-agent v2.15.0+incompatible h1:IB0Fy+dClpBq9aEoIrLyQXzU34JyI1xVTanPLB/+jvU= github.com/newrelic/go-agent v2.15.0+incompatible/go.mod h1:a8Fv1b/fYhFSReoTU6HDkTYIMZeSVNffmoS726Y0LzQ= +github.com/nguyer/promptui v0.8.1-0.20210517132806-70ccd4709797 h1:unCiBzwNjcuVbP3bgM76z0ORyIuI4sspop1qhkQJ044= +github.com/nguyer/promptui v0.8.1-0.20210517132806-70ccd4709797/go.mod h1:CBMXL3a2sC3Q8TjpLcQt8w/3aQ23VSy6r7UFeCG6phA= github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= -github.com/omorsi/pkcs7 v0.0.0-20210217142924-a7b80a2a8568 h1:+MPqEswjYiS0S1FCTg8MIhMBMzxiVQ94rooFwvPPiWk= -github.com/omorsi/pkcs7 v0.0.0-20210217142924-a7b80a2a8568/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= @@ -391,6 +476,7 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= @@ -399,10 +485,10 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= -github.com/samfoo/ansi v0.0.0-20160124022901-b6bd2ded7189 h1:CmSpbxmewNQbzqztaY0bke1qzHhyNyC29wYgh17Gxfo= -github.com/samfoo/ansi v0.0.0-20160124022901-b6bd2ded7189/go.mod h1:UUwuHEJ9zkkPDxspIHOa59PUeSkGFljESGzbxntLmIg= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= @@ -411,8 +497,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= -github.com/smallstep/nosql v0.3.6 h1:cq6a3NwjFJxkVlWU1T4qGskcfEXr0fO1WqQrraDO1Po= -github.com/smallstep/nosql v0.3.6/go.mod h1:h1zC/Z54uNHc8euquLED4qJNCrMHd3nytA141ZZh4qQ= +github.com/smallstep/nosql v0.3.8 h1:1/EWUbbEdz9ai0g9Fd09VekVjtxp+5+gIHpV2PdwW3o= +github.com/smallstep/nosql v0.3.8/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= @@ -422,8 +508,9 @@ github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0b github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cast v1.4.1 h1:s0hze+J0196ZfEMTs80N7UlFt0BDuQ7Q+JDnHiMWKdA= +github.com/spf13/cast v1.4.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= @@ -438,8 +525,9 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/thales-e-security/pool v0.0.2 h1:RAPs4q2EbWsTit6tpzuvTFlgFRJ3S8Evf5gtvVDbmPg= github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpupY8mv0Phz0gjhU= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= @@ -455,10 +543,14 @@ github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= +go.mozilla.org/pkcs7 v0.0.0-20210730143726-725912489c62/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk= +go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 h1:CCriYyAfq1Br1aIYettdHZTy8mBTIPo7We18TuO/bak= +go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk= go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= @@ -466,13 +558,17 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.5 h1:dntmOdLpSpHlVqbW5Eay97DelsZHe+55D+xC6i0dDS0= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= -go.step.sm/cli-utils v0.2.0 h1:hpVu9+6dpv/7/Bd8nGJFc3V+gQ+TciSJRTu9TavDUQ4= -go.step.sm/cli-utils v0.2.0/go.mod h1:+t4qCp5NO+080DdGkJxEh3xL5S4TcYC2JTPLMM72b6Y= -go.step.sm/crypto v0.6.1/go.mod h1:AKS4yMZVZD4EGjpSkY4eibuMenrvKCscb+BpWMet8c0= -go.step.sm/crypto v0.8.3 h1:TO/OPlaUrYXhs8srGEFNyL6OWVQvRmEPCUONNnQUuEM= -go.step.sm/crypto v0.8.3/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= +go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= +go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= +go.step.sm/cli-utils v0.6.1 h1:v31ctEh/BFPGU067fF9Y8u2EIg6LRldUbN2dc/+u/V8= +go.step.sm/cli-utils v0.6.1/go.mod h1:stgyXHHHi9KwcR86sgzDdFC6e/tAmpF4NbqwSK7q/GM= +go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= +go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8= +go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg= +go.step.sm/linkedca v0.7.0 h1:ydYigs0CgLFkPGjOO4KJcAcAWbuPP8ECF1IsyHdftYc= +go.step.sm/linkedca v0.7.0/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= @@ -491,8 +587,10 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200414173820-0848c9571904/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc= +golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/crypto v0.0.0-20210915214749-c084706c2272 h1:3erb+vDS8lU1sxfDHF4/hhWyaXnhIaO+7RgL4fDZORA= +golang.org/x/crypto v0.0.0-20210915214749-c084706c2272/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -515,6 +613,8 @@ golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= @@ -523,6 +623,9 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20170726083632-f5079bd7f6f7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -558,17 +661,30 @@ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8= +golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 h1:ld7aEMNHoBnnDAX15v1T6z31v8HwR2A9FYOuAhWqkwc= golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20201109201403-9fd604954f58/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c h1:pkQiBZBvdos9qq4wBAHqlzuZHEXo07pqV06ef90u1WI= +golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -577,6 +693,9 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20170728174421-0f826bdd13b5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -605,6 +724,7 @@ golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -621,16 +741,33 @@ golang.org/x/sys v0.0.0-20200828194041-157a740278f4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c h1:VwygUrnw9jn88c4u8GD3rZQbqrP/tgas88tPUbBxQrk= +golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 h1:7ZDGnxgHAMw7thfC5bEos0RDAccZKxioiWBhfIe+tvw= +golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -682,7 +819,14 @@ golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE= -golang.org/x/tools v0.0.0-20201017001424-6003fad69a88/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= +golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -705,16 +849,22 @@ google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0M google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= -google.golang.org/api v0.33.0 h1:+gL0XvACeMIvpwLZ5rQZzLn5cwOsgg8dIcfJ2SYfBVw= -google.golang.org/api v0.33.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg= +google.golang.org/api v0.35.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg= +google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34qYtE= +google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjRCQ8= +google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU= +google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94= +google.golang.org/api v0.47.0 h1:sQLWZQvP6jPGIP4JGPkJu4zHswrv81iobiyszr3b/0I= +google.golang.org/api v0.47.0/go.mod h1:Wbvgpq1HddcWVtzsVLyfLp8lDg6AA241LmgIL59tHXo= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.6 h1:lMO5rYAqUxkmaj76jAkRUvt5JZgFymx/+Q5Mzfivuhc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -739,6 +889,7 @@ google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfG google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= @@ -746,8 +897,19 @@ google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154 h1:bFFRpT+e8JJVY7lMMfvezL1ZIwqiwmPl2bsE2yx4HqM= -google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20210222152913-aa3ee6e6a81c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20210303154014-9728d6b83eeb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= +google.golang.org/genproto v0.0.0-20210513213006-bf773b8c8384/go.mod h1:P3QM42oQyzQSnHPnZ/vqoCdDmzH28fzWByN9asMeM8A= +google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= +google.golang.org/genproto v0.0.0-20210719143636-1d5a45f8e492 h1:7yQQsvnwjfEahbNNEKcBHv3mR+HnB1ctGY/z1JXzx8M= +google.golang.org/genproto v0.0.0-20210719143636-1d5a45f8e492/go.mod h1:ob2IJxKrgPT52GcgX759i1sleT07tiKowYBGbczaW48= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= @@ -766,8 +928,18 @@ google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3Iji google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.32.0 h1:zWTV+LMdc3kaiJMSTOFz2UgSBgx8RNQoTGiZu3fR9S0= -google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= +google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= +google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= +google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc v1.37.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= +google.golang.org/grpc v1.37.1/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= +google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= +google.golang.org/grpc v1.39.0 h1:Klz8I9kdtkIN6EpHHUOMLCYhTn/2WAe5a0s1hcBkdTI= +google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -777,8 +949,11 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= -google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -789,15 +964,20 @@ gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= +gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/kms/apiv1/options.go b/kms/apiv1/options.go index 7cc7f748..79b07a60 100644 --- a/kms/apiv1/options.go +++ b/kms/apiv1/options.go @@ -29,6 +29,12 @@ type CertificateManager interface { StoreCertificate(req *StoreCertificateRequest) error } +// ValidateName is an interface that KeyManager can implement to validate a +// given name or URI. +type NameValidator interface { + ValidateName(s string) error +} + // ErrNotImplemented is the type of error returned if an operation is not // implemented. type ErrNotImplemented struct { @@ -73,6 +79,8 @@ const ( YubiKey Type = "yubikey" // SSHAgentKMS is a KMS implementation using ssh-agent to access keys. SSHAgentKMS Type = "sshagentkms" + // AzureKMS is a KMS implementation using Azure Key Vault. + AzureKMS Type = "azurekms" ) // Options are the KMS options. They represent the kms object in the ca.json. @@ -81,18 +89,18 @@ type Options struct { Type string `json:"type"` // Path to the credentials file used in CloudKMS and AmazonKMS. - CredentialsFile string `json:"credentialsFile"` + CredentialsFile string `json:"credentialsFile,omitempty"` // URI is based on the PKCS #11 URI Scheme defined in // https://tools.ietf.org/html/rfc7512 and represents the configuration used // to connect to the KMS. // // Used by: pkcs11 - URI string `json:"uri"` + URI string `json:"uri,omitempty"` // Pin used to access the PKCS11 module. It can be defined in the URI using // the pin-value or pin-source properties. - Pin string `json:"pin"` + Pin string `json:"pin,omitempty"` // ManagementKey used in YubiKeys. Default management key is the hexadecimal // string 010203040506070801020304050607080102030405060708: @@ -101,13 +109,13 @@ type Options struct { // 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // } - ManagementKey string `json:"managementKey"` + ManagementKey string `json:"managementKey,omitempty"` // Region to use in AmazonKMS. - Region string `json:"region"` + Region string `json:"region,omitempty"` // Profile to use in AmazonKMS. - Profile string `json:"profile"` + Profile string `json:"profile,omitempty"` } // Validate checks the fields in Options. @@ -118,8 +126,9 @@ func (o *Options) Validate() error { switch Type(strings.ToLower(o.Type)) { case DefaultKMS, SoftKMS: // Go crypto based kms. - case CloudKMS, AmazonKMS, SSHAgentKMS: // Cloud based kms. + case CloudKMS, AmazonKMS, AzureKMS: // Cloud based kms. case YubiKey, PKCS11: // Hardware based kms. + case SSHAgentKMS: // Others default: return errors.Errorf("unsupported kms type %s", o.Type) } diff --git a/kms/azurekms/internal/mock/key_vault_client.go b/kms/azurekms/internal/mock/key_vault_client.go new file mode 100644 index 00000000..42bd55fd --- /dev/null +++ b/kms/azurekms/internal/mock/key_vault_client.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/smallstep/certificates/kms/azurekms (interfaces: KeyVaultClient) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + keyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// KeyVaultClient is a mock of KeyVaultClient interface +type KeyVaultClient struct { + ctrl *gomock.Controller + recorder *KeyVaultClientMockRecorder +} + +// KeyVaultClientMockRecorder is the mock recorder for KeyVaultClient +type KeyVaultClientMockRecorder struct { + mock *KeyVaultClient +} + +// NewKeyVaultClient creates a new mock instance +func NewKeyVaultClient(ctrl *gomock.Controller) *KeyVaultClient { + mock := &KeyVaultClient{ctrl: ctrl} + mock.recorder = &KeyVaultClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *KeyVaultClient) EXPECT() *KeyVaultClientMockRecorder { + return m.recorder +} + +// CreateKey mocks base method +func (m *KeyVaultClient) CreateKey(arg0 context.Context, arg1, arg2 string, arg3 keyvault.KeyCreateParameters) (keyvault.KeyBundle, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateKey", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(keyvault.KeyBundle) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateKey indicates an expected call of CreateKey +func (mr *KeyVaultClientMockRecorder) CreateKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateKey", reflect.TypeOf((*KeyVaultClient)(nil).CreateKey), arg0, arg1, arg2, arg3) +} + +// GetKey mocks base method +func (m *KeyVaultClient) GetKey(arg0 context.Context, arg1, arg2, arg3 string) (keyvault.KeyBundle, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKey", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(keyvault.KeyBundle) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKey indicates an expected call of GetKey +func (mr *KeyVaultClientMockRecorder) GetKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*KeyVaultClient)(nil).GetKey), arg0, arg1, arg2, arg3) +} + +// Sign mocks base method +func (m *KeyVaultClient) Sign(arg0 context.Context, arg1, arg2, arg3 string, arg4 keyvault.KeySignParameters) (keyvault.KeyOperationResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sign", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(keyvault.KeyOperationResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Sign indicates an expected call of Sign +func (mr *KeyVaultClientMockRecorder) Sign(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*KeyVaultClient)(nil).Sign), arg0, arg1, arg2, arg3, arg4) +} diff --git a/kms/azurekms/key_vault.go b/kms/azurekms/key_vault.go new file mode 100644 index 00000000..34d9c3f1 --- /dev/null +++ b/kms/azurekms/key_vault.go @@ -0,0 +1,342 @@ +package azurekms + +import ( + "context" + "crypto" + "regexp" + "time" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/azure/auth" + "github.com/Azure/go-autorest/autorest/date" + "github.com/pkg/errors" + "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/certificates/kms/uri" +) + +func init() { + apiv1.Register(apiv1.AzureKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) { + return New(ctx, opts) + }) +} + +// Scheme is the scheme used for the Azure Key Vault uris. +const Scheme = "azurekms" + +// keyIDRegexp is the regular expression that Key Vault uses on the kid. We can +// extract the vault, name and version of the key. +var keyIDRegexp = regexp.MustCompile(`^https://([0-9a-zA-Z-]+)\.vault\.azure\.net/keys/([0-9a-zA-Z-]+)/([0-9a-zA-Z-]+)$`) + +var ( + valueTrue = true + value2048 int32 = 2048 + value3072 int32 = 3072 + value4096 int32 = 4096 +) + +var now = func() time.Time { + return time.Now().UTC() +} + +type keyType struct { + Kty keyvault.JSONWebKeyType + Curve keyvault.JSONWebKeyCurveName +} + +func (k keyType) KeyType(pl apiv1.ProtectionLevel) keyvault.JSONWebKeyType { + switch k.Kty { + case keyvault.EC: + if pl == apiv1.HSM { + return keyvault.ECHSM + } + return k.Kty + case keyvault.RSA: + if pl == apiv1.HSM { + return keyvault.RSAHSM + } + return k.Kty + default: + return "" + } +} + +var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]keyType{ + apiv1.UnspecifiedSignAlgorithm: { + Kty: keyvault.EC, + Curve: keyvault.P256, + }, + apiv1.SHA256WithRSA: { + Kty: keyvault.RSA, + }, + apiv1.SHA384WithRSA: { + Kty: keyvault.RSA, + }, + apiv1.SHA512WithRSA: { + Kty: keyvault.RSA, + }, + apiv1.SHA256WithRSAPSS: { + Kty: keyvault.RSA, + }, + apiv1.SHA384WithRSAPSS: { + Kty: keyvault.RSA, + }, + apiv1.SHA512WithRSAPSS: { + Kty: keyvault.RSA, + }, + apiv1.ECDSAWithSHA256: { + Kty: keyvault.EC, + Curve: keyvault.P256, + }, + apiv1.ECDSAWithSHA384: { + Kty: keyvault.EC, + Curve: keyvault.P384, + }, + apiv1.ECDSAWithSHA512: { + Kty: keyvault.EC, + Curve: keyvault.P521, + }, +} + +// vaultResource is the value the client will use as audience. +const vaultResource = "https://vault.azure.net" + +// KeyVaultClient is the interface implemented by keyvault.BaseClient. It will +// be used for testing purposes. +type KeyVaultClient interface { + GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (keyvault.KeyBundle, error) + CreateKey(ctx context.Context, vaultBaseURL string, keyName string, parameters keyvault.KeyCreateParameters) (keyvault.KeyBundle, error) + Sign(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string, parameters keyvault.KeySignParameters) (keyvault.KeyOperationResult, error) +} + +// KeyVault implements a KMS using Azure Key Vault. +// +// The URI format used in Azure Key Vault is the following: +// +// - azurekms:name=key-name;vault=vault-name +// - azurekms:name=key-name;vault=vault-name?version=key-version +// - azurekms:name=key-name;vault=vault-name?hsm=true +// +// The scheme is "azurekms"; "name" is the key name; "vault" is the key vault +// name where the key is located; "version" is an optional parameter that +// defines the version of they key, if version is not given, the latest one will +// be used; "hsm" defines if an HSM want to be used for this key, this is +// specially useful when this is used from `step`. +// +// TODO(mariano): The implementation is using /services/keyvault/v7.1/keyvault +// package, at some point Azure might create a keyvault client with all the +// functionality in /sdk/keyvault, we should migrate to that once available. +type KeyVault struct { + baseClient KeyVaultClient + defaults DefaultOptions +} + +// DefaultOptions are custom options that can be passed as defaults using the +// URI in apiv1.Options. +type DefaultOptions struct { + Vault string + ProtectionLevel apiv1.ProtectionLevel +} + +var createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + baseClient := keyvault.New() + + // With an URI, try to log in only using client credentials in the URI. + // Client credentials requires: + // - client-id + // - client-secret + // - tenant-id + // And optionally the aad-endpoint to support custom clouds: + // - aad-endpoint (defaults to https://login.microsoftonline.com/) + if opts.URI != "" { + u, err := uri.ParseWithScheme(Scheme, opts.URI) + if err != nil { + return nil, err + } + + // Required options + clientID := u.Get("client-id") + clientSecret := u.Get("client-secret") + tenantID := u.Get("tenant-id") + // optional + aadEndpoint := u.Get("aad-endpoint") + + if clientID != "" && clientSecret != "" && tenantID != "" { + s := auth.EnvironmentSettings{ + Values: map[string]string{ + auth.ClientID: clientID, + auth.ClientSecret: clientSecret, + auth.TenantID: tenantID, + auth.Resource: vaultResource, + }, + Environment: azure.PublicCloud, + } + if aadEndpoint != "" { + s.Environment.ActiveDirectoryEndpoint = aadEndpoint + } + baseClient.Authorizer, err = s.GetAuthorizer() + if err != nil { + return nil, err + } + return baseClient, nil + } + } + + // Attempt to authorize with the following methods: + // 1. Environment variables. + // - Client credentials + // - Client certificate + // - Username and password + // - MSI + // 2. Using Azure CLI 2.0 on local development. + authorizer, err := auth.NewAuthorizerFromEnvironmentWithResource(vaultResource) + if err != nil { + authorizer, err = auth.NewAuthorizerFromCLIWithResource(vaultResource) + if err != nil { + return nil, errors.Wrap(err, "error getting authorizer for key vault") + } + } + baseClient.Authorizer = authorizer + return &baseClient, nil +} + +// New initializes a new KMS implemented using Azure Key Vault. +func New(ctx context.Context, opts apiv1.Options) (*KeyVault, error) { + baseClient, err := createClient(ctx, opts) + if err != nil { + return nil, err + } + + // step and step-ca do not need and URI, but having a default vault and + // protection level is useful if this package is used as an api + var defaults DefaultOptions + if opts.URI != "" { + u, err := uri.ParseWithScheme(Scheme, opts.URI) + if err != nil { + return nil, err + } + defaults.Vault = u.Get("vault") + if u.GetBool("hsm") { + defaults.ProtectionLevel = apiv1.HSM + } + } + + return &KeyVault{ + baseClient: baseClient, + defaults: defaults, + }, nil +} + +// GetPublicKey loads a public key from Azure Key Vault by its resource name. +func (k *KeyVault) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { + if req.Name == "" { + return nil, errors.New("getPublicKeyRequest 'name' cannot be empty") + } + + vault, name, version, _, err := parseKeyName(req.Name, k.defaults) + if err != nil { + return nil, err + } + + ctx, cancel := defaultContext() + defer cancel() + + resp, err := k.baseClient.GetKey(ctx, vaultBaseURL(vault), name, version) + if err != nil { + return nil, errors.Wrap(err, "keyVault GetKey failed") + } + + return convertKey(resp.Key) +} + +// CreateKey creates a asymmetric key in Azure Key Vault. +func (k *KeyVault) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + if req.Name == "" { + return nil, errors.New("createKeyRequest 'name' cannot be empty") + } + + vault, name, _, hsm, err := parseKeyName(req.Name, k.defaults) + if err != nil { + return nil, err + } + + // Override protection level to HSM only if it's not specified, and is given + // in the uri. + protectionLevel := req.ProtectionLevel + if protectionLevel == apiv1.UnspecifiedProtectionLevel && hsm { + protectionLevel = apiv1.HSM + } + + kt, ok := signatureAlgorithmMapping[req.SignatureAlgorithm] + if !ok { + return nil, errors.Errorf("keyVault does not support signature algorithm '%s'", req.SignatureAlgorithm) + } + var keySize *int32 + if kt.Kty == keyvault.RSA || kt.Kty == keyvault.RSAHSM { + switch req.Bits { + case 2048: + keySize = &value2048 + case 0, 3072: + keySize = &value3072 + case 4096: + keySize = &value4096 + default: + return nil, errors.Errorf("keyVault does not support key size %d", req.Bits) + } + } + + created := date.UnixTime(now()) + + ctx, cancel := defaultContext() + defer cancel() + + resp, err := k.baseClient.CreateKey(ctx, vaultBaseURL(vault), name, keyvault.KeyCreateParameters{ + Kty: kt.KeyType(protectionLevel), + KeySize: keySize, + Curve: kt.Curve, + KeyOps: &[]keyvault.JSONWebKeyOperation{ + keyvault.Sign, keyvault.Verify, + }, + KeyAttributes: &keyvault.KeyAttributes{ + Enabled: &valueTrue, + Created: &created, + NotBefore: &created, + }, + }) + if err != nil { + return nil, errors.Wrap(err, "keyVault CreateKey failed") + } + + publicKey, err := convertKey(resp.Key) + if err != nil { + return nil, err + } + + keyURI := getKeyName(vault, name, resp) + return &apiv1.CreateKeyResponse{ + Name: keyURI, + PublicKey: publicKey, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: keyURI, + }, + }, nil +} + +// CreateSigner returns a crypto.Signer from a previously created asymmetric key. +func (k *KeyVault) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { + if req.SigningKey == "" { + return nil, errors.New("createSignerRequest 'signingKey' cannot be empty") + } + return NewSigner(k.baseClient, req.SigningKey, k.defaults) +} + +// Close closes the client connection to the Azure Key Vault. This is a noop. +func (k *KeyVault) Close() error { + return nil +} + +// ValidateName validates that the given string is a valid URI. +func (k *KeyVault) ValidateName(s string) error { + _, _, _, _, err := parseKeyName(s, k.defaults) + return err +} diff --git a/kms/azurekms/key_vault_test.go b/kms/azurekms/key_vault_test.go new file mode 100644 index 00000000..8f968189 --- /dev/null +++ b/kms/azurekms/key_vault_test.go @@ -0,0 +1,653 @@ +//go:generate mockgen -package mock -mock_names=KeyVaultClient=KeyVaultClient -destination internal/mock/key_vault_client.go github.com/smallstep/certificates/kms/azurekms KeyVaultClient +package azurekms + +import ( + "context" + "crypto" + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/Azure/go-autorest/autorest/date" + "github.com/golang/mock/gomock" + "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/certificates/kms/azurekms/internal/mock" + "go.step.sm/crypto/keyutil" + "gopkg.in/square/go-jose.v2" +) + +var errTest = fmt.Errorf("test error") + +func mockNow(t *testing.T) time.Time { + old := now + t0 := time.Unix(1234567890, 123).UTC() + now = func() time.Time { + return t0 + } + t.Cleanup(func() { + now = old + }) + return t0 +} + +func mockClient(t *testing.T) *mock.KeyVaultClient { + t.Helper() + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + return mock.NewKeyVaultClient(ctrl) +} + +func createJWK(t *testing.T, pub crypto.PublicKey) *keyvault.JSONWebKey { + t.Helper() + b, err := json.Marshal(&jose.JSONWebKey{ + Key: pub, + }) + if err != nil { + t.Fatal(err) + } + key := new(keyvault.JSONWebKey) + if err := json.Unmarshal(b, key); err != nil { + t.Fatal(err) + } + return key +} + +func Test_now(t *testing.T) { + t0 := now() + if loc := t0.Location(); loc != time.UTC { + t.Errorf("now() Location = %v, want %v", loc, time.UTC) + } +} + +func TestNew(t *testing.T) { + client := mockClient(t) + old := createClient + t.Cleanup(func() { + createClient = old + }) + + type args struct { + ctx context.Context + opts apiv1.Options + } + tests := []struct { + name string + setup func() + args args + want *KeyVault + wantErr bool + }{ + {"ok", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{}}, &KeyVault{ + baseClient: client, + }, false}, + {"ok with vault", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{ + URI: "azurekms:vault=my-vault", + }}, &KeyVault{ + baseClient: client, + defaults: DefaultOptions{ + Vault: "my-vault", + ProtectionLevel: apiv1.UnspecifiedProtectionLevel, + }, + }, false}, + {"ok with vault + hsm", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{ + URI: "azurekms:vault=my-vault;hsm=true", + }}, &KeyVault{ + baseClient: client, + defaults: DefaultOptions{ + Vault: "my-vault", + ProtectionLevel: apiv1.HSM, + }, + }, false}, + {"fail", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return nil, errTest + } + }, args{context.Background(), apiv1.Options{}}, nil, true}, + {"fail uri", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{ + URI: "kms:vault=my-vault;hsm=true", + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup() + got, err := New(tt.args.ctx, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("New() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_createClient(t *testing.T) { + type args struct { + ctx context.Context + opts apiv1.Options + } + tests := []struct { + name string + args args + skip bool + wantErr bool + }{ + {"ok", args{context.Background(), apiv1.Options{}}, true, false}, + {"ok with uri", args{context.Background(), apiv1.Options{ + URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id", + }}, false, false}, + {"ok with uri+aad", args{context.Background(), apiv1.Options{ + URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id;aad-enpoint=https%3A%2F%2Flogin.microsoftonline.us%2F", + }}, false, false}, + {"ok with uri no config", args{context.Background(), apiv1.Options{ + URI: "azurekms:", + }}, true, false}, + {"fail uri", args{context.Background(), apiv1.Options{ + URI: "kms:client-id=id;client-secret=secret;tenant-id=id", + }}, false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skip { + t.SkipNow() + } + _, err := createClient(tt.args.ctx, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestKeyVault_GetPublicKey(t *testing.T) { + key, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + pub := key.Public() + jwk := createJWK(t, pub) + + client := mockClient(t) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest) + + type fields struct { + baseClient KeyVaultClient + } + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + fields fields + args args + want crypto.PublicKey + wantErr bool + }{ + {"ok", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + }}, pub, false}, + {"ok with version", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key?version=my-version", + }}, pub, false}, + {"fail GetKey", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found?version=my-version", + }}, nil, true}, + {"fail empty", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "", + }}, nil, true}, + {"fail vault", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=;name=not-found?version=my-version", + }}, nil, true}, + {"fail id", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "azurekms:vault=;name=?version=my-version", + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{ + baseClient: tt.fields.baseClient, + } + got, err := k.GetPublicKey(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("KeyVault.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KeyVault.GetPublicKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_CreateKey(t *testing.T) { + ecKey, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + rsaKey, err := keyutil.GenerateSigner("RSA", "", 2048) + if err != nil { + t.Fatal(err) + } + ecPub := ecKey.Public() + rsaPub := rsaKey.Public() + ecJWK := createJWK(t, ecPub) + rsaJWK := createJWK(t, rsaPub) + + t0 := date.UnixTime(mockNow(t)) + client := mockClient(t) + + expects := []struct { + Name string + Kty keyvault.JSONWebKeyType + KeySize *int32 + Curve keyvault.JSONWebKeyCurveName + Key *keyvault.JSONWebKey + }{ + {"P-256", keyvault.EC, nil, keyvault.P256, ecJWK}, + {"P-256 HSM", keyvault.ECHSM, nil, keyvault.P256, ecJWK}, + {"P-256 HSM (uri)", keyvault.ECHSM, nil, keyvault.P256, ecJWK}, + {"P-256 Default", keyvault.EC, nil, keyvault.P256, ecJWK}, + {"P-384", keyvault.EC, nil, keyvault.P384, ecJWK}, + {"P-521", keyvault.EC, nil, keyvault.P521, ecJWK}, + {"RSA 0", keyvault.RSA, &value3072, "", rsaJWK}, + {"RSA 0 HSM", keyvault.RSAHSM, &value3072, "", rsaJWK}, + {"RSA 0 HSM (uri)", keyvault.RSAHSM, &value3072, "", rsaJWK}, + {"RSA 2048", keyvault.RSA, &value2048, "", rsaJWK}, + {"RSA 3072", keyvault.RSA, &value3072, "", rsaJWK}, + {"RSA 4096", keyvault.RSA, &value4096, "", rsaJWK}, + } + + for _, e := range expects { + client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", keyvault.KeyCreateParameters{ + Kty: e.Kty, + KeySize: e.KeySize, + Curve: e.Curve, + KeyOps: &[]keyvault.JSONWebKeyOperation{ + keyvault.Sign, keyvault.Verify, + }, + KeyAttributes: &keyvault.KeyAttributes{ + Enabled: &valueTrue, + Created: &t0, + NotBefore: &t0, + }, + }).Return(keyvault.KeyBundle{ + Key: e.Key, + }, nil) + } + client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{}, errTest) + client.EXPECT().CreateKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", gomock.Any()).Return(keyvault.KeyBundle{ + Key: nil, + }, nil) + + type fields struct { + baseClient KeyVaultClient + } + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + fields fields + args args + want *apiv1.CreateKeyResponse + wantErr bool + }{ + {"ok P-256", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + ProtectionLevel: apiv1.Software, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-256 HSM", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + ProtectionLevel: apiv1.HSM, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-256 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key?hsm=true", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-256 Default", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-384", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA384, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok P-521", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA512, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 0", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 0, + SignatureAlgorithm: apiv1.SHA256WithRSA, + ProtectionLevel: apiv1.Software, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 0 HSM", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 0, + SignatureAlgorithm: apiv1.SHA256WithRSAPSS, + ProtectionLevel: apiv1.HSM, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 0 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key;hsm=true", + Bits: 0, + SignatureAlgorithm: apiv1.SHA256WithRSAPSS, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 2048", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 2048, + SignatureAlgorithm: apiv1.SHA384WithRSA, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 3072", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 3072, + SignatureAlgorithm: apiv1.SHA512WithRSA, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"ok RSA 4096", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key", + Bits: 4096, + SignatureAlgorithm: apiv1.SHA512WithRSAPSS, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, + {"fail createKey", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, + {"fail convertKey", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, + {"fail name", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "", + }}, nil, true}, + {"fail vault", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=;name=not-found?version=my-version", + }}, nil, true}, + {"fail id", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=?version=my-version", + }}, nil, true}, + {"fail SignatureAlgorithm", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found", + SignatureAlgorithm: apiv1.PureEd25519, + }}, nil, true}, + {"fail bit size", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=not-found", + SignatureAlgorithm: apiv1.SHA384WithRSAPSS, + Bits: 1024, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{ + baseClient: tt.fields.baseClient, + } + got, err := k.CreateKey(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("KeyVault.CreateKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KeyVault.CreateKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_CreateSigner(t *testing.T) { + key, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + pub := key.Public() + jwk := createJWK(t, pub) + + client := mockClient(t) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest) + + type fields struct { + baseClient KeyVaultClient + } + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + fields fields + args args + want crypto.Signer + wantErr bool + }{ + {"ok", fields{client}, args{&apiv1.CreateSignerRequest{ + SigningKey: "azurekms:vault=my-vault;name=my-key", + }}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "", + publicKey: pub, + }, false}, + {"ok with version", fields{client}, args{&apiv1.CreateSignerRequest{ + SigningKey: "azurekms:vault=my-vault;name=my-key;version=my-version", + }}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "my-version", + publicKey: pub, + }, false}, + {"fail GetKey", fields{client}, args{&apiv1.CreateSignerRequest{ + SigningKey: "azurekms:vault=my-vault;name=not-found;version=my-version", + }}, nil, true}, + {"fail SigningKey", fields{client}, args{&apiv1.CreateSignerRequest{ + SigningKey: "", + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{ + baseClient: tt.fields.baseClient, + } + got, err := k.CreateSigner(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("KeyVault.CreateSigner() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KeyVault.CreateSigner() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_Close(t *testing.T) { + client := mockClient(t) + type fields struct { + baseClient KeyVaultClient + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"ok", fields{client}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{ + baseClient: tt.fields.baseClient, + } + if err := k.Close(); (err != nil) != tt.wantErr { + t.Errorf("KeyVault.Close() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_keyType_KeyType(t *testing.T) { + type fields struct { + Kty keyvault.JSONWebKeyType + Curve keyvault.JSONWebKeyCurveName + } + type args struct { + pl apiv1.ProtectionLevel + } + tests := []struct { + name string + fields fields + args args + want keyvault.JSONWebKeyType + }{ + {"ec", fields{keyvault.EC, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.EC}, + {"ec software", fields{keyvault.EC, keyvault.P384}, args{apiv1.Software}, keyvault.EC}, + {"ec hsm", fields{keyvault.EC, keyvault.P521}, args{apiv1.HSM}, keyvault.ECHSM}, + {"rsa", fields{keyvault.RSA, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.RSA}, + {"rsa software", fields{keyvault.RSA, ""}, args{apiv1.Software}, keyvault.RSA}, + {"rsa hsm", fields{keyvault.RSA, ""}, args{apiv1.HSM}, keyvault.RSAHSM}, + {"empty", fields{"FOO", ""}, args{apiv1.UnspecifiedProtectionLevel}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := keyType{ + Kty: tt.fields.Kty, + Curve: tt.fields.Curve, + } + if got := k.KeyType(tt.args.pl); !reflect.DeepEqual(got, tt.want) { + t.Errorf("keyType.KeyType() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyVault_ValidateName(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{"azurekms:name=my-key;vault=my-vault"}, false}, + {"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true"}, false}, + {"fail scheme", args{"azure:name=my-key;vault=my-vault"}, true}, + {"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault"}, true}, + {"fail no name", args{"azurekms:vault=my-vault"}, true}, + {"fail no vault", args{"azurekms:name=my-key"}, true}, + {"fail empty", args{""}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KeyVault{} + if err := k.ValidateName(tt.args.s); (err != nil) != tt.wantErr { + t.Errorf("KeyVault.ValidateName() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/kms/azurekms/signer.go b/kms/azurekms/signer.go new file mode 100644 index 00000000..b0349108 --- /dev/null +++ b/kms/azurekms/signer.go @@ -0,0 +1,182 @@ +package azurekms + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "encoding/base64" + "io" + "math/big" + "time" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/pkg/errors" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +// Signer implements a crypto.Signer using the AWS KMS. +type Signer struct { + client KeyVaultClient + vaultBaseURL string + name string + version string + publicKey crypto.PublicKey +} + +// NewSigner creates a new signer using a key in the AWS KMS. +func NewSigner(client KeyVaultClient, signingKey string, defaults DefaultOptions) (crypto.Signer, error) { + vault, name, version, _, err := parseKeyName(signingKey, defaults) + if err != nil { + return nil, err + } + + // Make sure that the key exists. + signer := &Signer{ + client: client, + vaultBaseURL: vaultBaseURL(vault), + name: name, + version: version, + } + if err := signer.preloadKey(); err != nil { + return nil, err + } + + return signer, nil +} + +func (s *Signer) preloadKey() error { + ctx, cancel := defaultContext() + defer cancel() + + resp, err := s.client.GetKey(ctx, s.vaultBaseURL, s.name, s.version) + if err != nil { + return errors.Wrap(err, "keyVault GetKey failed") + } + + s.publicKey, err = convertKey(resp.Key) + return err +} + +// Public returns the public key of this signer or an error. +func (s *Signer) Public() crypto.PublicKey { + return s.publicKey +} + +// Sign signs digest with the private key stored in the AWS KMS. +func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + alg, err := getSigningAlgorithm(s.Public(), opts) + if err != nil { + return nil, err + } + + b64 := base64.RawURLEncoding.EncodeToString(digest) + + // Sign with retry if the key is not ready + resp, err := s.signWithRetry(alg, b64, 3) + if err != nil { + return nil, errors.Wrap(err, "keyVault Sign failed") + } + + sig, err := base64.RawURLEncoding.DecodeString(*resp.Result) + if err != nil { + return nil, errors.Wrap(err, "error decoding keyVault Sign result") + } + + var octetSize int + switch alg { + case keyvault.ES256: + octetSize = 32 // 256-bit, concat(R,S) = 64 bytes + case keyvault.ES384: + octetSize = 48 // 384-bit, concat(R,S) = 96 bytes + case keyvault.ES512: + octetSize = 66 // 528-bit, concat(R,S) = 132 bytes + default: + return sig, nil + } + + // Convert to asn1 + if len(sig) != octetSize*2 { + return nil, errors.Errorf("keyVault Sign failed: unexpected signature length") + } + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(new(big.Int).SetBytes(sig[:octetSize])) // R + b.AddASN1BigInt(new(big.Int).SetBytes(sig[octetSize:])) // S + }) + return b.Bytes() +} + +func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttempts int) (keyvault.KeyOperationResult, error) { +retry: + ctx, cancel := defaultContext() + defer cancel() + + resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{ + Algorithm: alg, + Value: &b64, + }) + if err != nil && retryAttempts > 0 { + var requestError *azure.RequestError + if errors.As(err, &requestError) { + if se := requestError.ServiceError; se != nil && se.InnerError != nil { + code, ok := se.InnerError["code"].(string) + if ok && code == "KeyNotYetValid" { + time.Sleep(time.Second / time.Duration(retryAttempts)) + retryAttempts-- + goto retry + } + } + } + } + return resp, err +} + +func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) { + switch key.(type) { + case *rsa.PublicKey: + hashFunc := opts.HashFunc() + pss, isPSS := opts.(*rsa.PSSOptions) + // Random salt lengths are not supported + if isPSS && + pss.SaltLength != rsa.PSSSaltLengthAuto && + pss.SaltLength != rsa.PSSSaltLengthEqualsHash && + pss.SaltLength != hashFunc.Size() { + return "", errors.Errorf("unsupported RSA-PSS salt length %d", pss.SaltLength) + } + + switch h := hashFunc; h { + case crypto.SHA256: + if isPSS { + return keyvault.PS256, nil + } + return keyvault.RS256, nil + case crypto.SHA384: + if isPSS { + return keyvault.PS384, nil + } + return keyvault.RS384, nil + case crypto.SHA512: + if isPSS { + return keyvault.PS512, nil + } + return keyvault.RS512, nil + default: + return "", errors.Errorf("unsupported hash function %v", h) + } + case *ecdsa.PublicKey: + switch h := opts.HashFunc(); h { + case crypto.SHA256: + return keyvault.ES256, nil + case crypto.SHA384: + return keyvault.ES384, nil + case crypto.SHA512: + return keyvault.ES512, nil + default: + return "", errors.Errorf("unsupported hash function %v", h) + } + default: + return "", errors.Errorf("unsupported key type %T", key) + } +} diff --git a/kms/azurekms/signer_test.go b/kms/azurekms/signer_test.go new file mode 100644 index 00000000..bd072b25 --- /dev/null +++ b/kms/azurekms/signer_test.go @@ -0,0 +1,493 @@ +package azurekms + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "io" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/golang/mock/gomock" + "github.com/smallstep/certificates/kms/apiv1" + "go.step.sm/crypto/keyutil" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +func TestNewSigner(t *testing.T) { + key, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + pub := key.Public() + jwk := createJWK(t, pub) + + client := mockClient(t) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest) + + var noOptions DefaultOptions + type args struct { + client KeyVaultClient + signingKey string + defaults DefaultOptions + } + tests := []struct { + name string + args args + want crypto.Signer + wantErr bool + }{ + {"ok", args{client, "azurekms:vault=my-vault;name=my-key", noOptions}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "", + publicKey: pub, + }, false}, + {"ok with version", args{client, "azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "my-version", + publicKey: pub, + }, false}, + {"ok with options", args{client, "azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault", ProtectionLevel: apiv1.HSM}}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "my-version", + publicKey: pub, + }, false}, + {"fail GetKey", args{client, "azurekms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true}, + {"fail vault", args{client, "azurekms:name=not-found;vault=", noOptions}, nil, true}, + {"fail id", args{client, "azurekms:name=;vault=my-vault?version=my-version", noOptions}, nil, true}, + {"fail scheme", args{client, "kms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewSigner(tt.args.client, tt.args.signingKey, tt.args.defaults) + if (err != nil) != tt.wantErr { + t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewSigner() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSigner_Public(t *testing.T) { + key, err := keyutil.GenerateDefaultSigner() + if err != nil { + t.Fatal(err) + } + pub := key.Public() + + type fields struct { + publicKey crypto.PublicKey + } + tests := []struct { + name string + fields fields + want crypto.PublicKey + }{ + {"ok", fields{pub}, pub}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Signer{ + publicKey: tt.fields.publicKey, + } + if got := s.Public(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.Public() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSigner_Sign(t *testing.T) { + sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) { + key, err := keyutil.GenerateSigner(kty, crv, bits) + if err != nil { + t.Fatal(err) + } + h := opts.HashFunc().New() + h.Write([]byte("random-data")) + sum := h.Sum(nil) + + var sig, resultSig []byte + if priv, ok := key.(*ecdsa.PrivateKey); ok { + r, s, err := ecdsa.Sign(rand.Reader, priv, sum) + if err != nil { + t.Fatal(err) + } + curveBits := priv.Params().BitSize + keyBytes := curveBits / 8 + if curveBits%8 > 0 { + keyBytes++ + } + rBytes := r.Bytes() + rBytesPadded := make([]byte, keyBytes) + copy(rBytesPadded[keyBytes-len(rBytes):], rBytes) + + sBytes := s.Bytes() + sBytesPadded := make([]byte, keyBytes) + copy(sBytesPadded[keyBytes-len(sBytes):], sBytes) + // nolint:gocritic + resultSig = append(rBytesPadded, sBytesPadded...) + + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(r) + b.AddASN1BigInt(s) + }) + sig, err = b.Bytes() + if err != nil { + t.Fatal(err) + } + } else { + sig, err = key.Sign(rand.Reader, sum, opts) + if err != nil { + t.Fatal(err) + } + resultSig = sig + } + + return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig + } + + p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256) + p384, p384Digest, p386ResultSig, p384Sig := sign("EC", "P-384", 0, crypto.SHA384) + p521, p521Digest, p521ResultSig, p521Sig := sign("EC", "P-521", 0, crypto.SHA512) + rsaSHA256, rsaSHA256Digest, rsaSHA256ResultSig, rsaSHA256Sig := sign("RSA", "", 2048, crypto.SHA256) + rsaSHA384, rsaSHA384Digest, rsaSHA384ResultSig, rsaSHA384Sig := sign("RSA", "", 2048, crypto.SHA384) + rsaSHA512, rsaSHA512Digest, rsaSHA512ResultSig, rsaSHA512Sig := sign("RSA", "", 2048, crypto.SHA512) + rsaPSSSHA256, rsaPSSSHA256Digest, rsaPSSSHA256ResultSig, rsaPSSSHA256Sig := sign("RSA", "", 2048, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + Hash: crypto.SHA256, + }) + rsaPSSSHA384, rsaPSSSHA384Digest, rsaPSSSHA384ResultSig, rsaPSSSHA384Sig := sign("RSA", "", 2048, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + Hash: crypto.SHA512, + }) + rsaPSSSHA512, rsaPSSSHA512Digest, rsaPSSSHA512ResultSig, rsaPSSSHA512Sig := sign("RSA", "", 2048, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + Hash: crypto.SHA512, + }) + + ed25519Key, err := keyutil.GenerateSigner("OKP", "Ed25519", 0) + if err != nil { + t.Fatal(err) + } + + client := mockClient(t) + expects := []struct { + name string + keyVersion string + alg keyvault.JSONWebKeySignatureAlgorithm + digest []byte + result keyvault.KeyOperationResult + err error + }{ + {"P-256", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ + Result: &p256ResultSig, + }, nil}, + {"P-384", "my-version", keyvault.ES384, p384Digest, keyvault.KeyOperationResult{ + Result: &p386ResultSig, + }, nil}, + {"P-521", "my-version", keyvault.ES512, p521Digest, keyvault.KeyOperationResult{ + Result: &p521ResultSig, + }, nil}, + {"RSA SHA256", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{ + Result: &rsaSHA256ResultSig, + }, nil}, + {"RSA SHA384", "", keyvault.RS384, rsaSHA384Digest, keyvault.KeyOperationResult{ + Result: &rsaSHA384ResultSig, + }, nil}, + {"RSA SHA512", "", keyvault.RS512, rsaSHA512Digest, keyvault.KeyOperationResult{ + Result: &rsaSHA512ResultSig, + }, nil}, + {"RSA-PSS SHA256", "", keyvault.PS256, rsaPSSSHA256Digest, keyvault.KeyOperationResult{ + Result: &rsaPSSSHA256ResultSig, + }, nil}, + {"RSA-PSS SHA384", "", keyvault.PS384, rsaPSSSHA384Digest, keyvault.KeyOperationResult{ + Result: &rsaPSSSHA384ResultSig, + }, nil}, + {"RSA-PSS SHA512", "", keyvault.PS512, rsaPSSSHA512Digest, keyvault.KeyOperationResult{ + Result: &rsaPSSSHA512ResultSig, + }, nil}, + // Errors + {"fail Sign", "", keyvault.RS256, rsaSHA256Digest, keyvault.KeyOperationResult{}, errTest}, + {"fail sign length", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ + Result: &rsaSHA256ResultSig, + }, nil}, + {"fail base64", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ + Result: func() *string { + v := "😎" + return &v + }(), + }, nil}, + } + for _, e := range expects { + value := base64.RawURLEncoding.EncodeToString(e.digest) + client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{ + Algorithm: e.alg, + Value: &value, + }).Return(e.result, e.err) + } + + type fields struct { + client KeyVaultClient + vaultBaseURL string + name string + version string + publicKey crypto.PublicKey + } + type args struct { + rand io.Reader + digest []byte + opts crypto.SignerOpts + } + tests := []struct { + name string + fields fields + args args + want []byte + wantErr bool + }{ + {"ok P-256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, p256Sig, false}, + {"ok P-384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p384}, args{ + rand.Reader, p384Digest, crypto.SHA384, + }, p384Sig, false}, + {"ok P-521", fields{client, "https://my-vault.vault.azure.net/", "my-key", "my-version", p521}, args{ + rand.Reader, p521Digest, crypto.SHA512, + }, p521Sig, false}, + {"ok RSA SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{ + rand.Reader, rsaSHA256Digest, crypto.SHA256, + }, rsaSHA256Sig, false}, + {"ok RSA SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA384}, args{ + rand.Reader, rsaSHA384Digest, crypto.SHA384, + }, rsaSHA384Sig, false}, + {"ok RSA SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA512}, args{ + rand.Reader, rsaSHA512Digest, crypto.SHA512, + }, rsaSHA512Sig, false}, + {"ok RSA-PSS SHA256", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{ + rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + Hash: crypto.SHA256, + }, + }, rsaPSSSHA256Sig, false}, + {"ok RSA-PSS SHA384", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA384}, args{ + rand.Reader, rsaPSSSHA384Digest, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + Hash: crypto.SHA384, + }, + }, rsaPSSSHA384Sig, false}, + {"ok RSA-PSS SHA512", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA512}, args{ + rand.Reader, rsaPSSSHA512Digest, &rsa.PSSOptions{ + SaltLength: 64, + Hash: crypto.SHA512, + }, + }, rsaPSSSHA512Sig, false}, + {"fail Sign", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{ + rand.Reader, rsaSHA256Digest, crypto.SHA256, + }, nil, true}, + {"fail sign length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, nil, true}, + {"fail base64", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, nil, true}, + {"fail RSA-PSS salt length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{ + rand.Reader, rsaPSSSHA256Digest, &rsa.PSSOptions{ + SaltLength: 64, + Hash: crypto.SHA256, + }, + }, nil, true}, + {"fail RSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaSHA256}, args{ + rand.Reader, rsaSHA256Digest, crypto.SHA1, + }, nil, true}, + {"fail ECDSA Hash", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.MD5, + }, nil, true}, + {"fail Ed25519", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", ed25519Key}, args{ + rand.Reader, []byte("message"), crypto.Hash(0), + }, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Signer{ + client: tt.fields.client, + vaultBaseURL: tt.fields.vaultBaseURL, + name: tt.fields.name, + version: tt.fields.version, + publicKey: tt.fields.publicKey, + } + got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.Sign() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSigner_Sign_signWithRetry(t *testing.T) { + sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) { + key, err := keyutil.GenerateSigner(kty, crv, bits) + if err != nil { + t.Fatal(err) + } + h := opts.HashFunc().New() + h.Write([]byte("random-data")) + sum := h.Sum(nil) + + var sig, resultSig []byte + if priv, ok := key.(*ecdsa.PrivateKey); ok { + r, s, err := ecdsa.Sign(rand.Reader, priv, sum) + if err != nil { + t.Fatal(err) + } + curveBits := priv.Params().BitSize + keyBytes := curveBits / 8 + if curveBits%8 > 0 { + keyBytes++ + } + rBytes := r.Bytes() + rBytesPadded := make([]byte, keyBytes) + copy(rBytesPadded[keyBytes-len(rBytes):], rBytes) + + sBytes := s.Bytes() + sBytesPadded := make([]byte, keyBytes) + copy(sBytesPadded[keyBytes-len(sBytes):], sBytes) + // nolint:gocritic + resultSig = append(rBytesPadded, sBytesPadded...) + + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(r) + b.AddASN1BigInt(s) + }) + sig, err = b.Bytes() + if err != nil { + t.Fatal(err) + } + } else { + sig, err = key.Sign(rand.Reader, sum, opts) + if err != nil { + t.Fatal(err) + } + resultSig = sig + } + + return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig + } + + p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256) + okResult := keyvault.KeyOperationResult{ + Result: &p256ResultSig, + } + failResult := keyvault.KeyOperationResult{} + retryError := autorest.DetailedError{ + Original: &azure.RequestError{ + ServiceError: &azure.ServiceError{ + InnerError: map[string]interface{}{ + "code": "KeyNotYetValid", + }, + }, + }, + } + + client := mockClient(t) + expects := []struct { + name string + keyVersion string + alg keyvault.JSONWebKeySignatureAlgorithm + digest []byte + result keyvault.KeyOperationResult + err error + }{ + {"ok 1", "", keyvault.ES256, p256Digest, failResult, retryError}, + {"ok 2", "", keyvault.ES256, p256Digest, failResult, retryError}, + {"ok 3", "", keyvault.ES256, p256Digest, failResult, retryError}, + {"ok 4", "", keyvault.ES256, p256Digest, okResult, nil}, + {"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError}, + {"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError}, + {"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError}, + {"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError}, + } + for _, e := range expects { + value := base64.RawURLEncoding.EncodeToString(e.digest) + client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{ + Algorithm: e.alg, + Value: &value, + }).Return(e.result, e.err) + } + + type fields struct { + client KeyVaultClient + vaultBaseURL string + name string + version string + publicKey crypto.PublicKey + } + type args struct { + rand io.Reader + digest []byte + opts crypto.SignerOpts + } + tests := []struct { + name string + fields fields + args args + want []byte + wantErr bool + }{ + {"ok", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, p256Sig, false}, + {"fail", fields{client, "https://my-vault.vault.azure.net/", "my-key", "fail-version", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Signer{ + client: tt.fields.client, + vaultBaseURL: tt.fields.vaultBaseURL, + name: tt.fields.name, + version: tt.fields.version, + publicKey: tt.fields.publicKey, + } + got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.Sign() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/kms/azurekms/utils.go b/kms/azurekms/utils.go new file mode 100644 index 00000000..d4201907 --- /dev/null +++ b/kms/azurekms/utils.go @@ -0,0 +1,98 @@ +package azurekms + +import ( + "context" + "crypto" + "encoding/json" + "net/url" + "time" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/pkg/errors" + "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/certificates/kms/uri" + "go.step.sm/crypto/jose" +) + +// defaultContext returns the default context used in requests to azure. +func defaultContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 15*time.Second) +} + +// getKeyName returns the uri of the key vault key. +func getKeyName(vault, name string, bundle keyvault.KeyBundle) string { + if bundle.Key != nil && bundle.Key.Kid != nil { + sm := keyIDRegexp.FindAllStringSubmatch(*bundle.Key.Kid, 1) + if len(sm) == 1 && len(sm[0]) == 4 { + m := sm[0] + u := uri.New(Scheme, url.Values{ + "vault": []string{m[1]}, + "name": []string{m[2]}, + }) + u.RawQuery = url.Values{"version": []string{m[3]}}.Encode() + return u.String() + } + } + // Fallback to URI without id. + return uri.New(Scheme, url.Values{ + "vault": []string{vault}, + "name": []string{name}, + }).String() +} + +// parseKeyName returns the key vault, name and version from URIs like: +// +// - azurekms:vault=key-vault;name=key-name +// - azurekms:vault=key-vault;name=key-name?version=key-id +// - azurekms:vault=key-vault;name=key-name?version=key-id&hsm=true +// +// The key-id defines the version of the key, if it is not passed the latest +// version will be used. +// +// HSM can also be passed to define the protection level if this is not given in +// CreateQuery. +func parseKeyName(rawURI string, defaults DefaultOptions) (vault, name, version string, hsm bool, err error) { + var u *uri.URI + + u, err = uri.ParseWithScheme(Scheme, rawURI) + if err != nil { + return + } + if name = u.Get("name"); name == "" { + err = errors.Errorf("key uri %s is not valid: name is missing", rawURI) + return + } + if vault = u.Get("vault"); vault == "" { + if defaults.Vault == "" { + name = "" + err = errors.Errorf("key uri %s is not valid: vault is missing", rawURI) + return + } + vault = defaults.Vault + } + if u.Get("hsm") == "" { + hsm = (defaults.ProtectionLevel == apiv1.HSM) + } else { + hsm = u.GetBool("hsm") + } + + version = u.Get("version") + + return +} + +func vaultBaseURL(vault string) string { + return "https://" + vault + ".vault.azure.net/" +} + +func convertKey(key *keyvault.JSONWebKey) (crypto.PublicKey, error) { + b, err := json.Marshal(key) + if err != nil { + return nil, errors.Wrap(err, "error marshaling key") + } + var jwk jose.JSONWebKey + if err := jwk.UnmarshalJSON(b); err != nil { + return nil, errors.Wrap(err, "error unmarshaling key") + } + return jwk.Key, nil +} diff --git a/kms/azurekms/utils_test.go b/kms/azurekms/utils_test.go new file mode 100644 index 00000000..cded50ea --- /dev/null +++ b/kms/azurekms/utils_test.go @@ -0,0 +1,96 @@ +package azurekms + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/smallstep/certificates/kms/apiv1" +) + +func Test_getKeyName(t *testing.T) { + getBundle := func(kid string) keyvault.KeyBundle { + return keyvault.KeyBundle{ + Key: &keyvault.JSONWebKey{ + Kid: &kid, + }, + } + } + + type args struct { + vault string + name string + bundle keyvault.KeyBundle + } + tests := []struct { + name string + args args + want string + }{ + {"ok", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault?version=my-version"}, + {"ok default", args{"my-vault", "my-key", getBundle("https://my-vault.foo.net/keys/my-key/my-version")}, "azurekms:name=my-key;vault=my-vault"}, + {"ok too short", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-version")}, "azurekms:name=my-key;vault=my-vault"}, + {"ok too long", args{"my-vault", "my-key", getBundle("https://my-vault.vault.azure.net/keys/my-key/my-version/sign")}, "azurekms:name=my-key;vault=my-vault"}, + {"ok nil key", args{"my-vault", "my-key", keyvault.KeyBundle{}}, "azurekms:name=my-key;vault=my-vault"}, + {"ok nil kid", args{"my-vault", "my-key", keyvault.KeyBundle{Key: &keyvault.JSONWebKey{}}}, "azurekms:name=my-key;vault=my-vault"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getKeyName(tt.args.vault, tt.args.name, tt.args.bundle); got != tt.want { + t.Errorf("getKeyName() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseKeyName(t *testing.T) { + var noOptions DefaultOptions + type args struct { + rawURI string + defaults DefaultOptions + } + tests := []struct { + name string + args args + wantVault string + wantName string + wantVersion string + wantHsm bool + wantErr bool + }{ + {"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false}, + {"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false}, + {"ok no version", args{"azurekms:name=my-key;vault=my-vault", noOptions}, "my-vault", "my-key", "", false, false}, + {"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true", noOptions}, "my-vault", "my-key", "", true, false}, + {"ok hsm false", args{"azurekms:name=my-key;vault=my-vault?hsm=false", noOptions}, "my-vault", "my-key", "", false, false}, + {"ok default vault", args{"azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault"}}, "my-vault", "my-key", "my-version", false, false}, + {"ok default hsm", args{"azurekms:name=my-key;vault=my-vault?version=my-version", DefaultOptions{Vault: "other-vault", ProtectionLevel: apiv1.HSM}}, "my-vault", "my-key", "my-version", true, false}, + {"fail scheme", args{"azure:name=my-key;vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail no name", args{"azurekms:vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail empty name", args{"azurekms:name=;vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail no vault", args{"azurekms:name=my-key", noOptions}, "", "", "", false, true}, + {"fail empty vault", args{"azurekms:name=my-key;vault=", noOptions}, "", "", "", false, true}, + {"fail empty", args{"", noOptions}, "", "", "", false, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotVault, gotName, gotVersion, gotHsm, err := parseKeyName(tt.args.rawURI, tt.args.defaults) + if (err != nil) != tt.wantErr { + t.Errorf("parseKeyName() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotVault != tt.wantVault { + t.Errorf("parseKeyName() gotVault = %v, want %v", gotVault, tt.wantVault) + } + if gotName != tt.wantName { + t.Errorf("parseKeyName() gotName = %v, want %v", gotName, tt.wantName) + } + if gotVersion != tt.wantVersion { + t.Errorf("parseKeyName() gotVersion = %v, want %v", gotVersion, tt.wantVersion) + } + if gotHsm != tt.wantHsm { + t.Errorf("parseKeyName() gotHsm = %v, want %v", gotHsm, tt.wantHsm) + } + }) + } +} diff --git a/kms/cloudkms/cloudkms.go b/kms/cloudkms/cloudkms.go index cfbf8235..65d06048 100644 --- a/kms/cloudkms/cloudkms.go +++ b/kms/cloudkms/cloudkms.go @@ -3,6 +3,7 @@ package cloudkms import ( "context" "crypto" + "crypto/x509" "log" "strings" "time" @@ -46,8 +47,8 @@ var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]interface{}{ 4096: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256, }, apiv1.SHA512WithRSA: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{ - 0: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256, - 4096: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256, + 0: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512, + 4096: kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512, }, apiv1.SHA256WithRSAPSS: map[int]kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm{ 0: kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256, @@ -63,6 +64,19 @@ var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]interface{}{ apiv1.ECDSAWithSHA384: kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384, } +var cryptoKeyVersionMapping = map[kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm]x509.SignatureAlgorithm{ + kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256: x509.ECDSAWithSHA256, + kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384: x509.ECDSAWithSHA384, + kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256: x509.SHA256WithRSA, + kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256: x509.SHA256WithRSA, + kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256: x509.SHA256WithRSA, + kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512: x509.SHA512WithRSA, + kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256: x509.SHA256WithRSAPSS, + kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256: x509.SHA256WithRSAPSS, + kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256: x509.SHA256WithRSAPSS, + kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512: x509.SHA512WithRSAPSS, +} + // KeyManagementClient defines the methods on KeyManagementClient that this // package will use. This interface will be used for unit testing. type KeyManagementClient interface { diff --git a/kms/cloudkms/signer.go b/kms/cloudkms/signer.go index 686aca25..5a5443cf 100644 --- a/kms/cloudkms/signer.go +++ b/kms/cloudkms/signer.go @@ -2,6 +2,7 @@ package cloudkms import ( "crypto" + "crypto/x509" "io" "github.com/pkg/errors" @@ -13,6 +14,7 @@ import ( type Signer struct { client KeyManagementClient signingKey string + algorithm x509.SignatureAlgorithm publicKey crypto.PublicKey } @@ -40,7 +42,7 @@ func (s *Signer) preloadKey(signingKey string) error { if err != nil { return errors.Wrap(err, "cloudKMS GetPublicKey failed") } - + s.algorithm = cryptoKeyVersionMapping[response.Algorithm] s.publicKey, err = pemutil.ParseKey([]byte(response.Pem)) return err } @@ -84,3 +86,10 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([] return response.Signature, nil } + +// SignatureAlgorithm returns the algorithm that must be specified in a +// certificate to sign. This is specially important to distinguish RSA and +// RSAPSS schemas. +func (s *Signer) SignatureAlgorithm() x509.SignatureAlgorithm { + return s.algorithm +} diff --git a/kms/cloudkms/signer_test.go b/kms/cloudkms/signer_test.go index fa730fe3..a8f964f1 100644 --- a/kms/cloudkms/signer_test.go +++ b/kms/cloudkms/signer_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto" "crypto/rand" + "crypto/x509" "fmt" "io" "io/ioutil" @@ -156,3 +157,79 @@ func Test_signer_Sign(t *testing.T) { }) } } + +func TestSigner_SignatureAlgorithm(t *testing.T) { + pemBytes, err := ioutil.ReadFile("testdata/pub.pem") + if err != nil { + t.Fatal(err) + } + + client := &MockClient{ + getPublicKey: func(_ context.Context, req *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { + var algorithm kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm + switch req.Name { + case "ECDSA-SHA256": + algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256 + case "ECDSA-SHA384": + algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384 + case "SHA256-RSA-2048": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256 + case "SHA256-RSA-3072": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256 + case "SHA256-RSA-4096": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256 + case "SHA512-RSA-4096": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512 + case "SHA256-RSAPSS-2048": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256 + case "SHA256-RSAPSS-3072": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256 + case "SHA256-RSAPSS-4096": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256 + case "SHA512-RSAPSS-4096": + algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512 + } + return &kmspb.PublicKey{ + Pem: string(pemBytes), + Algorithm: algorithm, + }, nil + }, + } + + if err != nil { + t.Fatal(err) + } + + type fields struct { + client KeyManagementClient + signingKey string + } + tests := []struct { + name string + fields fields + want x509.SignatureAlgorithm + }{ + {"ECDSA-SHA256", fields{client, "ECDSA-SHA256"}, x509.ECDSAWithSHA256}, + {"ECDSA-SHA384", fields{client, "ECDSA-SHA384"}, x509.ECDSAWithSHA384}, + {"SHA256-RSA-2048", fields{client, "SHA256-RSA-2048"}, x509.SHA256WithRSA}, + {"SHA256-RSA-3072", fields{client, "SHA256-RSA-3072"}, x509.SHA256WithRSA}, + {"SHA256-RSA-4096", fields{client, "SHA256-RSA-4096"}, x509.SHA256WithRSA}, + {"SHA512-RSA-4096", fields{client, "SHA512-RSA-4096"}, x509.SHA512WithRSA}, + {"SHA256-RSAPSS-2048", fields{client, "SHA256-RSAPSS-2048"}, x509.SHA256WithRSAPSS}, + {"SHA256-RSAPSS-3072", fields{client, "SHA256-RSAPSS-3072"}, x509.SHA256WithRSAPSS}, + {"SHA256-RSAPSS-4096", fields{client, "SHA256-RSAPSS-4096"}, x509.SHA256WithRSAPSS}, + {"SHA512-RSAPSS-4096", fields{client, "SHA512-RSAPSS-4096"}, x509.SHA512WithRSAPSS}, + {"unknown", fields{client, "UNKNOWN"}, x509.UnknownSignatureAlgorithm}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + signer, err := NewSigner(tt.fields.client, tt.fields.signingKey) + if err != nil { + t.Errorf("NewSigner() error = %v", err) + } + if got := signer.SignatureAlgorithm(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.SignatureAlgorithm() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/kms/kms.go b/kms/kms.go index 3eddca93..92b544df 100644 --- a/kms/kms.go +++ b/kms/kms.go @@ -8,7 +8,7 @@ import ( "github.com/smallstep/certificates/kms/apiv1" // Enable default implementation - _ "github.com/smallstep/certificates/kms/softkms" + "github.com/smallstep/certificates/kms/softkms" ) // KeyManager is the interface implemented by all the KMS. @@ -18,6 +18,12 @@ type KeyManager = apiv1.KeyManager // store x509.Certificates. type CertificateManager = apiv1.CertificateManager +// Options are the KMS options. They represent the kms object in the ca.json. +type Options = apiv1.Options + +// Default is the implementation of the default KMS. +var Default = &softkms.SoftKMS{} + // New initializes a new KMS from the given type. func New(ctx context.Context, opts apiv1.Options) (KeyManager, error) { if err := opts.Validate(); err != nil { diff --git a/kms/pkcs11/benchmark_test.go b/kms/pkcs11/benchmark_test.go index 30e21117..c567872f 100644 --- a/kms/pkcs11/benchmark_test.go +++ b/kms/pkcs11/benchmark_test.go @@ -1,3 +1,4 @@ +//go:build cgo // +build cgo package pkcs11 diff --git a/kms/pkcs11/opensc_test.go b/kms/pkcs11/opensc_test.go index f3b61932..b365e614 100644 --- a/kms/pkcs11/opensc_test.go +++ b/kms/pkcs11/opensc_test.go @@ -1,3 +1,4 @@ +//go:build opensc // +build opensc package pkcs11 diff --git a/kms/pkcs11/other_test.go b/kms/pkcs11/other_test.go index 51732853..9f4ab4a8 100644 --- a/kms/pkcs11/other_test.go +++ b/kms/pkcs11/other_test.go @@ -201,10 +201,10 @@ func (s *privateKey) Delete() error { return nil } -func (s *privateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) { +func (s *privateKey) Decrypt(rnd io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) { k, ok := s.Signer.(*rsa.PrivateKey) if !ok { return nil, errors.New("key is not an rsa key") } - return k.Decrypt(rand, msg, opts) + return k.Decrypt(rnd, msg, opts) } diff --git a/kms/pkcs11/pkcs11.go b/kms/pkcs11/pkcs11.go index 7a418e19..a0c8cea6 100644 --- a/kms/pkcs11/pkcs11.go +++ b/kms/pkcs11/pkcs11.go @@ -15,7 +15,6 @@ import ( "sync" "github.com/ThalesIgnite/crypto11" - "github.com/miekg/pkcs11" "github.com/pkg/errors" "github.com/smallstep/certificates/kms/apiv1" "github.com/smallstep/certificates/kms/uri" @@ -146,8 +145,7 @@ func (k *PKCS11) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons // CreateSigner creates a signer using the key present in the PKCS#11 MODULE signature // slot. func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { - switch { - case req.SigningKey == "": + if req.SigningKey == "" { return nil, errors.New("createSignerRequest 'signingKey' cannot be empty") } @@ -209,9 +207,7 @@ func (k *PKCS11) StoreCertificate(req *apiv1.StoreCertificateRequest) error { return errors.Wrap(err, "storeCertificate failed") } if req.Extractable { - template.AddIfNotPresent([]*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, true), - }) + template.Set(crypto11.CkaExtractable, true) } if err := k.p11.ImportCertificateWithAttributes(template, cert); err != nil { return errors.Wrap(err, "storeCertificate failed") @@ -221,8 +217,8 @@ func (k *PKCS11) StoreCertificate(req *apiv1.StoreCertificateRequest) error { } // DeleteKey is a utility function to delete a key given an uri. -func (k *PKCS11) DeleteKey(uri string) error { - id, object, err := parseObject(uri) +func (k *PKCS11) DeleteKey(u string) error { + id, object, err := parseObject(u) if err != nil { return errors.Wrap(err, "deleteKey failed") } @@ -240,8 +236,8 @@ func (k *PKCS11) DeleteKey(uri string) error { } // DeleteCertificate is a utility function to delete a certificate given an uri. -func (k *PKCS11) DeleteCertificate(uri string) error { - id, object, err := parseObject(uri) +func (k *PKCS11) DeleteCertificate(u string) error { + id, object, err := parseObject(u) if err != nil { return errors.Wrap(err, "deleteCertificate failed") } @@ -309,9 +305,7 @@ func generateKey(ctx P11, req *apiv1.CreateKeyRequest) (crypto11.Signer, error) } private := public.Copy() if req.Extractable { - private.AddIfNotPresent([]*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, true), - }) + private.Set(crypto11.CkaExtractable, true) } bits := req.Bits @@ -339,20 +333,6 @@ func generateKey(ctx P11, req *apiv1.CreateKeyRequest) (crypto11.Signer, error) } } -func GenerateECDSAKeyPairWithLabel(ctx P11, id, label []byte, curve elliptic.Curve, extractable bool) (crypto11.Signer, error) { - public, err := crypto11.NewAttributeSetWithIDAndLabel(id, label) - if err != nil { - return nil, err - } - // Copy the AttributeSet to allow modifications. - private := public.Copy() - private.AddIfNotPresent([]*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, extractable), - }) - - return ctx.GenerateECDSAKeyPairWithAttributes(public, private, curve) -} - func findSigner(ctx P11, rawuri string) (crypto11.Signer, error) { id, object, err := parseObject(rawuri) if err != nil { diff --git a/kms/pkcs11/pkcs11_no_cgo.go b/kms/pkcs11/pkcs11_no_cgo.go index 87c9a36b..6fa51dff 100644 --- a/kms/pkcs11/pkcs11_no_cgo.go +++ b/kms/pkcs11/pkcs11_no_cgo.go @@ -1,3 +1,4 @@ +//go:build !cgo // +build !cgo package pkcs11 diff --git a/kms/pkcs11/pkcs11_test.go b/kms/pkcs11/pkcs11_test.go index 77277366..6df9b92a 100644 --- a/kms/pkcs11/pkcs11_test.go +++ b/kms/pkcs11/pkcs11_test.go @@ -1,3 +1,4 @@ +//go:build cgo // +build cgo package pkcs11 diff --git a/kms/pkcs11/setup_test.go b/kms/pkcs11/setup_test.go index c9ff9311..52dc5207 100644 --- a/kms/pkcs11/setup_test.go +++ b/kms/pkcs11/setup_test.go @@ -1,3 +1,4 @@ +//go:build cgo // +build cgo package pkcs11 diff --git a/kms/pkcs11/softhsm2_test.go b/kms/pkcs11/softhsm2_test.go index 37aa667d..ed2ff208 100644 --- a/kms/pkcs11/softhsm2_test.go +++ b/kms/pkcs11/softhsm2_test.go @@ -1,3 +1,4 @@ +//go:build cgo && softhsm2 // +build cgo,softhsm2 package pkcs11 diff --git a/kms/pkcs11/yubihsm2_test.go b/kms/pkcs11/yubihsm2_test.go index 6d02a420..281aff54 100644 --- a/kms/pkcs11/yubihsm2_test.go +++ b/kms/pkcs11/yubihsm2_test.go @@ -1,3 +1,4 @@ +//go:build cgo && yubihsm2 // +build cgo,yubihsm2 package pkcs11 diff --git a/kms/sshagentkms/sshagentkms_test.go b/kms/sshagentkms/sshagentkms_test.go index 30edd5d1..d3a9e9f5 100644 --- a/kms/sshagentkms/sshagentkms_test.go +++ b/kms/sshagentkms/sshagentkms_test.go @@ -378,6 +378,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) { t.Errorf("SSHAgentKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr) return } + // nolint:gocritic switch s := got.(type) { case *WrappedSSHSigner: gotPkS := s.Sshsigner.PublicKey().(*agent.Key).String() + "\n" @@ -562,6 +563,7 @@ func TestSSHAgentKMS_GetPublicKey(t *testing.T) { t.Errorf("SSHAgentKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr) return } + // nolint:gocritic switch tt.want.(type) { case ssh.PublicKey: // If we want a ssh.PublicKey, protote got to a diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 94009c47..36e15e7d 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -59,7 +59,9 @@ func Parse(rawuri string) (*URI, error) { if u.Scheme == "" { return nil, errors.Errorf("error parsing %s: scheme is missing", rawuri) } - v, err := url.ParseQuery(u.Opaque) + // Starting with Go 1.17 url.ParseQuery returns an error using semicolon as + // separator. + v, err := url.ParseQuery(strings.ReplaceAll(u.Opaque, ";", "&")) if err != nil { return nil, errors.Wrapf(err, "error parsing %s", rawuri) } @@ -93,6 +95,16 @@ func (u *URI) Get(key string) string { return v } +// GetBool returns true if a given key has the value "true". It returns false +// otherwise. +func (u *URI) GetBool(key string) bool { + v := u.Values.Get(key) + if v == "" { + v = u.URL.Query().Get(key) + } + return strings.EqualFold(v, "true") +} + // GetEncoded returns the first value in the uri with the given key, it will // return empty nil if that field is not present or is empty. If the return // value is hex encoded it will decode it and return it. diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index aa420db4..01fbad0f 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -212,6 +212,40 @@ func TestURI_Get(t *testing.T) { } } +func TestURI_GetBool(t *testing.T) { + mustParse := func(s string) *URI { + u, err := Parse(s) + if err != nil { + t.Fatal(err) + } + return u + } + type args struct { + key string + } + tests := []struct { + name string + uri *URI + args args + want bool + }{ + {"true", mustParse("azurekms:name=foo;vault=bar;hsm=true"), args{"hsm"}, true}, + {"TRUE", mustParse("azurekms:name=foo;vault=bar;hsm=TRUE"), args{"hsm"}, true}, + {"tRUe query", mustParse("azurekms:name=foo;vault=bar?hsm=tRUe"), args{"hsm"}, true}, + {"false", mustParse("azurekms:name=foo;vault=bar;hsm=false"), args{"hsm"}, false}, + {"false query", mustParse("azurekms:name=foo;vault=bar?hsm=false"), args{"hsm"}, false}, + {"empty", mustParse("azurekms:name=foo;vault=bar;hsm=?bar=true"), args{"hsm"}, false}, + {"missing", mustParse("azurekms:name=foo;vault=bar"), args{"hsm"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.uri.GetBool(tt.args.key); got != tt.want { + t.Errorf("URI.GetBool() = %v, want %v", got, tt.want) + } + }) + } +} + func TestURI_GetEncoded(t *testing.T) { mustParse := func(s string) *URI { u, err := Parse(s) @@ -274,3 +308,28 @@ func TestURI_Pin(t *testing.T) { }) } } + +func TestURI_String(t *testing.T) { + mustParse := func(s string) *URI { + u, err := Parse(s) + if err != nil { + t.Fatal(err) + } + return u + } + tests := []struct { + name string + uri *URI + want string + }{ + {"ok new", New("yubikey", url.Values{"slot-id": []string{"9a"}, "foo": []string{"bar"}}), "yubikey:foo=bar;slot-id=9a"}, + {"ok parse", mustParse("yubikey:slot-id=9a;foo=bar?bar=zar"), "yubikey:slot-id=9a;foo=bar?bar=zar"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.uri.String(); got != tt.want { + t.Errorf("URI.String() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/kms/yubikey/yubikey.go b/kms/yubikey/yubikey.go index 2dde244a..b1d5f7e3 100644 --- a/kms/yubikey/yubikey.go +++ b/kms/yubikey/yubikey.go @@ -1,3 +1,4 @@ +//go:build cgo // +build cgo package yubikey diff --git a/kms/yubikey/yubikey_no_cgo.go b/kms/yubikey/yubikey_no_cgo.go index 6ed7c630..24a76174 100644 --- a/kms/yubikey/yubikey_no_cgo.go +++ b/kms/yubikey/yubikey_no_cgo.go @@ -1,3 +1,4 @@ +//go:build !cgo // +build !cgo package yubikey diff --git a/make/docker.mk b/make/docker.mk index 8ed25219..edb82423 100644 --- a/make/docker.mk +++ b/make/docker.mk @@ -54,6 +54,8 @@ define DOCKER_BUILDX # $(1) -- Image Tag # $(2) -- Push (empty is no push | --push will push to dockerhub) docker buildx build . --progress plain -t $(DOCKER_IMAGE_NAME):$(1) -f docker/Dockerfile.step-ca --platform="$(DOCKER_PLATFORMS)" $(2) + echo -n "$(COSIGN_PWD)" | cosign sign -key /tmp/cosign.key -r $(DOCKER_IMAGE_NAME):$(1) + endef # For non-master builds don't build the docker containers. diff --git a/pki/helm.go b/pki/helm.go new file mode 100644 index 00000000..0a2f7f02 --- /dev/null +++ b/pki/helm.go @@ -0,0 +1,155 @@ +package pki + +import ( + "io" + "text/template" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority" + authconfig "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/templates" + "go.step.sm/linkedca" +) + +type helmVariables struct { + *linkedca.Configuration + Defaults *linkedca.Defaults + Password string + EnableSSH bool + TLS authconfig.TLSOptions + Provisioners []provisioner.Interface +} + +// WriteHelmTemplate a helm template to configure the +// smallstep/step-certificates helm chart. +func (p *PKI) WriteHelmTemplate(w io.Writer) error { + tmpl, err := template.New("helm").Funcs(templates.StepFuncMap()).Parse(helmTemplate) + if err != nil { + return errors.Wrap(err, "error writing helm template") + } + + // Delete ssh section if it is not enabled + if !p.options.enableSSH { + p.Ssh = nil + } + + // Convert provisioner to ca.json + provisioners := make([]provisioner.Interface, len(p.Authority.Provisioners)) + for i, p := range p.Authority.Provisioners { + pp, err := authority.ProvisionerToCertificates(p) + if err != nil { + return err + } + provisioners[i] = pp + } + + if err := tmpl.Execute(w, helmVariables{ + Configuration: &p.Configuration, + Defaults: &p.Defaults, + Password: "", + EnableSSH: p.options.enableSSH, + TLS: authconfig.DefaultTLSOptions, + Provisioners: provisioners, + }); err != nil { + return errors.Wrap(err, "error executing helm template") + } + return nil +} + +const helmTemplate = `# Helm template +inject: + enabled: true + # Config contains the configuration files ca.json and defaults.json + config: + files: + ca.json: + root: {{ first .Root }} + federateRoots: [] + crt: {{ .Intermediate }} + key: {{ .IntermediateKey }} + {{- if .EnableSSH }} + ssh: + hostKey: {{ .Ssh.HostKey }} + userKey: {{ .Ssh.UserKey }} + {{- end }} + address: {{ .Address }} + dnsNames: + {{- range .DnsNames }} + - {{ . }} + {{- end }} + logger: + format: json + db: + type: badgerv2 + dataSource: /home/step/db + authority: + provisioners: + {{- range .Provisioners }} + - {{ . | toJson }} + {{- end }} + tls: + cipherSuites: + {{- range .TLS.CipherSuites }} + - {{ . }} + {{- end }} + minVersion: {{ .TLS.MinVersion }} + maxVersion: {{ .TLS.MaxVersion }} + renegotiation: {{ .TLS.Renegotiation }} + + defaults.json: + ca-url: {{ .Defaults.CaUrl }} + ca-config: {{ .Defaults.CaConfig }} + fingerprint: {{ .Defaults.Fingerprint }} + root: {{ .Defaults.Root }} + + # Certificates contains the root and intermediate certificate and + # optionally the SSH host and user public keys + certificates: + # intermediate_ca contains the text of the intermediate CA Certificate + intermediate_ca: | + {{- index .Files .Intermediate | toString | nindent 6 }} + + # root_ca contains the text of the root CA Certificate + root_ca: | + {{- first .Root | index .Files | toString | nindent 6 }} + + {{- if .Ssh }} + # ssh_host_ca contains the text of the public ssh key for the SSH root CA + ssh_host_ca: {{ index .Files .Ssh.HostPublicKey | toString }} + + # ssh_user_ca contains the text of the public ssh key for the SSH root CA + ssh_user_ca: {{ index .Files .Ssh.UserPublicKey | toString }} + {{- end }} + + # Secrets contains the root and intermediate keys and optionally the SSH + # private keys + secrets: + # ca_password contains the password used to encrypt x509.intermediate_ca_key, ssh.host_ca_key and ssh.user_ca_key + # This value must be base64 encoded. + ca_password: {{ .Password | b64enc }} + provisioner_password: {{ .Password | b64enc}} + + x509: + # intermediate_ca_key contains the contents of your encrypted intermediate CA key + intermediate_ca_key: | + {{- index .Files .IntermediateKey | toString | nindent 8 }} + + # root_ca_key contains the contents of your encrypted root CA key + # Note that this value can be omitted without impacting the functionality of step-certificates + # If supplied, this should be encrypted using a unique password that is not used for encrypting + # the intermediate_ca_key, ssh.host_ca_key or ssh.user_ca_key. + root_ca_key: | + {{- first .RootKey | index .Files | toString | nindent 8 }} + + {{- if .Ssh }} + ssh: + # ssh_host_ca_key contains the contents of your encrypted SSH Host CA key + host_ca_key: | + {{- index .Files .Ssh.HostKey | toString | nindent 8 }} + + # ssh_user_ca_key contains the contents of your encrypted SSH User CA key + user_ca_key: | + {{- index .Files .Ssh.UserKey | toString | nindent 8 }} + {{- end }} +` diff --git a/pki/pki.go b/pki/pki.go index c95ca985..61e20b6b 100644 --- a/pki/pki.go +++ b/pki/pki.go @@ -10,31 +10,65 @@ import ( "encoding/json" "encoding/pem" "fmt" - "html" "net" "os" "path/filepath" - "strconv" "strings" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" + admindb "github.com/smallstep/certificates/authority/admin/db/nosql" + authconfig "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/cas" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/kms" + kmsapi "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/nosql" "go.step.sm/cli-utils/config" "go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/ui" "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" + "go.step.sm/linkedca" "golang.org/x/crypto/ssh" ) +// DeploymentType defines witch type of deployment a user is initializing +type DeploymentType int + +const ( + // StandaloneDeployment is a deployment where all the components like keys, + // provisioners, admins, certificates and others are managed by the user. + StandaloneDeployment DeploymentType = iota + // LinkedDeployment is a deployment where the keys are managed by the user, + // but provisioners, admins and the record of certificates are managed in + // the cloud. + LinkedDeployment + // HostedDeployment is a deployment where all the components are managed in + // the cloud by smallstep.com/certificate-manager. + HostedDeployment +) + +// String returns the string version of the deployment type. +func (d DeploymentType) String() string { + switch d { + case StandaloneDeployment: + return "standalone" + case LinkedDeployment: + return "linked" + case HostedDeployment: + return "hosted" + default: + return "unknown" + } +} + const ( // ConfigPath is the directory name under the step path where the configuration // files will be stored. @@ -95,7 +129,7 @@ func GetTemplatesPath() string { // GetProvisioners returns the map of provisioners on the given CA. func GetProvisioners(caURL, rootFile string) (provisioner.List, error) { - if len(rootFile) == 0 { + if rootFile == "" { rootFile = GetRootCAPath() } client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile)) @@ -120,7 +154,7 @@ func GetProvisioners(caURL, rootFile string) (provisioner.List, error) { // GetProvisionerKey returns the encrypted provisioner key with the for the // given kid. func GetProvisionerKey(caURL, rootFile, kid string) (string, error) { - if len(rootFile) == 0 { + if rootFile == "" { rootFile = GetRootCAPath() } client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile)) @@ -134,43 +168,150 @@ func GetProvisionerKey(caURL, rootFile, kid string) (string, error) { return resp.Key, nil } +type options struct { + provisioner string + pkiOnly bool + enableACME bool + enableSSH bool + enableAdmin bool + noDB bool + isHelm bool + deploymentType DeploymentType + rootKeyURI string + intermediateKeyURI string + hostKeyURI string + userKeyURI string +} + +// Option is the type of a configuration option on the pki constructor. +type Option func(p *PKI) + +// WithAddress sets the listen address of step-ca. +func WithAddress(s string) Option { + return func(p *PKI) { + p.Address = s + } +} + +// WithCaURL sets the default ca-url of step-ca. +func WithCaURL(s string) Option { + return func(p *PKI) { + p.Defaults.CaUrl = s + } +} + +// WithDNSNames sets the SANs of step-ca. +func WithDNSNames(s []string) Option { + return func(p *PKI) { + p.DnsNames = s + } +} + +// WithProvisioner defines the name of the default provisioner. +func WithProvisioner(s string) Option { + return func(p *PKI) { + p.options.provisioner = s + } +} + +// WithPKIOnly will only generate the PKI without the step-ca config files. +func WithPKIOnly() Option { + return func(p *PKI) { + p.options.pkiOnly = true + } +} + +// WithACME enables acme provisioner in step-ca. +func WithACME() Option { + return func(p *PKI) { + p.options.enableACME = true + } +} + +// WithSSH enables ssh in step-ca. +func WithSSH() Option { + return func(p *PKI) { + p.options.enableSSH = true + } +} + +// WithAdmin enables the admin api in step-ca. +func WithAdmin() Option { + return func(p *PKI) { + p.options.enableAdmin = true + } +} + +// WithNoDB disables the db in step-ca. +func WithNoDB() Option { + return func(p *PKI) { + p.options.noDB = true + } +} + +// WithHelm configures the pki to create a helm values.yaml. +func WithHelm() Option { + return func(p *PKI) { + p.options.isHelm = true + } +} + +// WithDeploymentType defines the deployment type of step-ca. +func WithDeploymentType(dt DeploymentType) Option { + return func(p *PKI) { + p.options.deploymentType = dt + } +} + +// WithKMS enables the kms with the given name. +func WithKMS(name string) Option { + return func(p *PKI) { + typ := linkedca.KMS_Type_value[strings.ToUpper(name)] + p.Configuration.Kms = &linkedca.KMS{ + Type: linkedca.KMS_Type(typ), + } + } +} + +// WithKeyURIs defines the key uris for X.509 and SSH keys. +func WithKeyURIs(rootKey, intermediateKey, hostKey, userKey string) Option { + return func(p *PKI) { + p.options.rootKeyURI = rootKey + p.options.intermediateKeyURI = intermediateKey + p.options.hostKeyURI = hostKey + p.options.userKeyURI = userKey + } +} + // PKI represents the Public Key Infrastructure used by a certificate authority. type PKI struct { - casOptions apiv1.Options - caCreator apiv1.CertificateAuthorityCreator - root, rootKey, rootFingerprint string - intermediate, intermediateKey string - sshHostPubKey, sshHostKey string - sshUserPubKey, sshUserKey string - config, defaults string - ottPublicKey *jose.JSONWebKey - ottPrivateKey *jose.JSONWebEncryption - provisioner string - address string - dnsNames []string - caURL string - enableSSH bool + linkedca.Configuration + Defaults linkedca.Defaults + casOptions apiv1.Options + caService apiv1.CertificateAuthorityService + caCreator apiv1.CertificateAuthorityCreator + keyManager kmsapi.KeyManager + config string + defaults string + ottPublicKey *jose.JSONWebKey + ottPrivateKey *jose.JSONWebEncryption + options *options } // New creates a new PKI configuration. -func New(opts apiv1.Options) (*PKI, error) { - caCreator, err := cas.NewCreator(context.Background(), opts) +func New(o apiv1.Options, opts ...Option) (*PKI, error) { + caService, err := cas.New(context.Background(), o) if err != nil { return nil, err } - public := GetPublicPath() - private := GetSecretsPath() - config := GetConfigPath() - - // Create directories - dirs := []string{public, private, config, GetTemplatesPath()} - for _, name := range dirs { - if _, err := os.Stat(name); os.IsNotExist(err) { - if err = os.MkdirAll(name, 0700); err != nil { - return nil, errs.FileError(err, name) - } + var caCreator apiv1.CertificateAuthorityCreator + if o.IsCreator { + creator, ok := caService.(apiv1.CertificateAuthorityCreator) + if !ok { + return nil, errors.Errorf("cas type '%s' does not implements CertificateAuthorityCreator", o.Type) } + caCreator = creator } // get absolute path for dir/name @@ -180,45 +321,105 @@ func New(opts apiv1.Options) (*PKI, error) { } p := &PKI{ - casOptions: opts, - caCreator: caCreator, - provisioner: "step-cli", - address: "127.0.0.1:9000", - dnsNames: []string{"127.0.0.1"}, + Configuration: linkedca.Configuration{ + Address: "127.0.0.1:9000", + DnsNames: []string{"127.0.0.1"}, + Ssh: &linkedca.SSH{}, + Authority: &linkedca.Authority{}, + Files: make(map[string][]byte), + }, + casOptions: o, + caService: caService, + caCreator: caCreator, + keyManager: o.KeyManager, + options: &options{ + provisioner: "step-cli", + }, } - if p.root, err = getPath(public, "root_ca.crt"); err != nil { - return nil, err + for _, fn := range opts { + fn(p) } - if p.rootKey, err = getPath(private, "root_ca_key"); err != nil { - return nil, err + + // Use default key manager + if p.keyManager == nil { + p.keyManager = kms.Default } - if p.intermediate, err = getPath(public, "intermediate_ca.crt"); err != nil { - return nil, err - } - if p.intermediateKey, err = getPath(private, "intermediate_ca_key"); err != nil { - return nil, err - } - if p.sshHostPubKey, err = getPath(public, "ssh_host_ca_key.pub"); err != nil { - return nil, err - } - if p.sshUserPubKey, err = getPath(public, "ssh_user_ca_key.pub"); err != nil { - return nil, err - } - if p.sshHostKey, err = getPath(private, "ssh_host_ca_key"); err != nil { - return nil, err - } - if p.sshUserKey, err = getPath(private, "ssh_user_ca_key"); err != nil { - return nil, err - } - if len(config) > 0 { - if p.config, err = getPath(config, "ca.json"); err != nil { - return nil, err - } - if p.defaults, err = getPath(config, "defaults.json"); err != nil { - return nil, err + + // Use /home/step as the step path in helm configurations. + // Use the current step path when creating pki in files. + var public, private, cfg string + if p.options.isHelm { + public = "/home/step/certs" + private = "/home/step/secrets" + cfg = "/home/step/config" + } else { + public = GetPublicPath() + private = GetSecretsPath() + cfg = GetConfigPath() + // Create directories + dirs := []string{public, private, cfg, GetTemplatesPath()} + for _, name := range dirs { + if _, err := os.Stat(name); os.IsNotExist(err) { + if err = os.MkdirAll(name, 0700); err != nil { + return nil, errs.FileError(err, name) + } + } } } + if p.Defaults.CaUrl == "" { + p.Defaults.CaUrl = p.DnsNames[0] + _, port, err := net.SplitHostPort(p.Address) + if err != nil { + return nil, errors.Wrapf(err, "error parsing %s", p.Address) + } + // On k8s we usually access through a service, and this is configured on + // port 443 by default. + if port == "443" || p.options.isHelm { + p.Defaults.CaUrl = fmt.Sprintf("https://%s", p.Defaults.CaUrl) + } else { + p.Defaults.CaUrl = fmt.Sprintf("https://%s:%s", p.Defaults.CaUrl, port) + } + } + + root, err := getPath(public, "root_ca.crt") + if err != nil { + return nil, err + } + rootKey, err := getPath(private, "root_ca_key") + if err != nil { + return nil, err + } + p.Root = []string{root} + p.RootKey = []string{rootKey} + p.Defaults.Root = root + + if p.Intermediate, err = getPath(public, "intermediate_ca.crt"); err != nil { + return nil, err + } + if p.IntermediateKey, err = getPath(private, "intermediate_ca_key"); err != nil { + return nil, err + } + if p.Ssh.HostPublicKey, err = getPath(public, "ssh_host_ca_key.pub"); err != nil { + return nil, err + } + if p.Ssh.UserPublicKey, err = getPath(public, "ssh_user_ca_key.pub"); err != nil { + return nil, err + } + if p.Ssh.HostKey, err = getPath(private, "ssh_host_ca_key"); err != nil { + return nil, err + } + if p.Ssh.UserKey, err = getPath(private, "ssh_user_ca_key"); err != nil { + return nil, err + } + if p.defaults, err = getPath(cfg, "defaults.json"); err != nil { + return nil, err + } + if p.config, err = getPath(cfg, "ca.json"); err != nil { + return nil, err + } + p.Defaults.CaConfig = p.config + return p, nil } @@ -229,27 +430,7 @@ func (p *PKI) GetCAConfigPath() string { // GetRootFingerprint returns the root fingerprint. func (p *PKI) GetRootFingerprint() string { - return p.rootFingerprint -} - -// SetProvisioner sets the provisioner name of the OTT keys. -func (p *PKI) SetProvisioner(s string) { - p.provisioner = s -} - -// SetAddress sets the listening address of the CA. -func (p *PKI) SetAddress(s string) { - p.address = s -} - -// SetDNSNames sets the dns names of the CA. -func (p *PKI) SetDNSNames(s []string) { - p.dnsNames = s -} - -// SetCAURL sets the ca-url to use in the defaults.json. -func (p *PKI) SetCAURL(s string) { - p.caURL = s + return p.Defaults.Fingerprint } // GenerateKeyPairs generates the key pairs used by the certificate authority. @@ -261,17 +442,56 @@ func (p *PKI) GenerateKeyPairs(pass []byte) error { return err } + var claims *linkedca.Claims + if p.options.enableSSH { + claims = &linkedca.Claims{ + Ssh: &linkedca.SSHClaims{ + Enabled: true, + }, + } + } + + // Add JWK provisioner to the configuration. + publicKey, err := json.Marshal(p.ottPublicKey) + if err != nil { + return errors.Wrap(err, "error marshaling public key") + } + encryptedKey, err := p.ottPrivateKey.CompactSerialize() + if err != nil { + return errors.Wrap(err, "error serializing private key") + } + p.Authority.Provisioners = append(p.Authority.Provisioners, &linkedca.Provisioner{ + Type: linkedca.Provisioner_JWK, + Name: p.options.provisioner, + Claims: claims, + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_JWK{ + JWK: &linkedca.JWKProvisioner{ + PublicKey: publicKey, + EncryptedPrivateKey: []byte(encryptedKey), + }, + }, + }, + }) + return nil } // GenerateRootCertificate generates a root certificate with the given name // and using the default key type. func (p *PKI) GenerateRootCertificate(name, org, resource string, pass []byte) (*apiv1.CreateCertificateAuthorityResponse, error) { + if uri := p.options.rootKeyURI; uri != "" { + p.RootKey[0] = uri + } + resp, err := p.caCreator.CreateCertificateAuthority(&apiv1.CreateCertificateAuthorityRequest{ - Name: resource + "-Root-CA", - Type: apiv1.RootCA, - Lifetime: 10 * 365 * 24 * time.Hour, - CreateKey: nil, // use default + Name: resource + "-Root-CA", + Type: apiv1.RootCA, + Lifetime: 10 * 365 * 24 * time.Hour, + CreateKey: &apiv1.CreateKeyRequest{ + Name: p.RootKey[0], + SignatureAlgorithm: kmsapi.UnspecifiedSignAlgorithm, + }, Template: &x509.Certificate{ Subject: pkix.Name{ CommonName: name + " Root CA", @@ -288,6 +508,13 @@ func (p *PKI) GenerateRootCertificate(name, org, resource string, pass []byte) ( return nil, err } + // Replace key name with the one from the key manager if available. On + // softcas this will be the original filename, on any other kms will be the + // uri to the key. + if resp.KeyName != "" { + p.RootKey[0] = resp.KeyName + } + // PrivateKey will only be set if we have access to it (SoftCAS). if err := p.WriteRootCertificate(resp.Certificate, resp.PrivateKey, pass); err != nil { return nil, err @@ -296,14 +523,36 @@ func (p *PKI) GenerateRootCertificate(name, org, resource string, pass []byte) ( return resp, nil } +// WriteRootCertificate writes to the buffer the given certificate and key if given. +func (p *PKI) WriteRootCertificate(rootCrt *x509.Certificate, rootKey interface{}, pass []byte) error { + p.Files[p.Root[0]] = encodeCertificate(rootCrt) + if rootKey != nil { + var err error + p.Files[p.RootKey[0]], err = encodePrivateKey(rootKey, pass) + if err != nil { + return err + } + } + sum := sha256.Sum256(rootCrt.Raw) + p.Defaults.Fingerprint = strings.ToLower(hex.EncodeToString(sum[:])) + return nil +} + // GenerateIntermediateCertificate generates an intermediate certificate with // the given name and using the default key type. func (p *PKI) GenerateIntermediateCertificate(name, org, resource string, parent *apiv1.CreateCertificateAuthorityResponse, pass []byte) error { + if uri := p.options.intermediateKeyURI; uri != "" { + p.IntermediateKey = uri + } + resp, err := p.caCreator.CreateCertificateAuthority(&apiv1.CreateCertificateAuthorityRequest{ - Name: resource + "-Intermediate-CA", - Type: apiv1.IntermediateCA, - Lifetime: 10 * 365 * 24 * time.Hour, - CreateKey: nil, // use default + Name: resource + "-Intermediate-CA", + Type: apiv1.IntermediateCA, + Lifetime: 10 * 365 * 24 * time.Hour, + CreateKey: &apiv1.CreateKeyRequest{ + Name: p.IntermediateKey, + SignatureAlgorithm: kmsapi.UnspecifiedSignAlgorithm, + }, Template: &x509.Certificate{ Subject: pkix.Name{ CommonName: name + " Intermediate CA", @@ -322,46 +571,21 @@ func (p *PKI) GenerateIntermediateCertificate(name, org, resource string, parent } p.casOptions.CertificateAuthority = resp.Name - return p.WriteIntermediateCertificate(resp.Certificate, resp.PrivateKey, pass) -} + p.Files[p.Intermediate] = encodeCertificate(resp.Certificate) -// WriteRootCertificate writes to disk the given certificate and key. -func (p *PKI) WriteRootCertificate(rootCrt *x509.Certificate, rootKey interface{}, pass []byte) error { - if err := fileutil.WriteFile(p.root, pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: rootCrt.Raw, - }), 0600); err != nil { - return err + // Replace the key name with the one from the key manager. On softcas this + // will be the original filename, on any other kms will be the uri to the + // key. + if resp.KeyName != "" { + p.IntermediateKey = resp.KeyName } - if rootKey != nil { - _, err := pemutil.Serialize(rootKey, pemutil.WithPassword(pass), pemutil.ToFile(p.rootKey, 0600)) - if err != nil { - return err - } + // If a kms is used it will not have the private key + if resp.PrivateKey != nil { + p.Files[p.IntermediateKey], err = encodePrivateKey(resp.PrivateKey, pass) } - sum := sha256.Sum256(rootCrt.Raw) - p.rootFingerprint = strings.ToLower(hex.EncodeToString(sum[:])) - - return nil -} - -// WriteIntermediateCertificate writes to disk the given certificate and key. -func (p *PKI) WriteIntermediateCertificate(crt *x509.Certificate, key interface{}, pass []byte) error { - if err := fileutil.WriteFile(p.intermediate, pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: crt.Raw, - }), 0600); err != nil { - return err - } - if key != nil { - _, err := pemutil.Serialize(key, pemutil.WithPassword(pass), pemutil.ToFile(p.intermediateKey, 0600)) - if err != nil { - return err - } - } - return nil + return err } // CreateCertificateAuthorityResponse returns a @@ -379,7 +603,7 @@ func (p *PKI) CreateCertificateAuthorityResponse(cert *x509.Certificate, key cry // GetCertificateAuthority attempts to load the certificate authority from the // RA. func (p *PKI) GetCertificateAuthority() error { - srv, ok := p.caCreator.(apiv1.CertificateAuthorityGetter) + srv, ok := p.caService.(apiv1.CertificateAuthorityGetter) if !ok { return nil } @@ -396,8 +620,8 @@ func (p *PKI) GetCertificateAuthority() error { } // Issuer is in the RA - p.intermediate = "" - p.intermediateKey = "" + p.Intermediate = "" + p.IntermediateKey = "" return nil } @@ -405,71 +629,120 @@ func (p *PKI) GetCertificateAuthority() error { // GenerateSSHSigningKeys generates and encrypts a private key used for signing // SSH user certificates and a private key used for signing host certificates. func (p *PKI) GenerateSSHSigningKeys(password []byte) error { - var pubNames = []string{p.sshHostPubKey, p.sshUserPubKey} - var privNames = []string{p.sshHostKey, p.sshUserKey} - for i := 0; i < 2; i++ { - pub, priv, err := keyutil.GenerateDefaultKeyPair() + // Enable SSH + p.options.enableSSH = true + + // Create SSH key used to sign host certificates. Using + // kmsapi.UnspecifiedSignAlgorithm will default to the default algorithm. + name := p.Ssh.HostKey + if uri := p.options.hostKeyURI; uri != "" { + name = uri + } + resp, err := p.keyManager.CreateKey(&kmsapi.CreateKeyRequest{ + Name: name, + SignatureAlgorithm: kmsapi.UnspecifiedSignAlgorithm, + }) + if err != nil { + return err + } + sshKey, err := ssh.NewPublicKey(resp.PublicKey) + if err != nil { + return errors.Wrapf(err, "error converting public key") + } + p.Files[p.Ssh.HostPublicKey] = ssh.MarshalAuthorizedKey(sshKey) + + // On softkms we will have the private key + if resp.PrivateKey != nil { + p.Files[p.Ssh.HostKey], err = encodePrivateKey(resp.PrivateKey, password) if err != nil { return err } - if _, ok := priv.(crypto.Signer); !ok { - return errors.Errorf("key of type %T is not a crypto.Signer", priv) - } - sshKey, err := ssh.NewPublicKey(pub) - if err != nil { - return errors.Wrapf(err, "error converting public key") - } - _, err = pemutil.Serialize(priv, pemutil.WithFilename(privNames[i]), pemutil.WithPassword(password)) + } else { + p.Ssh.HostKey = resp.Name + } + + // Create SSH key used to sign user certificates. Using + // kmsapi.UnspecifiedSignAlgorithm will default to the default algorithm. + name = p.Ssh.UserKey + if uri := p.options.userKeyURI; uri != "" { + name = uri + } + resp, err = p.keyManager.CreateKey(&kmsapi.CreateKeyRequest{ + Name: name, + SignatureAlgorithm: kmsapi.UnspecifiedSignAlgorithm, + }) + if err != nil { + return err + } + sshKey, err = ssh.NewPublicKey(resp.PublicKey) + if err != nil { + return errors.Wrapf(err, "error converting public key") + } + p.Files[p.Ssh.UserPublicKey] = ssh.MarshalAuthorizedKey(sshKey) + + // On softkms we will have the private key + if resp.PrivateKey != nil { + p.Files[p.Ssh.UserKey], err = encodePrivateKey(resp.PrivateKey, password) if err != nil { return err } - if err = fileutil.WriteFile(pubNames[i], ssh.MarshalAuthorizedKey(sshKey), 0600); err != nil { + } else { + p.Ssh.UserKey = resp.Name + } + + return nil +} + +// WriteFiles writes on disk the previously generated files. +func (p *PKI) WriteFiles() error { + for fn, b := range p.Files { + if err := fileutil.WriteFile(fn, b, 0600); err != nil { return err } } - p.enableSSH = true return nil } func (p *PKI) askFeedback() { ui.Println() - ui.Printf("\033[1mFEEDBACK\033[0m %s %s\n", - html.UnescapeString("&#"+strconv.Itoa(128525)+";"), - html.UnescapeString("&#"+strconv.Itoa(127867)+";")) - ui.Println(" The \033[1mstep\033[0m utility is not instrumented for usage statistics. It does not") - ui.Println(" phone home. But your feedback is extremely valuable. Any information you") - ui.Println(" can provide regarding how you’re using `step` helps. Please send us a") - ui.Println(" sentence or two, good or bad: \033[1mfeedback@smallstep.com\033[0m or join") - ui.Println(" \033[1mhttps://github.com/smallstep/certificates/discussions\033[0m.") -} + ui.Println("\033[1mFEEDBACK\033[0m 😍 🍻") + ui.Println(" The \033[1mstep\033[0m utility is not instrumented for usage statistics. It does not phone") + ui.Println(" home. But your feedback is extremely valuable. Any information you can provide") + ui.Println(" regarding how you’re using `step` helps. Please send us a sentence or two,") + ui.Println(" good or bad at \033[1mfeedback@smallstep.com\033[0m or join GitHub Discussions") + ui.Println(" \033[1mhttps://github.com/smallstep/certificates/discussions\033[0m and our Discord ") + ui.Println(" \033[1mhttps://u.step.sm/discord\033[0m.") -// TellPKI outputs the locations of public and private keys generated -// generated for a new PKI. Generally this will consist of a root certificate -// and key and an intermediate certificate and key. -func (p *PKI) TellPKI() { - p.tellPKI() - p.askFeedback() + if p.options.deploymentType == LinkedDeployment { + ui.Println() + ui.Println("\033[1mNEXT STEPS\033[0m") + ui.Println(" 1. Log in or create a Certificate Manager account at \033[1mhttps://u.step.sm/linked\033[0m") + ui.Println(" 2. Add a new authority and select \"Link a step-ca instance\"") + ui.Println(" 3. Follow instructions in browser to start `step-ca` using the `--token` flag") + ui.Println() + } } func (p *PKI) tellPKI() { ui.Println() - if p.casOptions.Is(apiv1.SoftCAS) { - ui.PrintSelected("Root certificate", p.root) - ui.PrintSelected("Root private key", p.rootKey) - ui.PrintSelected("Root fingerprint", p.rootFingerprint) - ui.PrintSelected("Intermediate certificate", p.intermediate) - ui.PrintSelected("Intermediate private key", p.intermediateKey) - } else if p.rootFingerprint != "" { - ui.PrintSelected("Root certificate", p.root) - ui.PrintSelected("Root fingerprint", p.rootFingerprint) - } else { + switch { + case p.casOptions.Is(apiv1.SoftCAS): + ui.PrintSelected("Root certificate", p.Root[0]) + ui.PrintSelected("Root private key", p.RootKey[0]) + ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint) + ui.PrintSelected("Intermediate certificate", p.Intermediate) + ui.PrintSelected("Intermediate private key", p.IntermediateKey) + case p.Defaults.Fingerprint != "": + ui.PrintSelected("Root certificate", p.Root[0]) + ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint) + default: ui.Printf(`{{ "%s" | red }} {{ "Root certificate:" | bold }} failed to retrieve it from RA`+"\n", ui.IconBad) } - if p.enableSSH { - ui.PrintSelected("SSH user root certificate", p.sshUserPubKey) - ui.PrintSelected("SSH user root private key", p.sshUserKey) - ui.PrintSelected("SSH host root certificate", p.sshHostPubKey) - ui.PrintSelected("SSH host root private key", p.sshHostKey) + if p.options.enableSSH { + ui.PrintSelected("SSH user public key", p.Ssh.UserPublicKey) + ui.PrintSelected("SSH user private key", p.Ssh.UserKey) + ui.PrintSelected("SSH host public key", p.Ssh.HostPublicKey) + ui.PrintSelected("SSH host private key", p.Ssh.HostKey) } } @@ -480,176 +753,237 @@ type caDefaults struct { Root string `json:"root"` } -// Option is the type for modifiers over the auth config object. -type Option func(c *authority.Config) error - -// WithDefaultDB is a configuration modifier that adds a default DB stanza to -// the authority config. -func WithDefaultDB() Option { - return func(c *authority.Config) error { - c.DB = &db.Config{ - Type: "badger", - DataSource: GetDBPath(), - } - return nil - } -} - -// WithoutDB is a configuration modifier that adds a default DB stanza to -// the authority config. -func WithoutDB() Option { - return func(c *authority.Config) error { - c.DB = nil - return nil - } -} +// ConfigOption is the type for modifiers over the auth config object. +type ConfigOption func(c *authconfig.Config) error // GenerateConfig returns the step certificates configuration. -func (p *PKI) GenerateConfig(opt ...Option) (*authority.Config, error) { - key, err := p.ottPrivateKey.CompactSerialize() - if err != nil { - return nil, errors.Wrap(err, "error serializing private key") - } - - prov := &provisioner.JWK{ - Name: p.provisioner, - Type: "JWK", - Key: p.ottPublicKey, - EncryptedKey: key, - } - +func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) { var authorityOptions *apiv1.Options if !p.casOptions.Is(apiv1.SoftCAS) { authorityOptions = &p.casOptions } - config := &authority.Config{ - Root: []string{p.root}, - FederatedRoots: []string{}, - IntermediateCert: p.intermediate, - IntermediateKey: p.intermediateKey, - Address: p.address, - DNSNames: p.dnsNames, + cfg := &authconfig.Config{ + Root: p.Root, + FederatedRoots: p.FederatedRoots, + IntermediateCert: p.Intermediate, + IntermediateKey: p.IntermediateKey, + Address: p.Address, + DNSNames: p.DnsNames, Logger: []byte(`{"format": "text"}`), DB: &db.Config{ - Type: "badger", + Type: "badgerv2", DataSource: GetDBPath(), }, - AuthorityConfig: &authority.AuthConfig{ + AuthorityConfig: &authconfig.AuthConfig{ Options: authorityOptions, DisableIssuedAtCheck: false, - Provisioners: provisioner.List{prov}, - }, - TLS: &authority.TLSOptions{ - MinVersion: authority.DefaultTLSMinVersion, - MaxVersion: authority.DefaultTLSMaxVersion, - Renegotiation: authority.DefaultTLSRenegotiation, - CipherSuites: authority.DefaultTLSCipherSuites, + EnableAdmin: false, }, + TLS: &authconfig.DefaultTLSOptions, Templates: p.getTemplates(), } - if p.enableSSH { - enableSSHCA := true - config.SSH = &authority.SSHConfig{ - HostKey: p.sshHostKey, - UserKey: p.sshUserKey, + + // Add linked as a deployment type to detect it on start and provide a + // message if the token is not given. + if p.options.deploymentType == LinkedDeployment { + cfg.AuthorityConfig.DeploymentType = LinkedDeployment.String() + } + + // Enable KMS if necessary + if p.Kms != nil { + cfg.KMS = &kmsapi.Options{ + Type: strings.ToLower(p.Kms.Type.String()), } - // Enable SSH authorization for default JWK provisioner - prov.Claims = &provisioner.Claims{ - EnableSSHCA: &enableSSHCA, + } + + // On standalone deployments add the provisioners to either the ca.json or + // the database. + var provisioners []provisioner.Interface + if p.options.deploymentType == StandaloneDeployment { + key, err := p.ottPrivateKey.CompactSerialize() + if err != nil { + return nil, errors.Wrap(err, "error serializing private key") } - // Add default SSHPOP provisioner - sshpop := &provisioner.SSHPOP{ - Type: "SSHPOP", - Name: "sshpop", - Claims: &provisioner.Claims{ + + prov := &provisioner.JWK{ + Name: p.options.provisioner, + Type: "JWK", + Key: p.ottPublicKey, + EncryptedKey: key, + } + provisioners = append(provisioners, prov) + + // Add default ACME provisioner if enabled + if p.options.enableACME { + provisioners = append(provisioners, &provisioner.ACME{ + Type: "ACME", + Name: "acme", + }) + } + + if p.options.enableSSH { + enableSSHCA := true + cfg.SSH = &authconfig.SSHConfig{ + HostKey: p.Ssh.HostKey, + UserKey: p.Ssh.UserKey, + } + // Enable SSH authorization for default JWK provisioner + prov.Claims = &provisioner.Claims{ EnableSSHCA: &enableSSHCA, - }, + } + + // Add default SSHPOP provisioner + provisioners = append(provisioners, &provisioner.SSHPOP{ + Type: "SSHPOP", + Name: "sshpop", + Claims: &provisioner.Claims{ + EnableSSHCA: &enableSSHCA, + }, + }) } - config.AuthorityConfig.Provisioners = append(config.AuthorityConfig.Provisioners, sshpop) } // Apply configuration modifiers for _, o := range opt { - if err = o(config); err != nil { + if err := o(cfg); err != nil { return nil, err } } - return config, nil + // Set authority.enableAdmin to true + if p.options.enableAdmin { + cfg.AuthorityConfig.EnableAdmin = true + } + + if p.options.deploymentType == StandaloneDeployment { + if !cfg.AuthorityConfig.EnableAdmin { + cfg.AuthorityConfig.Provisioners = provisioners + } else { + // At this moment this code path is never used because `step ca + // init` will always set enableAdmin to false for a standalone + // deployment. Once we move `step beta` commands out of the beta we + // should probably default to this route. + // + // Note that we might want to be able to define the database as a + // flag in `step ca init` so we can write to the proper place. + _db, err := db.New(cfg.DB) + if err != nil { + return nil, err + } + adminDB, err := admindb.New(_db.(nosql.DB), admin.DefaultAuthorityID) + if err != nil { + return nil, err + } + // Add all the provisioners to the db. + var adminID string + for i, p := range provisioners { + prov, err := authority.ProvisionerToLinkedca(p) + if err != nil { + return nil, err + } + if err := adminDB.CreateProvisioner(context.Background(), prov); err != nil { + return nil, err + } + if i == 0 { + adminID = prov.Id + } + } + // Add the first provisioner as an admin. + if err := adminDB.CreateAdmin(context.Background(), &linkedca.Admin{ + AuthorityId: admin.DefaultAuthorityID, + Subject: "step", + Type: linkedca.Admin_SUPER_ADMIN, + ProvisionerId: adminID, + }); err != nil { + return nil, err + } + } + } + + return cfg, nil } // Save stores the pki on a json file that will be used as the certificate // authority configuration. -func (p *PKI) Save(opt ...Option) error { +func (p *PKI) Save(opt ...ConfigOption) error { + // Write generated files + if err := p.WriteFiles(); err != nil { + return err + } + + // Display the files written p.tellPKI() // Generate and write ca.json - config, err := p.GenerateConfig(opt...) - if err != nil { - return err - } - - b, err := json.MarshalIndent(config, "", "\t") - if err != nil { - return errors.Wrapf(err, "error marshaling %s", p.config) - } - if err = fileutil.WriteFile(p.config, b, 0644); err != nil { - return errs.FileError(err, p.config) - } - - // Generate the CA URL. - if p.caURL == "" { - p.caURL = p.dnsNames[0] - var port string - _, port, err = net.SplitHostPort(p.address) + if !p.options.pkiOnly { + cfg, err := p.GenerateConfig(opt...) if err != nil { - return errors.Wrapf(err, "error parsing %s", p.address) + return err } - if port == "443" { - p.caURL = fmt.Sprintf("https://%s", p.caURL) - } else { - p.caURL = fmt.Sprintf("https://%s:%s", p.caURL, port) + + b, err := json.MarshalIndent(cfg, "", "\t") + if err != nil { + return errors.Wrapf(err, "error marshaling %s", p.config) + } + if err = fileutil.WriteFile(p.config, b, 0644); err != nil { + return errs.FileError(err, p.config) } - } - // Generate and write defaults.json - defaults := &caDefaults{ - Root: p.root, - CAConfig: p.config, - CAUrl: p.caURL, - Fingerprint: p.rootFingerprint, - } - b, err = json.MarshalIndent(defaults, "", "\t") - if err != nil { - return errors.Wrapf(err, "error marshaling %s", p.defaults) - } - if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil { - return errs.FileError(err, p.defaults) - } + // Generate and write defaults.json + defaults := &caDefaults{ + Root: p.Defaults.Root, + CAConfig: p.Defaults.CaConfig, + CAUrl: p.Defaults.CaUrl, + Fingerprint: p.Defaults.Fingerprint, + } + b, err = json.MarshalIndent(defaults, "", "\t") + if err != nil { + return errors.Wrapf(err, "error marshaling %s", p.defaults) + } + if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil { + return errs.FileError(err, p.defaults) + } - // Generate and write templates - if err := generateTemplates(config.Templates); err != nil { - return err - } + // Generate and write templates + if err := generateTemplates(cfg.Templates); err != nil { + return err + } - if config.DB != nil { - ui.PrintSelected("Database folder", config.DB.DataSource) - } - if config.Templates != nil { - ui.PrintSelected("Templates folder", GetTemplatesPath()) - } + if cfg.DB != nil { + ui.PrintSelected("Database folder", cfg.DB.DataSource) + } + if cfg.Templates != nil { + ui.PrintSelected("Templates folder", GetTemplatesPath()) + } - ui.PrintSelected("Default configuration", p.defaults) - ui.PrintSelected("Certificate Authority configuration", p.config) - ui.Println() - if p.casOptions.Is(apiv1.SoftCAS) { - ui.Println("Your PKI is ready to go. To generate certificates for individual services see 'step help ca'.") - } else { - ui.Println("Your registration authority is ready to go. To generate certificates for individual services see 'step help ca'.") + ui.PrintSelected("Default configuration", p.defaults) + ui.PrintSelected("Certificate Authority configuration", p.config) + if p.options.deploymentType != LinkedDeployment { + ui.Println() + if p.casOptions.Is(apiv1.SoftCAS) { + ui.Println("Your PKI is ready to go. To generate certificates for individual services see 'step help ca'.") + } else { + ui.Println("Your registration authority is ready to go. To generate certificates for individual services see 'step help ca'.") + } + } } p.askFeedback() - return nil } + +func encodeCertificate(c *x509.Certificate) []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: c.Raw, + }) +} + +func encodePrivateKey(key crypto.PrivateKey, pass []byte) ([]byte, error) { + block, err := pemutil.Serialize(key, pemutil.WithPassword(pass)) + if err != nil { + return nil, err + } + return pem.EncodeToMemory(block), nil +} diff --git a/pki/templates.go b/pki/templates.go index 4c5309bb..3506a96d 100644 --- a/pki/templates.go +++ b/pki/templates.go @@ -13,7 +13,7 @@ import ( // getTemplates returns all the templates enabled func (p *PKI) getTemplates() *templates.Templates { - if !p.enableSSH { + if !p.options.enableSSH { return nil } return &templates.Templates{ diff --git a/scep/api/api.go b/scep/api/api.go index e64eef83..4e02d4a1 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -198,14 +198,14 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return } - provisioner, ok := p.(*provisioner.SCEP) + prov, ok := p.(*provisioner.SCEP) if !ok { api.WriteError(w, errors.New("provisioner must be of type SCEP")) return } ctx := r.Context() - ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(provisioner)) + ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov)) next(w, r.WithContext(ctx)) } } diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 00000000..80d3cdba --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,4 @@ +# Scripts folder + +Please note that `install-step-ra.sh` is referenced on the `files.smallstep.com` S3 website bucket as a redirect to `raw.githubusercontent.com`. If you move it, please update the S3 redirect. + diff --git a/scripts/install-step-ra.sh b/scripts/install-step-ra.sh new file mode 100644 index 00000000..1da64ed6 --- /dev/null +++ b/scripts/install-step-ra.sh @@ -0,0 +1,278 @@ +#!/bin/bash +set -e + +# TODO: +# - Parse params using argbash (argbash.io). Here's a template that I have tested but have not implemented yet: +# +# ARG_OPTIONAL_SINGLE([ca-url], , [the URL of the upstream (issuing) step-ca server]) +# ARG_OPTIONAL_SINGLE([fingerprint], , [the SHA256 fingerprint of the upstream peer step-ca server]) +# ARG_OPTIONAL_SINGLE([provisioner-name], , [the name of a JWK provisioner on the upstream CA that this RA will use]) +# ARG_OPTIONAL_SINGLE([provisioner-password-file], , [the name a file containing the upstream JWK provisioner password]) +# ARG_OPTIONAL_REPEATED([dns-name], , [DNS name of this RA that will appear on its TLS certificate; you may pass this flag multiple times]) +# ARG_OPTIONAL_SINGLE([listen-address], , [the address (and port #) this RA will listen on, eg. :443 or 127.0.0.1:4443]) +# ARG_HELP([This script will install and configure a Registration Authority that connects to an upstream CA running step-ca.]) +# ARGBASH_GO + +echo "This script will install and start a step-ca server running in Registration Authority (RA) mode." +echo "" +echo "You will need an upstream CA (URL and fingerprint)" +echo "Don't have a CA? Sign up for a hosted CA at smallstep.com β€” or run your own." +echo "" + +# Fail if this script is not run as root. +if ! [ $(id -u) = 0 ]; then + echo "This script must be run as root" + exit 1 +fi + +# Architecture detection +arch=$(uname -m) +case $arch in + x86_64) arch="amd64" ;; + x86) arch="386" ;; + i686) arch="386" ;; + i386) arch="386" ;; + aarch64) arch="arm64" ;; + armv5*) arch="armv5" ;; + armv6*) arch="armv6" ;; + armv7*) arch="armv7" ;; +esac + +if [ "$arch" = "armv5" ]; then + echo "This script doesn't work on armv5 machines" + exit 1 +fi + +if ! hash jq &> /dev/null; then + echo "This script requires the jq commmand; please install it." + exit 1 +fi + +if ! hash curl &> /dev/null; then + echo "This script requires the curl commmand; please install it." + exit 1 +fi + +if ! hash tar &> /dev/null; then + echo "This script requires the tar commmand; please install it." + exit 1 +fi + +while [ $# -gt 0 ]; do + case "$1" in + --ca-url) + CA_URL="$2" + shift + shift + ;; + --fingerprint) + CA_FINGERPRINT="$2" + shift + shift + ;; + --provisioner-name) + CA_PROVISIONER_NAME="$2" + shift + shift + ;; + --provisioner-password-file) + CA_PROVISIONER_JWK_PASSWORD_FILE="$2" + shift + shift + ;; + --dns-names) + RA_DNS_NAMES="$2" + shift + shift + ;; + --listen-address) + RA_ADDRESS="$2" + shift + shift + ;; + *) + shift + ;; + esac +done + +# Install step +if ! hash step &> /dev/null; then + echo "Installing 'step' in /usr/bin..." + STEP_VERSION=$(curl -s https://api.github.com/repos/smallstep/cli/releases/latest | jq -r '.tag_name') + + curl -sLO https://github.com/smallstep/cli/releases/download/$STEP_VERSION/step_linux_${STEP_VERSION:1}_$arch.tar.gz + tar xvzf step_linux_${STEP_VERSION:1}_$arch.tar.gz + install -m 0755 -t /usr/bin step_${STEP_VERSION:1}/bin/step + + rm step_linux_${STEP_VERSION:1}_$arch.tar.gz + rm -rf step_${STEP_VERSION:1} +fi + +# Prompt for required parameters +if [ -z "$CA_URL" ]; then + CA_URL="" + while [[ $CA_URL = "" ]]; do + read -p "Issuing CA URL: " CA_URL < /dev/tty + done +fi + +if [ -z "$CA_FINGERPRINT" ]; then + CA_FINGERPRINT="" + while [[ $CA_FINGERPRINT = "" ]]; do + read -p "Issuing CA Fingerprint: " CA_FINGERPRINT < /dev/tty + done +fi + +echo "Bootstrapping with the CA..." +export STEPPATH=$(mktemp -d) + +step ca bootstrap --ca-url $CA_URL --fingerprint $CA_FINGERPRINT + +if [ -z "$CA_PROVISIONER_NAME" ]; then + declare -a provisioners + readarray -t provisioners < <(step ca provisioner list | jq -r '.[] | select(.type == "JWK") | .name') + printf '%s\n' "${provisioners[@]}" + + printf "%b" "\nSelect a JWK provisioner:\n" >&2 + select provisioner in "${provisioners[@]}"; do + if [ -n "$provisioner" ]; then + echo "Using existing provisioner $provisioner." + CA_PROVISIONER_NAME=$provisioner + break + else + echo "Invalid selection!" + fi + done +fi + +if [ -z "$RA_DNS_NAMES" ]; then + RA_DNS_NAMES="" + while [[ $RA_DNS_NAMES = "" ]]; do + echo "What DNS names or IP addresses will your RA use?" + read -p "(e.g. acme.example.com[,1.1.1.1,etc.]): " RA_DNS_NAMES < /dev/tty + done +fi + + +count=0 +ra_dns_names_quoted="" + +for i in ${RA_DNS_NAMES//,/ } +do + if [ "$count" = "0" ]; then + ra_dns_names_quoted="\"$i\"" + else + ra_dns_names_quoted="${ra_dns_names_quoted}, \"$i\"" + fi + count=$((count+1)) +done + +if [ "$count" = "0" ]; then + echo "You must supply at least one RA DNS name" + exit 1 +fi + +echo "Got here" + +if [ -z "$RA_ADDRESS" ]; then + RA_ADDRESS="" + while [[ $RA_ADDRESS = "" ]] ; do + echo "What address should your RA listen on?" + read -p "(e.g. :443 or 10.2.1.201:4430): " RA_ADDRESS < /dev/tty + done +fi + +if [ -z "$CA_PROVISIONER_JWK_PASSWORD_FILE" ]; then + read -s -p "Enter the CA Provisioner Password: " CA_PROVISIONER_JWK_PASSWORD < /dev/tty + printf "%b" "\n" +fi + +echo "Installing 'step-ca' in /usr/bin..." +CA_VERSION=$(curl -s https://api.github.com/repos/smallstep/certificates/releases/latest | jq -r '.tag_name') + +curl -sLO https://github.com/smallstep/certificates/releases/download/$CA_VERSION/step-ca_linux_${CA_VERSION:1}_$arch.tar.gz +tar -xf step-ca_linux_${CA_VERSION:1}_$arch.tar.gz +install -m 0755 -t /usr/bin step-ca_${CA_VERSION:1}/bin/step-ca +setcap CAP_NET_BIND_SERVICE=+eip $(which step-ca) +rm step-ca_linux_${CA_VERSION:1}_$arch.tar.gz +rm -rf step-ca_${CA_VERSION:1} + +echo "Creating 'step' user..." +export STEPPATH=/etc/step-ca + +useradd --system --home $(step path) --shell /bin/false step + +echo "Creating RA configuration..." +mkdir -p $(step path)/db +mkdir -p $(step path)/config + +cat < $(step path)/config/ca.json +{ + "address": "$RA_ADDRESS", + "dnsNames": [$ra_dns_names_quoted], + "db": { + "type": "badgerV2", + "dataSource": "/etc/step-ca/db" + }, + "logger": {"format": "text"}, + "authority": { + "type": "stepcas", + "certificateAuthority": "$CA_URL", + "certificateAuthorityFingerprint": "$CA_FINGERPRINT", + "certificateIssuer": { + "type" : "jwk", + "provisioner": "$CA_PROVISIONER_NAME" + }, + "provisioners": [{ + "type": "ACME", + "name": "acme" + }] + }, + "tls": { + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" + ], + "minVersion": 1.2, + "maxVersion": 1.3, + "renegotiation": false + } +} +EOF + +if ! [ -z "$CA_PROVISIONER_JWK_PASSWORD" ]; then + echo "Saving provisoiner password to $(step path)/password.txt..." + echo $CA_PROVISIONER_JWK_PASSWORD > $(step path)/password.txt +else + echo "Copying provisioner password file to $(step path)/password.txt..." + cp $CA_PROVISIONER_JWK_PASSWORD_FILE $(step path)/password.txt +fi +chmod 440 $(step path)/password.txt + +# Add a service to systemd for the RA. +echo "Creating systemd service step-ca.service..." +curl -sL https://raw.githubusercontent.com/smallstep/certificates/master/systemd/step-ca.service \ + -o /etc/systemd/system/step-ca.service + +echo "Creating RA mode override /etc/systemd/system/step-ca.service.d/local.conf..." +mkdir /etc/systemd/system/step-ca.service.d +cat < /etc/systemd/system/step-ca.service.d/local.conf +[Service] +; The empty ExecStart= clears the inherited ExecStart= value +ExecStart= +ExecStart=/usr/bin/step-ca config/ca.json --issuer-password-file password.txt +EOF + +echo "Starting step-ca.service..." +systemctl daemon-reload + +chown -R step:step $(step path) + +systemctl enable --now step-ca + +echo "Adding STEPPATH export to /root/.bash_profile..." +echo "export STEPPATH=$STEPPATH" >> /root/.bash_profile + +echo "Finished. Check the journal with journalctl -fu step-ca.service" + diff --git a/server/server.go b/server/server.go index d3968c4a..2b864148 100644 --- a/server/server.go +++ b/server/server.go @@ -72,10 +72,10 @@ func (srv *Server) Serve(ln net.Listener) error { // Start server if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) { log.Printf("Serving HTTP on %s ...", srv.Addr) - err = srv.Server.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) + err = srv.Server.Serve(ln) } else { log.Printf("Serving HTTPS on %s ...", srv.Addr) - err = srv.Server.ServeTLS(tcpKeepAliveListener{ln.(*net.TCPListener)}, "", "") + err = srv.Server.ServeTLS(ln, "", "") } // log unexpected errors @@ -155,21 +155,3 @@ func (srv *Server) Forbidden(w http.ResponseWriter) { w.WriteHeader(http.StatusForbidden) w.Write([]byte("Forbidden.\n")) } - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - tc, err := ln.AcceptTCP() - if err != nil { - return - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - return tc, nil -} diff --git a/systemd/cert-renewer@.service b/systemd/cert-renewer@.service index f38951b5..a9962c2e 100644 --- a/systemd/cert-renewer@.service +++ b/systemd/cert-renewer@.service @@ -12,17 +12,13 @@ Environment=STEPPATH=/etc/step-ca \ CERT_LOCATION=/etc/step/certs/%i.crt \ KEY_LOCATION=/etc/step/certs/%i.key -; ExecStartPre checks if the certificate is ready for renewal, +; ExecCondition checks if the certificate is ready for renewal, ; based on the exit status of the command. -; (In systemd 243 and above, you can use ExecCondition= here.) -ExecStartPre=/usr/bin/env bash -c \ - 'step certificate inspect $CERT_LOCATION --format json --roots "$STEPPATH/certs/root_ca.crt" | \ - jq -e "(((.validity.start | fromdate) + \ - ((.validity.end | fromdate) - (.validity.start | fromdate)) * 0.66) \ - - now) <= 0" > /dev/null' +; (In systemd 242 or below, you can use ExecStartPre= here.) +ExecCondition=/usr/bin/step certificate needs-renewal ${CERT_LOCATION} ; ExecStart renews the certificate, if ExecStartPre was successful. -ExecStart=/usr/bin/step ca renew --force $CERT_LOCATION $KEY_LOCATION +ExecStart=/usr/bin/step ca renew --force ${CERT_LOCATION} ${KEY_LOCATION} ; Try to reload or restart the systemd service that relies on this cert-renewer ; If the relying service doesn't exist, forge ahead. diff --git a/templates/templates.go b/templates/templates.go index f98fb866..09416b68 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "os" "path/filepath" + "strings" "text/template" "github.com/Masterminds/sprig/v3" @@ -182,7 +183,7 @@ func (t *Template) Load() error { // the template fails. func (t *Template) LoadBytes(b []byte) error { t.backfill(b) - tmpl, err := template.New(t.Name).Funcs(sprig.TxtFuncMap()).Parse(string(b)) + tmpl, err := template.New(t.Name).Funcs(StepFuncMap()).Parse(string(b)) if err != nil { return errors.Wrapf(err, "error parsing template %s", t.Name) } @@ -226,14 +227,11 @@ func (t *Template) Output(data interface{}) (Output, error) { // backfill updates old templates with the required data. func (t *Template) backfill(b []byte) { - switch t.Name { - case "sshd_config.tpl": - if len(t.RequiredData) == 0 { - a := bytes.TrimSpace(b) - b := bytes.TrimSpace([]byte(DefaultSSHTemplateData[t.Name])) - if bytes.Equal(a, b) { - t.RequiredData = []string{"Certificate", "Key"} - } + if strings.EqualFold(t.Name, "sshd_config.tpl") && len(t.RequiredData) == 0 { + a := bytes.TrimSpace(b) + b := bytes.TrimSpace([]byte(DefaultSSHTemplateData[t.Name])) + if bytes.Equal(a, b) { + t.RequiredData = []string{"Certificate", "Key"} } } } @@ -272,3 +270,12 @@ func mkdir(path string, perm os.FileMode) error { } return nil } + +// StepFuncMap returns sprig.TxtFuncMap but removing the "env" and "expandenv" +// functions to avoid any leak of information. +func StepFuncMap() template.FuncMap { + m := sprig.TxtFuncMap() + delete(m, "env") + delete(m, "expandenv") + return m +}