forked from TrueCloudLab/certificates
Merge branch 'smallstep_master' into extractable
This commit is contained in:
commit
aa80bf9f07
190 changed files with 14980 additions and 2736 deletions
9
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
9
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Ask on Discord
|
||||
url: https://discord.gg/7xgjhVAg6g
|
||||
about: You can ask for help here!
|
||||
- name: Want to contribute to step certificates?
|
||||
url: https://github.com/smallstep/certificates/blob/master/docs/CONTRIBUTING.md
|
||||
about: Be sure to read contributing guidelines!
|
||||
|
2
.github/ISSUE_TEMPLATE/enhancement.md
vendored
2
.github/ISSUE_TEMPLATE/enhancement.md
vendored
|
@ -1,5 +1,5 @@
|
|||
---
|
||||
name: Certificates Enhancement
|
||||
name: Enhancement
|
||||
about: Suggest an enhancement to step certificates
|
||||
title: ''
|
||||
labels: enhancement, needs triage
|
||||
|
|
10
.github/workflows/labeler.yml
vendored
10
.github/workflows/labeler.yml
vendored
|
@ -1,14 +1,12 @@
|
|||
name: labeler
|
||||
name: Pull Request Labeler
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
pull_request_target
|
||||
|
||||
jobs:
|
||||
label:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v3
|
||||
- uses: actions/labeler@v3.0.2
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
configuration-path: .github/needs-triage-labeler.yml
|
||||
|
||||
|
|
89
.github/workflows/release.yml
vendored
89
.github/workflows/release.yml
vendored
|
@ -12,7 +12,7 @@ jobs:
|
|||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
matrix:
|
||||
go: [ '1.15', '1.16' ]
|
||||
go: [ '1.15', '1.16', '1.17' ]
|
||||
outputs:
|
||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||
steps:
|
||||
|
@ -62,8 +62,15 @@ jobs:
|
|||
needs: test
|
||||
runs-on: ubuntu-20.04
|
||||
outputs:
|
||||
debversion: ${{ steps.extract-tag.outputs.DEB_VERSION }}
|
||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||
steps:
|
||||
-
|
||||
name: Extract Tag Names
|
||||
id: extract-tag
|
||||
run: |
|
||||
DEB_VERSION=$(echo ${GITHUB_REF#refs/tags/v} | sed 's/-/./')
|
||||
echo "::set-output name=DEB_VERSION::${DEB_VERSION}"
|
||||
-
|
||||
name: Is Pre-release
|
||||
id: is_prerelease
|
||||
|
@ -99,62 +106,71 @@ jobs:
|
|||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.16
|
||||
-
|
||||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@56f5b77f7fa4a8fe068bf22b732ec036cc9bc13f # v2.4.1
|
||||
with:
|
||||
version: latest
|
||||
args: release --rm-dist
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.PAT }}
|
||||
|
||||
release_deb:
|
||||
name: Build & Upload Debian Package To Github
|
||||
runs-on: ubuntu-20.04
|
||||
needs: create_release
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: '1.16'
|
||||
go-version: 1.17
|
||||
-
|
||||
name: APT Install
|
||||
id: aptInstall
|
||||
run: sudo apt-get -y install build-essential debhelper fakeroot
|
||||
-
|
||||
name: Build Debian package
|
||||
id: build
|
||||
id: make_debian
|
||||
run: |
|
||||
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
|
||||
make debian
|
||||
# need to restore the git state otherwise goreleaser fails due to dirty state
|
||||
git restore debian/changelog
|
||||
git clean -fd
|
||||
-
|
||||
name: Upload Debian Package
|
||||
id: upload_deb
|
||||
name: Install cosign
|
||||
uses: sigstore/cosign-installer@v1.1.0
|
||||
with:
|
||||
cosign-release: 'v1.1.0'
|
||||
-
|
||||
name: Write cosign key to disk
|
||||
id: write_key
|
||||
run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
|
||||
-
|
||||
name: Get Release Date
|
||||
id: release_date
|
||||
run: |
|
||||
tag_name="${GITHUB_REF##*/}"
|
||||
hub release edit $(find ./.releases -type f -printf "-a %p ") -m "" "$tag_name"
|
||||
RELEASE_DATE=$(date +"%y-%m-%d")
|
||||
echo "::set-output name=RELEASE_DATE::${RELEASE_DATE}"
|
||||
-
|
||||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0
|
||||
with:
|
||||
version: latest
|
||||
args: release --rm-dist
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_TOKEN: ${{ secrets.PAT }}
|
||||
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}
|
||||
DEB_VERSION: ${{ needs.create_release.outputs.debversion }}
|
||||
RELEASE_DATE: ${{ steps.release_date.outputs.RELEASE_DATE }}
|
||||
|
||||
build_upload_docker:
|
||||
name: Build & Upload Docker Images
|
||||
runs-on: ubuntu-20.04
|
||||
needs: test
|
||||
steps:
|
||||
- name: Checkout
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
- name: Setup Go
|
||||
-
|
||||
name: Setup Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: '1.16'
|
||||
- name: Build
|
||||
go-version: '1.17'
|
||||
-
|
||||
name: Install cosign
|
||||
uses: sigstore/cosign-installer@v1.1.0
|
||||
with:
|
||||
cosign-release: 'v1.1.0'
|
||||
-
|
||||
name: Write cosign key to disk
|
||||
id: write_key
|
||||
run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
|
||||
-
|
||||
name: Build
|
||||
id: build
|
||||
run: |
|
||||
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
|
||||
|
@ -162,3 +178,4 @@ jobs:
|
|||
env:
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }}
|
||||
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}
|
||||
|
|
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
@ -14,7 +14,7 @@ jobs:
|
|||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
matrix:
|
||||
go: [ '1.15', '1.16' ]
|
||||
go: [ '1.15', '1.16', '1.17' ]
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
|
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -14,8 +14,8 @@
|
|||
|
||||
# Others
|
||||
*.swp
|
||||
.travis-releases
|
||||
.releases
|
||||
coverage.txt
|
||||
vendor
|
||||
output
|
||||
vendor
|
||||
.idea
|
||||
|
|
|
@ -8,7 +8,7 @@ linters-settings:
|
|||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf
|
||||
golint:
|
||||
revive:
|
||||
min-confidence: 0
|
||||
gocyclo:
|
||||
min-complexity: 10
|
||||
|
@ -36,22 +36,30 @@ linters-settings:
|
|||
- performance
|
||||
- style
|
||||
- experimental
|
||||
- diagnostic
|
||||
disabled-checks:
|
||||
- wrapperFunc
|
||||
- dupImport # https://github.com/go-critic/go-critic/issues/845
|
||||
- commentFormatting
|
||||
- commentedOutCode
|
||||
- evalOrder
|
||||
- hugeParam
|
||||
- octalLiteral
|
||||
- rangeValCopy
|
||||
- tooManyResultsChecker
|
||||
- unnamedResult
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- gofmt
|
||||
- golint
|
||||
- govet
|
||||
- misspell
|
||||
- ineffassign
|
||||
- deadcode
|
||||
- gocritic
|
||||
- gofmt
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- misspell
|
||||
- revive
|
||||
- staticcheck
|
||||
- unused
|
||||
- gosimple
|
||||
|
||||
run:
|
||||
skip-dirs:
|
||||
|
|
154
.goreleaser.yml
154
.goreleaser.yml
|
@ -1,34 +1,27 @@
|
|||
# This is an example .goreleaser.yml file with some sane defaults.
|
||||
# Make sure to check the documentation at http://goreleaser.com
|
||||
project_name: step-ca
|
||||
|
||||
before:
|
||||
hooks:
|
||||
# You may remove this if you don't use go modules.
|
||||
- go mod download
|
||||
|
||||
builds:
|
||||
-
|
||||
id: step-ca
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm
|
||||
- arm64
|
||||
- 386
|
||||
goarm:
|
||||
- 6
|
||||
- 7
|
||||
ignore:
|
||||
- goos: windows
|
||||
goarch: 386
|
||||
- goos: windows
|
||||
goarm: 6
|
||||
- goos: windows
|
||||
goarm: 7
|
||||
targets:
|
||||
- darwin_amd64
|
||||
- darwin_arm64
|
||||
- freebsd_amd64
|
||||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
flags:
|
||||
- -trimpath
|
||||
main: ./cmd/step-ca/main.go
|
||||
|
@ -39,25 +32,16 @@ builds:
|
|||
id: step-cloudkms-init
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm
|
||||
- arm64
|
||||
- 386
|
||||
goarm:
|
||||
- 6
|
||||
- 7
|
||||
ignore:
|
||||
- goos: windows
|
||||
goarch: 386
|
||||
- goos: windows
|
||||
goarm: 6
|
||||
- goos: windows
|
||||
goarm: 7
|
||||
targets:
|
||||
- darwin_amd64
|
||||
- darwin_arm64
|
||||
- freebsd_amd64
|
||||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
flags:
|
||||
- -trimpath
|
||||
main: ./cmd/step-cloudkms-init/main.go
|
||||
|
@ -68,31 +52,23 @@ builds:
|
|||
id: step-awskms-init
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm
|
||||
- arm64
|
||||
- 386
|
||||
goarm:
|
||||
- 6
|
||||
- 7
|
||||
ignore:
|
||||
- goos: windows
|
||||
goarch: 386
|
||||
- goos: windows
|
||||
goarm: 6
|
||||
- goos: windows
|
||||
goarm: 7
|
||||
targets:
|
||||
- darwin_amd64
|
||||
- darwin_arm64
|
||||
- freebsd_amd64
|
||||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
flags:
|
||||
- -trimpath
|
||||
main: ./cmd/step-awskms-init/main.go
|
||||
binary: bin/step-awskms-init
|
||||
ldflags:
|
||||
- -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}}
|
||||
|
||||
archives:
|
||||
-
|
||||
# Can be used to change the archive formats for specific GOOSs.
|
||||
|
@ -106,13 +82,25 @@ archives:
|
|||
files:
|
||||
- README.md
|
||||
- LICENSE
|
||||
|
||||
source:
|
||||
enabled: true
|
||||
name_template: '{{ .ProjectName }}_{{ .Version }}'
|
||||
|
||||
checksum:
|
||||
name_template: 'checksums.txt'
|
||||
extra_files:
|
||||
- glob: ./.releases/*
|
||||
|
||||
signs:
|
||||
- cmd: cosign
|
||||
stdin: '{{ .Env.COSIGN_PWD }}'
|
||||
args: ["sign-blob", "-key=/tmp/cosign.key", "-output=${signature}", "${artifact}"]
|
||||
artifacts: all
|
||||
|
||||
snapshot:
|
||||
name_template: "{{ .Tag }}-next"
|
||||
|
||||
release:
|
||||
# Repo in which the release will be created.
|
||||
# Default is extracted from the origin remote URL or empty if its private hosted.
|
||||
|
@ -139,7 +127,55 @@ release:
|
|||
|
||||
# You can change the name of the release.
|
||||
# Default is `{{.Tag}}`
|
||||
#name_template: "{{.ProjectName}}-v{{.Version}} {{.Env.USER}}"
|
||||
name_template: "Step CA {{ .Tag }} ({{ .Env.RELEASE_DATE }})"
|
||||
|
||||
# Header template for the release body.
|
||||
# Defaults to empty.
|
||||
header: |
|
||||
## Official Release Artifacts
|
||||
|
||||
#### Linux
|
||||
|
||||
- 📦 [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz)
|
||||
- 📦 [step-ca_{{ .Env.DEB_VERSION }}_amd64.deb](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ .Env.DEB_VERSION }}_amd64.deb)
|
||||
|
||||
#### OSX Darwin
|
||||
|
||||
- 📦 [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz)
|
||||
- 📦 [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz)
|
||||
|
||||
#### Windows
|
||||
|
||||
- 📦 [step-ca_windows_{{ .Version }}_arm64.zip](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip)
|
||||
|
||||
For more builds across platforms and architectures, see the `Assets` section below.
|
||||
And for packaged versions (Docker, k8s, Homebrew), see our [installation docs](https://smallstep.com/docs/step-ca/installation).
|
||||
|
||||
Don't see the artifact you need? Open an issue [here](https://github.com/smallstep/certificates/issues/new/choose).
|
||||
|
||||
## Signatures and Checksums
|
||||
|
||||
`step-ca` uses [sigstore/cosign](https://github.com/sigstore/cosign) for signing and verifying release artifacts.
|
||||
|
||||
Below is an example using `cosign` to verify a release artifact:
|
||||
|
||||
```
|
||||
cosign verify-blob \
|
||||
-key https://raw.githubusercontent.com/smallstep/certificates/master/cosign.pub \
|
||||
-signature ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig
|
||||
~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz
|
||||
```
|
||||
|
||||
The `checksums.txt` file (in the `Assets` section below) contains a checksum for every artifact in the release.
|
||||
|
||||
# Footer template for the release body.
|
||||
# Defaults to empty.
|
||||
footer: |
|
||||
## Thanks!
|
||||
|
||||
Those were the changes on {{ .Tag }}!
|
||||
|
||||
Come join us on [Discord](https://discord.gg/X2RKGwEbV9) to ask questions, chat about PKI, or get a sneak peak at the freshest PKI memes.
|
||||
|
||||
# You can disable this pipe in order to not upload any artifacts.
|
||||
# Defaults to false.
|
||||
|
@ -149,6 +185,8 @@ release:
|
|||
# The filename on the release will be the last part of the path (base). If
|
||||
# another file with the same name exists, the latest one found will be used.
|
||||
# Defaults to empty.
|
||||
extra_files:
|
||||
- glob: ./.releases/*
|
||||
#extra_files:
|
||||
# - glob: ./path/to/file.txt
|
||||
# - glob: ./glob/**/to/**/file/**/*
|
||||
|
|
53
CHANGELOG.md
53
CHANGELOG.md
|
@ -4,10 +4,61 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
||||
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased - 0.0.1] - DATE
|
||||
## [Unreleased - 0.17.7] - DATE
|
||||
### Added
|
||||
### Changed
|
||||
### Deprecated
|
||||
### Removed
|
||||
### Fixed
|
||||
### Security
|
||||
|
||||
## [0.17.6] - 2021-10-20
|
||||
### Notes
|
||||
- 0.17.5 failed in CI/CD
|
||||
|
||||
## [0.17.5] - 2021-10-20
|
||||
### Added
|
||||
- Support for Azure Key Vault as a KMS.
|
||||
- Adapt `pki` package to support key managers.
|
||||
- gocritic linter
|
||||
### Fixed
|
||||
- gocritic warnings
|
||||
|
||||
## [0.17.4] - 2021-09-28
|
||||
### Fixed
|
||||
- Support host-only or user-only SSH CA.
|
||||
|
||||
## [0.17.3] - 2021-09-24
|
||||
### Added
|
||||
- go 1.17 to github action test matrix
|
||||
- Support for CloudKMS RSA-PSS signers without using templates.
|
||||
- Add flags to support individual passwords for the intermediate and SSH keys.
|
||||
- Global support for group admins in the OIDC provisioner.
|
||||
### Changed
|
||||
- Using go 1.17 for binaries
|
||||
### Fixed
|
||||
- Upgrade go-jose.v2 to fix a bug in the JWK fingerprint of Ed25519 keys.
|
||||
### Security
|
||||
- Use cosign to sign and upload signatures for multi-arch Docker container.
|
||||
- Add debian checksum
|
||||
|
||||
## [0.17.2] - 2021-08-30
|
||||
### Added
|
||||
- Additional way to distinguish Azure IID and Azure OIDC tokens.
|
||||
### Security
|
||||
- Sign over all goreleaser github artifacts using cosign
|
||||
|
||||
## [0.17.1] - 2021-08-26
|
||||
|
||||
## [0.17.0] - 2021-08-25
|
||||
### Added
|
||||
- Add support for Linked CAs using protocol buffers and gRPC
|
||||
- `step-ca init` adds support for
|
||||
- configuring a StepCAS RA
|
||||
- configuring a Linked CA
|
||||
- congifuring a `step-ca` using Helm
|
||||
### Changed
|
||||
- Update badger driver to use v2 by default
|
||||
- Update TLS cipher suites to include 1.3
|
||||
### Security
|
||||
- Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512.
|
||||
|
|
7
Makefile
7
Makefile
|
@ -15,6 +15,7 @@ PREFIX?=
|
|||
SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*")
|
||||
GOOS_OVERRIDE ?=
|
||||
OUTPUT_ROOT=output/
|
||||
RELEASE=./.releases
|
||||
|
||||
all: lint test build
|
||||
|
||||
|
@ -28,7 +29,7 @@ ci: testcgo build
|
|||
|
||||
bootstra%:
|
||||
# Using a released version of golangci-lint to take into account custom replacements in their go.mod
|
||||
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.39.0
|
||||
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v1.42.0
|
||||
|
||||
.PHONY: bootstra%
|
||||
|
||||
|
@ -67,7 +68,7 @@ PUSHTYPE := branch
|
|||
endif
|
||||
|
||||
VERSION := $(shell echo $(VERSION) | sed 's/^v//')
|
||||
DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/~/g')
|
||||
DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/./g')
|
||||
|
||||
ifdef V
|
||||
$(info TRAVIS_TAG is $(TRAVIS_TAG))
|
||||
|
@ -153,7 +154,7 @@ fmt:
|
|||
$Q gofmt -l -w $(SRC)
|
||||
|
||||
lint:
|
||||
$Q $(GOFLAGS) LOG_LEVEL=error golangci-lint run --timeout=30m
|
||||
$Q golangci-lint run --timeout=30m
|
||||
|
||||
lintcgo:
|
||||
$Q LOG_LEVEL=error golangci-lint run --timeout=30m
|
||||
|
|
36
README.md
36
README.md
|
@ -18,7 +18,14 @@ You can use it to:
|
|||
|
||||
Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [safe, sane defaults](https://smallstep.com/docs/step-ca/certificate-authority-server-production#sane-cryptographic-defaults).
|
||||
|
||||
**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions).**
|
||||
---
|
||||
|
||||
**Don't want to run your own CA?**
|
||||
To get up and running quickly, or as an alternative to running your own `step-ca` server, consider creating a [free hosted smallstep Certificate Manager authority](https://info.smallstep.com/certificate-manager-early-access-mvp/).
|
||||
|
||||
---
|
||||
|
||||
**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions) or [Join our Discord](https://u.step.sm/discord).**
|
||||
|
||||
[Website](https://smallstep.com/certificates) |
|
||||
[Documentation](https://smallstep.com/docs) |
|
||||
|
@ -27,7 +34,6 @@ Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [
|
|||
[Contributor's Guide](./docs/CONTRIBUTING.md)
|
||||
|
||||
[![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest)
|
||||
[![CA Image](https://images.microbadger.com/badges/image/smallstep/step-ca.svg)](https://microbadger.com/images/smallstep/step-ca)
|
||||
[![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates)
|
||||
[![Build Status](https://travis-ci.com/smallstep/certificates.svg?branch=master)](https://travis-ci.com/smallstep/certificates)
|
||||
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
|
||||
|
@ -58,10 +64,11 @@ You can issue certificates in exchange for:
|
|||
- ID tokens from Okta, GSuite, Azure AD, Auth0.
|
||||
- ID tokens from an OAuth OIDC service that you host, like [Keycloak](https://www.keycloak.org/) or [Dex](https://github.com/dexidp/dex)
|
||||
- [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure
|
||||
- [Single-use, short-lived JWK tokens]() issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
|
||||
- [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
|
||||
- A trusted X.509 certificate (X5C provisioner)
|
||||
- Expiring SSH host certificates needing rotation (the SSHPOP provisioner)
|
||||
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/configuration#jwk)
|
||||
- A SCEP challenge (SCEP provisioner)
|
||||
- An SSH host certificates needing renewal (the SSHPOP provisioner)
|
||||
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners)
|
||||
|
||||
### 🏔 Your own private ACME server
|
||||
|
||||
|
@ -74,16 +81,17 @@ ACME is the protocol used by Let's Encrypt to automate the issuance of HTTPS cer
|
|||
- For `tls-alpn-01`, respond to the challenge at the TLS layer ([as Caddy does](https://caddy.community/t/caddy-supports-the-acme-tls-alpn-challenge/4860)) to prove that you control the web server
|
||||
|
||||
- Works with any ACME client. We've written examples for:
|
||||
- [certbot](https://smallstep.com/blog/private-acme-server/#certbotuploadsacme-certbotpng-certbot-example)
|
||||
- [acme.sh](https://smallstep.com/blog/private-acme-server/#acmeshuploadsacme-acme-shpng-acmesh-example)
|
||||
- [Caddy](https://smallstep.com/blog/private-acme-server/#caddyuploadsacme-caddypng-caddy-example)
|
||||
- [Traefik](https://smallstep.com/blog/private-acme-server/#traefikuploadsacme-traefikpng-traefik-example)
|
||||
- [Apache](https://smallstep.com/blog/private-acme-server/#apacheuploadsacme-apachepng-apache-example)
|
||||
- [nginx](https://smallstep.com/blog/private-acme-server/#nginxuploadsacme-nginxpng-nginx-example)
|
||||
- [certbot](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#certbot)
|
||||
- [acme.sh](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#acmesh)
|
||||
- [win-acme](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#win-acme)
|
||||
- [Caddy](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#caddy-v2)
|
||||
- [Traefik](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#traefik)
|
||||
- [Apache](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#apache)
|
||||
- [nginx](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#nginx)
|
||||
- Get certificates programmatically using ACME, using these libraries:
|
||||
- [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/blog/private-acme-server/#golanguploadsacme-golangpng-go-example))
|
||||
- certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/blog/private-acme-server/#pythonuploadsacme-pythonpng-python-example))
|
||||
- [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/blog/private-acme-server/#nodejsuploadsacme-node-jspng-nodejs-example))
|
||||
- [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#golang))
|
||||
- certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#python))
|
||||
- [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#node))
|
||||
- Our own [`step` CLI tool](https://github.com/smallstep/cli) is also an ACME client!
|
||||
- See our [ACME tutorial](https://smallstep.com/docs/tutorials/acme-challenge) for more
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ type NewAccountRequest struct {
|
|||
|
||||
func validateContacts(cs []string) error {
|
||||
for _, c := range cs {
|
||||
if len(c) == 0 {
|
||||
if c == "" {
|
||||
return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -178,7 +178,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
|
||||
u := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
|
||||
|
||||
oids := []string{"foo", "bar"}
|
||||
oidURLs := []string{
|
||||
|
@ -255,7 +255,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrdersByAccountID(w, req)
|
||||
|
|
|
@ -64,8 +64,14 @@ type HandlerOptions struct {
|
|||
|
||||
// NewHandler returns a new ACME API handler.
|
||||
func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
|
|
|
@ -148,7 +148,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("authzID", az.ID)
|
||||
url := fmt.Sprintf("%s/acme/%s/authz/%s",
|
||||
u := fmt.Sprintf("%s/acme/%s/authz/%s",
|
||||
baseURL.String(), provName, az.ID)
|
||||
|
||||
type test struct {
|
||||
|
@ -280,7 +280,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
expB, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Location"], []string{u})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
|
@ -314,7 +314,7 @@ func TestHandler_GetCertificate(t *testing.T) {
|
|||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("certID", certID)
|
||||
url := fmt.Sprintf("%s/acme/%s/certificate/%s",
|
||||
u := fmt.Sprintf("%s/acme/%s/certificate/%s",
|
||||
baseURL.String(), provName, certID)
|
||||
|
||||
type test struct {
|
||||
|
@ -396,7 +396,7 @@ func TestHandler_GetCertificate(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetCertificate(w, req)
|
||||
|
@ -434,7 +434,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
|
||||
url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
|
||||
u := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
|
||||
baseURL.String(), provName, "authzID", "chID")
|
||||
|
||||
type test struct {
|
||||
|
@ -574,13 +574,13 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
assert.Equals(t, azID, "authzID")
|
||||
return &acme.Challenge{
|
||||
Status: acme.StatusPending,
|
||||
Type: "http-01",
|
||||
Type: acme.HTTP01,
|
||||
AccountID: "accID",
|
||||
}, nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
assert.Equals(t, ch.Status, acme.StatusPending)
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
assert.Equals(t, ch.AccountID, "accID")
|
||||
assert.Equals(t, ch.AuthorizationID, "authzID")
|
||||
assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String())
|
||||
|
@ -616,13 +616,13 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
return &acme.Challenge{
|
||||
ID: "chID",
|
||||
Status: acme.StatusPending,
|
||||
Type: "http-01",
|
||||
Type: acme.HTTP01,
|
||||
AccountID: "accID",
|
||||
}, nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
assert.Equals(t, ch.Status, acme.StatusPending)
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
assert.Equals(t, ch.AccountID, "accID")
|
||||
assert.Equals(t, ch.AuthorizationID, "authzID")
|
||||
assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String())
|
||||
|
@ -633,9 +633,9 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
ID: "chID",
|
||||
Status: acme.StatusPending,
|
||||
AuthorizationID: "authzID",
|
||||
Type: "http-01",
|
||||
Type: acme.HTTP01,
|
||||
AccountID: "accID",
|
||||
URL: url,
|
||||
URL: u,
|
||||
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
||||
},
|
||||
vco: &acme.ValidateChallengeOptions{
|
||||
|
@ -652,7 +652,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetChallenge(w, req)
|
||||
|
@ -678,7 +678,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")})
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Location"], []string{u})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
|
|
|
@ -223,7 +223,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
|||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
||||
return
|
||||
}
|
||||
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
||||
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
||||
return
|
||||
}
|
||||
|
@ -288,13 +288,13 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
name := chi.URLParam(r, "provisionerID")
|
||||
provID, err := url.PathUnescape(name)
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner id '%s'", name))
|
||||
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
p, err := h.ca.LoadProvisionerByID("acme/" + provID)
|
||||
p, err := h.ca.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
|
@ -367,7 +367,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
|||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandler_addNonce(t *testing.T) {
|
||||
url := "https://ca.smallstep.com/acme/new-nonce"
|
||||
u := "https://ca.smallstep.com/acme/new-nonce"
|
||||
type test struct {
|
||||
db acme.DB
|
||||
err *acme.Error
|
||||
|
@ -141,7 +141,7 @@ func TestHandler_addNonce(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.addNonce(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
@ -230,7 +230,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
prov := newProv()
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
||||
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
||||
type test struct {
|
||||
h Handler
|
||||
ctx context.Context
|
||||
|
@ -245,7 +245,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
url: url,
|
||||
url: u,
|
||||
ctx: context.Background(),
|
||||
contentType: "foo",
|
||||
statusCode: 500,
|
||||
|
@ -257,7 +257,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
url: url,
|
||||
url: u,
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
contentType: "foo",
|
||||
statusCode: 400,
|
||||
|
@ -319,11 +319,11 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
_url := url
|
||||
_u := u
|
||||
if tc.url != "" {
|
||||
_url = tc.url
|
||||
_u = tc.url
|
||||
}
|
||||
req := httptest.NewRequest("GET", _url, nil)
|
||||
req := httptest.NewRequest("GET", _u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req.Header.Add("Content-Type", tc.contentType)
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -353,7 +353,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandler_isPostAsGet(t *testing.T) {
|
||||
url := "https://ca.smallstep.com/acme/new-account"
|
||||
u := "https://ca.smallstep.com/acme/new-account"
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
err *acme.Error
|
||||
|
@ -392,7 +392,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.isPostAsGet(testNext)(w, req)
|
||||
|
@ -430,7 +430,7 @@ func (errReader) Close() error {
|
|||
}
|
||||
|
||||
func TestHandler_parseJWS(t *testing.T) {
|
||||
url := "https://ca.smallstep.com/acme/new-account"
|
||||
u := "https://ca.smallstep.com/acme/new-account"
|
||||
type test struct {
|
||||
next nextHTTP
|
||||
body io.Reader
|
||||
|
@ -483,7 +483,7 @@ func TestHandler_parseJWS(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest("GET", url, tc.body)
|
||||
req := httptest.NewRequest("GET", u, tc.body)
|
||||
w := httptest.NewRecorder()
|
||||
h.parseJWS(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
@ -528,7 +528,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
url := "https://ca.smallstep.com/acme/account/1234"
|
||||
u := "https://ca.smallstep.com/acme/account/1234"
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
next func(http.ResponseWriter, *http.Request)
|
||||
|
@ -681,7 +681,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||
|
@ -713,7 +713,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
url := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||
baseURL, provName)
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -883,7 +883,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: tc.linker}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.lookupJWK(tc.next)(w, req)
|
||||
|
@ -934,7 +934,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
|
||||
u := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
|
||||
provName)
|
||||
type test struct {
|
||||
db acme.DB
|
||||
|
@ -1079,7 +1079,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractJWK(tc.next)(w, req)
|
||||
|
@ -1108,7 +1108,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandler_validateJWS(t *testing.T) {
|
||||
url := "https://ca.smallstep.com/acme/account/1234"
|
||||
u := "https://ca.smallstep.com/acme/account/1234"
|
||||
type test struct {
|
||||
db acme.DB
|
||||
ctx context.Context
|
||||
|
@ -1198,7 +1198,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
Algorithm: jose.RS256,
|
||||
JSONWebKey: &pub,
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"url": url,
|
||||
"url": u,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1226,7 +1226,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
Algorithm: jose.RS256,
|
||||
JSONWebKey: &pub,
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"url": url,
|
||||
"url": u,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1298,7 +1298,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", url),
|
||||
err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", u),
|
||||
}
|
||||
},
|
||||
"fail/both-jwk-kid": func(t *testing.T) test {
|
||||
|
@ -1313,7 +1313,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
KeyID: "bar",
|
||||
JSONWebKey: &pub,
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"url": url,
|
||||
"url": u,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1337,7 +1337,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
Protected: jose.Header{
|
||||
Algorithm: jose.ES256,
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"url": url,
|
||||
"url": u,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1362,7 +1362,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
Algorithm: jose.ES256,
|
||||
KeyID: "bar",
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"url": url,
|
||||
"url": u,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1392,7 +1392,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
Algorithm: jose.ES256,
|
||||
JSONWebKey: &pub,
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"url": url,
|
||||
"url": u,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1422,7 +1422,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
Algorithm: jose.RS256,
|
||||
JSONWebKey: &pub,
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"url": url,
|
||||
"url": u,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1446,7 +1446,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.validateJWS(tc.next)(w, req)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -28,9 +29,12 @@ func (n *NewOrderRequest) Validate() error {
|
|||
return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty")
|
||||
}
|
||||
for _, id := range n.Identifiers {
|
||||
if id.Type != "dns" {
|
||||
if !(id.Type == acme.DNS || id.Type == acme.IP) {
|
||||
return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type)
|
||||
}
|
||||
if id.Type == acme.IP && net.ParseIP(id.Value) == nil {
|
||||
return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -85,6 +89,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
"failed to unmarshal new-order request payload"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := nor.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
|
@ -149,15 +154,9 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
|
|||
}
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
chTypes = []string{"dns-01"}
|
||||
)
|
||||
// HTTP and TLS challenges can only be used for identifiers without wildcards.
|
||||
if !az.Wildcard {
|
||||
chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...)
|
||||
}
|
||||
chTypes := challengeTypes(az)
|
||||
|
||||
var err error
|
||||
az.Token, err = randutil.Alphanumeric(32)
|
||||
if err != nil {
|
||||
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
|
||||
|
@ -275,3 +274,24 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
api.JSON(w, o)
|
||||
}
|
||||
|
||||
// challengeTypes determines the types of challenges that should be used
|
||||
// for the ACME authorization request.
|
||||
func challengeTypes(az *acme.Authorization) []acme.ChallengeType {
|
||||
var chTypes []acme.ChallengeType
|
||||
|
||||
switch az.Identifier.Type {
|
||||
case acme.IP:
|
||||
chTypes = []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}
|
||||
case acme.DNS:
|
||||
chTypes = []acme.ChallengeType{acme.DNS01}
|
||||
// HTTP and TLS challenges can only be used for identifiers without wildcards.
|
||||
if !az.Wildcard {
|
||||
chTypes = append(chTypes, []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}...)
|
||||
}
|
||||
default:
|
||||
chTypes = []acme.ChallengeType{}
|
||||
}
|
||||
|
||||
return chTypes
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -44,6 +45,22 @@ func TestNewOrderRequest_Validate(t *testing.T) {
|
|||
err: acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: foo"),
|
||||
}
|
||||
},
|
||||
"fail/bad-ip": func(t *testing.T) test {
|
||||
nbf := time.Now().UTC().Add(time.Minute)
|
||||
naf := time.Now().UTC().Add(5 * time.Minute)
|
||||
return test{
|
||||
nor: &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "ip", Value: "192.168.42.1000"},
|
||||
},
|
||||
NotAfter: naf,
|
||||
NotBefore: nbf,
|
||||
},
|
||||
nbf: nbf,
|
||||
naf: naf,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", "192.168.42.1000"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
nbf := time.Now().UTC().Add(time.Minute)
|
||||
naf := time.Now().UTC().Add(5 * time.Minute)
|
||||
|
@ -60,6 +77,68 @@ func TestNewOrderRequest_Validate(t *testing.T) {
|
|||
naf: naf,
|
||||
}
|
||||
},
|
||||
"ok/ipv4": func(t *testing.T) test {
|
||||
nbf := time.Now().UTC().Add(time.Minute)
|
||||
naf := time.Now().UTC().Add(5 * time.Minute)
|
||||
return test{
|
||||
nor: &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
},
|
||||
NotAfter: naf,
|
||||
NotBefore: nbf,
|
||||
},
|
||||
nbf: nbf,
|
||||
naf: naf,
|
||||
}
|
||||
},
|
||||
"ok/ipv6": func(t *testing.T) test {
|
||||
nbf := time.Now().UTC().Add(time.Minute)
|
||||
naf := time.Now().UTC().Add(5 * time.Minute)
|
||||
return test{
|
||||
nor: &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "ip", Value: "2001:db8::1"},
|
||||
},
|
||||
NotAfter: naf,
|
||||
NotBefore: nbf,
|
||||
},
|
||||
nbf: nbf,
|
||||
naf: naf,
|
||||
}
|
||||
},
|
||||
"ok/mixed-dns-and-ipv4": func(t *testing.T) test {
|
||||
nbf := time.Now().UTC().Add(time.Minute)
|
||||
naf := time.Now().UTC().Add(5 * time.Minute)
|
||||
return test{
|
||||
nor: &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
},
|
||||
NotAfter: naf,
|
||||
NotBefore: nbf,
|
||||
},
|
||||
nbf: nbf,
|
||||
naf: naf,
|
||||
}
|
||||
},
|
||||
"ok/mixed-ipv4-and-ipv6": func(t *testing.T) test {
|
||||
nbf := time.Now().UTC().Add(time.Minute)
|
||||
naf := time.Now().UTC().Add(5 * time.Minute)
|
||||
return test{
|
||||
nor: &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
{Type: "ip", Value: "2001:db8::1"},
|
||||
},
|
||||
NotAfter: naf,
|
||||
NotBefore: nbf,
|
||||
},
|
||||
nbf: nbf,
|
||||
naf: naf,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
|
@ -185,7 +264,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("ordID", o.ID)
|
||||
url := fmt.Sprintf("%s/acme/%s/order/%s",
|
||||
u := fmt.Sprintf("%s/acme/%s/order/%s",
|
||||
baseURL.String(), escProvName, o.ID)
|
||||
|
||||
type test struct {
|
||||
|
@ -343,7 +422,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrder(w, req)
|
||||
|
@ -369,7 +448,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Location"], []string{u})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
|
@ -395,7 +474,7 @@ func TestHandler_newAuthorization(t *testing.T) {
|
|||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
assert.Equals(t, ch.AccountID, az.AccountID)
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
assert.Equals(t, ch.Token, az.Token)
|
||||
assert.Equals(t, ch.Status, acme.StatusPending)
|
||||
assert.Equals(t, ch.Value, az.Identifier.Value)
|
||||
|
@ -424,15 +503,15 @@ func TestHandler_newAuthorization(t *testing.T) {
|
|||
switch count {
|
||||
case 0:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
ch1 = &ch
|
||||
case 1:
|
||||
ch.ID = "http"
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
ch2 = &ch
|
||||
case 2:
|
||||
ch.ID = "tls"
|
||||
assert.Equals(t, ch.Type, "tls-alpn-01")
|
||||
assert.Equals(t, ch.Type, acme.TLSALPN01)
|
||||
ch3 = &ch
|
||||
default:
|
||||
assert.FatalError(t, errors.New("test logic error"))
|
||||
|
@ -478,15 +557,15 @@ func TestHandler_newAuthorization(t *testing.T) {
|
|||
switch count {
|
||||
case 0:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
ch1 = &ch
|
||||
case 1:
|
||||
ch.ID = "http"
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
ch2 = &ch
|
||||
case 2:
|
||||
ch.ID = "tls"
|
||||
assert.Equals(t, ch.Type, "tls-alpn-01")
|
||||
assert.Equals(t, ch.Type, acme.TLSALPN01)
|
||||
ch3 = &ch
|
||||
default:
|
||||
assert.FatalError(t, errors.New("test logic error"))
|
||||
|
@ -528,7 +607,7 @@ func TestHandler_newAuthorization(t *testing.T) {
|
|||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
assert.Equals(t, ch.AccountID, az.AccountID)
|
||||
assert.Equals(t, ch.Token, az.Token)
|
||||
assert.Equals(t, ch.Status, acme.StatusPending)
|
||||
|
@ -584,7 +663,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
prov := newProv()
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
url := fmt.Sprintf("%s/acme/%s/order/ordID",
|
||||
u := fmt.Sprintf("%s/acme/%s/order/ordID",
|
||||
baseURL.String(), escProvName)
|
||||
|
||||
type test struct {
|
||||
|
@ -695,7 +774,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
assert.Equals(t, ch.AccountID, "accID")
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
assert.NotEquals(t, ch.Token, "")
|
||||
assert.Equals(t, ch.Status, acme.StatusPending)
|
||||
assert.Equals(t, ch.Value, "zap.internal")
|
||||
|
@ -730,15 +809,15 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
switch count {
|
||||
case 0:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
ch1 = &ch
|
||||
case 1:
|
||||
ch.ID = "http"
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
ch2 = &ch
|
||||
case 2:
|
||||
ch.ID = "tls"
|
||||
assert.Equals(t, ch.Type, "tls-alpn-01")
|
||||
assert.Equals(t, ch.Type, acme.TLSALPN01)
|
||||
ch3 = &ch
|
||||
default:
|
||||
assert.FatalError(t, errors.New("test logic error"))
|
||||
|
@ -802,22 +881,22 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
switch chCount {
|
||||
case 0:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
assert.Equals(t, ch.Value, "zap.internal")
|
||||
ch1 = &ch
|
||||
case 1:
|
||||
ch.ID = "http"
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
assert.Equals(t, ch.Value, "zap.internal")
|
||||
ch2 = &ch
|
||||
case 2:
|
||||
ch.ID = "tls"
|
||||
assert.Equals(t, ch.Type, "tls-alpn-01")
|
||||
assert.Equals(t, ch.Type, acme.TLSALPN01)
|
||||
assert.Equals(t, ch.Value, "zap.internal")
|
||||
ch3 = &ch
|
||||
case 3:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
assert.Equals(t, ch.Value, "zar.internal")
|
||||
ch4 = &ch
|
||||
default:
|
||||
|
@ -842,7 +921,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
az.ID = "az2ID"
|
||||
az2ID = &az.ID
|
||||
assert.Equals(t, az.Identifier, acme.Identifier{
|
||||
Type: "dns",
|
||||
Type: acme.DNS,
|
||||
Value: "zar.internal",
|
||||
})
|
||||
assert.Equals(t, az.Wildcard, true)
|
||||
|
@ -917,15 +996,15 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
switch count {
|
||||
case 0:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
ch1 = &ch
|
||||
case 1:
|
||||
ch.ID = "http"
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
ch2 = &ch
|
||||
case 2:
|
||||
ch.ID = "tls"
|
||||
assert.Equals(t, ch.Type, "tls-alpn-01")
|
||||
assert.Equals(t, ch.Type, acme.TLSALPN01)
|
||||
ch3 = &ch
|
||||
default:
|
||||
assert.FatalError(t, errors.New("test logic error"))
|
||||
|
@ -1009,15 +1088,15 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
switch count {
|
||||
case 0:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
ch1 = &ch
|
||||
case 1:
|
||||
ch.ID = "http"
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
ch2 = &ch
|
||||
case 2:
|
||||
ch.ID = "tls"
|
||||
assert.Equals(t, ch.Type, "tls-alpn-01")
|
||||
assert.Equals(t, ch.Type, acme.TLSALPN01)
|
||||
ch3 = &ch
|
||||
default:
|
||||
assert.FatalError(t, errors.New("test logic error"))
|
||||
|
@ -1100,15 +1179,15 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
switch count {
|
||||
case 0:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
ch1 = &ch
|
||||
case 1:
|
||||
ch.ID = "http"
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
ch2 = &ch
|
||||
case 2:
|
||||
ch.ID = "tls"
|
||||
assert.Equals(t, ch.Type, "tls-alpn-01")
|
||||
assert.Equals(t, ch.Type, acme.TLSALPN01)
|
||||
ch3 = &ch
|
||||
default:
|
||||
assert.FatalError(t, errors.New("test logic error"))
|
||||
|
@ -1192,15 +1271,15 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
switch count {
|
||||
case 0:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, "dns-01")
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
ch1 = &ch
|
||||
case 1:
|
||||
ch.ID = "http"
|
||||
assert.Equals(t, ch.Type, "http-01")
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
ch2 = &ch
|
||||
case 2:
|
||||
ch.ID = "tls"
|
||||
assert.Equals(t, ch.Type, "tls-alpn-01")
|
||||
assert.Equals(t, ch.Type, acme.TLSALPN01)
|
||||
ch3 = &ch
|
||||
default:
|
||||
assert.FatalError(t, errors.New("test logic error"))
|
||||
|
@ -1256,7 +1335,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.NewOrder(w, req)
|
||||
|
@ -1284,7 +1363,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
tc.vr(t, ro)
|
||||
}
|
||||
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Location"], []string{u})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
|
@ -1327,7 +1406,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("ordID", o.ID)
|
||||
url := fmt.Sprintf("%s/acme/%s/order/%s",
|
||||
u := fmt.Sprintf("%s/acme/%s/order/%s",
|
||||
baseURL.String(), escProvName, o.ID)
|
||||
|
||||
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
||||
|
@ -1546,7 +1625,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.FinalizeOrder(w, req)
|
||||
|
@ -1575,9 +1654,58 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
assert.FatalError(t, json.Unmarshal(body, ro))
|
||||
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Location"], []string{u})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_challengeTypes(t *testing.T) {
|
||||
type args struct {
|
||||
az *acme.Authorization
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []acme.ChallengeType
|
||||
}{
|
||||
{
|
||||
name: "ok/dns",
|
||||
args: args{
|
||||
az: &acme.Authorization{
|
||||
Identifier: acme.Identifier{Type: "dns", Value: "example.com"},
|
||||
Wildcard: false,
|
||||
},
|
||||
},
|
||||
want: []acme.ChallengeType{acme.DNS01, acme.HTTP01, acme.TLSALPN01},
|
||||
},
|
||||
{
|
||||
name: "ok/wildcard",
|
||||
args: args{
|
||||
az: &acme.Authorization{
|
||||
Identifier: acme.Identifier{Type: "dns", Value: "*.example.com"},
|
||||
Wildcard: true,
|
||||
},
|
||||
},
|
||||
want: []acme.ChallengeType{acme.DNS01},
|
||||
},
|
||||
{
|
||||
name: "ok/ip",
|
||||
args: args{
|
||||
az: &acme.Authorization{
|
||||
Identifier: acme.Identifier{Type: "ip", Value: "192.168.42.42"},
|
||||
Wildcard: false,
|
||||
},
|
||||
},
|
||||
want: []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := challengeTypes(tt.args.az); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Handler.challengeTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,29 +10,39 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
type ChallengeType string
|
||||
|
||||
const (
|
||||
HTTP01 ChallengeType = "http-01"
|
||||
DNS01 ChallengeType = "dns-01"
|
||||
TLSALPN01 ChallengeType = "tls-alpn-01"
|
||||
)
|
||||
|
||||
// Challenge represents an ACME response Challenge type.
|
||||
type Challenge struct {
|
||||
ID string `json:"-"`
|
||||
AccountID string `json:"-"`
|
||||
AuthorizationID string `json:"-"`
|
||||
Value string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Status Status `json:"status"`
|
||||
Token string `json:"token"`
|
||||
ValidatedAt string `json:"validated,omitempty"`
|
||||
URL string `json:"url"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
ID string `json:"-"`
|
||||
AccountID string `json:"-"`
|
||||
AuthorizationID string `json:"-"`
|
||||
Value string `json:"-"`
|
||||
Type ChallengeType `json:"type"`
|
||||
Status Status `json:"status"`
|
||||
Token string `json:"token"`
|
||||
ValidatedAt string `json:"validated,omitempty"`
|
||||
URL string `json:"url"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
|
@ -54,11 +64,11 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
|
|||
return nil
|
||||
}
|
||||
switch ch.Type {
|
||||
case "http-01":
|
||||
case HTTP01:
|
||||
return http01Validate(ctx, ch, db, jwk, vo)
|
||||
case "dns-01":
|
||||
case DNS01:
|
||||
return dns01Validate(ctx, ch, db, jwk, vo)
|
||||
case "tls-alpn-01":
|
||||
case TLSALPN01:
|
||||
return tlsalpn01Validate(ctx, ch, db, jwk, vo)
|
||||
default:
|
||||
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
|
||||
|
@ -66,23 +76,23 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
|
|||
}
|
||||
|
||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||
|
||||
resp, err := vo.HTTPGet(url.String())
|
||||
resp, err := vo.HTTPGet(u.String())
|
||||
if err != nil {
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||
"error doing http GET for url %s", url))
|
||||
"error doing http GET for url %s", u))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return storeError(ctx, db, ch, false, NewError(ErrorConnectionType,
|
||||
"error doing http GET for url %s with status code %d", url, resp.StatusCode))
|
||||
"error doing http GET for url %s with status code %d", u, resp.StatusCode))
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error reading "+
|
||||
"response body for url %s", url)
|
||||
"response body for url %s", u)
|
||||
}
|
||||
keyAuth := strings.TrimSpace(string(body))
|
||||
|
||||
|
@ -106,6 +116,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
|
|||
return nil
|
||||
}
|
||||
|
||||
func tlsAlert(err error) uint8 {
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
v := reflect.ValueOf(opErr.Err)
|
||||
if v.Kind() == reflect.Uint8 {
|
||||
return uint8(v.Uint())
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
config := &tls.Config{
|
||||
NextProtos: []string{"acme-tls/1"},
|
||||
|
@ -113,7 +134,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
// ACME servers that implement "acme-tls/1" MUST only negotiate TLS 1.2
|
||||
// [RFC5246] or higher when connecting to clients for validation.
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: ch.Value,
|
||||
ServerName: serverName(ch),
|
||||
InsecureSkipVerify: true, // we expect a self-signed challenge certificate
|
||||
}
|
||||
|
||||
|
@ -121,6 +142,14 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
|
||||
conn, err := vo.TLSDial("tcp", hostPort, config)
|
||||
if err != nil {
|
||||
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
|
||||
// client and server protocols. When this happens the connection is
|
||||
// closed with the error no_application_protocol(120) as required by
|
||||
// RFC7301. See https://golang.org/doc/go1.17#ALPN
|
||||
if tlsAlert(err) == 120 {
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||
"cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge"))
|
||||
}
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||
"error doing TLS dial for %s", hostPort))
|
||||
}
|
||||
|
@ -141,9 +170,17 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
|
||||
leafCert := certs[0]
|
||||
|
||||
if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) {
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||
"incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value))
|
||||
// if no DNS names present, look for IP address and verify that exactly one exists
|
||||
if len(leafCert.DNSNames) == 0 {
|
||||
if len(leafCert.IPAddresses) != 1 || !leafCert.IPAddresses[0].Equal(net.ParseIP(ch.Value)) {
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||
"incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value))
|
||||
}
|
||||
} else {
|
||||
if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) {
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||
"incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value))
|
||||
}
|
||||
}
|
||||
|
||||
idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31}
|
||||
|
@ -244,6 +281,65 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK
|
|||
return nil
|
||||
}
|
||||
|
||||
// serverName determines the SNI HostName to set based on an acme.Challenge
|
||||
// for TLS-ALPN-01 challenges RFC8738 states that, if HostName is an IP, it
|
||||
// should be the ARPA address https://datatracker.ietf.org/doc/html/rfc8738#section-6.
|
||||
// It also references TLS Extensions [RFC6066].
|
||||
func serverName(ch *Challenge) string {
|
||||
var serverName string
|
||||
ip := net.ParseIP(ch.Value)
|
||||
if ip != nil {
|
||||
serverName = reverseAddr(ip)
|
||||
} else {
|
||||
serverName = ch.Value
|
||||
}
|
||||
return serverName
|
||||
}
|
||||
|
||||
// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
|
||||
// address addr suitable for rDNS (PTR) record lookup or an error if it fails
|
||||
// to parse the IP address.
|
||||
// Implementation taken and adapted from https://golang.org/src/net/dnsclient.go?s=780:834#L20
|
||||
func reverseAddr(ip net.IP) (arpa string) {
|
||||
if ip.To4() != nil {
|
||||
return uitoa(uint(ip[15])) + "." + uitoa(uint(ip[14])) + "." + uitoa(uint(ip[13])) + "." + uitoa(uint(ip[12])) + ".in-addr.arpa."
|
||||
}
|
||||
// Must be IPv6
|
||||
buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
|
||||
// Add it, in reverse, to the buffer
|
||||
for i := len(ip) - 1; i >= 0; i-- {
|
||||
v := ip[i]
|
||||
buf = append(buf, hexit[v&0xF],
|
||||
'.',
|
||||
hexit[v>>4],
|
||||
'.')
|
||||
}
|
||||
// Append "ip6.arpa." and return (buf already has the final .)
|
||||
buf = append(buf, "ip6.arpa."...)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// Convert unsigned integer to decimal string.
|
||||
// Implementation taken from https://golang.org/src/net/parse.go
|
||||
func uitoa(val uint) string {
|
||||
if val == 0 { // avoid string allocation
|
||||
return "0"
|
||||
}
|
||||
var buf [20]byte // big enough for 64bit value base 10
|
||||
i := len(buf) - 1
|
||||
for val >= 10 {
|
||||
q := val / 10
|
||||
buf[i] = byte('0' + val - q*10)
|
||||
i--
|
||||
val = q
|
||||
}
|
||||
// val < 10
|
||||
buf[i] = byte('0' + val)
|
||||
return string(buf[i:])
|
||||
}
|
||||
|
||||
const hexit = "0123456789abcdef"
|
||||
|
||||
// KeyAuthorization creates the ACME key authorization value from a token
|
||||
// and a jwk.
|
||||
func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
|
||||
|
|
|
@ -1276,7 +1276,7 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na
|
|||
oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
|
||||
}
|
||||
|
||||
keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:])
|
||||
keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash)
|
||||
|
||||
certTemplate.ExtraExtensions = []pkix.Extension{
|
||||
{
|
||||
|
@ -1395,7 +1395,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
assert.Equals(t, updch.Type, ch.Type)
|
||||
assert.Equals(t, updch.Value, ch.Value)
|
||||
|
||||
err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.Value)
|
||||
err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443:", ch.Value)
|
||||
|
||||
assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
|
||||
assert.Equals(t, updch.Error.Type, err.Type)
|
||||
|
@ -1544,7 +1544,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
err: NewErrorISE("failure saving error to acme challenge: force"),
|
||||
}
|
||||
},
|
||||
"ok/no-names-error": func(t *testing.T) test {
|
||||
"ok/no-names-nor-ips-error": func(t *testing.T) test {
|
||||
ch := makeTLSCh()
|
||||
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
|
@ -1573,7 +1573,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
assert.Equals(t, updch.Type, ch.Type)
|
||||
assert.Equals(t, updch.Value, ch.Value)
|
||||
|
||||
err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)
|
||||
err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value)
|
||||
|
||||
assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
|
||||
assert.Equals(t, updch.Error.Type, err.Type)
|
||||
|
@ -1616,7 +1616,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
assert.Equals(t, updch.Type, ch.Type)
|
||||
assert.Equals(t, updch.Value, ch.Value)
|
||||
|
||||
err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)
|
||||
err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value)
|
||||
|
||||
assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
|
||||
assert.Equals(t, updch.Error.Type, err.Type)
|
||||
|
@ -1660,7 +1660,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
assert.Equals(t, updch.Type, ch.Type)
|
||||
assert.Equals(t, updch.Value, ch.Value)
|
||||
|
||||
err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)
|
||||
err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value)
|
||||
|
||||
assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
|
||||
assert.Equals(t, updch.Error.Type, err.Type)
|
||||
|
@ -1703,7 +1703,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
assert.Equals(t, updch.Type, ch.Type)
|
||||
assert.Equals(t, updch.Value, ch.Value)
|
||||
|
||||
err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)
|
||||
err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value)
|
||||
|
||||
assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
|
||||
assert.Equals(t, updch.Error.Type, err.Type)
|
||||
|
@ -2187,6 +2187,43 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
srv, tlsDial := newTestTLSALPNServer(cert)
|
||||
srv.Start()
|
||||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
assert.Equals(t, updch.ID, ch.ID)
|
||||
assert.Equals(t, updch.Token, ch.Token)
|
||||
assert.Equals(t, updch.Status, StatusValid)
|
||||
assert.Equals(t, updch.Type, ch.Type)
|
||||
assert.Equals(t, updch.Value, ch.Value)
|
||||
assert.Equals(t, updch.Error, nil)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
srv: srv,
|
||||
jwk: jwk,
|
||||
}
|
||||
},
|
||||
"ok/ip": func(t *testing.T) test {
|
||||
ch := makeTLSCh()
|
||||
ch.Value = "127.0.0.1"
|
||||
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
|
||||
assert.FatalError(t, err)
|
||||
expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
|
||||
|
||||
cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
srv, tlsDial := newTestTLSALPNServer(cert)
|
||||
srv.Start()
|
||||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
|
@ -2235,3 +2272,82 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_reverseAddr(t *testing.T) {
|
||||
type args struct {
|
||||
ip net.IP
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantArpa string
|
||||
}{
|
||||
{
|
||||
name: "ok/ipv4",
|
||||
args: args{
|
||||
ip: net.ParseIP("127.0.0.1"),
|
||||
},
|
||||
wantArpa: "1.0.0.127.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "ok/ipv6",
|
||||
args: args{
|
||||
ip: net.ParseIP("2001:db8::567:89ab"),
|
||||
},
|
||||
wantArpa: "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotArpa := reverseAddr(tt.args.ip); gotArpa != tt.wantArpa {
|
||||
t.Errorf("reverseAddr() = %v, want %v", gotArpa, tt.wantArpa)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_serverName(t *testing.T) {
|
||||
type args struct {
|
||||
ch *Challenge
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "ok/dns",
|
||||
args: args{
|
||||
ch: &Challenge{
|
||||
Value: "example.com",
|
||||
},
|
||||
},
|
||||
want: "example.com",
|
||||
},
|
||||
{
|
||||
name: "ok/ipv4",
|
||||
args: args{
|
||||
ch: &Challenge{
|
||||
Value: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
want: "1.0.0.127.in-addr.arpa.",
|
||||
},
|
||||
{
|
||||
name: "ok/ipv6",
|
||||
args: args{
|
||||
ch: &Challenge{
|
||||
Value: "2001:db8::567:89ab",
|
||||
},
|
||||
},
|
||||
want: "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := serverName(tt.args.ch); got != tt.want {
|
||||
t.Errorf("serverName() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
|
|
|
@ -93,8 +93,8 @@ func TestDB_getDBAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if dbacc, err := d.getDBAccount(context.Background(), accID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -109,15 +109,13 @@ func TestDB_getDBAccount(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
||||
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
|
||||
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
||||
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
|
||||
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -174,8 +172,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if retAccID, err := d.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -190,10 +188,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, retAccID, accID)
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, retAccID, accID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -250,8 +246,8 @@ func TestDB_GetAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if acc, err := db.GetAccount(context.Background(), accID); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if acc, err := d.GetAccount(context.Background(), accID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -266,13 +262,11 @@ func TestDB_GetAccount(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -358,8 +352,8 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if acc, err := d.GetAccountByKeyID(context.Background(), kid); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -374,13 +368,11 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -527,8 +519,8 @@ func TestDB_CreateAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.CreateAccount(context.Background(), tc.acc); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.CreateAccount(context.Background(), tc.acc); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -688,8 +680,8 @@ func TestDB_UpdateAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.UpdateAccount(context.Background(), tc.acc); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.UpdateAccount(context.Background(), tc.acc); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
|
|
@ -97,8 +97,8 @@ func TestDB_getDBAuthz(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if dbaz, err := d.getDBAuthz(context.Background(), azID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -113,18 +113,16 @@ func TestDB_getDBAuthz(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbaz.ID, tc.dbaz.ID)
|
||||
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
|
||||
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
|
||||
assert.Equals(t, dbaz.Status, tc.dbaz.Status)
|
||||
assert.Equals(t, dbaz.Token, tc.dbaz.Token)
|
||||
assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
|
||||
assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
|
||||
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbaz.ID, tc.dbaz.ID)
|
||||
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
|
||||
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
|
||||
assert.Equals(t, dbaz.Status, tc.dbaz.Status)
|
||||
assert.Equals(t, dbaz.Token, tc.dbaz.Token)
|
||||
assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
|
||||
assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
|
||||
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -293,8 +291,8 @@ func TestDB_GetAuthorization(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if az, err := db.GetAuthorization(context.Background(), azID); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if az, err := d.GetAuthorization(context.Background(), azID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -309,21 +307,19 @@ func TestDB_GetAuthorization(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, az.ID, tc.dbaz.ID)
|
||||
assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
|
||||
assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
|
||||
assert.Equals(t, az.Status, tc.dbaz.Status)
|
||||
assert.Equals(t, az.Token, tc.dbaz.Token)
|
||||
assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
|
||||
assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||
assert.Equals(t, az.Challenges, []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
})
|
||||
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, az.ID, tc.dbaz.ID)
|
||||
assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
|
||||
assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
|
||||
assert.Equals(t, az.Status, tc.dbaz.Status)
|
||||
assert.Equals(t, az.Token, tc.dbaz.Token)
|
||||
assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
|
||||
assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||
assert.Equals(t, az.Challenges, []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
})
|
||||
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -445,8 +441,8 @@ func TestDB_CreateAuthorization(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.CreateAuthorization(context.Background(), tc.az); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.CreateAuthorization(context.Background(), tc.az); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -594,8 +590,8 @@ func TestDB_UpdateAuthorization(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.UpdateAuthorization(context.Background(), tc.az); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
|
|
@ -98,8 +98,8 @@ func TestDB_CreateCertificate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.CreateCertificate(context.Background(), tc.cert); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.CreateCertificate(context.Background(), tc.cert); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -228,8 +228,8 @@ func TestDB_GetCertificate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
cert, err := db.GetCertificate(context.Background(), certID)
|
||||
d := DB{db: tc.db}
|
||||
cert, err := d.GetCertificate(context.Background(), certID)
|
||||
if err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
|
@ -245,14 +245,12 @@ func TestDB_GetCertificate(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, cert.ID, certID)
|
||||
assert.Equals(t, cert.AccountID, "accountID")
|
||||
assert.Equals(t, cert.OrderID, "orderID")
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, cert.ID, certID)
|
||||
assert.Equals(t, cert.AccountID, "accountID")
|
||||
assert.Equals(t, cert.OrderID, "orderID")
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -11,15 +11,15 @@ import (
|
|||
)
|
||||
|
||||
type dbChallenge struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Type string `json:"type"`
|
||||
Status acme.Status `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Value string `json:"value"`
|
||||
ValidatedAt string `json:"validatedAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Error *acme.Error `json:"error"`
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Type acme.ChallengeType `json:"type"`
|
||||
Status acme.Status `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Value string `json:"value"`
|
||||
ValidatedAt string `json:"validatedAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Error *acme.Error `json:"error"`
|
||||
}
|
||||
|
||||
func (dbc *dbChallenge) clone() *dbChallenge {
|
||||
|
|
|
@ -92,8 +92,8 @@ func TestDB_getDBChallenge(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if ch, err := d.getDBChallenge(context.Background(), chID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -108,17 +108,15 @@ func TestDB_getDBChallenge(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -206,8 +204,8 @@ func TestDB_CreateChallenge(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.CreateChallenge(context.Background(), tc.ch); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -286,8 +284,8 @@ func TestDB_GetChallenge(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if ch, err := d.GetChallenge(context.Background(), chID, azID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -302,17 +300,15 @@ func TestDB_GetChallenge(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -442,8 +438,8 @@ func TestDB_UpdateChallenge(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
|
|||
ID: id,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
|
||||
if err := db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return acme.Nonce(id), nil
|
||||
|
|
|
@ -67,8 +67,8 @@ func TestDB_CreateNonce(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if n, err := db.CreateNonce(context.Background()); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if n, err := d.CreateNonce(context.Background()); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -144,8 +144,8 @@ func TestDB_DeleteNonce(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
|
|
@ -41,7 +41,7 @@ func New(db nosqlDB.DB) (*DB, error) {
|
|||
|
||||
// save writes the new data to the database, overwriting the old data if it
|
||||
// existed.
|
||||
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
|
||||
func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||
var (
|
||||
err error
|
||||
newB []byte
|
||||
|
|
|
@ -126,8 +126,8 @@ func TestDB_save(t *testing.T) {
|
|||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := &DB{db: tc.db}
|
||||
if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
|
||||
d := &DB{db: tc.db}
|
||||
if err := d.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
|
|
@ -124,10 +124,8 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st
|
|||
ordersByAccountMux.Lock()
|
||||
defer ordersByAccountMux.Unlock()
|
||||
|
||||
var oldOids []string
|
||||
b, err := db.db.Get(ordersByAccountIDTable, []byte(accID))
|
||||
var (
|
||||
oldOids []string
|
||||
)
|
||||
if err != nil {
|
||||
if !nosql.IsErrNotFound(err) {
|
||||
return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID)
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
nosqldb "github.com/smallstep/nosql/database"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestDB_getDBOrder(t *testing.T) {
|
||||
|
@ -31,7 +31,7 @@ func TestDB_getDBOrder(t *testing.T) {
|
|||
assert.Equals(t, bucket, orderTable)
|
||||
assert.Equals(t, string(key), orderID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
||||
|
@ -100,8 +100,8 @@ func TestDB_getDBOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if dbo, err := db.getDBOrder(context.Background(), orderID); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if dbo, err := d.getDBOrder(context.Background(), orderID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -116,20 +116,18 @@ func TestDB_getDBOrder(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbo.ID, tc.dbo.ID)
|
||||
assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID)
|
||||
assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID)
|
||||
assert.Equals(t, dbo.Status, tc.dbo.Status)
|
||||
assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt)
|
||||
assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt)
|
||||
assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore)
|
||||
assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter)
|
||||
assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers)
|
||||
assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
||||
assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbo.ID, tc.dbo.ID)
|
||||
assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID)
|
||||
assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID)
|
||||
assert.Equals(t, dbo.Status, tc.dbo.Status)
|
||||
assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt)
|
||||
assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt)
|
||||
assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore)
|
||||
assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter)
|
||||
assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers)
|
||||
assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
||||
assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -164,7 +162,7 @@ func TestDB_GetOrder(t *testing.T) {
|
|||
assert.Equals(t, bucket, orderTable)
|
||||
assert.Equals(t, string(key), orderID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
||||
|
@ -206,8 +204,8 @@ func TestDB_GetOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if o, err := db.GetOrder(context.Background(), orderID); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if o, err := d.GetOrder(context.Background(), orderID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
|
@ -222,20 +220,18 @@ func TestDB_GetOrder(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, o.ID, tc.dbo.ID)
|
||||
assert.Equals(t, o.AccountID, tc.dbo.AccountID)
|
||||
assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID)
|
||||
assert.Equals(t, o.CertificateID, tc.dbo.CertificateID)
|
||||
assert.Equals(t, o.Status, tc.dbo.Status)
|
||||
assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt)
|
||||
assert.Equals(t, o.NotBefore, tc.dbo.NotBefore)
|
||||
assert.Equals(t, o.NotAfter, tc.dbo.NotAfter)
|
||||
assert.Equals(t, o.Identifiers, tc.dbo.Identifiers)
|
||||
assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
||||
assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, o.ID, tc.dbo.ID)
|
||||
assert.Equals(t, o.AccountID, tc.dbo.AccountID)
|
||||
assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID)
|
||||
assert.Equals(t, o.CertificateID, tc.dbo.CertificateID)
|
||||
assert.Equals(t, o.Status, tc.dbo.Status)
|
||||
assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt)
|
||||
assert.Equals(t, o.NotBefore, tc.dbo.NotBefore)
|
||||
assert.Equals(t, o.NotAfter, tc.dbo.NotAfter)
|
||||
assert.Equals(t, o.Identifiers, tc.dbo.Identifiers)
|
||||
assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
||||
assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -366,8 +362,8 @@ func TestDB_UpdateOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.UpdateOrder(context.Background(), tc.o); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.UpdateOrder(context.Background(), tc.o); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -511,7 +507,7 @@ func TestDB_CreateOrder(t *testing.T) {
|
|||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, string(bucket), string(ordersByAccountIDTable))
|
||||
assert.Equals(t, string(key), o.AccountID)
|
||||
return nil, nosqldb.ErrNotFound
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
switch string(bucket) {
|
||||
|
@ -557,8 +553,8 @@ func TestDB_CreateOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
if err := db.CreateOrder(context.Background(), tc.o); err != nil {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.CreateOrder(context.Background(), tc.o); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -680,7 +676,7 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
|||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||
assert.Equals(t, key, []byte(accID))
|
||||
return nil, nosqldb.ErrNotFound
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||
|
@ -710,6 +706,34 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
|||
err: errors.Errorf("error saving orderIDs index for account %s", accID),
|
||||
}
|
||||
},
|
||||
"ok/no-old": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(ordersByAccountIDTable):
|
||||
return nil, database.ErrNotFound
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
switch string(bucket) {
|
||||
case string(ordersByAccountIDTable):
|
||||
assert.Equals(t, key, []byte(accID))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, nu, nil)
|
||||
return nil, true, nil
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
res: []string{},
|
||||
}
|
||||
},
|
||||
"ok/all-old-not-pending": func(t *testing.T) test {
|
||||
oldOids := []string{"foo", "bar"}
|
||||
bOldOids, err := json.Marshal(oldOids)
|
||||
|
@ -967,15 +991,15 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := DB{db: tc.db}
|
||||
d := DB{db: tc.db}
|
||||
var (
|
||||
res []string
|
||||
err error
|
||||
)
|
||||
if tc.addOids == nil {
|
||||
res, err = db.updateAddOrderIDs(context.Background(), accID)
|
||||
res, err = d.updateAddOrderIDs(context.Background(), accID)
|
||||
} else {
|
||||
res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...)
|
||||
res, err = d.updateAddOrderIDs(context.Background(), accID, tc.addOids...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -993,10 +1017,8 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.True(t, reflect.DeepEqual(res, tc.res))
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.True(t, reflect.DeepEqual(res, tc.res))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
190
acme/order.go
190
acme/order.go
|
@ -1,9 +1,11 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -12,10 +14,17 @@ import (
|
|||
"go.step.sm/crypto/x509util"
|
||||
)
|
||||
|
||||
type IdentifierType string
|
||||
|
||||
const (
|
||||
IP IdentifierType = "ip"
|
||||
DNS IdentifierType = "dns"
|
||||
)
|
||||
|
||||
// Identifier encodes the type that an order pertains to.
|
||||
type Identifier struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
Type IdentifierType `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// Order contains order metadata for the ACME protocol order type.
|
||||
|
@ -131,41 +140,13 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
|
|||
return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID)
|
||||
}
|
||||
|
||||
// RFC8555: The CSR MUST indicate the exact same set of requested
|
||||
// identifiers as the initial newOrder request. Identifiers of type "dns"
|
||||
// MUST appear either in the commonName portion of the requested subject
|
||||
// name or in an extensionRequest attribute [RFC2985] requesting a
|
||||
// subjectAltName extension, or both.
|
||||
if csr.Subject.CommonName != "" {
|
||||
csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
|
||||
}
|
||||
csr.DNSNames = uniqueSortedLowerNames(csr.DNSNames)
|
||||
orderNames := make([]string, len(o.Identifiers))
|
||||
for i, n := range o.Identifiers {
|
||||
orderNames[i] = n.Value
|
||||
}
|
||||
orderNames = uniqueSortedLowerNames(orderNames)
|
||||
// canonicalize the CSR to allow for comparison
|
||||
csr = canonicalize(csr)
|
||||
|
||||
// Validate identifier names against CSR alternative names.
|
||||
//
|
||||
// Note that with certificate templates we are not going to check for the
|
||||
// absence of other SANs as they will only be set if the templates allows
|
||||
// them.
|
||||
if len(csr.DNSNames) != len(orderNames) {
|
||||
return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
|
||||
}
|
||||
|
||||
sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames))
|
||||
for i := range csr.DNSNames {
|
||||
if csr.DNSNames[i] != orderNames[i] {
|
||||
return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
|
||||
}
|
||||
sans[i] = x509util.SubjectAlternativeName{
|
||||
Type: x509util.DNSType,
|
||||
Value: csr.DNSNames[i],
|
||||
}
|
||||
// retrieve the requested SANs for the Order
|
||||
sans, err := o.sans(csr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get authorizations from the ACME provisioner.
|
||||
|
@ -213,6 +194,123 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
|
|||
return nil
|
||||
}
|
||||
|
||||
func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativeName, error) {
|
||||
|
||||
var sans []x509util.SubjectAlternativeName
|
||||
|
||||
// order the DNS names and IP addresses, so that they can be compared against the canonicalized CSR
|
||||
orderNames := make([]string, numberOfIdentifierType(DNS, o.Identifiers))
|
||||
orderIPs := make([]net.IP, numberOfIdentifierType(IP, o.Identifiers))
|
||||
indexDNS, indexIP := 0, 0
|
||||
for _, n := range o.Identifiers {
|
||||
switch n.Type {
|
||||
case DNS:
|
||||
orderNames[indexDNS] = n.Value
|
||||
indexDNS++
|
||||
case IP:
|
||||
orderIPs[indexIP] = net.ParseIP(n.Value) // NOTE: this assumes are all valid IPs at this time; or will result in nil entries
|
||||
indexIP++
|
||||
default:
|
||||
return sans, NewErrorISE("unsupported identifier type in order: %s", n.Type)
|
||||
}
|
||||
}
|
||||
orderNames = uniqueSortedLowerNames(orderNames)
|
||||
orderIPs = uniqueSortedIPs(orderIPs)
|
||||
|
||||
totalNumberOfSANs := len(csr.DNSNames) + len(csr.IPAddresses)
|
||||
sans = make([]x509util.SubjectAlternativeName, totalNumberOfSANs)
|
||||
index := 0
|
||||
|
||||
// Validate identifier names against CSR alternative names.
|
||||
//
|
||||
// Note that with certificate templates we are not going to check for the
|
||||
// absence of other SANs as they will only be set if the template allows
|
||||
// them.
|
||||
if len(csr.DNSNames) != len(orderNames) {
|
||||
return sans, NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
|
||||
}
|
||||
|
||||
for i := range csr.DNSNames {
|
||||
if csr.DNSNames[i] != orderNames[i] {
|
||||
return sans, NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
|
||||
}
|
||||
sans[index] = x509util.SubjectAlternativeName{
|
||||
Type: x509util.DNSType,
|
||||
Value: csr.DNSNames[i],
|
||||
}
|
||||
index++
|
||||
}
|
||||
|
||||
if len(csr.IPAddresses) != len(orderIPs) {
|
||||
return sans, NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+
|
||||
"CSR IPs = %v, Order IPs = %v", csr.IPAddresses, orderIPs)
|
||||
}
|
||||
|
||||
for i := range csr.IPAddresses {
|
||||
if !ipsAreEqual(csr.IPAddresses[i], orderIPs[i]) {
|
||||
return sans, NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+
|
||||
"CSR IPs = %v, Order IPs = %v", csr.IPAddresses, orderIPs)
|
||||
}
|
||||
sans[index] = x509util.SubjectAlternativeName{
|
||||
Type: x509util.IPType,
|
||||
Value: csr.IPAddresses[i].String(),
|
||||
}
|
||||
index++
|
||||
}
|
||||
|
||||
return sans, nil
|
||||
}
|
||||
|
||||
// numberOfIdentifierType returns the number of Identifiers that
|
||||
// are of type typ.
|
||||
func numberOfIdentifierType(typ IdentifierType, ids []Identifier) int {
|
||||
c := 0
|
||||
for _, id := range ids {
|
||||
if id.Type == typ {
|
||||
c++
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// canonicalize canonicalizes a CSR so that it can be compared against an Order
|
||||
// NOTE: this effectively changes the order of SANs in the CSR, which may be OK,
|
||||
// but may not be expected.
|
||||
func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.CertificateRequest) {
|
||||
|
||||
// for clarity only; we're operating on the same object by pointer
|
||||
canonicalized = csr
|
||||
|
||||
// RFC8555: The CSR MUST indicate the exact same set of requested
|
||||
// identifiers as the initial newOrder request. Identifiers of type "dns"
|
||||
// MUST appear either in the commonName portion of the requested subject
|
||||
// name or in an extensionRequest attribute [RFC2985] requesting a
|
||||
// subjectAltName extension, or both.
|
||||
if csr.Subject.CommonName != "" {
|
||||
// nolint:gocritic
|
||||
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
|
||||
}
|
||||
canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames)
|
||||
canonicalized.IPAddresses = uniqueSortedIPs(csr.IPAddresses)
|
||||
|
||||
return canonicalized
|
||||
}
|
||||
|
||||
// ipsAreEqual compares IPs to be equal. Nil values (i.e. invalid IPs) are
|
||||
// not considered equal. IPv6 representations of IPv4 addresses are
|
||||
// considered equal to the IPv4 address in this implementation, which is
|
||||
// standard Go behavior. An example is "::ffff:192.168.42.42", which
|
||||
// is equal to "192.168.42.42". This is considered a known issue within
|
||||
// step and is tracked here too: https://github.com/golang/go/issues/37921.
|
||||
func ipsAreEqual(x, y net.IP) bool {
|
||||
if x == nil || y == nil {
|
||||
return false
|
||||
}
|
||||
return x.Equal(y)
|
||||
}
|
||||
|
||||
// uniqueSortedLowerNames returns the set of all unique names in the input after all
|
||||
// of them are lowercased. The returned names will be in their lowercased form
|
||||
// and sorted alphabetically.
|
||||
|
@ -228,3 +326,23 @@ func uniqueSortedLowerNames(names []string) (unique []string) {
|
|||
sort.Strings(unique)
|
||||
return
|
||||
}
|
||||
|
||||
// uniqueSortedIPs returns the set of all unique net.IPs in the input. They
|
||||
// are sorted by their bytes (octet) representation.
|
||||
func uniqueSortedIPs(ips []net.IP) (unique []net.IP) {
|
||||
type entry struct {
|
||||
ip net.IP
|
||||
}
|
||||
ipEntryMap := make(map[string]entry, len(ips))
|
||||
for _, ip := range ips {
|
||||
ipEntryMap[ip.String()] = entry{ip: ip}
|
||||
}
|
||||
unique = make([]net.IP, 0, len(ipEntryMap))
|
||||
for _, entry := range ipEntryMap {
|
||||
unique = append(unique, entry.ip)
|
||||
}
|
||||
sort.Slice(unique, func(i, j int) bool {
|
||||
return bytes.Compare(unique[i], unique[j]) < 0
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -5,12 +5,15 @@ import (
|
|||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/x509util"
|
||||
)
|
||||
|
||||
func TestOrder_UpdateStatus(t *testing.T) {
|
||||
|
@ -261,10 +264,10 @@ func TestOrder_UpdateStatus(t *testing.T) {
|
|||
}
|
||||
|
||||
type mockSignAuth struct {
|
||||
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
loadProvisionerByID func(string) (provisioner.Interface, error)
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
loadProvisionerByName func(string) (provisioner.Interface, error)
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
|
@ -276,9 +279,9 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S
|
|||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockSignAuth) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByID != nil {
|
||||
return m.loadProvisionerByID(id)
|
||||
func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByName != nil {
|
||||
return m.loadProvisionerByName(name)
|
||||
}
|
||||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
@ -364,61 +367,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
err: NewErrorISE("unrecognized order status: %s", o.Status),
|
||||
}
|
||||
},
|
||||
"fail/error-names-length-mismatch": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
ID: "oID",
|
||||
AccountID: "accID",
|
||||
Status: StatusReady,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
AuthorizationIDs: []string{"a", "b"},
|
||||
Identifiers: []Identifier{
|
||||
{Type: "dns", Value: "foo.internal"},
|
||||
{Type: "dns", Value: "bar.internal"},
|
||||
},
|
||||
}
|
||||
orderNames := []string{"bar.internal", "foo.internal"}
|
||||
csr := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "foo.internal",
|
||||
},
|
||||
}
|
||||
|
||||
return test{
|
||||
o: o,
|
||||
csr: csr,
|
||||
err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", []string{"foo.internal"}, orderNames),
|
||||
}
|
||||
},
|
||||
"fail/error-names-mismatch": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
ID: "oID",
|
||||
AccountID: "accID",
|
||||
Status: StatusReady,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
AuthorizationIDs: []string{"a", "b"},
|
||||
Identifiers: []Identifier{
|
||||
{Type: "dns", Value: "foo.internal"},
|
||||
{Type: "dns", Value: "bar.internal"},
|
||||
},
|
||||
}
|
||||
orderNames := []string{"bar.internal", "foo.internal"}
|
||||
csr := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "foo.internal",
|
||||
},
|
||||
DNSNames: []string{"zap.internal"},
|
||||
}
|
||||
|
||||
return test{
|
||||
o: o,
|
||||
csr: csr,
|
||||
err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", []string{"foo.internal", "zap.internal"}, orderNames),
|
||||
}
|
||||
},
|
||||
"fail/error-provisioner-auth": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
|
@ -650,7 +598,7 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
err: NewErrorISE("error updating order oID: force"),
|
||||
}
|
||||
},
|
||||
"ok/new-cert": func(t *testing.T) test {
|
||||
"ok/new-cert-dns": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
ID: "oID",
|
||||
|
@ -674,6 +622,131 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}}
|
||||
baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}}
|
||||
|
||||
return test{
|
||||
o: o,
|
||||
csr: csr,
|
||||
prov: &MockProvisioner{
|
||||
MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
assert.Equals(t, token, "")
|
||||
return nil, nil
|
||||
},
|
||||
MgetOptions: func() *provisioner.Options {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ca: &mockSignAuth{
|
||||
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
assert.Equals(t, _csr, csr)
|
||||
return []*x509.Certificate{foo, bar, baz}, nil
|
||||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
cert.ID = "certID"
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
assert.Equals(t, cert.OrderID, o.ID)
|
||||
assert.Equals(t, cert.Leaf, foo)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz})
|
||||
return nil
|
||||
},
|
||||
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
|
||||
assert.Equals(t, updo.CertificateID, "certID")
|
||||
assert.Equals(t, updo.Status, StatusValid)
|
||||
assert.Equals(t, updo.ID, o.ID)
|
||||
assert.Equals(t, updo.AccountID, o.AccountID)
|
||||
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
|
||||
assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
|
||||
assert.Equals(t, updo.Identifiers, o.Identifiers)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/new-cert-ip": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
ID: "oID",
|
||||
AccountID: "accID",
|
||||
Status: StatusReady,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
AuthorizationIDs: []string{"a", "b"},
|
||||
Identifiers: []Identifier{
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
{Type: "ip", Value: "192.168.43.42"},
|
||||
},
|
||||
}
|
||||
csr := &x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, // in case of IPs, no Common Name
|
||||
}
|
||||
|
||||
foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}}
|
||||
bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}}
|
||||
baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}}
|
||||
|
||||
return test{
|
||||
o: o,
|
||||
csr: csr,
|
||||
prov: &MockProvisioner{
|
||||
MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
assert.Equals(t, token, "")
|
||||
return nil, nil
|
||||
},
|
||||
MgetOptions: func() *provisioner.Options {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ca: &mockSignAuth{
|
||||
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
assert.Equals(t, _csr, csr)
|
||||
return []*x509.Certificate{foo, bar, baz}, nil
|
||||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
cert.ID = "certID"
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
assert.Equals(t, cert.OrderID, o.ID)
|
||||
assert.Equals(t, cert.Leaf, foo)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz})
|
||||
return nil
|
||||
},
|
||||
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
|
||||
assert.Equals(t, updo.CertificateID, "certID")
|
||||
assert.Equals(t, updo.Status, StatusValid)
|
||||
assert.Equals(t, updo.ID, o.ID)
|
||||
assert.Equals(t, updo.AccountID, o.AccountID)
|
||||
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
|
||||
assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
|
||||
assert.Equals(t, updo.Identifiers, o.Identifiers)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/new-cert-dns-and-ip": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
ID: "oID",
|
||||
AccountID: "accID",
|
||||
Status: StatusReady,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
AuthorizationIDs: []string{"a", "b"},
|
||||
Identifiers: []Identifier{
|
||||
{Type: "dns", Value: "foo.internal"},
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
},
|
||||
}
|
||||
csr := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "foo.internal",
|
||||
},
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.42.42")},
|
||||
}
|
||||
|
||||
foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}}
|
||||
bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}}
|
||||
baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}}
|
||||
|
||||
return test{
|
||||
o: o,
|
||||
csr: csr,
|
||||
|
@ -737,3 +810,592 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_uniqueSortedIPs(t *testing.T) {
|
||||
type args struct {
|
||||
ips []net.IP
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantUnique []net.IP
|
||||
}{
|
||||
{
|
||||
name: "ok/empty",
|
||||
args: args{
|
||||
ips: []net.IP{},
|
||||
},
|
||||
wantUnique: []net.IP{},
|
||||
},
|
||||
{
|
||||
name: "ok/single-ipv4",
|
||||
args: args{
|
||||
ips: []net.IP{net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
wantUnique: []net.IP{net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
{
|
||||
name: "ok/multiple-ipv4",
|
||||
args: args{
|
||||
ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.1")},
|
||||
},
|
||||
wantUnique: []net.IP{net.ParseIP("192.168.42.1"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
{
|
||||
name: "ok/unique-ipv4",
|
||||
args: args{
|
||||
ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
wantUnique: []net.IP{net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
{
|
||||
name: "ok/single-ipv6",
|
||||
args: args{
|
||||
ips: []net.IP{net.ParseIP("2001:db8::30")},
|
||||
},
|
||||
wantUnique: []net.IP{net.ParseIP("2001:db8::30")},
|
||||
},
|
||||
{
|
||||
name: "ok/multiple-ipv6",
|
||||
args: args{
|
||||
ips: []net.IP{net.ParseIP("2001:db8::30"), net.ParseIP("2001:db8::20"), net.ParseIP("2001:db8::10")},
|
||||
},
|
||||
wantUnique: []net.IP{net.ParseIP("2001:db8::10"), net.ParseIP("2001:db8::20"), net.ParseIP("2001:db8::30")},
|
||||
},
|
||||
{
|
||||
name: "ok/unique-ipv6",
|
||||
args: args{
|
||||
ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1")},
|
||||
},
|
||||
wantUnique: []net.IP{net.ParseIP("2001:db8::1")},
|
||||
},
|
||||
{
|
||||
name: "ok/mixed-ipv4-and-ipv6",
|
||||
args: args{
|
||||
ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1"), net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
wantUnique: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::1")},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotUnique := uniqueSortedIPs(tt.args.ips); !reflect.DeepEqual(gotUnique, tt.wantUnique) {
|
||||
t.Errorf("uniqueSortedIPs() = %v, want %v", gotUnique, tt.wantUnique)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_numberOfIdentifierType(t *testing.T) {
|
||||
type args struct {
|
||||
typ IdentifierType
|
||||
ids []Identifier
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "ok/no-identifiers",
|
||||
args: args{
|
||||
typ: DNS,
|
||||
ids: []Identifier{},
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "ok/no-dns",
|
||||
args: args{
|
||||
typ: DNS,
|
||||
ids: []Identifier{
|
||||
{
|
||||
Type: IP,
|
||||
Value: "192.168.42.42",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "ok/no-ips",
|
||||
args: args{
|
||||
typ: IP,
|
||||
ids: []Identifier{
|
||||
{
|
||||
Type: DNS,
|
||||
Value: "example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "ok/one-dns",
|
||||
args: args{
|
||||
typ: DNS,
|
||||
ids: []Identifier{
|
||||
{
|
||||
Type: DNS,
|
||||
Value: "example.com",
|
||||
},
|
||||
{
|
||||
Type: IP,
|
||||
Value: "192.168.42.42",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "ok/one-ip",
|
||||
args: args{
|
||||
typ: IP,
|
||||
ids: []Identifier{
|
||||
{
|
||||
Type: DNS,
|
||||
Value: "example.com",
|
||||
},
|
||||
{
|
||||
Type: IP,
|
||||
Value: "192.168.42.42",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "ok/more-dns",
|
||||
args: args{
|
||||
typ: DNS,
|
||||
ids: []Identifier{
|
||||
{
|
||||
Type: DNS,
|
||||
Value: "example.com",
|
||||
},
|
||||
{
|
||||
Type: DNS,
|
||||
Value: "*.example.com",
|
||||
},
|
||||
{
|
||||
Type: IP,
|
||||
Value: "192.168.42.42",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 2,
|
||||
},
|
||||
{
|
||||
name: "ok/more-ips",
|
||||
args: args{
|
||||
typ: IP,
|
||||
ids: []Identifier{
|
||||
{
|
||||
Type: DNS,
|
||||
Value: "example.com",
|
||||
},
|
||||
{
|
||||
Type: IP,
|
||||
Value: "192.168.42.42",
|
||||
},
|
||||
{
|
||||
Type: IP,
|
||||
Value: "192.168.42.43",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 2,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := numberOfIdentifierType(tt.args.typ, tt.args.ids); got != tt.want {
|
||||
t.Errorf("numberOfIdentifierType() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ipsAreEqual(t *testing.T) {
|
||||
type args struct {
|
||||
x net.IP
|
||||
y net.IP
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "ok/ipv4",
|
||||
args: args{
|
||||
x: net.ParseIP("192.168.42.42"),
|
||||
y: net.ParseIP("192.168.42.42"),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "fail/ipv4",
|
||||
args: args{
|
||||
x: net.ParseIP("192.168.42.42"),
|
||||
y: net.ParseIP("192.168.42.43"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ok/ipv6",
|
||||
args: args{
|
||||
x: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
|
||||
y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "fail/ipv6",
|
||||
args: args{
|
||||
x: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
|
||||
y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7335"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "fail/ipv4-and-ipv6",
|
||||
args: args{
|
||||
x: net.ParseIP("192.168.42.42"),
|
||||
y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ok/ipv4-mapped-to-ipv6",
|
||||
args: args{
|
||||
x: net.ParseIP("192.168.42.42"),
|
||||
y: net.ParseIP("::ffff:192.168.42.42"), // parsed to the same IPv4 by Go
|
||||
},
|
||||
want: true, // we expect this to happen; a known issue in which ipv4 mapped ipv6 addresses are considered the same as their ipv4 counterpart
|
||||
},
|
||||
{
|
||||
name: "fail/invalid-ipv4-and-valid-ipv6",
|
||||
args: args{
|
||||
x: net.ParseIP("192.168.42.1000"),
|
||||
y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "fail/valid-ipv4-and-invalid-ipv6",
|
||||
args: args{
|
||||
x: net.ParseIP("192.168.42.42"),
|
||||
y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:733400"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "fail/invalid-ipv4-and-invalid-ipv6",
|
||||
args: args{
|
||||
x: net.ParseIP("192.168.42.1000"),
|
||||
y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:1000000"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ipsAreEqual(tt.args.x, tt.args.y); got != tt.want {
|
||||
t.Errorf("ipsAreEqual() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_canonicalize(t *testing.T) {
|
||||
type args struct {
|
||||
csr *x509.CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantCanonicalized *x509.CertificateRequest
|
||||
}{
|
||||
{
|
||||
name: "ok/dns",
|
||||
args: args{
|
||||
csr: &x509.CertificateRequest{
|
||||
DNSNames: []string{"www.example.com", "example.com"},
|
||||
},
|
||||
},
|
||||
wantCanonicalized: &x509.CertificateRequest{
|
||||
DNSNames: []string{"example.com", "www.example.com"},
|
||||
IPAddresses: []net.IP{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok/common-name",
|
||||
args: args{
|
||||
csr: &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "example.com",
|
||||
},
|
||||
DNSNames: []string{"www.example.com"},
|
||||
},
|
||||
},
|
||||
wantCanonicalized: &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "example.com",
|
||||
},
|
||||
DNSNames: []string{"example.com", "www.example.com"},
|
||||
IPAddresses: []net.IP{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok/ipv4",
|
||||
args: args{
|
||||
csr: &x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
},
|
||||
wantCanonicalized: &x509.CertificateRequest{
|
||||
DNSNames: []string{},
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok/mixed",
|
||||
args: args{
|
||||
csr: &x509.CertificateRequest{
|
||||
DNSNames: []string{"www.example.com", "example.com"},
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
},
|
||||
wantCanonicalized: &x509.CertificateRequest{
|
||||
DNSNames: []string{"example.com", "www.example.com"},
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok/mixed-common-name",
|
||||
args: args{
|
||||
csr: &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "example.com",
|
||||
},
|
||||
DNSNames: []string{"www.example.com"},
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
},
|
||||
wantCanonicalized: &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "example.com",
|
||||
},
|
||||
DNSNames: []string{"example.com", "www.example.com"},
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotCanonicalized := canonicalize(tt.args.csr); !reflect.DeepEqual(gotCanonicalized, tt.wantCanonicalized) {
|
||||
t.Errorf("canonicalize() = %v, want %v", gotCanonicalized, tt.wantCanonicalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrder_sans(t *testing.T) {
|
||||
type fields struct {
|
||||
Identifiers []Identifier
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
csr *x509.CertificateRequest
|
||||
want []x509util.SubjectAlternativeName
|
||||
err *Error
|
||||
}{
|
||||
{
|
||||
name: "ok/dns",
|
||||
fields: fields{
|
||||
Identifiers: []Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
},
|
||||
},
|
||||
csr: &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "example.com",
|
||||
},
|
||||
},
|
||||
want: []x509util.SubjectAlternativeName{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "fail/error-names-length-mismatch",
|
||||
fields: fields{
|
||||
Identifiers: []Identifier{
|
||||
{Type: "dns", Value: "foo.internal"},
|
||||
{Type: "dns", Value: "bar.internal"},
|
||||
},
|
||||
},
|
||||
csr: &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "foo.internal",
|
||||
},
|
||||
},
|
||||
want: []x509util.SubjectAlternativeName{},
|
||||
err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", []string{"foo.internal"}, []string{"bar.internal", "foo.internal"}),
|
||||
},
|
||||
{
|
||||
name: "fail/error-names-mismatch",
|
||||
fields: fields{
|
||||
Identifiers: []Identifier{
|
||||
{Type: "dns", Value: "foo.internal"},
|
||||
{Type: "dns", Value: "bar.internal"},
|
||||
},
|
||||
},
|
||||
csr: &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "foo.internal",
|
||||
},
|
||||
DNSNames: []string{"zap.internal"},
|
||||
},
|
||||
want: []x509util.SubjectAlternativeName{},
|
||||
err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", []string{"foo.internal", "zap.internal"}, []string{"bar.internal", "foo.internal"}),
|
||||
},
|
||||
{
|
||||
name: "ok/ipv4",
|
||||
fields: fields{
|
||||
Identifiers: []Identifier{
|
||||
{Type: "ip", Value: "192.168.43.42"},
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
},
|
||||
},
|
||||
csr: &x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
want: []x509util.SubjectAlternativeName{
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
{Type: "ip", Value: "192.168.43.42"},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ok/ipv6",
|
||||
fields: fields{
|
||||
Identifiers: []Identifier{
|
||||
{Type: "ip", Value: "2001:0db8:85a3::8a2e:0370:7335"},
|
||||
{Type: "ip", Value: "2001:0db8:85a3::8a2e:0370:7334"},
|
||||
},
|
||||
},
|
||||
csr: &x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7335"), net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
|
||||
},
|
||||
want: []x509util.SubjectAlternativeName{
|
||||
{Type: "ip", Value: "2001:db8:85a3::8a2e:370:7334"},
|
||||
{Type: "ip", Value: "2001:db8:85a3::8a2e:370:7335"},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "fail/error-ips-length-mismatch",
|
||||
fields: fields{
|
||||
Identifiers: []Identifier{
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
{Type: "ip", Value: "192.168.43.42"},
|
||||
},
|
||||
},
|
||||
csr: &x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
want: []x509util.SubjectAlternativeName{},
|
||||
err: NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+
|
||||
"CSR IPs = %v, Order IPs = %v", []net.IP{net.ParseIP("192.168.42.42")}, []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}),
|
||||
},
|
||||
{
|
||||
name: "fail/error-ips-mismatch",
|
||||
fields: fields{
|
||||
Identifiers: []Identifier{
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
{Type: "ip", Value: "192.168.43.42"},
|
||||
},
|
||||
},
|
||||
csr: &x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.32")},
|
||||
},
|
||||
want: []x509util.SubjectAlternativeName{},
|
||||
err: NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+
|
||||
"CSR IPs = %v, Order IPs = %v", []net.IP{net.ParseIP("192.168.42.32"), net.ParseIP("192.168.42.42")}, []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}),
|
||||
},
|
||||
{
|
||||
name: "ok/mixed",
|
||||
fields: fields{
|
||||
Identifiers: []Identifier{
|
||||
{Type: "dns", Value: "foo.internal"},
|
||||
{Type: "dns", Value: "bar.internal"},
|
||||
{Type: "ip", Value: "192.168.43.42"},
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
{Type: "ip", Value: "2001:0db8:85a3:0000:0000:8a2e:0370:7334"},
|
||||
},
|
||||
},
|
||||
csr: &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "bar.internal",
|
||||
},
|
||||
DNSNames: []string{"foo.internal"},
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42"), net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")},
|
||||
},
|
||||
want: []x509util.SubjectAlternativeName{
|
||||
{Type: "dns", Value: "bar.internal"},
|
||||
{Type: "dns", Value: "foo.internal"},
|
||||
{Type: "ip", Value: "192.168.42.42"},
|
||||
{Type: "ip", Value: "192.168.43.42"},
|
||||
{Type: "ip", Value: "2001:db8:85a3::8a2e:370:7334"},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "fail/unsupported-identifier-type",
|
||||
fields: fields{
|
||||
Identifiers: []Identifier{
|
||||
{Type: "ipv4", Value: "192.168.42.42"},
|
||||
},
|
||||
},
|
||||
csr: &x509.CertificateRequest{
|
||||
IPAddresses: []net.IP{net.ParseIP("192.168.42.42")},
|
||||
},
|
||||
want: []x509util.SubjectAlternativeName{},
|
||||
err: NewError(ErrorServerInternalType, "unsupported identifier type in order: ipv4"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := &Order{
|
||||
Identifiers: tt.fields.Identifiers,
|
||||
}
|
||||
canonicalizedCSR := canonicalize(tt.csr)
|
||||
got, err := o.sans(canonicalizedCSR)
|
||||
if tt.err != nil {
|
||||
if err == nil {
|
||||
t.Errorf("Order.sans() = %v, want error; got none", got)
|
||||
return
|
||||
}
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
assert.Equals(t, k.Type, tt.err.Type)
|
||||
assert.Equals(t, k.Detail, tt.err.Detail)
|
||||
assert.Equals(t, k.Status, tt.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tt.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tt.err.Detail)
|
||||
default:
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Order.sans() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
41
api/api.go
41
api/api.go
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
|
@ -32,13 +33,13 @@ type Authority interface {
|
|||
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||
GetTLSOptions() *authority.TLSOptions
|
||||
GetTLSOptions() *config.TLSOptions
|
||||
Root(shasum string) (*x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
|
||||
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
GetEncryptedKey(kid string) (string, error)
|
||||
|
@ -239,9 +240,9 @@ type caHandler struct {
|
|||
}
|
||||
|
||||
// New creates a new RouterHandler with the CA endpoints.
|
||||
func New(authority Authority) RouterHandler {
|
||||
func New(auth Authority) RouterHandler {
|
||||
return &caHandler{
|
||||
Authority: authority,
|
||||
Authority: auth,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -294,7 +295,7 @@ func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
|||
// certificate for the given SHA256.
|
||||
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
||||
sha := chi.URLParam(r, "sha")
|
||||
sum := strings.ToLower(strings.Replace(sha, "-", "", -1))
|
||||
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
|
||||
// Load root certificate with the
|
||||
cert, err := h.Authority.Root(sum)
|
||||
if err != nil {
|
||||
|
@ -315,7 +316,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
|||
|
||||
// Provisioners returns the list of provisioners configured in the authority.
|
||||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := parseCursor(r)
|
||||
cursor, limit, err := ParseCursor(r)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
return
|
||||
|
@ -399,7 +400,7 @@ func logOtt(w http.ResponseWriter, token string) {
|
|||
func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
m := map[string]interface{}{
|
||||
"serial": cert.SerialNumber,
|
||||
"serial": cert.SerialNumber.String(),
|
||||
"subject": cert.Subject.CommonName,
|
||||
"issuer": cert.Issuer.CommonName,
|
||||
"valid-from": cert.NotBefore.Format(time.RFC3339),
|
||||
|
@ -408,25 +409,27 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
|||
"certificate": base64.StdEncoding.EncodeToString(cert.Raw),
|
||||
}
|
||||
for _, ext := range cert.Extensions {
|
||||
if ext.Id.Equal(oidStepProvisioner) {
|
||||
val := &stepProvisioner{}
|
||||
rest, err := asn1.Unmarshal(ext.Value, val)
|
||||
if err != nil || len(rest) > 0 {
|
||||
break
|
||||
}
|
||||
if len(val.CredentialID) > 0 {
|
||||
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
|
||||
} else {
|
||||
m["provisioner"] = fmt.Sprintf("%s", val.Name)
|
||||
}
|
||||
if !ext.Id.Equal(oidStepProvisioner) {
|
||||
continue
|
||||
}
|
||||
val := &stepProvisioner{}
|
||||
rest, err := asn1.Unmarshal(ext.Value, val)
|
||||
if err != nil || len(rest) > 0 {
|
||||
break
|
||||
}
|
||||
if len(val.CredentialID) > 0 {
|
||||
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
|
||||
} else {
|
||||
m["provisioner"] = string(val.Name)
|
||||
}
|
||||
break
|
||||
}
|
||||
rl.WithFields(m)
|
||||
}
|
||||
}
|
||||
|
||||
func parseCursor(r *http.Request) (cursor string, limit int, err error) {
|
||||
// ParseCursor parses the cursor and limit from the request query params.
|
||||
func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
|
||||
q := r.URL.Query()
|
||||
cursor = q.Get("cursor")
|
||||
if v := q.Get("limit"); len(v) > 0 {
|
||||
|
|
|
@ -186,8 +186,8 @@ func TestCertificate_MarshalJSON(t *testing.T) {
|
|||
}{
|
||||
{"nil", fields{Certificate: nil}, []byte("null"), false},
|
||||
{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
|
||||
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"`), false},
|
||||
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`), false},
|
||||
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"`), false},
|
||||
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -219,11 +219,11 @@ func TestCertificate_UnmarshalJSON(t *testing.T) {
|
|||
{"invalid string", []byte(`"foobar"`), false, true},
|
||||
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
||||
{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true},
|
||||
{"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false, true},
|
||||
{"invalid type", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), false, true},
|
||||
{"empty string", []byte(`""`), false, false},
|
||||
{"json null", []byte(`null`), false, false},
|
||||
{"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true, false},
|
||||
{"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), true, false},
|
||||
{"valid root", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), true, false},
|
||||
{"valid cert", []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"`), true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -251,7 +251,7 @@ func TestCertificate_UnmarshalJSON_json(t *testing.T) {
|
|||
{"empty crt (null)", `{"crt":null}`, false, false},
|
||||
{"empty crt (string)", `{"crt":""}`, false, false},
|
||||
{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true},
|
||||
{"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, true, false},
|
||||
{"valid crt", `{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"}`, true, false},
|
||||
}
|
||||
|
||||
type request struct {
|
||||
|
@ -297,7 +297,7 @@ func TestCertificateRequest_MarshalJSON(t *testing.T) {
|
|||
}{
|
||||
{"nil", fields{CertificateRequest: nil}, []byte("null"), false},
|
||||
{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
|
||||
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `\n"`), false},
|
||||
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `\n"`), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -329,10 +329,10 @@ func TestCertificateRequest_UnmarshalJSON(t *testing.T) {
|
|||
{"invalid string", []byte(`"foobar"`), false, true},
|
||||
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
||||
{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true},
|
||||
{"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false, true},
|
||||
{"invalid type", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), false, true},
|
||||
{"empty string", []byte(`""`), false, false},
|
||||
{"json null", []byte(`null`), false, false},
|
||||
{"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true, false},
|
||||
{"valid csr", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -360,7 +360,7 @@ func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) {
|
|||
{"empty csr (null)", `{"csr":null}`, false, false},
|
||||
{"empty csr (string)", `{"csr":""}`, false, false},
|
||||
{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true},
|
||||
{"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, true, false},
|
||||
{"valid csr", `{"csr":"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"}`, true, false},
|
||||
}
|
||||
|
||||
type request struct {
|
||||
|
@ -430,6 +430,7 @@ type mockProvisioner struct {
|
|||
ret1, ret2, ret3 interface{}
|
||||
err error
|
||||
getID func() string
|
||||
getIDForToken func() string
|
||||
getTokenID func(string) (string, error)
|
||||
getName func() string
|
||||
getType func() provisioner.Type
|
||||
|
@ -452,6 +453,13 @@ func (m *mockProvisioner) GetID() string {
|
|||
return m.ret1.(string)
|
||||
}
|
||||
|
||||
func (m *mockProvisioner) GetIDForToken() string {
|
||||
if m.getIDForToken != nil {
|
||||
return m.getIDForToken()
|
||||
}
|
||||
return m.ret1.(string)
|
||||
}
|
||||
|
||||
func (m *mockProvisioner) GetTokenID(token string) (string, error) {
|
||||
if m.getTokenID != nil {
|
||||
return m.getTokenID(token)
|
||||
|
@ -553,7 +561,7 @@ type mockAuthority struct {
|
|||
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||
loadProvisionerByID func(provID string) (provisioner.Interface, error)
|
||||
loadProvisionerByName func(name string) (provisioner.Interface, error)
|
||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
revoke func(context.Context, *authority.RevokeOptions) error
|
||||
getEncryptedKey func(kid string) (string, error)
|
||||
|
@ -633,9 +641,9 @@ func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (pr
|
|||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByID != nil {
|
||||
return m.loadProvisionerByID(provID)
|
||||
func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByName != nil {
|
||||
return m.loadProvisionerByName(name)
|
||||
}
|
||||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
@ -731,7 +739,7 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin
|
|||
return m.ret1.(bool), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) {
|
||||
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||
if m.getSSHBastion != nil {
|
||||
return m.getSSHBastion(ctx, user, hostname)
|
||||
}
|
||||
|
@ -808,7 +816,7 @@ func Test_caHandler_Root(t *testing.T) {
|
|||
req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil)
|
||||
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
|
||||
|
||||
expected := []byte(`{"ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`)
|
||||
expected := []byte(`{"ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"}`)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -852,8 +860,8 @@ func Test_caHandler_Sign(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected1 := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
expected2 := []byte(`{"crt":"` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||
expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -926,7 +934,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
||||
}
|
||||
|
||||
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -987,7 +995,7 @@ func Test_caHandler_Rekey(t *testing.T) {
|
|||
{"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest},
|
||||
}
|
||||
|
||||
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -1202,7 +1210,7 @@ func Test_caHandler_Roots(t *testing.T) {
|
|||
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||
}
|
||||
|
||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -1248,7 +1256,7 @@ func Test_caHandler_Federation(t *testing.T) {
|
|||
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||
}
|
||||
|
||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||
expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/scep"
|
||||
|
@ -19,6 +20,9 @@ func WriteError(w http.ResponseWriter, err error) {
|
|||
case *acme.Error:
|
||||
acme.WriteError(w, k)
|
||||
return
|
||||
case *admin.Error:
|
||||
admin.WriteError(w, k)
|
||||
return
|
||||
case *scep.Error:
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
default:
|
||||
|
@ -46,12 +50,10 @@ func WriteError(w http.ResponseWriter, err error) {
|
|||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
} else {
|
||||
if e, ok := cause.(errs.StackTracer); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
}
|
||||
} else if e, ok := cause.(errs.StackTracer); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
12
api/sign.go
12
api/sign.go
|
@ -5,7 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
@ -37,11 +37,11 @@ func (s *SignRequest) Validate() error {
|
|||
|
||||
// SignResponse is the response object of the certificate signature request.
|
||||
type SignResponse struct {
|
||||
ServerPEM Certificate `json:"crt"`
|
||||
CaPEM Certificate `json:"ca"`
|
||||
CertChainPEM []Certificate `json:"certChain"`
|
||||
TLSOptions *authority.TLSOptions `json:"tlsOptions,omitempty"`
|
||||
TLS *tls.ConnectionState `json:"-"`
|
||||
ServerPEM Certificate `json:"crt"`
|
||||
CaPEM Certificate `json:"ca"`
|
||||
CertChainPEM []Certificate `json:"certChain"`
|
||||
TLSOptions *config.TLSOptions `json:"tlsOptions,omitempty"`
|
||||
TLS *tls.ConnectionState `json:"-"`
|
||||
}
|
||||
|
||||
// Sign is an HTTP handler that reads a certificate request and an
|
||||
|
|
25
api/ssh.go
25
api/ssh.go
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
|
@ -22,12 +23,12 @@ type SSHAuthority interface {
|
|||
RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error)
|
||||
GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error)
|
||||
GetSSHRoots(ctx context.Context) (*config.SSHKeys, error)
|
||||
GetSSHFederation(ctx context.Context) (*config.SSHKeys, error)
|
||||
GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
|
||||
CheckSSHHost(ctx context.Context, principal string, token string) (bool, error)
|
||||
GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error)
|
||||
GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
|
||||
GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
|
||||
GetSSHBastion(ctx context.Context, user string, hostname string) (*config.Bastion, error)
|
||||
}
|
||||
|
||||
// SSHSignRequest is the request body of an SSH certificate request.
|
||||
|
@ -51,7 +52,7 @@ func (s *SSHSignRequest) Validate() error {
|
|||
return errors.Errorf("unknown certType %s", s.CertType)
|
||||
case len(s.PublicKey) == 0:
|
||||
return errors.New("missing or empty publicKey")
|
||||
case len(s.OTT) == 0:
|
||||
case s.OTT == "":
|
||||
return errors.New("missing or empty ott")
|
||||
default:
|
||||
// Validate identity signature if provided
|
||||
|
@ -86,7 +87,7 @@ type SSHCertificate struct {
|
|||
// SSHGetHostsResponse is the response object that returns the list of valid
|
||||
// hosts for SSH.
|
||||
type SSHGetHostsResponse struct {
|
||||
Hosts []authority.Host `json:"hosts"`
|
||||
Hosts []config.Host `json:"hosts"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface. Returns a quoted,
|
||||
|
@ -239,8 +240,8 @@ func (r *SSHBastionRequest) Validate() error {
|
|||
// SSHBastionResponse is the response body used to return the bastion for a
|
||||
// given host.
|
||||
type SSHBastionResponse struct {
|
||||
Hostname string `json:"hostname"`
|
||||
Bastion *authority.Bastion `json:"bastion,omitempty"`
|
||||
Hostname string `json:"hostname"`
|
||||
Bastion *config.Bastion `json:"bastion,omitempty"`
|
||||
}
|
||||
|
||||
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||
|
@ -407,18 +408,18 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var config SSHConfigResponse
|
||||
var cfg SSHConfigResponse
|
||||
switch body.Type {
|
||||
case provisioner.SSHUserCert:
|
||||
config.UserTemplates = ts
|
||||
cfg.UserTemplates = ts
|
||||
case provisioner.SSHHostCert:
|
||||
config.HostTemplates = ts
|
||||
cfg.HostTemplates = ts
|
||||
default:
|
||||
WriteError(w, errs.InternalServer("it should hot get here"))
|
||||
return
|
||||
}
|
||||
|
||||
JSON(w, config)
|
||||
JSON(w, cfg)
|
||||
}
|
||||
|
||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||
|
|
|
@ -2,6 +2,7 @@ package api
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
|
@ -18,7 +19,7 @@ type SSHRekeyRequest struct {
|
|||
// Validate validates the SSHSignRekey.
|
||||
func (s *SSHRekeyRequest) Validate() error {
|
||||
switch {
|
||||
case len(s.OTT) == 0:
|
||||
case s.OTT == "":
|
||||
return errors.New("missing or empty ott")
|
||||
case len(s.PublicKey) == 0:
|
||||
return errors.New("missing or empty public key")
|
||||
|
@ -72,7 +73,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r)
|
||||
// Match identity cert with the SSH cert
|
||||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
|
@ -16,7 +18,7 @@ type SSHRenewRequest struct {
|
|||
// Validate validates the SSHSignRequest.
|
||||
func (s *SSHRenewRequest) Validate() error {
|
||||
switch {
|
||||
case len(s.OTT) == 0:
|
||||
case s.OTT == "":
|
||||
return errors.New("missing or empty ott")
|
||||
default:
|
||||
return nil
|
||||
|
@ -62,7 +64,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r)
|
||||
// Match identity cert with the SSH cert
|
||||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
WriteError(w, errs.ForbiddenErr(err))
|
||||
return
|
||||
|
@ -74,13 +80,28 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// renewIdentityCertificate request the client TLS certificate if present.
|
||||
func (h *caHandler) renewIdentityCertificate(r *http.Request) ([]Certificate, error) {
|
||||
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
|
||||
func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
|
||||
// Clone the certificate as we can modify it.
|
||||
cert, err := x509.ParseCertificate(r.TLS.PeerCertificates[0].Raw)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing client certificate")
|
||||
}
|
||||
|
||||
// Enforce the cert to match another certificate, for example an ssh
|
||||
// certificate.
|
||||
if !notBefore.IsZero() {
|
||||
cert.NotBefore = notBefore
|
||||
}
|
||||
if !notAfter.IsZero() {
|
||||
cert.NotAfter = notAfter
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
|||
if !r.Passive {
|
||||
return errs.NotImplemented("non-passive revocation not implemented")
|
||||
}
|
||||
if len(r.OTT) == 0 {
|
||||
if r.OTT == "" {
|
||||
return errs.BadRequest("missing ott")
|
||||
}
|
||||
return
|
||||
|
|
|
@ -284,7 +284,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
|||
identityCerts := []*x509.Certificate{
|
||||
parseCertificate(certPEM),
|
||||
}
|
||||
identityCertsPEM := []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`)
|
||||
identityCertsPEM := []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
36
api/utils.go
36
api/utils.go
|
@ -3,11 +3,14 @@ package api
|
|||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// EnableLogger is an interface that enables response logging for an object.
|
||||
|
@ -64,6 +67,29 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
|||
LogEnabledResponse(w, v)
|
||||
}
|
||||
|
||||
// ProtoJSON writes the passed value into the http.ResponseWriter.
|
||||
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
|
||||
ProtoJSONStatus(w, m, http.StatusOK)
|
||||
}
|
||||
|
||||
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
|
||||
// given status is written as the status code of the response.
|
||||
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
b, err := protojson.Marshal(m)
|
||||
if err != nil {
|
||||
LogError(w, err)
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(b); err != nil {
|
||||
LogError(w, err)
|
||||
return
|
||||
}
|
||||
//LogEnabledResponse(w, v)
|
||||
}
|
||||
|
||||
// ReadJSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
func ReadJSON(r io.Reader, v interface{}) error {
|
||||
|
@ -72,3 +98,13 @@ func ReadJSON(r io.Reader, v interface{}) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadProtoJSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
func ReadProtoJSON(r io.Reader, m proto.Message) error {
|
||||
data, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusBadRequest, err, "error reading request body")
|
||||
}
|
||||
return protojson.Unmarshal(data, m)
|
||||
}
|
||||
|
|
160
authority/admin/api/admin.go
Normal file
160
authority/admin/api/admin.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// CreateAdminRequest represents the body for a CreateAdmin request.
|
||||
type CreateAdminRequest struct {
|
||||
Subject string `json:"subject"`
|
||||
Provisioner string `json:"provisioner"`
|
||||
Type linkedca.Admin_Type `json:"type"`
|
||||
}
|
||||
|
||||
// Validate validates a new-admin request body.
|
||||
func (car *CreateAdminRequest) Validate() error {
|
||||
if car.Subject == "" {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "subject cannot be empty")
|
||||
}
|
||||
if car.Provisioner == "" {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "provisioner cannot be empty")
|
||||
}
|
||||
switch car.Type {
|
||||
case linkedca.Admin_SUPER_ADMIN, linkedca.Admin_ADMIN:
|
||||
default:
|
||||
return admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAdminsResponse for returning a list of admins.
|
||||
type GetAdminsResponse struct {
|
||||
Admins []*linkedca.Admin `json:"admins"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}
|
||||
|
||||
// UpdateAdminRequest represents the body for a UpdateAdmin request.
|
||||
type UpdateAdminRequest struct {
|
||||
Type linkedca.Admin_Type `json:"type"`
|
||||
}
|
||||
|
||||
// Validate validates a new-admin request body.
|
||||
func (uar *UpdateAdminRequest) Validate() error {
|
||||
switch uar.Type {
|
||||
case linkedca.Admin_SUPER_ADMIN, linkedca.Admin_ADMIN:
|
||||
default:
|
||||
return admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteResponse is the resource for successful DELETE responses.
|
||||
type DeleteResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// GetAdmin returns the requested admin, or an error.
|
||||
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
adm, ok := h.auth.LoadAdminByID(id)
|
||||
if !ok {
|
||||
api.WriteError(w, admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found", id))
|
||||
return
|
||||
}
|
||||
api.ProtoJSON(w, adm)
|
||||
}
|
||||
|
||||
// GetAdmins returns a segment of admins associated with the authority.
|
||||
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
"error parsing cursor and limit from query params"))
|
||||
return
|
||||
}
|
||||
|
||||
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
||||
return
|
||||
}
|
||||
api.JSON(w, &GetAdminsResponse{
|
||||
Admins: admins,
|
||||
NextCursor: nextCursor,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateAdmin creates a new admin.
|
||||
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body CreateAdminRequest
|
||||
if err := api.ReadJSON(r.Body, &body); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
||||
return
|
||||
}
|
||||
adm := &linkedca.Admin{
|
||||
ProvisionerId: p.GetID(),
|
||||
Subject: body.Subject,
|
||||
Type: body.Type,
|
||||
}
|
||||
// Store to authority collection.
|
||||
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error storing admin"))
|
||||
return
|
||||
}
|
||||
|
||||
api.ProtoJSONStatus(w, adm, http.StatusCreated)
|
||||
}
|
||||
|
||||
// DeleteAdmin deletes admin.
|
||||
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
api.JSON(w, &DeleteResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// UpdateAdmin updates an existing admin.
|
||||
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body UpdateAdminRequest
|
||||
if err := api.ReadJSON(r.Body, &body); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
api.ProtoJSON(w, adm)
|
||||
}
|
41
authority/admin/api/handler.go
Normal file
41
authority/admin/api/handler.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
// Handler is the ACME API request handler.
|
||||
type Handler struct {
|
||||
db admin.DB
|
||||
auth *authority.Authority
|
||||
}
|
||||
|
||||
// NewHandler returns a new Authority Config Handler.
|
||||
func NewHandler(auth *authority.Authority) api.RouterHandler {
|
||||
h := &Handler{db: auth.GetAdminDatabase(), auth: auth}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
authnz := func(next nextHTTP) nextHTTP {
|
||||
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next))
|
||||
}
|
||||
|
||||
// Provisioners
|
||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners))
|
||||
r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner))
|
||||
r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner))
|
||||
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner))
|
||||
|
||||
// Admins
|
||||
r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin))
|
||||
r.MethodFunc("GET", "/admins", authnz(h.GetAdmins))
|
||||
r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin))
|
||||
r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin))
|
||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin))
|
||||
}
|
54
authority/admin/api/middleware.go
Normal file
54
authority/admin/api/middleware.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
||||
// requireAPIEnabled is a middleware that ensures the Administration API
|
||||
// is enabled before servicing requests.
|
||||
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.auth.IsAdminAPIEnabled() {
|
||||
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"administration API not enabled"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
|
||||
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
tok := r.Header.Get("Authorization")
|
||||
if tok == "" {
|
||||
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
"missing authorization header token"))
|
||||
return
|
||||
}
|
||||
|
||||
adm, err := h.auth.AuthorizeAdminToken(r, tok)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), adminContextKey, adm)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// ContextKey is the key type for storing and searching for ACME request
|
||||
// essentials in the context of a request.
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// adminContextKey account key
|
||||
adminContextKey = ContextKey("admin")
|
||||
)
|
175
authority/admin/api/provisioner.go
Normal file
175
authority/admin/api/provisioner.go
Normal file
|
@ -0,0 +1,175 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// GetProvisionersResponse is the type for GET /admin/provisioners responses.
|
||||
type GetProvisionersResponse struct {
|
||||
Provisioners provisioner.List `json:"provisioners"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}
|
||||
|
||||
// GetProvisioner returns the requested provisioner, or an error.
|
||||
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if len(id) > 0 {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
prov, err := h.db.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
api.ProtoJSON(w, prov)
|
||||
}
|
||||
|
||||
// GetProvisioners returns the given segment of provisioners associated with the authority.
|
||||
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
"error parsing cursor & limit query params"))
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := h.auth.GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
api.WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
api.JSON(w, &GetProvisionersResponse{
|
||||
Provisioners: p,
|
||||
NextCursor: next,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateProvisioner creates a new prov.
|
||||
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var prov = new(linkedca.Provisioner)
|
||||
if err := api.ReadProtoJSON(r.Body, prov); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Validate inputs
|
||||
if err := authority.ValidateClaims(prov.Claims); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
||||
return
|
||||
}
|
||||
api.ProtoJSONStatus(w, prov, http.StatusCreated)
|
||||
}
|
||||
|
||||
// DeleteProvisioner deletes a provisioner.
|
||||
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if len(id) > 0 {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
api.JSON(w, &DeleteResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// UpdateProvisioner updates an existing prov.
|
||||
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var nu = new(linkedca.Provisioner)
|
||||
if err := api.ReadProtoJSON(r.Body, nu); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
name := chi.URLParam(r, "name")
|
||||
_old, err := h.auth.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
||||
return
|
||||
}
|
||||
|
||||
old, err := h.db.GetProvisioner(r.Context(), _old.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
|
||||
return
|
||||
}
|
||||
|
||||
if nu.Id != old.Id {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID"))
|
||||
return
|
||||
}
|
||||
if nu.Type != old.Type {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner type"))
|
||||
return
|
||||
}
|
||||
if nu.AuthorityId != old.AuthorityId {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID"))
|
||||
return
|
||||
}
|
||||
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt"))
|
||||
return
|
||||
}
|
||||
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Validate inputs
|
||||
if err := authority.ValidateClaims(nu.Claims); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
api.ProtoJSON(w, nu)
|
||||
}
|
179
authority/admin/db.go
Normal file
179
authority/admin/db.go
Normal file
|
@ -0,0 +1,179 @@
|
|||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultAuthorityID is the default AuthorityID. This will be the ID
|
||||
// of the first Authority created, as well as the default AuthorityID
|
||||
// if one is not specified in the configuration.
|
||||
DefaultAuthorityID = "00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
|
||||
// ErrNotFound is an error that should be used by the authority.DB interface to
|
||||
// indicate that an entity does not exist.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// UnmarshalProvisionerDetails unmarshals details type to the specific provisioner details.
|
||||
func UnmarshalProvisionerDetails(typ linkedca.Provisioner_Type, data []byte) (*linkedca.ProvisionerDetails, error) {
|
||||
var v linkedca.ProvisionerDetails
|
||||
switch typ {
|
||||
case linkedca.Provisioner_JWK:
|
||||
v.Data = new(linkedca.ProvisionerDetails_JWK)
|
||||
case linkedca.Provisioner_OIDC:
|
||||
v.Data = new(linkedca.ProvisionerDetails_OIDC)
|
||||
case linkedca.Provisioner_GCP:
|
||||
v.Data = new(linkedca.ProvisionerDetails_GCP)
|
||||
case linkedca.Provisioner_AWS:
|
||||
v.Data = new(linkedca.ProvisionerDetails_AWS)
|
||||
case linkedca.Provisioner_AZURE:
|
||||
v.Data = new(linkedca.ProvisionerDetails_Azure)
|
||||
case linkedca.Provisioner_ACME:
|
||||
v.Data = new(linkedca.ProvisionerDetails_ACME)
|
||||
case linkedca.Provisioner_X5C:
|
||||
v.Data = new(linkedca.ProvisionerDetails_X5C)
|
||||
case linkedca.Provisioner_K8SSA:
|
||||
v.Data = new(linkedca.ProvisionerDetails_K8SSA)
|
||||
case linkedca.Provisioner_SSHPOP:
|
||||
v.Data = new(linkedca.ProvisionerDetails_SSHPOP)
|
||||
case linkedca.Provisioner_SCEP:
|
||||
v.Data = new(linkedca.ProvisionerDetails_SCEP)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provisioner type %s", typ)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, v.Data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &linkedca.ProvisionerDetails{Data: v.Data}, nil
|
||||
}
|
||||
|
||||
// DB is the DB interface expected by the step-ca Admin API.
|
||||
type DB interface {
|
||||
CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error)
|
||||
GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error)
|
||||
UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
DeleteProvisioner(ctx context.Context, id string) error
|
||||
|
||||
CreateAdmin(ctx context.Context, admin *linkedca.Admin) error
|
||||
GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error)
|
||||
GetAdmins(ctx context.Context) ([]*linkedca.Admin, error)
|
||||
UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error
|
||||
DeleteAdmin(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
// a mock in tests.
|
||||
type MockDB struct {
|
||||
MockCreateProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
MockGetProvisioner func(ctx context.Context, id string) (*linkedca.Provisioner, error)
|
||||
MockGetProvisioners func(ctx context.Context) ([]*linkedca.Provisioner, error)
|
||||
MockUpdateProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
MockDeleteProvisioner func(ctx context.Context, id string) error
|
||||
|
||||
MockCreateAdmin func(ctx context.Context, adm *linkedca.Admin) error
|
||||
MockGetAdmin func(ctx context.Context, id string) (*linkedca.Admin, error)
|
||||
MockGetAdmins func(ctx context.Context) ([]*linkedca.Admin, error)
|
||||
MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error
|
||||
MockDeleteAdmin func(ctx context.Context, id string) error
|
||||
|
||||
MockError error
|
||||
MockRet1 interface{}
|
||||
}
|
||||
|
||||
// CreateProvisioner mock.
|
||||
func (m *MockDB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
if m.MockCreateProvisioner != nil {
|
||||
return m.MockCreateProvisioner(ctx, prov)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetProvisioner mock.
|
||||
func (m *MockDB) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
if m.MockGetProvisioner != nil {
|
||||
return m.MockGetProvisioner(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Provisioner), m.MockError
|
||||
}
|
||||
|
||||
// GetProvisioners mock
|
||||
func (m *MockDB) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
||||
if m.MockGetProvisioners != nil {
|
||||
return m.MockGetProvisioners(ctx)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.([]*linkedca.Provisioner), m.MockError
|
||||
}
|
||||
|
||||
// UpdateProvisioner mock
|
||||
func (m *MockDB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
if m.MockUpdateProvisioner != nil {
|
||||
return m.MockUpdateProvisioner(ctx, prov)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// DeleteProvisioner mock
|
||||
func (m *MockDB) DeleteProvisioner(ctx context.Context, id string) error {
|
||||
if m.MockDeleteProvisioner != nil {
|
||||
return m.MockDeleteProvisioner(ctx, id)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateAdmin mock
|
||||
func (m *MockDB) CreateAdmin(ctx context.Context, admin *linkedca.Admin) error {
|
||||
if m.MockCreateAdmin != nil {
|
||||
return m.MockCreateAdmin(ctx, admin)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAdmin mock.
|
||||
func (m *MockDB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) {
|
||||
if m.MockGetAdmin != nil {
|
||||
return m.MockGetAdmin(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Admin), m.MockError
|
||||
}
|
||||
|
||||
// GetAdmins mock
|
||||
func (m *MockDB) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||
if m.MockGetAdmins != nil {
|
||||
return m.MockGetAdmins(ctx)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.([]*linkedca.Admin), m.MockError
|
||||
}
|
||||
|
||||
// UpdateAdmin mock
|
||||
func (m *MockDB) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
if m.MockUpdateAdmin != nil {
|
||||
return m.MockUpdateAdmin(ctx, adm)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// DeleteAdmin mock
|
||||
func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error {
|
||||
if m.MockDeleteAdmin != nil {
|
||||
return m.MockDeleteAdmin(ctx, id)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
178
authority/admin/db/nosql/admin.go
Normal file
178
authority/admin/db/nosql/admin.go
Normal file
|
@ -0,0 +1,178 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/nosql"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// dbAdmin is the database representation of the Admin type.
|
||||
type dbAdmin struct {
|
||||
ID string `json:"id"`
|
||||
AuthorityID string `json:"authorityID"`
|
||||
ProvisionerID string `json:"provisionerID"`
|
||||
Subject string `json:"subject"`
|
||||
Type linkedca.Admin_Type `json:"type"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
DeletedAt time.Time `json:"deletedAt"`
|
||||
}
|
||||
|
||||
func (dba *dbAdmin) convert() *linkedca.Admin {
|
||||
return &linkedca.Admin{
|
||||
Id: dba.ID,
|
||||
AuthorityId: dba.AuthorityID,
|
||||
ProvisionerId: dba.ProvisionerID,
|
||||
Subject: dba.Subject,
|
||||
Type: dba.Type,
|
||||
CreatedAt: timestamppb.New(dba.CreatedAt),
|
||||
DeletedAt: timestamppb.New(dba.DeletedAt),
|
||||
}
|
||||
}
|
||||
|
||||
func (dba *dbAdmin) clone() *dbAdmin {
|
||||
u := *dba
|
||||
return &u
|
||||
}
|
||||
|
||||
func (db *DB) getDBAdminBytes(ctx context.Context, id string) ([]byte, error) {
|
||||
data, err := db.db.Get(adminsTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading admin %s", id)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalDBAdmin(data []byte, id string) (*dbAdmin, error) {
|
||||
var dba = new(dbAdmin)
|
||||
if err := json.Unmarshal(data, dba); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling admin %s into dbAdmin", id)
|
||||
}
|
||||
if !dba.DeletedAt.IsZero() {
|
||||
return nil, admin.NewError(admin.ErrorDeletedType, "admin %s is deleted", id)
|
||||
}
|
||||
if dba.AuthorityID != db.authorityID {
|
||||
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
|
||||
"admin %s is not owned by authority %s", dba.ID, db.authorityID)
|
||||
}
|
||||
return dba, nil
|
||||
}
|
||||
|
||||
func (db *DB) getDBAdmin(ctx context.Context, id string) (*dbAdmin, error) {
|
||||
data, err := db.getDBAdminBytes(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dba, err := db.unmarshalDBAdmin(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dba, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalAdmin(data []byte, id string) (*linkedca.Admin, error) {
|
||||
dba, err := db.unmarshalDBAdmin(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dba.convert(), nil
|
||||
}
|
||||
|
||||
// GetAdmin retrieves and unmarshals a admin from the database.
|
||||
func (db *DB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) {
|
||||
data, err := db.getDBAdminBytes(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adm, err := db.unmarshalAdmin(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return adm, nil
|
||||
}
|
||||
|
||||
// GetAdmins retrieves and unmarshals all active (not deleted) admins
|
||||
// from the database.
|
||||
// TODO should we be paginating?
|
||||
func (db *DB) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||
dbEntries, err := db.db.List(adminsTable)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error loading admins")
|
||||
}
|
||||
var admins = []*linkedca.Admin{}
|
||||
for _, entry := range dbEntries {
|
||||
adm, err := db.unmarshalAdmin(entry.Value, string(entry.Key))
|
||||
if err != nil {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if k.IsType(admin.ErrorDeletedType) || k.IsType(admin.ErrorAuthorityMismatchType) {
|
||||
continue
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if adm.AuthorityId != db.authorityID {
|
||||
continue
|
||||
}
|
||||
admins = append(admins, adm)
|
||||
}
|
||||
return admins, nil
|
||||
}
|
||||
|
||||
// CreateAdmin stores a new admin to the database.
|
||||
func (db *DB) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
var err error
|
||||
adm.Id, err = randID()
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error generating random id for admin")
|
||||
}
|
||||
adm.AuthorityId = db.authorityID
|
||||
|
||||
dba := &dbAdmin{
|
||||
ID: adm.Id,
|
||||
AuthorityID: db.authorityID,
|
||||
ProvisionerID: adm.ProvisionerId,
|
||||
Subject: adm.Subject,
|
||||
Type: adm.Type,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
return db.save(ctx, dba.ID, dba, nil, "admin", adminsTable)
|
||||
}
|
||||
|
||||
// UpdateAdmin saves an updated admin to the database.
|
||||
func (db *DB) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
old, err := db.getDBAdmin(ctx, adm.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
nu.Type = adm.Type
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "admin", adminsTable)
|
||||
}
|
||||
|
||||
// DeleteAdmin saves an updated admin to the database.
|
||||
func (db *DB) DeleteAdmin(ctx context.Context, id string) error {
|
||||
old, err := db.getDBAdmin(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
nu.DeletedAt = clock.Now()
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "admin", adminsTable)
|
||||
}
|
1108
authority/admin/db/nosql/admin_test.go
Normal file
1108
authority/admin/db/nosql/admin_test.go
Normal file
File diff suppressed because it is too large
Load diff
88
authority/admin/db/nosql/nosql.go
Normal file
88
authority/admin/db/nosql/nosql.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
nosqlDB "github.com/smallstep/nosql/database"
|
||||
"go.step.sm/crypto/randutil"
|
||||
)
|
||||
|
||||
var (
|
||||
adminsTable = []byte("admins")
|
||||
provisionersTable = []byte("provisioners")
|
||||
)
|
||||
|
||||
// DB is a struct that implements the AdminDB interface.
|
||||
type DB struct {
|
||||
db nosqlDB.DB
|
||||
authorityID string
|
||||
}
|
||||
|
||||
// New configures and returns a new Authority DB backend implemented using a nosql DB.
|
||||
func New(db nosqlDB.DB, authorityID string) (*DB, error) {
|
||||
tables := [][]byte{adminsTable, provisionersTable}
|
||||
for _, b := range tables {
|
||||
if err := db.CreateTable(b); err != nil {
|
||||
return nil, errors.Wrapf(err, "error creating table %s",
|
||||
string(b))
|
||||
}
|
||||
}
|
||||
return &DB{db, authorityID}, nil
|
||||
}
|
||||
|
||||
// save writes the new data to the database, overwriting the old data if it
|
||||
// existed.
|
||||
func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||
var (
|
||||
err error
|
||||
newB []byte
|
||||
)
|
||||
if nu == nil {
|
||||
newB = nil
|
||||
} else {
|
||||
newB, err = json.Marshal(nu)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error marshaling authority type: %s, value: %v", typ, nu)
|
||||
}
|
||||
}
|
||||
var oldB []byte
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
oldB, err = json.Marshal(old)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error marshaling admin type: %s, value: %v", typ, old)
|
||||
}
|
||||
}
|
||||
|
||||
_, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return errors.Wrapf(err, "error saving authority %s", typ)
|
||||
case !swapped:
|
||||
return errors.Errorf("error saving authority %s; changed since last read", typ)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func randID() (val string, err error) {
|
||||
val, err = randutil.UUIDv4()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error generating random alphanumeric ID")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
|
||||
// Now returns the UTC time rounded to seconds.
|
||||
func (c *Clock) Now() time.Time {
|
||||
return time.Now().UTC().Truncate(time.Second)
|
||||
}
|
||||
|
||||
var clock = new(Clock)
|
211
authority/admin/db/nosql/provisioner.go
Normal file
211
authority/admin/db/nosql/provisioner.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/nosql"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// dbProvisioner is the database representation of a Provisioner type.
|
||||
type dbProvisioner struct {
|
||||
ID string `json:"id"`
|
||||
AuthorityID string `json:"authorityID"`
|
||||
Type linkedca.Provisioner_Type `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Claims *linkedca.Claims `json:"claims"`
|
||||
Details []byte `json:"details"`
|
||||
X509Template *linkedca.Template `json:"x509Template"`
|
||||
SSHTemplate *linkedca.Template `json:"sshTemplate"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
DeletedAt time.Time `json:"deletedAt"`
|
||||
}
|
||||
|
||||
func (dbp *dbProvisioner) clone() *dbProvisioner {
|
||||
u := *dbp
|
||||
return &u
|
||||
}
|
||||
|
||||
func (dbp *dbProvisioner) convert2linkedca() (*linkedca.Provisioner, error) {
|
||||
details, err := admin.UnmarshalProvisionerDetails(dbp.Type, dbp.Details)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &linkedca.Provisioner{
|
||||
Id: dbp.ID,
|
||||
AuthorityId: dbp.AuthorityID,
|
||||
Type: dbp.Type,
|
||||
Name: dbp.Name,
|
||||
Claims: dbp.Claims,
|
||||
Details: details,
|
||||
X509Template: dbp.X509Template,
|
||||
SshTemplate: dbp.SSHTemplate,
|
||||
CreatedAt: timestamppb.New(dbp.CreatedAt),
|
||||
DeletedAt: timestamppb.New(dbp.DeletedAt),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) {
|
||||
data, err := db.db.Get(provisionersTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading provisioner %s", id)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalDBProvisioner(data []byte, id string) (*dbProvisioner, error) {
|
||||
var dbp = new(dbProvisioner)
|
||||
if err := json.Unmarshal(data, dbp); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling provisioner %s into dbProvisioner", id)
|
||||
}
|
||||
if !dbp.DeletedAt.IsZero() {
|
||||
return nil, admin.NewError(admin.ErrorDeletedType, "provisioner %s is deleted", id)
|
||||
}
|
||||
if dbp.AuthorityID != db.authorityID {
|
||||
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
|
||||
"provisioner %s is not owned by authority %s", id, db.authorityID)
|
||||
}
|
||||
return dbp, nil
|
||||
}
|
||||
|
||||
func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner, error) {
|
||||
data, err := db.getDBProvisionerBytes(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbp, err := db.unmarshalDBProvisioner(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbp, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalProvisioner(data []byte, id string) (*linkedca.Provisioner, error) {
|
||||
dbp, err := db.unmarshalDBProvisioner(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dbp.convert2linkedca()
|
||||
}
|
||||
|
||||
// GetProvisioner retrieves and unmarshals a provisioner from the database.
|
||||
func (db *DB) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
data, err := db.getDBProvisionerBytes(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prov, err := db.unmarshalProvisioner(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return prov, nil
|
||||
}
|
||||
|
||||
// GetProvisioners retrieves and unmarshals all active (not deleted) provisioners
|
||||
// from the database.
|
||||
func (db *DB) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
||||
dbEntries, err := db.db.List(provisionersTable)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error loading provisioners")
|
||||
}
|
||||
var provs []*linkedca.Provisioner
|
||||
for _, entry := range dbEntries {
|
||||
prov, err := db.unmarshalProvisioner(entry.Value, string(entry.Key))
|
||||
if err != nil {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if k.IsType(admin.ErrorDeletedType) || k.IsType(admin.ErrorAuthorityMismatchType) {
|
||||
continue
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if prov.AuthorityId != db.authorityID {
|
||||
continue
|
||||
}
|
||||
provs = append(provs, prov)
|
||||
}
|
||||
return provs, nil
|
||||
}
|
||||
|
||||
// CreateProvisioner stores a new provisioner to the database.
|
||||
func (db *DB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
var err error
|
||||
prov.Id, err = randID()
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error generating random id for provisioner")
|
||||
}
|
||||
|
||||
details, err := json.Marshal(prov.Details.GetData())
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error marshaling details when creating provisioner %s", prov.Name)
|
||||
}
|
||||
|
||||
dbp := &dbProvisioner{
|
||||
ID: prov.Id,
|
||||
AuthorityID: db.authorityID,
|
||||
Type: prov.Type,
|
||||
Name: prov.Name,
|
||||
Claims: prov.Claims,
|
||||
Details: details,
|
||||
X509Template: prov.X509Template,
|
||||
SSHTemplate: prov.SshTemplate,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
if err := db.save(ctx, prov.Id, dbp, nil, "provisioner", provisionersTable); err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating provisioner %s", prov.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProvisioner saves an updated provisioner to the database.
|
||||
func (db *DB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
old, err := db.getDBProvisioner(ctx, prov.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
|
||||
if old.Type != prov.Type {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "cannot update provisioner type")
|
||||
}
|
||||
nu.Name = prov.Name
|
||||
nu.Claims = prov.Claims
|
||||
nu.Details, err = json.Marshal(prov.Details.GetData())
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name)
|
||||
}
|
||||
nu.X509Template = prov.X509Template
|
||||
nu.SSHTemplate = prov.SshTemplate
|
||||
|
||||
return db.save(ctx, prov.Id, nu, old, "provisioner", provisionersTable)
|
||||
}
|
||||
|
||||
// DeleteProvisioner saves an updated admin to the database.
|
||||
func (db *DB) DeleteProvisioner(ctx context.Context, id string) error {
|
||||
old, err := db.getDBProvisioner(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
nu.DeletedAt = clock.Now()
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "provisioner", provisionersTable)
|
||||
}
|
1208
authority/admin/db/nosql/provisioner_test.go
Normal file
1208
authority/admin/db/nosql/provisioner_test.go
Normal file
File diff suppressed because it is too large
Load diff
223
authority/admin/errors.go
Normal file
223
authority/admin/errors.go
Normal file
|
@ -0,0 +1,223 @@
|
|||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// ProblemType is the type of the Admin problem.
|
||||
type ProblemType int
|
||||
|
||||
const (
|
||||
// ErrorNotFoundType resource not found.
|
||||
ErrorNotFoundType ProblemType = iota
|
||||
// ErrorAuthorityMismatchType resource Authority ID does not match the
|
||||
// context Authority ID.
|
||||
ErrorAuthorityMismatchType
|
||||
// ErrorDeletedType resource has been deleted.
|
||||
ErrorDeletedType
|
||||
// ErrorBadRequestType bad request.
|
||||
ErrorBadRequestType
|
||||
// ErrorNotImplementedType not implemented.
|
||||
ErrorNotImplementedType
|
||||
// ErrorUnauthorizedType internal server error.
|
||||
ErrorUnauthorizedType
|
||||
// ErrorServerInternalType internal server error.
|
||||
ErrorServerInternalType
|
||||
)
|
||||
|
||||
// String returns the string representation of the admin problem type,
|
||||
// fulfilling the Stringer interface.
|
||||
func (ap ProblemType) String() string {
|
||||
switch ap {
|
||||
case ErrorNotFoundType:
|
||||
return "notFound"
|
||||
case ErrorAuthorityMismatchType:
|
||||
return "authorityMismatch"
|
||||
case ErrorDeletedType:
|
||||
return "deleted"
|
||||
case ErrorBadRequestType:
|
||||
return "badRequest"
|
||||
case ErrorNotImplementedType:
|
||||
return "notImplemented"
|
||||
case ErrorUnauthorizedType:
|
||||
return "unauthorized"
|
||||
case ErrorServerInternalType:
|
||||
return "internalServerError"
|
||||
default:
|
||||
return fmt.Sprintf("unsupported error type '%d'", int(ap))
|
||||
}
|
||||
}
|
||||
|
||||
type errorMetadata struct {
|
||||
details string
|
||||
status int
|
||||
typ string
|
||||
String string
|
||||
}
|
||||
|
||||
var (
|
||||
errorServerInternalMetadata = errorMetadata{
|
||||
typ: ErrorServerInternalType.String(),
|
||||
details: "the server experienced an internal error",
|
||||
status: 500,
|
||||
}
|
||||
errorMap = map[ProblemType]errorMetadata{
|
||||
ErrorNotFoundType: {
|
||||
typ: ErrorNotFoundType.String(),
|
||||
details: "resource not found",
|
||||
status: http.StatusNotFound,
|
||||
},
|
||||
ErrorAuthorityMismatchType: {
|
||||
typ: ErrorAuthorityMismatchType.String(),
|
||||
details: "resource not owned by authority",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
ErrorDeletedType: {
|
||||
typ: ErrorDeletedType.String(),
|
||||
details: "resource is deleted",
|
||||
status: http.StatusNotFound,
|
||||
},
|
||||
ErrorNotImplementedType: {
|
||||
typ: ErrorNotImplementedType.String(),
|
||||
details: "not implemented",
|
||||
status: http.StatusNotImplemented,
|
||||
},
|
||||
ErrorBadRequestType: {
|
||||
typ: ErrorBadRequestType.String(),
|
||||
details: "bad request",
|
||||
status: http.StatusBadRequest,
|
||||
},
|
||||
ErrorUnauthorizedType: {
|
||||
typ: ErrorUnauthorizedType.String(),
|
||||
details: "unauthorized",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
ErrorServerInternalType: errorServerInternalMetadata,
|
||||
}
|
||||
)
|
||||
|
||||
// Error represents an Admin
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Message string `json:"message"`
|
||||
Err error `json:"-"`
|
||||
Status int `json:"-"`
|
||||
}
|
||||
|
||||
// IsType returns true if the error type matches the input type.
|
||||
func (e *Error) IsType(pt ProblemType) bool {
|
||||
return pt.String() == e.Type
|
||||
}
|
||||
|
||||
// NewError creates a new Error type.
|
||||
func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
|
||||
return newError(pt, errors.Errorf(msg, args...))
|
||||
}
|
||||
|
||||
func newError(pt ProblemType, err error) *Error {
|
||||
meta, ok := errorMap[pt]
|
||||
if !ok {
|
||||
meta = errorServerInternalMetadata
|
||||
return &Error{
|
||||
Type: meta.typ,
|
||||
Detail: meta.details,
|
||||
Status: meta.status,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
return &Error{
|
||||
Type: meta.typ,
|
||||
Detail: meta.details,
|
||||
Status: meta.status,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorISE creates a new ErrorServerInternalType Error.
|
||||
func NewErrorISE(msg string, args ...interface{}) *Error {
|
||||
return NewError(ErrorServerInternalType, msg, args...)
|
||||
}
|
||||
|
||||
// WrapError attempts to wrap the internal error.
|
||||
func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error {
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case *Error:
|
||||
if e.Err == nil {
|
||||
e.Err = errors.Errorf(msg+"; "+e.Detail, args...)
|
||||
} else {
|
||||
e.Err = errors.Wrapf(e.Err, msg, args...)
|
||||
}
|
||||
return e
|
||||
default:
|
||||
return newError(typ, errors.Wrapf(err, msg, args...))
|
||||
}
|
||||
}
|
||||
|
||||
// WrapErrorISE shortcut to wrap an internal server error type.
|
||||
func WrapErrorISE(err error, msg string, args ...interface{}) *Error {
|
||||
return WrapError(ErrorServerInternalType, err, msg, args...)
|
||||
}
|
||||
|
||||
// StatusCode returns the status code and implements the StatusCoder interface.
|
||||
func (e *Error) StatusCode() int {
|
||||
return e.Status
|
||||
}
|
||||
|
||||
// Error allows AError to implement the error interface.
|
||||
func (e *Error) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
// Cause returns the internal error and implements the Causer interface.
|
||||
func (e *Error) Cause() error {
|
||||
if e.Err == nil {
|
||||
return errors.New(e.Detail)
|
||||
}
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// ToLog implements the EnableLogger interface.
|
||||
func (e *Error) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return nil, WrapErrorISE(err, "error marshaling authority.Error for logging")
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// WriteError writes to w a JSON representation of the given error.
|
||||
func WriteError(w http.ResponseWriter, err *Error) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(err.StatusCode())
|
||||
|
||||
err.Message = err.Err.Error()
|
||||
// Write errors in the response writer
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"error": err.Err,
|
||||
})
|
||||
if os.Getenv("STEPDEBUG") == "1" {
|
||||
if e, ok := err.Err.(errs.StackTracer); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
243
authority/administrator/collection.go
Normal file
243
authority/administrator/collection.go
Normal file
|
@ -0,0 +1,243 @@
|
|||
package administrator
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// DefaultAdminLimit is the default limit for listing provisioners.
|
||||
const DefaultAdminLimit = 20
|
||||
|
||||
// DefaultAdminMax is the maximum limit for listing provisioners.
|
||||
const DefaultAdminMax = 100
|
||||
|
||||
type adminSlice []*linkedca.Admin
|
||||
|
||||
func (p adminSlice) Len() int { return len(p) }
|
||||
func (p adminSlice) Less(i, j int) bool { return p[i].Id < p[j].Id }
|
||||
func (p adminSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
// Collection is a memory map of admins.
|
||||
type Collection struct {
|
||||
byID *sync.Map
|
||||
bySubProv *sync.Map
|
||||
byProv *sync.Map
|
||||
sorted adminSlice
|
||||
provisioners *provisioner.Collection
|
||||
superCount int
|
||||
superCountByProvisioner map[string]int
|
||||
}
|
||||
|
||||
// NewCollection initializes a collection of provisioners. The given list of
|
||||
// audiences are the audiences used by the JWT provisioner.
|
||||
func NewCollection(provisioners *provisioner.Collection) *Collection {
|
||||
return &Collection{
|
||||
byID: new(sync.Map),
|
||||
byProv: new(sync.Map),
|
||||
bySubProv: new(sync.Map),
|
||||
superCountByProvisioner: map[string]int{},
|
||||
provisioners: provisioners,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadByID a admin by the ID.
|
||||
func (c *Collection) LoadByID(id string) (*linkedca.Admin, bool) {
|
||||
return loadAdmin(c.byID, id)
|
||||
}
|
||||
|
||||
type subProv struct {
|
||||
subject string
|
||||
provisioner string
|
||||
}
|
||||
|
||||
func newSubProv(subject, prov string) subProv {
|
||||
return subProv{subject, prov}
|
||||
}
|
||||
|
||||
// LoadBySubProv a admin by the subject and provisioner name.
|
||||
func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) {
|
||||
return loadAdmin(c.bySubProv, newSubProv(sub, provName))
|
||||
}
|
||||
|
||||
// LoadByProvisioner a admin by the subject and provisioner name.
|
||||
func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) {
|
||||
val, ok := c.byProv.Load(provName)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
admins, ok := val.([]*linkedca.Admin)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return admins, true
|
||||
}
|
||||
|
||||
// Store adds an admin to the collection and enforces the uniqueness of
|
||||
// admin IDs and amdin subject <-> provisioner name combos.
|
||||
func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
// Input validation.
|
||||
if adm.ProvisionerId != prov.GetID() {
|
||||
return admin.NewErrorISE("admin.provisionerId does not match provisioner argument")
|
||||
}
|
||||
|
||||
// Store admin always in byID. ID must be unique.
|
||||
if _, loaded := c.byID.LoadOrStore(adm.Id, adm); loaded {
|
||||
return errors.New("cannot add multiple admins with the same id")
|
||||
}
|
||||
|
||||
provName := prov.GetName()
|
||||
// Store admin always in bySubProv. Subject <-> ProvisionerName must be unique.
|
||||
if _, loaded := c.bySubProv.LoadOrStore(newSubProv(adm.Subject, provName), adm); loaded {
|
||||
c.byID.Delete(adm.Id)
|
||||
return errors.New("cannot add multiple admins with the same subject and provisioner")
|
||||
}
|
||||
|
||||
var isSuper = (adm.Type == linkedca.Admin_SUPER_ADMIN)
|
||||
if admins, ok := c.LoadByProvisioner(provName); ok {
|
||||
c.byProv.Store(provName, append(admins, adm))
|
||||
if isSuper {
|
||||
c.superCountByProvisioner[provName]++
|
||||
}
|
||||
} else {
|
||||
c.byProv.Store(provName, []*linkedca.Admin{adm})
|
||||
if isSuper {
|
||||
c.superCountByProvisioner[provName] = 1
|
||||
}
|
||||
}
|
||||
if isSuper {
|
||||
c.superCount++
|
||||
}
|
||||
|
||||
c.sorted = append(c.sorted, adm)
|
||||
sort.Sort(c.sorted)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes an admin from all associated collections and lists.
|
||||
func (c *Collection) Remove(id string) error {
|
||||
adm, ok := c.LoadByID(id)
|
||||
if !ok {
|
||||
return admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)
|
||||
}
|
||||
if adm.Type == linkedca.Admin_SUPER_ADMIN && c.SuperCount() == 1 {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "cannot remove the last super admin")
|
||||
}
|
||||
prov, ok := c.provisioners.Load(adm.ProvisionerId)
|
||||
if !ok {
|
||||
return admin.NewError(admin.ErrorNotFoundType,
|
||||
"provisioner %s for admin %s not found", adm.ProvisionerId, id)
|
||||
}
|
||||
provName := prov.GetName()
|
||||
adminsByProv, ok := c.LoadByProvisioner(provName)
|
||||
if !ok {
|
||||
return admin.NewError(admin.ErrorNotFoundType,
|
||||
"admins not found for provisioner %s", provName)
|
||||
}
|
||||
|
||||
// Find index in sorted list.
|
||||
sortedIndex := sort.Search(c.sorted.Len(), func(i int) bool { return c.sorted[i].Id >= adm.Id })
|
||||
if c.sorted[sortedIndex].Id != adm.Id {
|
||||
return admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found in sorted list", adm.Id)
|
||||
}
|
||||
|
||||
var found bool
|
||||
for i, a := range adminsByProv {
|
||||
if a.Id == adm.Id {
|
||||
// Remove admin from list. https://stackoverflow.com/questions/37334119/how-to-delete-an-element-from-a-slice-in-golang
|
||||
// Order does not matter.
|
||||
adminsByProv[i] = adminsByProv[len(adminsByProv)-1]
|
||||
c.byProv.Store(provName, adminsByProv[:len(adminsByProv)-1])
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found in adminsByProvisioner list", adm.Id)
|
||||
}
|
||||
|
||||
// Remove index in sorted list
|
||||
copy(c.sorted[sortedIndex:], c.sorted[sortedIndex+1:]) // Shift a[i+1:] left one index.
|
||||
c.sorted[len(c.sorted)-1] = nil // Erase last element (write zero value).
|
||||
c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice.
|
||||
|
||||
c.byID.Delete(adm.Id)
|
||||
c.bySubProv.Delete(newSubProv(adm.Subject, provName))
|
||||
|
||||
if adm.Type == linkedca.Admin_SUPER_ADMIN {
|
||||
c.superCount--
|
||||
c.superCountByProvisioner[provName]--
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update updates the given admin in all related lists and collections.
|
||||
func (c *Collection) Update(id string, nu *linkedca.Admin) (*linkedca.Admin, error) {
|
||||
adm, ok := c.LoadByID(id)
|
||||
if !ok {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", adm.Id)
|
||||
}
|
||||
if adm.Type == nu.Type {
|
||||
return adm, nil
|
||||
}
|
||||
if adm.Type == linkedca.Admin_SUPER_ADMIN && c.SuperCount() == 1 {
|
||||
return nil, admin.NewError(admin.ErrorBadRequestType, "cannot change role of last super admin")
|
||||
}
|
||||
|
||||
adm.Type = nu.Type
|
||||
return adm, nil
|
||||
}
|
||||
|
||||
// SuperCount returns the total number of admins.
|
||||
func (c *Collection) SuperCount() int {
|
||||
return c.superCount
|
||||
}
|
||||
|
||||
// SuperCountByProvisioner returns the total number of admins.
|
||||
func (c *Collection) SuperCountByProvisioner(provName string) int {
|
||||
if cnt, ok := c.superCountByProvisioner[provName]; ok {
|
||||
return cnt
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Find implements pagination on a list of sorted admins.
|
||||
func (c *Collection) Find(cursor string, limit int) ([]*linkedca.Admin, string) {
|
||||
switch {
|
||||
case limit <= 0:
|
||||
limit = DefaultAdminLimit
|
||||
case limit > DefaultAdminMax:
|
||||
limit = DefaultAdminMax
|
||||
}
|
||||
|
||||
n := c.sorted.Len()
|
||||
i := sort.Search(n, func(i int) bool { return c.sorted[i].Id >= cursor })
|
||||
|
||||
slice := []*linkedca.Admin{}
|
||||
for ; i < n && len(slice) < limit; i++ {
|
||||
slice = append(slice, c.sorted[i])
|
||||
}
|
||||
|
||||
if i < n {
|
||||
return slice, c.sorted[i].Id
|
||||
}
|
||||
return slice, ""
|
||||
}
|
||||
|
||||
func loadAdmin(m *sync.Map, key interface{}) (*linkedca.Admin, bool) {
|
||||
val, ok := m.Load(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
adm, ok := val.(*linkedca.Admin)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return adm, true
|
||||
}
|
97
authority/admins.go
Normal file
97
authority/admins.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// LoadAdminByID returns an *linkedca.Admin with the given ID.
|
||||
func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
|
||||
a.adminMutex.RLock()
|
||||
defer a.adminMutex.RUnlock()
|
||||
return a.admins.LoadByID(id)
|
||||
}
|
||||
|
||||
// LoadAdminBySubProv returns an *linkedca.Admin with the given ID.
|
||||
func (a *Authority) LoadAdminBySubProv(subject, prov string) (*linkedca.Admin, bool) {
|
||||
a.adminMutex.RLock()
|
||||
defer a.adminMutex.RUnlock()
|
||||
return a.admins.LoadBySubProv(subject, prov)
|
||||
}
|
||||
|
||||
// GetAdmins returns a map listing each provisioner and the JWK Key Set
|
||||
// with their public keys.
|
||||
func (a *Authority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) {
|
||||
a.adminMutex.RLock()
|
||||
defer a.adminMutex.RUnlock()
|
||||
admins, nextCursor := a.admins.Find(cursor, limit)
|
||||
return admins, nextCursor, nil
|
||||
}
|
||||
|
||||
// StoreAdmin stores an *linkedca.Admin to the authority.
|
||||
func (a *Authority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
if adm.ProvisionerId != prov.GetID() {
|
||||
return admin.NewErrorISE("admin.provisionerId does not match provisioner argument")
|
||||
}
|
||||
|
||||
if _, ok := a.admins.LoadBySubProv(adm.Subject, prov.GetName()); ok {
|
||||
return admin.NewError(admin.ErrorBadRequestType,
|
||||
"admin with subject %s and provisioner %s already exists", adm.Subject, prov.GetName())
|
||||
}
|
||||
// Store to database -- this will set the ID.
|
||||
if err := a.adminDB.CreateAdmin(ctx, adm); err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating admin")
|
||||
}
|
||||
if err := a.admins.Store(adm, prov); err != nil {
|
||||
if err := a.reloadAdminResources(ctx); err != nil {
|
||||
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin store")
|
||||
}
|
||||
return admin.WrapErrorISE(err, "error storing admin in authority cache")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateAdmin stores an *linkedca.Admin to the authority.
|
||||
func (a *Authority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
adm, err := a.admins.Update(id, nu)
|
||||
if err != nil {
|
||||
return nil, admin.WrapErrorISE(err, "error updating cached admin %s", id)
|
||||
}
|
||||
if err := a.adminDB.UpdateAdmin(ctx, adm); err != nil {
|
||||
if err := a.reloadAdminResources(ctx); err != nil {
|
||||
return nil, admin.WrapErrorISE(err, "error reloading admin resources on failed admin update")
|
||||
}
|
||||
return nil, admin.WrapErrorISE(err, "error updating admin %s", id)
|
||||
}
|
||||
return adm, nil
|
||||
}
|
||||
|
||||
// RemoveAdmin removes an *linkedca.Admin from the authority.
|
||||
func (a *Authority) RemoveAdmin(ctx context.Context, id string) error {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
return a.removeAdmin(ctx, id)
|
||||
}
|
||||
|
||||
// removeAdmin helper that assumes lock.
|
||||
func (a *Authority) removeAdmin(ctx context.Context, id string) error {
|
||||
if err := a.admins.Remove(id); err != nil {
|
||||
return admin.WrapErrorISE(err, "error removing admin %s from authority cache", id)
|
||||
}
|
||||
if err := a.adminDB.DeleteAdmin(ctx, id); err != nil {
|
||||
if err := a.reloadAdminResources(ctx); err != nil {
|
||||
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin remove")
|
||||
}
|
||||
return admin.WrapErrorISE(err, "error deleting admin %s", id)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -7,39 +7,47 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/cas"
|
||||
"github.com/smallstep/certificates/scep"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||
"github.com/smallstep/certificates/authority/administrator"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/cas"
|
||||
casapi "github.com/smallstep/certificates/cas/apiv1"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/kms"
|
||||
kmsapi "github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/kms/sshagentkms"
|
||||
"github.com/smallstep/certificates/scep"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"github.com/smallstep/nosql"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"go.step.sm/linkedca"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
legacyAuthority = "step-certificate-authority"
|
||||
)
|
||||
|
||||
// Authority implements the Certificate Authority internal interface.
|
||||
type Authority struct {
|
||||
config *Config
|
||||
keyManager kms.KeyManager
|
||||
provisioners *provisioner.Collection
|
||||
db db.AuthDB
|
||||
templates *templates.Templates
|
||||
config *config.Config
|
||||
keyManager kms.KeyManager
|
||||
provisioners *provisioner.Collection
|
||||
admins *administrator.Collection
|
||||
db db.AuthDB
|
||||
adminDB admin.DB
|
||||
templates *templates.Templates
|
||||
linkedCAToken string
|
||||
|
||||
// X509 CA
|
||||
password []byte
|
||||
issuerPassword []byte
|
||||
x509CAService cas.CertificateAuthorityService
|
||||
rootX509Certs []*x509.Certificate
|
||||
rootX509CertPool *x509.CertPool
|
||||
federatedX509Certs []*x509.Certificate
|
||||
certificates *sync.Map
|
||||
|
||||
|
@ -47,6 +55,8 @@ type Authority struct {
|
|||
scepService *scep.Service
|
||||
|
||||
// SSH CA
|
||||
sshHostPassword []byte
|
||||
sshUserPassword []byte
|
||||
sshCAUserCertSignKey ssh.Signer
|
||||
sshCAHostCertSignKey ssh.Signer
|
||||
sshCAUserCerts []ssh.PublicKey
|
||||
|
@ -59,21 +69,23 @@ type Authority struct {
|
|||
startTime time.Time
|
||||
|
||||
// Custom functions
|
||||
sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error)
|
||||
sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error)
|
||||
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
|
||||
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]Host, error)
|
||||
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
|
||||
getIdentityFunc provisioner.GetIdentityFunc
|
||||
|
||||
adminMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates and initiates a new Authority type.
|
||||
func New(config *Config, opts ...Option) (*Authority, error) {
|
||||
err := config.Validate()
|
||||
func New(cfg *config.Config, opts ...Option) (*Authority, error) {
|
||||
err := cfg.Validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var a = &Authority{
|
||||
config: config,
|
||||
config: cfg,
|
||||
certificates: new(sync.Map),
|
||||
}
|
||||
|
||||
|
@ -96,7 +108,7 @@ func New(config *Config, opts ...Option) (*Authority, error) {
|
|||
// project without the limitations of the config.
|
||||
func NewEmbedded(opts ...Option) (*Authority, error) {
|
||||
a := &Authority{
|
||||
config: &Config{},
|
||||
config: &config.Config{},
|
||||
certificates: new(sync.Map),
|
||||
}
|
||||
|
||||
|
@ -120,7 +132,7 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
|
|||
}
|
||||
|
||||
// Initialize config required fields.
|
||||
a.config.init()
|
||||
a.config.Init()
|
||||
|
||||
// Initialize authority from options or configuration.
|
||||
if err := a.init(); err != nil {
|
||||
|
@ -130,6 +142,65 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
|
|||
return a, nil
|
||||
}
|
||||
|
||||
// reloadAdminResources reloads admins and provisioners from the DB.
|
||||
func (a *Authority) reloadAdminResources(ctx context.Context) error {
|
||||
var (
|
||||
provList provisioner.List
|
||||
adminList []*linkedca.Admin
|
||||
)
|
||||
if a.config.AuthorityConfig.EnableAdmin {
|
||||
provs, err := a.adminDB.GetProvisioners(ctx)
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error getting provisioners to initialize authority")
|
||||
}
|
||||
provList, err = provisionerListToCertificates(provs)
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error converting provisioner list to certificates")
|
||||
}
|
||||
adminList, err = a.adminDB.GetAdmins(ctx)
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error getting admins to initialize authority")
|
||||
}
|
||||
} else {
|
||||
provList = a.config.AuthorityConfig.Provisioners
|
||||
adminList = a.config.AuthorityConfig.Admins
|
||||
}
|
||||
|
||||
provisionerConfig, err := a.generateProvisionerConfig(ctx)
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error generating provisioner config")
|
||||
}
|
||||
|
||||
// Create provisioner collection.
|
||||
provClxn := provisioner.NewCollection(provisionerConfig.Audiences)
|
||||
for _, p := range provList {
|
||||
if err := p.Init(*provisionerConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := provClxn.Store(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Create admin collection.
|
||||
adminClxn := administrator.NewCollection(provClxn)
|
||||
for _, adm := range adminList {
|
||||
p, ok := provClxn.Load(adm.ProvisionerId)
|
||||
if !ok {
|
||||
return admin.NewErrorISE("provisioner %s not found when loading admin %s",
|
||||
adm.ProvisionerId, adm.Id)
|
||||
}
|
||||
if err := adminClxn.Store(adm, p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
a.config.AuthorityConfig.Provisioners = provList
|
||||
a.provisioners = provClxn
|
||||
a.config.AuthorityConfig.Admins = adminList
|
||||
a.admins = adminClxn
|
||||
return nil
|
||||
}
|
||||
|
||||
// init performs validation and initializes the fields of an Authority struct.
|
||||
func (a *Authority) init() error {
|
||||
// Check if handler has already been validated/initialized.
|
||||
|
@ -139,6 +210,26 @@ func (a *Authority) init() error {
|
|||
|
||||
var err error
|
||||
|
||||
// Set password if they are not set.
|
||||
var configPassword []byte
|
||||
if a.config.Password != "" {
|
||||
configPassword = []byte(a.config.Password)
|
||||
}
|
||||
if configPassword != nil && a.password == nil {
|
||||
a.password = configPassword
|
||||
}
|
||||
if a.sshHostPassword == nil {
|
||||
a.sshHostPassword = a.password
|
||||
}
|
||||
if a.sshUserPassword == nil {
|
||||
a.sshUserPassword = a.password
|
||||
}
|
||||
|
||||
// Automatically enable admin for all linked cas.
|
||||
if a.linkedCAToken != "" {
|
||||
a.config.AuthorityConfig.EnableAdmin = true
|
||||
}
|
||||
|
||||
// Initialize step-ca Database if it's not already initialized with WithDB.
|
||||
// If a.config.DB is nil then a simple, barebones in memory DB will be used.
|
||||
if a.db == nil {
|
||||
|
@ -166,6 +257,11 @@ func (a *Authority) init() error {
|
|||
options = *a.config.AuthorityConfig.Options
|
||||
}
|
||||
|
||||
// Set the issuer password if passed in the flags.
|
||||
if options.CertificateIssuer != nil && a.issuerPassword != nil {
|
||||
options.CertificateIssuer.Password = string(a.issuerPassword)
|
||||
}
|
||||
|
||||
// Read intermediate and create X509 signer for default CAS.
|
||||
if options.Is(casapi.SoftCAS) {
|
||||
options.CertificateChain, err = pemutil.ReadCertificateBundle(a.config.IntermediateCert)
|
||||
|
@ -174,7 +270,7 @@ func (a *Authority) init() error {
|
|||
}
|
||||
options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
||||
SigningKey: a.config.IntermediateKey,
|
||||
Password: []byte(a.config.Password),
|
||||
Password: []byte(a.password),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -216,6 +312,11 @@ func (a *Authority) init() error {
|
|||
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
|
||||
}
|
||||
|
||||
a.rootX509CertPool = x509.NewCertPool()
|
||||
for _, cert := range a.rootX509Certs {
|
||||
a.rootX509CertPool.AddCert(cert)
|
||||
}
|
||||
|
||||
// Read federated certificates and store them in the certificates map.
|
||||
if len(a.federatedX509Certs) == 0 {
|
||||
a.federatedX509Certs = make([]*x509.Certificate, len(a.config.FederatedRoots))
|
||||
|
@ -238,7 +339,7 @@ func (a *Authority) init() error {
|
|||
if a.config.SSH.HostKey != "" {
|
||||
signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
||||
SigningKey: a.config.SSH.HostKey,
|
||||
Password: []byte(a.config.Password),
|
||||
Password: []byte(a.sshHostPassword),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -264,7 +365,7 @@ func (a *Authority) init() error {
|
|||
if a.config.SSH.UserKey != "" {
|
||||
signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
||||
SigningKey: a.config.SSH.UserKey,
|
||||
Password: []byte(a.config.Password),
|
||||
Password: []byte(a.sshUserPassword),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -288,59 +389,52 @@ func (a *Authority) init() error {
|
|||
a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, a.sshCAUserCertSignKey.PublicKey())
|
||||
}
|
||||
|
||||
// Append other public keys
|
||||
// Append other public keys and add them to the template variables.
|
||||
for _, key := range a.config.SSH.Keys {
|
||||
publicKey := key.PublicKey()
|
||||
switch key.Type {
|
||||
case provisioner.SSHHostCert:
|
||||
if key.Federated {
|
||||
a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, key.PublicKey())
|
||||
a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, publicKey)
|
||||
} else {
|
||||
a.sshCAHostCerts = append(a.sshCAHostCerts, key.PublicKey())
|
||||
a.sshCAHostCerts = append(a.sshCAHostCerts, publicKey)
|
||||
}
|
||||
case provisioner.SSHUserCert:
|
||||
if key.Federated {
|
||||
a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, key.PublicKey())
|
||||
a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, publicKey)
|
||||
} else {
|
||||
a.sshCAUserCerts = append(a.sshCAUserCerts, key.PublicKey())
|
||||
a.sshCAUserCerts = append(a.sshCAUserCerts, publicKey)
|
||||
}
|
||||
default:
|
||||
return errors.Errorf("unsupported type %s", key.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Configure template variables.
|
||||
// Configure template variables. On the template variables HostFederatedKeys
|
||||
// and UserFederatedKeys we will skip the actual CA that will be available
|
||||
// in HostKey and UserKey.
|
||||
//
|
||||
// We cannot do it in the previous blocks because this configuration can be
|
||||
// injected using options.
|
||||
if a.sshCAHostCertSignKey != nil {
|
||||
tmplVars.SSH.HostKey = a.sshCAHostCertSignKey.PublicKey()
|
||||
tmplVars.SSH.UserKey = a.sshCAUserCertSignKey.PublicKey()
|
||||
// On the templates we skip the first one because there's a distinction
|
||||
// between the main key and federated keys.
|
||||
tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts[1:]...)
|
||||
} else {
|
||||
tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts...)
|
||||
}
|
||||
if a.sshCAUserCertSignKey != nil {
|
||||
tmplVars.SSH.UserKey = a.sshCAUserCertSignKey.PublicKey()
|
||||
tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts[1:]...)
|
||||
} else {
|
||||
tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts...)
|
||||
}
|
||||
|
||||
// Merge global and configuration claims
|
||||
claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: should we also be combining the ssh federated roots here?
|
||||
// If we rotate ssh roots keys, sshpop provisioner will lose ability to
|
||||
// validate old SSH certificates, unless they are added as federated certs.
|
||||
sshKeys, err := a.GetSSHRoots(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Initialize provisioners
|
||||
audiences := a.config.getAudiences()
|
||||
a.provisioners = provisioner.NewCollection(audiences)
|
||||
config := provisioner.Config{
|
||||
Claims: claimer.Claims(),
|
||||
Audiences: audiences,
|
||||
DB: a.db,
|
||||
SSHKeys: &provisioner.SSHKeys{
|
||||
UserKeys: sshKeys.UserKeys,
|
||||
HostKeys: sshKeys.HostKeys,
|
||||
},
|
||||
GetIdentityFunc: a.getIdentityFunc,
|
||||
// Check if a KMS with decryption capability is required and available
|
||||
if a.requiresDecrypter() {
|
||||
if _, ok := a.keyManager.(kmsapi.Decrypter); !ok {
|
||||
return errors.New("keymanager doesn't provide crypto.Decrypter")
|
||||
}
|
||||
}
|
||||
|
||||
// Check if a KMS with decryption capability is required and available
|
||||
|
@ -362,7 +456,7 @@ func (a *Authority) init() error {
|
|||
}
|
||||
options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
||||
SigningKey: a.config.IntermediateKey,
|
||||
Password: []byte(a.config.Password),
|
||||
Password: []byte(a.password),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -371,7 +465,7 @@ func (a *Authority) init() error {
|
|||
if km, ok := a.keyManager.(kmsapi.Decrypter); ok {
|
||||
options.Decrypter, err = km.CreateDecrypter(&kmsapi.CreateDecrypterRequest{
|
||||
DecryptionKey: a.config.IntermediateKey,
|
||||
Password: []byte(a.config.Password),
|
||||
Password: []byte(a.password),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -386,14 +480,56 @@ func (a *Authority) init() error {
|
|||
// TODO: mimick the x509CAService GetCertificateAuthority here too?
|
||||
}
|
||||
|
||||
// Store all the provisioners
|
||||
for _, p := range a.config.AuthorityConfig.Provisioners {
|
||||
if err := p.Init(config); err != nil {
|
||||
return err
|
||||
if a.config.AuthorityConfig.EnableAdmin {
|
||||
// Initialize step-ca Admin Database if it's not already initialized using
|
||||
// WithAdminDB.
|
||||
if a.adminDB == nil {
|
||||
if a.linkedCAToken == "" {
|
||||
// Check if AuthConfig already exists
|
||||
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Use the linkedca client as the admindb.
|
||||
client, err := newLinkedCAClient(a.linkedCAToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If authorityId is configured make sure it matches the one in the token
|
||||
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) {
|
||||
return errors.New("error initializing linkedca: token authority and configured authority do not match")
|
||||
}
|
||||
client.Run()
|
||||
a.adminDB = client
|
||||
}
|
||||
}
|
||||
if err := a.provisioners.Store(p); err != nil {
|
||||
return err
|
||||
|
||||
provs, err := a.adminDB.GetProvisioners(context.Background())
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
||||
}
|
||||
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
|
||||
// Create First Provisioner
|
||||
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password))
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating first provisioner")
|
||||
}
|
||||
|
||||
// Create first admin
|
||||
if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{
|
||||
ProvisionerId: prov.Id,
|
||||
Subject: "step",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
}); err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating first admin")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load Provisioners and Admins
|
||||
if err := a.reloadAdminResources(context.Background()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure templates, currently only ssh templates are supported.
|
||||
|
@ -423,6 +559,17 @@ func (a *Authority) GetDatabase() db.AuthDB {
|
|||
return a.db
|
||||
}
|
||||
|
||||
// GetAdminDatabase returns the admin database, if one exists.
|
||||
func (a *Authority) GetAdminDatabase() admin.DB {
|
||||
return a.adminDB
|
||||
}
|
||||
|
||||
// IsAdminAPIEnabled returns a boolean indicating whether the Admin API has
|
||||
// been enabled.
|
||||
func (a *Authority) IsAdminAPIEnabled() bool {
|
||||
return a.config.AuthorityConfig.EnableAdmin
|
||||
}
|
||||
|
||||
// Shutdown safely shuts down any clients, databases, etc. held by the Authority.
|
||||
func (a *Authority) Shutdown() error {
|
||||
if err := a.keyManager.Close(); err != nil {
|
||||
|
@ -436,6 +583,9 @@ func (a *Authority) CloseForReload() {
|
|||
if err := a.keyManager.Close(); err != nil {
|
||||
log.Printf("error closing the key manager: %v", err)
|
||||
}
|
||||
if client, ok := a.adminDB.(*linkedCaClient); ok {
|
||||
client.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// requiresDecrypter returns whether the Authority
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
|
@ -82,6 +83,10 @@ func testAuthority(t *testing.T, opts ...Option) *Authority {
|
|||
}
|
||||
a, err := New(c, opts...)
|
||||
assert.FatalError(t, err)
|
||||
// Avoid errors when test tokens are created before the test authority. This
|
||||
// happens in some tests where we re-create the same authority to test
|
||||
// special cases without re-creating the token.
|
||||
a.startTime = a.startTime.Add(-1 * time.Minute)
|
||||
return a
|
||||
}
|
||||
|
||||
|
@ -454,8 +459,6 @@ func TestAuthority_GetSCEPService(t *testing.T) {
|
|||
// getIdentityFunc: tt.fields.getIdentityFunc,
|
||||
// }
|
||||
a, err := New(tt.fields.config)
|
||||
fmt.Println(err)
|
||||
fmt.Println(a)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.New(), error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
|
@ -6,11 +6,15 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/linkedca"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
|
@ -50,7 +54,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
|||
// key in order to verify the claims and we need the issuer from the claims
|
||||
// before we can look up the provisioner.
|
||||
var claims Claims
|
||||
if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
|
||||
}
|
||||
|
||||
|
@ -73,25 +77,124 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
|||
// Store the token to protect against reuse unless it's skipped.
|
||||
// If we cannot get a token id from the provisioner, just hash the token.
|
||||
if !SkipTokenReuseFromContext(ctx) {
|
||||
if reuseKey, err := p.GetTokenID(token); err == nil {
|
||||
if reuseKey == "" {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
reuseKey = strings.ToLower(hex.EncodeToString(sum[:]))
|
||||
}
|
||||
ok, err := a.db.UseToken(reuseKey, token)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.authorizeToken: failed when attempting to store token")
|
||||
}
|
||||
if !ok {
|
||||
return nil, errs.Unauthorized("authority.authorizeToken: token already used")
|
||||
}
|
||||
if err := a.UseToken(token, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// AuthorizeAdminToken authorize an Admin token.
|
||||
func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c token")
|
||||
}
|
||||
|
||||
verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{
|
||||
Roots: a.rootX509CertPool,
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err,
|
||||
"adminHandler.authorizeToken; error verifying x5c certificate chain in token")
|
||||
}
|
||||
leaf := verifiedChains[0][0]
|
||||
|
||||
if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType, "adminHandler.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")
|
||||
}
|
||||
|
||||
// Using the leaf certificates key to validate the claims accomplishes two
|
||||
// things:
|
||||
// 1. Asserts that the private key used to sign the token corresponds
|
||||
// to the public certificate in the `x5c` header of the token.
|
||||
// 2. Asserts that the claims are valid - have not been tampered with.
|
||||
var claims jose.Claims
|
||||
if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
|
||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims")
|
||||
}
|
||||
|
||||
prov, err := a.LoadProvisionerByCertificate(leaf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check that the token has not been used.
|
||||
if err := a.UseToken(token, prov); err != nil {
|
||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token")
|
||||
}
|
||||
|
||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
// more than a few minutes.
|
||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: prov.GetName(),
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "x5c.authorizeToken; invalid x5c claims")
|
||||
}
|
||||
|
||||
// validate audience: path matches the current path
|
||||
if r.URL.Path != claims.Audience[0] {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
"x5c.authorizeToken; x5c token has invalid audience "+
|
||||
"claim (aud); expected %s, but got %s", r.URL.Path, claims.Audience)
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
"x5c.authorizeToken; x5c token subject cannot be empty")
|
||||
}
|
||||
|
||||
var (
|
||||
ok bool
|
||||
adm *linkedca.Admin
|
||||
)
|
||||
adminFound := false
|
||||
adminSANs := append([]string{leaf.Subject.CommonName}, leaf.DNSNames...)
|
||||
adminSANs = append(adminSANs, leaf.EmailAddresses...)
|
||||
for _, san := range adminSANs {
|
||||
if adm, ok = a.LoadAdminBySubProv(san, claims.Issuer); ok {
|
||||
adminFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !adminFound {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
"adminHandler.authorizeToken; unable to load admin with subject(s) %s and provisioner '%s'",
|
||||
adminSANs, claims.Issuer)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(r.URL.Path, "/admin/admins") && (r.Method != "GET") && adm.Type != linkedca.Admin_SUPER_ADMIN {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType, "must have super admin access to make this request")
|
||||
}
|
||||
|
||||
return adm, nil
|
||||
}
|
||||
|
||||
// UseToken stores the token to protect against reuse.
|
||||
//
|
||||
// This method currently ignores any error coming from the GetTokenID, but it
|
||||
// should specifically ignore the error provisioner.ErrAllowTokenReuse.
|
||||
func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
|
||||
if reuseKey, err := prov.GetTokenID(token); err == nil {
|
||||
if reuseKey == "" {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
reuseKey = strings.ToLower(hex.EncodeToString(sum[:]))
|
||||
}
|
||||
ok, err := a.db.UseToken(reuseKey, token)
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.authorizeToken: failed when attempting to store token")
|
||||
}
|
||||
if !ok {
|
||||
return errs.Unauthorized("authority.authorizeToken: token already used")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authorize grabs the method from the context and authorizes the request by
|
||||
// validating the one-time-token.
|
||||
func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
|
@ -159,7 +262,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
|||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
||||
}
|
||||
if err = p.AuthorizeRevoke(ctx, token); err != nil {
|
||||
if err := p.AuthorizeRevoke(ctx, token); err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
||||
}
|
||||
return nil
|
||||
|
@ -171,10 +274,19 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
|||
//
|
||||
// TODO(mariano): should we authorize by default?
|
||||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||
var err error
|
||||
var isRevoked bool
|
||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
|
||||
|
||||
// Check the passive revocation table.
|
||||
isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String())
|
||||
serial := cert.SerialNumber.String()
|
||||
if lca, ok := a.adminDB.(interface {
|
||||
IsRevoked(string) (bool, error)
|
||||
}); ok {
|
||||
isRevoked, err = lca.IsRevoked(serial)
|
||||
} else {
|
||||
isRevoked, err = a.db.IsRevoked(serial)
|
||||
}
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||
}
|
||||
|
@ -192,6 +304,28 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// authorizeSSHCertificate returns an error if the given certificate is revoked.
|
||||
func (a *Authority) authorizeSSHCertificate(ctx context.Context, cert *ssh.Certificate) error {
|
||||
var err error
|
||||
var isRevoked bool
|
||||
|
||||
serial := strconv.FormatUint(cert.Serial, 10)
|
||||
if lca, ok := a.adminDB.(interface {
|
||||
IsSSHRevoked(string) (bool, error)
|
||||
}); ok {
|
||||
isRevoked, err = lca.IsSSHRevoked(serial)
|
||||
} else {
|
||||
isRevoked, err = a.db.IsSSHRevoked(serial)
|
||||
}
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHCertificate", errs.WithKeyVal("serialNumber", serial))
|
||||
}
|
||||
if isRevoked {
|
||||
return errs.Unauthorized("authority.authorizeSSHCertificate: certificate has been revoked", errs.WithKeyVal("serialNumber", serial))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorizeSSHSign loads the provisioner from the token, checks that it has not
|
||||
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a
|
||||
// list of methods to apply to the signing flow.
|
||||
|
|
|
@ -822,7 +822,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
cert: renewDisabledCrt,
|
||||
err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner renew_disabled:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk"),
|
||||
err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -917,7 +917,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err = cert.SignCert(rand.Reader, signer); err != nil {
|
||||
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return cert, jwk, nil
|
||||
|
|
|
@ -1,298 +1,46 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
import "github.com/smallstep/certificates/authority/config"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
cas "github.com/smallstep/certificates/cas/apiv1"
|
||||
"github.com/smallstep/certificates/db"
|
||||
kms "github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
)
|
||||
// Config is an alias to support older APIs.
|
||||
type Config = config.Config
|
||||
|
||||
var (
|
||||
// DefaultTLSOptions represents the default TLS version as well as the cipher
|
||||
// suites used in the TLS certificates.
|
||||
DefaultTLSOptions = TLSOptions{
|
||||
CipherSuites: CipherSuites{
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
||||
},
|
||||
MinVersion: 1.2,
|
||||
MaxVersion: 1.2,
|
||||
Renegotiation: false,
|
||||
}
|
||||
defaultBackdate = time.Minute
|
||||
defaultDisableRenewal = false
|
||||
defaultEnableSSHCA = false
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
|
||||
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &defaultEnableSSHCA,
|
||||
}
|
||||
)
|
||||
// LoadConfiguration is an alias to support older APIs.
|
||||
var LoadConfiguration = config.LoadConfiguration
|
||||
|
||||
// Config represents the CA configuration and it's mapped to a JSON object.
|
||||
type Config struct {
|
||||
Root multiString `json:"root"`
|
||||
FederatedRoots []string `json:"federatedRoots"`
|
||||
IntermediateCert string `json:"crt"`
|
||||
IntermediateKey string `json:"key"`
|
||||
Address string `json:"address"`
|
||||
InsecureAddress string `json:"insecureAddress"`
|
||||
DNSNames []string `json:"dnsNames"`
|
||||
KMS *kms.Options `json:"kms,omitempty"`
|
||||
SSH *SSHConfig `json:"ssh,omitempty"`
|
||||
Logger json.RawMessage `json:"logger,omitempty"`
|
||||
DB *db.Config `json:"db,omitempty"`
|
||||
Monitoring json.RawMessage `json:"monitoring,omitempty"`
|
||||
AuthorityConfig *AuthConfig `json:"authority,omitempty"`
|
||||
TLS *TLSOptions `json:"tls,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Templates *templates.Templates `json:"templates,omitempty"`
|
||||
}
|
||||
// AuthConfig is an alias to support older APIs.
|
||||
type AuthConfig = config.AuthConfig
|
||||
|
||||
// ASN1DN contains ASN1.DN attributes that are used in Subject and Issuer
|
||||
// x509 Certificate blocks.
|
||||
type ASN1DN struct {
|
||||
Country string `json:"country,omitempty" step:"country"`
|
||||
Organization string `json:"organization,omitempty" step:"organization"`
|
||||
OrganizationalUnit string `json:"organizationalUnit,omitempty" step:"organizationalUnit"`
|
||||
Locality string `json:"locality,omitempty" step:"locality"`
|
||||
Province string `json:"province,omitempty" step:"province"`
|
||||
StreetAddress string `json:"streetAddress,omitempty" step:"streetAddress"`
|
||||
CommonName string `json:"commonName,omitempty" step:"commonName"`
|
||||
}
|
||||
// TLS
|
||||
|
||||
// AuthConfig represents the configuration options for the authority. An
|
||||
// underlaying registration authority can also be configured using the
|
||||
// cas.Options.
|
||||
type AuthConfig struct {
|
||||
*cas.Options
|
||||
Provisioners provisioner.List `json:"provisioners"`
|
||||
Template *ASN1DN `json:"template,omitempty"`
|
||||
Claims *provisioner.Claims `json:"claims,omitempty"`
|
||||
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
||||
Backdate *provisioner.Duration `json:"backdate,omitempty"`
|
||||
}
|
||||
// ASN1DN is an alias to support older APIs.
|
||||
type ASN1DN = config.ASN1DN
|
||||
|
||||
// init initializes the required fields in the AuthConfig if they are not
|
||||
// provided.
|
||||
func (c *AuthConfig) init() {
|
||||
if c.Provisioners == nil {
|
||||
c.Provisioners = provisioner.List{}
|
||||
}
|
||||
if c.Template == nil {
|
||||
c.Template = &ASN1DN{}
|
||||
}
|
||||
if c.Backdate == nil {
|
||||
c.Backdate = &provisioner.Duration{
|
||||
Duration: defaultBackdate,
|
||||
}
|
||||
}
|
||||
}
|
||||
// DefaultTLSOptions is an alias to support older APIs.
|
||||
var DefaultTLSOptions = config.DefaultTLSOptions
|
||||
|
||||
// Validate validates the authority configuration.
|
||||
func (c *AuthConfig) Validate(audiences provisioner.Audiences) error {
|
||||
if c == nil {
|
||||
return errors.New("authority cannot be undefined")
|
||||
}
|
||||
// TLSOptions is an alias to support older APIs.
|
||||
type TLSOptions = config.TLSOptions
|
||||
|
||||
// Initialize required fields.
|
||||
c.init()
|
||||
// CipherSuites is an alias to support older APIs.
|
||||
type CipherSuites = config.CipherSuites
|
||||
|
||||
// Check that only one K8sSA is enabled
|
||||
var k8sCount int
|
||||
for _, p := range c.Provisioners {
|
||||
if p.GetType() == provisioner.TypeK8sSA {
|
||||
k8sCount++
|
||||
}
|
||||
}
|
||||
if k8sCount > 1 {
|
||||
return errors.New("cannot have more than one kubernetes service account provisioner")
|
||||
}
|
||||
// SSH
|
||||
|
||||
if c.Backdate.Duration < 0 {
|
||||
return errors.New("authority.backdate cannot be less than 0")
|
||||
}
|
||||
// SSHConfig is an alias to support older APIs.
|
||||
type SSHConfig = config.SSHConfig
|
||||
|
||||
return nil
|
||||
}
|
||||
// Bastion is an alias to support older APIs.
|
||||
type Bastion = config.Bastion
|
||||
|
||||
// LoadConfiguration parses the given filename in JSON format and returns the
|
||||
// configuration struct.
|
||||
func LoadConfiguration(filename string) (*Config, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error opening %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
// HostTag is an alias to support older APIs.
|
||||
type HostTag = config.HostTag
|
||||
|
||||
var c Config
|
||||
if err := json.NewDecoder(f).Decode(&c); err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing %s", filename)
|
||||
}
|
||||
// Host is an alias to support older APIs.
|
||||
type Host = config.Host
|
||||
|
||||
c.init()
|
||||
// SSHPublicKey is an alias to support older APIs.
|
||||
type SSHPublicKey = config.SSHPublicKey
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// initializes the minimal configuration required to create an authority. This
|
||||
// is mainly used on embedded authorities.
|
||||
func (c *Config) init() {
|
||||
if c.DNSNames == nil {
|
||||
c.DNSNames = []string{"localhost", "127.0.0.1", "::1"}
|
||||
}
|
||||
if c.TLS == nil {
|
||||
c.TLS = &DefaultTLSOptions
|
||||
}
|
||||
if c.AuthorityConfig == nil {
|
||||
c.AuthorityConfig = &AuthConfig{}
|
||||
}
|
||||
c.AuthorityConfig.init()
|
||||
}
|
||||
|
||||
// Save saves the configuration to the given filename.
|
||||
func (c *Config) Save(filename string) error {
|
||||
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error opening %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", "\t")
|
||||
return errors.Wrapf(enc.Encode(c), "error writing %s", filename)
|
||||
}
|
||||
|
||||
// Validate validates the configuration.
|
||||
func (c *Config) Validate() error {
|
||||
switch {
|
||||
case c.Address == "":
|
||||
return errors.New("address cannot be empty")
|
||||
|
||||
case len(c.DNSNames) == 0:
|
||||
return errors.New("dnsNames cannot be empty")
|
||||
}
|
||||
|
||||
// Options holds the RA/CAS configuration.
|
||||
ra := c.AuthorityConfig.Options
|
||||
// The default RA/CAS requires root, crt and key.
|
||||
if ra.Is(cas.SoftCAS) {
|
||||
switch {
|
||||
case c.Root.HasEmpties():
|
||||
return errors.New("root cannot be empty")
|
||||
case c.IntermediateCert == "":
|
||||
return errors.New("crt cannot be empty")
|
||||
case c.IntermediateKey == "":
|
||||
return errors.New("key cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate address (a port is required)
|
||||
if _, _, err := net.SplitHostPort(c.Address); err != nil {
|
||||
return errors.Errorf("invalid address %s", c.Address)
|
||||
}
|
||||
|
||||
// Validate insecure address if it is configured
|
||||
if c.InsecureAddress != "" {
|
||||
if _, _, err := net.SplitHostPort(c.InsecureAddress); err != nil {
|
||||
return errors.Errorf("invalid address %s", c.InsecureAddress)
|
||||
}
|
||||
}
|
||||
|
||||
if c.TLS == nil {
|
||||
c.TLS = &DefaultTLSOptions
|
||||
} else {
|
||||
if len(c.TLS.CipherSuites) == 0 {
|
||||
c.TLS.CipherSuites = DefaultTLSOptions.CipherSuites
|
||||
}
|
||||
if c.TLS.MaxVersion == 0 {
|
||||
c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion
|
||||
}
|
||||
if c.TLS.MinVersion == 0 {
|
||||
c.TLS.MinVersion = c.TLS.MaxVersion
|
||||
}
|
||||
if c.TLS.MinVersion > c.TLS.MaxVersion {
|
||||
return errors.New("tls minVersion cannot exceed tls maxVersion")
|
||||
}
|
||||
c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation
|
||||
}
|
||||
|
||||
// Validate KMS options, nil is ok.
|
||||
if err := c.KMS.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate RA/CAS options, nil is ok.
|
||||
if err := ra.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate ssh: nil is ok
|
||||
if err := c.SSH.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate templates: nil is ok
|
||||
if err := c.Templates.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.AuthorityConfig.Validate(c.getAudiences())
|
||||
}
|
||||
|
||||
// getAudiences returns the legacy and possible urls without the ports that will
|
||||
// be used as the default provisioner audiences. The CA might have proxies in
|
||||
// front so we cannot rely on the port.
|
||||
func (c *Config) getAudiences() provisioner.Audiences {
|
||||
audiences := provisioner.Audiences{
|
||||
Sign: []string{legacyAuthority},
|
||||
Revoke: []string{legacyAuthority},
|
||||
SSHSign: []string{},
|
||||
SSHRevoke: []string{},
|
||||
SSHRenew: []string{},
|
||||
}
|
||||
|
||||
for _, name := range c.DNSNames {
|
||||
audiences.Sign = append(audiences.Sign,
|
||||
fmt.Sprintf("https://%s/1.0/sign", name),
|
||||
fmt.Sprintf("https://%s/sign", name),
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/ssh/sign", name))
|
||||
audiences.Revoke = append(audiences.Revoke,
|
||||
fmt.Sprintf("https://%s/1.0/revoke", name),
|
||||
fmt.Sprintf("https://%s/revoke", name))
|
||||
audiences.SSHSign = append(audiences.SSHSign,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/1.0/sign", name),
|
||||
fmt.Sprintf("https://%s/sign", name))
|
||||
audiences.SSHRevoke = append(audiences.SSHRevoke,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/revoke", name),
|
||||
fmt.Sprintf("https://%s/ssh/revoke", name))
|
||||
audiences.SSHRenew = append(audiences.SSHRenew,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/renew", name),
|
||||
fmt.Sprintf("https://%s/ssh/renew", name))
|
||||
audiences.SSHRekey = append(audiences.SSHRekey,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/rekey", name),
|
||||
fmt.Sprintf("https://%s/ssh/rekey", name))
|
||||
}
|
||||
|
||||
return audiences
|
||||
}
|
||||
// SSHKeys is an alias to support older APIs.
|
||||
type SSHKeys = config.SSHKeys
|
||||
|
|
297
authority/config/config.go
Normal file
297
authority/config/config.go
Normal file
|
@ -0,0 +1,297 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
cas "github.com/smallstep/certificates/cas/apiv1"
|
||||
"github.com/smallstep/certificates/db"
|
||||
kms "github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
const (
|
||||
legacyAuthority = "step-certificate-authority"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBackdate length of time to backdate certificates to avoid
|
||||
// clock skew validation issues.
|
||||
DefaultBackdate = time.Minute
|
||||
// DefaultDisableRenewal disables renewals per provisioner.
|
||||
DefaultDisableRenewal = false
|
||||
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally
|
||||
// for all provisioners.
|
||||
DefaultEnableSSHCA = false
|
||||
// GlobalProvisionerClaims default claims for the Authority. Can be overridden
|
||||
// by provisioner specific claims.
|
||||
GlobalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &DefaultDisableRenewal,
|
||||
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
|
||||
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &DefaultEnableSSHCA,
|
||||
}
|
||||
)
|
||||
|
||||
// Config represents the CA configuration and it's mapped to a JSON object.
|
||||
type Config struct {
|
||||
Root multiString `json:"root"`
|
||||
FederatedRoots []string `json:"federatedRoots"`
|
||||
IntermediateCert string `json:"crt"`
|
||||
IntermediateKey string `json:"key"`
|
||||
Address string `json:"address"`
|
||||
InsecureAddress string `json:"insecureAddress"`
|
||||
DNSNames []string `json:"dnsNames"`
|
||||
KMS *kms.Options `json:"kms,omitempty"`
|
||||
SSH *SSHConfig `json:"ssh,omitempty"`
|
||||
Logger json.RawMessage `json:"logger,omitempty"`
|
||||
DB *db.Config `json:"db,omitempty"`
|
||||
Monitoring json.RawMessage `json:"monitoring,omitempty"`
|
||||
AuthorityConfig *AuthConfig `json:"authority,omitempty"`
|
||||
TLS *TLSOptions `json:"tls,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Templates *templates.Templates `json:"templates,omitempty"`
|
||||
}
|
||||
|
||||
// ASN1DN contains ASN1.DN attributes that are used in Subject and Issuer
|
||||
// x509 Certificate blocks.
|
||||
type ASN1DN struct {
|
||||
Country string `json:"country,omitempty"`
|
||||
Organization string `json:"organization,omitempty"`
|
||||
OrganizationalUnit string `json:"organizationalUnit,omitempty"`
|
||||
Locality string `json:"locality,omitempty"`
|
||||
Province string `json:"province,omitempty"`
|
||||
StreetAddress string `json:"streetAddress,omitempty"`
|
||||
SerialNumber string `json:"serialNumber,omitempty"`
|
||||
CommonName string `json:"commonName,omitempty"`
|
||||
}
|
||||
|
||||
// AuthConfig represents the configuration options for the authority. An
|
||||
// underlaying registration authority can also be configured using the
|
||||
// cas.Options.
|
||||
type AuthConfig struct {
|
||||
*cas.Options
|
||||
AuthorityID string `json:"authorityId,omitempty"`
|
||||
DeploymentType string `json:"deploymentType,omitempty"`
|
||||
Provisioners provisioner.List `json:"provisioners,omitempty"`
|
||||
Admins []*linkedca.Admin `json:"-"`
|
||||
Template *ASN1DN `json:"template,omitempty"`
|
||||
Claims *provisioner.Claims `json:"claims,omitempty"`
|
||||
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
||||
Backdate *provisioner.Duration `json:"backdate,omitempty"`
|
||||
EnableAdmin bool `json:"enableAdmin,omitempty"`
|
||||
}
|
||||
|
||||
// init initializes the required fields in the AuthConfig if they are not
|
||||
// provided.
|
||||
func (c *AuthConfig) init() {
|
||||
if c.Provisioners == nil {
|
||||
c.Provisioners = provisioner.List{}
|
||||
}
|
||||
if c.Template == nil {
|
||||
c.Template = &ASN1DN{}
|
||||
}
|
||||
if c.Backdate == nil {
|
||||
c.Backdate = &provisioner.Duration{
|
||||
Duration: DefaultBackdate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the authority configuration.
|
||||
func (c *AuthConfig) Validate(audiences provisioner.Audiences) error {
|
||||
if c == nil {
|
||||
return errors.New("authority cannot be undefined")
|
||||
}
|
||||
|
||||
// Initialize required fields.
|
||||
c.init()
|
||||
|
||||
// Check that only one K8sSA is enabled
|
||||
var k8sCount int
|
||||
for _, p := range c.Provisioners {
|
||||
if p.GetType() == provisioner.TypeK8sSA {
|
||||
k8sCount++
|
||||
}
|
||||
}
|
||||
if k8sCount > 1 {
|
||||
return errors.New("cannot have more than one kubernetes service account provisioner")
|
||||
}
|
||||
|
||||
if c.Backdate.Duration < 0 {
|
||||
return errors.New("authority.backdate cannot be less than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfiguration parses the given filename in JSON format and returns the
|
||||
// configuration struct.
|
||||
func LoadConfiguration(filename string) (*Config, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error opening %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var c Config
|
||||
if err := json.NewDecoder(f).Decode(&c); err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing %s", filename)
|
||||
}
|
||||
|
||||
c.Init()
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// Init initializes the minimal configuration required to create an authority. This
|
||||
// is mainly used on embedded authorities.
|
||||
func (c *Config) Init() {
|
||||
if c.DNSNames == nil {
|
||||
c.DNSNames = []string{"localhost", "127.0.0.1", "::1"}
|
||||
}
|
||||
if c.TLS == nil {
|
||||
c.TLS = &DefaultTLSOptions
|
||||
}
|
||||
if c.AuthorityConfig == nil {
|
||||
c.AuthorityConfig = &AuthConfig{}
|
||||
}
|
||||
c.AuthorityConfig.init()
|
||||
}
|
||||
|
||||
// Save saves the configuration to the given filename.
|
||||
func (c *Config) Save(filename string) error {
|
||||
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error opening %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", "\t")
|
||||
return errors.Wrapf(enc.Encode(c), "error writing %s", filename)
|
||||
}
|
||||
|
||||
// Validate validates the configuration.
|
||||
func (c *Config) Validate() error {
|
||||
switch {
|
||||
case c.Address == "":
|
||||
return errors.New("address cannot be empty")
|
||||
case len(c.DNSNames) == 0:
|
||||
return errors.New("dnsNames cannot be empty")
|
||||
case c.AuthorityConfig == nil:
|
||||
return errors.New("authority cannot be nil")
|
||||
}
|
||||
|
||||
// Options holds the RA/CAS configuration.
|
||||
ra := c.AuthorityConfig.Options
|
||||
// The default RA/CAS requires root, crt and key.
|
||||
if ra.Is(cas.SoftCAS) {
|
||||
switch {
|
||||
case c.Root.HasEmpties():
|
||||
return errors.New("root cannot be empty")
|
||||
case c.IntermediateCert == "":
|
||||
return errors.New("crt cannot be empty")
|
||||
case c.IntermediateKey == "":
|
||||
return errors.New("key cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate address (a port is required)
|
||||
if _, _, err := net.SplitHostPort(c.Address); err != nil {
|
||||
return errors.Errorf("invalid address %s", c.Address)
|
||||
}
|
||||
|
||||
if c.TLS == nil {
|
||||
c.TLS = &DefaultTLSOptions
|
||||
} else {
|
||||
if len(c.TLS.CipherSuites) == 0 {
|
||||
c.TLS.CipherSuites = DefaultTLSOptions.CipherSuites
|
||||
}
|
||||
if c.TLS.MaxVersion == 0 {
|
||||
c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion
|
||||
}
|
||||
if c.TLS.MinVersion == 0 {
|
||||
c.TLS.MinVersion = DefaultTLSOptions.MinVersion
|
||||
}
|
||||
if c.TLS.MinVersion > c.TLS.MaxVersion {
|
||||
return errors.New("tls minVersion cannot exceed tls maxVersion")
|
||||
}
|
||||
c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation
|
||||
}
|
||||
|
||||
// Validate KMS options, nil is ok.
|
||||
if err := c.KMS.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate RA/CAS options, nil is ok.
|
||||
if err := ra.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate ssh: nil is ok
|
||||
if err := c.SSH.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate templates: nil is ok
|
||||
if err := c.Templates.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.AuthorityConfig.Validate(c.GetAudiences())
|
||||
}
|
||||
|
||||
// GetAudiences returns the legacy and possible urls without the ports that will
|
||||
// be used as the default provisioner audiences. The CA might have proxies in
|
||||
// front so we cannot rely on the port.
|
||||
func (c *Config) GetAudiences() provisioner.Audiences {
|
||||
audiences := provisioner.Audiences{
|
||||
Sign: []string{legacyAuthority},
|
||||
Revoke: []string{legacyAuthority},
|
||||
SSHSign: []string{},
|
||||
SSHRevoke: []string{},
|
||||
SSHRenew: []string{},
|
||||
}
|
||||
|
||||
for _, name := range c.DNSNames {
|
||||
audiences.Sign = append(audiences.Sign,
|
||||
fmt.Sprintf("https://%s/1.0/sign", name),
|
||||
fmt.Sprintf("https://%s/sign", name),
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/ssh/sign", name))
|
||||
audiences.Revoke = append(audiences.Revoke,
|
||||
fmt.Sprintf("https://%s/1.0/revoke", name),
|
||||
fmt.Sprintf("https://%s/revoke", name))
|
||||
audiences.SSHSign = append(audiences.SSHSign,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/1.0/sign", name),
|
||||
fmt.Sprintf("https://%s/sign", name))
|
||||
audiences.SSHRevoke = append(audiences.SSHRevoke,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/revoke", name),
|
||||
fmt.Sprintf("https://%s/ssh/revoke", name))
|
||||
audiences.SSHRenew = append(audiences.SSHRenew,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/renew", name),
|
||||
fmt.Sprintf("https://%s/ssh/renew", name))
|
||||
audiences.SSHRekey = append(audiences.SSHRekey,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/rekey", name),
|
||||
fmt.Sprintf("https://%s/ssh/rekey", name))
|
||||
}
|
||||
|
||||
return audiences
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package authority
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -8,12 +8,14 @@ import (
|
|||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
_ "github.com/smallstep/certificates/cas"
|
||||
)
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
maxjwk, err := jose.ReadKey("testdata/secrets/max_pub.jwk")
|
||||
maxjwk, err := jose.ReadKey("../testdata/secrets/max_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
clijwk, err := jose.ReadKey("testdata/secrets/step_cli_key_pub.jwk")
|
||||
clijwk, err := jose.ReadKey("../testdata/secrets/step_cli_key_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
ac := &AuthConfig{
|
||||
Provisioners: provisioner.List{
|
||||
|
@ -39,9 +41,9 @@ func TestConfigValidate(t *testing.T) {
|
|||
"empty-address": func(t *testing.T) ConfigValidateTest {
|
||||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
Root: []string{"../testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "../testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "../testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
|
@ -53,9 +55,9 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
Root: []string{"../testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "../testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "../testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
|
@ -67,8 +69,8 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
IntermediateCert: "../testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "../testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
|
@ -80,8 +82,8 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
Root: []string{"../testdata/secrets/root_ca.crt"},
|
||||
IntermediateKey: "../testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
|
@ -93,8 +95,8 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
Root: []string{"../testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "../testdata/secrets/intermediate_ca.crt",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
|
@ -106,9 +108,9 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
Root: []string{"../testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "../testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "../testdata/secrets/intermediate_ca_key",
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
},
|
||||
|
@ -119,9 +121,9 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
Root: []string{"../testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "../testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "../testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
|
@ -133,9 +135,9 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
Root: []string{"../testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "../testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "../testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
|
@ -148,9 +150,9 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
Root: []string{"../testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "../testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "../testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
|
@ -177,9 +179,9 @@ func TestConfigValidate(t *testing.T) {
|
|||
return ConfigValidateTest{
|
||||
config: &Config{
|
||||
Address: "127.0.0.1:443",
|
||||
Root: []string{"testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "testdata/secrets/intermediate_ca_key",
|
||||
Root: []string{"../testdata/secrets/root_ca.crt"},
|
||||
IntermediateCert: "../testdata/secrets/intermediate_ca.crt",
|
||||
IntermediateKey: "../testdata/secrets/intermediate_ca_key",
|
||||
DNSNames: []string{"test.smallstep.com"},
|
||||
Password: "pass",
|
||||
AuthorityConfig: ac,
|
||||
|
@ -207,6 +209,8 @@ func TestConfigValidate(t *testing.T) {
|
|||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
fmt.Printf("tc.tls = %+v\n", tc.tls)
|
||||
fmt.Printf("*tc.config.TLS = %+v\n", *tc.config.TLS)
|
||||
assert.Equals(t, *tc.config.TLS, tc.tls)
|
||||
}
|
||||
}
|
||||
|
@ -224,9 +228,9 @@ func TestAuthConfigValidate(t *testing.T) {
|
|||
CommonName: "test",
|
||||
}
|
||||
|
||||
maxjwk, err := jose.ReadKey("testdata/secrets/max_pub.jwk")
|
||||
maxjwk, err := jose.ReadKey("../testdata/secrets/max_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
clijwk, err := jose.ReadKey("testdata/secrets/step_cli_key_pub.jwk")
|
||||
clijwk, err := jose.ReadKey("../testdata/secrets/step_cli_key_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
p := provisioner.List{
|
||||
&provisioner.JWK{
|
94
authority/config/ssh.go
Normal file
94
authority/config/ssh.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SSHConfig contains the user and host keys.
|
||||
type SSHConfig struct {
|
||||
HostKey string `json:"hostKey"`
|
||||
UserKey string `json:"userKey"`
|
||||
Keys []*SSHPublicKey `json:"keys,omitempty"`
|
||||
AddUserPrincipal string `json:"addUserPrincipal,omitempty"`
|
||||
AddUserCommand string `json:"addUserCommand,omitempty"`
|
||||
Bastion *Bastion `json:"bastion,omitempty"`
|
||||
}
|
||||
|
||||
// Bastion contains the custom properties used on bastion.
|
||||
type Bastion struct {
|
||||
Hostname string `json:"hostname"`
|
||||
User string `json:"user,omitempty"`
|
||||
Port string `json:"port,omitempty"`
|
||||
Command string `json:"cmd,omitempty"`
|
||||
Flags string `json:"flags,omitempty"`
|
||||
}
|
||||
|
||||
// HostTag are tagged with k,v pairs. These tags are how a user is ultimately
|
||||
// associated with a host.
|
||||
type HostTag struct {
|
||||
ID string
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Host defines expected attributes for an ssh host.
|
||||
type Host struct {
|
||||
HostID string `json:"hid"`
|
||||
HostTags []HostTag `json:"host_tags"`
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
|
||||
// Validate checks the fields in SSHConfig.
|
||||
func (c *SSHConfig) Validate() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for _, k := range c.Keys {
|
||||
if err := k.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SSHPublicKey contains a public key used by federated CAs to keep old signing
|
||||
// keys for this ca.
|
||||
type SSHPublicKey struct {
|
||||
Type string `json:"type"`
|
||||
Federated bool `json:"federated"`
|
||||
Key jose.JSONWebKey `json:"key"`
|
||||
publicKey ssh.PublicKey
|
||||
}
|
||||
|
||||
// Validate checks the fields in SSHPublicKey.
|
||||
func (k *SSHPublicKey) Validate() error {
|
||||
switch {
|
||||
case k.Type == "":
|
||||
return errors.New("type cannot be empty")
|
||||
case k.Type != provisioner.SSHHostCert && k.Type != provisioner.SSHUserCert:
|
||||
return errors.Errorf("invalid type %s, it must be user or host", k.Type)
|
||||
case !k.Key.IsPublic():
|
||||
return errors.New("invalid key type, it must be a public key")
|
||||
}
|
||||
|
||||
key, err := ssh.NewPublicKey(k.Key.Key)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error creating ssh key")
|
||||
}
|
||||
k.publicKey = key
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublicKey returns the ssh public key.
|
||||
func (k *SSHPublicKey) PublicKey() ssh.PublicKey {
|
||||
return k.publicKey
|
||||
}
|
||||
|
||||
// SSHKeys represents the SSH User and Host public keys.
|
||||
type SSHKeys struct {
|
||||
UserKeys []ssh.PublicKey
|
||||
HostKeys []ssh.PublicKey
|
||||
}
|
73
authority/config/ssh_test.go
Normal file
73
authority/config/ssh_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"go.step.sm/crypto/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestSSHPublicKey_Validate(t *testing.T) {
|
||||
key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
Type string
|
||||
Federated bool
|
||||
Key jose.JSONWebKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"user", fields{"user", true, key.Public()}, false},
|
||||
{"host", fields{"host", false, key.Public()}, false},
|
||||
{"empty", fields{"", true, key.Public()}, true},
|
||||
{"badType", fields{"bad", false, key.Public()}, true},
|
||||
{"badKey", fields{"user", false, *key}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &SSHPublicKey{
|
||||
Type: tt.fields.Type,
|
||||
Federated: tt.fields.Federated,
|
||||
Key: tt.fields.Key,
|
||||
}
|
||||
if err := k.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SSHPublicKey.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHPublicKey_PublicKey(t *testing.T) {
|
||||
key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
pub, err := ssh.NewPublicKey(key.Public().Key)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
publicKey ssh.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want ssh.PublicKey
|
||||
}{
|
||||
{"ok", fields{pub}, pub},
|
||||
{"nil", fields{nil}, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &SSHPublicKey{
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
if got := k.PublicKey(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SSHPublicKey.PublicKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package authority
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
@ -15,8 +15,9 @@ var (
|
|||
// DefaultTLSRenegotiation default TLS connection renegotiation policy.
|
||||
DefaultTLSRenegotiation = false // Never regnegotiate.
|
||||
// DefaultTLSCipherSuites specifies default step ciphersuite(s).
|
||||
// These are TLS 1.0 - 1.2 cipher suites.
|
||||
DefaultTLSCipherSuites = CipherSuites{
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
||||
}
|
||||
// ApprovedTLSCipherSuites smallstep approved ciphersuites.
|
||||
|
@ -26,13 +27,21 @@ var (
|
|||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
|
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305",
|
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
|
||||
}
|
||||
// DefaultTLSOptions represents the default TLS version as well as the cipher
|
||||
// suites used in the TLS certificates.
|
||||
DefaultTLSOptions = TLSOptions{
|
||||
CipherSuites: DefaultTLSCipherSuites,
|
||||
MinVersion: DefaultTLSMinVersion,
|
||||
MaxVersion: DefaultTLSMaxVersion,
|
||||
Renegotiation: DefaultTLSRenegotiation,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -107,27 +116,38 @@ func (c CipherSuites) Value() []uint16 {
|
|||
|
||||
// cipherSuites has the list of supported cipher suites.
|
||||
var cipherSuites = map[string]uint16{
|
||||
"TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||
"TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
// TLS 1.0 - 1.2 cipher suites.
|
||||
"TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||
"TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
|
||||
// TLS 1.3 cipher sutes.
|
||||
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
|
||||
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
|
||||
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
|
||||
// Legacy names.
|
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
}
|
||||
|
||||
// TLSOptions represents the TLS options that can be specified on *tls.Config
|
|
@ -1,4 +1,4 @@
|
|||
package authority
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
|
@ -1,4 +1,4 @@
|
|||
package authority
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
@ -25,7 +25,7 @@ func (s multiString) HasEmpties() bool {
|
|||
return true
|
||||
}
|
||||
for _, ss := range s {
|
||||
if len(ss) == 0 {
|
||||
if ss == "" {
|
||||
return true
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package authority
|
||||
package config
|
||||
|
||||
import (
|
||||
"reflect"
|
284
authority/export.go
Normal file
284
authority/export.go
Normal file
|
@ -0,0 +1,284 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/cli-utils/config"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
// Export creates a linkedca configuration form the current ca.json and loaded
|
||||
// authorities.
|
||||
//
|
||||
// Note that export will not export neither the pki password nor the certificate
|
||||
// issuer password.
|
||||
func (a *Authority) Export() (c *linkedca.Configuration, err error) {
|
||||
// Recover from panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
files := make(map[string][]byte)
|
||||
|
||||
// The exported configuration should not include the password in it.
|
||||
c = &linkedca.Configuration{
|
||||
Version: "1.0",
|
||||
Root: mustReadFilesOrURIs(a.config.Root, files),
|
||||
FederatedRoots: mustReadFilesOrURIs(a.config.FederatedRoots, files),
|
||||
Intermediate: mustReadFileOrURI(a.config.IntermediateCert, files),
|
||||
IntermediateKey: mustReadFileOrURI(a.config.IntermediateKey, files),
|
||||
Address: a.config.Address,
|
||||
InsecureAddress: a.config.InsecureAddress,
|
||||
DnsNames: a.config.DNSNames,
|
||||
Db: mustMarshalToStruct(a.config.DB),
|
||||
Logger: mustMarshalToStruct(a.config.Logger),
|
||||
Monitoring: mustMarshalToStruct(a.config.Monitoring),
|
||||
Authority: &linkedca.Authority{
|
||||
Id: a.config.AuthorityConfig.AuthorityID,
|
||||
EnableAdmin: a.config.AuthorityConfig.EnableAdmin,
|
||||
DisableIssuedAtCheck: a.config.AuthorityConfig.DisableIssuedAtCheck,
|
||||
Backdate: mustDuration(a.config.AuthorityConfig.Backdate),
|
||||
DeploymentType: a.config.AuthorityConfig.DeploymentType,
|
||||
},
|
||||
Files: files,
|
||||
}
|
||||
|
||||
// SSH
|
||||
if v := a.config.SSH; v != nil {
|
||||
c.Ssh = &linkedca.SSH{
|
||||
HostKey: mustReadFileOrURI(v.HostKey, files),
|
||||
UserKey: mustReadFileOrURI(v.UserKey, files),
|
||||
AddUserPrincipal: v.AddUserPrincipal,
|
||||
AddUserCommand: v.AddUserCommand,
|
||||
}
|
||||
for _, k := range v.Keys {
|
||||
typ, ok := linkedca.SSHPublicKey_Type_value[strings.ToUpper(k.Type)]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unsupported ssh key type %s", k.Type)
|
||||
}
|
||||
c.Ssh.Keys = append(c.Ssh.Keys, &linkedca.SSHPublicKey{
|
||||
Type: linkedca.SSHPublicKey_Type(typ),
|
||||
Federated: k.Federated,
|
||||
Key: mustMarshalToStruct(k),
|
||||
})
|
||||
}
|
||||
if b := v.Bastion; b != nil {
|
||||
c.Ssh.Bastion = &linkedca.Bastion{
|
||||
Hostname: b.Hostname,
|
||||
User: b.User,
|
||||
Port: b.Port,
|
||||
Command: b.Command,
|
||||
Flags: b.Flags,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// KMS
|
||||
if v := a.config.KMS; v != nil {
|
||||
var typ int32
|
||||
var ok bool
|
||||
if v.Type == "" {
|
||||
typ = int32(linkedca.KMS_SOFTKMS)
|
||||
} else {
|
||||
typ, ok = linkedca.KMS_Type_value[strings.ToUpper(v.Type)]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unsupported kms type %s", v.Type)
|
||||
}
|
||||
}
|
||||
c.Kms = &linkedca.KMS{
|
||||
Type: linkedca.KMS_Type(typ),
|
||||
CredentialsFile: v.CredentialsFile,
|
||||
Uri: v.URI,
|
||||
Pin: v.Pin,
|
||||
ManagementKey: v.ManagementKey,
|
||||
Region: v.Region,
|
||||
Profile: v.Profile,
|
||||
}
|
||||
}
|
||||
|
||||
// Authority
|
||||
// cas options
|
||||
if v := a.config.AuthorityConfig.Options; v != nil {
|
||||
c.Authority.Type = 0
|
||||
c.Authority.CertificateAuthority = v.CertificateAuthority
|
||||
c.Authority.CertificateAuthorityFingerprint = v.CertificateAuthorityFingerprint
|
||||
c.Authority.CredentialsFile = v.CredentialsFile
|
||||
if iss := v.CertificateIssuer; iss != nil {
|
||||
typ, ok := linkedca.CertificateIssuer_Type_value[strings.ToUpper(iss.Type)]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown certificate issuer type %s", iss.Type)
|
||||
}
|
||||
// The exported certificate issuer should not include the password.
|
||||
c.Authority.CertificateIssuer = &linkedca.CertificateIssuer{
|
||||
Type: linkedca.CertificateIssuer_Type(typ),
|
||||
Provisioner: iss.Provisioner,
|
||||
Certificate: mustReadFileOrURI(iss.Certificate, files),
|
||||
Key: mustReadFileOrURI(iss.Key, files),
|
||||
}
|
||||
}
|
||||
}
|
||||
// admins
|
||||
for {
|
||||
list, cursor := a.admins.Find("", 100)
|
||||
c.Authority.Admins = append(c.Authority.Admins, list...)
|
||||
if cursor == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
// provisioners
|
||||
for {
|
||||
list, cursor := a.provisioners.Find("", 100)
|
||||
for _, p := range list {
|
||||
lp, err := ProvisionerToLinkedca(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Authority.Provisioners = append(c.Authority.Provisioners, lp)
|
||||
}
|
||||
if cursor == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
// global claims
|
||||
c.Authority.Claims = claimsToLinkedca(a.config.AuthorityConfig.Claims)
|
||||
// Distinguished names template
|
||||
if v := a.config.AuthorityConfig.Template; v != nil {
|
||||
c.Authority.Template = &linkedca.DistinguishedName{
|
||||
Country: v.Country,
|
||||
Organization: v.Organization,
|
||||
OrganizationalUnit: v.OrganizationalUnit,
|
||||
Locality: v.Locality,
|
||||
Province: v.Province,
|
||||
StreetAddress: v.StreetAddress,
|
||||
SerialNumber: v.SerialNumber,
|
||||
CommonName: v.CommonName,
|
||||
}
|
||||
}
|
||||
|
||||
// TLS
|
||||
if v := a.config.TLS; v != nil {
|
||||
c.Tls = &linkedca.TLS{
|
||||
MinVersion: v.MinVersion.String(),
|
||||
MaxVersion: v.MaxVersion.String(),
|
||||
Renegotiation: v.Renegotiation,
|
||||
}
|
||||
for _, cs := range v.CipherSuites.Value() {
|
||||
c.Tls.CipherSuites = append(c.Tls.CipherSuites, linkedca.TLS_CiperSuite(cs))
|
||||
}
|
||||
}
|
||||
|
||||
// Templates
|
||||
if v := a.config.Templates; v != nil {
|
||||
c.Templates = &linkedca.ConfigTemplates{
|
||||
Ssh: &linkedca.SSHConfigTemplate{},
|
||||
Data: mustMarshalToStruct(v.Data),
|
||||
}
|
||||
// Remove automatically loaded vars
|
||||
if c.Templates.Data != nil && c.Templates.Data.Fields != nil {
|
||||
delete(c.Templates.Data.Fields, "Step")
|
||||
}
|
||||
for _, t := range v.SSH.Host {
|
||||
typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unsupported template type %s", t.Type)
|
||||
}
|
||||
c.Templates.Ssh.Hosts = append(c.Templates.Ssh.Hosts, &linkedca.ConfigTemplate{
|
||||
Type: linkedca.ConfigTemplate_Type(typ),
|
||||
Name: t.Name,
|
||||
Template: mustReadFileOrURI(t.TemplatePath, files),
|
||||
Path: t.Path,
|
||||
Comment: t.Comment,
|
||||
Requires: t.RequiredData,
|
||||
Content: t.Content,
|
||||
})
|
||||
}
|
||||
for _, t := range v.SSH.User {
|
||||
typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unsupported template type %s", t.Type)
|
||||
}
|
||||
c.Templates.Ssh.Users = append(c.Templates.Ssh.Users, &linkedca.ConfigTemplate{
|
||||
Type: linkedca.ConfigTemplate_Type(typ),
|
||||
Name: t.Name,
|
||||
Template: mustReadFileOrURI(t.TemplatePath, files),
|
||||
Path: t.Path,
|
||||
Comment: t.Comment,
|
||||
Requires: t.RequiredData,
|
||||
Content: t.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func mustDuration(d *provisioner.Duration) string {
|
||||
if d == nil || d.Duration == 0 {
|
||||
return ""
|
||||
}
|
||||
return d.String()
|
||||
}
|
||||
|
||||
func mustMarshalToStruct(v interface{}) *structpb.Struct {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(errors.Wrapf(err, "error marshaling %T", v))
|
||||
}
|
||||
var r *structpb.Struct
|
||||
if err := json.Unmarshal(b, &r); err != nil {
|
||||
panic(errors.Wrapf(err, "error unmarshaling %T", v))
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func mustReadFileOrURI(fn string, m map[string][]byte) string {
|
||||
if fn == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
stepPath := filepath.ToSlash(config.StepPath())
|
||||
if !strings.HasSuffix(stepPath, "/") {
|
||||
stepPath += "/"
|
||||
}
|
||||
|
||||
fn = strings.TrimPrefix(filepath.ToSlash(fn), stepPath)
|
||||
|
||||
ok, err := isFilename(fn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if ok {
|
||||
b, err := ioutil.ReadFile(config.StepAbs(fn))
|
||||
if err != nil {
|
||||
panic(errors.Wrapf(err, "error reading %s", fn))
|
||||
}
|
||||
m[fn] = b
|
||||
return fn
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
func mustReadFilesOrURIs(fns []string, m map[string][]byte) []string {
|
||||
var result []string
|
||||
for _, fn := range fns {
|
||||
result = append(result, mustReadFileOrURI(fn, m))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func isFilename(fn string) (bool, error) {
|
||||
u, err := url.Parse(fn)
|
||||
if err != nil {
|
||||
return false, errors.Wrapf(err, "error parsing %s", fn)
|
||||
}
|
||||
return u.Scheme == "" || u.Scheme == "file", nil
|
||||
}
|
490
authority/linkedca.go
Normal file
490
authority/linkedca.go
Normal file
|
@ -0,0 +1,490 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/tlsutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"go.step.sm/linkedca"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
|
||||
|
||||
type linkedCaClient struct {
|
||||
renewer *tlsutil.Renewer
|
||||
client linkedca.MajordomoClient
|
||||
authorityID string
|
||||
}
|
||||
|
||||
type linkedCAClaims struct {
|
||||
jose.Claims
|
||||
SANs []string `json:"sans"`
|
||||
SHA string `json:"sha"`
|
||||
}
|
||||
|
||||
func newLinkedCAClient(token string) (*linkedCaClient, error) {
|
||||
tok, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing token")
|
||||
}
|
||||
|
||||
var claims linkedCAClaims
|
||||
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing token")
|
||||
}
|
||||
// Validate claims
|
||||
if len(claims.Audience) != 1 {
|
||||
return nil, errors.New("error parsing token: invalid aud claim")
|
||||
}
|
||||
if claims.SHA == "" {
|
||||
return nil, errors.New("error parsing token: invalid sha claim")
|
||||
}
|
||||
// Get linkedCA endpoint from audience.
|
||||
u, err := url.Parse(claims.Audience[0])
|
||||
if err != nil {
|
||||
return nil, errors.New("error parsing token: invalid aud claim")
|
||||
}
|
||||
// Get authority from SANs
|
||||
authority, err := getAuthority(claims.SANs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create csr to login with
|
||||
signer, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
csr, err := x509util.CreateCertificateRequest(claims.Subject, claims.SANs, signer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get and verify root certificate
|
||||
root, err := getRootCertificate(u.Host, claims.SHA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(root)
|
||||
|
||||
// Login with majordomo and get certificates
|
||||
cert, tlsConfig, err := login(authority, token, csr, signer, u.Host, pool)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start TLS renewer and set the GetClientCertificate callback to it.
|
||||
renewer, err := tlsutil.NewRenewer(cert, tlsConfig, func() (*tls.Certificate, *tls.Config, error) {
|
||||
return login(authority, token, csr, signer, u.Host, pool)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||
|
||||
// Start mTLS client
|
||||
conn, err := grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error connecting %s", u.Host)
|
||||
}
|
||||
|
||||
return &linkedCaClient{
|
||||
renewer: renewer,
|
||||
client: linkedca.NewMajordomoClient(conn),
|
||||
authorityID: authority,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) Run() {
|
||||
c.renewer.Run()
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) Stop() {
|
||||
c.renewer.Stop()
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
resp, err := c.client.CreateProvisioner(ctx, &linkedca.CreateProvisionerRequest{
|
||||
Type: prov.Type,
|
||||
Name: prov.Name,
|
||||
Details: prov.Details,
|
||||
Claims: prov.Claims,
|
||||
X509Template: prov.X509Template,
|
||||
SshTemplate: prov.SshTemplate,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error creating provisioner")
|
||||
}
|
||||
prov.Id = resp.Id
|
||||
prov.AuthorityId = resp.AuthorityId
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
resp, err := c.client.GetProvisioner(ctx, &linkedca.GetProvisionerRequest{
|
||||
Id: id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting provisioners")
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
||||
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
||||
AuthorityId: c.authorityID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting provisioners")
|
||||
}
|
||||
return resp.Provisioners, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
_, err := c.client.UpdateProvisioner(ctx, &linkedca.UpdateProvisionerRequest{
|
||||
Id: prov.Id,
|
||||
Name: prov.Name,
|
||||
Details: prov.Details,
|
||||
Claims: prov.Claims,
|
||||
X509Template: prov.X509Template,
|
||||
SshTemplate: prov.SshTemplate,
|
||||
})
|
||||
return errors.Wrap(err, "error updating provisioner")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) DeleteProvisioner(ctx context.Context, id string) error {
|
||||
_, err := c.client.DeleteProvisioner(ctx, &linkedca.DeleteProvisionerRequest{
|
||||
Id: id,
|
||||
})
|
||||
return errors.Wrap(err, "error deleting provisioner")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
resp, err := c.client.CreateAdmin(ctx, &linkedca.CreateAdminRequest{
|
||||
Subject: adm.Subject,
|
||||
ProvisionerId: adm.ProvisionerId,
|
||||
Type: adm.Type,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error creating admin")
|
||||
}
|
||||
adm.Id = resp.Id
|
||||
adm.AuthorityId = resp.AuthorityId
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) {
|
||||
resp, err := c.client.GetAdmin(ctx, &linkedca.GetAdminRequest{
|
||||
Id: id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting admins")
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
||||
AuthorityId: c.authorityID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting admins")
|
||||
}
|
||||
return resp.Admins, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
_, err := c.client.UpdateAdmin(ctx, &linkedca.UpdateAdminRequest{
|
||||
Id: adm.Id,
|
||||
Type: adm.Type,
|
||||
})
|
||||
return errors.Wrap(err, "error updating admin")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
|
||||
_, err := c.client.DeleteAdmin(ctx, &linkedca.DeleteAdminRequest{
|
||||
Id: id,
|
||||
})
|
||||
return errors.Wrap(err, "error deleting admin")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
|
||||
PemCertificate: serializeCertificateChain(fullchain[0]),
|
||||
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
||||
})
|
||||
return errors.Wrap(err, "error posting certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullchain ...*x509.Certificate) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
|
||||
PemCertificate: serializeCertificateChain(fullchain[0]),
|
||||
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
||||
PemParentCertificate: serializeCertificateChain(parent),
|
||||
})
|
||||
return errors.Wrap(err, "error posting certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
|
||||
Certificate: string(ssh.MarshalAuthorizedKey(crt)),
|
||||
})
|
||||
return errors.Wrap(err, "error posting ssh certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.RevokeCertificate(ctx, &linkedca.RevokeCertificateRequest{
|
||||
Serial: rci.Serial,
|
||||
PemCertificate: serializeCertificate(crt),
|
||||
Reason: rci.Reason,
|
||||
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
|
||||
Passive: true,
|
||||
})
|
||||
|
||||
return errors.Wrap(err, "error revoking certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) RevokeSSH(cert *ssh.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{
|
||||
Serial: rci.Serial,
|
||||
Certificate: serializeSSHCertificate(cert),
|
||||
Reason: rci.Reason,
|
||||
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
|
||||
Passive: true,
|
||||
})
|
||||
|
||||
return errors.Wrap(err, "error revoking ssh certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) IsRevoked(serial string) (bool, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
resp, err := c.client.GetCertificateStatus(ctx, &linkedca.GetCertificateStatusRequest{
|
||||
Serial: serial,
|
||||
})
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "error getting certificate status")
|
||||
}
|
||||
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
resp, err := c.client.GetSSHCertificateStatus(ctx, &linkedca.GetSSHCertificateStatusRequest{
|
||||
Serial: serial,
|
||||
})
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "error getting certificate status")
|
||||
}
|
||||
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
||||
}
|
||||
|
||||
func serializeCertificate(crt *x509.Certificate) string {
|
||||
if crt == nil {
|
||||
return ""
|
||||
}
|
||||
return string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: crt.Raw,
|
||||
}))
|
||||
}
|
||||
|
||||
func serializeCertificateChain(fullchain ...*x509.Certificate) string {
|
||||
var chain string
|
||||
for _, crt := range fullchain {
|
||||
chain += string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: crt.Raw,
|
||||
}))
|
||||
}
|
||||
return chain
|
||||
}
|
||||
|
||||
func serializeSSHCertificate(crt *ssh.Certificate) string {
|
||||
if crt == nil {
|
||||
return ""
|
||||
}
|
||||
return string(ssh.MarshalAuthorizedKey(crt))
|
||||
}
|
||||
|
||||
func getAuthority(sans []string) (string, error) {
|
||||
for _, s := range sans {
|
||||
if strings.HasPrefix(s, "urn:smallstep:authority:") {
|
||||
if regexp.MustCompile(uuidPattern).MatchString(s[24:]) {
|
||||
return s[24:], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("error parsing token: invalid sans claim")
|
||||
}
|
||||
|
||||
// getRootCertificate creates an insecure majordomo client and returns the
|
||||
// verified root certificate.
|
||||
func getRootCertificate(endpoint, fingerprint string) (*x509.Certificate, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error connecting %s", endpoint)
|
||||
}
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := linkedca.NewMajordomoClient(conn)
|
||||
resp, err := client.GetRootCertificate(ctx, &linkedca.GetRootCertificateRequest{
|
||||
Fingerprint: fingerprint,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting root certificate: %w", err)
|
||||
}
|
||||
|
||||
var block *pem.Block
|
||||
b := []byte(resp.PemCertificate)
|
||||
for len(b) > 0 {
|
||||
block, b = pem.Decode(b)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %w", err)
|
||||
}
|
||||
|
||||
// verify the sha256
|
||||
sum := sha256.Sum256(cert.Raw)
|
||||
if !strings.EqualFold(fingerprint, hex.EncodeToString(sum[:])) {
|
||||
return nil, fmt.Errorf("error verifying certificate: SHA256 fingerprint does not match")
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("error getting root certificate: certificate not found")
|
||||
}
|
||||
|
||||
// login creates a new majordomo client with just the root ca pool and returns
|
||||
// the signed certificate and tls configuration.
|
||||
func login(authority, token string, csr *x509.CertificateRequest, signer crypto.PrivateKey, endpoint string, rootCAs *x509.CertPool) (*tls.Certificate, *tls.Config, error) {
|
||||
// Connect to majordomo
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
})))
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "error connecting %s", endpoint)
|
||||
}
|
||||
|
||||
// Login to get the signed certificate
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := linkedca.NewMajordomoClient(conn)
|
||||
resp, err := client.Login(ctx, &linkedca.LoginRequest{
|
||||
AuthorityId: authority,
|
||||
Token: token,
|
||||
PemCertificateRequest: string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE REQUEST",
|
||||
Bytes: csr.Raw,
|
||||
})),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "error logging in %s", endpoint)
|
||||
}
|
||||
|
||||
// Parse login response
|
||||
var block *pem.Block
|
||||
var bundle []*x509.Certificate
|
||||
rest := []byte(resp.PemCertificateChain)
|
||||
for {
|
||||
block, rest = pem.Decode(rest)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return nil, nil, errors.New("error decoding login response: pemCertificateChain is not a certificate bundle")
|
||||
}
|
||||
crt, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error parsing login response")
|
||||
}
|
||||
bundle = append(bundle, crt)
|
||||
}
|
||||
if len(bundle) == 0 {
|
||||
return nil, nil, errors.New("error decoding login response: pemCertificateChain should not be empty")
|
||||
}
|
||||
|
||||
// Build tls.Certificate with PemCertificate and intermediates in the
|
||||
// PemCertificateChain
|
||||
cert := &tls.Certificate{
|
||||
PrivateKey: signer,
|
||||
}
|
||||
rest = []byte(resp.PemCertificate)
|
||||
for {
|
||||
block, rest = pem.Decode(rest)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type == "CERTIFICATE" {
|
||||
leaf, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error parsing pemCertificate")
|
||||
}
|
||||
cert.Certificate = append(cert.Certificate, block.Bytes)
|
||||
cert.Leaf = leaf
|
||||
}
|
||||
}
|
||||
|
||||
// Add intermediates to the tls.Certificate
|
||||
last := len(bundle) - 1
|
||||
for i := 0; i < last; i++ {
|
||||
cert.Certificate = append(cert.Certificate, bundle[i].Raw)
|
||||
}
|
||||
|
||||
// Add root to the pool if it's not there yet
|
||||
rootCAs.AddCert(bundle[last])
|
||||
|
||||
return cert, &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
}, nil
|
||||
}
|
|
@ -7,6 +7,8 @@ import (
|
|||
"encoding/pem"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/cas"
|
||||
casapi "github.com/smallstep/certificates/cas/apiv1"
|
||||
|
@ -20,9 +22,9 @@ type Option func(*Authority) error
|
|||
|
||||
// WithConfig replaces the current config with the given one. No validation is
|
||||
// performed in the given value.
|
||||
func WithConfig(config *Config) Option {
|
||||
func WithConfig(cfg *config.Config) Option {
|
||||
return func(a *Authority) error {
|
||||
a.config = config
|
||||
a.config = cfg
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -31,16 +33,52 @@ func WithConfig(config *Config) Option {
|
|||
// the current one. No validation is performed in the given configuration.
|
||||
func WithConfigFile(filename string) Option {
|
||||
return func(a *Authority) (err error) {
|
||||
a.config, err = LoadConfiguration(filename)
|
||||
a.config, err = config.LoadConfiguration(filename)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithPassword set the password to decrypt the intermediate key as well as the
|
||||
// ssh host and user keys if they are not overridden by other options.
|
||||
func WithPassword(password []byte) Option {
|
||||
return func(a *Authority) (err error) {
|
||||
a.password = password
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithSSHHostPassword set the password to decrypt the key used to sign SSH host
|
||||
// certificates.
|
||||
func WithSSHHostPassword(password []byte) Option {
|
||||
return func(a *Authority) (err error) {
|
||||
a.sshHostPassword = password
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithSSHUserPassword set the password to decrypt the key used to sign SSH user
|
||||
// certificates.
|
||||
func WithSSHUserPassword(password []byte) Option {
|
||||
return func(a *Authority) (err error) {
|
||||
a.sshUserPassword = password
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithIssuerPassword set the password to decrypt the certificate issuer private
|
||||
// key used in RA mode.
|
||||
func WithIssuerPassword(password []byte) Option {
|
||||
return func(a *Authority) (err error) {
|
||||
a.issuerPassword = password
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithDatabase sets an already initialized authority database to a new
|
||||
// authority. This option is intended to be use on graceful reloads.
|
||||
func WithDatabase(db db.AuthDB) Option {
|
||||
func WithDatabase(d db.AuthDB) Option {
|
||||
return func(a *Authority) error {
|
||||
a.db = db
|
||||
a.db = d
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -56,7 +94,7 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e
|
|||
|
||||
// WithSSHBastionFunc sets a custom function to get the bastion for a
|
||||
// given user-host pair.
|
||||
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*Bastion, error)) Option {
|
||||
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option {
|
||||
return func(a *Authority) error {
|
||||
a.sshBastionFunc = fn
|
||||
return nil
|
||||
|
@ -65,7 +103,7 @@ func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*Bastio
|
|||
|
||||
// WithSSHGetHosts sets a custom function to get the bastion for a
|
||||
// given user-host pair.
|
||||
func WithSSHGetHosts(fn func(ctx context.Context, cert *x509.Certificate) ([]Host, error)) Option {
|
||||
func WithSSHGetHosts(fn func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)) Option {
|
||||
return func(a *Authority) error {
|
||||
a.sshGetHostsFunc = fn
|
||||
return nil
|
||||
|
@ -186,6 +224,23 @@ func WithX509FederatedBundle(pemCerts []byte) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithAdminDB is an option to set the database backing the admin APIs.
|
||||
func WithAdminDB(d admin.DB) Option {
|
||||
return func(a *Authority) error {
|
||||
a.adminDB = d
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithLinkedCAToken is an option to set the authentication token used to enable
|
||||
// linked ca.
|
||||
func WithLinkedCAToken(token string) Option {
|
||||
return func(a *Authority) error {
|
||||
a.linkedCAToken = token
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
|
||||
var block *pem.Block
|
||||
var certs []*x509.Certificate
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
// provisioning flow.
|
||||
type ACME struct {
|
||||
*base
|
||||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
ForceCN bool `json:"forceCN,omitempty"`
|
||||
|
@ -23,6 +24,15 @@ type ACME struct {
|
|||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
func (p ACME) GetID() string {
|
||||
if p.ID != "" {
|
||||
return p.ID
|
||||
}
|
||||
return p.GetIDForToken()
|
||||
}
|
||||
|
||||
// GetIDForToken returns an identifier that will be used to load the provisioner
|
||||
// from a token.
|
||||
func (p *ACME) GetIDForToken() string {
|
||||
return "acme/" + p.Name
|
||||
}
|
||||
|
||||
|
@ -95,7 +105,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -61,7 +61,7 @@ func TestACME_Init(t *testing.T) {
|
|||
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{Name: "foo", Type: "bar", Claims: &Claims{DefaultTLSDur: &Duration{0}}},
|
||||
err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"),
|
||||
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||
|
@ -110,7 +110,7 @@ func TestACME_AuthorizeRenew(t *testing.T) {
|
|||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner %s", p.GetID()),
|
||||
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
|
|
@ -252,6 +252,7 @@ type awsInstanceIdentityDocument struct {
|
|||
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html
|
||||
type AWS struct {
|
||||
*base
|
||||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Accounts []string `json:"accounts"`
|
||||
|
@ -269,6 +270,15 @@ type AWS struct {
|
|||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
func (p *AWS) GetID() string {
|
||||
if p.ID != "" {
|
||||
return p.ID
|
||||
}
|
||||
return p.GetIDForToken()
|
||||
}
|
||||
|
||||
// GetIDForToken returns an identifier that will be used to load the provisioner
|
||||
// from a token.
|
||||
func (p *AWS) GetIDForToken() string {
|
||||
return "aws/" + p.Name
|
||||
}
|
||||
|
||||
|
@ -286,7 +296,7 @@ func (p *AWS) GetTokenID(token string) (string, error) {
|
|||
}
|
||||
|
||||
// Use provisioner + instance-id as the identifier.
|
||||
unique := fmt.Sprintf("%s.%s", p.GetID(), payload.document.InstanceID)
|
||||
unique := fmt.Sprintf("%s.%s", p.GetIDForToken(), payload.document.InstanceID)
|
||||
sum := sha256.Sum256([]byte(unique))
|
||||
return strings.ToLower(hex.EncodeToString(sum[:])), nil
|
||||
}
|
||||
|
@ -302,7 +312,7 @@ func (p *AWS) GetType() Type {
|
|||
}
|
||||
|
||||
// GetEncryptedKey is not available in an AWS provisioner.
|
||||
func (p *AWS) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||
func (p *AWS) GetEncryptedKey() (kid, key string, ok bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
|
@ -334,7 +344,7 @@ func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
audience, err := generateSignAudience(caURL, p.GetID())
|
||||
audience, err := generateSignAudience(caURL, p.GetIDForToken())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -342,7 +352,7 @@ func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) {
|
|||
// Create unique ID for Trust On First Use (TOFU). Only the first instance
|
||||
// per provisioner is allowed as we don't have a way to trust the given
|
||||
// sans.
|
||||
unique := fmt.Sprintf("%s.%s", p.GetID(), idoc.InstanceID)
|
||||
unique := fmt.Sprintf("%s.%s", p.GetIDForToken(), idoc.InstanceID)
|
||||
sum := sha256.Sum256([]byte(unique))
|
||||
|
||||
// Create a JWT from the identity document
|
||||
|
@ -397,7 +407,7 @@ func (p *AWS) Init(config Config) (err error) {
|
|||
if p.config, err = newAWSConfig(p.IIDRoots); err != nil {
|
||||
return err
|
||||
}
|
||||
p.audiences = config.Audiences.WithFragment(p.GetID())
|
||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
|
||||
// validate IMDS versions
|
||||
if len(p.IMDSVersions) == 0 {
|
||||
|
@ -439,13 +449,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// There's no way to trust them other than TOFU.
|
||||
var so []SignOption
|
||||
if p.DisableCustomSANs {
|
||||
dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region)
|
||||
so = append(so, dnsNamesValidator([]string{dnsName}))
|
||||
so = append(so, ipAddressesValidator([]net.IP{
|
||||
net.ParseIP(doc.PrivateIP),
|
||||
}))
|
||||
so = append(so, emailAddressesValidator(nil))
|
||||
so = append(so, urisValidator(nil))
|
||||
dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region)
|
||||
so = append(so,
|
||||
dnsNamesValidator([]string{dnsName}),
|
||||
ipAddressesValidator([]net.IP{
|
||||
net.ParseIP(doc.PrivateIP),
|
||||
}),
|
||||
emailAddressesValidator(nil),
|
||||
urisValidator(nil),
|
||||
)
|
||||
|
||||
// Template options
|
||||
data.SetSANs([]string{dnsName, doc.PrivateIP})
|
||||
|
@ -474,7 +486,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -505,6 +517,11 @@ func (p *AWS) readURL(url string) ([]byte, error) {
|
|||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
// Initialize IMDS versions when this is called from the cli.
|
||||
if len(p.IMDSVersions) == 0 {
|
||||
p.IMDSVersions = []string{"v2", "v1"}
|
||||
}
|
||||
|
||||
for _, v := range p.IMDSVersions {
|
||||
switch v {
|
||||
case "v1":
|
||||
|
@ -654,7 +671,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
if p.DisableCustomSANs {
|
||||
if payload.Subject != doc.InstanceID &&
|
||||
payload.Subject != doc.PrivateIP &&
|
||||
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
|
||||
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region) {
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
|
||||
}
|
||||
}
|
||||
|
@ -687,7 +704,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
|
||||
}
|
||||
claims, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
|
@ -705,7 +722,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validated principals.
|
||||
principals := []string{
|
||||
doc.PrivateIP,
|
||||
fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region),
|
||||
fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region),
|
||||
}
|
||||
|
||||
// Only enforce known principals if disable custom sans is true.
|
||||
|
|
|
@ -141,6 +141,12 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
|||
p7.config.signatureURL = p1.config.signatureURL
|
||||
p7.config.tokenURL = p1.config.tokenURL
|
||||
|
||||
p8, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
p8.IMDSVersions = nil
|
||||
p8.Accounts = p1.Accounts
|
||||
p8.config = p1.config
|
||||
|
||||
caURL := "https://ca.smallstep.com"
|
||||
u, err := url.Parse(caURL)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -156,6 +162,7 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
|||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{"foo.local", caURL}, false},
|
||||
{"ok no imds", p8, args{"foo.local", caURL}, false},
|
||||
{"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true},
|
||||
{"fail identityURL", p2, args{"foo.local", caURL}, true},
|
||||
{"fail signatureURL", p3, args{"foo.local", caURL}, true},
|
||||
|
@ -656,15 +663,15 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
switch got, err := tt.aws.AuthorizeSign(ctx, tt.args.token); {
|
||||
case (err != nil) != tt.wantErr:
|
||||
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
} else if err != nil {
|
||||
case err != nil:
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
} else {
|
||||
default:
|
||||
assert.Len(t, tt.wantLen, got)
|
||||
for _, o := range got {
|
||||
switch v := o.(type) {
|
||||
|
|
|
@ -84,6 +84,7 @@ type azurePayload struct {
|
|||
// and https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service
|
||||
type Azure struct {
|
||||
*base
|
||||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
TenantID string `json:"tenantID"`
|
||||
|
@ -101,6 +102,15 @@ type Azure struct {
|
|||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
func (p *Azure) GetID() string {
|
||||
if p.ID != "" {
|
||||
return p.ID
|
||||
}
|
||||
return p.GetIDForToken()
|
||||
}
|
||||
|
||||
// GetIDForToken returns an identifier that will be used to load the provisioner
|
||||
// from a token.
|
||||
func (p *Azure) GetIDForToken() string {
|
||||
return p.TenantID
|
||||
}
|
||||
|
||||
|
@ -121,9 +131,10 @@ func (p *Azure) GetTokenID(token string) (string, error) {
|
|||
return "", errors.Wrap(err, "error verifying claims")
|
||||
}
|
||||
|
||||
// If TOFU is disabled create return the token kid
|
||||
// If TOFU is disabled then allow token re-use. Azure caches the token for
|
||||
// 24h and without allowing the re-use we cannot use it twice.
|
||||
if p.DisableTrustOnFirstUse {
|
||||
return claims.ID, nil
|
||||
return "", ErrAllowTokenReuse
|
||||
}
|
||||
|
||||
sum := sha256.Sum256([]byte(claims.XMSMirID))
|
||||
|
@ -141,7 +152,7 @@ func (p *Azure) GetType() Type {
|
|||
}
|
||||
|
||||
// GetEncryptedKey is not available in an Azure provisioner.
|
||||
func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||
func (p *Azure) GetEncryptedKey() (kid, key string, ok bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
|
@ -292,11 +303,13 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
var so []SignOption
|
||||
if p.DisableCustomSANs {
|
||||
// name will work only inside the virtual network
|
||||
so = append(so, commonNameValidator(name))
|
||||
so = append(so, dnsNamesValidator([]string{name}))
|
||||
so = append(so, ipAddressesValidator(nil))
|
||||
so = append(so, emailAddressesValidator(nil))
|
||||
so = append(so, urisValidator(nil))
|
||||
so = append(so,
|
||||
commonNameValidator(name),
|
||||
dnsNamesValidator([]string{name}),
|
||||
ipAddressesValidator(nil),
|
||||
emailAddressesValidator(nil),
|
||||
urisValidator(nil),
|
||||
)
|
||||
|
||||
// Enforce SANs in the template.
|
||||
data.SetSANs([]string{name})
|
||||
|
@ -324,7 +337,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -332,7 +345,7 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName())
|
||||
}
|
||||
|
||||
_, name, _, err := p.authorizeToken(token)
|
||||
|
|
|
@ -72,7 +72,7 @@ func TestAzure_GetTokenID(t *testing.T) {
|
|||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, w1, false},
|
||||
{"ok no TOFU", p2, args{t2}, "the-jti", false},
|
||||
{"ok no TOFU", p2, args{t2}, "", true},
|
||||
{"fail token", p1, args{"bad-token"}, "", true},
|
||||
{"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true},
|
||||
}
|
||||
|
@ -446,15 +446,15 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
switch got, err := tt.azure.AuthorizeSign(ctx, tt.args.token); {
|
||||
case (err != nil) != tt.wantErr:
|
||||
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
} else if err != nil {
|
||||
case err != nil:
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
} else {
|
||||
default:
|
||||
assert.Len(t, tt.wantLen, got)
|
||||
for _, o := range got {
|
||||
switch v := o.(type) {
|
||||
|
|
|
@ -71,6 +71,9 @@ func (c *Claimer) DefaultTLSCertDuration() time.Duration {
|
|||
// minimum from the authority configuration will be used.
|
||||
func (c *Claimer) MinTLSCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MinTLSDur == nil {
|
||||
if c.claims != nil && c.claims.DefaultTLSDur != nil && c.claims.DefaultTLSDur.Duration < c.global.MinTLSDur.Duration {
|
||||
return c.claims.DefaultTLSDur.Duration
|
||||
}
|
||||
return c.global.MinTLSDur.Duration
|
||||
}
|
||||
return c.claims.MinTLSDur.Duration
|
||||
|
@ -81,6 +84,9 @@ func (c *Claimer) MinTLSCertDuration() time.Duration {
|
|||
// maximum from the authority configuration will be used.
|
||||
func (c *Claimer) MaxTLSCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MaxTLSDur == nil {
|
||||
if c.claims != nil && c.claims.DefaultTLSDur != nil && c.claims.DefaultTLSDur.Duration > c.global.MaxTLSDur.Duration {
|
||||
return c.claims.DefaultTLSDur.Duration
|
||||
}
|
||||
return c.global.MaxTLSDur.Duration
|
||||
}
|
||||
return c.claims.MaxTLSDur.Duration
|
||||
|
@ -126,6 +132,9 @@ func (c *Claimer) DefaultUserSSHCertDuration() time.Duration {
|
|||
// global minimum from the authority configuration will be used.
|
||||
func (c *Claimer) MinUserSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MinUserSSHDur == nil {
|
||||
if c.claims != nil && c.claims.DefaultUserSSHDur != nil && c.claims.DefaultUserSSHDur.Duration < c.global.MinUserSSHDur.Duration {
|
||||
return c.claims.DefaultUserSSHDur.Duration
|
||||
}
|
||||
return c.global.MinUserSSHDur.Duration
|
||||
}
|
||||
return c.claims.MinUserSSHDur.Duration
|
||||
|
@ -136,6 +145,9 @@ func (c *Claimer) MinUserSSHCertDuration() time.Duration {
|
|||
// global maximum from the authority configuration will be used.
|
||||
func (c *Claimer) MaxUserSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MaxUserSSHDur == nil {
|
||||
if c.claims != nil && c.claims.DefaultUserSSHDur != nil && c.claims.DefaultUserSSHDur.Duration > c.global.MaxUserSSHDur.Duration {
|
||||
return c.claims.DefaultUserSSHDur.Duration
|
||||
}
|
||||
return c.global.MaxUserSSHDur.Duration
|
||||
}
|
||||
return c.claims.MaxUserSSHDur.Duration
|
||||
|
@ -156,6 +168,9 @@ func (c *Claimer) DefaultHostSSHCertDuration() time.Duration {
|
|||
// global minimum from the authority configuration will be used.
|
||||
func (c *Claimer) MinHostSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MinHostSSHDur == nil {
|
||||
if c.claims != nil && c.claims.DefaultHostSSHDur != nil && c.claims.DefaultHostSSHDur.Duration < c.global.MinHostSSHDur.Duration {
|
||||
return c.claims.DefaultHostSSHDur.Duration
|
||||
}
|
||||
return c.global.MinHostSSHDur.Duration
|
||||
}
|
||||
return c.claims.MinHostSSHDur.Duration
|
||||
|
@ -166,6 +181,9 @@ func (c *Claimer) MinHostSSHCertDuration() time.Duration {
|
|||
// global maximum from the authority configuration will be used.
|
||||
func (c *Claimer) MaxHostSSHCertDuration() time.Duration {
|
||||
if c.claims == nil || c.claims.MaxHostSSHDur == nil {
|
||||
if c.claims != nil && c.claims.DefaultHostSSHDur != nil && c.claims.DefaultHostSSHDur.Duration > c.global.MaxHostSSHDur.Duration {
|
||||
return c.claims.DefaultHostSSHDur.Duration
|
||||
}
|
||||
return c.global.MaxHostSSHDur.Duration
|
||||
}
|
||||
return c.claims.MaxHostSSHDur.Duration
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
|
@ -37,14 +37,17 @@ func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
|||
// provisioner.
|
||||
type loadByTokenPayload struct {
|
||||
jose.Claims
|
||||
AuthorizedParty string `json:"azp"` // OIDC client id
|
||||
TenantID string `json:"tid"` // Microsoft Azure tenant id
|
||||
Email string `json:"email"` // OIDC email
|
||||
AuthorizedParty string `json:"azp"` // OIDC client id
|
||||
TenantID string `json:"tid"` // Microsoft Azure tenant id
|
||||
}
|
||||
|
||||
// Collection is a memory map of provisioners.
|
||||
type Collection struct {
|
||||
byID *sync.Map
|
||||
byKey *sync.Map
|
||||
byName *sync.Map
|
||||
byTokenID *sync.Map
|
||||
sorted provisionerSlice
|
||||
audiences Audiences
|
||||
}
|
||||
|
@ -55,6 +58,8 @@ func NewCollection(audiences Audiences) *Collection {
|
|||
return &Collection{
|
||||
byID: new(sync.Map),
|
||||
byKey: new(sync.Map),
|
||||
byName: new(sync.Map),
|
||||
byTokenID: new(sync.Map),
|
||||
audiences: audiences,
|
||||
}
|
||||
}
|
||||
|
@ -64,6 +69,18 @@ func (c *Collection) Load(id string) (Interface, bool) {
|
|||
return loadProvisioner(c.byID, id)
|
||||
}
|
||||
|
||||
// LoadByName a provisioner by name.
|
||||
func (c *Collection) LoadByName(name string) (Interface, bool) {
|
||||
return loadProvisioner(c.byName, name)
|
||||
}
|
||||
|
||||
// LoadByTokenID a provisioner by identifier found in token.
|
||||
// For different provisioner types this identifier may be found in in different
|
||||
// attributes of the token.
|
||||
func (c *Collection) LoadByTokenID(tokenProvisionerID string) (Interface, bool) {
|
||||
return loadProvisioner(c.byTokenID, tokenProvisionerID)
|
||||
}
|
||||
|
||||
// LoadByToken parses the token claims and loads the provisioner associated.
|
||||
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) {
|
||||
var audiences []string
|
||||
|
@ -79,11 +96,12 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
|
|||
if matchesAudience(claims.Audience, audiences) {
|
||||
// Use fragment to get provisioner name (GCP, AWS, SSHPOP)
|
||||
if fragment != "" {
|
||||
return c.Load(fragment)
|
||||
return c.LoadByTokenID(fragment)
|
||||
}
|
||||
// If matches with stored audiences it will be a JWT token (default), and
|
||||
// the id would be <issuer>:<kid>.
|
||||
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
|
||||
// TODO: is this ok?
|
||||
return c.LoadByTokenID(claims.Issuer + ":" + token.Headers[0].KeyID)
|
||||
}
|
||||
|
||||
// The ID will be just the clientID stored in azp, aud or tid.
|
||||
|
@ -94,7 +112,7 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
|
|||
|
||||
// Kubernetes Service Account tokens.
|
||||
if payload.Issuer == k8sSAIssuer {
|
||||
if p, ok := c.Load(K8sSAID); ok {
|
||||
if p, ok := c.LoadByTokenID(K8sSAID); ok {
|
||||
return p, ok
|
||||
}
|
||||
// Kubernetes service account provisioner not found
|
||||
|
@ -108,18 +126,26 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
|
|||
|
||||
// Try with azp (OIDC)
|
||||
if len(payload.AuthorizedParty) > 0 {
|
||||
if p, ok := c.Load(payload.AuthorizedParty); ok {
|
||||
if p, ok := c.LoadByTokenID(payload.AuthorizedParty); ok {
|
||||
return p, ok
|
||||
}
|
||||
}
|
||||
// Try with tid (Azure)
|
||||
// Try with tid (Azure, Azure OIDC)
|
||||
if payload.TenantID != "" {
|
||||
if p, ok := c.Load(payload.TenantID); ok {
|
||||
// Try to load an OIDC provisioner first.
|
||||
if payload.Email != "" {
|
||||
if p, ok := c.LoadByTokenID(payload.Audience[0]); ok {
|
||||
return p, ok
|
||||
}
|
||||
}
|
||||
// Try to load an Azure provisioner.
|
||||
if p, ok := c.LoadByTokenID(payload.TenantID); ok {
|
||||
return p, ok
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to aud
|
||||
return c.Load(payload.Audience[0])
|
||||
return c.LoadByTokenID(payload.Audience[0])
|
||||
}
|
||||
|
||||
// LoadByCertificate looks for the provisioner extension and extracts the
|
||||
|
@ -131,24 +157,7 @@ func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool)
|
|||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
switch Type(provisioner.Type) {
|
||||
case TypeJWK:
|
||||
return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID))
|
||||
case TypeAWS:
|
||||
return c.Load("aws/" + string(provisioner.Name))
|
||||
case TypeGCP:
|
||||
return c.Load("gcp/" + string(provisioner.Name))
|
||||
case TypeACME:
|
||||
return c.Load("acme/" + string(provisioner.Name))
|
||||
case TypeSCEP:
|
||||
return c.Load("scep/" + string(provisioner.Name))
|
||||
case TypeX5C:
|
||||
return c.Load("x5c/" + string(provisioner.Name))
|
||||
case TypeK8sSA:
|
||||
return c.Load(K8sSAID)
|
||||
default:
|
||||
return c.Load(string(provisioner.CredentialID))
|
||||
}
|
||||
return c.LoadByName(string(provisioner.Name))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -173,7 +182,21 @@ func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
|
|||
func (c *Collection) Store(p Interface) error {
|
||||
// Store provisioner always in byID. ID must be unique.
|
||||
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded {
|
||||
return errors.New("cannot add multiple provisioners with the same id")
|
||||
return admin.NewError(admin.ErrorBadRequestType,
|
||||
"cannot add multiple provisioners with the same id")
|
||||
}
|
||||
// Store provisioner always by name.
|
||||
if _, loaded := c.byName.LoadOrStore(p.GetName(), p); loaded {
|
||||
c.byID.Delete(p.GetID())
|
||||
return admin.NewError(admin.ErrorBadRequestType,
|
||||
"cannot add multiple provisioners with the same name")
|
||||
}
|
||||
// Store provisioner always by ID presented in token.
|
||||
if _, loaded := c.byTokenID.LoadOrStore(p.GetIDForToken(), p); loaded {
|
||||
c.byID.Delete(p.GetID())
|
||||
c.byName.Delete(p.GetName())
|
||||
return admin.NewError(admin.ErrorBadRequestType,
|
||||
"cannot add multiple provisioners with the same token identifier")
|
||||
}
|
||||
|
||||
// Store provisioner in byKey if EncryptedKey is defined.
|
||||
|
@ -197,6 +220,66 @@ func (c *Collection) Store(p Interface) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes an provisioner from all associated collections and lists.
|
||||
func (c *Collection) Remove(id string) error {
|
||||
prov, ok := c.Load(id)
|
||||
if !ok {
|
||||
return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id)
|
||||
}
|
||||
|
||||
var found bool
|
||||
for i, elem := range c.sorted {
|
||||
if elem.provisioner.GetID() != id {
|
||||
continue
|
||||
}
|
||||
// Remove index in sorted list
|
||||
copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index.
|
||||
c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value).
|
||||
c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice.
|
||||
found = true
|
||||
break
|
||||
}
|
||||
if !found {
|
||||
return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName())
|
||||
}
|
||||
|
||||
c.byID.Delete(id)
|
||||
c.byName.Delete(prov.GetName())
|
||||
c.byTokenID.Delete(prov.GetIDForToken())
|
||||
if kid, _, ok := prov.GetEncryptedKey(); ok {
|
||||
c.byKey.Delete(kid)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update updates the given provisioner in all related lists and collections.
|
||||
func (c *Collection) Update(nu Interface) error {
|
||||
old, ok := c.Load(nu.GetID())
|
||||
if !ok {
|
||||
return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", nu.GetID())
|
||||
}
|
||||
|
||||
if old.GetName() != nu.GetName() {
|
||||
if _, ok := c.LoadByName(nu.GetName()); ok {
|
||||
return admin.NewError(admin.ErrorBadRequestType,
|
||||
"provisioner with name %s already exists", nu.GetName())
|
||||
}
|
||||
}
|
||||
if old.GetIDForToken() != nu.GetIDForToken() {
|
||||
if _, ok := c.LoadByTokenID(nu.GetIDForToken()); ok {
|
||||
return admin.NewError(admin.ErrorBadRequestType,
|
||||
"provisioner with Token ID %s already exists", nu.GetIDForToken())
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.Remove(old.GetID()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Store(nu)
|
||||
}
|
||||
|
||||
// Find implements pagination on a list of sorted provisioners.
|
||||
func (c *Collection) Find(cursor string, limit int) (List, string) {
|
||||
switch {
|
||||
|
|
|
@ -132,6 +132,7 @@ func TestCollection_LoadByToken(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
byTokenID: tt.fields.byID,
|
||||
audiences: tt.fields.audiences,
|
||||
}
|
||||
got, got1 := c.LoadByToken(tt.args.token, tt.args.claims)
|
||||
|
@ -153,10 +154,10 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
|||
p3, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p1.GetID(), p1)
|
||||
byID.Store(p2.GetID(), p2)
|
||||
byID.Store(p3.GetID(), p3)
|
||||
byName := new(sync.Map)
|
||||
byName.Store(p1.GetName(), p1)
|
||||
byName.Store(p2.GetName(), p2)
|
||||
byName.Store(p3.GetName(), p3)
|
||||
|
||||
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -186,7 +187,7 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
|||
}
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
byName *sync.Map
|
||||
audiences Audiences
|
||||
}
|
||||
type args struct {
|
||||
|
@ -199,17 +200,17 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
|||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
|
||||
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
|
||||
{"ok3", fields{byID, testAudiences}, args{ok3Cert}, p3, true},
|
||||
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
||||
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
|
||||
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
|
||||
{"ok1", fields{byName, testAudiences}, args{ok1Cert}, p1, true},
|
||||
{"ok2", fields{byName, testAudiences}, args{ok2Cert}, p2, true},
|
||||
{"ok3", fields{byName, testAudiences}, args{ok3Cert}, p3, true},
|
||||
{"noExtension", fields{byName, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
||||
{"notFound", fields{byName, testAudiences}, args{notFoundCert}, nil, false},
|
||||
{"badCert", fields{byName, testAudiences}, args{badCert}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
byName: tt.fields.byName,
|
||||
audiences: tt.fields.audiences,
|
||||
}
|
||||
got, got1 := c.LoadByCertificate(tt.args.cert)
|
||||
|
|
|
@ -78,6 +78,7 @@ func newGCPConfig() *gcpConfig {
|
|||
// https://cloud.google.com/compute/docs/instances/verifying-instance-identity
|
||||
type GCP struct {
|
||||
*base
|
||||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
ServiceAccounts []string `json:"serviceAccounts"`
|
||||
|
@ -96,6 +97,16 @@ type GCP struct {
|
|||
// GetID returns the provisioner unique identifier. The name should uniquely
|
||||
// identify any GCP provisioner.
|
||||
func (p *GCP) GetID() string {
|
||||
if p.ID != "" {
|
||||
return p.ID
|
||||
}
|
||||
return p.GetIDForToken()
|
||||
|
||||
}
|
||||
|
||||
// GetIDForToken returns an identifier that will be used to load the provisioner
|
||||
// from a token.
|
||||
func (p *GCP) GetIDForToken() string {
|
||||
return "gcp/" + p.Name
|
||||
}
|
||||
|
||||
|
@ -123,7 +134,7 @@ func (p *GCP) GetTokenID(token string) (string, error) {
|
|||
// Create unique ID for Trust On First Use (TOFU). Only the first instance
|
||||
// per provisioner is allowed as we don't have a way to trust the given
|
||||
// sans.
|
||||
unique := fmt.Sprintf("%s.%s", p.GetID(), claims.Google.ComputeEngine.InstanceID)
|
||||
unique := fmt.Sprintf("%s.%s", p.GetIDForToken(), claims.Google.ComputeEngine.InstanceID)
|
||||
sum := sha256.Sum256([]byte(unique))
|
||||
return strings.ToLower(hex.EncodeToString(sum[:])), nil
|
||||
}
|
||||
|
@ -139,7 +150,7 @@ func (p *GCP) GetType() Type {
|
|||
}
|
||||
|
||||
// GetEncryptedKey is not available in a GCP provisioner.
|
||||
func (p *GCP) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||
func (p *GCP) GetEncryptedKey() (kid, key string, ok bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
|
@ -157,7 +168,7 @@ func (p *GCP) GetIdentityURL(audience string) string {
|
|||
|
||||
// GetIdentityToken does an HTTP request to the identity url.
|
||||
func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
|
||||
audience, err := generateSignAudience(caURL, p.GetID())
|
||||
audience, err := generateSignAudience(caURL, p.GetIDForToken())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -205,7 +216,7 @@ func (p *GCP) Init(config Config) error {
|
|||
return err
|
||||
}
|
||||
|
||||
p.audiences = config.Audiences.WithFragment(p.GetID())
|
||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -233,15 +244,17 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
if p.DisableCustomSANs {
|
||||
dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID)
|
||||
dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID)
|
||||
so = append(so, commonNameSliceValidator([]string{
|
||||
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
|
||||
}))
|
||||
so = append(so, dnsNamesValidator([]string{
|
||||
dnsName1, dnsName2,
|
||||
}))
|
||||
so = append(so, ipAddressesValidator(nil))
|
||||
so = append(so, emailAddressesValidator(nil))
|
||||
so = append(so, urisValidator(nil))
|
||||
so = append(so,
|
||||
commonNameSliceValidator([]string{
|
||||
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
|
||||
}),
|
||||
dnsNamesValidator([]string{
|
||||
dnsName1, dnsName2,
|
||||
}),
|
||||
ipAddressesValidator(nil),
|
||||
emailAddressesValidator(nil),
|
||||
urisValidator(nil),
|
||||
)
|
||||
|
||||
// Template SANs
|
||||
data.SetSANs([]string{dnsName1, dnsName2})
|
||||
|
@ -266,7 +279,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -371,7 +384,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName())
|
||||
}
|
||||
claims, err := p.authorizeToken(token)
|
||||
if err != nil {
|
||||
|
|
|
@ -535,15 +535,15 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
switch got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token); {
|
||||
case (err != nil) != tt.wantErr:
|
||||
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
} else if err != nil {
|
||||
case err != nil:
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
} else {
|
||||
default:
|
||||
assert.Len(t, tt.wantLen, got)
|
||||
for _, o := range got {
|
||||
switch v := o.(type) {
|
||||
|
|
|
@ -28,6 +28,7 @@ type stepPayload struct {
|
|||
// signature requests.
|
||||
type JWK struct {
|
||||
*base
|
||||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
|
@ -41,6 +42,15 @@ type JWK struct {
|
|||
// GetID returns the provisioner unique identifier. The name and credential id
|
||||
// should uniquely identify any JWK provisioner.
|
||||
func (p *JWK) GetID() string {
|
||||
if p.ID != "" {
|
||||
return p.ID
|
||||
}
|
||||
return p.GetIDForToken()
|
||||
}
|
||||
|
||||
// GetIDForToken returns an identifier that will be used to load the provisioner
|
||||
// from a token.
|
||||
func (p *JWK) GetIDForToken() string {
|
||||
return p.Name + ":" + p.Key.KeyID
|
||||
}
|
||||
|
||||
|
@ -184,7 +194,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// certificate was configured to allow renewals.
|
||||
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -192,7 +202,7 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName())
|
||||
}
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
if err != nil {
|
||||
|
|
|
@ -77,7 +77,7 @@ func TestJWK_Init(t *testing.T) {
|
|||
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
|
||||
err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"),
|
||||
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||
|
|
|
@ -42,6 +42,7 @@ type k8sSAPayload struct {
|
|||
// entity trusted to make signature requests.
|
||||
type K8sSA struct {
|
||||
*base
|
||||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
PubKeys []byte `json:"publicKeys,omitempty"`
|
||||
|
@ -56,6 +57,15 @@ type K8sSA struct {
|
|||
// GetID returns the provisioner unique identifier. The name and credential id
|
||||
// should uniquely identify any K8sSA provisioner.
|
||||
func (p *K8sSA) GetID() string {
|
||||
if p.ID != "" {
|
||||
return p.ID
|
||||
}
|
||||
return p.GetIDForToken()
|
||||
}
|
||||
|
||||
// GetIDForToken returns an identifier that will be used to load the provisioner
|
||||
// from a token.
|
||||
func (p *K8sSA) GetIDForToken() string {
|
||||
return K8sSAID
|
||||
}
|
||||
|
||||
|
@ -101,12 +111,12 @@ func (p *K8sSA) Init(config Config) (err error) {
|
|||
}
|
||||
key, err := pemutil.ParseKey(pem.EncodeToMemory(block))
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error parsing public key in provisioner %s", p.GetID())
|
||||
return errors.Wrapf(err, "error parsing public key in provisioner '%s'", p.GetName())
|
||||
}
|
||||
switch q := key.(type) {
|
||||
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
|
||||
default:
|
||||
return errors.Errorf("Unexpected public key type %T in provisioner %s", q, p.GetID())
|
||||
return errors.Errorf("Unexpected public key type %T in provisioner '%s'", q, p.GetName())
|
||||
}
|
||||
p.pubKeys = append(p.pubKeys, key)
|
||||
}
|
||||
|
@ -240,7 +250,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID())
|
||||
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -248,7 +258,7 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
|
|||
// AuthorizeSSHSign validates an request for an SSH certificate.
|
||||
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID())
|
||||
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName())
|
||||
}
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
if err != nil {
|
||||
|
|
|
@ -198,7 +198,7 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
|||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner %s", p.GetID()),
|
||||
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
@ -319,7 +319,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
|||
p: p,
|
||||
token: "foo",
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner %s", p.GetID()),
|
||||
err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()),
|
||||
}
|
||||
},
|
||||
"fail/invalid-token": func(t *testing.T) test {
|
||||
|
|
|
@ -18,7 +18,7 @@ const (
|
|||
defaultCacheJitter = 1 * time.Hour
|
||||
)
|
||||
|
||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)")
|
||||
var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)
|
||||
|
||||
type keyStore struct {
|
||||
sync.RWMutex
|
||||
|
|
|
@ -14,6 +14,10 @@ func (p *noop) GetID() string {
|
|||
return "noop"
|
||||
}
|
||||
|
||||
func (p *noop) GetIDForToken() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (p *noop) GetTokenID(token string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
@ -25,7 +29,7 @@ func (p *noop) GetType() Type {
|
|||
return noopType
|
||||
}
|
||||
|
||||
func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||
func (p *noop) GetEncryptedKey() (kid, key string, ok bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
|
|
|
@ -49,11 +49,35 @@ type openIDPayload struct {
|
|||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
func (o *openIDPayload) IsAdmin(admins []string) bool {
|
||||
if o.Email != "" {
|
||||
email := sanitizeEmail(o.Email)
|
||||
for _, e := range admins {
|
||||
if email == sanitizeEmail(e) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The groups and emails can be in the same array for now, but consider
|
||||
// making a specialized option later.
|
||||
for _, name := range o.Groups {
|
||||
for _, admin := range admins {
|
||||
if name == admin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// OIDC represents an OAuth 2.0 OpenID Connect provider.
|
||||
//
|
||||
// ClientSecret is mandatory, but it can be an empty string.
|
||||
type OIDC struct {
|
||||
*base
|
||||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
ClientID string `json:"clientID"`
|
||||
|
@ -72,35 +96,6 @@ type OIDC struct {
|
|||
getIdentityFunc GetIdentityFunc
|
||||
}
|
||||
|
||||
// IsAdmin returns true if the given email is in the Admins allowlist, false
|
||||
// otherwise.
|
||||
func (o *OIDC) IsAdmin(email string) bool {
|
||||
if email != "" {
|
||||
email = sanitizeEmail(email)
|
||||
for _, e := range o.Admins {
|
||||
if email == sanitizeEmail(e) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsAdminGroup returns true if the one group in the given list is in the Admins
|
||||
// allowlist, false otherwise.
|
||||
func (o *OIDC) IsAdminGroup(groups []string) bool {
|
||||
for _, g := range groups {
|
||||
// The groups and emails can be in the same array for now, but consider
|
||||
// making a specialized option later.
|
||||
for _, gadmin := range o.Admins {
|
||||
if g == gadmin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sanitizeEmail(email string) string {
|
||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||
email = email[:i] + strings.ToLower(email[i:])
|
||||
|
@ -111,6 +106,15 @@ func sanitizeEmail(email string) string {
|
|||
// GetID returns the provisioner unique identifier, the OIDC provisioner the
|
||||
// uses the clientID for this.
|
||||
func (o *OIDC) GetID() string {
|
||||
if o.ID != "" {
|
||||
return o.ID
|
||||
}
|
||||
return o.GetIDForToken()
|
||||
}
|
||||
|
||||
// GetIDForToken returns an identifier that will be used to load the provisioner
|
||||
// from a token.
|
||||
func (o *OIDC) GetIDForToken() string {
|
||||
return o.ClientID
|
||||
}
|
||||
|
||||
|
@ -144,7 +148,7 @@ func (o *OIDC) GetType() Type {
|
|||
}
|
||||
|
||||
// GetEncryptedKey is not available in an OIDC provisioner.
|
||||
func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) {
|
||||
func (o *OIDC) GetEncryptedKey() (kid, key string, ok bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
|
@ -189,7 +193,7 @@ func (o *OIDC) Init(config Config) (err error) {
|
|||
}
|
||||
// Replace {tenantid} with the configured one
|
||||
if o.TenantID != "" {
|
||||
o.configuration.Issuer = strings.Replace(o.configuration.Issuer, "{tenantid}", o.TenantID, -1)
|
||||
o.configuration.Issuer = strings.ReplaceAll(o.configuration.Issuer, "{tenantid}", o.TenantID)
|
||||
}
|
||||
// Get JWK key set
|
||||
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
|
||||
|
@ -224,7 +228,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
|||
}
|
||||
|
||||
// Validate domains (case-insensitive)
|
||||
if p.Email != "" && len(o.Domains) > 0 && !o.IsAdmin(p.Email) {
|
||||
if p.Email != "" && len(o.Domains) > 0 && !p.IsAdmin(o.Admins) {
|
||||
email := sanitizeEmail(p.Email)
|
||||
var found bool
|
||||
for _, d := range o.Domains {
|
||||
|
@ -303,9 +307,10 @@ func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
|
|||
}
|
||||
|
||||
// Only admins can revoke certificates.
|
||||
if o.IsAdmin(claims.Email) {
|
||||
if claims.IsAdmin(o.Admins) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
|
||||
}
|
||||
|
||||
|
@ -341,7 +346,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// Use the default template unless no-templates are configured and email is
|
||||
// an admin, in that case we will use the CR template.
|
||||
defaultTemplate := x509util.DefaultLeafTemplate
|
||||
if !o.Options.GetX509Options().HasTemplate() && o.IsAdmin(claims.Email) {
|
||||
if !o.Options.GetX509Options().HasTemplate() && claims.IsAdmin(o.Admins) {
|
||||
defaultTemplate = x509util.DefaultAdminLeafTemplate
|
||||
}
|
||||
|
||||
|
@ -367,7 +372,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// certificate was configured to allow renewals.
|
||||
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if o.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner %s", o.GetID())
|
||||
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -375,7 +380,7 @@ func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !o.claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner %s", o.GetID())
|
||||
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName())
|
||||
}
|
||||
claims, err := o.authorizeToken(token)
|
||||
if err != nil {
|
||||
|
@ -410,10 +415,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
|||
|
||||
// Use the default template unless no-templates are configured and email is
|
||||
// an admin, in that case we will use the parameters in the request.
|
||||
isAdmin := o.IsAdmin(claims.Email)
|
||||
if !isAdmin && len(claims.Groups) > 0 {
|
||||
isAdmin = o.IsAdminGroup(claims.Groups)
|
||||
}
|
||||
isAdmin := claims.IsAdmin(o.Admins)
|
||||
defaultTemplate := sshutil.DefaultTemplate
|
||||
if isAdmin && !o.Options.GetSSHOptions().HasTemplate() {
|
||||
defaultTemplate = sshutil.DefaultAdminTemplate
|
||||
|
@ -461,10 +463,11 @@ func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
|||
}
|
||||
|
||||
// Only admins can revoke certificates.
|
||||
if !o.IsAdmin(claims.Email) {
|
||||
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
|
||||
if claims.IsAdmin(o.Admins) {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
|
||||
}
|
||||
|
||||
func getAndDecode(uri string, v interface{}) error {
|
||||
|
|
|
@ -321,32 +321,26 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
|||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
if assert.NotNil(t, got) {
|
||||
if tt.name == "admin" {
|
||||
assert.Len(t, 5, got)
|
||||
} else {
|
||||
assert.Len(t, 5, got)
|
||||
}
|
||||
for _, o := range got {
|
||||
switch v := o.(type) {
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeOIDC))
|
||||
assert.Equals(t, v.Name, tt.prov.GetName())
|
||||
assert.Equals(t, v.CredentialID, tt.prov.ClientID)
|
||||
assert.Len(t, 0, v.KeyValuePairs)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
|
||||
case emailOnlyIdentity:
|
||||
assert.Equals(t, string(v), "name@smallstep.com")
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
} else if assert.NotNil(t, got) {
|
||||
assert.Len(t, 5, got)
|
||||
for _, o := range got {
|
||||
switch v := o.(type) {
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeOIDC))
|
||||
assert.Equals(t, v.Name, tt.prov.GetName())
|
||||
assert.Equals(t, v.CredentialID, tt.prov.ClientID)
|
||||
assert.Len(t, 0, v.KeyValuePairs)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
|
||||
case emailOnlyIdentity:
|
||||
assert.Equals(t, string(v), "name@smallstep.com")
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -698,3 +692,39 @@ func Test_sanitizeEmail(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_openIDPayload_IsAdmin(t *testing.T) {
|
||||
type fields struct {
|
||||
Email string
|
||||
Groups []string
|
||||
}
|
||||
type args struct {
|
||||
admins []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{"ok email", fields{"admin@smallstep.com", nil}, args{[]string{"admin@smallstep.com"}}, true},
|
||||
{"ok email multiple", fields{"admin@smallstep.com", []string{"admin", "eng"}}, args{[]string{"eng@smallstep.com", "admin@smallstep.com"}}, true},
|
||||
{"ok email sanitized", fields{"admin@Smallstep.com", nil}, args{[]string{"admin@smallStep.com"}}, true},
|
||||
{"ok group", fields{"", []string{"admin"}}, args{[]string{"admin"}}, true},
|
||||
{"ok group multiple", fields{"admin@smallstep.com", []string{"engineering", "admin"}}, args{[]string{"admin"}}, true},
|
||||
{"fail missing", fields{"eng@smallstep.com", []string{"admin"}}, args{[]string{"admin@smallstep.com"}}, false},
|
||||
{"fail email letter case", fields{"Admin@smallstep.com", []string{}}, args{[]string{"admin@smallstep.com"}}, false},
|
||||
{"fail group letter case", fields{"", []string{"Admin"}}, args{[]string{"admin"}}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := &openIDPayload{
|
||||
Email: tt.fields.Email,
|
||||
Groups: tt.fields.Groups,
|
||||
}
|
||||
if got := o.IsAdmin(tt.args.admins); got != tt.want {
|
||||
t.Errorf("openIDPayload.IsAdmin() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,7 +138,7 @@ func unsafeParseSigned(s string) (map[string]interface{}, error) {
|
|||
return nil, err
|
||||
}
|
||||
claims := make(map[string]interface{})
|
||||
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claims, nil
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
stderrors "errors"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
@ -17,6 +18,7 @@ import (
|
|||
// Interface is the interface that all provisioner types must implement.
|
||||
type Interface interface {
|
||||
GetID() string
|
||||
GetIDForToken() string
|
||||
GetTokenID(token string) (string, error)
|
||||
GetName() string
|
||||
GetType() Type
|
||||
|
@ -31,6 +33,17 @@ type Interface interface {
|
|||
AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error)
|
||||
}
|
||||
|
||||
// ErrAllowTokenReuse is an error that is returned by provisioners that allows
|
||||
// the reuse of tokens.
|
||||
//
|
||||
// This is, for example, returned by the Azure provisioner when
|
||||
// DisableTrustOnFirstUse is set to true. Azure caches tokens for up to 24hr and
|
||||
// has no mechanism for getting a different token - this can be an issue when
|
||||
// rebooting a VM. In contrast, AWS and GCP have facilities for requesting a new
|
||||
// token. Therefore, for the Azure provisioner we are enabling token reuse, with
|
||||
// the understanding that we are not following security best practices
|
||||
var ErrAllowTokenReuse = stderrors.New("allow token reuse")
|
||||
|
||||
// Audiences stores all supported audiences by request type.
|
||||
type Audiences struct {
|
||||
Sign []string
|
||||
|
@ -110,7 +123,7 @@ func (a Audiences) WithFragment(fragment string) Audiences {
|
|||
|
||||
// generateSignAudience generates a sign audience with the format
|
||||
// https://<host>/1.0/sign#provisionerID
|
||||
func generateSignAudience(caURL string, provisionerID string) (string, error) {
|
||||
func generateSignAudience(caURL, provisionerID string) (string, error) {
|
||||
u, err := url.Parse(caURL)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "error parsing %s", caURL)
|
||||
|
@ -394,6 +407,7 @@ type MockProvisioner struct {
|
|||
Mret1, Mret2, Mret3 interface{}
|
||||
Merr error
|
||||
MgetID func() string
|
||||
MgetIDForToken func() string
|
||||
MgetTokenID func(string) (string, error)
|
||||
MgetName func() string
|
||||
MgetType func() Type
|
||||
|
@ -416,6 +430,14 @@ func (m *MockProvisioner) GetID() string {
|
|||
return m.Mret1.(string)
|
||||
}
|
||||
|
||||
// GetIDForToken mock
|
||||
func (m *MockProvisioner) GetIDForToken() string {
|
||||
if m.MgetIDForToken != nil {
|
||||
return m.MgetIDForToken()
|
||||
}
|
||||
return m.Mret1.(string)
|
||||
}
|
||||
|
||||
// GetTokenID mock
|
||||
func (m *MockProvisioner) GetTokenID(token string) (string, error) {
|
||||
if m.MgetTokenID != nil {
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
// SCEP provisioning flow
|
||||
type SCEP struct {
|
||||
*base
|
||||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
|
||||
|
@ -27,7 +28,16 @@ type SCEP struct {
|
|||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
func (s SCEP) GetID() string {
|
||||
func (s *SCEP) GetID() string {
|
||||
if s.ID != "" {
|
||||
return s.ID
|
||||
}
|
||||
return s.GetIDForToken()
|
||||
}
|
||||
|
||||
// GetIDForToken returns an identifier that will be used to load the provisioner
|
||||
// from a token.
|
||||
func (s *SCEP) GetIDForToken() string {
|
||||
return "scep/" + s.Name
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -455,10 +456,10 @@ func containsAllMembers(group, subgroup []string) bool {
|
|||
}
|
||||
visit := make(map[string]struct{}, lg)
|
||||
for i := 0; i < lg; i++ {
|
||||
visit[group[i]] = struct{}{}
|
||||
visit[strings.ToLower(group[i])] = struct{}{}
|
||||
}
|
||||
for i := 0; i < lsg; i++ {
|
||||
if _, ok := visit[subgroup[i]]; !ok {
|
||||
if _, ok := visit[strings.ToLower(subgroup[i])]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ func TestSSHOptions_Modify(t *testing.T) {
|
|||
valid func(*ssh.Certificate)
|
||||
err error
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
tests := map[string]func() test{
|
||||
"fail/unexpected-cert-type": func() test {
|
||||
return test{
|
||||
so: SignSSHOptions{CertType: "foo"},
|
||||
|
@ -117,7 +117,7 @@ func TestSSHOptions_Match(t *testing.T) {
|
|||
cmp SignSSHOptions
|
||||
err error
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
tests := map[string]func() test{
|
||||
"fail/cert-type": func() test {
|
||||
return test{
|
||||
so: SignSSHOptions{CertType: "foo"},
|
||||
|
@ -208,7 +208,7 @@ func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
|
|||
cert *ssh.Certificate
|
||||
expected []string
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
tests := map[string]func() test{
|
||||
"ok": func() test {
|
||||
a := []string{"foo", "bar"}
|
||||
return test{
|
||||
|
@ -234,7 +234,7 @@ func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
|
|||
cert *ssh.Certificate
|
||||
expected string
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
tests := map[string]func() test{
|
||||
"ok": func() test {
|
||||
a := "foo"
|
||||
return test{
|
||||
|
@ -260,7 +260,7 @@ func Test_sshCertTypeModifier_Modify(t *testing.T) {
|
|||
cert *ssh.Certificate
|
||||
expected uint32
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
tests := map[string]func() test{
|
||||
"ok/user": func() test {
|
||||
return test{
|
||||
modifier: sshCertTypeModifier("user"),
|
||||
|
@ -299,7 +299,7 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
|||
cert *ssh.Certificate
|
||||
expected uint64
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
tests := map[string]func() test{
|
||||
"ok": func() test {
|
||||
return test{
|
||||
modifier: sshCertValidAfterModifier(15),
|
||||
|
@ -324,7 +324,7 @@ func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
|
|||
cert *ssh.Certificate
|
||||
valid func(*ssh.Certificate)
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
tests := map[string]func() test{
|
||||
"ok/changes": func() test {
|
||||
n := time.Now()
|
||||
va := NewTimeDuration(n.Add(1 * time.Minute))
|
||||
|
@ -388,7 +388,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
|
|||
valid func(*ssh.Certificate)
|
||||
err error
|
||||
}
|
||||
tests := map[string](func() test){
|
||||
tests := map[string]func() test{
|
||||
"fail/unexpected-cert-type": func() test {
|
||||
cert := &ssh.Certificate{CertType: 3}
|
||||
return test{
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
@ -26,10 +25,10 @@ type sshPOPPayload struct {
|
|||
// signature requests.
|
||||
type SSHPOP struct {
|
||||
*base
|
||||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
db db.AuthDB
|
||||
claimer *Claimer
|
||||
audiences Audiences
|
||||
sshPubKeys *SSHKeys
|
||||
|
@ -38,6 +37,15 @@ type SSHPOP struct {
|
|||
// GetID returns the provisioner unique identifier. The name and credential id
|
||||
// should uniquely identify any SSH-POP provisioner.
|
||||
func (p *SSHPOP) GetID() string {
|
||||
if p.ID != "" {
|
||||
return p.ID
|
||||
}
|
||||
return p.GetIDForToken()
|
||||
}
|
||||
|
||||
// GetIDForToken returns an identifier that will be used to load the provisioner
|
||||
// from a token.
|
||||
func (p *SSHPOP) GetIDForToken() string {
|
||||
return "sshpop/" + p.Name
|
||||
}
|
||||
|
||||
|
@ -91,8 +99,7 @@ func (p *SSHPOP) Init(config Config) error {
|
|||
return err
|
||||
}
|
||||
|
||||
p.audiences = config.Audiences.WithFragment(p.GetID())
|
||||
p.db = config.DB
|
||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
p.sshPubKeys = config.SSHKeys
|
||||
return nil
|
||||
}
|
||||
|
@ -100,6 +107,8 @@ func (p *SSHPOP) Init(config Config) error {
|
|||
// authorizeToken performs common jwt authorization actions and returns the
|
||||
// claims for case specific downstream parsing.
|
||||
// e.g. a Sign request will auth/validate different fields than a Revoke request.
|
||||
//
|
||||
// Checking for certificate revocation has been moved to the authority package.
|
||||
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
|
||||
sshCert, jwt, err := ExtractSSHPOPCert(token)
|
||||
if err != nil {
|
||||
|
@ -107,14 +116,6 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
"sshpop.authorizeToken; error extracting sshpop header from token")
|
||||
}
|
||||
|
||||
// Check for revocation.
|
||||
if isRevoked, err := p.db.IsSSHRevoked(strconv.FormatUint(sshCert.Serial, 10)); err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||
"sshpop.authorizeToken; error checking checking sshpop cert revocation")
|
||||
} else if isRevoked {
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked")
|
||||
}
|
||||
|
||||
// Check validity period of the certificate.
|
||||
n := time.Now()
|
||||
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
|
@ -47,7 +46,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err = cert.SignCert(rand.Reader, signer); err != nil {
|
||||
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return cert, jwk, nil
|
||||
|
@ -83,52 +82,9 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
|
||||
}
|
||||
},
|
||||
"fail/error-revoked-db-check": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, errors.New("force")
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusInternalServerError,
|
||||
err: errors.New("sshpop.authorizeToken; error checking checking sshpop cert revocation: force"),
|
||||
}
|
||||
},
|
||||
"fail/cert-already-revoked": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.New("sshpop.authorizeToken; sshpop certificate is revoked"),
|
||||
}
|
||||
},
|
||||
"fail/cert-not-yet-valid": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{
|
||||
CertType: ssh.UserCert,
|
||||
ValidAfter: uint64(time.Now().Add(time.Minute).Unix()),
|
||||
|
@ -146,11 +102,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
"fail/cert-past-validity": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{
|
||||
CertType: ssh.UserCert,
|
||||
ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()),
|
||||
|
@ -168,11 +119,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
"fail/no-signer-found": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
|
@ -187,11 +133,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
"fail/error-parsing-claims-bad-sig": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
|
@ -208,11 +149,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
"fail/invalid-claims-issuer": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "",
|
||||
|
@ -228,11 +164,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
"fail/invalid-audience": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), "invalid-aud", "",
|
||||
|
@ -248,11 +179,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
"fail/empty-subject": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
|
||||
|
@ -268,11 +194,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
"ok": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||
|
@ -293,10 +214,8 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.NotNil(t, claims)
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.NotNil(t, claims)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -330,11 +249,6 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
|
|||
"fail/subject-not-equal-serial": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "",
|
||||
|
@ -350,11 +264,6 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
|
|||
"ok": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "",
|
||||
|
@ -419,11 +328,6 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
|
|||
"fail/not-host-cert": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "",
|
||||
|
@ -439,11 +343,6 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
|
|||
"ok": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "",
|
||||
|
@ -511,11 +410,6 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
|||
"fail/not-host-cert": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "",
|
||||
|
@ -531,11 +425,6 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
|||
"ok": func(t *testing.T) test {
|
||||
p, err := generateSSHPOP()
|
||||
assert.FatalError(t, err)
|
||||
p.db = &db.MockAuthDB{
|
||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "",
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue