forked from TrueCloudLab/certificates
Compare commits
51 commits
master
...
ssh-config
Author | SHA1 | Date | |
---|---|---|---|
|
907c13a7be | ||
|
191aa90ac5 | ||
|
861a653769 | ||
|
bd2c764afe | ||
|
8ef9b020ed | ||
|
85218a93b0 | ||
|
665cf90353 | ||
|
e6b1439c20 | ||
|
7712f186e8 | ||
|
b49444327c | ||
|
519ed1ffb9 | ||
|
4669bef8c7 | ||
|
9cfd84324e | ||
|
8b7415f021 | ||
|
a25d4e8b04 | ||
|
d35cebc1fb | ||
|
d6de129513 | ||
|
325d77fce9 | ||
|
8890c8edd7 | ||
|
c563143e6d | ||
|
dc24f27647 | ||
|
15390fc636 | ||
|
7a8ec397ac | ||
|
829a86a5d7 | ||
|
74c1b4e771 | ||
|
3e913d374f | ||
|
962800676f | ||
|
9f13a92b9e | ||
|
50c8b10a4f | ||
|
eb210ccc70 | ||
|
7b8f0327bd | ||
|
92bbadba32 | ||
|
43771b341f | ||
|
b2fc1a324a | ||
|
954db8792d | ||
|
caaba4a80d | ||
|
b9b0c2e2d6 | ||
|
823c190ccd | ||
|
f2ec20d53f | ||
|
1364dd9654 | ||
|
db68bf1081 | ||
|
32ae027ed9 | ||
|
9893c00e1b | ||
|
42fe92fd5c | ||
|
6148f8f775 | ||
|
cba1799a43 | ||
|
40f002a97d | ||
|
624dd0bb9d | ||
|
4dc2410134 | ||
|
3d3598a1a9 | ||
|
0a61c4035b |
437 changed files with 18700 additions and 98934 deletions
|
@ -1,5 +0,0 @@
|
|||
bin
|
||||
coverage.txt
|
||||
*.test
|
||||
*.out
|
||||
.travis-releases
|
56
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
56
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
|
@ -1,56 +0,0 @@
|
|||
name: Bug Report
|
||||
description: File a bug report
|
||||
title: "[Bug]: "
|
||||
labels: ["bug", "needs triage"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report!
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Tell us how to reproduce this issue.
|
||||
placeholder: These are the steps!
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: your-env
|
||||
attributes:
|
||||
label: Your Environment
|
||||
value: |-
|
||||
* OS -
|
||||
* `step-ca` Version -
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: What did you expect to happen?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: What happens instead?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context about the problem here.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: contributing
|
||||
attributes:
|
||||
label: Contributing
|
||||
value: |
|
||||
Vote on this issue by adding a 👍 reaction.
|
||||
To contribute a fix for this issue, leave a comment (and link to your pull request, if you've opened one already).
|
||||
validations:
|
||||
required: false
|
24
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
24
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
|
@ -0,0 +1,24 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
|
||||
---
|
||||
|
||||
### Subject of the issue
|
||||
Describe your issue here.
|
||||
|
||||
### Your environment
|
||||
* OS -
|
||||
* Version -
|
||||
|
||||
### Steps to reproduce
|
||||
Tell us how to reproduce this issue. Please provide a working demo, you can use [this template](https://plnkr.co/edit/XorWgI?p=preview) as a base.
|
||||
|
||||
### Expected behaviour
|
||||
Tell us what should happen
|
||||
|
||||
### Actual behaviour
|
||||
Tell us what happens instead
|
||||
|
||||
### Additional context
|
||||
Add any other context about the problem here.
|
9
.github/ISSUE_TEMPLATE/config.yml
vendored
9
.github/ISSUE_TEMPLATE/config.yml
vendored
|
@ -1,9 +0,0 @@
|
|||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Ask on Discord
|
||||
url: https://discord.gg/7xgjhVAg6g
|
||||
about: You can ask for help here!
|
||||
- name: Want to contribute to step certificates?
|
||||
url: https://github.com/smallstep/certificates/blob/master/docs/CONTRIBUTING.md
|
||||
about: Be sure to read contributing guidelines!
|
||||
|
22
.github/ISSUE_TEMPLATE/documentation-request.md
vendored
22
.github/ISSUE_TEMPLATE/documentation-request.md
vendored
|
@ -1,22 +0,0 @@
|
|||
---
|
||||
name: Documentation Request
|
||||
about: Request documentation for a feature
|
||||
title: '[Docs]:'
|
||||
labels: docs, needs triage
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Hello!
|
||||
<!-- Please leave this section as-is, it's designed to help others in the community know how to interact with our GitHub issues. -->
|
||||
|
||||
- Vote on this issue by adding a 👍 reaction
|
||||
- If you want to document this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
|
||||
|
||||
## Affected area/feature
|
||||
|
||||
<!---
|
||||
Tell us which feature you'd like to see documented.
|
||||
- Where would you like that documentation to live (command line usage output, website, github markdown on the repo)?
|
||||
- If there are specific attributes or options you'd like to see documented, please include those in the request.
|
||||
-->
|
23
.github/ISSUE_TEMPLATE/enhancement.md
vendored
23
.github/ISSUE_TEMPLATE/enhancement.md
vendored
|
@ -1,24 +1,11 @@
|
|||
---
|
||||
name: Enhancement
|
||||
about: Suggest an enhancement to step-ca
|
||||
title: ''
|
||||
labels: enhancement, needs triage
|
||||
assignees: ''
|
||||
|
||||
name: Certificates Enhancement
|
||||
about: Suggest an enhancement to step certificates
|
||||
labels: area/cert-management enhancement
|
||||
---
|
||||
|
||||
## Hello!
|
||||
<!-- Please leave this section as-is,
|
||||
it's designed to help others in the community know how to interact with our GitHub issues. -->
|
||||
### What would you like to be added
|
||||
|
||||
- Vote on this issue by adding a 👍 reaction
|
||||
- If you want to implement this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
|
||||
|
||||
## Issue details
|
||||
### Why this is needed
|
||||
|
||||
<!-- Enhancement requests are most helpful when they describe the problem you're having
|
||||
as well as articulating the potential solution you'd like to see built. -->
|
||||
|
||||
## Why is this needed?
|
||||
|
||||
<!-- Let us know why you think this enhancement would be good for the project or community. -->
|
||||
|
|
20
.github/PULL_REQUEST_TEMPLATE
vendored
20
.github/PULL_REQUEST_TEMPLATE
vendored
|
@ -1,20 +1,4 @@
|
|||
<!---
|
||||
Please provide answers in the spaces below each prompt, where applicable.
|
||||
Not every PR requires responses for each prompt.
|
||||
Use your discretion.
|
||||
-->
|
||||
#### Name of feature:
|
||||
|
||||
#### Pain or issue this feature alleviates:
|
||||
|
||||
#### Why is this important to the project (if not answered above):
|
||||
|
||||
#### Is there documentation on how to use this feature? If so, where?
|
||||
|
||||
#### In what environments or workflows is this feature supported?
|
||||
|
||||
#### In what environments or workflows is this feature explicitly NOT supported (if any)?
|
||||
|
||||
#### Supporting links/other PRs/issues:
|
||||
### Description
|
||||
Please describe your pull request.
|
||||
|
||||
💔Thank you!
|
||||
|
|
11
.github/dependabot.yml
vendored
11
.github/dependabot.yml
vendored
|
@ -1,11 +0,0 @@
|
|||
# To get started with Dependabot version updates, you'll need to specify which
|
||||
# package ecosystems to update and where the package manifests are located.
|
||||
# Please see the documentation for all configuration options:
|
||||
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
|
||||
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "gomod" # See documentation for possible values
|
||||
directory: "/" # Location of package manifests
|
||||
schedule:
|
||||
interval: "weekly"
|
27
.github/workflows/ci.yml
vendored
27
.github/workflows/ci.yml
vendored
|
@ -1,27 +0,0 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags-ignore:
|
||||
- 'v*'
|
||||
branches:
|
||||
- "master"
|
||||
pull_request:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: true
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ci:
|
||||
uses: smallstep/workflows/.github/workflows/goCI.yml@main
|
||||
with:
|
||||
only-latest-golang: false
|
||||
os-dependencies: 'libpcsclite-dev'
|
||||
run-codeql: true
|
||||
test-command: 'V=1 make test'
|
||||
secrets: inherit
|
9
.github/workflows/code-scan-cron.yml
vendored
9
.github/workflows/code-scan-cron.yml
vendored
|
@ -1,9 +0,0 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
|
||||
jobs:
|
||||
code-scan:
|
||||
uses: smallstep/workflows/.github/workflows/code-scan.yml@main
|
||||
secrets:
|
||||
GITLEAKS_LICENSE_KEY: ${{ secrets.GITLEAKS_LICENSE_KEY }}
|
22
.github/workflows/dependabot-auto-merge.yml
vendored
22
.github/workflows/dependabot-auto-merge.yml
vendored
|
@ -1,22 +0,0 @@
|
|||
name: Dependabot auto-merge
|
||||
on: pull_request
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
dependabot:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.actor == 'dependabot[bot]' }}
|
||||
steps:
|
||||
- name: Dependabot metadata
|
||||
id: metadata
|
||||
uses: dependabot/fetch-metadata@v1.1.1
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
- name: Enable auto-merge for Dependabot PRs
|
||||
run: gh pr merge --auto --merge "$PR_URL"
|
||||
env:
|
||||
PR_URL: ${{github.event.pull_request.html_url}}
|
||||
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
|
91
.github/workflows/release.yml
vendored
91
.github/workflows/release.yml
vendored
|
@ -1,91 +0,0 @@
|
|||
name: Create Release & Upload Assets
|
||||
|
||||
on:
|
||||
push:
|
||||
# Sequence of patterns matched against refs/tags
|
||||
tags:
|
||||
- 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10
|
||||
|
||||
jobs:
|
||||
ci:
|
||||
uses: smallstep/certificates/.github/workflows/ci.yml@master
|
||||
secrets: inherit
|
||||
|
||||
create_release:
|
||||
name: Create Release
|
||||
needs: ci
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DOCKER_IMAGE: smallstep/step-ca
|
||||
outputs:
|
||||
version: ${{ steps.extract-tag.outputs.VERSION }}
|
||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||
docker_tags: ${{ env.DOCKER_TAGS }}
|
||||
docker_tags_hsm: ${{ env.DOCKER_TAGS_HSM }}
|
||||
steps:
|
||||
- name: Is Pre-release
|
||||
id: is_prerelease
|
||||
run: |
|
||||
set +e
|
||||
echo ${{ github.ref }} | grep "\-rc.*"
|
||||
OUT=$?
|
||||
if [ $OUT -eq 0 ]; then IS_PRERELEASE=true; else IS_PRERELEASE=false; fi
|
||||
echo "IS_PRERELEASE=${IS_PRERELEASE}" >> ${GITHUB_OUTPUT}
|
||||
- name: Extract Tag Names
|
||||
id: extract-tag
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
echo "VERSION=${VERSION}" >> ${GITHUB_OUTPUT}
|
||||
echo "DOCKER_TAGS=${{ env.DOCKER_IMAGE }}:${VERSION}" >> ${GITHUB_ENV}
|
||||
echo "DOCKER_TAGS_HSM=${{ env.DOCKER_IMAGE }}:${VERSION}-hsm" >> ${GITHUB_ENV}
|
||||
- name: Add Latest Tag
|
||||
if: steps.is_prerelease.outputs.IS_PRERELEASE == 'false'
|
||||
run: |
|
||||
echo "DOCKER_TAGS=${{ env.DOCKER_TAGS }},${{ env.DOCKER_IMAGE }}:latest" >> ${GITHUB_ENV}
|
||||
echo "DOCKER_TAGS_HSM=${{ env.DOCKER_TAGS_HSM }},${{ env.DOCKER_IMAGE }}:hsm" >> ${GITHUB_ENV}
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ github.ref }}
|
||||
release_name: Release ${{ github.ref }}
|
||||
draft: false
|
||||
prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||
|
||||
goreleaser:
|
||||
needs: create_release
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
uses: smallstep/workflows/.github/workflows/goreleaser.yml@main
|
||||
secrets: inherit
|
||||
|
||||
build_upload_docker:
|
||||
name: Build & Upload Docker Images
|
||||
needs: create_release
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
uses: smallstep/workflows/.github/workflows/docker-buildx-push.yml@main
|
||||
with:
|
||||
platforms: linux/amd64,linux/386,linux/arm,linux/arm64
|
||||
tags: ${{ needs.create_release.outputs.docker_tags }}
|
||||
docker_image: smallstep/step-ca
|
||||
docker_file: docker/Dockerfile
|
||||
secrets: inherit
|
||||
|
||||
build_upload_docker_hsm:
|
||||
name: Build & Upload HSM Enabled Docker Images
|
||||
needs: create_release
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
uses: smallstep/workflows/.github/workflows/docker-buildx-push.yml@main
|
||||
with:
|
||||
platforms: linux/amd64,linux/386,linux/arm,linux/arm64
|
||||
tags: ${{ needs.create_release.outputs.docker_tags_hsm }}
|
||||
docker_image: smallstep/step-ca
|
||||
docker_file: docker/Dockerfile.hsm
|
||||
secrets: inherit
|
16
.github/workflows/triage.yml
vendored
16
.github/workflows/triage.yml
vendored
|
@ -1,16 +0,0 @@
|
|||
name: Add Issues and PRs to Triage
|
||||
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
- opened
|
||||
- reopened
|
||||
pull_request_target:
|
||||
types:
|
||||
- opened
|
||||
- reopened
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
uses: smallstep/workflows/.github/workflows/triage.yml@main
|
||||
secrets: inherit
|
10
.gitignore
vendored
10
.gitignore
vendored
|
@ -6,10 +6,6 @@
|
|||
*.so
|
||||
*.dylib
|
||||
|
||||
# Go Workspaces
|
||||
go.work
|
||||
go.work.sum
|
||||
|
||||
# Test binary, build with `go test -c`
|
||||
*.test
|
||||
|
||||
|
@ -18,9 +14,7 @@ go.work.sum
|
|||
|
||||
# Others
|
||||
*.swp
|
||||
.releases
|
||||
.travis-releases
|
||||
coverage.txt
|
||||
output
|
||||
vendor
|
||||
.idea
|
||||
.envrc
|
||||
output
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:85
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:107
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:108
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:129
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:131
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:136
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:138
|
||||
7c9ab9814fb676cb3c125c3dac4893271f1b7ae5:README.md:generic-api-key:282
|
||||
fb7140444ac8f1fa1245a80e49d17e206f7435f3:docs/provisioners.md:generic-api-key:110
|
||||
e4de7f07e82118b3f926716666b620db058fa9f7:docs/revocation.md:generic-api-key:73
|
||||
e4de7f07e82118b3f926716666b620db058fa9f7:docs/revocation.md:generic-api-key:113
|
||||
e4de7f07e82118b3f926716666b620db058fa9f7:docs/revocation.md:generic-api-key:151
|
||||
8b2de42e9cf6ce99f53a5049881e1d6077d5d66e:docs/docker.md:generic-api-key:152
|
||||
3939e855264117e81531df777a642ea953d325a7:autocert/init/ca/intermediate_ca_key:private-key:1
|
||||
e72f08703753facfa05f2d8c68f9f6a3745824b8:README.md:generic-api-key:244
|
||||
e70a5dae7de0b6ca40a0393c09c28872d4cfa071:autocert/README.md:generic-api-key:365
|
||||
e70a5dae7de0b6ca40a0393c09c28872d4cfa071:autocert/README.md:generic-api-key:366
|
||||
c284a2c0ab1c571a46443104be38c873ef0c7c6d:config.json:generic-api-key:10
|
71
.golangci.yml
Normal file
71
.golangci.yml
Normal file
|
@ -0,0 +1,71 @@
|
|||
linters-settings:
|
||||
govet:
|
||||
check-shadowing: true
|
||||
settings:
|
||||
printf:
|
||||
funcs:
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Infof
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf
|
||||
golint:
|
||||
min-confidence: 0
|
||||
gocyclo:
|
||||
min-complexity: 10
|
||||
maligned:
|
||||
suggest-new: true
|
||||
dupl:
|
||||
threshold: 100
|
||||
goconst:
|
||||
min-len: 2
|
||||
min-occurrences: 2
|
||||
depguard:
|
||||
list-type: blacklist
|
||||
packages:
|
||||
# logging is allowed only by logutils.Log, logrus
|
||||
# is allowed to use only in logutils package
|
||||
- github.com/sirupsen/logrus
|
||||
misspell:
|
||||
locale: US
|
||||
lll:
|
||||
line-length: 140
|
||||
goimports:
|
||||
local-prefixes: github.com/golangci/golangci-lint
|
||||
gocritic:
|
||||
enabled-tags:
|
||||
- performance
|
||||
- style
|
||||
- experimental
|
||||
disabled-checks:
|
||||
- wrapperFunc
|
||||
- dupImport # https://github.com/go-critic/go-critic/issues/845
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- gofmt
|
||||
- golint
|
||||
- vet
|
||||
- misspell
|
||||
- ineffassign
|
||||
- deadcode
|
||||
- staticcheck
|
||||
- unused
|
||||
- gosimple
|
||||
|
||||
run:
|
||||
skip-dirs:
|
||||
- pkg
|
||||
|
||||
issues:
|
||||
exclude:
|
||||
- can't lint
|
||||
- declaration of "err" shadows declaration at line
|
||||
- should have a package comment, unless it's in another file for this package
|
||||
- error strings should not be capitalized or end with punctuation or a newline
|
||||
# golangci.com configuration
|
||||
# https://github.com/golangci/golangci/wiki/Configuration
|
||||
service:
|
||||
golangci-lint-version: 1.18.x # use the fixed version to not introduce new linters unexpectedly
|
||||
prepare:
|
||||
- echo "here I can run custom commands, but no preparation needed for this repo"
|
238
.goreleaser.yml
238
.goreleaser.yml
|
@ -1,238 +0,0 @@
|
|||
# This is an example .goreleaser.yml file with some sane defaults.
|
||||
# Make sure to check the documentation at http://goreleaser.com
|
||||
project_name: step-ca
|
||||
|
||||
before:
|
||||
hooks:
|
||||
# You may remove this if you don't use go modules.
|
||||
- go mod download
|
||||
|
||||
builds:
|
||||
-
|
||||
id: step-ca
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
targets:
|
||||
- darwin_amd64
|
||||
- darwin_arm64
|
||||
- freebsd_amd64
|
||||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_5
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
flags:
|
||||
- -trimpath
|
||||
main: ./cmd/step-ca/main.go
|
||||
binary: step-ca
|
||||
ldflags:
|
||||
- -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}}
|
||||
|
||||
archives:
|
||||
- &ARCHIVE
|
||||
# Can be used to change the archive formats for specific GOOSs.
|
||||
# Most common use case is to archive as zip on Windows.
|
||||
# Default is empty.
|
||||
name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Version }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
|
||||
rlcp: true
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
format: zip
|
||||
wrap_in_directory: "{{ .ProjectName }}_{{ .Version }}"
|
||||
files:
|
||||
- README.md
|
||||
- LICENSE
|
||||
allow_different_binary_count: true
|
||||
-
|
||||
<< : *ARCHIVE
|
||||
id: unversioned
|
||||
name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
|
||||
|
||||
|
||||
nfpms:
|
||||
# Configure nFPM for .deb and .rpm releases
|
||||
#
|
||||
# See https://nfpm.goreleaser.com/configuration/
|
||||
# and https://goreleaser.com/customization/nfpm/
|
||||
#
|
||||
# Useful tools for debugging .debs:
|
||||
# List file contents: dpkg -c dist/step_...deb
|
||||
# Package metadata: dpkg --info dist/step_....deb
|
||||
#
|
||||
- &NFPM
|
||||
builds:
|
||||
- step-ca
|
||||
package_name: step-ca
|
||||
file_name_template: "{{ .PackageName }}_{{ .Version }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
|
||||
vendor: Smallstep Labs
|
||||
homepage: https://github.com/smallstep/certificates
|
||||
maintainer: Smallstep <techadmin@smallstep.com>
|
||||
description: >
|
||||
step-ca is an online certificate authority for secure, automated certificate management.
|
||||
license: Apache 2.0
|
||||
section: utils
|
||||
formats:
|
||||
- deb
|
||||
- rpm
|
||||
priority: optional
|
||||
bindir: /usr/bin
|
||||
contents:
|
||||
- src: debian/copyright
|
||||
dst: /usr/share/doc/step-ca/copyright
|
||||
-
|
||||
<< : *NFPM
|
||||
id: unversioned
|
||||
file_name_template: "{{ .PackageName }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
|
||||
|
||||
source:
|
||||
enabled: true
|
||||
rlcp: true
|
||||
name_template: '{{ .ProjectName }}_{{ .Version }}'
|
||||
|
||||
checksum:
|
||||
name_template: 'checksums.txt'
|
||||
extra_files:
|
||||
- glob: ./.releases/*
|
||||
|
||||
signs:
|
||||
- cmd: cosign
|
||||
signature: "${artifact}.sig"
|
||||
certificate: "${artifact}.pem"
|
||||
args: ["sign-blob", "--oidc-issuer=https://token.actions.githubusercontent.com", "--output-certificate=${certificate}", "--output-signature=${signature}", "${artifact}"]
|
||||
artifacts: all
|
||||
|
||||
snapshot:
|
||||
name_template: "{{ .Tag }}-next"
|
||||
|
||||
release:
|
||||
# Repo in which the release will be created.
|
||||
# Default is extracted from the origin remote URL or empty if its private hosted.
|
||||
# Note: it can only be one: either github, gitlab or gitea
|
||||
github:
|
||||
owner: smallstep
|
||||
name: certificates
|
||||
|
||||
# IDs of the archives to use.
|
||||
# Defaults to all.
|
||||
#ids:
|
||||
# - foo
|
||||
# - bar
|
||||
|
||||
# If set to true, will not auto-publish the release.
|
||||
# Default is false.
|
||||
draft: false
|
||||
|
||||
# If set to auto, will mark the release as not ready for production
|
||||
# in case there is an indicator for this in the tag e.g. v1.0.0-rc1
|
||||
# If set to true, will mark the release as not ready for production.
|
||||
# Default is false.
|
||||
prerelease: auto
|
||||
|
||||
# You can change the name of the release.
|
||||
# Default is `{{.Tag}}`
|
||||
name_template: "Step CA {{ .Tag }} ({{ .Env.RELEASE_DATE }})"
|
||||
|
||||
# Header template for the release body.
|
||||
# Defaults to empty.
|
||||
header: |
|
||||
## Official Release Artifacts
|
||||
|
||||
#### Linux
|
||||
|
||||
- 📦 [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz)
|
||||
- 📦 [step-ca_{{ .Version }}_amd64.deb](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ .Version }}_amd64.deb)
|
||||
|
||||
#### OSX Darwin
|
||||
|
||||
- 📦 [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz)
|
||||
- 📦 [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz)
|
||||
|
||||
#### Windows
|
||||
|
||||
- 📦 [step-ca_windows_{{ .Version }}_amd64.zip](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip)
|
||||
|
||||
For more builds across platforms and architectures, see the `Assets` section below.
|
||||
And for packaged versions (Docker, k8s, Homebrew), see our [installation docs](https://smallstep.com/docs/step-ca/installation).
|
||||
|
||||
Don't see the artifact you need? Open an issue [here](https://github.com/smallstep/certificates/issues/new/choose).
|
||||
|
||||
## Signatures and Checksums
|
||||
|
||||
`step-ca` uses [sigstore/cosign](https://github.com/sigstore/cosign) for signing and verifying release artifacts.
|
||||
|
||||
Below is an example using `cosign` to verify a release artifact:
|
||||
|
||||
```
|
||||
cosign verify-blob \
|
||||
--certificate ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig.pem \
|
||||
--signature ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig \
|
||||
--certificate-identity-regexp "https://github\.com/smallstep/certificates/.*" \
|
||||
--certificate-oidc-issuer https://token.actions.githubusercontent.com \
|
||||
~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz
|
||||
```
|
||||
|
||||
The `checksums.txt` file (in the `Assets` section below) contains a checksum for every artifact in the release.
|
||||
|
||||
# Footer template for the release body.
|
||||
# Defaults to empty.
|
||||
footer: |
|
||||
## Thanks!
|
||||
|
||||
Those were the changes on {{ .Tag }}!
|
||||
|
||||
Come join us on [Discord](https://discord.gg/X2RKGwEbV9) to ask questions, chat about PKI, or get a sneak peak at the freshest PKI memes.
|
||||
|
||||
# You can disable this pipe in order to not upload any artifacts.
|
||||
# Defaults to false.
|
||||
#disable: true
|
||||
|
||||
# You can add extra pre-existing files to the release.
|
||||
# The filename on the release will be the last part of the path (base). If
|
||||
# another file with the same name exists, the latest one found will be used.
|
||||
# Defaults to empty.
|
||||
extra_files:
|
||||
- glob: ./.releases/*
|
||||
#extra_files:
|
||||
# - glob: ./path/to/file.txt
|
||||
# - glob: ./glob/**/to/**/file/**/*
|
||||
# - glob: ./glob/foo/to/bar/file/foobar/override_from_previous
|
||||
|
||||
scoops:
|
||||
-
|
||||
ids: [ default ]
|
||||
# Template for the url which is determined by the given Token (github or gitlab)
|
||||
# Default for github is "https://github.com/<repo_owner>/<repo_name>/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
|
||||
# Default for gitlab is "https://gitlab.com/<repo_owner>/<repo_name>/uploads/{{ .ArtifactUploadHash }}/{{ .ArtifactName }}"
|
||||
# Default for gitea is "https://gitea.com/<repo_owner>/<repo_name>/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
|
||||
url_template: "http://github.com/smallstep/certificates/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
|
||||
# Repository to push the app manifest to.
|
||||
bucket:
|
||||
owner: smallstep
|
||||
name: scoop-bucket
|
||||
|
||||
# Git author used to commit to the repository.
|
||||
# Defaults are shown.
|
||||
commit_author:
|
||||
name: goreleaserbot
|
||||
email: goreleaser@smallstep.com
|
||||
|
||||
# The project name and current git tag are used in the format string.
|
||||
commit_msg_template: "Scoop update for {{ .ProjectName }} version {{ .Tag }}"
|
||||
|
||||
# Your app's homepage.
|
||||
# Default is empty.
|
||||
homepage: "https://smallstep.com/docs/step-ca"
|
||||
|
||||
# Skip uploads for prerelease.
|
||||
skip_upload: auto
|
||||
|
||||
# Your app's description.
|
||||
# Default is empty.
|
||||
description: "A private certificate authority (X.509 & SSH) & ACME server for secure automated certificate management, so you can use TLS everywhere & SSO for SSH."
|
||||
|
||||
# Your app's license
|
||||
# Default is empty.
|
||||
license: "Apache-2.0"
|
||||
|
33
.travis.yml
Normal file
33
.travis.yml
Normal file
|
@ -0,0 +1,33 @@
|
|||
language: go
|
||||
go:
|
||||
- 1.13.x
|
||||
addons:
|
||||
apt:
|
||||
packages:
|
||||
- debhelper
|
||||
- fakeroot
|
||||
- bash-completion
|
||||
env:
|
||||
global:
|
||||
- V=1
|
||||
before_script:
|
||||
- make bootstrap
|
||||
script:
|
||||
- make
|
||||
- make artifacts
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash) -t "$CODECOV_TOKEN" || echo "Codecov did
|
||||
not collect coverage reports"
|
||||
notifications:
|
||||
email: false
|
||||
stage: Github Release
|
||||
deploy:
|
||||
provider: releases
|
||||
skip_cleanup: true
|
||||
api_key:
|
||||
secure: EVV43Vkqn67hhKGYn4WhQp2YO6KFmUDSkLXjYXYGX07Fm8p5KjRFBPOz9LV83QrvVmLigvg0CtR8Jqqcnq2SUhus3nhZaN2g19NhMypZLioyOVP0kAkas8ocuvxkwz3YxIK/yMrmTKbQ7JGXtbc8IjAox9ovNo1fFIQmVMAzPfu++OWBJ0j+gUqKtpaNA7gzsSv8UOw3/T3hNm6E1IbpWxl9BPSOzUOE9F/QOThANzifGfdxvqNJFkAgqu5DVPz8zQNbMrz4zH+KwASKxd6hjhzSSMzouKzOEHTA/elDCHEjke0Jos29MkGWHcIydLtCD95DGecqM8BFSC9f2acHDjmUO1rdfoLA3Pt+UiZJuTwyQm/jrHHhRnH8oJpK15G5LvxSqzY9YDWpAk38+jMw/udW6wt7BGAU8FEXLbq0bsFL3yfTepeWjmzT5WS0YXdiBz2SEK+Og9R2bSdtl4owghRzKNio5DNPuYAbqbpi+jqzqQVLj27x7LWoQ0MHvZcz9U+oO00r6M1tDCmFVRdtfgb2H+MIDY69qYGo5qoGMfH1btCWR8bA9wSYB/Z7hW/xZT9r7f/d5/P40k8yKINmTZqyUTQeplrE3y4BPVzKksclczBZa67syIUQ49I35QppnH4GFQHUwlra7r3W9zfZRvaLnp5qOIKAQe3MAIZqtLg=
|
||||
file_glob: true
|
||||
file: .travis-releases/*
|
||||
on:
|
||||
repo: smallstep/certificates
|
||||
tags: true
|
|
@ -1,4 +1,4 @@
|
|||
#!/usr/bin/env sh
|
||||
#!/usr/bin/env bash
|
||||
read -r firstline < .VERSION
|
||||
last_half="${firstline##*tag: }"
|
||||
if [[ ${last_half::1} == "v" ]]; then
|
||||
|
|
396
CHANGELOG.md
396
CHANGELOG.md
|
@ -1,407 +1,13 @@
|
|||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
||||
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## TEMPLATE -- do not alter or remove
|
||||
|
||||
---
|
||||
|
||||
## [x.y.z] - aaaa-bb-cc
|
||||
|
||||
## [Unreleased - 0.0.1] - DATE
|
||||
### Added
|
||||
|
||||
### Changed
|
||||
|
||||
### Deprecated
|
||||
|
||||
### Removed
|
||||
|
||||
### Fixed
|
||||
|
||||
### Security
|
||||
|
||||
---
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Fixed
|
||||
|
||||
- Improved authentication for ACME requests using kid and provisioner name
|
||||
(smallstep/certificates#1386).
|
||||
|
||||
|
||||
## [v0.24.2] - 2023-05-11
|
||||
|
||||
### Added
|
||||
|
||||
- Log SSH certificates (smallstep/certificates#1374)
|
||||
- CRL endpoints on the HTTP server (smallstep/certificates#1372)
|
||||
- Dynamic SCEP challenge validation using webhooks (smallstep/certificates#1366)
|
||||
- For Docker deployments, added DOCKER_STEPCA_INIT_PASSWORD_FILE. Useful for pointing to a Docker Secret in the container (smallstep/certificates#1384)
|
||||
|
||||
### Changed
|
||||
|
||||
- Depend on [smallstep/go-attestation](https://github.com/smallstep/go-attestation) instead of [google/go-attestation](https://github.com/google/go-attestation)
|
||||
- Render CRLs into http.ResponseWriter instead of memory (smallstep/certificates#1373)
|
||||
- Redaction of SCEP static challenge when listing provisioners (smallstep/certificates#1204)
|
||||
|
||||
### Fixed
|
||||
|
||||
- VaultCAS certificate lifetime (smallstep/certificates#1376)
|
||||
|
||||
## [v0.24.1] - 2023-04-14
|
||||
|
||||
### Fixed
|
||||
|
||||
- Docker image name for HSM support (smallstep/certificates#1348)
|
||||
|
||||
## [v0.24.0] - 2023-04-12
|
||||
|
||||
### Added
|
||||
|
||||
- Add ACME `device-attest-01` support with TPM 2.0
|
||||
(smallstep/certificates#1063).
|
||||
- Add support for new Azure SDK, sovereign clouds, and HSM keys on Azure KMS
|
||||
(smallstep/crypto#192, smallstep/crypto#197, smallstep/crypto#198,
|
||||
smallstep/certificates#1323, smallstep/certificates#1309).
|
||||
- Add support for ASN.1 functions on certificate templates
|
||||
(smallstep/crypto#208, smallstep/certificates#1345)
|
||||
- Add `DOCKER_STEPCA_INIT_ADDRESS` to configure the address to use in a docker
|
||||
container (smallstep/certificates#1262).
|
||||
- Make sure that the CSR used matches the attested key when using AME
|
||||
`device-attest-01` challenge (smallstep/certificates#1265).
|
||||
- Add support for compacting the Badger DB (smallstep/certificates#1298).
|
||||
- Build and release cleanups (smallstep/certificates#1322,
|
||||
smallstep/certificates#1329, smallstep/certificates#1340).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fix support for PKCS #7 RSA-OAEP decryption through
|
||||
[smallstep/pkcs7#4](https://github.com/smallstep/pkcs7/pull/4), as used in
|
||||
SCEP.
|
||||
- Fix RA installation using `scripts/install-step-ra.sh`
|
||||
(smallstep/certificates#1255).
|
||||
- Clarify error messages on policy errors (smallstep/certificates#1287,
|
||||
smallstep/certificates#1278).
|
||||
- Clarify error message on OIDC email validation (smallstep/certificates#1290).
|
||||
- Mark the IDP critical in the generated CRL data (smallstep/certificates#1293).
|
||||
- Disable database if CA is initialized with the `--no-db` flag
|
||||
(smallstep/certificates#1294).
|
||||
|
||||
## [v0.23.2] - 2023-02-02
|
||||
|
||||
### Added
|
||||
|
||||
- Added [`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) to
|
||||
docker images, and a new image, `smallstep/step-ca-hsm`, compiled with cgo
|
||||
(smallstep/certificates#1243).
|
||||
- Added [`scoop`](https://scoop.sh) packages back to the release
|
||||
(smallstep/certificates#1250).
|
||||
- Added optional flag `--pidfile` which allows passing a filename where step-ca
|
||||
will write its process id (smallstep/certificates#1251).
|
||||
- Added helpful message on CA startup when config can't be opened
|
||||
(smallstep/certificates#1252).
|
||||
- Improved validation and error messages on `device-attest-01` orders
|
||||
(smallstep/certificates#1235).
|
||||
|
||||
### Removed
|
||||
|
||||
- The deprecated CLI utils `step-awskms-init`, `step-cloudkms-init`,
|
||||
`step-pkcs11-init`, `step-yubikey-init` have been removed.
|
||||
[`step`](https://github.com/smallstep/cli) and
|
||||
[`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) should be
|
||||
used instead (smallstep/certificates#1240).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed remote management flags in docker images (smallstep/certificates#1228).
|
||||
|
||||
## [v0.23.1] - 2023-01-10
|
||||
|
||||
### Added
|
||||
|
||||
- Added configuration property `.crl.idpURL` to be able to set a custom Issuing
|
||||
Distribution Point in the CRL (smallstep/certificates#1178).
|
||||
- Added WithContext methods to the CA client (smallstep/certificates#1211).
|
||||
- Docker: Added environment variables for enabling Remote Management and ACME
|
||||
provisioner (smallstep/certificates#1201).
|
||||
- Docker: The entrypoint script now generates and displays an initial JWK
|
||||
provisioner password by default when the CA is being initialized
|
||||
(smallstep/certificates#1223).
|
||||
|
||||
### Changed
|
||||
|
||||
- Ignore SSH principals validation when using an OIDC provisioner. The
|
||||
provisioner will ignore the principals passed and set the defaults or the ones
|
||||
including using WebHooks or templates (smallstep/certificates#1206).
|
||||
|
||||
## [v0.23.0] - 2022-11-11
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for ACME device-attest-01 challenge on iOS, iPadOS, tvOS and
|
||||
YubiKey.
|
||||
- Ability to disable ACME challenges and attestation formats.
|
||||
- Added flags to change ACME challenge ports for testing purposes.
|
||||
- Added name constraints evaluation and enforcement when issuing or renewing
|
||||
X.509 certificates.
|
||||
- Added provisioner webhooks for augmenting template data and authorizing
|
||||
certificate requests before signing.
|
||||
- Added automatic migration of provisioners when enabling remote management.
|
||||
- Added experimental support for CRLs.
|
||||
- Add certificate renewal support on RA mode. The `step ca renew` command must
|
||||
use the flag `--mtls=false` to use the token renewal flow.
|
||||
- Added support for initializing remote management using `step ca init`.
|
||||
- Added support for renewing X.509 certificates on RAs.
|
||||
- Added support for using SCEP with keys in a KMS.
|
||||
- Added client support to set the dialer's local address with the environment variable
|
||||
`STEP_CLIENT_ADDR`.
|
||||
|
||||
### Changed
|
||||
|
||||
- Remove the email requirement for issuing SSH certificates with an OIDC
|
||||
provisioner.
|
||||
- Root files can contain more than one certificate.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed MySQL DSN parsing issues with an upgrade to
|
||||
[smallstep/nosql@v0.5.0](https://github.com/smallstep/nosql/releases/tag/v0.5.0).
|
||||
- Fixed renewal of certificates with missing subject attributes.
|
||||
- Fixed ACME support with [ejabberd](https://github.com/processone/ejabberd).
|
||||
|
||||
### Deprecated
|
||||
|
||||
- The CLIs `step-awskms-init`, `step-cloudkms-init`, `step-pkcs11-init`,
|
||||
`step-yubikey-init` are deprecated. Now you can use
|
||||
[`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) in
|
||||
combination with `step certificates create` to initialize your PKI.
|
||||
|
||||
## [0.22.1] - 2022-08-31
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed signature algorithm on EC (root) + RSA (intermediate) PKIs.
|
||||
|
||||
## [0.22.0] - 2022-08-26
|
||||
|
||||
### Added
|
||||
|
||||
- Added automatic configuration of Linked RAs.
|
||||
- Send provisioner configuration on Linked RAs.
|
||||
|
||||
### Changed
|
||||
|
||||
- Certificates signed by an issuer using an RSA key will be signed using the
|
||||
same algorithm used to sign the issuer certificate. The signature will no
|
||||
longer default to PKCS #1. For example, if the issuer certificate was signed
|
||||
using RSA-PSS with SHA-256, a new certificate will also be signed using
|
||||
RSA-PSS with SHA-256.
|
||||
- Support two latest versions of Go (1.18, 1.19).
|
||||
- Validate revocation serial number (either base 10 or prefixed with an
|
||||
appropriate base).
|
||||
- Sanitize TLS options.
|
||||
|
||||
## [0.20.0] - 2022-05-26
|
||||
|
||||
### Added
|
||||
|
||||
- Added Kubernetes auth method for Vault RAs.
|
||||
- Added support for reporting provisioners to linkedca.
|
||||
- Added support for certificate policies on authority level.
|
||||
- Added a Dockerfile with a step-ca build with HSM support.
|
||||
- A few new WithXX methods for instantiating authorities
|
||||
|
||||
### Changed
|
||||
|
||||
- Context usage in HTTP APIs.
|
||||
- Changed authentication for Vault RAs.
|
||||
- Error message returned to client when authenticating with expired certificate.
|
||||
- Strip padding from ACME CSRs.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- HTTP API handler types.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed SSH revocation.
|
||||
- CA client dial context for js/wasm target.
|
||||
- Incomplete `extraNames` support in templates.
|
||||
- SCEP GET request support.
|
||||
- Large SCEP request handling.
|
||||
|
||||
## [0.19.0] - 2022-04-19
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`.
|
||||
- Added support for `extraNames` in X.509 templates.
|
||||
- Added `armv5` builds.
|
||||
- Added RA support using a Vault instance as the CA.
|
||||
- Added `WithX509SignerFunc` authority option.
|
||||
- Added a new `/roots.pem` endpoint to download the CA roots in PEM format.
|
||||
- Added support for Azure `Managed Identity` tokens.
|
||||
- Added support for automatic configuration of linked RAs.
|
||||
- Added support for the `--context` flag. It's now possible to start the
|
||||
CA with `step-ca --context=abc` to use the configuration from context `abc`.
|
||||
When a context has been configured and no configuration file is provided
|
||||
on startup, the configuration for the current context is used.
|
||||
- Added startup info logging and option to skip it (`--quiet`).
|
||||
- Added support for renaming the CA (Common Name).
|
||||
|
||||
### Changed
|
||||
|
||||
- Made SCEP CA URL paths dynamic.
|
||||
- Support two latest versions of Go (1.17, 1.18).
|
||||
- Upgrade go.step.sm/crypto to v0.16.1.
|
||||
- Upgrade go.step.sm/linkedca to v0.15.0.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Go 1.16 support.
|
||||
|
||||
### Removed
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed admin credentials on RAs.
|
||||
- Fixed ACME HTTP-01 challenges for IPv6 identifiers.
|
||||
- Various improvements under the hood.
|
||||
|
||||
### Security
|
||||
|
||||
## [0.18.2] - 2022-03-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner.
|
||||
- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering
|
||||
out database drivers using Go tags. For example, using the Go flag
|
||||
`--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx
|
||||
driver for PostgreSQL.
|
||||
|
||||
### Changed
|
||||
|
||||
- IPv6 addresses are normalized as IP addresses instead of hostnames.
|
||||
- More descriptive JWK decryption error message.
|
||||
- Make the X5C leaf certificate available to the templates using `{{ .AuthorizationCrt }}`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- During provisioner add - validate provisioner configuration before storing to DB.
|
||||
|
||||
## [0.18.1] - 2022-02-03
|
||||
|
||||
### Added
|
||||
|
||||
- Support for ACME revocation.
|
||||
- Replace hash function with an RSA SSH CA to "rsa-sha2-256".
|
||||
- Support Nebula provisioners.
|
||||
- Example Ansible configurations.
|
||||
- Support PKCS#11 as a decrypter, as used by SCEP.
|
||||
|
||||
### Changed
|
||||
|
||||
- Automatically create database directory on `step ca init`.
|
||||
- Slightly improve errors reported when a template has invalid content.
|
||||
- Error reporting in logs and to clients.
|
||||
|
||||
### Fixed
|
||||
|
||||
- SCEP renewal using HTTPS on macOS.
|
||||
|
||||
## [0.18.0] - 2021-11-17
|
||||
|
||||
### Added
|
||||
|
||||
- Support for multiple certificate authority contexts.
|
||||
- Support for generating extractable keys and certificates on a pkcs#11 module.
|
||||
|
||||
### Changed
|
||||
|
||||
- Support two latest versions of Go (1.16, 1.17)
|
||||
|
||||
### Deprecated
|
||||
|
||||
- go 1.15 support
|
||||
|
||||
## [0.17.6] - 2021-10-20
|
||||
|
||||
### Notes
|
||||
|
||||
- 0.17.5 failed in CI/CD
|
||||
|
||||
## [0.17.5] - 2021-10-20
|
||||
|
||||
### Added
|
||||
|
||||
- Support for Azure Key Vault as a KMS.
|
||||
- Adapt `pki` package to support key managers.
|
||||
- gocritic linter
|
||||
|
||||
### Fixed
|
||||
|
||||
- gocritic warnings
|
||||
|
||||
## [0.17.4] - 2021-09-28
|
||||
|
||||
### Fixed
|
||||
|
||||
- Support host-only or user-only SSH CA.
|
||||
|
||||
## [0.17.3] - 2021-09-24
|
||||
|
||||
### Added
|
||||
|
||||
- go 1.17 to github action test matrix
|
||||
- Support for CloudKMS RSA-PSS signers without using templates.
|
||||
- Add flags to support individual passwords for the intermediate and SSH keys.
|
||||
- Global support for group admins in the OIDC provisioner.
|
||||
|
||||
### Changed
|
||||
|
||||
- Using go 1.17 for binaries
|
||||
|
||||
### Fixed
|
||||
|
||||
- Upgrade go-jose.v2 to fix a bug in the JWK fingerprint of Ed25519 keys.
|
||||
|
||||
### Security
|
||||
|
||||
- Use cosign to sign and upload signatures for multi-arch Docker container.
|
||||
- Add debian checksum
|
||||
|
||||
## [0.17.2] - 2021-08-30
|
||||
|
||||
### Added
|
||||
|
||||
- Additional way to distinguish Azure IID and Azure OIDC tokens.
|
||||
|
||||
### Security
|
||||
|
||||
- Sign over all goreleaser github artifacts using cosign
|
||||
|
||||
## [0.17.1] - 2021-08-26
|
||||
|
||||
## [0.17.0] - 2021-08-25
|
||||
|
||||
### Added
|
||||
|
||||
- Add support for Linked CAs using protocol buffers and gRPC
|
||||
- `step-ca init` adds support for
|
||||
- configuring a StepCAS RA
|
||||
- configuring a Linked CA
|
||||
- congifuring a `step-ca` using Helm
|
||||
|
||||
### Changed
|
||||
|
||||
- Update badger driver to use v2 by default
|
||||
- Update TLS cipher suites to include 1.3
|
||||
|
||||
### Security
|
||||
|
||||
- Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512.
|
||||
|
|
209
LICENSE
209
LICENSE
|
@ -1,202 +1,13 @@
|
|||
Copyright (c) 2019 Smallstep Labs, Inc.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 Smallstep Labs, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
253
Makefile
253
Makefile
|
@ -6,23 +6,18 @@ Q=$(if $V,,@)
|
|||
PREFIX?=
|
||||
SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*")
|
||||
GOOS_OVERRIDE ?=
|
||||
OUTPUT_ROOT=output/
|
||||
|
||||
all: lint test build
|
||||
all: build test lint
|
||||
|
||||
ci: testcgo build
|
||||
|
||||
.PHONY: all ci
|
||||
.PHONY: all
|
||||
|
||||
#########################################
|
||||
# Bootstrapping
|
||||
#########################################
|
||||
|
||||
bootstra%:
|
||||
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $$(go env GOPATH)/bin latest
|
||||
$Q go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
$Q go install gotest.tools/gotestsum@latest
|
||||
$Q go install github.com/goreleaser/goreleaser@latest
|
||||
$Q go install github.com/sigstore/cosign/v2/cmd/cosign@latest
|
||||
$Q GO111MODULE=on go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.18.0
|
||||
|
||||
.PHONY: bootstra%
|
||||
|
||||
|
@ -30,29 +25,23 @@ bootstra%:
|
|||
# Determine the type of `push` and `version`
|
||||
#################################################
|
||||
|
||||
# GITHUB Actions
|
||||
ifdef GITHUB_REF
|
||||
VERSION ?= $(shell echo $(GITHUB_REF) | sed 's/^refs\/tags\///')
|
||||
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
|
||||
ifeq ($(NOT_RC),)
|
||||
PUSHTYPE := release-candidate
|
||||
else
|
||||
PUSHTYPE := release
|
||||
endif
|
||||
else
|
||||
# Version flags to embed in the binaries
|
||||
VERSION ?= $(shell [ -d .git ] && git describe --tags --always --dirty="-dev")
|
||||
# If we are not in an active git dir then try reading the version from .VERSION.
|
||||
# .VERSION contains a slug populated by `git archive`.
|
||||
VERSION := $(or $(VERSION),$(shell ./.version.sh .VERSION))
|
||||
PUSHTYPE := branch
|
||||
endif
|
||||
|
||||
VERSION := $(shell echo $(VERSION) | sed 's/^v//')
|
||||
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
|
||||
|
||||
ifdef V
|
||||
$(info GITHUB_REF is $(GITHUB_REF))
|
||||
$(info VERSION is $(VERSION))
|
||||
$(info PUSHTYPE is $(PUSHTYPE))
|
||||
# If TRAVIS_TAG is set then we know this ref has been tagged.
|
||||
ifdef TRAVIS_TAG
|
||||
ifeq ($(NOT_RC),)
|
||||
PUSHTYPE=release-candidate
|
||||
else
|
||||
PUSHTYPE=release
|
||||
endif
|
||||
else
|
||||
PUSHTYPE=master
|
||||
endif
|
||||
|
||||
#########################################
|
||||
|
@ -61,23 +50,7 @@ endif
|
|||
|
||||
DATE := $(shell date -u '+%Y-%m-%d %H:%M UTC')
|
||||
LDFLAGS := -ldflags='-w -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"'
|
||||
|
||||
# Always explicitly enable or disable cgo,
|
||||
# so that go doesn't silently fall back on
|
||||
# non-cgo when gcc is not found.
|
||||
ifeq (,$(findstring CGO_ENABLED,$(GO_ENVS)))
|
||||
ifneq ($(origin GOFLAGS),undefined)
|
||||
# This section is for backward compatibility with
|
||||
#
|
||||
# $ make build GOFLAGS=""
|
||||
#
|
||||
# which is how we recommended building step-ca with cgo support
|
||||
# until June 2023.
|
||||
GO_ENVS := $(GO_ENVS) CGO_ENABLED=1
|
||||
else
|
||||
GO_ENVS := $(GO_ENVS) CGO_ENABLED=0
|
||||
endif
|
||||
endif
|
||||
GOFLAGS := CGO_ENABLED=0
|
||||
|
||||
download:
|
||||
$Q go mod download
|
||||
|
@ -87,10 +60,13 @@ build: $(PREFIX)bin/$(BINNAME)
|
|||
|
||||
$(PREFIX)bin/$(BINNAME): download $(call rwildcard,*.go)
|
||||
$Q mkdir -p $(@D)
|
||||
$Q $(GOOS_OVERRIDE) GOFLAGS="$(GOFLAGS)" $(GO_ENVS) go build -v -o $(PREFIX)bin/$(BINNAME) $(LDFLAGS) $(PKG)
|
||||
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(BINNAME) $(LDFLAGS) $(PKG)
|
||||
|
||||
# Target to force a build of step-ca without running tests
|
||||
simple: build
|
||||
simple:
|
||||
$Q mkdir -p $(PREFIX)bin
|
||||
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(BINNAME) $(LDFLAGS) $(PKG)
|
||||
@echo "Build Complete!"
|
||||
|
||||
.PHONY: download build simple
|
||||
|
||||
|
@ -106,26 +82,15 @@ generate:
|
|||
#########################################
|
||||
# Test
|
||||
#########################################
|
||||
test: testdefault testtpmsimulator combinecoverage
|
||||
test:
|
||||
$Q $(GOFLAGS) go test -short -coverprofile=coverage.out ./...
|
||||
|
||||
testdefault:
|
||||
$Q $(GO_ENVS) gotestsum -- -coverprofile=defaultcoverage.out -short -covermode=atomic ./...
|
||||
|
||||
testtpmsimulator:
|
||||
$Q CGO_ENABLED=1 gotestsum -- -coverprofile=tpmsimulatorcoverage.out -short -covermode=atomic -tags tpmsimulator ./acme
|
||||
|
||||
testcgo:
|
||||
$Q gotestsum -- -coverprofile=coverage.out -short -covermode=atomic ./...
|
||||
|
||||
combinecoverage:
|
||||
cat defaultcoverage.out tpmsimulatorcoverage.out > coverage.out
|
||||
|
||||
.PHONY: test testdefault testtpmsimulator testcgo combinecoverage
|
||||
.PHONY: test
|
||||
|
||||
integrate: integration
|
||||
|
||||
integration: bin/$(BINNAME)
|
||||
$Q $(GO_ENVS) gotestsum -- -tags=integration ./integration/...
|
||||
$Q $(GOFLAGS) go test -tags=integration ./integration/...
|
||||
|
||||
.PHONY: integrate integration
|
||||
|
||||
|
@ -134,14 +99,12 @@ integration: bin/$(BINNAME)
|
|||
#########################################
|
||||
|
||||
fmt:
|
||||
$Q goimports -l -w $(SRC)
|
||||
$Q gofmt -l -w $(SRC)
|
||||
|
||||
lint: SHELL:=/bin/bash
|
||||
lint:
|
||||
$Q LOG_LEVEL=error golangci-lint run --config <(curl -s https://raw.githubusercontent.com/smallstep/workflows/master/.golangci.yml) --timeout=30m
|
||||
$Q govulncheck ./...
|
||||
$Q LOG_LEVEL=error golangci-lint run
|
||||
|
||||
.PHONY: fmt lint
|
||||
.PHONY: lint fmt
|
||||
|
||||
#########################################
|
||||
# Install
|
||||
|
@ -169,11 +132,163 @@ endif
|
|||
.PHONY: clean
|
||||
|
||||
#########################################
|
||||
# Dev
|
||||
# Building Docker Image
|
||||
#
|
||||
# Builds a dockerfile for step by building a linux version of the step-cli and
|
||||
# then copying the specific binary when building the container.
|
||||
#
|
||||
# This ensures the container is as small as possible without having to deal
|
||||
# with getting access to private repositories inside the container during build
|
||||
# time.
|
||||
#########################################
|
||||
|
||||
run:
|
||||
$Q go run cmd/step-ca/main.go $(shell step path)/config/ca.json
|
||||
# XXX We put the output for the build in 'output' so we don't mess with how we
|
||||
# do rule overriding from the base Makefile (if you name it 'build' it messes up
|
||||
# the wildcarding).
|
||||
DOCKER_OUTPUT=$(OUTPUT_ROOT)docker/
|
||||
|
||||
.PHONY: run
|
||||
DOCKER_MAKE=V=$V GOOS_OVERRIDE='GOOS=linux GOARCH=amd64' PREFIX=$(1) make $(1)bin/$(2)
|
||||
DOCKER_BUILD=$Q docker build -t smallstep/$(1):latest -f docker/$(2) --build-arg BINPATH=$(DOCKER_OUTPUT)bin/$(1) .
|
||||
|
||||
docker: docker-make docker/Dockerfile.step-ca
|
||||
$(call DOCKER_BUILD,step-ca,Dockerfile.step-ca)
|
||||
|
||||
docker-make:
|
||||
mkdir -p $(DOCKER_OUTPUT)
|
||||
$(call DOCKER_MAKE,$(DOCKER_OUTPUT),step-ca)
|
||||
|
||||
.PHONY: docker docker-make
|
||||
|
||||
#################################################
|
||||
# Releasing Docker Images
|
||||
#
|
||||
# Using the docker build infrastructure, this section is responsible for
|
||||
# logging into docker hub and pushing the built docker containers up with the
|
||||
# appropriate tags.
|
||||
#################################################
|
||||
|
||||
DOCKER_TAG=docker tag smallstep/$(1):latest smallstep/$(1):$(2)
|
||||
DOCKER_PUSH=docker push smallstep/$(1):$(2)
|
||||
|
||||
docker-tag:
|
||||
$(call DOCKER_TAG,step-ca,$(VERSION))
|
||||
|
||||
docker-push-tag: docker-tag
|
||||
$(call DOCKER_PUSH,step-ca,$(VERSION))
|
||||
|
||||
docker-push-tag-latest:
|
||||
$(call DOCKER_PUSH,step-ca,latest)
|
||||
|
||||
# Rely on DOCKER_USERNAME and DOCKER_PASSWORD being set inside the CI or
|
||||
# equivalent environment
|
||||
docker-login:
|
||||
$Q docker login -u="$(DOCKER_USERNAME)" -p="$(DOCKER_PASSWORD)"
|
||||
|
||||
.PHONY: docker-login docker-tag docker-push-tag docker-push-tag-latest
|
||||
|
||||
#################################################
|
||||
# Targets for pushing the docker images
|
||||
#################################################
|
||||
|
||||
# For all builds we build the docker container
|
||||
docker-master: docker
|
||||
|
||||
# For all builds with a release candidate tag
|
||||
docker-release-candidate: docker-master docker-login docker-push-tag
|
||||
|
||||
# For all builds with a release tag
|
||||
docker-release: docker-release-candidate docker-push-tag-latest
|
||||
|
||||
.PHONY: docker-master docker-release-candidate docker-release
|
||||
|
||||
#########################################
|
||||
# Debian
|
||||
#########################################
|
||||
|
||||
changelog:
|
||||
$Q echo "step-certificates ($(VERSION)) unstable; urgency=medium" > debian/changelog
|
||||
$Q echo >> debian/changelog
|
||||
$Q echo " * See https://github.com/smallstep/certificates/releases" >> debian/changelog
|
||||
$Q echo >> debian/changelog
|
||||
$Q echo " -- Smallstep Labs, Inc. <techadmin@smallstep.com> $(shell date -uR)" >> debian/changelog
|
||||
|
||||
debian: changelog
|
||||
$Q mkdir -p $(RELEASE); \
|
||||
OUTPUT=../step-certificates_*.deb; \
|
||||
rm $$OUTPUT; \
|
||||
dpkg-buildpackage -b -rfakeroot -us -uc && cp $$OUTPUT $(RELEASE)/
|
||||
|
||||
distclean: clean
|
||||
|
||||
.PHONY: changelog debian distclean
|
||||
|
||||
#################################################
|
||||
# Build statically compiled step binary for various operating systems
|
||||
#################################################
|
||||
|
||||
BINARY_OUTPUT=$(OUTPUT_ROOT)binary/
|
||||
BUNDLE_MAKE=v=$v GOOS_OVERRIDE='GOOS=$(1) GOARCH=$(2)' PREFIX=$(3) make $(3)bin/$(BINNAME)
|
||||
RELEASE=./.travis-releases
|
||||
|
||||
binary-linux:
|
||||
$(call BUNDLE_MAKE,linux,amd64,$(BINARY_OUTPUT)linux/)
|
||||
|
||||
binary-darwin:
|
||||
$(call BUNDLE_MAKE,darwin,amd64,$(BINARY_OUTPUT)darwin/)
|
||||
|
||||
define BUNDLE
|
||||
$(q)BUNDLE_DIR=$(BINARY_OUTPUT)$(1)/bundle; \
|
||||
stepName=step-certificates_$(2); \
|
||||
mkdir -p $$BUNDLE_DIR $(RELEASE); \
|
||||
TMP=$$(mktemp -d $$BUNDLE_DIR/tmp.XXXX); \
|
||||
trap "rm -rf $$TMP" EXIT INT QUIT TERM; \
|
||||
newdir=$$TMP/$$stepName; \
|
||||
mkdir -p $$newdir/bin; \
|
||||
cp $(BINARY_OUTPUT)$(1)/bin/$(BINNAME) $$newdir/bin/; \
|
||||
cp README.md $$newdir/; \
|
||||
NEW_BUNDLE=$(RELEASE)/step-certificates_$(2)_$(1)_$(3).tar.gz; \
|
||||
rm -f $$NEW_BUNDLE; \
|
||||
tar -zcvf $$NEW_BUNDLE -C $$TMP $$stepName;
|
||||
endef
|
||||
|
||||
bundle-linux: binary-linux
|
||||
$(call BUNDLE,linux,$(VERSION),amd64)
|
||||
|
||||
bundle-darwin: binary-darwin
|
||||
$(call BUNDLE,darwin,$(VERSION),amd64)
|
||||
|
||||
.PHONY: binary-linux binary-darwin bundle-linux bundle-darwin
|
||||
|
||||
#################################################
|
||||
# Targets for creating OS specific artifacts and archives
|
||||
#################################################
|
||||
|
||||
artifacts-linux-tag: bundle-linux debian
|
||||
|
||||
artifacts-darwin-tag: bundle-darwin
|
||||
|
||||
artifacts-archive-tag:
|
||||
$Q mkdir -p $(RELEASE)
|
||||
$Q git archive v$(VERSION) | gzip > $(RELEASE)/step-certificates_$(VERSION).tar.gz
|
||||
|
||||
artifacts-tag: artifacts-linux-tag artifacts-darwin-tag artifacts-archive-tag
|
||||
|
||||
.PHONY: artifacts-linux-tag artifacts-darwin-tag artifacts-archive-tag artifacts-tag
|
||||
|
||||
#################################################
|
||||
# Targets for creating step artifacts
|
||||
#################################################
|
||||
|
||||
# For all builds that are not tagged
|
||||
artifacts-master:
|
||||
|
||||
# For all builds with a release-candidate (-rc) tag
|
||||
artifacts-release-candidate: artifacts-tag
|
||||
|
||||
# For all builds with a release tag
|
||||
artifacts-release: artifacts-tag
|
||||
|
||||
# This command is called by travis directly *after* a successful build
|
||||
artifacts: artifacts-$(PUSHTYPE) docker-$(PUSHTYPE)
|
||||
|
||||
.PHONY: artifacts-master artifacts-release-candidate artifacts-release artifacts
|
||||
|
|
430
README.md
430
README.md
|
@ -1,131 +1,367 @@
|
|||
# Step Certificates
|
||||
|
||||
`step-ca` is an online certificate authority for secure, automated certificate management. It's the server counterpart to the [`step` CLI tool](https://github.com/smallstep/cli).
|
||||
An online certificate authority and related tools for secure automated certificate management, so you can use TLS everywhere.
|
||||
|
||||
You can use it to:
|
||||
- Issue X.509 certificates for your internal infrastructure:
|
||||
- HTTPS certificates that [work in browsers](https://smallstep.com/blog/step-v0-8-6-valid-HTTPS-certificates-for-dev-pre-prod.html) ([RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliance)
|
||||
- TLS certificates for VMs, containers, APIs, mobile clients, database connections, printers, wifi networks, toaster ovens...
|
||||
- Client certificates to [enable mutual TLS (mTLS)](https://smallstep.com/hello-mtls) in your infra. mTLS is an optional feature in TLS where both client and server authenticate each other. Why add the complexity of a VPN when you can safely use mTLS over the public internet?
|
||||
- Issue SSH certificates:
|
||||
- For people, in exchange for single sign-on ID tokens
|
||||
- For hosts, in exchange for cloud instance identity documents
|
||||
- Easily automate certificate management:
|
||||
- It's an ACME v2 server
|
||||
- It has a JSON API
|
||||
- It comes with a [Go wrapper](./examples#user-content-basic-client-usage)
|
||||
- ... and there's a [command-line client](https://github.com/smallstep/cli) you can use in scripts!
|
||||
This repository is for `step-ca`, a certificate authority that exposes an API for automated certificate management. It also contains a [golang SDK](https://github.com/smallstep/certificates/tree/master/examples#basic-client-usage) for interacting with `step-ca` programatically. However, you'll probably want to use the [`step` command-line tool](https://github.com/smallstep/cli) to operate `step-ca` and get certificates, instead of using this low-level SDK directly.
|
||||
|
||||
Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [safe, sane defaults](https://smallstep.com/docs/step-ca/certificate-authority-server-production#sane-cryptographic-defaults).
|
||||
**Questions? Find us [on gitter](https://gitter.im/smallstep/community).**
|
||||
|
||||
---
|
||||
[Website](https://smallstep.com) |
|
||||
[Documentation](#documentation) |
|
||||
[Installation Guide](#installation-guide) |
|
||||
[Getting Started](./docs/GETTING_STARTED.md) |
|
||||
[Contribution Guide](./docs/CONTRIBUTING.md)
|
||||
|
||||
**Don't want to run your own CA?**
|
||||
To get up and running quickly, or as an alternative to running your own `step-ca` server, consider creating a [free hosted smallstep Certificate Manager authority](https://info.smallstep.com/certificate-manager-early-access-mvp/).
|
||||
|
||||
---
|
||||
|
||||
**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions) or [Join our Discord](https://u.step.sm/discord).**
|
||||
|
||||
[Website](https://smallstep.com/certificates) |
|
||||
[Documentation](https://smallstep.com/docs) |
|
||||
[Installation](https://smallstep.com/docs/step-ca/installation) |
|
||||
[Getting Started](https://smallstep.com/docs/step-ca/getting-started) |
|
||||
[Contributor's Guide](./docs/CONTRIBUTING.md)
|
||||
|
||||
[](https://github.com/smallstep/certificates/releases/latest)
|
||||
[](https://github.com/smallstep/certificates/releases)
|
||||
[](https://gitter.im/smallstep/community)
|
||||
[](https://microbadger.com/images/smallstep/step-ca)
|
||||
[](https://goreportcard.com/report/github.com/smallstep/certificates)
|
||||
[](https://github.com/smallstep/certificates)
|
||||
[](https://travis-ci.com/smallstep/certificates)
|
||||
[](https://opensource.org/licenses/Apache-2.0)
|
||||
[](https://cla-assistant.io/smallstep/certificates)
|
||||
|
||||
[](https://github.com/smallstep/certificates/stargazers)
|
||||
[](https://twitter.com/intent/follow?screen_name=smallsteplabs)
|
||||
|
||||

|
||||

|
||||
|
||||
## Features
|
||||
|
||||
### 🦾 A fast, stable, flexible private CA
|
||||
It's super easy to get started and to operate `step-ca` thanks to [streamlined initialization](https://github.com/smallstep/certificates#lets-get-started) and [safe, sane defaults](https://github.com/smallstep/certificates/blob/master/docs/defaults.md). **Get started in 15 minutes.**
|
||||
|
||||
Setting up a *public key infrastructure* (PKI) is out of reach for many small teams. `step-ca` makes it easier.
|
||||
### A private certificate authority you run yourself
|
||||
|
||||
- Choose key types (RSA, ECDSA, EdDSA) and lifetimes to suit your needs
|
||||
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation
|
||||
- Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
|
||||
- Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca)
|
||||
- [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
|
||||
- Issue client and server certificates to VMs, containers, devices, and people using internal hostnames and emails
|
||||
- [RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliant certificates that work **for TLS and HTTPS**
|
||||
- Choose key types (RSA, ECDSA, EdDSA) & lifetimes to suit your needs
|
||||
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with **fully automated** enrollment, renewal, and revocation
|
||||
- Fast, stable, and capable of high availability deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
|
||||
- Operate as an online intermediate for an existing root CA
|
||||
- [Pluggable database backends](https://github.com/smallstep/certificates/blob/master/docs/database.md) for persistence
|
||||
- [Helm charts](https://hub.helm.sh/charts/smallstep/step-certificates), [autocert](https://github.com/smallstep/autocert), and [cert-manager integration](https://github.com/smallstep/step-issuer) for kubernetes
|
||||
|
||||
### ⚙️ Many ways to automate
|
||||
### Lots of (automatable) ways to get certificates
|
||||
|
||||
There are several ways to authorize a request with the CA and establish a chain of trust that suits your flow.
|
||||
- [Single sign-on](https://smallstep.com/blog/easily-curl-services-secured-by-https-tls.html) using Okta, GSuite, Active Directory, or any other OAuth OIDC identity provider
|
||||
- [Instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/) for VMs on AWS, GCP, and Azure
|
||||
- [Single-use short-lived tokens](https://smallstep.com/docs/design-doc.html#jwk-provisioner) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
|
||||
- Use an existing certificate from another CA (e.g., using a device certificate like [Twilio's Trust OnBoard](https://www.twilio.com/wireless/trust-onboard)) *coming soon*
|
||||
|
||||
You can issue certificates in exchange for:
|
||||
- [ACME challenge responses](#your-own-private-acme-server) from any ACMEv2 client
|
||||
- [OAuth OIDC single sign-on tokens](https://smallstep.com/blog/easily-curl-services-secured-by-https-tls.html), eg:
|
||||
- ID tokens from Okta, GSuite, Azure AD, Auth0.
|
||||
- ID tokens from an OAuth OIDC service that you host, like [Keycloak](https://www.keycloak.org/) or [Dex](https://github.com/dexidp/dex)
|
||||
- [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure
|
||||
- [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
|
||||
- A trusted X.509 certificate (X5C provisioner)
|
||||
- A host certificate from your Nebula network
|
||||
- A SCEP challenge (SCEP provisioner)
|
||||
- An SSH host certificates needing renewal (the SSHPOP provisioner)
|
||||
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners)
|
||||
### [Your own private ACME Server](https://smallstep.com/blog/private-acme-server/)
|
||||
- Issue certificates using ACMEv2 ([RFC8555](https://tools.ietf.org/html/rfc8555)), **the protocol used by Let's Encrypt**
|
||||
- Great for [using ACME in development & pre-production](https://smallstep.com/blog/private-acme-server/#local-development-pre-production)
|
||||
- Supports the `http-01` and `dns-01` ACME challenge types
|
||||
- Works with any compliant ACME client including [certbot](https://smallstep.com/blog/private-acme-server/#certbot-uploads-acme-certbot-png-certbot-example), [acme.sh](https://smallstep.com/blog/private-acme-server/#acme-sh-uploads-acme-acme-sh-png-acme-sh-example), [Caddy](https://smallstep.com/blog/private-acme-server/#caddy-uploads-acme-caddy-png-caddy-example), and [traefik](https://smallstep.com/blog/private-acme-server/#traefik-uploads-acme-traefik-png-traefik-example)
|
||||
- Get certificates programmatically (e.g., in [Go](https://smallstep.com/blog/private-acme-server/#golang-uploads-acme-golang-png-go-example), [Python](https://smallstep.com/blog/private-acme-server/#python-uploads-acme-python-png-python-example), [Node.js](https://smallstep.com/blog/private-acme-server/#node-js-uploads-acme-node-js-png-node-js-example))
|
||||
|
||||
### 🏔 Your own private ACME server
|
||||
### [SSH Certificates](https://smallstep.com/blog/use-ssh-certificates/)
|
||||
|
||||
ACME is the protocol used by Let's Encrypt to automate the issuance of HTTPS certificates. It's _super easy_ to issue certificates to any ACMEv2 ([RFC8555](https://tools.ietf.org/html/rfc8555)) client.
|
||||
- Use [certificate authentication for SSH](https://smallstep.com/blog/use-ssh-certificates/): connect SSH to SSO, improve security, and eliminate warnings & errors
|
||||
- Issue SSH user certificates using OAuth OIDC
|
||||
- Issue SSH host certificates to cloud VMs using instance identity documents
|
||||
|
||||
- [Use ACME in development & pre-production](https://smallstep.com/blog/private-acme-server/#local-development--pre-production)
|
||||
- Supports the most popular [ACME challenge types](https://letsencrypt.org/docs/challenge-types/):
|
||||
- For `http-01`, place a token at a well-known URL to prove that you control the web server
|
||||
- For `dns-01`, add a `TXT` record to prove that you control the DNS record set
|
||||
- For `tls-alpn-01`, respond to the challenge at the TLS layer ([as Caddy does](https://caddy.community/t/caddy-supports-the-acme-tls-alpn-challenge/4860)) to prove that you control the web server
|
||||
|
||||
- Works with any ACME client. We've written examples for:
|
||||
- [certbot](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#certbot)
|
||||
- [acme.sh](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#acmesh)
|
||||
- [win-acme](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#win-acme)
|
||||
- [Caddy](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#caddy-v2)
|
||||
- [Traefik](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#traefik)
|
||||
- [Apache](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#apache)
|
||||
- [nginx](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#nginx)
|
||||
- Get certificates programmatically using ACME, using these libraries:
|
||||
- [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#golang))
|
||||
- certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#python))
|
||||
- [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#node))
|
||||
- Our own [`step` CLI tool](https://github.com/smallstep/cli) is also an ACME client!
|
||||
- See our [ACME tutorial](https://smallstep.com/docs/tutorials/acme-challenge) for more
|
||||
|
||||
### 👩🏽💻 An online SSH Certificate Authority
|
||||
|
||||
- Delegate SSH authentication to `step-ca` by using [SSH certificates](https://smallstep.com/blog/use-ssh-certificates/) instead of public keys and `authorized_keys` files
|
||||
- For user certificates, [connect SSH to your single sign-on provider](https://smallstep.com/blog/diy-single-sign-on-for-ssh/), to improve security with short-lived certificates and MFA (or other security policies) via any OAuth OIDC provider.
|
||||
- For host certificates, improve security, [eliminate TOFU warnings](https://smallstep.com/blog/use-ssh-certificates/), and set up automated host certificate renewal.
|
||||
|
||||
### 🤓 A general purpose PKI tool, via [`step` CLI](https://github.com/smallstep/cli) [integration](https://smallstep.com/docs/step-cli/reference/ca/)
|
||||
### Easy certificate management and automation via [`step` CLI](https://github.com/smallstep/cli) [integration](https://smallstep.com/docs/cli/ca/)
|
||||
|
||||
- Generate key pairs where they're needed so private keys are never transmitted across the network
|
||||
- [Authenticate and obtain a certificate](https://smallstep.com/docs/step-cli/reference/ca/certificate/) using any provisioner supported by `step-ca`
|
||||
- Securely [distribute root certificates](https://smallstep.com/docs/step-cli/reference/ca/root/) and [bootstrap](https://smallstep.com/docs/step-cli/reference/ca/bootstrap/) PKI relying parties
|
||||
- [Renew](https://smallstep.com/docs/step-cli/reference/ca/renew/) and [revoke](https://smallstep.com/docs/step-cli/reference/ca/revoke/) certificates issued by `step-ca`
|
||||
- [Install root certificates](https://smallstep.com/docs/step-cli/reference/certificate/install/) on your machine and browsers, so your CA is trusted
|
||||
- [Inspect](https://smallstep.com/docs/step-cli/reference/certificate/inspect/) and [lint](https://smallstep.com/docs/step-cli/reference/certificate/lint/) certificates
|
||||
- [Authenticate and obtain a certificate](https://smallstep.com/docs/cli/ca/certificate/) using any enrollment mechanism supported by `step-ca`
|
||||
- Securely [distribute root certificates](https://smallstep.com/docs/cli/ca/root/) and [bootstrap](https://smallstep.com/docs/cli/ca/bootstrap/) PKI relying parties
|
||||
- [Renew](https://smallstep.com/docs/cli/ca/renew/) and [revoke](https://smallstep.com/docs/cli/ca/revoke/) certificates issued by `step-ca`
|
||||
- [Install root certificates](https://smallstep.com/docs/cli/certificate/install/) so your CA is trusted by default (issue development certificates **that [work in browsers](https://smallstep.com/blog/step-v0-8-6-valid-HTTPS-certificates-for-dev-pre-prod.html)**)
|
||||
- [Inspect](https://smallstep.com/docs/cli/certificate/inspect/) and [lint](https://smallstep.com/docs/cli/certificate/lint/) certificates
|
||||
|
||||
## Installation
|
||||
## Motivation
|
||||
|
||||
See our installation docs [here](https://smallstep.com/docs/step-ca/installation).
|
||||
Managing your own *public key infrastructure* (PKI) can be tedious and error
|
||||
prone. Good security hygiene is hard. Setting up simple PKI is out of reach for
|
||||
many small teams, and following best practices like proper certificate
|
||||
revocation and rolling is challenging even for experts.
|
||||
|
||||
Amongst numerous use cases, proper PKI makes it easy to use mTLS (mutual TLS)
|
||||
to improve security and to make it possible to connect services across the
|
||||
public internet. Unlike VPNs & SDNs, deploying and scaling mTLS is pretty
|
||||
easy. You're (hopefully) already using TLS, and your existing tools and
|
||||
standard libraries will provide most of what you need. If you know how to
|
||||
operate DNS and reverse proxies, you know how to operate mTLS
|
||||
infrastructure.
|
||||
|
||||

|
||||
|
||||
There's just one problem: **you need certificates issued by your own
|
||||
certificate authority (CA)**. Building and operating a CA, issuing
|
||||
certificates, and making sure they're renewed before they expire is tricky.
|
||||
This project provides the infrastructure, automations, and workflows you'll
|
||||
need.
|
||||
|
||||
`step certificates` is part of smallstep's broader security architecture, which
|
||||
makes it much easier to implement good security practices early, and
|
||||
incrementally improve them as your system matures.
|
||||
|
||||
For more information and [docs](https://smallstep.com/docs) see [the smallstep
|
||||
website](https://smallstep.com/certificates) and the [blog
|
||||
post](https://smallstep.com/blog/step-certificates.html) announcing this project.
|
||||
|
||||
## Installation Guide
|
||||
|
||||
These instructions will install an OS specific version of the `step-ca` binary on
|
||||
your local machine.
|
||||
|
||||
While `step` is not required to run `step-ca`, it will make your life easier so you'll probably want to [install it](https://github.com/smallstep/cli#installation-guide) too.
|
||||
|
||||
### Mac OS
|
||||
|
||||
Install `step` and `step-ca` together via [Homebrew](https://brew.sh/):
|
||||
|
||||
<pre><code><b>$ brew install step</b>
|
||||
|
||||
# Test installation ...
|
||||
<b>$ step certificate inspect https://smallstep.com</b>
|
||||
Certificate:
|
||||
Data:
|
||||
Version: 3 (0x2)
|
||||
Serial Number: 326381749415081530968054238478851085504954 (0x3bf265673332db2d0c70e48a163fb7d11ba)
|
||||
Signature Algorithm: SHA256-RSA
|
||||
Issuer: C=US,O=Let's Encrypt,CN=Let's Encrypt Authority X3
|
||||
...</code></pre>
|
||||
|
||||
> Note: If you have installed `step` previously through the `smallstep/smallstep`
|
||||
> tap you will need to run the following commands before installing:
|
||||
>
|
||||
> ```
|
||||
> $ brew untap smallstep/smallstep
|
||||
> $ brew uninstall step
|
||||
> ```
|
||||
|
||||
### Linux
|
||||
|
||||
#### Debian
|
||||
|
||||
1. [Optional] Install `step`.
|
||||
|
||||
Download the latest Debian package from
|
||||
[`step` releases](https://github.com/smallstep/cli/releases):
|
||||
|
||||
```
|
||||
$ wget https://github.com/smallstep/cli/releases/download/vX.Y.Z/step-cli_X.Y.Z_amd64.deb
|
||||
```
|
||||
|
||||
Install the Debian package:
|
||||
|
||||
```
|
||||
$ sudo dpkg -i step-cli_X.Y.Z_amd64.deb
|
||||
```
|
||||
|
||||
2. Install `step-ca`.
|
||||
|
||||
Download the latest Debian package from [releases](https://github.com/smallstep/certificates/releases):
|
||||
|
||||
```
|
||||
$ wget https://github.com/smallstep/certificates/releases/download/vX.Y.Z/step-certificates_X.Y.Z_amd64.deb
|
||||
```
|
||||
|
||||
Install the Debian package:
|
||||
|
||||
```
|
||||
$ sudo dpkg -i step-certificates_X.Y.Z_amd64.deb
|
||||
```
|
||||
|
||||
#### Arch Linux
|
||||
|
||||
We are using the [Arch User Repository](https://aur.archlinux.org) to distribute
|
||||
`step` binaries for Arch Linux.
|
||||
|
||||
* [Optional] The `step` binary tarball can be found [here](https://aur.archlinux.org/packages/step-cli-bin/).
|
||||
* The `step-ca` binary tarball can be found [here](https://aur.archlinux.org/packages/step-ca-bin/).
|
||||
|
||||
You can use [pacman](https://www.archlinux.org/pacman/) to install the packages.
|
||||
|
||||
### Kubernetes
|
||||
|
||||
We publish [helm charts](https://hub.helm.sh/charts/smallstep/step-certificates) for easy installation on kubernetes:
|
||||
|
||||
```
|
||||
helm install step-certificates
|
||||
```
|
||||
|
||||
> <a href="https://github.com/smallstep/autocert"><img width="25%" src="https://raw.githubusercontent.com/smallstep/autocert/master/autocert-logo.png"></a>
|
||||
>
|
||||
> If you're using Kubernetes, make sure you [check out
|
||||
> autocert](https://github.com/smallstep/autocert): a kubernetes add-on that builds on `step
|
||||
> certificates` to automatically inject TLS/HTTPS certificates into your containers.
|
||||
|
||||
### Test
|
||||
|
||||
<pre><code><b>$ step version</b>
|
||||
Smallstep CLI/0.10.0 (darwin/amd64)
|
||||
Release Date: 2019-04-30 19:01 UTC
|
||||
|
||||
<b>$ step-ca version</b>
|
||||
Smallstep CA/0.10.0 (darwin/amd64)
|
||||
Release Date: 2019-04-30 19:02 UTC</code></pre>
|
||||
|
||||
## Quickstart
|
||||
|
||||
In the following guide we'll run a simple `hello` server that requires clients
|
||||
to connect over an authorized and encrypted channel using HTTPS. `step-ca`
|
||||
will issue certificates to our server, allowing it to authenticate and encrypt
|
||||
communication. Let's get started!
|
||||
|
||||
### Prerequisites
|
||||
|
||||
* [`step`](#installation-guide)
|
||||
* [golang](https://golang.org/doc/install)
|
||||
|
||||
### Let's get started!
|
||||
|
||||
#### 1. Run `step ca init` to create your CA's keys & certificates and configure `step-ca`:
|
||||
|
||||
<pre><code><b>$ step ca init</b>
|
||||
✔ What would you like to name your new PKI? (e.g. Smallstep): <b>Example Inc.</b>
|
||||
✔ What DNS names or IP addresses would you like to add to your new CA? (e.g. ca.smallstep.com[,1.1.1.1,etc.]): <b>localhost</b>
|
||||
✔ What address will your new CA listen at? (e.g. :443): <b>127.0.0.1:8080</b>
|
||||
✔ What would you like to name the first provisioner for your new CA? (e.g. you@smallstep.com): <b>bob@example.com</b>
|
||||
✔ What do you want your password to be? [leave empty and we'll generate one]: <b>abc123</b>
|
||||
|
||||
Generating root certificate...
|
||||
all done!
|
||||
|
||||
Generating intermediate certificate...
|
||||
all done!
|
||||
|
||||
✔ Root certificate: /Users/bob/src/github.com/smallstep/step/.step/certs/root_ca.crt
|
||||
✔ Root private key: /Users/bob/src/github.com/smallstep/step/.step/secrets/root_ca_key
|
||||
✔ Root fingerprint: 702a094e239c9eec6f0dcd0a5f65e595bf7ed6614012825c5fe3d1ae1b2fd6ee
|
||||
✔ Intermediate certificate: /Users/bob/src/github.com/smallstep/step/.step/certs/intermediate_ca.crt
|
||||
✔ Intermediate private key: /Users/bob/src/github.com/smallstep/step/.step/secrets/intermediate_ca_key
|
||||
✔ Default configuration: /Users/bob/src/github.com/smallstep/step/.step/config/defaults.json
|
||||
✔ Certificate Authority configuration: /Users/bob/src/github.com/smallstep/step/.step/config/ca.json
|
||||
|
||||
Your PKI is ready to go. To generate certificates for individual services see 'step help ca'.</code></pre>
|
||||
|
||||
This command will:
|
||||
|
||||
- Generate [password protected](https://github.com/smallstep/certificates/blob/master/docs/GETTING_STARTED.md#passwords) private keys for your CA to sign certificates
|
||||
- Generate a root and [intermediate signing certificate](https://security.stackexchange.com/questions/128779/why-is-it-more-secure-to-use-intermediate-ca-certificates) for your CA
|
||||
- Create a JSON configuration file for `step-ca` (see [getting started](./docs/GETTING_STARTED.md) for details)
|
||||
|
||||
You can find these artifacts in `$STEPPATH` (or `~/.step` by default).
|
||||
|
||||
#### 2. Start `step-ca`:
|
||||
|
||||
You'll be prompted for your password from the previous step, to decrypt the CA's private signing key:
|
||||
|
||||
<pre><code><b>$ step-ca $(step path)/config/ca.json</b>
|
||||
Please enter the password to decrypt /Users/bob/src/github.com/smallstep/step/.step/secrets/intermediate_ca_key: <b>abc123</b>
|
||||
2019/02/18 13:28:58 Serving HTTPS on 127.0.0.1:8080 ...</code></pre>
|
||||
|
||||
#### 3. Copy our `hello world` golang server.
|
||||
|
||||
```
|
||||
$ cat > srv.go <<EOF
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"log"
|
||||
)
|
||||
|
||||
func HiHandler(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Write([]byte("Hello, world!\n"))
|
||||
}
|
||||
|
||||
func main() {
|
||||
http.HandleFunc("/hi", HiHandler)
|
||||
err := http.ListenAndServeTLS(":8443", "srv.crt", "srv.key", nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
EOF
|
||||
```
|
||||
|
||||
#### 4. Get an identity for your server from the Step CA.
|
||||
|
||||
<pre><code><b>$ step ca certificate localhost srv.crt srv.key</b>
|
||||
✔ Key ID: rQxROEr7Kx9TNjSQBTETtsu3GKmuW9zm02dMXZ8GUEk (bob@example.com)
|
||||
✔ Please enter the password to decrypt the provisioner key: abc123
|
||||
✔ CA: https://localhost:8080/1.0/sign
|
||||
✔ Certificate: srv.crt
|
||||
✔ Private Key: srv.key
|
||||
|
||||
<b>$ step certificate inspect --bundle srv.crt</b>
|
||||
Certificate:
|
||||
Data:
|
||||
Version: 3 (0x2)
|
||||
Serial Number: 140439335711218707689123407681832384336 (0x69a7a1d7f6f22f68059d2d9088307750)
|
||||
Signature Algorithm: ECDSA-SHA256
|
||||
Issuer: CN=Example Inc. Intermediate CA
|
||||
Validity
|
||||
Not Before: Feb 18 21:32:35 2019 UTC
|
||||
Not After : Feb 19 21:32:35 2019 UTC
|
||||
Subject: CN=localhost
|
||||
...
|
||||
Certificate:
|
||||
Data:
|
||||
Version: 3 (0x2)
|
||||
Serial Number: 207035091234452090159026162349261226844 (0x9bc18217bd560cf07db23178ed90835c)
|
||||
Signature Algorithm: ECDSA-SHA256
|
||||
Issuer: CN=Example Inc. Root CA
|
||||
Validity
|
||||
Not Before: Feb 18 21:27:21 2019 UTC
|
||||
Not After : Feb 15 21:27:21 2029 UTC
|
||||
Subject: CN=Example Inc. Intermediate CA
|
||||
...</code></pre>
|
||||
|
||||
Note that `step` and `step-ca` handle details like [certificate bundling](https://smallstep.com/blog/everything-pki.html#intermediates-chains-and-bundling) for you.
|
||||
|
||||
#### 5. Run the simple server.
|
||||
|
||||
<pre><code><b>$ go run srv.go &</b></code></pre>
|
||||
|
||||
#### 6. Get the root certificate from the Step CA.
|
||||
|
||||
In a new Terminal window:
|
||||
|
||||
<pre><code><b>$ step ca root root.crt</b>
|
||||
The root certificate has been saved in root.crt.</code></pre>
|
||||
|
||||
#### 7. Make an authenticated, encrypted curl request to your server using HTTP over TLS.
|
||||
|
||||
<pre><code><b>$ curl --cacert root.crt https://localhost:8443/hi</b>
|
||||
Hello, world!</code></pre>
|
||||
|
||||
*All Done!*
|
||||
|
||||
Check out the [Getting Started](./docs/GETTING_STARTED.md) guide for more examples
|
||||
and best practices on running Step CA in production.
|
||||
|
||||
## Documentation
|
||||
|
||||
* [Official documentation](https://smallstep.com/docs/step-ca) is on smallstep.com
|
||||
* The `step` command reference is available via `step help`,
|
||||
[on smallstep.com](https://smallstep.com/docs/step-cli/reference/),
|
||||
or by running `step help --http=:8080` from the command line
|
||||
Documentation can be found in a handful of different places:
|
||||
|
||||
1. The [docs](./docs/README.md) sub-repo has an index of documentation and tutorials.
|
||||
|
||||
2. On the command line with `step ca help xxx` where `xxx` is the subcommand
|
||||
you are interested in. Ex: `step help ca provisioners list`.
|
||||
|
||||
3. On the web at https://smallstep.com/docs/certificates.
|
||||
|
||||
4. On your browser by running `step help --http=:8080 ca` from the command line
|
||||
and visiting http://localhost:8080.
|
||||
|
||||
## Feedback?
|
||||
|
||||
* Tell us what you like and don't like about managing your PKI - we're eager to help solve problems in this space.
|
||||
* Tell us about a feature you'd like to see! [Add a feature request Issue](https://github.com/smallstep/certificates/issues/new?assignees=&labels=enhancement%2C+needs+triage&template=enhancement.md&title=), [ask on Discussions](https://github.com/smallstep/certificates/discussions), or hit us up on [Twitter](https://twitter.com/smallsteplabs).
|
||||
## The Future
|
||||
|
||||
We plan to build more tools that facilitate the use and management of zero trust
|
||||
networks.
|
||||
|
||||
* Tell us what you like and don't like about managing your PKI - we're eager to
|
||||
help solve problems in this space.
|
||||
* Tell us what features you'd like to see - open issues or hit us on
|
||||
[Twitter](https://twitter.com/smallsteplabs).
|
||||
|
||||
## Further Reading
|
||||
|
||||
Check out the [Getting Started](https://smallstep.com/docs/getting-started/) guide for more examples
|
||||
and best practices on running Step CA in production.
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
We appreciate any effort to discover and disclose security vulnerabilities responsibly.
|
||||
|
||||
If you would like to report a vulnerability in one of our projects, or have security concerns regarding Smallstep software, please email security@smallstep.com.
|
||||
|
||||
In order for us to best respond to your report, please include any of the following:
|
||||
* Steps to reproduce or proof-of-concept
|
||||
* Any relevant tools, including versions used
|
||||
* Tool output
|
254
acme/account.go
254
acme/account.go
|
@ -1,134 +1,214 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Account is a subset of the internal account type containing only those
|
||||
// attributes required for responses in the ACME protocol.
|
||||
type Account struct {
|
||||
ID string `json:"-"`
|
||||
Key *jose.JSONWebKey `json:"-"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status Status `json:"status"`
|
||||
OrdersURL string `json:"orders"`
|
||||
ExternalAccountBinding interface{} `json:"externalAccountBinding,omitempty"`
|
||||
LocationPrefix string `json:"-"`
|
||||
ProvisionerName string `json:"-"`
|
||||
}
|
||||
|
||||
// GetLocation returns the URL location of the given account.
|
||||
func (a *Account) GetLocation() string {
|
||||
if a.LocationPrefix == "" {
|
||||
return ""
|
||||
}
|
||||
return a.LocationPrefix + a.ID
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Orders string `json:"orders"`
|
||||
ID string `json:"-"`
|
||||
Key *jose.JSONWebKey `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (a *Account) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, WrapErrorISE(err, "error marshaling account for logging")
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the account ID.
|
||||
func (a *Account) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
// GetKey returns the JWK associated with the account.
|
||||
func (a *Account) GetKey() *jose.JSONWebKey {
|
||||
return a.Key
|
||||
}
|
||||
|
||||
// IsValid returns true if the Account is valid.
|
||||
func (a *Account) IsValid() bool {
|
||||
return a.Status == StatusValid
|
||||
}
|
||||
|
||||
// KeyToID converts a JWK to a thumbprint.
|
||||
func KeyToID(jwk *jose.JSONWebKey) (string, error) {
|
||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||
// AccountOptions are the options needed to create a new ACME account.
|
||||
type AccountOptions struct {
|
||||
Key *jose.JSONWebKey
|
||||
Contact []string
|
||||
}
|
||||
|
||||
// account represents an ACME account.
|
||||
type account struct {
|
||||
ID string `json:"id"`
|
||||
Created time.Time `json:"created"`
|
||||
Deactivated time.Time `json:"deactivated"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// newAccount returns a new acme account type.
|
||||
func newAccount(db nosql.DB, ops AccountOptions) (*account, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return "", WrapErrorISE(err, "error generating jwk thumbprint")
|
||||
return nil, err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||
|
||||
a := &account{
|
||||
ID: id,
|
||||
Key: ops.Key,
|
||||
Contact: ops.Contact,
|
||||
Status: "valid",
|
||||
Created: clock.Now(),
|
||||
}
|
||||
return a, a.saveNew(db)
|
||||
}
|
||||
|
||||
// PolicyNames contains ACME account level policy names
|
||||
type PolicyNames struct {
|
||||
DNSNames []string `json:"dns"`
|
||||
IPRanges []string `json:"ips"`
|
||||
// toACME converts the internal Account type into the public acmeAccount
|
||||
// type for presentation in the ACME protocol.
|
||||
func (a *account) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Account, error) {
|
||||
return &Account{
|
||||
Status: a.Status,
|
||||
Contact: a.Contact,
|
||||
Orders: dir.getLink(OrdersByAccountLink, URLSafeProvisionerName(p), true, a.ID),
|
||||
Key: a.Key,
|
||||
ID: a.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// X509Policy contains ACME account level X.509 policy
|
||||
type X509Policy struct {
|
||||
Allowed PolicyNames `json:"allow"`
|
||||
Denied PolicyNames `json:"deny"`
|
||||
AllowWildcardNames bool `json:"allowWildcardNames"`
|
||||
}
|
||||
// save writes the Account to the DB.
|
||||
// If the account is new then the necessary indices will be created.
|
||||
// Else, the account in the DB will be updated.
|
||||
func (a *account) saveNew(db nosql.DB) error {
|
||||
kid, err := keyToID(a.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
kidB := []byte(kid)
|
||||
|
||||
// Policy is an ACME Account level policy
|
||||
type Policy struct {
|
||||
X509 X509Policy `json:"x509"`
|
||||
}
|
||||
|
||||
func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions {
|
||||
if p == nil {
|
||||
// Set the jwkID -> acme account ID index
|
||||
_, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID))
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("key-id to account-id index already exists"))
|
||||
default:
|
||||
if err = a.save(db, nil); err != nil {
|
||||
db.Del(accountByKeyIDTable, kidB)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &policy.X509NameOptions{
|
||||
DNSDomains: p.X509.Allowed.DNSNames,
|
||||
IPRanges: p.X509.Allowed.IPRanges,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions {
|
||||
if p == nil {
|
||||
func (a *account) save(db nosql.DB, old *account) error {
|
||||
var (
|
||||
err error
|
||||
oldB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
||||
}
|
||||
}
|
||||
|
||||
b, err := json.Marshal(*a)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error marshaling new account object")
|
||||
}
|
||||
// Set the Account
|
||||
_, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error storing account"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.New("error storing account; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return &policy.X509NameOptions{
|
||||
DNSDomains: p.X509.Denied.DNSNames,
|
||||
IPRanges: p.X509.Denied.IPRanges,
|
||||
}
|
||||
|
||||
// update updates the acme account object stored in the database if,
|
||||
// and only if, the account has not changed since the last read.
|
||||
func (a *account) update(db nosql.DB, contact []string) (*account, error) {
|
||||
b := *a
|
||||
b.Contact = contact
|
||||
if err := b.save(db, a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// AreWildcardNamesAllowed returns if wildcard names
|
||||
// like *.example.com are allowed to be signed.
|
||||
// Defaults to false.
|
||||
func (p *Policy) AreWildcardNamesAllowed() bool {
|
||||
if p == nil {
|
||||
return false
|
||||
// deactivate deactivates the acme account.
|
||||
func (a *account) deactivate(db nosql.DB) (*account, error) {
|
||||
b := *a
|
||||
b.Status = StatusDeactivated
|
||||
b.Deactivated = clock.Now()
|
||||
if err := b.save(db, a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.X509.AllowWildcardNames
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// ExternalAccountKey is an ACME External Account Binding key.
|
||||
type ExternalAccountKey struct {
|
||||
ID string `json:"id"`
|
||||
ProvisionerID string `json:"provisionerID"`
|
||||
Reference string `json:"reference"`
|
||||
AccountID string `json:"-"`
|
||||
HmacKey []byte `json:"-"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
BoundAt time.Time `json:"boundAt,omitempty"`
|
||||
Policy *Policy `json:"policy,omitempty"`
|
||||
}
|
||||
|
||||
// AlreadyBound returns whether this EAK is already bound to
|
||||
// an ACME Account or not.
|
||||
func (eak *ExternalAccountKey) AlreadyBound() bool {
|
||||
return !eak.BoundAt.IsZero()
|
||||
}
|
||||
|
||||
// BindTo binds the EAK to an Account.
|
||||
// It returns an error if it's already bound.
|
||||
func (eak *ExternalAccountKey) BindTo(account *Account) error {
|
||||
if eak.AlreadyBound() {
|
||||
return NewError(ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", eak.ID, eak.AccountID, eak.BoundAt)
|
||||
// getAccountByID retrieves the account with the given ID.
|
||||
func getAccountByID(db nosql.DB, id string) (*account, error) {
|
||||
ab, err := db.Get(accountTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id))
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id))
|
||||
}
|
||||
eak.AccountID = account.ID
|
||||
eak.BoundAt = time.Now()
|
||||
eak.HmacKey = []byte{} // clearing the key bytes; can only be used once
|
||||
return nil
|
||||
|
||||
a := new(account)
|
||||
if err = json.Unmarshal(ab, a); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account"))
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// getAccountByKeyID retrieves Id associated with the given Kid.
|
||||
func getAccountByKeyID(db nosql.DB, kid string) (*account, error) {
|
||||
id, err := db.Get(accountByKeyIDTable, []byte(kid))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid))
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index"))
|
||||
}
|
||||
return getAccountByID(db, string(id))
|
||||
}
|
||||
|
||||
// getOrderIDsByAccount retrieves a list of Order IDs that were created by the
|
||||
// account.
|
||||
func getOrderIDsByAccount(db nosql.DB, id string) ([]string, error) {
|
||||
b, err := db.Get(ordersByAccountIDTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return []string{}, nil
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", id))
|
||||
}
|
||||
var orderIDs []string
|
||||
if err := json.Unmarshal(b, &orderIDs); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", id))
|
||||
}
|
||||
return orderIDs, nil
|
||||
}
|
||||
|
|
|
@ -1,163 +1,843 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestKeyToID(t *testing.T) {
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}
|
||||
)
|
||||
|
||||
func newProv() provisioner.Interface {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func newAcc() (*account, error) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
return newAccount(mockdb, AccountOptions{
|
||||
Key: jwk, Contact: []string{"foo", "bar"},
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAccountByID(t *testing.T) {
|
||||
type test struct {
|
||||
jwk *jose.JSONWebKey
|
||||
exp string
|
||||
id string
|
||||
db nosql.DB
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/error-generating-thumbprint": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
jwk.Key = "foo"
|
||||
return test{
|
||||
jwk: jwk,
|
||||
err: NewErrorISE("error generating jwk thumbprint: square/go-jose: unknown key type 'string'"),
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
jwk: jwk,
|
||||
exp: base64.RawURLEncoding.EncodeToString(kid),
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if id, err := KeyToID(tc.jwk); err != nil {
|
||||
if acc, err := getAccountByID(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
var k *Error
|
||||
if errors.As(err, &k) {
|
||||
assert.Equals(t, k.Type, tc.err.Type)
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
assert.Equals(t, k.Status, tc.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
} else {
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, id, tc.exp)
|
||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_GetLocation(t *testing.T) {
|
||||
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
|
||||
func TestGetAccountByKeyID(t *testing.T) {
|
||||
type test struct {
|
||||
acc *Account
|
||||
exp string
|
||||
kid string
|
||||
db nosql.DB
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]test{
|
||||
"empty": {acc: &Account{LocationPrefix: ""}, exp: ""},
|
||||
"not-empty": {acc: &Account{ID: "bar", LocationPrefix: locationPrefix}, exp: locationPrefix + "bar"},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert.Equals(t, tc.acc.GetLocation(), tc.exp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_IsValid(t *testing.T) {
|
||||
type test struct {
|
||||
acc *Account
|
||||
exp bool
|
||||
}
|
||||
tests := map[string]test{
|
||||
"valid": {acc: &Account{Status: StatusValid}, exp: true},
|
||||
"invalid": {acc: &Account{Status: StatusInvalid}, exp: false},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert.Equals(t, tc.acc.IsValid(), tc.exp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAccountKey_BindTo(t *testing.T) {
|
||||
boundAt := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
eak *ExternalAccountKey
|
||||
acct *Account
|
||||
err *Error
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
eak: &ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
},
|
||||
acct: &Account{
|
||||
ID: "accountID",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "fail/already-bound",
|
||||
eak: &ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
AccountID: "someAccountID",
|
||||
BoundAt: boundAt,
|
||||
},
|
||||
acct: &Account{
|
||||
ID: "accountID",
|
||||
},
|
||||
err: NewError(ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", "eakID", "someAccountID", boundAt),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
eak := tt.eak
|
||||
acct := tt.acct
|
||||
err := eak.BindTo(acct)
|
||||
wantErr := tt.err != nil
|
||||
gotErr := err != nil
|
||||
if wantErr != gotErr {
|
||||
t.Errorf("ExternalAccountKey.BindTo() error = %v, wantErr %v", err, tt.err)
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/kid-not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("account with key id foo not found: not found")),
|
||||
}
|
||||
if wantErr {
|
||||
assert.NotNil(t, err)
|
||||
var ae *Error
|
||||
if assert.True(t, errors.As(err, &ae)) {
|
||||
assert.Equals(t, ae.Type, tt.err.Type)
|
||||
assert.Equals(t, ae.Detail, tt.err.Detail)
|
||||
assert.Equals(t, ae.Subproblems, tt.err.Subproblems)
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading key-account index: force")),
|
||||
}
|
||||
},
|
||||
"fail/getAccount-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
count++
|
||||
return []byte("bar"), nil
|
||||
}
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading account bar: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
kid: acc.Key.KeyID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
var ret []byte
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(acc.Key.KeyID))
|
||||
ret = []byte(acc.ID)
|
||||
case 1:
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
ret = b
|
||||
}
|
||||
count++
|
||||
return ret, nil
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Equals(t, eak.AccountID, acct.ID)
|
||||
assert.Equals(t, eak.HmacKey, []byte{})
|
||||
assert.NotNil(t, eak.BoundAt)
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountIDsByAccount(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
res []string
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
res: []string{},
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
oids := []string{"foo", "bar", "baz"}
|
||||
b, err := json.Marshal(oids)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
res: oids,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if oids, err := getOrderIDsByAccount(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.res, oids)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountToACME(t *testing.T) {
|
||||
dir := newDirectory("ca.smallstep.com", "acme")
|
||||
prov := newProv()
|
||||
|
||||
type test struct {
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{acc: acc}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acmeAccount, err := tc.acc.toACME(nil, dir, prov)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acmeAccount.ID, tc.acc.ID)
|
||||
assert.Equals(t, acmeAccount.Status, tc.acc.Status)
|
||||
assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
|
||||
assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
|
||||
assert.Equals(t, acmeAccount.Orders,
|
||||
fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s/orders", URLSafeProvisionerName(prov), tc.acc.ID))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSave(t *testing.T) {
|
||||
type test struct {
|
||||
acc, old *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok/old-nil": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, []byte(acc.ID), key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/old-not-nil": func(t *testing.T) test {
|
||||
oldAcc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
oldb, err := json.Marshal(oldAcc)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: oldAcc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, []byte(acc.ID), key)
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.acc.save(tc.db, tc.old); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSaveNew(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/keyToID-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
acc.Key.Key = "foo"
|
||||
return test{
|
||||
acc: acc,
|
||||
err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
|
||||
}
|
||||
},
|
||||
"fail/swap-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||
}
|
||||
},
|
||||
"fail/swap-false": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("key-id to account-id index already exists")),
|
||||
}
|
||||
},
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
count++
|
||||
return nil, true, nil
|
||||
}
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
MDel: func(bucket, key []byte) error {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
return nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
count++
|
||||
return nil, true, nil
|
||||
}
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.acc.saveNew(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountUpdate(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
contact []string
|
||||
db nosql.DB
|
||||
res []byte
|
||||
err *Error
|
||||
}
|
||||
contact := []string{"foo", "bar"}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_acc := *acc
|
||||
clone := &_acc
|
||||
clone.Contact = contact
|
||||
b, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
contact: contact,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_acc := *acc
|
||||
clone := &_acc
|
||||
clone.Contact = contact
|
||||
b, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
contact: contact,
|
||||
res: b,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
acc, err := tc.acc.update(tc.db, tc.contact)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, b, tc.res)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountDeactivate(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
acc, err := tc.acc.deactivate(tc.db)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.acc.ID)
|
||||
assert.Equals(t, acc.Contact, tc.acc.Contact)
|
||||
assert.Equals(t, acc.Status, StatusDeactivated)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID)
|
||||
assert.Equals(t, acc.Created, tc.acc.Created)
|
||||
|
||||
assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAccount(t *testing.T) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(jwk)
|
||||
assert.FatalError(t, err)
|
||||
ops := AccountOptions{
|
||||
Key: jwk,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
type test struct {
|
||||
ops AccountOptions
|
||||
db nosql.DB
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/store-error": func(t *testing.T) test {
|
||||
return test{
|
||||
ops: ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var _id string
|
||||
id := &_id
|
||||
count := 0
|
||||
return test{
|
||||
ops: ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
case 1:
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
*id = string(key)
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acc, err := newAccount(tc.db, tc.ops)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, *tc.id)
|
||||
assert.Equals(t, acc.Status, StatusValid)
|
||||
assert.Equals(t, acc.Contact, ops.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID)
|
||||
|
||||
assert.True(t, acc.Deactivated.IsZero())
|
||||
|
||||
assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,30 +1,27 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// NewAccountRequest represents the payload for a new account request.
|
||||
type NewAccountRequest struct {
|
||||
Contact []string `json:"contact"`
|
||||
OnlyReturnExisting bool `json:"onlyReturnExisting"`
|
||||
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"`
|
||||
ExternalAccountBinding *ExternalAccountBinding `json:"externalAccountBinding,omitempty"`
|
||||
Contact []string `json:"contact"`
|
||||
OnlyReturnExisting bool `json:"onlyReturnExisting"`
|
||||
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"`
|
||||
}
|
||||
|
||||
func validateContacts(cs []string) error {
|
||||
for _, c := range cs {
|
||||
if c == "" {
|
||||
return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
|
||||
if len(c) == 0 {
|
||||
return acme.MalformedErr(errors.New("contact cannot be empty string"))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -33,23 +30,29 @@ func validateContacts(cs []string) error {
|
|||
// Validate validates a new-account request body.
|
||||
func (n *NewAccountRequest) Validate() error {
|
||||
if n.OnlyReturnExisting && len(n.Contact) > 0 {
|
||||
return acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone")
|
||||
return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone"))
|
||||
}
|
||||
return validateContacts(n.Contact)
|
||||
}
|
||||
|
||||
// UpdateAccountRequest represents an update-account request.
|
||||
type UpdateAccountRequest struct {
|
||||
Contact []string `json:"contact"`
|
||||
Status acme.Status `json:"status"`
|
||||
Contact []string `json:"contact"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// IsDeactivateRequest returns true if the update request is a deactivation
|
||||
// request, false otherwise.
|
||||
func (u *UpdateAccountRequest) IsDeactivateRequest() bool {
|
||||
return u.Status == acme.StatusDeactivated
|
||||
}
|
||||
|
||||
// Validate validates a update-account request body.
|
||||
func (u *UpdateAccountRequest) Validate() error {
|
||||
switch {
|
||||
case len(u.Status) > 0 && len(u.Contact) > 0:
|
||||
return acme.NewError(acme.ErrorMalformedType, "incompatible input; contact and "+
|
||||
"status updates are mutually exclusive")
|
||||
return acme.MalformedErr(errors.New("incompatible input; contact and " +
|
||||
"status updates are mutually exclusive"))
|
||||
case len(u.Contact) > 0:
|
||||
if err := validateContacts(u.Contact); err != nil {
|
||||
return err
|
||||
|
@ -57,162 +60,117 @@ func (u *UpdateAccountRequest) Validate() error {
|
|||
return nil
|
||||
case len(u.Status) > 0:
|
||||
if u.Status != acme.StatusDeactivated {
|
||||
return acme.NewError(acme.ErrorMalformedType, "cannot update account "+
|
||||
"status to %s, only deactivated", u.Status)
|
||||
return acme.MalformedErr(errors.Errorf("cannot update account "+
|
||||
"status to %s, only deactivated", u.Status))
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
// According to the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2)
|
||||
// accountUpdate should ignore any fields not recognized by the server.
|
||||
return nil
|
||||
return acme.MalformedErr(errors.Errorf("empty update request"))
|
||||
}
|
||||
}
|
||||
|
||||
// getAccountLocationPath returns the current account URL location.
|
||||
// Returned location will be of the form: https://<ca-url>/acme/<provisioner>/account/<accID>
|
||||
func getAccountLocationPath(ctx context.Context, linker acme.Linker, accID string) string {
|
||||
return linker.GetLink(ctx, acme.AccountLinkType, accID)
|
||||
}
|
||||
|
||||
// NewAccount is the handler resource for creating new ACME accounts.
|
||||
func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var nar NewAccountRequest
|
||||
if err := json.Unmarshal(payload.value, &nar); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
"failed to unmarshal new-account request payload"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
||||
"failed to unmarshal new-account request payload")))
|
||||
return
|
||||
}
|
||||
if err := nar.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
httpStatus := http.StatusCreated
|
||||
acc, err := accountFromContext(ctx)
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if !errors.As(err, &acmeErr) || acmeErr.Status != http.StatusBadRequest {
|
||||
acmeErr, ok := err.(*acme.Error)
|
||||
if !ok || acmeErr.Status != http.StatusNotFound {
|
||||
// Something went wrong ...
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Account does not exist //
|
||||
if nar.OnlyReturnExisting {
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
|
||||
"account does not exist"))
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
||||
return
|
||||
}
|
||||
|
||||
jwk, err := jwkFromContext(ctx)
|
||||
jwk, err := jwkFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
eak, err := validateExternalAccountBinding(ctx, &nar)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
if acc, err = h.Auth.NewAccount(prov, acme.AccountOptions{
|
||||
Key: jwk,
|
||||
Contact: nar.Contact,
|
||||
}); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
acc = &acme.Account{
|
||||
Key: jwk,
|
||||
Contact: nar.Contact,
|
||||
Status: acme.StatusValid,
|
||||
LocationPrefix: getAccountLocationPath(ctx, linker, ""),
|
||||
ProvisionerName: prov.GetName(),
|
||||
}
|
||||
if err := db.CreateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
||||
return
|
||||
}
|
||||
|
||||
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
|
||||
if err := eak.BindTo(acc); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
|
||||
return
|
||||
}
|
||||
acc.ExternalAccountBinding = nar.ExternalAccountBinding
|
||||
}
|
||||
} else {
|
||||
// Account exists
|
||||
// Account exists //
|
||||
httpStatus = http.StatusOK
|
||||
}
|
||||
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID))
|
||||
render.JSONStatus(w, acc, httpStatus)
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink,
|
||||
acme.URLSafeProvisionerName(prov), true, acc.GetID()))
|
||||
api.JSONStatus(w, acc, httpStatus)
|
||||
}
|
||||
|
||||
// GetOrUpdateAccount is the api for updating an ACME account.
|
||||
func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
// GetUpdateAccount is the api for updating an ACME account.
|
||||
func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(ctx)
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// If PostAsGet just respond with the account, otherwise process like a
|
||||
// normal Post request.
|
||||
if !payload.isPostAsGet {
|
||||
var uar UpdateAccountRequest
|
||||
if err := json.Unmarshal(payload.value, &uar); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
"failed to unmarshal new-account request payload"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload")))
|
||||
return
|
||||
}
|
||||
if err := uar.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if len(uar.Status) > 0 || len(uar.Contact) > 0 {
|
||||
if len(uar.Status) > 0 {
|
||||
acc.Status = uar.Status
|
||||
} else if len(uar.Contact) > 0 {
|
||||
acc.Contact = uar.Contact
|
||||
}
|
||||
|
||||
if err := db.UpdateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
|
||||
return
|
||||
}
|
||||
var err error
|
||||
if uar.IsDeactivateRequest() {
|
||||
acc, err = h.Auth.DeactivateAccount(prov, acc.GetID())
|
||||
} else {
|
||||
acc, err = h.Auth.UpdateAccount(prov, acc.GetID(), uar.Contact)
|
||||
}
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
|
||||
render.JSON(w, acc)
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, acc.GetID()))
|
||||
api.JSON(w, acc)
|
||||
}
|
||||
|
||||
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
||||
|
@ -224,31 +182,29 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
|||
}
|
||||
}
|
||||
|
||||
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
|
||||
func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
|
||||
func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
accID := chi.URLParam(r, "accID")
|
||||
if acc.ID != accID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param")))
|
||||
return
|
||||
}
|
||||
|
||||
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
|
||||
orders, err := h.Auth.GetOrdersByAccount(prov, acc.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrdersByAccountID(ctx, orders)
|
||||
|
||||
render.JSON(w, orders)
|
||||
api.JSON(w, orders)
|
||||
logOrdersByAccount(w, orders)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
168
acme/api/eab.go
168
acme/api/eab.go
|
@ -1,168 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
)
|
||||
|
||||
// ExternalAccountBinding represents the ACME externalAccountBinding JWS
|
||||
type ExternalAccountBinding struct {
|
||||
Protected string `json:"protected"`
|
||||
Payload string `json:"payload"`
|
||||
Sig string `json:"signature"`
|
||||
}
|
||||
|
||||
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
|
||||
func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
|
||||
}
|
||||
|
||||
if !acmeProv.RequireEAB {
|
||||
//nolint:nilnil // legacy
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if nar.ExternalAccountBinding == nil {
|
||||
return nil, acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided")
|
||||
}
|
||||
|
||||
eabJSONBytes, err := json.Marshal(nar.ExternalAccountBinding)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error marshaling externalAccountBinding into bytes")
|
||||
}
|
||||
|
||||
eabJWS, err := jose.ParseJWS(string(eabJSONBytes))
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error parsing externalAccountBinding jws")
|
||||
}
|
||||
|
||||
// TODO(hs): implement strategy pattern to allow for different ways of verification (i.e. webhook call) based on configuration?
|
||||
|
||||
keyID, acmeErr := validateEABJWS(ctx, eabJWS)
|
||||
if acmeErr != nil {
|
||||
return nil, acmeErr
|
||||
}
|
||||
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||
if err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
|
||||
}
|
||||
return nil, acme.WrapErrorISE(err, "error retrieving external account key")
|
||||
}
|
||||
|
||||
if externalAccountKey == nil {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key")
|
||||
}
|
||||
|
||||
if len(externalAccountKey.HmacKey) == 0 {
|
||||
return nil, acme.NewError(acme.ErrorServerInternalType, "external account binding key with id '%s' does not have secret bytes", keyID)
|
||||
}
|
||||
|
||||
if externalAccountKey.AlreadyBound() {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt)
|
||||
}
|
||||
|
||||
payload, err := eabJWS.Verify(externalAccountKey.HmacKey)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature")
|
||||
}
|
||||
|
||||
jwk, err := jwkFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var payloadJWK *jose.JSONWebKey
|
||||
if err = json.Unmarshal(payload, &payloadJWK); err != nil {
|
||||
return nil, acme.WrapError(acme.ErrorMalformedType, err, "error unmarshaling payload into jwk")
|
||||
}
|
||||
|
||||
if !keysAreEqual(jwk, payloadJWK) {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "keys in jws and eab payload do not match")
|
||||
}
|
||||
|
||||
return externalAccountKey, nil
|
||||
}
|
||||
|
||||
// keysAreEqual performs an equality check on two JWKs by comparing
|
||||
// the (base64 encoding) of the Key IDs.
|
||||
func keysAreEqual(x, y *jose.JSONWebKey) bool {
|
||||
if x == nil || y == nil {
|
||||
return false
|
||||
}
|
||||
digestX, errX := acme.KeyToID(x)
|
||||
digestY, errY := acme.KeyToID(y)
|
||||
if errX != nil || errY != nil {
|
||||
return false
|
||||
}
|
||||
return digestX == digestY
|
||||
}
|
||||
|
||||
// validateEABJWS verifies the contents of the External Account Binding JWS.
|
||||
// The protected header of the JWS MUST meet the following criteria:
|
||||
//
|
||||
// - The "alg" field MUST indicate a MAC-based algorithm
|
||||
// - The "kid" field MUST contain the key identifier provided by the CA
|
||||
// - The "nonce" field MUST NOT be present
|
||||
// - The "url" field MUST be set to the same value as the outer JWS
|
||||
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
|
||||
if jws == nil {
|
||||
return "", acme.NewErrorISE("no JWS provided")
|
||||
}
|
||||
|
||||
if len(jws.Signatures) != 1 {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "JWS must have one signature")
|
||||
}
|
||||
|
||||
header := jws.Signatures[0].Protected
|
||||
algorithm := header.Algorithm
|
||||
keyID := header.KeyID
|
||||
nonce := header.Nonce
|
||||
|
||||
if !(algorithm == jose.HS256 || algorithm == jose.HS384 || algorithm == jose.HS512) {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'alg' field set to invalid algorithm '%s'", algorithm)
|
||||
}
|
||||
|
||||
if keyID == "" {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'kid' field is required")
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'nonce' must not be present")
|
||||
}
|
||||
|
||||
jwsURL, ok := header.ExtraHeaders["url"]
|
||||
if !ok {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'url' field is required")
|
||||
}
|
||||
|
||||
outerJWS, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
return "", acme.WrapErrorISE(err, "could not retrieve outer JWS from context")
|
||||
}
|
||||
|
||||
if len(outerJWS.Signatures) != 1 {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "outer JWS must have one signature")
|
||||
}
|
||||
|
||||
outerJWSURL, ok := outerJWS.Signatures[0].Protected.ExtraHeaders["url"]
|
||||
if !ok {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'url' field must be set in outer JWS")
|
||||
}
|
||||
|
||||
if jwsURL != outerJWSURL {
|
||||
return "", acme.NewError(acme.ErrorMalformedType, "'url' field is not the same value as the outer JWS")
|
||||
}
|
||||
|
||||
return keyID, nil
|
||||
}
|
1154
acme/api/eab_test.go
1154
acme/api/eab_test.go
File diff suppressed because it is too large
Load diff
|
@ -1,36 +1,30 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func link(url, typ string) string {
|
||||
return fmt.Sprintf("<%s>;rel=%q", url, typ)
|
||||
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
type contextKey string
|
||||
|
||||
// Now returns the UTC time rounded to seconds.
|
||||
func (c *Clock) Now() time.Time {
|
||||
return time.Now().UTC().Truncate(time.Second)
|
||||
}
|
||||
|
||||
var clock Clock
|
||||
const (
|
||||
accContextKey = contextKey("acc")
|
||||
jwsContextKey = contextKey("jws")
|
||||
jwkContextKey = contextKey("jwk")
|
||||
payloadContextKey = contextKey("payload")
|
||||
provisionerContextKey = contextKey("provisioner")
|
||||
)
|
||||
|
||||
type payloadInfo struct {
|
||||
value []byte
|
||||
|
@ -38,152 +32,82 @@ type payloadInfo struct {
|
|||
isEmptyJSON bool
|
||||
}
|
||||
|
||||
// HandlerOptions required to create a new ACME API request handler.
|
||||
type HandlerOptions struct {
|
||||
// DB storage backend that implements the acme.DB interface.
|
||||
//
|
||||
// Deprecated: use acme.NewContex(context.Context, acme.DB)
|
||||
DB acme.DB
|
||||
|
||||
// CA is the certificate authority interface.
|
||||
//
|
||||
// Deprecated: use authority.NewContext(context.Context, *authority.Authority)
|
||||
CA acme.CertificateAuthority
|
||||
|
||||
// Backdate is the duration that the CA will subtract from the current time
|
||||
// to set the NotBefore in the certificate.
|
||||
Backdate provisioner.Duration
|
||||
|
||||
// DNS the host used to generate accurate ACME links. By default the authority
|
||||
// will use the Host from the request, so this value will only be used if
|
||||
// request.Host is empty.
|
||||
DNS string
|
||||
|
||||
// Prefix is a URL path prefix under which the ACME api is served. This
|
||||
// prefix is required to generate accurate ACME links.
|
||||
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
|
||||
// "acme" is the prefix from which the ACME api is accessed.
|
||||
Prefix string
|
||||
|
||||
// PrerequisitesChecker checks if all prerequisites for serving ACME are
|
||||
// met by the CA configuration.
|
||||
PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
}
|
||||
|
||||
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// handler is the ACME API request handler.
|
||||
type handler struct {
|
||||
opts *HandlerOptions
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface. For backward compatibility
|
||||
// this route adds will add a new middleware that will set the ACME components
|
||||
// on the context.
|
||||
//
|
||||
// Note: this method is deprecated in step-ca, other applications can still use
|
||||
// this to support ACME, but the recommendation is to use use
|
||||
// api.Route(api.Router) and acme.NewContext() instead.
|
||||
func (h *handler) Route(r api.Router) {
|
||||
client := acme.NewClient()
|
||||
linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix)
|
||||
route(r, func(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
|
||||
ctx = authority.NewContext(ctx, ca)
|
||||
}
|
||||
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// NewHandler returns a new ACME API handler.
|
||||
//
|
||||
// Note: this method is deprecated in step-ca, other applications can still use
|
||||
// this to support ACME, but the recommendation is to use use
|
||||
// api.Route(api.Router) and acme.NewContext() instead.
|
||||
func NewHandler(opts HandlerOptions) api.RouterHandler {
|
||||
return &handler{
|
||||
opts: &opts,
|
||||
func accountFromContext(r *http.Request) (*acme.Account, error) {
|
||||
val, ok := r.Context().Value(accContextKey).(*acme.Account)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.AccountDoesNotExistErr(nil)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func jwkFromContext(r *http.Request) (*jose.JSONWebKey, error) {
|
||||
val, ok := r.Context().Value(jwkContextKey).(*jose.JSONWebKey)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("jwk expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func jwsFromContext(r *http.Request) (*jose.JSONWebSignature, error) {
|
||||
val, ok := r.Context().Value(jwsContextKey).(*jose.JSONWebSignature)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("jws expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func payloadFromContext(r *http.Request) (*payloadInfo, error) {
|
||||
val, ok := r.Context().Value(payloadContextKey).(*payloadInfo)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func provisionerFromContext(r *http.Request) (provisioner.Interface, error) {
|
||||
val, ok := r.Context().Value(provisionerContextKey).(provisioner.Interface)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("provisioner expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface. This method requires that
|
||||
// all the acme components, authority, db, client, linker, and prerequisite
|
||||
// checker to be present in the context.
|
||||
func Route(r api.Router) {
|
||||
route(r, nil)
|
||||
// New returns a new ACME API router.
|
||||
func New(acmeAuth acme.Interface) api.RouterHandler {
|
||||
return &Handler{acmeAuth}
|
||||
}
|
||||
|
||||
func route(r api.Router, middleware func(next nextHTTP) nextHTTP) {
|
||||
commonMiddleware := func(next nextHTTP) nextHTTP {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
// Linker middleware gets the provisioner and current url from the
|
||||
// request and sets them in the context.
|
||||
linker := acme.MustLinkerFromContext(r.Context())
|
||||
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
|
||||
}
|
||||
if middleware != nil {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
return handler
|
||||
}
|
||||
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
||||
return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
|
||||
}
|
||||
// Handler is the ACME request handler.
|
||||
type Handler struct {
|
||||
Auth acme.Interface
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
getLink := h.Auth.GetLink
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
|
||||
r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
|
||||
r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
|
||||
r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
|
||||
|
||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
|
||||
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))
|
||||
}
|
||||
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))
|
||||
}
|
||||
|
||||
getPath := acme.GetUnescapedPathSuffix
|
||||
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
|
||||
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
|
||||
extractPayloadByJWK(NewAccount))
|
||||
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(GetOrUpdateAccount))
|
||||
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(NotImplemented))
|
||||
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
|
||||
extractPayloadByKid(NewOrder))
|
||||
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrder)))
|
||||
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
|
||||
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(FinalizeOrder))
|
||||
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetAuthorization)))
|
||||
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
|
||||
extractPayloadByKid(GetChallenge))
|
||||
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetCertificate)))
|
||||
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
|
||||
extractPayloadByKidOrJWK(RevokeCert))
|
||||
r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false), extractPayloadByJWK(h.NewAccount))
|
||||
r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.GetUpdateAccount))
|
||||
r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false), extractPayloadByKid(h.NewOrder))
|
||||
r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
||||
r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount)))
|
||||
r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
||||
r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz)))
|
||||
r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, "{chID}"), extractPayloadByKid(h.GetChallenge))
|
||||
r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
||||
}
|
||||
|
||||
// GetNonce just sets the right header since a Nonce is added to each response
|
||||
// by middleware by default.
|
||||
func GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
|
@ -191,209 +115,95 @@ func GetNonce(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
type Meta struct {
|
||||
TermsOfService string `json:"termsOfService,omitempty"`
|
||||
Website string `json:"website,omitempty"`
|
||||
CaaIdentities []string `json:"caaIdentities,omitempty"`
|
||||
ExternalAccountRequired bool `json:"externalAccountRequired,omitempty"`
|
||||
}
|
||||
|
||||
// Directory represents an ACME directory for configuring clients.
|
||||
type Directory struct {
|
||||
NewNonce string `json:"newNonce"`
|
||||
NewAccount string `json:"newAccount"`
|
||||
NewOrder string `json:"newOrder"`
|
||||
RevokeCert string `json:"revokeCert"`
|
||||
KeyChange string `json:"keyChange"`
|
||||
Meta *Meta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging for the Directory type.
|
||||
func (d *Directory) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error marshaling directory for logging")
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetDirectory is the ACME resource for returning a directory configuration
|
||||
// for client configuration.
|
||||
func GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
render.JSON(w, &Directory{
|
||||
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
|
||||
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
|
||||
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
|
||||
RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
|
||||
KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
|
||||
Meta: createMetaObject(acmeProv),
|
||||
})
|
||||
dir := h.Auth.GetDirectory(prov)
|
||||
api.JSON(w, dir)
|
||||
}
|
||||
|
||||
// createMetaObject creates a Meta object if the ACME provisioner
|
||||
// has one or more properties that are written in the ACME directory output.
|
||||
// It returns nil if none of the properties are set.
|
||||
func createMetaObject(p *provisioner.ACME) *Meta {
|
||||
if shouldAddMetaObject(p) {
|
||||
return &Meta{
|
||||
TermsOfService: p.TermsOfService,
|
||||
Website: p.Website,
|
||||
CaaIdentities: p.CaaIdentities,
|
||||
ExternalAccountRequired: p.RequireEAB,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldAddMetaObject returns whether or not the ACME provisioner
|
||||
// has properties configured that must be added to the ACME directory object.
|
||||
func shouldAddMetaObject(p *provisioner.ACME) bool {
|
||||
switch {
|
||||
case p.TermsOfService != "":
|
||||
return true
|
||||
case p.Website != "":
|
||||
return true
|
||||
case len(p.CaaIdentities) > 0:
|
||||
return true
|
||||
case p.RequireEAB:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
||||
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
||||
func NotImplemented(w http.ResponseWriter, _ *http.Request) {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
||||
}
|
||||
|
||||
// GetAuthorization ACME api for retrieving an Authz.
|
||||
func GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
// GetAuthz ACME api for retrieving an Authz.
|
||||
func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if acc.ID != az.AccountID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
||||
return
|
||||
}
|
||||
if err = az.UpdateStatus(ctx, db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
||||
authz, err := h.Auth.GetAuthz(prov, acc.GetID(), chi.URLParam(r, "authzID"))
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkAuthorization(ctx, az)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
|
||||
render.JSON(w, az)
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, authz.GetID()))
|
||||
api.JSON(w, authz)
|
||||
}
|
||||
|
||||
// GetChallenge ACME api for retrieving a Challenge.
|
||||
func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
// Just verify that the payload was set, since we're not strictly adhering
|
||||
// to ACME V2 spec for reasons specified below.
|
||||
_, err = payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// NOTE: We should be checking that the request is either a POST-as-GET, or
|
||||
// that for all challenges except for device-attest-01, the payload is an
|
||||
// empty JSON block ({}). However, older ACME clients still send a vestigial
|
||||
// body (rather than an empty JSON block) and strict enforcement would
|
||||
// render these clients broken.
|
||||
|
||||
azID := chi.URLParam(r, "authzID")
|
||||
ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||
// that the payload is an empty JSON block ({}). However, older ACME clients
|
||||
// still send a vestigial body (rather than an empty JSON block) and
|
||||
// strict enforcement would render these clients broken. For the time being
|
||||
// we'll just ignore the body.
|
||||
var (
|
||||
ch *acme.Challenge
|
||||
chID = chi.URLParam(r, "chID")
|
||||
)
|
||||
ch, err = h.Auth.ValidateChallenge(prov, acc.GetID(), chID, acc.GetKey())
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
||||
return
|
||||
}
|
||||
ch.AuthorizationID = azID
|
||||
if acc.ID != ch.AccountID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own challenge '%s'", acc.ID, ch.ID))
|
||||
return
|
||||
}
|
||||
jwk, err := jwkFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if err = ch.Validate(ctx, db, jwk, payload.value); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkChallenge(ctx, ch, azID)
|
||||
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
|
||||
render.JSON(w, ch)
|
||||
getLink := h.Auth.GetLink
|
||||
w.Header().Add("Link", link(getLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, ch.GetAuthzID()), "up"))
|
||||
w.Header().Set("Location", getLink(acme.ChallengeLink, acme.URLSafeProvisionerName(prov), true, ch.GetID()))
|
||||
api.JSON(w, ch)
|
||||
}
|
||||
|
||||
// GetCertificate ACME api for retrieving a Certificate.
|
||||
func GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
certID := chi.URLParam(r, "certID")
|
||||
cert, err := db.GetCertificate(ctx, certID)
|
||||
certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
||||
return
|
||||
}
|
||||
if cert.AccountID != acc.ID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own certificate '%s'", acc.ID, certID))
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var certBytes []byte
|
||||
for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) {
|
||||
certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: c.Raw,
|
||||
})...)
|
||||
}
|
||||
|
||||
api.LogCertificate(w, cert.Leaf)
|
||||
w.Header().Set("Content-Type", "application/pem-certificate-chain")
|
||||
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
|
||||
w.Write(certBytes)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -3,20 +3,20 @@ package api
|
|||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
@ -31,79 +31,74 @@ func logNonce(w http.ResponseWriter, nonce string) {
|
|||
}
|
||||
|
||||
// addNonce is a middleware that adds a nonce to the response header.
|
||||
func addNonce(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
db := acme.MustDatabaseFromContext(r.Context())
|
||||
nonce, err := db.CreateNonce(r.Context())
|
||||
nonce, err := h.Auth.NewNonce()
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Replay-Nonce", string(nonce))
|
||||
w.Header().Set("Replay-Nonce", nonce)
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
logNonce(w, string(nonce))
|
||||
logNonce(w, nonce)
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// addDirLink is a middleware that adds a 'Link' response reader with the
|
||||
// directory index url.
|
||||
func addDirLink(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Link", link(h.Auth.GetLink(acme.DirectoryLink, acme.URLSafeProvisionerName(prov), true), "index"))
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// verifyContentType is a middleware that verifies that content type is
|
||||
// application/jose+json.
|
||||
func verifyContentType(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
p, err := provisionerFromContext(r.Context())
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
|
||||
}
|
||||
|
||||
ct := r.Header.Get("Content-Type")
|
||||
var expected []string
|
||||
if strings.Contains(r.URL.String(), u.EscapedPath()) {
|
||||
if strings.Contains(r.URL.Path, h.Auth.GetLink(acme.CertificateLink, acme.URLSafeProvisionerName(prov), false, "")) {
|
||||
// GET /certificate requests allow a greater range of content types.
|
||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||
} else {
|
||||
// By default every request should have content-type applictaion/jose+json.
|
||||
expected = []string{"application/jose+json"}
|
||||
}
|
||||
|
||||
ct := r.Header.Get("Content-Type")
|
||||
for _, e := range expected {
|
||||
if ct == e {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"expected content-type to be in %s, but got %s", expected, ct))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf(
|
||||
"expected content-type to be in %s, but got %s", expected, ct)))
|
||||
}
|
||||
}
|
||||
|
||||
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
|
||||
func parseJWS(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "failed to read request body"))
|
||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body")))
|
||||
return
|
||||
}
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body")))
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
|
||||
|
@ -119,29 +114,26 @@ func parseJWS(next nextHTTP) nextHTTP {
|
|||
// The JWS Unprotected Header [RFC7515] MUST NOT be used
|
||||
// The JWS Payload MUST NOT be detached
|
||||
// The JWS Protected Header MUST include the following fields:
|
||||
// - “alg” (Algorithm).
|
||||
// This field MUST NOT contain “none” or a Message Authentication Code
|
||||
// (MAC) algorithm (e.g. one in which the algorithm registry description
|
||||
// mentions MAC/HMAC).
|
||||
// - “nonce” (defined in Section 6.5)
|
||||
// - “url” (defined in Section 6.4)
|
||||
// - Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func validateJWS(next nextHTTP) nextHTTP {
|
||||
// * “alg” (Algorithm)
|
||||
// * This field MUST NOT contain “none” or a Message Authentication Code
|
||||
// (MAC) algorithm (e.g. one in which the algorithm registry description
|
||||
// mentions MAC/HMAC).
|
||||
// * “nonce” (defined in Section 6.5)
|
||||
// * “url” (defined in Section 6.4)
|
||||
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature")))
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) > 1 {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature")))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -152,59 +144,57 @@ func validateJWS(next nextHTTP) nextHTTP {
|
|||
len(uh.Algorithm) > 0 ||
|
||||
len(uh.Nonce) > 0 ||
|
||||
len(uh.ExtraHeaders) > 0 {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used")))
|
||||
return
|
||||
}
|
||||
hdr := sig.Protected
|
||||
switch hdr.Algorithm {
|
||||
case jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512:
|
||||
case jose.RS256, jose.RS384, jose.RS512:
|
||||
if hdr.JSONWebKey != nil {
|
||||
switch k := hdr.JSONWebKey.Key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if k.Size() < keyutil.MinRSAKeyBytes {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"rsa keys must be at least %d bits (%d bytes) in size",
|
||||
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
|
||||
if k.Size() < keys.MinRSAKeyBytes {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+
|
||||
"keys must be at least %d bits (%d bytes) in size",
|
||||
8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes)))
|
||||
return
|
||||
}
|
||||
default:
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"jws key type and algorithm do not match"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")))
|
||||
return
|
||||
}
|
||||
}
|
||||
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
||||
// we good
|
||||
default:
|
||||
render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm)))
|
||||
return
|
||||
}
|
||||
|
||||
// Check the validity/freshness of the Nonce.
|
||||
if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||
render.Error(w, err)
|
||||
if err := h.Auth.UseNonce(hdr.Nonce); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the JWS url matches the requested url.
|
||||
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header")))
|
||||
return
|
||||
}
|
||||
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
||||
if jwsURL != reqURL.String() {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)))
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")))
|
||||
return
|
||||
}
|
||||
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
||||
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
|
@ -214,48 +204,40 @@ func validateJWS(next nextHTTP) nextHTTP {
|
|||
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
|
||||
// in the context. Make sure to parse and validate the JWS before running this
|
||||
// middleware.
|
||||
func extractJWK(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jwk := jws.Signatures[0].Protected.JSONWebKey
|
||||
if jwk == nil {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header")))
|
||||
return
|
||||
}
|
||||
if !jwk.Valid() {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header")))
|
||||
return
|
||||
}
|
||||
|
||||
// Overwrite KeyID with the JWK thumbprint.
|
||||
jwk.KeyID, err = acme.KeyToID(jwk)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
|
||||
return
|
||||
}
|
||||
|
||||
// Store the JWK in the context.
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
|
||||
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
|
||||
acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||
acc, err := h.Auth.GetAccountByKey(prov, jwk)
|
||||
switch {
|
||||
case errors.Is(err, acme.ErrNotFound):
|
||||
// For NewAccount and Revoke requests ...
|
||||
case nosql.IsErrNotFound(err):
|
||||
// For NewAccount requests ...
|
||||
break
|
||||
case err != nil:
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
|
@ -264,101 +246,71 @@ func extractJWK(next nextHTTP) nextHTTP {
|
|||
}
|
||||
}
|
||||
|
||||
// checkPrerequisites checks if all prerequisites for serving ACME
|
||||
// are met by the CA configuration.
|
||||
func checkPrerequisites(next nextHTTP) nextHTTP {
|
||||
// lookupProvisioner loads the provisioner associated with the request.
|
||||
// Responsds 404 if the provisioner does not exist.
|
||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
// If the function is not set assume that all prerequisites are met.
|
||||
checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
|
||||
if ok {
|
||||
ok, err := checkFunc(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
|
||||
name := chi.URLParam(r, "provisionerID")
|
||||
provID, err := url.PathUnescape(name)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name)))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
p, err := h.Auth.LoadProvisionerByID("acme/" + provID)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if p.GetType() != provisioner.TypeACME {
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, p)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||
// kid parameter of the signed payload.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func lookupJWK(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
kidPrefix := h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, "")
|
||||
kid := jws.Signatures[0].Protected.KeyID
|
||||
if kid == "" {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"))
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+
|
||||
"required prefix; expected %s, but got %s", kidPrefix, kid)))
|
||||
return
|
||||
}
|
||||
|
||||
accID := path.Base(kid)
|
||||
acc, err := db.GetAccount(ctx, accID)
|
||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||
acc, err := h.Auth.GetAccount(prov, accID)
|
||||
switch {
|
||||
case acme.IsErrNotFound(err):
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||
case nosql.IsErrNotFound(err):
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
||||
return
|
||||
case err != nil:
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||
return
|
||||
}
|
||||
|
||||
if storedLocation := acc.GetLocation(); storedLocation != "" {
|
||||
if kid != storedLocation {
|
||||
// ACME accounts should have a stored location equivalent to the
|
||||
// kid in the ACME request.
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"kid does not match stored account location; expected %s, but got %s",
|
||||
storedLocation, kid))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify that the provisioner with which the account was created
|
||||
// matches the provisioner in the request URL.
|
||||
reqProv := acme.MustProvisionerFromContext(ctx)
|
||||
reqProvName := reqProv.GetName()
|
||||
accProvName := acc.ProvisionerName
|
||||
if reqProvName != accProvName {
|
||||
// Provisioner in the URL must match the provisioner with
|
||||
// which the account was created.
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s",
|
||||
accProvName, reqProvName))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// This code will only execute for old ACME accounts that do
|
||||
// not have a cached location. The following validation was
|
||||
// the original implementation of the `kid` check which has
|
||||
// since been deprecated. However, the code will remain to
|
||||
// ensure consistent behavior for old ACME accounts.
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"kid does not have required prefix; expected %s, but got %s",
|
||||
kidPrefix, kid))
|
||||
return
|
||||
}
|
||||
}
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
|
||||
next(w, r.WithContext(ctx))
|
||||
|
@ -367,69 +319,32 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
|||
}
|
||||
}
|
||||
|
||||
// extractOrLookupJWK forwards handling to either extractJWK or
|
||||
// lookupJWK based on the presence of a JWK or a KID, respectively.
|
||||
func extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// at this point the JWS has already been verified (if correctly configured in middleware),
|
||||
// and it can be used to check if a JWK exists. This flow is used when the ACME client
|
||||
// signed the payload with a certificate private key.
|
||||
if canExtractJWKFrom(jws) {
|
||||
extractJWK(next)(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
|
||||
// signed the payload with an account private key.
|
||||
lookupJWK(next)(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// canExtractJWKFrom checks if the JWS has a JWK that can be extracted
|
||||
func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
|
||||
if jws == nil {
|
||||
return false
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
return false
|
||||
}
|
||||
return jws.Signatures[0].Protected.JSONWebKey != nil
|
||||
}
|
||||
|
||||
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jwk, err := jwkFromContext(ctx)
|
||||
jwk, err := jwkFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
||||
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||
api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match")))
|
||||
return
|
||||
}
|
||||
payload, err := jws.Verify(jwk)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
|
||||
ctx := context.WithValue(r.Context(), payloadContextKey, &payloadInfo{
|
||||
value: payload,
|
||||
isPostAsGet: len(payload) == 0,
|
||||
isPostAsGet: string(payload) == "",
|
||||
isEmptyJSON: string(payload) == "{}",
|
||||
})
|
||||
next(w, r.WithContext(ctx))
|
||||
|
@ -437,95 +352,17 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
|
||||
func isPostAsGet(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, err := payloadFromContext(r.Context())
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if !payload.isPostAsGet {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET")))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// ContextKey is the key type for storing and searching for ACME request
|
||||
// essentials in the context of a request.
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// accContextKey account key
|
||||
accContextKey = ContextKey("acc")
|
||||
// jwsContextKey jws key
|
||||
jwsContextKey = ContextKey("jws")
|
||||
// jwkContextKey jwk key
|
||||
jwkContextKey = ContextKey("jwk")
|
||||
// payloadContextKey payload key
|
||||
payloadContextKey = ContextKey("payload")
|
||||
)
|
||||
|
||||
// accountFromContext searches the context for an ACME account. Returns the
|
||||
// account or an error.
|
||||
func accountFromContext(ctx context.Context) (*acme.Account, error) {
|
||||
val, ok := ctx.Value(accContextKey).(*acme.Account)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
|
||||
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
|
||||
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.NewErrorISE("jwk expected in request context")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// jwsFromContext searches the context for a JWS. Returns the JWS or an error.
|
||||
func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
|
||||
val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.NewErrorISE("jws expected in request context")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// provisionerFromContext searches the context for a provisioner. Returns the
|
||||
// provisioner or an error.
|
||||
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
|
||||
p, ok := acme.ProvisionerFromContext(ctx)
|
||||
if !ok || p == nil {
|
||||
return nil, acme.NewErrorISE("provisioner expected in request context")
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
|
||||
// pointer to an ACME provisioner or an error.
|
||||
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
|
||||
p, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ap, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
|
||||
return ap, nil
|
||||
}
|
||||
|
||||
// payloadFromContext searches the context for a payload. Returns the payload
|
||||
// or an error.
|
||||
func payloadFromContext(ctx context.Context) (*payloadInfo, error) {
|
||||
val, ok := ctx.Value(payloadContextKey).(*payloadInfo)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.NewErrorISE("payload expected in request context")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,24 +1,16 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"go.step.sm/crypto/randutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/api"
|
||||
)
|
||||
|
||||
// NewOrderRequest represents the body for a NewOrder request.
|
||||
|
@ -31,29 +23,12 @@ type NewOrderRequest struct {
|
|||
// Validate validates a new-order request body.
|
||||
func (n *NewOrderRequest) Validate() error {
|
||||
if len(n.Identifiers) == 0 {
|
||||
return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty")
|
||||
return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty"))
|
||||
}
|
||||
for _, id := range n.Identifiers {
|
||||
switch id.Type {
|
||||
case acme.IP:
|
||||
if net.ParseIP(id.Value) == nil {
|
||||
return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value)
|
||||
}
|
||||
case acme.DNS:
|
||||
value, _ := trimIfWildcard(id.Value)
|
||||
if _, err := x509util.SanitizeName(value); err != nil {
|
||||
return acme.NewError(acme.ErrorMalformedType, "invalid DNS name: %s", id.Value)
|
||||
}
|
||||
case acme.PermanentIdentifier:
|
||||
if id.Value == "" {
|
||||
return acme.NewError(acme.ErrorMalformedType, "permanent identifier cannot be empty")
|
||||
}
|
||||
default:
|
||||
return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type)
|
||||
if id.Type != "dns" {
|
||||
return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type))
|
||||
}
|
||||
|
||||
// TODO(hs): add some validations for DNS domains?
|
||||
// TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -67,341 +42,120 @@ type FinalizeRequest struct {
|
|||
// Validate validates a finalize request body.
|
||||
func (f *FinalizeRequest) Validate() error {
|
||||
var err error
|
||||
// RFC 8555 isn't 100% conclusive about using raw base64-url encoding for the
|
||||
// CSR specifically, instead of "normal" base64-url encoding (incl. padding).
|
||||
// By trimming the padding from CSRs submitted by ACME clients that use
|
||||
// base64-url encoding instead of raw base64-url encoding, these are also
|
||||
// supported. This was reported in https://github.com/smallstep/certificates/issues/939
|
||||
// to be the case for a Synology DSM NAS system.
|
||||
csrBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(f.CSR, "="))
|
||||
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
|
||||
if err != nil {
|
||||
return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr")
|
||||
return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr"))
|
||||
}
|
||||
f.csr, err = x509.ParseCertificateRequest(csrBytes)
|
||||
if err != nil {
|
||||
return acme.WrapError(acme.ErrorMalformedType, err, "unable to parse csr")
|
||||
return acme.MalformedErr(errors.Wrap(err, "unable to parse csr"))
|
||||
}
|
||||
if err = f.csr.CheckSignature(); err != nil {
|
||||
return acme.WrapError(acme.ErrorMalformedType, err, "csr failed signature check")
|
||||
return acme.MalformedErr(errors.Wrap(err, "csr failed signature check"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var defaultOrderExpiry = time.Hour * 24
|
||||
var defaultOrderBackdate = time.Minute
|
||||
|
||||
// NewOrder ACME api for creating a new order.
|
||||
func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ca := mustAuthority(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(ctx)
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var nor NewOrderRequest
|
||||
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
"failed to unmarshal new-order request payload"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
||||
"failed to unmarshal new-order request payload")))
|
||||
return
|
||||
}
|
||||
|
||||
if err := nor.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(hs): gather all errors, so that we can build one response with ACME subproblems
|
||||
// include the nor.Validate() error here too, like in the example in the ACME RFC?
|
||||
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
o, err := h.Auth.NewOrder(prov, acme.OrderOptions{
|
||||
AccountID: acc.GetID(),
|
||||
Identifiers: nor.Identifiers,
|
||||
NotBefore: nor.NotBefore,
|
||||
NotAfter: nor.NotAfter,
|
||||
})
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var eak *acme.ExternalAccountKey
|
||||
if acmeProv.RequireEAB {
|
||||
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
acmePolicy, err := newACMEPolicyEngine(eak)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine"))
|
||||
return
|
||||
}
|
||||
|
||||
for _, identifier := range nor.Identifiers {
|
||||
// evaluate the ACME account level policy
|
||||
if err = isIdentifierAllowed(acmePolicy, identifier); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
// evaluate the provisioner level policy
|
||||
orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
|
||||
if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
// evaluate the authority level policy
|
||||
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
// New order.
|
||||
o := &acme.Order{
|
||||
AccountID: acc.ID,
|
||||
ProvisionerID: prov.GetID(),
|
||||
Status: acme.StatusPending,
|
||||
Identifiers: nor.Identifiers,
|
||||
ExpiresAt: now.Add(defaultOrderExpiry),
|
||||
AuthorizationIDs: make([]string, len(nor.Identifiers)),
|
||||
NotBefore: nor.NotBefore,
|
||||
NotAfter: nor.NotAfter,
|
||||
}
|
||||
|
||||
for i, identifier := range o.Identifiers {
|
||||
az := &acme.Authorization{
|
||||
AccountID: acc.ID,
|
||||
Identifier: identifier,
|
||||
ExpiresAt: o.ExpiresAt,
|
||||
Status: acme.StatusPending,
|
||||
}
|
||||
if err := newAuthorization(ctx, az); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
o.AuthorizationIDs[i] = az.ID
|
||||
}
|
||||
|
||||
if o.NotBefore.IsZero() {
|
||||
o.NotBefore = now
|
||||
}
|
||||
if o.NotAfter.IsZero() {
|
||||
o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration())
|
||||
}
|
||||
// If request NotBefore was empty then backdate the order.NotBefore (now)
|
||||
// to avoid timing issues.
|
||||
if nor.NotBefore.IsZero() {
|
||||
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
|
||||
}
|
||||
|
||||
if err := db.CreateOrder(ctx, o); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSONStatus(w, o, http.StatusCreated)
|
||||
}
|
||||
|
||||
func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error {
|
||||
if acmePolicy == nil {
|
||||
return nil
|
||||
}
|
||||
return acmePolicy.AreSANsAllowed([]string{identifier.Value})
|
||||
}
|
||||
|
||||
func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error) {
|
||||
if eak == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return policy.NewX509PolicyEngine(eak.Policy)
|
||||
}
|
||||
|
||||
func trimIfWildcard(value string) (string, bool) {
|
||||
if strings.HasPrefix(value, "*.") {
|
||||
return strings.TrimPrefix(value, "*."), true
|
||||
}
|
||||
return value, false
|
||||
}
|
||||
|
||||
func newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
value, isWildcard := trimIfWildcard(az.Identifier.Value)
|
||||
az.Wildcard = isWildcard
|
||||
az.Identifier = acme.Identifier{
|
||||
Value: value,
|
||||
Type: az.Identifier.Type,
|
||||
}
|
||||
|
||||
chTypes := challengeTypes(az)
|
||||
|
||||
var err error
|
||||
az.Token, err = randutil.Alphanumeric(32)
|
||||
if err != nil {
|
||||
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
|
||||
}
|
||||
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
prov := acme.MustProvisionerFromContext(ctx)
|
||||
az.Challenges = make([]*acme.Challenge, 0, len(chTypes))
|
||||
for _, typ := range chTypes {
|
||||
if !prov.IsChallengeEnabled(ctx, provisioner.ACMEChallenge(typ)) {
|
||||
continue
|
||||
}
|
||||
|
||||
ch := &acme.Challenge{
|
||||
AccountID: az.AccountID,
|
||||
Value: az.Identifier.Value,
|
||||
Type: typ,
|
||||
Token: az.Token,
|
||||
Status: acme.StatusPending,
|
||||
}
|
||||
if err := db.CreateChallenge(ctx, ch); err != nil {
|
||||
return acme.WrapErrorISE(err, "error creating challenge")
|
||||
}
|
||||
az.Challenges = append(az.Challenges, ch)
|
||||
}
|
||||
if err = db.CreateAuthorization(ctx, az); err != nil {
|
||||
return acme.WrapErrorISE(err, "error creating authorization")
|
||||
}
|
||||
return nil
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
|
||||
api.JSONStatus(w, o, http.StatusCreated)
|
||||
}
|
||||
|
||||
// GetOrder ACME api for retrieving an order.
|
||||
func GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
oid := chi.URLParam(r, "ordID")
|
||||
o, err := h.Auth.GetOrder(prov, acc.GetID(), oid)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
}
|
||||
if acc.ID != o.AccountID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
||||
return
|
||||
}
|
||||
if prov.GetID() != o.ProvisionerID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
return
|
||||
}
|
||||
if err = o.UpdateStatus(ctx, db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
|
||||
api.JSON(w, o)
|
||||
}
|
||||
|
||||
// FinalizeOrder attempts to finalize an order and create a certificate.
|
||||
func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(ctx)
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var fr FinalizeRequest
|
||||
if err := json.Unmarshal(payload.value, &fr); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
"failed to unmarshal finalize-order request payload"))
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload")))
|
||||
return
|
||||
}
|
||||
if err := fr.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
oid := chi.URLParam(r, "ordID")
|
||||
o, err := h.Auth.FinalizeOrder(prov, acc.GetID(), oid, fr.csr)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
}
|
||||
if acc.ID != o.AccountID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
||||
return
|
||||
}
|
||||
if prov.GetID() != o.ProvisionerID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ca := mustAuthority(ctx)
|
||||
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
}
|
||||
|
||||
// challengeTypes determines the types of challenges that should be used
|
||||
// for the ACME authorization request.
|
||||
func challengeTypes(az *acme.Authorization) []acme.ChallengeType {
|
||||
var chTypes []acme.ChallengeType
|
||||
|
||||
switch az.Identifier.Type {
|
||||
case acme.IP:
|
||||
chTypes = []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}
|
||||
case acme.DNS:
|
||||
chTypes = []acme.ChallengeType{acme.DNS01}
|
||||
// HTTP and TLS challenges can only be used for identifiers without wildcards.
|
||||
if !az.Wildcard {
|
||||
chTypes = append(chTypes, []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}...)
|
||||
}
|
||||
case acme.PermanentIdentifier:
|
||||
chTypes = []acme.ChallengeType{acme.DEVICEATTEST01}
|
||||
default:
|
||||
chTypes = []acme.ChallengeType{}
|
||||
}
|
||||
|
||||
return chTypes
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.ID))
|
||||
api.JSON(w, o)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,291 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
type revokePayload struct {
|
||||
Certificate string `json:"certificate"`
|
||||
ReasonCode *int `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// RevokeCert attempts to revoke a certificate.
|
||||
func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var p revokePayload
|
||||
err = json.Unmarshal(payload.value, &p)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
|
||||
return
|
||||
}
|
||||
|
||||
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
|
||||
if err != nil {
|
||||
// in this case the most likely cause is a client that didn't properly encode the certificate
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
|
||||
return
|
||||
}
|
||||
|
||||
certToBeRevoked, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
// in this case a client may have encoded something different than a certificate
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
serial := certToBeRevoked.SerialNumber.String()
|
||||
dbCert, err := db.GetCertificateBySerial(ctx, serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
|
||||
// this should never happen
|
||||
render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal"))
|
||||
return
|
||||
}
|
||||
|
||||
if shouldCheckAccountFrom(jws) {
|
||||
account, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||
if acmeErr != nil {
|
||||
render.Error(w, acmeErr)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// if account doesn't need to be checked, the JWS should be verified to be signed by the
|
||||
// private key that belongs to the public key in the certificate to be revoked.
|
||||
_, err := jws.Verify(certToBeRevoked.PublicKey)
|
||||
if err != nil {
|
||||
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
|
||||
render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ca := mustAuthority(ctx)
|
||||
hasBeenRevokedBefore, err := ca.IsRevoked(serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
if hasBeenRevokedBefore {
|
||||
render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
|
||||
return
|
||||
}
|
||||
|
||||
reasonCode := p.ReasonCode
|
||||
acmeErr := validateReasonCode(reasonCode)
|
||||
if acmeErr != nil {
|
||||
render.Error(w, acmeErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Authorize revocation by ACME provisioner
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
|
||||
err = prov.AuthorizeRevoke(ctx, "")
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
|
||||
return
|
||||
}
|
||||
|
||||
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
||||
err = ca.Revoke(ctx, options)
|
||||
if err != nil {
|
||||
render.Error(w, wrapRevokeErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
logRevoke(w, options)
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
w.Write(nil)
|
||||
}
|
||||
|
||||
// isAccountAuthorized checks if an ACME account that was retrieved earlier is authorized
|
||||
// to revoke the certificate. An Account must always be valid in order to revoke a certificate.
|
||||
// In case the certificate retrieved from the database belongs to the Account, the Account is
|
||||
// authorized. If the certificate retrieved from the database doesn't belong to the Account,
|
||||
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
|
||||
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
|
||||
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
|
||||
func isAccountAuthorized(_ context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||
if !account.IsValid() {
|
||||
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
|
||||
}
|
||||
certificateBelongsToAccount := dbCert.AccountID == account.ID
|
||||
if certificateBelongsToAccount {
|
||||
return nil // return early
|
||||
}
|
||||
|
||||
// TODO(hs): according to RFC8555: 7.6, a server MUST consider the following accounts authorized
|
||||
// to revoke a certificate:
|
||||
//
|
||||
// o the account that issued the certificate.
|
||||
// o an account that holds authorizations for all of the identifiers in the certificate.
|
||||
//
|
||||
// We currently only support the first case. The second might result in step going OOM when
|
||||
// large numbers of Authorizations are involved when the current nosql interface is in use.
|
||||
// We want to protect users from this failure scenario, so that's why it hasn't been added yet.
|
||||
// This issue is tracked in https://github.com/smallstep/certificates/issues/767
|
||||
|
||||
// not authorized; fail closed.
|
||||
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' is not authorized", account.ID), nil)
|
||||
}
|
||||
|
||||
// wrapRevokeErr is a best effort implementation to transform an error during
|
||||
// revocation into an ACME error, so that clients can understand the error.
|
||||
func wrapRevokeErr(err error) *acme.Error {
|
||||
t := err.Error()
|
||||
if strings.Contains(t, "is already revoked") {
|
||||
return acme.NewError(acme.ErrorAlreadyRevokedType, t)
|
||||
}
|
||||
return acme.WrapErrorISE(err, "error when revoking certificate")
|
||||
}
|
||||
|
||||
// unauthorizedError returns an ACME error indicating the request was
|
||||
// not authorized to revoke the certificate.
|
||||
func wrapUnauthorizedError(cert *x509.Certificate, unauthorizedIdentifiers []acme.Identifier, msg string, err error) *acme.Error {
|
||||
var acmeErr *acme.Error
|
||||
if err == nil {
|
||||
acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg)
|
||||
} else {
|
||||
acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg)
|
||||
}
|
||||
acmeErr.Status = http.StatusForbidden // RFC8555 7.6 shows example with 403
|
||||
|
||||
switch {
|
||||
case len(unauthorizedIdentifiers) > 0:
|
||||
identifier := unauthorizedIdentifiers[0] // picking the first; compound may be an option too?
|
||||
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", identifier.Value)
|
||||
case cert.Subject.String() != "":
|
||||
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.CommonName)
|
||||
default:
|
||||
acmeErr.Detail = "No authorization provided"
|
||||
}
|
||||
|
||||
return acmeErr
|
||||
}
|
||||
|
||||
// logRevoke logs successful revocation of certificate
|
||||
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"serial": ri.Serial,
|
||||
"reasonCode": ri.ReasonCode,
|
||||
"reason": ri.Reason,
|
||||
"passiveOnly": ri.PassiveOnly,
|
||||
"ACME": ri.ACME,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateReasonCode validates the revocation reason
|
||||
func validateReasonCode(reasonCode *int) *acme.Error {
|
||||
if reasonCode != nil && ((*reasonCode < ocsp.Unspecified || *reasonCode > ocsp.AACompromise) || *reasonCode == 7) {
|
||||
return acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds")
|
||||
}
|
||||
// NOTE: it's possible to add additional requirements to the reason code:
|
||||
// The server MAY disallow a subset of reasonCodes from being
|
||||
// used by the user. If a request contains a disallowed reasonCode,
|
||||
// then the server MUST reject it with the error type
|
||||
// "urn:ietf:params:acme:error:badRevocationReason"
|
||||
// No additional checks have been implemented so far.
|
||||
return nil
|
||||
}
|
||||
|
||||
// revokeOptions determines the RevokeOptions for the Authority to use in revocation
|
||||
func revokeOptions(serial string, certToBeRevoked *x509.Certificate, reasonCode *int) *authority.RevokeOptions {
|
||||
opts := &authority.RevokeOptions{
|
||||
Serial: serial,
|
||||
ACME: true,
|
||||
Crt: certToBeRevoked,
|
||||
}
|
||||
if reasonCode != nil { // NOTE: when implementing CRL and/or OCSP, and reason code is missing, CRL entry extension should be omitted
|
||||
opts.Reason = reason(*reasonCode)
|
||||
opts.ReasonCode = *reasonCode
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// reason transforms an integer reason code to a
|
||||
// textual description of the revocation reason.
|
||||
func reason(reasonCode int) string {
|
||||
switch reasonCode {
|
||||
case ocsp.Unspecified:
|
||||
return "unspecified reason"
|
||||
case ocsp.KeyCompromise:
|
||||
return "key compromised"
|
||||
case ocsp.CACompromise:
|
||||
return "ca compromised"
|
||||
case ocsp.AffiliationChanged:
|
||||
return "affiliation changed"
|
||||
case ocsp.Superseded:
|
||||
return "superseded"
|
||||
case ocsp.CessationOfOperation:
|
||||
return "cessation of operation"
|
||||
case ocsp.CertificateHold:
|
||||
return "certificate hold"
|
||||
case ocsp.RemoveFromCRL:
|
||||
return "remove from crl"
|
||||
case ocsp.PrivilegeWithdrawn:
|
||||
return "privilege withdrawn"
|
||||
case ocsp.AACompromise:
|
||||
return "aa compromised"
|
||||
default:
|
||||
return "unspecified reason"
|
||||
}
|
||||
}
|
||||
|
||||
// shouldCheckAccountFrom indicates whether an account should be
|
||||
// retrieved from the context, so that it can be used for
|
||||
// additional checks. This should only be done when no JWK
|
||||
// can be extracted from the request, as that would indicate
|
||||
// that the revocation request was signed with a certificate
|
||||
// key pair (and not an account key pair). Looking up such
|
||||
// a JWK would result in no Account being found.
|
||||
func shouldCheckAccountFrom(jws *jose.JSONWebSignature) bool {
|
||||
return !canExtractJWKFrom(jws)
|
||||
}
|
File diff suppressed because it is too large
Load diff
288
acme/authority.go
Normal file
288
acme/authority.go
Normal file
|
@ -0,0 +1,288 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
database "github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Interface is the acme authority interface.
|
||||
type Interface interface {
|
||||
DeactivateAccount(provisioner.Interface, string) (*Account, error)
|
||||
FinalizeOrder(provisioner.Interface, string, string, *x509.CertificateRequest) (*Order, error)
|
||||
GetAccount(provisioner.Interface, string) (*Account, error)
|
||||
GetAccountByKey(provisioner.Interface, *jose.JSONWebKey) (*Account, error)
|
||||
GetAuthz(provisioner.Interface, string, string) (*Authz, error)
|
||||
GetCertificate(string, string) ([]byte, error)
|
||||
GetDirectory(provisioner.Interface) *Directory
|
||||
GetLink(Link, string, bool, ...string) string
|
||||
GetOrder(provisioner.Interface, string, string) (*Order, error)
|
||||
GetOrdersByAccount(provisioner.Interface, string) ([]string, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
NewAccount(provisioner.Interface, AccountOptions) (*Account, error)
|
||||
NewNonce() (string, error)
|
||||
NewOrder(provisioner.Interface, OrderOptions) (*Order, error)
|
||||
UpdateAccount(provisioner.Interface, string, []string) (*Account, error)
|
||||
UseNonce(string) error
|
||||
ValidateChallenge(provisioner.Interface, string, string, *jose.JSONWebKey) (*Challenge, error)
|
||||
}
|
||||
|
||||
// Authority is the layer that handles all ACME interactions.
|
||||
type Authority struct {
|
||||
db nosql.DB
|
||||
dir *directory
|
||||
signAuth SignAuthority
|
||||
}
|
||||
|
||||
var (
|
||||
accountTable = []byte("acme_accounts")
|
||||
accountByKeyIDTable = []byte("acme_keyID_accountID_index")
|
||||
authzTable = []byte("acme_authzs")
|
||||
challengeTable = []byte("acme_challenges")
|
||||
nonceTable = []byte("nonces")
|
||||
orderTable = []byte("acme_orders")
|
||||
ordersByAccountIDTable = []byte("acme_account-orders-index")
|
||||
certTable = []byte("acme_certs")
|
||||
)
|
||||
|
||||
// NewAuthority returns a new Authority that implements the ACME interface.
|
||||
func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) {
|
||||
if _, ok := db.(*database.SimpleDB); !ok {
|
||||
// If it's not a SimpleDB then go ahead and bootstrap the DB with the
|
||||
// necessary ACME tables. SimpleDB should ONLY be used for testing.
|
||||
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
|
||||
challengeTable, nonceTable, orderTable, ordersByAccountIDTable,
|
||||
certTable}
|
||||
for _, b := range tables {
|
||||
if err := db.CreateTable(b); err != nil {
|
||||
return nil, errors.Wrapf(err, "error creating table %s",
|
||||
string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
return &Authority{
|
||||
db: db, dir: newDirectory(dns, prefix), signAuth: signAuth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetLink returns the requested link from the directory.
|
||||
func (a *Authority) GetLink(typ Link, provID string, abs bool, inputs ...string) string {
|
||||
return a.dir.getLink(typ, provID, abs, inputs...)
|
||||
}
|
||||
|
||||
// GetDirectory returns the ACME directory object.
|
||||
func (a *Authority) GetDirectory(p provisioner.Interface) *Directory {
|
||||
name := url.PathEscape(p.GetName())
|
||||
return &Directory{
|
||||
NewNonce: a.dir.getLink(NewNonceLink, name, true),
|
||||
NewAccount: a.dir.getLink(NewAccountLink, name, true),
|
||||
NewOrder: a.dir.getLink(NewOrderLink, name, true),
|
||||
RevokeCert: a.dir.getLink(RevokeCertLink, name, true),
|
||||
KeyChange: a.dir.getLink(KeyChangeLink, name, true),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadProvisionerByID calls out to the SignAuthority interface to load a
|
||||
// provisioner by ID.
|
||||
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
return a.signAuth.LoadProvisionerByID(id)
|
||||
}
|
||||
|
||||
// NewNonce generates, stores, and returns a new ACME nonce.
|
||||
func (a *Authority) NewNonce() (string, error) {
|
||||
n, err := newNonce(a.db)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return n.ID, nil
|
||||
}
|
||||
|
||||
// UseNonce consumes the given nonce if it is valid, returns error otherwise.
|
||||
func (a *Authority) UseNonce(nonce string) error {
|
||||
return useNonce(a.db, nonce)
|
||||
}
|
||||
|
||||
// NewAccount creates, stores, and returns a new ACME account.
|
||||
func (a *Authority) NewAccount(p provisioner.Interface, ao AccountOptions) (*Account, error) {
|
||||
acc, err := newAccount(a.db, ao)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// UpdateAccount updates an ACME account.
|
||||
func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(err)
|
||||
}
|
||||
if acc, err = acc.update(a.db, contact); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetAccount returns an ACME account.
|
||||
func (a *Authority) GetAccount(p provisioner.Interface, id string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// DeactivateAccount deactivates an ACME account.
|
||||
func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if acc, err = acc.deactivate(a.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
func keyToID(jwk *jose.JSONWebKey) (string, error) {
|
||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint"))
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||
}
|
||||
|
||||
// GetAccountByKey returns the ACME associated with the jwk id.
|
||||
func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*Account, error) {
|
||||
kid, err := keyToID(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
acc, err := getAccountByKeyID(a.db, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetOrder returns an ACME order.
|
||||
func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*Order, error) {
|
||||
o, err := getOrder(a.db, orderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != o.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||
}
|
||||
if o, err = o.updateStatus(a.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return o.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetOrdersByAccount returns the list of order urls owned by the account.
|
||||
func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) {
|
||||
oids, err := getOrderIDsByAccount(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret = []string{}
|
||||
for _, oid := range oids {
|
||||
o, err := getOrder(a.db, oid)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(err)
|
||||
}
|
||||
if o.Status == StatusInvalid {
|
||||
continue
|
||||
}
|
||||
ret = append(ret, a.dir.getLink(OrderLink, URLSafeProvisionerName(p), true, o.ID))
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// NewOrder generates, stores, and returns a new ACME order.
|
||||
func (a *Authority) NewOrder(p provisioner.Interface, ops OrderOptions) (*Order, error) {
|
||||
order, err := newOrder(a.db, ops)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating order")
|
||||
}
|
||||
return order.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// FinalizeOrder attempts to finalize an order and generate a new certificate.
|
||||
func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
|
||||
o, err := getOrder(a.db, orderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != o.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||
}
|
||||
o, err = o.finalize(a.db, csr, a.signAuth, p)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error finalizing order")
|
||||
}
|
||||
return o.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetAuthz retrieves and attempts to update the status on an ACME authz
|
||||
// before returning.
|
||||
func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*Authz, error) {
|
||||
az, err := getAuthz(a.db, authzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != az.getAccountID() {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own authz"))
|
||||
}
|
||||
az, err = az.updateStatus(a.db)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error updating authz status")
|
||||
}
|
||||
return az.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// ValidateChallenge attempts to validate the challenge.
|
||||
func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) {
|
||||
ch, err := getChallenge(a.db, chID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != ch.getAccountID() {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own challenge"))
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: time.Duration(30 * time.Second),
|
||||
}
|
||||
ch, err = ch.validate(a.db, jwk, validateOptions{
|
||||
httpGet: client.Get,
|
||||
lookupTxt: net.LookupTXT,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error attempting challenge validation")
|
||||
}
|
||||
return ch.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetCertificate retrieves the Certificate by ID.
|
||||
func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) {
|
||||
cert, err := getCert(a.db, certID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != cert.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own certificate"))
|
||||
}
|
||||
return cert.toACME(a.db, a.dir)
|
||||
}
|
1517
acme/authority_test.go
Normal file
1517
acme/authority_test.go
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,70 +0,0 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Authorization representst an ACME Authorization.
|
||||
type Authorization struct {
|
||||
ID string `json:"-"`
|
||||
AccountID string `json:"-"`
|
||||
Token string `json:"-"`
|
||||
Fingerprint string `json:"-"`
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status Status `json:"status"`
|
||||
Challenges []*Challenge `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
ExpiresAt time.Time `json:"expires"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (az *Authorization) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(az)
|
||||
if err != nil {
|
||||
return nil, WrapErrorISE(err, "error marshaling authz for logging")
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the ACME Authorization Status if necessary.
|
||||
// Changes to the Authorization are saved using the database interface.
|
||||
func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
|
||||
now := clock.Now()
|
||||
|
||||
switch az.Status {
|
||||
case StatusInvalid:
|
||||
return nil
|
||||
case StatusValid:
|
||||
return nil
|
||||
case StatusPending:
|
||||
// check expiry
|
||||
if now.After(az.ExpiresAt) {
|
||||
az.Status = StatusInvalid
|
||||
break
|
||||
}
|
||||
|
||||
var isValid = false
|
||||
for _, ch := range az.Challenges {
|
||||
if ch.Status == StatusValid {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
return nil
|
||||
}
|
||||
az.Status = StatusValid
|
||||
az.Error = nil
|
||||
default:
|
||||
return NewErrorISE("unrecognized authorization status: %s", az.Status)
|
||||
}
|
||||
|
||||
if err := db.UpdateAuthorization(ctx, az); err != nil {
|
||||
return WrapErrorISE(err, "error updating authorization")
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,150 +0,0 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestAuthorization_UpdateStatus(t *testing.T) {
|
||||
type test struct {
|
||||
az *Authorization
|
||||
err *Error
|
||||
db DB
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok/already-invalid": func(t *testing.T) test {
|
||||
az := &Authorization{
|
||||
Status: StatusInvalid,
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
}
|
||||
},
|
||||
"ok/already-valid": func(t *testing.T) test {
|
||||
az := &Authorization{
|
||||
Status: StatusInvalid,
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
}
|
||||
},
|
||||
"fail/error-unexpected-status": func(t *testing.T) test {
|
||||
az := &Authorization{
|
||||
Status: "foo",
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
err: NewErrorISE("unrecognized authorization status: %s", az.Status),
|
||||
}
|
||||
},
|
||||
"ok/expired": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &Authorization{
|
||||
ID: "azID",
|
||||
AccountID: "accID",
|
||||
Status: StatusPending,
|
||||
ExpiresAt: now.Add(-5 * time.Minute),
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
db: &MockDB{
|
||||
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
|
||||
assert.Equals(t, updaz.ID, az.ID)
|
||||
assert.Equals(t, updaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, updaz.Status, StatusInvalid)
|
||||
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/db.UpdateAuthorization-error": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &Authorization{
|
||||
ID: "azID",
|
||||
AccountID: "accID",
|
||||
Status: StatusPending,
|
||||
ExpiresAt: now.Add(-5 * time.Minute),
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
db: &MockDB{
|
||||
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
|
||||
assert.Equals(t, updaz.ID, az.ID)
|
||||
assert.Equals(t, updaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, updaz.Status, StatusInvalid)
|
||||
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
err: NewErrorISE("error updating authorization: force"),
|
||||
}
|
||||
},
|
||||
"ok/no-valid-challenges": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &Authorization{
|
||||
ID: "azID",
|
||||
AccountID: "accID",
|
||||
Status: StatusPending,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Challenges: []*Challenge{
|
||||
{Status: StatusPending}, {Status: StatusPending}, {Status: StatusPending},
|
||||
},
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
}
|
||||
},
|
||||
"ok/valid": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &Authorization{
|
||||
ID: "azID",
|
||||
AccountID: "accID",
|
||||
Status: StatusPending,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Challenges: []*Challenge{
|
||||
{Status: StatusPending}, {Status: StatusPending}, {Status: StatusValid},
|
||||
},
|
||||
}
|
||||
return test{
|
||||
az: az,
|
||||
db: &MockDB{
|
||||
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
|
||||
assert.Equals(t, updaz.ID, az.ID)
|
||||
assert.Equals(t, updaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, updaz.Status, StatusValid)
|
||||
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
|
||||
assert.Equals(t, updaz.Error, nil)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
var k *Error
|
||||
if errors.As(err, &k) {
|
||||
assert.Equals(t, k.Type, tc.err.Type)
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
assert.Equals(t, k.Status, tc.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
} else {
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
337
acme/authz.go
Normal file
337
acme/authz.go
Normal file
|
@ -0,0 +1,337 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
var defaultExpiryDuration = time.Hour * 24
|
||||
|
||||
// Authz is a subset of the Authz type containing only those attributes
|
||||
// required for responses in the ACME protocol.
|
||||
type Authz struct {
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status string `json:"status"`
|
||||
Expires string `json:"expires"`
|
||||
Challenges []*Challenge `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
ID string `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (a *Authz) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the Authz ID.
|
||||
func (a *Authz) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
// authz is the interface that the various authz types must implement.
|
||||
type authz interface {
|
||||
save(nosql.DB, authz) error
|
||||
clone() *baseAuthz
|
||||
getID() string
|
||||
getAccountID() string
|
||||
getType() string
|
||||
getIdentifier() Identifier
|
||||
getStatus() string
|
||||
getExpiry() time.Time
|
||||
getWildcard() bool
|
||||
getChallenges() []string
|
||||
getCreated() time.Time
|
||||
updateStatus(db nosql.DB) (authz, error)
|
||||
toACME(nosql.DB, *directory, provisioner.Interface) (*Authz, error)
|
||||
}
|
||||
|
||||
// baseAuthz is the base authz type that others build from.
|
||||
type baseAuthz struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status string `json:"status"`
|
||||
Expires time.Time `json:"expires"`
|
||||
Challenges []string `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
Created time.Time `json:"created"`
|
||||
Error *Error `json:"error"`
|
||||
}
|
||||
|
||||
func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
ba := &baseAuthz{
|
||||
ID: id,
|
||||
AccountID: accID,
|
||||
Status: StatusPending,
|
||||
Created: now,
|
||||
Expires: now.Add(defaultExpiryDuration),
|
||||
Identifier: identifier,
|
||||
}
|
||||
|
||||
if strings.HasPrefix(identifier.Value, "*.") {
|
||||
ba.Wildcard = true
|
||||
ba.Identifier = Identifier{
|
||||
Value: strings.TrimPrefix(identifier.Value, "*."),
|
||||
Type: identifier.Type,
|
||||
}
|
||||
}
|
||||
|
||||
return ba, nil
|
||||
}
|
||||
|
||||
// getID returns the ID of the authz.
|
||||
func (ba *baseAuthz) getID() string {
|
||||
return ba.ID
|
||||
}
|
||||
|
||||
// getAccountID returns the Account ID that created the authz.
|
||||
func (ba *baseAuthz) getAccountID() string {
|
||||
return ba.AccountID
|
||||
}
|
||||
|
||||
// getType returns the type of the authz.
|
||||
func (ba *baseAuthz) getType() string {
|
||||
return ba.Identifier.Type
|
||||
}
|
||||
|
||||
// getIdentifier returns the identifier for the authz.
|
||||
func (ba *baseAuthz) getIdentifier() Identifier {
|
||||
return ba.Identifier
|
||||
}
|
||||
|
||||
// getStatus returns the status of the authz.
|
||||
func (ba *baseAuthz) getStatus() string {
|
||||
return ba.Status
|
||||
}
|
||||
|
||||
// getWildcard returns true if the authz identifier has a '*', false otherwise.
|
||||
func (ba *baseAuthz) getWildcard() bool {
|
||||
return ba.Wildcard
|
||||
}
|
||||
|
||||
// getChallenges returns the authz challenge IDs.
|
||||
func (ba *baseAuthz) getChallenges() []string {
|
||||
return ba.Challenges
|
||||
}
|
||||
|
||||
// getExpiry returns the expiration time of the authz.
|
||||
func (ba *baseAuthz) getExpiry() time.Time {
|
||||
return ba.Expires
|
||||
}
|
||||
|
||||
// getCreated returns the created time of the authz.
|
||||
func (ba *baseAuthz) getCreated() time.Time {
|
||||
return ba.Created
|
||||
}
|
||||
|
||||
// toACME converts the internal Authz type into the public acmeAuthz type for
|
||||
// presentation in the ACME protocol.
|
||||
func (ba *baseAuthz) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Authz, error) {
|
||||
var chs = make([]*Challenge, len(ba.Challenges))
|
||||
for i, chID := range ba.Challenges {
|
||||
ch, err := getChallenge(db, chID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chs[i], err = ch.toACME(db, dir, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Authz{
|
||||
Identifier: ba.Identifier,
|
||||
Status: ba.getStatus(),
|
||||
Challenges: chs,
|
||||
Wildcard: ba.getWildcard(),
|
||||
Expires: ba.Expires.Format(time.RFC3339),
|
||||
ID: ba.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) save(db nosql.DB, old authz) error {
|
||||
var (
|
||||
err error
|
||||
oldB, newB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old authz"))
|
||||
}
|
||||
}
|
||||
if newB, err = json.Marshal(ba); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new authz"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error storing authz"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("error storing authz; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) clone() *baseAuthz {
|
||||
u := *ba
|
||||
return &u
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) parent() authz {
|
||||
return &dnsAuthz{ba}
|
||||
}
|
||||
|
||||
// updateStatus attempts to update the status on a baseAuthz and stores the
|
||||
// updating object if necessary.
|
||||
func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) {
|
||||
newAuthz := ba.clone()
|
||||
|
||||
now := time.Now().UTC()
|
||||
switch ba.Status {
|
||||
case StatusInvalid:
|
||||
return ba.parent(), nil
|
||||
case StatusValid:
|
||||
return ba.parent(), nil
|
||||
case StatusPending:
|
||||
// check expiry
|
||||
if now.After(ba.Expires) {
|
||||
newAuthz.Status = StatusInvalid
|
||||
newAuthz.Error = MalformedErr(errors.New("authz has expired"))
|
||||
break
|
||||
}
|
||||
|
||||
var isValid = false
|
||||
for _, chID := range ba.Challenges {
|
||||
ch, err := getChallenge(db, chID)
|
||||
if err != nil {
|
||||
return ba, err
|
||||
}
|
||||
if ch.getStatus() == StatusValid {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
return ba.parent(), nil
|
||||
}
|
||||
newAuthz.Status = StatusValid
|
||||
newAuthz.Error = nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status))
|
||||
}
|
||||
|
||||
if err := newAuthz.save(db, ba); err != nil {
|
||||
return ba, err
|
||||
}
|
||||
return newAuthz.parent(), nil
|
||||
}
|
||||
|
||||
// unmarshalAuthz unmarshals an authz type into the correct sub-type.
|
||||
func unmarshalAuthz(data []byte) (authz, error) {
|
||||
var getType struct {
|
||||
Identifier Identifier `json:"identifier"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &getType); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type"))
|
||||
}
|
||||
|
||||
switch getType.Identifier.Type {
|
||||
case "dns":
|
||||
var ba baseAuthz
|
||||
if err := json.Unmarshal(data, &ba); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz"))
|
||||
}
|
||||
return &dnsAuthz{&ba}, nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s",
|
||||
getType.Identifier.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// dnsAuthz represents a dns acme authorization.
|
||||
type dnsAuthz struct {
|
||||
*baseAuthz
|
||||
}
|
||||
|
||||
// newAuthz returns a new acme authorization object based on the identifier
|
||||
// type.
|
||||
func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) {
|
||||
switch identifier.Type {
|
||||
case "dns":
|
||||
a, err = newDNSAuthz(db, accID, identifier)
|
||||
default:
|
||||
err = MalformedErr(errors.Errorf("unexpected authz type %s",
|
||||
identifier.Type))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// newDNSAuthz returns a new dns acme authorization object.
|
||||
func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) {
|
||||
ba, err := newBaseAuthz(accID, identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ba.Challenges = []string{}
|
||||
if !ba.Wildcard {
|
||||
// http challenges are only permitted if the DNS is not a wildcard dns.
|
||||
ch1, err := newHTTP01Challenge(db, ChallengeOptions{
|
||||
AccountID: accID,
|
||||
AuthzID: ba.ID,
|
||||
Identifier: ba.Identifier})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating http challenge")
|
||||
}
|
||||
ba.Challenges = append(ba.Challenges, ch1.getID())
|
||||
}
|
||||
ch2, err := newDNS01Challenge(db, ChallengeOptions{
|
||||
AccountID: accID,
|
||||
AuthzID: ba.ID,
|
||||
Identifier: identifier})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating dns challenge")
|
||||
}
|
||||
ba.Challenges = append(ba.Challenges, ch2.getID())
|
||||
|
||||
da := &dnsAuthz{ba}
|
||||
if err := da.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return da, nil
|
||||
}
|
||||
|
||||
// getAuthz retrieves and unmarshals an ACME authz type from the database.
|
||||
func getAuthz(db nosql.DB, id string) (authz, error) {
|
||||
b, err := db.Get(authzTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id))
|
||||
}
|
||||
az, err := unmarshalAuthz(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return az, nil
|
||||
}
|
809
acme/authz_test.go
Normal file
809
acme/authz_test.go
Normal file
|
@ -0,0 +1,809 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func newAz() (authz, error) {
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
}
|
||||
return newAuthz(mockdb, "1234", Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAuthz(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
az authz
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Identifier.Type = "foo"
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, key, []byte(az.getID()))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, key, []byte(az.getID()))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if az, err := getAuthz(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.getID(), az.getID())
|
||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||
assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
|
||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzClone(t *testing.T) {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
clone := az.clone()
|
||||
|
||||
assert.Equals(t, clone.getID(), az.getID())
|
||||
assert.Equals(t, clone.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, clone.getStatus(), az.getStatus())
|
||||
assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
|
||||
assert.Equals(t, clone.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, clone.getCreated(), az.getCreated())
|
||||
assert.Equals(t, clone.getChallenges(), az.getChallenges())
|
||||
|
||||
clone.Status = StatusValid
|
||||
|
||||
assert.NotEquals(t, clone.getStatus(), az.getStatus())
|
||||
}
|
||||
|
||||
func TestNewAuthz(t *testing.T) {
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
accID := "1234"
|
||||
type test struct {
|
||||
iden Identifier
|
||||
db nosql.DB
|
||||
err *Error
|
||||
resChs *([]string)
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/unexpected-type": func(t *testing.T) test {
|
||||
return test{
|
||||
iden: Identifier{Type: "foo", Value: "acme.example.com"},
|
||||
err: MalformedErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"fail/new-http-chall-error": func(t *testing.T) test {
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
|
||||
}
|
||||
},
|
||||
"fail/new-dns-chall-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 1 {
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
|
||||
}
|
||||
},
|
||||
"fail/save-authz-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 2 {
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
chs := &([]string{})
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 2 {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
az, err := unmarshalAuthz(newval)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, az.getID(), string(key))
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
assert.Equals(t, az.getIdentifier(), iden)
|
||||
assert.Equals(t, az.getWildcard(), false)
|
||||
|
||||
*chs = az.getChallenges()
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
resChs: chs,
|
||||
}
|
||||
},
|
||||
"ok/wildcard": func(t *testing.T) test {
|
||||
chs := &([]string{})
|
||||
count := 0
|
||||
_iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
|
||||
return test{
|
||||
iden: _iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 1 {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
az, err := unmarshalAuthz(newval)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, az.getID(), string(key))
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
assert.Equals(t, az.getIdentifier(), iden)
|
||||
assert.Equals(t, az.getWildcard(), true)
|
||||
|
||||
*chs = az.getChallenges()
|
||||
// Verify that we only have 1 challenge instead of 2.
|
||||
assert.True(t, len(*chs) == 1)
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
resChs: chs,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
az, err := newAuthz(tc.db, accID, tc.iden)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getType(), "dns")
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
|
||||
assert.Equals(t, az.getChallenges(), *(tc.resChs))
|
||||
|
||||
if strings.HasPrefix(tc.iden.Value, "*.") {
|
||||
assert.True(t, az.getWildcard())
|
||||
assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
|
||||
} else {
|
||||
assert.False(t, az.getWildcard())
|
||||
assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
|
||||
}
|
||||
|
||||
assert.True(t, az.getID() != "")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzToACME(t *testing.T) {
|
||||
dir := newDirectory("ca.smallstep.com", "acme")
|
||||
|
||||
var (
|
||||
ch1, ch2 challenge
|
||||
ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||
err error
|
||||
)
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
ch1, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
} else if count == 1 {
|
||||
*ch2Bytes = newval
|
||||
ch2, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
count++
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
prov := newProv()
|
||||
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/getChallenge1-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"fail/getChallenge2-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 1 {
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
return *ch2Bytes, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acmeAz, err := az.toACME(tc.db, dir, prov)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acmeAz.ID, az.getID())
|
||||
assert.Equals(t, acmeAz.Identifier, iden)
|
||||
assert.Equals(t, acmeAz.Status, StatusPending)
|
||||
|
||||
acmeCh1, err := ch1.toACME(nil, dir, prov)
|
||||
assert.FatalError(t, err)
|
||||
acmeCh2, err := ch2.toACME(nil, dir, prov)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
|
||||
assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
|
||||
|
||||
expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, expiry.String(), az.getExpiry().String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzSave(t *testing.T) {
|
||||
type test struct {
|
||||
az, old authz
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok/old-nil": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, []byte(az.getID()), key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/old-not-nil": func(t *testing.T) test {
|
||||
oldAz, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
oldb, err := json.Marshal(oldAz)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: oldAz,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, []byte(az.getID()), key)
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.az.save(tc.db, tc.old); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzUnmarshal(t *testing.T) {
|
||||
type test struct {
|
||||
az authz
|
||||
azb []byte
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/nil": func(t *testing.T) test {
|
||||
return test{
|
||||
azb: nil,
|
||||
err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/unexpected-type": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Identifier.Type = "foo"
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
azb: b,
|
||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"ok/dns": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
azb: b,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if az, err := unmarshalAuthz(tc.azb); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.getID(), az.getID())
|
||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
|
||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzUpdateStatus(t *testing.T) {
|
||||
type test struct {
|
||||
az, res authz
|
||||
err *Error
|
||||
db nosql.DB
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/already-invalid": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusInvalid
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
}
|
||||
},
|
||||
"fail/already-valid": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusValid
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
}
|
||||
},
|
||||
"fail/unexpected-status": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusReady
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
|
||||
}
|
||||
},
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"ok/expired": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||
|
||||
clone := az.clone()
|
||||
clone.Error = MalformedErr(errors.New("authz has expired"))
|
||||
clone.Status = StatusInvalid
|
||||
return test{
|
||||
az: az,
|
||||
res: clone.parent(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/get-challenge-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"ok/valid": func(t *testing.T) test {
|
||||
var (
|
||||
ch2 challenge
|
||||
ch1Bytes = &([]byte{})
|
||||
err error
|
||||
)
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
} else if count == 1 {
|
||||
ch2, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Error = MalformedErr(nil)
|
||||
|
||||
_ch, ok := ch2.(*dns01Challenge)
|
||||
assert.Fatal(t, ok)
|
||||
_ch.baseChallenge.Status = StatusValid
|
||||
chb, err := json.Marshal(ch2)
|
||||
|
||||
clone := az.clone()
|
||||
clone.Status = StatusValid
|
||||
clone.Error = nil
|
||||
|
||||
count = 0
|
||||
return test{
|
||||
az: az,
|
||||
res: clone.parent(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
count++
|
||||
return chb, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/still-pending": func(t *testing.T) test {
|
||||
var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
} else if count == 1 {
|
||||
*ch2Bytes = newval
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
count = 0
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
count++
|
||||
return *ch2Bytes, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
az, err := tc.az.updateStatus(tc.db)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
expB, err := json.Marshal(tc.res)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, expB, b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,13 +2,88 @@ package acme
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Certificate options with which to create and store a cert object.
|
||||
type Certificate struct {
|
||||
ID string
|
||||
type certificate struct {
|
||||
ID string `json:"id"`
|
||||
Created time.Time `json:"created"`
|
||||
AccountID string `json:"accountID"`
|
||||
OrderID string `json:"orderID"`
|
||||
Leaf []byte `json:"leaf"`
|
||||
Intermediates []byte `json:"intermediates"`
|
||||
}
|
||||
|
||||
// CertOptions options with which to create and store a cert object.
|
||||
type CertOptions struct {
|
||||
AccountID string
|
||||
OrderID string
|
||||
Leaf *x509.Certificate
|
||||
Intermediates []*x509.Certificate
|
||||
}
|
||||
|
||||
func newCert(db nosql.DB, ops CertOptions) (*certificate, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leaf := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: ops.Leaf.Raw,
|
||||
})
|
||||
var intermediates []byte
|
||||
for _, cert := range ops.Intermediates {
|
||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
|
||||
cert := &certificate{
|
||||
ID: id,
|
||||
AccountID: ops.AccountID,
|
||||
OrderID: ops.OrderID,
|
||||
Leaf: leaf,
|
||||
Intermediates: intermediates,
|
||||
Created: time.Now().UTC(),
|
||||
}
|
||||
certB, err := json.Marshal(cert)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate"))
|
||||
}
|
||||
|
||||
_, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate"))
|
||||
case !swapped:
|
||||
return nil, ServerInternalErr(errors.New("error storing certificate; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) {
|
||||
return append(c.Leaf, c.Intermediates...), nil
|
||||
}
|
||||
|
||||
func getCert(db nosql.DB, id string) (*certificate, error) {
|
||||
b, err := db.Get(certTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate"))
|
||||
}
|
||||
var cert certificate
|
||||
if err := json.Unmarshal(b, &cert); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate"))
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
|
253
acme/certificate_test.go
Normal file
253
acme/certificate_test.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func defaultCertOps() (*CertOptions, error) {
|
||||
crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CertOptions{
|
||||
AccountID: "accID",
|
||||
OrderID: "ordID",
|
||||
Leaf: crt,
|
||||
Intermediates: []*x509.Certificate{inter, root},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newcert() (*certificate, error) {
|
||||
ops, err := defaultCertOps()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
return newCert(mockdb, *ops)
|
||||
}
|
||||
|
||||
func TestNewCert(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
ops CertOptions
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing certificate: force")),
|
||||
}
|
||||
},
|
||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
var _id string
|
||||
id := &_id
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
*id = string(key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if cert, err := newCert(tc.db, tc.ops); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, cert.ID, *tc.id)
|
||||
assert.Equals(t, cert.AccountID, tc.ops.AccountID)
|
||||
assert.Equals(t, cert.OrderID, tc.ops.OrderID)
|
||||
|
||||
leaf := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: tc.ops.Leaf.Raw,
|
||||
})
|
||||
var intermediates []byte
|
||||
for _, cert := range tc.ops.Intermediates {
|
||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, intermediates)
|
||||
|
||||
assert.True(t, cert.Created.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, cert.Created.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCert(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
cert *certificate
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading certificate: force")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(cert)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if cert, err := getCert(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.cert.ID, cert.ID)
|
||||
assert.Equals(t, tc.cert.AccountID, cert.AccountID)
|
||||
assert.Equals(t, tc.cert.OrderID, cert.OrderID)
|
||||
assert.Equals(t, tc.cert.Created, cert.Created)
|
||||
assert.Equals(t, tc.cert.Leaf, cert.Leaf)
|
||||
assert.Equals(t, tc.cert.Intermediates, cert.Intermediates)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateToACME(t *testing.T) {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
acmeCert, err := cert.toACME(nil, nil)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert)
|
||||
}
|
1405
acme/challenge.go
1405
acme/challenge.go
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -1,860 +0,0 @@
|
|||
//go:build tpmsimulator
|
||||
// +build tpmsimulator
|
||||
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/smallstep/go-attestation/attest"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/minica"
|
||||
"go.step.sm/crypto/tpm"
|
||||
"go.step.sm/crypto/tpm/simulator"
|
||||
tpmstorage "go.step.sm/crypto/tpm/storage"
|
||||
"go.step.sm/crypto/x509util"
|
||||
)
|
||||
|
||||
func newSimulatedTPM(t *testing.T) *tpm.TPM {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
tpm, err := tpm.New(withSimulator(t), tpm.WithStore(tpmstorage.NewDirstore(tmpDir))) // TODO: provide in-memory storage implementation instead
|
||||
require.NoError(t, err)
|
||||
return tpm
|
||||
}
|
||||
|
||||
func withSimulator(t *testing.T) tpm.NewTPMOption {
|
||||
t.Helper()
|
||||
var sim simulator.Simulator
|
||||
t.Cleanup(func() {
|
||||
if sim == nil {
|
||||
return
|
||||
}
|
||||
err := sim.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
sim, err := simulator.New()
|
||||
require.NoError(t, err)
|
||||
err = sim.Open()
|
||||
require.NoError(t, err)
|
||||
return tpm.WithSimulator(sim)
|
||||
}
|
||||
|
||||
func generateKeyID(t *testing.T, pub crypto.PublicKey) []byte {
|
||||
t.Helper()
|
||||
b, err := x509.MarshalPKIXPublicKey(pub)
|
||||
require.NoError(t, err)
|
||||
hash := sha256.Sum256(b)
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
func mustAttestTPM(t *testing.T, keyAuthorization string, permanentIdentifiers []string) ([]byte, crypto.Signer, *x509.Certificate) {
|
||||
t.Helper()
|
||||
aca, err := minica.New(
|
||||
minica.WithName("TPM Testing"),
|
||||
minica.WithGetSignerFunc(
|
||||
func() (crypto.Signer, error) {
|
||||
return keyutil.GenerateSigner("RSA", "", 2048)
|
||||
},
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// prepare simulated TPM and create an AK
|
||||
stpm := newSimulatedTPM(t)
|
||||
eks, err := stpm.GetEKs(context.Background())
|
||||
require.NoError(t, err)
|
||||
ak, err := stpm.CreateAK(context.Background(), "first-ak")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ak)
|
||||
|
||||
// extract the AK public key // TODO(hs): replace this when there's a simpler method to get the AK public key (e.g. ak.Public())
|
||||
ap, err := ak.AttestationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
akp, err := attest.ParseAKPublic(attest.TPMVersion20, ap.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create template and sign certificate for the AK public key
|
||||
keyID := generateKeyID(t, eks[0].Public())
|
||||
template := &x509.Certificate{
|
||||
PublicKey: akp.Public,
|
||||
IsCA: false,
|
||||
UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate},
|
||||
}
|
||||
sans := []x509util.SubjectAlternativeName{}
|
||||
uris := []*url.URL{{Scheme: "urn", Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID)}}
|
||||
for _, pi := range permanentIdentifiers {
|
||||
sans = append(sans, x509util.SubjectAlternativeName{
|
||||
Type: x509util.PermanentIdentifierType,
|
||||
Value: pi,
|
||||
})
|
||||
}
|
||||
asn1Value := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55"))
|
||||
sans = append(sans, x509util.SubjectAlternativeName{
|
||||
Type: x509util.DirectoryNameType,
|
||||
ASN1Value: asn1Value,
|
||||
})
|
||||
ext, err := createSubjectAltNameExtension(nil, nil, nil, uris, sans, true)
|
||||
require.NoError(t, err)
|
||||
ext.Set(template)
|
||||
akCert, err := aca.Sign(template)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, akCert)
|
||||
|
||||
// create a new key attested by the AK, while including
|
||||
// the key authorization bytes as qualifying data.
|
||||
keyAuthSum := sha256.Sum256([]byte(keyAuthorization))
|
||||
config := tpm.AttestKeyConfig{
|
||||
Algorithm: "RSA",
|
||||
Size: 2048,
|
||||
QualifyingData: keyAuthSum[:],
|
||||
}
|
||||
key, err := stpm.AttestKey(context.Background(), "first-ak", "first-key", config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
require.Equal(t, "first-key", key.Name())
|
||||
require.NotEqual(t, 0, len(key.Data()))
|
||||
require.Equal(t, "first-ak", key.AttestedBy())
|
||||
require.True(t, key.WasAttested())
|
||||
require.True(t, key.WasAttestedBy(ak))
|
||||
|
||||
signer, err := key.Signer(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// prepare the attestation object with the AK certificate chain,
|
||||
// the attested key, its metadata and the signature signed by the
|
||||
// AK.
|
||||
params, err := key.CertificationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
attObj, err := cbor.Marshal(struct {
|
||||
Format string `json:"fmt"`
|
||||
AttStatement map[string]interface{} `json:"attStmt,omitempty"`
|
||||
}{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// marshal the ACME payload
|
||||
payload, err := json.Marshal(struct {
|
||||
AttObj string `json:"attObj"`
|
||||
}{
|
||||
AttObj: base64.RawURLEncoding.EncodeToString(attObj),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return payload, signer, aca.Root
|
||||
}
|
||||
|
||||
func Test_deviceAttest01ValidateWithTPMSimulator(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
ch *Challenge
|
||||
db DB
|
||||
jwk *jose.JSONWebKey
|
||||
payload []byte
|
||||
}
|
||||
type test struct {
|
||||
args args
|
||||
wantErr *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok/doTPMAttestationFormat-storeError": func(t *testing.T) test {
|
||||
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
|
||||
payload, _, root := mustAttestTPM(t, keyAuth, nil) // TODO: value(s) for AK cert?
|
||||
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
|
||||
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
|
||||
|
||||
// parse payload, set invalid "ver", remarshal
|
||||
var p payloadType
|
||||
err := json.Unmarshal(payload, &p)
|
||||
require.NoError(t, err)
|
||||
attObj, err := base64.RawURLEncoding.DecodeString(p.AttObj)
|
||||
require.NoError(t, err)
|
||||
att := attestationObject{}
|
||||
err = cbor.Unmarshal(attObj, &att)
|
||||
require.NoError(t, err)
|
||||
att.AttStatement["ver"] = "bogus"
|
||||
attObj, err = cbor.Marshal(struct {
|
||||
Format string `json:"fmt"`
|
||||
AttStatement map[string]interface{} `json:"attStmt,omitempty"`
|
||||
}{
|
||||
Format: "tpm",
|
||||
AttStatement: att.AttStatement,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
payload, err = json.Marshal(struct {
|
||||
AttObj string `json:"attObj"`
|
||||
}{
|
||||
AttObj: base64.RawURLEncoding.EncodeToString(attObj),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
jwk: jwk,
|
||||
ch: &Challenge{
|
||||
ID: "chID",
|
||||
AuthorizationID: "azID",
|
||||
Token: "token",
|
||||
Type: "device-attest-01",
|
||||
Status: StatusPending,
|
||||
Value: "device.id.12345678",
|
||||
},
|
||||
payload: payload,
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
assert.Equal(t, "azID", id)
|
||||
return &Authorization{ID: "azID"}, nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
assert.Equal(t, "chID", updch.ID)
|
||||
assert.Equal(t, "token", updch.Token)
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "device.id.12345678", updch.Value)
|
||||
|
||||
err := NewError(ErrorBadAttestationStatementType, `version "bogus" is not supported`)
|
||||
|
||||
assert.EqualError(t, updch.Error.Err, err.Err.Error())
|
||||
assert.Equal(t, err.Type, updch.Error.Type)
|
||||
assert.Equal(t, err.Detail, updch.Error.Detail)
|
||||
assert.Equal(t, err.Status, updch.Error.Status)
|
||||
assert.Equal(t, err.Subproblems, updch.Error.Subproblems)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
}
|
||||
},
|
||||
"ok with invalid PermanentIdentifier SAN": func(t *testing.T) test {
|
||||
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
|
||||
payload, _, root := mustAttestTPM(t, keyAuth, []string{"device.id.12345678"}) // TODO: value(s) for AK cert?
|
||||
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
|
||||
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
|
||||
return test{
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
jwk: jwk,
|
||||
ch: &Challenge{
|
||||
ID: "chID",
|
||||
AuthorizationID: "azID",
|
||||
Token: "token",
|
||||
Type: "device-attest-01",
|
||||
Status: StatusPending,
|
||||
Value: "device.id.99999999",
|
||||
},
|
||||
payload: payload,
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
assert.Equal(t, "azID", id)
|
||||
return &Authorization{ID: "azID"}, nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
assert.Equal(t, "chID", updch.ID)
|
||||
assert.Equal(t, "token", updch.Token)
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "device.id.99999999", updch.Value)
|
||||
|
||||
err := NewError(ErrorRejectedIdentifierType, `permanent identifier does not match`).
|
||||
AddSubproblems(NewSubproblemWithIdentifier(
|
||||
ErrorMalformedType,
|
||||
Identifier{Type: "permanent-identifier", Value: "device.id.99999999"},
|
||||
`challenge identifier "device.id.99999999" doesn't match any of the attested hardware identifiers ["device.id.12345678"]`,
|
||||
))
|
||||
|
||||
assert.EqualError(t, updch.Error.Err, err.Err.Error())
|
||||
assert.Equal(t, err.Type, updch.Error.Type)
|
||||
assert.Equal(t, err.Detail, updch.Error.Detail)
|
||||
assert.Equal(t, err.Status, updch.Error.Status)
|
||||
assert.Equal(t, err.Subproblems, updch.Error.Subproblems)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
|
||||
payload, signer, root := mustAttestTPM(t, keyAuth, nil) // TODO: value(s) for AK cert?
|
||||
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
|
||||
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
|
||||
return test{
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
jwk: jwk,
|
||||
ch: &Challenge{
|
||||
ID: "chID",
|
||||
AuthorizationID: "azID",
|
||||
Token: "token",
|
||||
Type: "device-attest-01",
|
||||
Status: StatusPending,
|
||||
Value: "device.id.12345678",
|
||||
},
|
||||
payload: payload,
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
assert.Equal(t, "azID", id)
|
||||
return &Authorization{ID: "azID"}, nil
|
||||
},
|
||||
MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error {
|
||||
fingerprint, err := keyutil.Fingerprint(signer.Public())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "azID", az.ID)
|
||||
assert.Equal(t, fingerprint, az.Fingerprint)
|
||||
return nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
assert.Equal(t, "chID", updch.ID)
|
||||
assert.Equal(t, "token", updch.Token)
|
||||
assert.Equal(t, StatusValid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "device.id.12345678", updch.Value)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
}
|
||||
},
|
||||
"ok with PermanentIdentifier SAN": func(t *testing.T) test {
|
||||
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
|
||||
payload, signer, root := mustAttestTPM(t, keyAuth, []string{"device.id.12345678"}) // TODO: value(s) for AK cert?
|
||||
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
|
||||
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
|
||||
return test{
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
jwk: jwk,
|
||||
ch: &Challenge{
|
||||
ID: "chID",
|
||||
AuthorizationID: "azID",
|
||||
Token: "token",
|
||||
Type: "device-attest-01",
|
||||
Status: StatusPending,
|
||||
Value: "device.id.12345678",
|
||||
},
|
||||
payload: payload,
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
assert.Equal(t, "azID", id)
|
||||
return &Authorization{ID: "azID"}, nil
|
||||
},
|
||||
MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error {
|
||||
fingerprint, err := keyutil.Fingerprint(signer.Public())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "azID", az.ID)
|
||||
assert.Equal(t, fingerprint, az.Fingerprint)
|
||||
return nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
assert.Equal(t, "chID", updch.ID)
|
||||
assert.Equal(t, "token", updch.Token)
|
||||
assert.Equal(t, StatusValid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "device.id.12345678", updch.Value)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
if err := deviceAttest01Validate(tc.args.ctx, tc.args.ch, tc.args.db, tc.args.jwk, tc.args.payload); err != nil {
|
||||
assert.Error(t, tc.wantErr)
|
||||
assert.EqualError(t, err, tc.wantErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.Nil(t, tc.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newBadAttestationStatementError(msg string) *Error {
|
||||
return &Error{
|
||||
Type: "urn:ietf:params:acme:error:badAttestationStatement",
|
||||
Status: 400,
|
||||
Err: errors.New(msg),
|
||||
}
|
||||
}
|
||||
|
||||
func newInternalServerError(msg string) *Error {
|
||||
return &Error{
|
||||
Type: "urn:ietf:params:acme:error:serverInternal",
|
||||
Status: 500,
|
||||
Err: errors.New(msg),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
oidPermanentIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}
|
||||
oidHardwareModuleNameIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 4}
|
||||
)
|
||||
|
||||
func Test_doTPMAttestationFormat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
aca, err := minica.New(
|
||||
minica.WithName("TPM Testing"),
|
||||
minica.WithGetSignerFunc(
|
||||
func() (crypto.Signer, error) {
|
||||
return keyutil.GenerateSigner("RSA", "", 2048)
|
||||
},
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
acaRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: aca.Root.Raw})
|
||||
|
||||
// prepare simulated TPM and create an AK
|
||||
stpm := newSimulatedTPM(t)
|
||||
eks, err := stpm.GetEKs(context.Background())
|
||||
require.NoError(t, err)
|
||||
ak, err := stpm.CreateAK(context.Background(), "first-ak")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ak)
|
||||
|
||||
// extract the AK public key // TODO(hs): replace this when there's a simpler method to get the AK public key (e.g. ak.Public())
|
||||
ap, err := ak.AttestationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
akp, err := attest.ParseAKPublic(attest.TPMVersion20, ap.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create template and sign certificate for the AK public key
|
||||
keyID := generateKeyID(t, eks[0].Public())
|
||||
template := &x509.Certificate{
|
||||
PublicKey: akp.Public,
|
||||
IsCA: false,
|
||||
UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate},
|
||||
}
|
||||
sans := []x509util.SubjectAlternativeName{}
|
||||
uris := []*url.URL{{Scheme: "urn", Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID)}}
|
||||
asn1Value := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55"))
|
||||
sans = append(sans, x509util.SubjectAlternativeName{
|
||||
Type: x509util.DirectoryNameType,
|
||||
ASN1Value: asn1Value,
|
||||
})
|
||||
ext, err := createSubjectAltNameExtension(nil, nil, nil, uris, sans, true)
|
||||
require.NoError(t, err)
|
||||
ext.Set(template)
|
||||
akCert, err := aca.Sign(template)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, akCert)
|
||||
|
||||
invalidTemplate := &x509.Certificate{
|
||||
PublicKey: akp.Public,
|
||||
IsCA: false,
|
||||
UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate},
|
||||
}
|
||||
invalidAKCert, err := aca.Sign(invalidTemplate)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, invalidAKCert)
|
||||
|
||||
// generate a JWK and the key authorization value
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
require.NoError(t, err)
|
||||
keyAuthorization, err := KeyAuthorization("token", jwk)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a new key attested by the AK, while including
|
||||
// the key authorization bytes as qualifying data.
|
||||
keyAuthSum := sha256.Sum256([]byte(keyAuthorization))
|
||||
config := tpm.AttestKeyConfig{
|
||||
Algorithm: "RSA",
|
||||
Size: 2048,
|
||||
QualifyingData: keyAuthSum[:],
|
||||
}
|
||||
key, err := stpm.AttestKey(context.Background(), "first-ak", "first-key", config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
params, err := key.CertificationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
signer, err := key.Signer(context.Background())
|
||||
require.NoError(t, err)
|
||||
fingerprint, err := keyutil.Fingerprint(signer.Public())
|
||||
require.NoError(t, err)
|
||||
|
||||
// attest another key and get its certification parameters
|
||||
anotherKey, err := stpm.AttestKey(context.Background(), "first-ak", "another-key", config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
anotherKeyParams, err := anotherKey.CertificationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
prov Provisioner
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
att *attestationObject
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tpmAttestationData
|
||||
expErr *Error
|
||||
}{
|
||||
{"ok", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, nil},
|
||||
{"fail ver not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("ver not present")},
|
||||
{"fail ver type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": []interface{}{},
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("ver not present")},
|
||||
{"fail bogus ver", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "bogus",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError(`version "bogus" is not supported`)},
|
||||
{"fail x5c not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c not present")},
|
||||
{"fail x5c type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": [][]byte{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c not present")},
|
||||
{"fail x5c empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is empty")},
|
||||
{"fail leaf type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{"leaf", aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is malformed")},
|
||||
{"fail leaf parse", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw[:100], aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is malformed: x509: malformed certificate")},
|
||||
{"fail intermediate type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, "intermediate"},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is malformed")},
|
||||
{"fail intermediate parse", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw[:100]},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is malformed: x509: malformed certificate")},
|
||||
{"fail roots", args{ctx, mustAttestationProvisioner(t, nil), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newInternalServerError("no root CA bundle available to verify the attestation certificate")},
|
||||
{"fail verify", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is not valid: x509: certificate signed by unknown authority")},
|
||||
{"fail validateAKCertificate", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{invalidAKCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("AK certificate is not valid: missing TPM manufacturer")},
|
||||
{"fail pubArea not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid pubArea in attestation statement")},
|
||||
{"fail pubArea type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": []interface{}{},
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid pubArea in attestation statement")},
|
||||
{"fail pubArea empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": []byte{},
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("pubArea is empty")},
|
||||
{"fail sig not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid sig in attestation statement")},
|
||||
{"fail sig type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": []interface{}{},
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid sig in attestation statement")},
|
||||
{"fail sig empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": []byte{},
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("sig is empty")},
|
||||
{"fail certInfo not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid certInfo in attestation statement")},
|
||||
{"fail certInfo type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": []interface{}{},
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid certInfo in attestation statement")},
|
||||
{"fail certInfo empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": []byte{},
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("certInfo is empty")},
|
||||
{"fail alg not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid alg in attestation statement")},
|
||||
{"fail alg type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(0), // invalid alg
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid alg 0 in attestation statement")},
|
||||
{"fail attestation verification", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": anotherKeyParams.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid certification parameters: certification refers to a different key")},
|
||||
{"fail keyAuthorization", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, &jose.JSONWebKey{Key: []byte("not an asymmetric key")}, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newInternalServerError("failed creating key auth digest: error generating JWK thumbprint: square/go-jose: unknown key type '[]uint8'")},
|
||||
{"fail different keyAuthorization", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "aDifferentToken"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), //
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("key authorization does not match")},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := doTPMAttestationFormat(tt.args.ctx, tt.args.prov, tt.args.ch, tt.args.jwk, tt.args.att)
|
||||
if tt.expErr != nil {
|
||||
var ae *Error
|
||||
if assert.True(t, errors.As(err, &ae)) {
|
||||
assert.EqualError(t, err, tt.expErr.Error())
|
||||
assert.Equal(t, ae.StatusCode(), tt.expErr.StatusCode())
|
||||
assert.Equal(t, ae.Type, tt.expErr.Type)
|
||||
}
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
if assert.NotNil(t, got) {
|
||||
assert.Equal(t, akCert, got.Certificate)
|
||||
assert.Equal(t, [][]*x509.Certificate{
|
||||
{
|
||||
akCert, aca.Intermediate, aca.Root,
|
||||
},
|
||||
}, got.VerifiedChains)
|
||||
assert.Equal(t, fingerprint, got.Fingerprint)
|
||||
assert.Empty(t, got.PermanentIdentifiers) // currently expected to be always empty
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client is the interface used to verify ACME challenges.
|
||||
type Client interface {
|
||||
// Get issues an HTTP GET to the specified URL.
|
||||
Get(url string) (*http.Response, error)
|
||||
|
||||
// LookupTXT returns the DNS TXT records for the given domain name.
|
||||
LookupTxt(name string) ([]string, error)
|
||||
|
||||
// TLSDial connects to the given network address using net.Dialer and then
|
||||
// initiates a TLS handshake, returning the resulting TLS connection.
|
||||
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
type clientKey struct{}
|
||||
|
||||
// NewClientContext adds the given client to the context.
|
||||
func NewClientContext(ctx context.Context, c Client) context.Context {
|
||||
return context.WithValue(ctx, clientKey{}, c)
|
||||
}
|
||||
|
||||
// ClientFromContext returns the current client from the given context.
|
||||
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
|
||||
c, ok = ctx.Value(clientKey{}).(Client)
|
||||
return
|
||||
}
|
||||
|
||||
// MustClientFromContext returns the current client from the given context. It will
|
||||
// return a new instance of the client if it does not exist.
|
||||
func MustClientFromContext(ctx context.Context) Client {
|
||||
c, ok := ClientFromContext(ctx)
|
||||
if !ok {
|
||||
return NewClient()
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type client struct {
|
||||
http *http.Client
|
||||
dialer *net.Dialer
|
||||
}
|
||||
|
||||
// NewClient returns an implementation of Client for verifying ACME challenges.
|
||||
func NewClient() Client {
|
||||
return &client{
|
||||
http: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
//nolint:gosec // used on tls-alpn-01 challenge
|
||||
InsecureSkipVerify: true, // lgtm[go/disabled-certificate-check]
|
||||
},
|
||||
},
|
||||
},
|
||||
dialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Get(url string) (*http.Response, error) {
|
||||
return c.http.Get(url)
|
||||
}
|
||||
|
||||
func (c *client) LookupTxt(name string) ([]string, error) {
|
||||
return net.LookupTXT(name)
|
||||
}
|
||||
|
||||
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(c.dialer, network, addr, config)
|
||||
}
|
230
acme/common.go
230
acme/common.go
|
@ -1,199 +1,65 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
)
|
||||
|
||||
// SignAuthority is the interface implemented by a CA authority.
|
||||
type SignAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// Identifier encodes the type that an order pertains to.
|
||||
type Identifier struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
var (
|
||||
// StatusValid -- valid
|
||||
StatusValid = "valid"
|
||||
// StatusInvalid -- invalid
|
||||
StatusInvalid = "invalid"
|
||||
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
|
||||
StatusPending = "pending"
|
||||
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
|
||||
StatusDeactivated = "deactivated"
|
||||
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
|
||||
StatusReady = "ready"
|
||||
//statusExpired = "expired"
|
||||
//statusActive = "active"
|
||||
//statusProcessing = "processing"
|
||||
)
|
||||
|
||||
var idLen = 32
|
||||
|
||||
func randID() (val string, err error) {
|
||||
val, err = randutil.Alphanumeric(idLen)
|
||||
if err != nil {
|
||||
return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
type Clock int
|
||||
|
||||
// Now returns the UTC time rounded to seconds.
|
||||
func (c *Clock) Now() time.Time {
|
||||
return time.Now().UTC().Truncate(time.Second)
|
||||
return time.Now().UTC().Round(time.Second)
|
||||
}
|
||||
|
||||
var clock Clock
|
||||
var clock = new(Clock)
|
||||
|
||||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
AreSANsAllowed(ctx context.Context, sans []string) error
|
||||
IsRevoked(sn string) (bool, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// NewContext adds the given acme components to the context.
|
||||
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
|
||||
ctx = NewDatabaseContext(ctx, db)
|
||||
ctx = NewClientContext(ctx, client)
|
||||
ctx = NewLinkerContext(ctx, linker)
|
||||
// Prerequisite checker is optional.
|
||||
if fn != nil {
|
||||
ctx = NewPrerequisitesCheckerContext(ctx, fn)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// PrerequisitesChecker is a function that checks if all prerequisites for
|
||||
// serving ACME are met by the CA configuration.
|
||||
type PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
|
||||
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
|
||||
// always true.
|
||||
func DefaultPrerequisitesChecker(context.Context) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type prerequisitesKey struct{}
|
||||
|
||||
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
|
||||
// context.
|
||||
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
|
||||
return context.WithValue(ctx, prerequisitesKey{}, fn)
|
||||
}
|
||||
|
||||
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
|
||||
// context.
|
||||
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
|
||||
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
|
||||
return fn, ok && fn != nil
|
||||
}
|
||||
|
||||
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
||||
// only those methods required by the ACME api/authority.
|
||||
type Provisioner interface {
|
||||
AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error
|
||||
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
|
||||
AuthorizeRevoke(ctx context.Context, token string) error
|
||||
IsChallengeEnabled(ctx context.Context, challenge provisioner.ACMEChallenge) bool
|
||||
IsAttestationFormatEnabled(ctx context.Context, format provisioner.ACMEAttestationFormat) bool
|
||||
GetAttestationRoots() (*x509.CertPool, bool)
|
||||
GetID() string
|
||||
GetName() string
|
||||
DefaultTLSCertDuration() time.Duration
|
||||
GetOptions() *provisioner.Options
|
||||
}
|
||||
|
||||
type provisionerKey struct{}
|
||||
|
||||
// NewProvisionerContext adds the given provisioner to the context.
|
||||
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
|
||||
return context.WithValue(ctx, provisionerKey{}, v)
|
||||
}
|
||||
|
||||
// ProvisionerFromContext returns the current provisioner from the given context.
|
||||
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
|
||||
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current provisioner from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustProvisionerFromContext(ctx context.Context) Provisioner {
|
||||
if v, ok := ProvisionerFromContext(ctx); !ok {
|
||||
panic("acme provisioner is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// MockProvisioner for testing
|
||||
type MockProvisioner struct {
|
||||
Mret1 interface{}
|
||||
Merr error
|
||||
MgetID func() string
|
||||
MgetName func() string
|
||||
MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error
|
||||
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
MauthorizeRevoke func(ctx context.Context, token string) error
|
||||
MisChallengeEnabled func(ctx context.Context, challenge provisioner.ACMEChallenge) bool
|
||||
MisAttFormatEnabled func(ctx context.Context, format provisioner.ACMEAttestationFormat) bool
|
||||
MgetAttestationRoots func() (*x509.CertPool, bool)
|
||||
MdefaultTLSCertDuration func() time.Duration
|
||||
MgetOptions func() *provisioner.Options
|
||||
}
|
||||
|
||||
// GetName mock
|
||||
func (m *MockProvisioner) GetName() string {
|
||||
if m.MgetName != nil {
|
||||
return m.MgetName()
|
||||
}
|
||||
return m.Mret1.(string)
|
||||
}
|
||||
|
||||
// AuthorizeOrderIdentifiers mock
|
||||
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
|
||||
if m.MauthorizeOrderIdentifier != nil {
|
||||
return m.MauthorizeOrderIdentifier(ctx, identifier)
|
||||
}
|
||||
return m.Merr
|
||||
}
|
||||
|
||||
// AuthorizeSign mock
|
||||
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
if m.MauthorizeSign != nil {
|
||||
return m.MauthorizeSign(ctx, ott)
|
||||
}
|
||||
return m.Mret1.([]provisioner.SignOption), m.Merr
|
||||
}
|
||||
|
||||
// AuthorizeRevoke mock
|
||||
func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
if m.MauthorizeRevoke != nil {
|
||||
return m.MauthorizeRevoke(ctx, token)
|
||||
}
|
||||
return m.Merr
|
||||
}
|
||||
|
||||
// IsChallengeEnabled mock
|
||||
func (m *MockProvisioner) IsChallengeEnabled(ctx context.Context, challenge provisioner.ACMEChallenge) bool {
|
||||
if m.MisChallengeEnabled != nil {
|
||||
return m.MisChallengeEnabled(ctx, challenge)
|
||||
}
|
||||
return m.Merr == nil
|
||||
}
|
||||
|
||||
// IsAttestationFormatEnabled mock
|
||||
func (m *MockProvisioner) IsAttestationFormatEnabled(ctx context.Context, format provisioner.ACMEAttestationFormat) bool {
|
||||
if m.MisAttFormatEnabled != nil {
|
||||
return m.MisAttFormatEnabled(ctx, format)
|
||||
}
|
||||
return m.Merr == nil
|
||||
}
|
||||
|
||||
func (m *MockProvisioner) GetAttestationRoots() (*x509.CertPool, bool) {
|
||||
if m.MgetAttestationRoots != nil {
|
||||
return m.MgetAttestationRoots()
|
||||
}
|
||||
return m.Mret1.(*x509.CertPool), m.Mret1 != nil
|
||||
}
|
||||
|
||||
// DefaultTLSCertDuration mock
|
||||
func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration {
|
||||
if m.MdefaultTLSCertDuration != nil {
|
||||
return m.MdefaultTLSCertDuration()
|
||||
}
|
||||
return m.Mret1.(time.Duration)
|
||||
}
|
||||
|
||||
// GetOptions mock
|
||||
func (m *MockProvisioner) GetOptions() *provisioner.Options {
|
||||
if m.MgetOptions != nil {
|
||||
return m.MgetOptions()
|
||||
}
|
||||
return m.Mret1.(*provisioner.Options)
|
||||
}
|
||||
|
||||
// GetID mock
|
||||
func (m *MockProvisioner) GetID() string {
|
||||
if m.MgetID != nil {
|
||||
return m.MgetID()
|
||||
}
|
||||
return m.Mret1.(string)
|
||||
// URLSafeProvisionerName returns a path escaped version of the ACME provisioner
|
||||
// ID that is safe to use in URL paths.
|
||||
func URLSafeProvisionerName(p provisioner.Interface) string {
|
||||
return url.PathEscape(p.GetName())
|
||||
}
|
||||
|
|
390
acme/db.go
390
acme/db.go
|
@ -1,390 +0,0 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ErrNotFound is an error that should be used by the acme.DB interface to
|
||||
// indicate that an entity does not exist. For example, in the new-account
|
||||
// endpoint, if GetAccountByKeyID returns ErrNotFound we will create the new
|
||||
// account.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// IsErrNotFound returns true if the error is a "not found" error. Returns false
|
||||
// otherwise.
|
||||
func IsErrNotFound(err error) bool {
|
||||
return errors.Is(err, ErrNotFound)
|
||||
}
|
||||
|
||||
// DB is the DB interface expected by the step-ca ACME API.
|
||||
type DB interface {
|
||||
CreateAccount(ctx context.Context, acc *Account) error
|
||||
GetAccount(ctx context.Context, id string) (*Account, error)
|
||||
GetAccountByKeyID(ctx context.Context, kid string) (*Account, error)
|
||||
UpdateAccount(ctx context.Context, acc *Account) error
|
||||
|
||||
CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
|
||||
GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
|
||||
GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
|
||||
GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
|
||||
GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
|
||||
DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error
|
||||
UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
|
||||
|
||||
CreateNonce(ctx context.Context) (Nonce, error)
|
||||
DeleteNonce(ctx context.Context, nonce Nonce) error
|
||||
|
||||
CreateAuthorization(ctx context.Context, az *Authorization) error
|
||||
GetAuthorization(ctx context.Context, id string) (*Authorization, error)
|
||||
UpdateAuthorization(ctx context.Context, az *Authorization) error
|
||||
GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error)
|
||||
|
||||
CreateCertificate(ctx context.Context, cert *Certificate) error
|
||||
GetCertificate(ctx context.Context, id string) (*Certificate, error)
|
||||
GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error)
|
||||
|
||||
CreateChallenge(ctx context.Context, ch *Challenge) error
|
||||
GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error)
|
||||
UpdateChallenge(ctx context.Context, ch *Challenge) error
|
||||
|
||||
CreateOrder(ctx context.Context, o *Order) error
|
||||
GetOrder(ctx context.Context, id string) (*Order, error)
|
||||
GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error)
|
||||
UpdateOrder(ctx context.Context, o *Order) error
|
||||
}
|
||||
|
||||
type dbKey struct{}
|
||||
|
||||
// NewDatabaseContext adds the given acme database to the context.
|
||||
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
|
||||
return context.WithValue(ctx, dbKey{}, db)
|
||||
}
|
||||
|
||||
// DatabaseFromContext returns the current acme database from the given context.
|
||||
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
|
||||
db, ok = ctx.Value(dbKey{}).(DB)
|
||||
return
|
||||
}
|
||||
|
||||
// MustDatabaseFromContext returns the current database from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustDatabaseFromContext(ctx context.Context) DB {
|
||||
if db, ok := DatabaseFromContext(ctx); !ok {
|
||||
panic("acme database is not in the context")
|
||||
} else {
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
// a mock in tests.
|
||||
type MockDB struct {
|
||||
MockCreateAccount func(ctx context.Context, acc *Account) error
|
||||
MockGetAccount func(ctx context.Context, id string) (*Account, error)
|
||||
MockGetAccountByKeyID func(ctx context.Context, kid string) (*Account, error)
|
||||
MockUpdateAccount func(ctx context.Context, acc *Account) error
|
||||
|
||||
MockCreateExternalAccountKey func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
|
||||
MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
|
||||
MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
|
||||
MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
|
||||
MockGetExternalAccountKeyByAccountID func(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
|
||||
MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error
|
||||
MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
|
||||
|
||||
MockCreateNonce func(ctx context.Context) (Nonce, error)
|
||||
MockDeleteNonce func(ctx context.Context, nonce Nonce) error
|
||||
|
||||
MockCreateAuthorization func(ctx context.Context, az *Authorization) error
|
||||
MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error)
|
||||
MockUpdateAuthorization func(ctx context.Context, az *Authorization) error
|
||||
MockGetAuthorizationsByAccountID func(ctx context.Context, accountID string) ([]*Authorization, error)
|
||||
|
||||
MockCreateCertificate func(ctx context.Context, cert *Certificate) error
|
||||
MockGetCertificate func(ctx context.Context, id string) (*Certificate, error)
|
||||
MockGetCertificateBySerial func(ctx context.Context, serial string) (*Certificate, error)
|
||||
|
||||
MockCreateChallenge func(ctx context.Context, ch *Challenge) error
|
||||
MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error)
|
||||
MockUpdateChallenge func(ctx context.Context, ch *Challenge) error
|
||||
|
||||
MockCreateOrder func(ctx context.Context, o *Order) error
|
||||
MockGetOrder func(ctx context.Context, id string) (*Order, error)
|
||||
MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error)
|
||||
MockUpdateOrder func(ctx context.Context, o *Order) error
|
||||
|
||||
MockRet1 interface{}
|
||||
MockError error
|
||||
}
|
||||
|
||||
// CreateAccount mock.
|
||||
func (m *MockDB) CreateAccount(ctx context.Context, acc *Account) error {
|
||||
if m.MockCreateAccount != nil {
|
||||
return m.MockCreateAccount(ctx, acc)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAccount mock.
|
||||
func (m *MockDB) GetAccount(ctx context.Context, id string) (*Account, error) {
|
||||
if m.MockGetAccount != nil {
|
||||
return m.MockGetAccount(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Account), m.MockError
|
||||
}
|
||||
|
||||
// GetAccountByKeyID mock
|
||||
func (m *MockDB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) {
|
||||
if m.MockGetAccountByKeyID != nil {
|
||||
return m.MockGetAccountByKeyID(ctx, kid)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Account), m.MockError
|
||||
}
|
||||
|
||||
// UpdateAccount mock
|
||||
func (m *MockDB) UpdateAccount(ctx context.Context, acc *Account) error {
|
||||
if m.MockUpdateAccount != nil {
|
||||
return m.MockUpdateAccount(ctx, acc)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateExternalAccountKey mock
|
||||
func (m *MockDB) CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) {
|
||||
if m.MockCreateExternalAccountKey != nil {
|
||||
return m.MockCreateExternalAccountKey(ctx, provisionerID, reference)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*ExternalAccountKey), m.MockError
|
||||
}
|
||||
|
||||
// GetExternalAccountKey mock
|
||||
func (m *MockDB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) {
|
||||
if m.MockGetExternalAccountKey != nil {
|
||||
return m.MockGetExternalAccountKey(ctx, provisionerID, keyID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*ExternalAccountKey), m.MockError
|
||||
}
|
||||
|
||||
// GetExternalAccountKeys mock
|
||||
func (m *MockDB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) {
|
||||
if m.MockGetExternalAccountKeys != nil {
|
||||
return m.MockGetExternalAccountKeys(ctx, provisionerID, cursor, limit)
|
||||
} else if m.MockError != nil {
|
||||
return nil, "", m.MockError
|
||||
}
|
||||
return m.MockRet1.([]*ExternalAccountKey), "", m.MockError
|
||||
}
|
||||
|
||||
// GetExternalAccountKeyByReference mock
|
||||
func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) {
|
||||
if m.MockGetExternalAccountKeyByReference != nil {
|
||||
return m.MockGetExternalAccountKeyByReference(ctx, provisionerID, reference)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*ExternalAccountKey), m.MockError
|
||||
}
|
||||
|
||||
// GetExternalAccountKeyByAccountID mock
|
||||
func (m *MockDB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) {
|
||||
if m.MockGetExternalAccountKeyByAccountID != nil {
|
||||
return m.MockGetExternalAccountKeyByAccountID(ctx, provisionerID, accountID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*ExternalAccountKey), m.MockError
|
||||
}
|
||||
|
||||
// DeleteExternalAccountKey mock
|
||||
func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error {
|
||||
if m.MockDeleteExternalAccountKey != nil {
|
||||
return m.MockDeleteExternalAccountKey(ctx, provisionerID, keyID)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// UpdateExternalAccountKey mock
|
||||
func (m *MockDB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error {
|
||||
if m.MockUpdateExternalAccountKey != nil {
|
||||
return m.MockUpdateExternalAccountKey(ctx, provisionerID, eak)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateNonce mock
|
||||
func (m *MockDB) CreateNonce(ctx context.Context) (Nonce, error) {
|
||||
if m.MockCreateNonce != nil {
|
||||
return m.MockCreateNonce(ctx)
|
||||
} else if m.MockError != nil {
|
||||
return Nonce(""), m.MockError
|
||||
}
|
||||
return m.MockRet1.(Nonce), m.MockError
|
||||
}
|
||||
|
||||
// DeleteNonce mock
|
||||
func (m *MockDB) DeleteNonce(ctx context.Context, nonce Nonce) error {
|
||||
if m.MockDeleteNonce != nil {
|
||||
return m.MockDeleteNonce(ctx, nonce)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateAuthorization mock
|
||||
func (m *MockDB) CreateAuthorization(ctx context.Context, az *Authorization) error {
|
||||
if m.MockCreateAuthorization != nil {
|
||||
return m.MockCreateAuthorization(ctx, az)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAuthorization mock
|
||||
func (m *MockDB) GetAuthorization(ctx context.Context, id string) (*Authorization, error) {
|
||||
if m.MockGetAuthorization != nil {
|
||||
return m.MockGetAuthorization(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Authorization), m.MockError
|
||||
}
|
||||
|
||||
// UpdateAuthorization mock
|
||||
func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) error {
|
||||
if m.MockUpdateAuthorization != nil {
|
||||
return m.MockUpdateAuthorization(ctx, az)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAuthorizationsByAccountID mock
|
||||
func (m *MockDB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) {
|
||||
if m.MockGetAuthorizationsByAccountID != nil {
|
||||
return m.MockGetAuthorizationsByAccountID(ctx, accountID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return nil, m.MockError
|
||||
}
|
||||
|
||||
// CreateCertificate mock
|
||||
func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error {
|
||||
if m.MockCreateCertificate != nil {
|
||||
return m.MockCreateCertificate(ctx, cert)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetCertificate mock
|
||||
func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, error) {
|
||||
if m.MockGetCertificate != nil {
|
||||
return m.MockGetCertificate(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Certificate), m.MockError
|
||||
}
|
||||
|
||||
// GetCertificateBySerial mock
|
||||
func (m *MockDB) GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error) {
|
||||
if m.MockGetCertificateBySerial != nil {
|
||||
return m.MockGetCertificateBySerial(ctx, serial)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Certificate), m.MockError
|
||||
}
|
||||
|
||||
// CreateChallenge mock
|
||||
func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error {
|
||||
if m.MockCreateChallenge != nil {
|
||||
return m.MockCreateChallenge(ctx, ch)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetChallenge mock
|
||||
func (m *MockDB) GetChallenge(ctx context.Context, chID, azID string) (*Challenge, error) {
|
||||
if m.MockGetChallenge != nil {
|
||||
return m.MockGetChallenge(ctx, chID, azID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Challenge), m.MockError
|
||||
}
|
||||
|
||||
// UpdateChallenge mock
|
||||
func (m *MockDB) UpdateChallenge(ctx context.Context, ch *Challenge) error {
|
||||
if m.MockUpdateChallenge != nil {
|
||||
return m.MockUpdateChallenge(ctx, ch)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateOrder mock
|
||||
func (m *MockDB) CreateOrder(ctx context.Context, o *Order) error {
|
||||
if m.MockCreateOrder != nil {
|
||||
return m.MockCreateOrder(ctx, o)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetOrder mock
|
||||
func (m *MockDB) GetOrder(ctx context.Context, id string) (*Order, error) {
|
||||
if m.MockGetOrder != nil {
|
||||
return m.MockGetOrder(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*Order), m.MockError
|
||||
}
|
||||
|
||||
// UpdateOrder mock
|
||||
func (m *MockDB) UpdateOrder(ctx context.Context, o *Order) error {
|
||||
if m.MockUpdateOrder != nil {
|
||||
return m.MockUpdateOrder(ctx, o)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetOrdersByAccountID mock
|
||||
func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
|
||||
if m.MockGetOrdersByAccountID != nil {
|
||||
return m.MockGetOrdersByAccountID(ctx, accID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.([]string), m.MockError
|
||||
}
|
|
@ -1,142 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
nosqlDB "github.com/smallstep/nosql"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
// dbAccount represents an ACME account.
|
||||
type dbAccount struct {
|
||||
ID string `json:"id"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status acme.Status `json:"status"`
|
||||
LocationPrefix string `json:"locationPrefix"`
|
||||
ProvisionerName string `json:"provisionerName"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
DeactivatedAt time.Time `json:"deactivatedAt"`
|
||||
}
|
||||
|
||||
func (dba *dbAccount) clone() *dbAccount {
|
||||
nu := *dba
|
||||
return &nu
|
||||
}
|
||||
|
||||
func (db *DB) getAccountIDByKeyID(_ context.Context, kid string) (string, error) {
|
||||
id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
|
||||
if err != nil {
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
return "", acme.ErrNotFound
|
||||
}
|
||||
return "", errors.Wrapf(err, "error loading key-account index for key %s", kid)
|
||||
}
|
||||
return string(id), nil
|
||||
}
|
||||
|
||||
// getDBAccount retrieves and unmarshals dbAccount.
|
||||
func (db *DB) getDBAccount(_ context.Context, id string) (*dbAccount, error) {
|
||||
data, err := db.db.Get(accountTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
return nil, acme.ErrNotFound
|
||||
}
|
||||
return nil, errors.Wrapf(err, "error loading account %s", id)
|
||||
}
|
||||
|
||||
dbacc := new(dbAccount)
|
||||
if err = json.Unmarshal(data, dbacc); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id)
|
||||
}
|
||||
return dbacc, nil
|
||||
}
|
||||
|
||||
// GetAccount retrieves an ACME account by ID.
|
||||
func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
|
||||
dbacc, err := db.getDBAccount(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &acme.Account{
|
||||
Status: dbacc.Status,
|
||||
Contact: dbacc.Contact,
|
||||
Key: dbacc.Key,
|
||||
ID: dbacc.ID,
|
||||
LocationPrefix: dbacc.LocationPrefix,
|
||||
ProvisionerName: dbacc.ProvisionerName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK).
|
||||
func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) {
|
||||
id, err := db.getAccountIDByKeyID(ctx, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.GetAccount(ctx, id)
|
||||
}
|
||||
|
||||
// CreateAccount imlements the AcmeDB.CreateAccount interface.
|
||||
func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
|
||||
var err error
|
||||
acc.ID, err = randID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dba := &dbAccount{
|
||||
ID: acc.ID,
|
||||
Key: acc.Key,
|
||||
Contact: acc.Contact,
|
||||
Status: acc.Status,
|
||||
CreatedAt: clock.Now(),
|
||||
LocationPrefix: acc.LocationPrefix,
|
||||
ProvisionerName: acc.ProvisionerName,
|
||||
}
|
||||
|
||||
kid, err := acme.KeyToID(dba.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
kidB := []byte(kid)
|
||||
|
||||
// Set the jwkID -> acme account ID index
|
||||
_, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(acc.ID))
|
||||
switch {
|
||||
case err != nil:
|
||||
return errors.Wrap(err, "error storing keyID to accountID index")
|
||||
case !swapped:
|
||||
return errors.Errorf("key-id to account-id index already exists")
|
||||
default:
|
||||
if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil {
|
||||
db.db.Del(accountByKeyIDTable, kidB)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateAccount imlements the AcmeDB.UpdateAccount interface.
|
||||
func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error {
|
||||
old, err := db.getDBAccount(ctx, acc.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
nu.Contact = acc.Contact
|
||||
nu.Status = acc.Status
|
||||
|
||||
// If the status has changed to 'deactivated', then set deactivatedAt timestamp.
|
||||
if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated {
|
||||
nu.DeactivatedAt = clock.Now()
|
||||
}
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "account", accountTable)
|
||||
}
|
|
@ -1,715 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
nosqldb "github.com/smallstep/nosql/database"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
func TestDB_getDBAccount(t *testing.T) {
|
||||
accID := "accID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbacc *dbAccount
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: acme.ErrNotFound,
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading account accID: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return []byte("foo"), nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error unmarshaling account accID into dbAccount"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbacc: dbacc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if dbacc, err := d.getDBAccount(context.Background(), accID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
||||
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
|
||||
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||
accID := "accID"
|
||||
kid := "kid"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), kid)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: acme.ErrNotFound,
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), kid)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading key-account index for key kid: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), kid)
|
||||
|
||||
return []byte(accID), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if retAccID, err := d.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, retAccID, accID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetAccount(t *testing.T) {
|
||||
accID := "accID"
|
||||
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
|
||||
provisionerName := "foo"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbacc *dbAccount
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading account accID: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
ProvisionerName: provisionerName,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbacc: dbacc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if acc, err := d.GetAccount(context.Background(), accID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, acc.LocationPrefix, tc.dbacc.LocationPrefix)
|
||||
assert.Equals(t, acc.ProvisionerName, tc.dbacc.ProvisionerName)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||
accID := "accID"
|
||||
kid := "kid"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbacc *dbAccount
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.getAccountIDByKeyID-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, string(bucket), string(accountByKeyIDTable))
|
||||
assert.Equals(t, string(key), kid)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading key-account index for key kid: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetAccount-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(accountByKeyIDTable):
|
||||
assert.Equals(t, string(key), kid)
|
||||
return []byte(accID), nil
|
||||
case string(accountTable):
|
||||
assert.Equals(t, string(key), accID)
|
||||
return nil, errors.New("force")
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading account accID: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(accountByKeyIDTable):
|
||||
assert.Equals(t, string(key), kid)
|
||||
return []byte(accID), nil
|
||||
case string(accountTable):
|
||||
assert.Equals(t, string(key), accID)
|
||||
return b, nil
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
dbacc: dbacc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if acc, err := d.GetAccountByKeyID(context.Background(), kid); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_CreateAccount(t *testing.T) {
|
||||
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
acc *acme.Account
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/keyID-cmpAndSwap-error": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), jwk.KeyID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
assert.Equals(t, nu, []byte(acc.ID))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
err: errors.New("error storing keyID to accountID index: force"),
|
||||
}
|
||||
},
|
||||
"fail/keyID-cmpAndSwap-false": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, string(key), jwk.KeyID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
assert.Equals(t, nu, []byte(acc.ID))
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
err: errors.New("key-id to account-id index already exists"),
|
||||
}
|
||||
},
|
||||
"fail/account-save-error": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
switch string(bucket) {
|
||||
case string(accountByKeyIDTable):
|
||||
assert.Equals(t, string(key), jwk.KeyID)
|
||||
assert.Equals(t, old, nil)
|
||||
return nu, true, nil
|
||||
case string(accountTable):
|
||||
assert.Equals(t, string(key), acc.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbacc := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||
assert.Equals(t, dbacc.ID, string(key))
|
||||
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||
assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix)
|
||||
assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName)
|
||||
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||
assert.True(t, dbacc.DeactivatedAt.IsZero())
|
||||
return nil, false, errors.New("force")
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
err: errors.New("error saving acme account: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
)
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
id = string(key)
|
||||
switch string(bucket) {
|
||||
case string(accountByKeyIDTable):
|
||||
assert.Equals(t, string(key), jwk.KeyID)
|
||||
assert.Equals(t, old, nil)
|
||||
return nu, true, nil
|
||||
case string(accountTable):
|
||||
assert.Equals(t, string(key), acc.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbacc := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||
assert.Equals(t, dbacc.ID, string(key))
|
||||
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||
assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix)
|
||||
assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName)
|
||||
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||
assert.True(t, dbacc.DeactivatedAt.IsZero())
|
||||
return nu, true, nil
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
_id: idPtr,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.CreateAccount(context.Background(), tc.acc); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_UpdateAccount(t *testing.T) {
|
||||
accID := "accID"
|
||||
now := clock.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
LocationPrefix: "foo",
|
||||
ProvisionerName: "alpha",
|
||||
Key: jwk,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
acc *acme.Account
|
||||
err error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
acc: &acme.Account{
|
||||
ID: accID,
|
||||
},
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading account accID: force"),
|
||||
}
|
||||
},
|
||||
"fail/already-deactivated": func(t *testing.T) test {
|
||||
clone := dbacc.clone()
|
||||
clone.Status = acme.StatusDeactivated
|
||||
clone.DeactivatedAt = now
|
||||
dbaccb, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return dbaccb, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbNew := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, clone.ID)
|
||||
assert.Equals(t, dbNew.Status, clone.Status)
|
||||
assert.Equals(t, dbNew.Contact, clone.Contact)
|
||||
assert.Equals(t, dbNew.Key.KeyID, clone.Key.KeyID)
|
||||
assert.Equals(t, dbNew.CreatedAt, clone.CreatedAt)
|
||||
assert.Equals(t, dbNew.DeactivatedAt, clone.DeactivatedAt)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme account: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbNew := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||
assert.Equals(t, dbNew.Status, acc.Status)
|
||||
assert.Equals(t, dbNew.Contact, dbacc.Contact)
|
||||
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||
assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme account: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
Contact: []string{"baz", "zap"},
|
||||
LocationPrefix: "bar",
|
||||
ProvisionerName: "beta",
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, string(key), accID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbNew := new(dbAccount)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||
assert.Equals(t, dbNew.Status, acc.Status)
|
||||
assert.Equals(t, dbNew.Contact, acc.Contact)
|
||||
// LocationPrefix should not change.
|
||||
assert.Equals(t, dbNew.LocationPrefix, dbacc.LocationPrefix)
|
||||
assert.Equals(t, dbNew.ProvisionerName, dbacc.ProvisionerName)
|
||||
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||
assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.UpdateAccount(context.Background(), tc.acc); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,156 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// dbAuthz is the base authz type that others build from.
|
||||
type dbAuthz struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Identifier acme.Identifier `json:"identifier"`
|
||||
Status acme.Status `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Fingerprint string `json:"fingerprint,omitempty"`
|
||||
ChallengeIDs []string `json:"challengeIDs"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
Error *acme.Error `json:"error"`
|
||||
}
|
||||
|
||||
func (ba *dbAuthz) clone() *dbAuthz {
|
||||
u := *ba
|
||||
return &u
|
||||
}
|
||||
|
||||
// getDBAuthz retrieves and unmarshals a database representation of the
|
||||
// ACME Authorization type.
|
||||
func (db *DB) getDBAuthz(_ context.Context, id string) (*dbAuthz, error) {
|
||||
data, err := db.db.Get(authzTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading authz %s", id)
|
||||
}
|
||||
|
||||
var dbaz dbAuthz
|
||||
if err = json.Unmarshal(data, &dbaz); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling authz %s into dbAuthz", id)
|
||||
}
|
||||
return &dbaz, nil
|
||||
}
|
||||
|
||||
// GetAuthorization retrieves and unmarshals an ACME authz type from the database.
|
||||
// Implements acme.DB GetAuthorization interface.
|
||||
func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||
dbaz, err := db.getDBAuthz(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var chs = make([]*acme.Challenge, len(dbaz.ChallengeIDs))
|
||||
for i, chID := range dbaz.ChallengeIDs {
|
||||
chs[i], err = db.GetChallenge(ctx, chID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &acme.Authorization{
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: chs,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Token: dbaz.Token,
|
||||
Fingerprint: dbaz.Fingerprint,
|
||||
Error: dbaz.Error,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateAuthorization creates an entry in the database for the Authorization.
|
||||
// Implements the acme.DB.CreateAuthorization interface.
|
||||
func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
var err error
|
||||
az.ID, err = randID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chIDs := make([]string, len(az.Challenges))
|
||||
for i, ch := range az.Challenges {
|
||||
chIDs[i] = ch.ID
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: az.ID,
|
||||
AccountID: az.AccountID,
|
||||
Status: az.Status,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: az.ExpiresAt,
|
||||
Identifier: az.Identifier,
|
||||
ChallengeIDs: chIDs,
|
||||
Token: az.Token,
|
||||
Fingerprint: az.Fingerprint,
|
||||
Wildcard: az.Wildcard,
|
||||
}
|
||||
|
||||
return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable)
|
||||
}
|
||||
|
||||
// UpdateAuthorization saves an updated ACME Authorization to the database.
|
||||
func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
old, err := db.getDBAuthz(ctx, az.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
nu.Status = az.Status
|
||||
nu.Fingerprint = az.Fingerprint
|
||||
nu.Error = az.Error
|
||||
return db.save(ctx, old.ID, nu, old, "authz", authzTable)
|
||||
}
|
||||
|
||||
// GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database.
|
||||
func (db *DB) GetAuthorizationsByAccountID(_ context.Context, accountID string) ([]*acme.Authorization, error) {
|
||||
entries, err := db.db.List(authzTable)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error listing authz")
|
||||
}
|
||||
authzs := []*acme.Authorization{}
|
||||
for _, entry := range entries {
|
||||
dbaz := new(dbAuthz)
|
||||
if err = json.Unmarshal(entry.Value, dbaz); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling dbAuthz key '%s' into dbAuthz struct", string(entry.Key))
|
||||
}
|
||||
// Filter out all dbAuthzs that don't belong to the accountID. This
|
||||
// could be made more efficient with additional data structures mapping the
|
||||
// Account ID to authorizations. Not trivial to do, though.
|
||||
if dbaz.AccountID != accountID {
|
||||
continue
|
||||
}
|
||||
authzs = append(authzs, &acme.Authorization{
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: nil, // challenges not required for current use case
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Token: dbaz.Token,
|
||||
Fingerprint: dbaz.Fingerprint,
|
||||
Error: dbaz.Error,
|
||||
})
|
||||
}
|
||||
|
||||
return authzs, nil
|
||||
}
|
|
@ -1,772 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
nosqldb "github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestDB_getDBAuthz(t *testing.T) {
|
||||
azID := "azID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbaz *dbAuthz
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"),
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading authz azID: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return []byte("foo"), nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error unmarshaling authz azID into dbAuthz"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("The server experienced an internal error"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbaz: dbaz,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if dbaz, err := d.getDBAuthz(context.Background(), azID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbaz.ID, tc.dbaz.ID)
|
||||
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
|
||||
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
|
||||
assert.Equals(t, dbaz.Status, tc.dbaz.Status)
|
||||
assert.Equals(t, dbaz.Token, tc.dbaz.Token)
|
||||
assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
|
||||
assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
|
||||
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetAuthorization(t *testing.T) {
|
||||
azID := "azID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbaz *dbAuthz
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading authz azID: force"),
|
||||
}
|
||||
},
|
||||
"fail/forward-acme-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(authzTable):
|
||||
assert.Equals(t, string(key), azID)
|
||||
return b, nil
|
||||
case string(challengeTable):
|
||||
assert.Equals(t, string(key), "foo")
|
||||
return nil, errors.New("force")
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading acme challenge foo: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetChallenge-not-found": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(authzTable):
|
||||
assert.Equals(t, string(key), azID)
|
||||
return b, nil
|
||||
case string(challengeTable):
|
||||
assert.Equals(t, string(key), "foo")
|
||||
return nil, nosqldb.ErrNotFound
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge foo not found"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("The server experienced an internal error"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
chCount := 0
|
||||
fooChb, err := json.Marshal(&dbChallenge{ID: "foo"})
|
||||
assert.FatalError(t, err)
|
||||
barChb, err := json.Marshal(&dbChallenge{ID: "bar"})
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
switch string(bucket) {
|
||||
case string(authzTable):
|
||||
assert.Equals(t, string(key), azID)
|
||||
return b, nil
|
||||
case string(challengeTable):
|
||||
if chCount == 0 {
|
||||
chCount++
|
||||
assert.Equals(t, string(key), "foo")
|
||||
return fooChb, nil
|
||||
}
|
||||
assert.Equals(t, string(key), "bar")
|
||||
return barChb, nil
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
},
|
||||
dbaz: dbaz,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if az, err := d.GetAuthorization(context.Background(), azID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, az.ID, tc.dbaz.ID)
|
||||
assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
|
||||
assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
|
||||
assert.Equals(t, az.Status, tc.dbaz.Status)
|
||||
assert.Equals(t, az.Token, tc.dbaz.Token)
|
||||
assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
|
||||
assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||
assert.Equals(t, az.Challenges, []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
})
|
||||
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_CreateAuthorization(t *testing.T) {
|
||||
azID := "azID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
az *acme.Authorization
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
az := &acme.Authorization{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Challenges: []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
},
|
||||
Wildcard: true,
|
||||
Error: acme.NewErrorISE("force"),
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), az.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbaz := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbaz))
|
||||
assert.Equals(t, dbaz.ID, string(key))
|
||||
assert.Equals(t, dbaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, dbaz.Identifier, acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
})
|
||||
assert.Equals(t, dbaz.Status, az.Status)
|
||||
assert.Equals(t, dbaz.Token, az.Token)
|
||||
assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"})
|
||||
assert.Equals(t, dbaz.Wildcard, az.Wildcard)
|
||||
assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt)
|
||||
assert.Nil(t, dbaz.Error)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
az: az,
|
||||
err: errors.New("error saving acme authz: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
now = clock.Now()
|
||||
az = &acme.Authorization{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Challenges: []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
},
|
||||
Wildcard: true,
|
||||
Error: acme.NewErrorISE("force"),
|
||||
}
|
||||
)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
*idPtr = string(key)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), az.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbaz := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbaz))
|
||||
assert.Equals(t, dbaz.ID, string(key))
|
||||
assert.Equals(t, dbaz.AccountID, az.AccountID)
|
||||
assert.Equals(t, dbaz.Identifier, acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
})
|
||||
assert.Equals(t, dbaz.Status, az.Status)
|
||||
assert.Equals(t, dbaz.Token, az.Token)
|
||||
assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"})
|
||||
assert.Equals(t, dbaz.Wildcard, az.Wildcard)
|
||||
assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt)
|
||||
assert.Nil(t, dbaz.Error)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt))
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
az: az,
|
||||
_id: idPtr,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.CreateAuthorization(context.Background(), tc.az); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.ID, *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_UpdateAuthorization(t *testing.T) {
|
||||
azID := "azID"
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "accountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
Fingerprint: "fingerprint",
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
az *acme.Authorization
|
||||
err error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
az: &acme.Authorization{
|
||||
ID: azID,
|
||||
},
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading authz azID: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||
updAz := &acme.Authorization{
|
||||
ID: azID,
|
||||
Status: acme.StatusValid,
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
}
|
||||
return test{
|
||||
az: updAz,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbOld := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||
assert.Equals(t, dbaz, dbOld)
|
||||
|
||||
dbNew := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbaz.ID)
|
||||
assert.Equals(t, dbNew.AccountID, dbaz.AccountID)
|
||||
assert.Equals(t, dbNew.Identifier, dbaz.Identifier)
|
||||
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||
assert.Equals(t, dbNew.Token, dbaz.Token)
|
||||
assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs)
|
||||
assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
|
||||
assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error())
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme authz: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
updAz := &acme.Authorization{
|
||||
ID: azID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Status: acme.StatusValid,
|
||||
Identifier: dbaz.Identifier,
|
||||
Challenges: []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
},
|
||||
Token: dbaz.Token,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Fingerprint: "fingerprint",
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
}
|
||||
return test{
|
||||
az: updAz,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, string(key), azID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbOld := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||
assert.Equals(t, dbaz, dbOld)
|
||||
|
||||
dbNew := new(dbAuthz)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbaz.ID)
|
||||
assert.Equals(t, dbNew.AccountID, dbaz.AccountID)
|
||||
assert.Equals(t, dbNew.Identifier, dbaz.Identifier)
|
||||
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||
assert.Equals(t, dbNew.Token, dbaz.Token)
|
||||
assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs)
|
||||
assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
|
||||
assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbNew.Fingerprint, dbaz.Fingerprint)
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error())
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.UpdateAuthorization(context.Background(), tc.az); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.ID, dbaz.ID)
|
||||
assert.Equals(t, tc.az.AccountID, dbaz.AccountID)
|
||||
assert.Equals(t, tc.az.Identifier, dbaz.Identifier)
|
||||
assert.Equals(t, tc.az.Status, acme.StatusValid)
|
||||
assert.Equals(t, tc.az.Wildcard, dbaz.Wildcard)
|
||||
assert.Equals(t, tc.az.Token, dbaz.Token)
|
||||
assert.Equals(t, tc.az.ExpiresAt, dbaz.ExpiresAt)
|
||||
assert.Equals(t, tc.az.Challenges, []*acme.Challenge{
|
||||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
})
|
||||
assert.Equals(t, tc.az.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetAuthorizationsByAccountID(t *testing.T) {
|
||||
azID := "azID"
|
||||
accountID := "accountID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
authzs []*acme.Authorization
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.List-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error listing authz: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal": func(t *testing.T) test {
|
||||
b := []byte(`{malformed}`)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
return []*nosqldb.Entry{
|
||||
{
|
||||
Bucket: bucket,
|
||||
Key: []byte(azID),
|
||||
Value: b,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
authzs: nil,
|
||||
err: fmt.Errorf("error unmarshaling dbAuthz key '%s' into dbAuthz struct", azID),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: accountID,
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusValid,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
return []*nosqldb.Entry{
|
||||
{
|
||||
Bucket: bucket,
|
||||
Key: []byte(azID),
|
||||
Value: b,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
authzs: []*acme.Authorization{
|
||||
{
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Token: dbaz.Token,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: nil,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Error: dbaz.Error,
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/skip-different-account": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "differentAccountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusValid,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
return []*nosqldb.Entry{
|
||||
{
|
||||
Bucket: bucket,
|
||||
Key: []byte(azID),
|
||||
Value: b,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
authzs: []*acme.Authorization{},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if azs, err := d.GetAuthorizationsByAccountID(context.Background(), accountID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
if !cmp.Equal(azs, tc.authzs) {
|
||||
t.Errorf("db.GetAuthorizationsByAccountID() diff =\n%s", cmp.Diff(azs, tc.authzs))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,141 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type dbCert struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
AccountID string `json:"accountID"`
|
||||
OrderID string `json:"orderID"`
|
||||
Leaf []byte `json:"leaf"`
|
||||
Intermediates []byte `json:"intermediates"`
|
||||
}
|
||||
|
||||
type dbSerial struct {
|
||||
Serial string `json:"serial"`
|
||||
CertificateID string `json:"certificateID"`
|
||||
}
|
||||
|
||||
// CreateCertificate creates and stores an ACME certificate type.
|
||||
func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error {
|
||||
var err error
|
||||
cert.ID, err = randID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
leaf := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Leaf.Raw,
|
||||
})
|
||||
var intermediates []byte
|
||||
for _, cert := range cert.Intermediates {
|
||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
|
||||
dbch := &dbCert{
|
||||
ID: cert.ID,
|
||||
AccountID: cert.AccountID,
|
||||
OrderID: cert.OrderID,
|
||||
Leaf: leaf,
|
||||
Intermediates: intermediates,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serial := cert.Leaf.SerialNumber.String()
|
||||
dbSerial := &dbSerial{
|
||||
Serial: serial,
|
||||
CertificateID: cert.ID,
|
||||
}
|
||||
return db.save(ctx, serial, dbSerial, nil, "serial", certBySerialTable)
|
||||
}
|
||||
|
||||
// GetCertificate retrieves and unmarshals an ACME certificate type from the
|
||||
// datastore.
|
||||
func (db *DB) GetCertificate(_ context.Context, id string) (*acme.Certificate, error) {
|
||||
b, err := db.db.Get(certTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading certificate %s", id)
|
||||
}
|
||||
dbC := new(dbCert)
|
||||
if err := json.Unmarshal(b, dbC); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling certificate %s", id)
|
||||
}
|
||||
|
||||
certs, err := parseBundle(append(dbC.Leaf, dbC.Intermediates...))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing certificate chain for ACME certificate with ID %s", id)
|
||||
}
|
||||
|
||||
return &acme.Certificate{
|
||||
ID: dbC.ID,
|
||||
AccountID: dbC.AccountID,
|
||||
OrderID: dbC.OrderID,
|
||||
Leaf: certs[0],
|
||||
Intermediates: certs[1:],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetCertificateBySerial retrieves and unmarshals an ACME certificate type from the
|
||||
// datastore based on a certificate serial number.
|
||||
func (db *DB) GetCertificateBySerial(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
b, err := db.db.Get(certBySerialTable, []byte(serial))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading certificate ID for serial %s", serial)
|
||||
}
|
||||
|
||||
dbSerial := new(dbSerial)
|
||||
if err := json.Unmarshal(b, dbSerial); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling certificate with serial %s", serial)
|
||||
}
|
||||
|
||||
return db.GetCertificate(ctx, dbSerial.CertificateID)
|
||||
}
|
||||
|
||||
func parseBundle(b []byte) ([]*x509.Certificate, error) {
|
||||
var (
|
||||
err error
|
||||
block *pem.Block
|
||||
bundle []*x509.Certificate
|
||||
)
|
||||
for len(b) > 0 {
|
||||
block, b = pem.Decode(b)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return nil, errors.New("error decoding PEM: data contains block that is not a certificate")
|
||||
}
|
||||
var crt *x509.Certificate
|
||||
crt, err = x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing x509 certificate")
|
||||
}
|
||||
bundle = append(bundle, crt)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
return nil, errors.New("error decoding PEM: unexpected data")
|
||||
}
|
||||
return bundle, nil
|
||||
}
|
|
@ -1,470 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
nosqldb "github.com/smallstep/nosql/database"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestDB_CreateCertificate(t *testing.T) {
|
||||
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
cert *acme.Certificate
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
cert := &acme.Certificate{
|
||||
AccountID: "accountID",
|
||||
OrderID: "orderID",
|
||||
Leaf: leaf,
|
||||
Intermediates: []*x509.Certificate{inter, root},
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbc := new(dbCert)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||
assert.Equals(t, dbc.ID, string(key))
|
||||
assert.Equals(t, dbc.ID, cert.ID)
|
||||
assert.Equals(t, dbc.AccountID, cert.AccountID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
cert: cert,
|
||||
err: errors.New("error saving acme certificate: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
cert := &acme.Certificate{
|
||||
AccountID: "accountID",
|
||||
OrderID: "orderID",
|
||||
Leaf: leaf,
|
||||
Intermediates: []*x509.Certificate{inter, root},
|
||||
}
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
)
|
||||
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
if !bytes.Equal(bucket, certTable) && !bytes.Equal(bucket, certBySerialTable) {
|
||||
t.Fail()
|
||||
}
|
||||
if bytes.Equal(bucket, certTable) {
|
||||
*idPtr = string(key)
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbc := new(dbCert)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||
assert.Equals(t, dbc.ID, string(key))
|
||||
assert.Equals(t, dbc.ID, cert.ID)
|
||||
assert.Equals(t, dbc.AccountID, cert.AccountID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||
}
|
||||
if bytes.Equal(bucket, certBySerialTable) {
|
||||
assert.Equals(t, bucket, certBySerialTable)
|
||||
assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String()))
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbs := new(dbSerial)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbs))
|
||||
assert.Equals(t, dbs.Serial, string(key))
|
||||
assert.Equals(t, dbs.CertificateID, cert.ID)
|
||||
|
||||
*idPtr = cert.ID
|
||||
}
|
||||
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
_id: idPtr,
|
||||
cert: cert,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.CreateCertificate(context.Background(), tc.cert); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.cert.ID, *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetCertificate(t *testing.T) {
|
||||
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
certID := "certID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"),
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
return nil, errors.Errorf("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading certificate certID: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
return []byte("foobar"), nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error unmarshaling certificate certID"),
|
||||
}
|
||||
},
|
||||
"fail/parseBundle-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
cert := dbCert{
|
||||
ID: certID,
|
||||
AccountID: "accountID",
|
||||
OrderID: "orderID",
|
||||
Leaf: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "Public Key",
|
||||
Bytes: leaf.Raw,
|
||||
}),
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(cert)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
err: errors.Errorf("error parsing certificate chain for ACME certificate with ID certID"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, string(key), certID)
|
||||
|
||||
cert := dbCert{
|
||||
ID: certID,
|
||||
AccountID: "accountID",
|
||||
OrderID: "orderID",
|
||||
Leaf: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: leaf.Raw,
|
||||
}),
|
||||
Intermediates: append(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: inter.Raw,
|
||||
}), pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: root.Raw,
|
||||
})...),
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(cert)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
cert, err := d.GetCertificate(context.Background(), certID)
|
||||
if err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, cert.ID, certID)
|
||||
assert.Equals(t, cert.AccountID, "accountID")
|
||||
assert.Equals(t, cert.OrderID, "orderID")
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseBundle(t *testing.T) {
|
||||
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
var certs []byte
|
||||
for _, cert := range []*x509.Certificate{leaf, inter, root} {
|
||||
certs = append(certs, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
|
||||
type test struct {
|
||||
b []byte
|
||||
err error
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"fail/bad-type-error": {
|
||||
b: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "Public Key",
|
||||
Bytes: leaf.Raw,
|
||||
}),
|
||||
err: errors.Errorf("error decoding PEM: data contains block that is not a certificate"),
|
||||
},
|
||||
"fail/bad-pem-error": {
|
||||
b: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: []byte("foo"),
|
||||
}),
|
||||
err: errors.Errorf("error parsing x509 certificate"),
|
||||
},
|
||||
"fail/unexpected-data": {
|
||||
b: append(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: leaf.Raw,
|
||||
}), []byte("foo")...),
|
||||
err: errors.Errorf("error decoding PEM: unexpected data"),
|
||||
},
|
||||
"ok": {
|
||||
b: certs,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ret, err := parseBundle(tc.b)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ret, []*x509.Certificate{leaf, inter, root})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetCertificateBySerial(t *testing.T) {
|
||||
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
certID := "certID"
|
||||
serial := ""
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if bytes.Equal(bucket, certBySerialTable) {
|
||||
return nil, nosqldb.ErrNotFound
|
||||
}
|
||||
return nil, errors.New("wrong table")
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if bytes.Equal(bucket, certBySerialTable) {
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
return nil, errors.New("wrong table")
|
||||
},
|
||||
},
|
||||
err: fmt.Errorf("error loading certificate ID for serial %s", serial),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-dbSerial": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if bytes.Equal(bucket, certBySerialTable) {
|
||||
return []byte(`{"serial":malformed!}`), nil
|
||||
}
|
||||
return nil, errors.New("wrong table")
|
||||
},
|
||||
},
|
||||
err: fmt.Errorf("error unmarshaling certificate with serial %s", serial),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
|
||||
if bytes.Equal(bucket, certBySerialTable) {
|
||||
certSerial := dbSerial{
|
||||
Serial: serial,
|
||||
CertificateID: certID,
|
||||
}
|
||||
|
||||
b, err := json.Marshal(certSerial)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
if bytes.Equal(bucket, certTable) {
|
||||
cert := dbCert{
|
||||
ID: certID,
|
||||
AccountID: "accountID",
|
||||
OrderID: "orderID",
|
||||
Leaf: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: leaf.Raw,
|
||||
}),
|
||||
Intermediates: append(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: inter.Raw,
|
||||
}), pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: root.Raw,
|
||||
})...),
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(cert)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return b, nil
|
||||
}
|
||||
return nil, errors.New("wrong table")
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
cert, err := d.GetCertificateBySerial(context.Background(), serial)
|
||||
if err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, cert.ID, certID)
|
||||
assert.Equals(t, cert.AccountID, "accountID")
|
||||
assert.Equals(t, cert.OrderID, "orderID")
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,106 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/nosql"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
)
|
||||
|
||||
type dbChallenge struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Type acme.ChallengeType `json:"type"`
|
||||
Status acme.Status `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Value string `json:"value"`
|
||||
ValidatedAt string `json:"validatedAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Error *acme.Error `json:"error"` // TODO(hs): a bit dangerous; should become db-specific type
|
||||
}
|
||||
|
||||
func (dbc *dbChallenge) clone() *dbChallenge {
|
||||
u := *dbc
|
||||
return &u
|
||||
}
|
||||
|
||||
func (db *DB) getDBChallenge(_ context.Context, id string) (*dbChallenge, error) {
|
||||
data, err := db.db.Get(challengeTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading acme challenge %s", id)
|
||||
}
|
||||
|
||||
dbch := new(dbChallenge)
|
||||
if err := json.Unmarshal(data, dbch); err != nil {
|
||||
return nil, errors.Wrap(err, "error unmarshaling dbChallenge")
|
||||
}
|
||||
return dbch, nil
|
||||
}
|
||||
|
||||
// CreateChallenge creates a new ACME challenge data structure in the database.
|
||||
// Implements acme.DB.CreateChallenge interface.
|
||||
func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error {
|
||||
var err error
|
||||
ch.ID, err = randID()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error generating random id for ACME challenge")
|
||||
}
|
||||
|
||||
dbch := &dbChallenge{
|
||||
ID: ch.ID,
|
||||
AccountID: ch.AccountID,
|
||||
Value: ch.Value,
|
||||
Status: acme.StatusPending,
|
||||
Token: ch.Token,
|
||||
CreatedAt: clock.Now(),
|
||||
Type: ch.Type,
|
||||
}
|
||||
|
||||
return db.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable)
|
||||
}
|
||||
|
||||
// GetChallenge retrieves and unmarshals an ACME challenge type from the database.
|
||||
// Implements the acme.DB GetChallenge interface.
|
||||
func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) {
|
||||
_ = authzID // unused input
|
||||
dbch, err := db.getDBChallenge(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := &acme.Challenge{
|
||||
ID: dbch.ID,
|
||||
AccountID: dbch.AccountID,
|
||||
Type: dbch.Type,
|
||||
Value: dbch.Value,
|
||||
Status: dbch.Status,
|
||||
Token: dbch.Token,
|
||||
Error: dbch.Error,
|
||||
ValidatedAt: dbch.ValidatedAt,
|
||||
}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// UpdateChallenge updates an ACME challenge type in the database.
|
||||
func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error {
|
||||
old, err := db.getDBChallenge(ctx, ch.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
|
||||
// These should be the only values changing in an Update request.
|
||||
nu.Status = ch.Status
|
||||
nu.Error = ch.Error
|
||||
nu.ValidatedAt = ch.ValidatedAt
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "challenge", challengeTable)
|
||||
}
|
|
@ -1,460 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
nosqldb "github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestDB_getDBChallenge(t *testing.T) {
|
||||
chID := "chID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbc *dbChallenge
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
|
||||
}
|
||||
},
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading acme challenge chID: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return []byte("foo"), nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error unmarshaling dbChallenge"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
dbc := &dbChallenge{
|
||||
ID: chID,
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
CreatedAt: clock.Now(),
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewErrorISE("The server experienced an internal error"),
|
||||
}
|
||||
b, err := json.Marshal(dbc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbc: dbc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if ch, err := d.getDBChallenge(context.Background(), chID); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_CreateChallenge(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
ch *acme.Challenge
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
ch := &acme.Challenge{
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), ch.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbc := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||
assert.Equals(t, dbc.ID, string(key))
|
||||
assert.Equals(t, dbc.AccountID, ch.AccountID)
|
||||
assert.Equals(t, dbc.Type, ch.Type)
|
||||
assert.Equals(t, dbc.Status, ch.Status)
|
||||
assert.Equals(t, dbc.Token, ch.Token)
|
||||
assert.Equals(t, dbc.Value, ch.Value)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
ch: ch,
|
||||
err: errors.New("error saving acme challenge: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
ch = &acme.Challenge{
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
}
|
||||
)
|
||||
|
||||
return test{
|
||||
ch: ch,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
*idPtr = string(key)
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), ch.ID)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbc := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||
assert.Equals(t, dbc.ID, string(key))
|
||||
assert.Equals(t, dbc.AccountID, ch.AccountID)
|
||||
assert.Equals(t, dbc.Type, ch.Type)
|
||||
assert.Equals(t, dbc.Status, ch.Status)
|
||||
assert.Equals(t, dbc.Token, ch.Token)
|
||||
assert.Equals(t, dbc.Value, ch.Value)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
_id: idPtr,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.CreateChallenge(context.Background(), tc.ch); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.ch.ID, *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetChallenge(t *testing.T) {
|
||||
chID := "chID"
|
||||
azID := "azID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
dbc *dbChallenge
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading acme challenge chID: force"),
|
||||
}
|
||||
},
|
||||
"fail/forward-acme-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, nosqldb.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
dbc := &dbChallenge{
|
||||
ID: chID,
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
CreatedAt: clock.Now(),
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewErrorISE("The server experienced an internal error"),
|
||||
}
|
||||
b, err := json.Marshal(dbc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
dbc: dbc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if ch, err := d.GetChallenge(context.Background(), chID, azID); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_UpdateChallenge(t *testing.T) {
|
||||
chID := "chID"
|
||||
dbc := &dbChallenge{
|
||||
ID: chID,
|
||||
AccountID: "accountID",
|
||||
Type: "dns-01",
|
||||
Status: acme.StatusPending,
|
||||
Token: "token",
|
||||
Value: "test.ca.smallstep.com",
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(dbc)
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
ch *acme.Challenge
|
||||
err error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.Get-error": func(t *testing.T) test {
|
||||
return test{
|
||||
ch: &acme.Challenge{
|
||||
ID: chID,
|
||||
},
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error loading acme challenge chID: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||
updCh := &acme.Challenge{
|
||||
ID: chID,
|
||||
Status: acme.StatusValid,
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"),
|
||||
}
|
||||
return test{
|
||||
ch: updCh,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbOld := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||
assert.Equals(t, dbc, dbOld)
|
||||
|
||||
dbNew := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbc.ID)
|
||||
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
|
||||
assert.Equals(t, dbNew.Type, dbc.Type)
|
||||
assert.Equals(t, dbNew.Status, updCh.Status)
|
||||
assert.Equals(t, dbNew.Token, dbc.Token)
|
||||
assert.Equals(t, dbNew.Value, dbc.Value)
|
||||
assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error())
|
||||
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
|
||||
assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme challenge: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
updCh := &acme.Challenge{
|
||||
ID: dbc.ID,
|
||||
AccountID: dbc.AccountID,
|
||||
Type: dbc.Type,
|
||||
Token: dbc.Token,
|
||||
Value: dbc.Value,
|
||||
Status: acme.StatusValid,
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
}
|
||||
return test{
|
||||
ch: updCh,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), chID)
|
||||
|
||||
return b, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, old, b)
|
||||
|
||||
dbOld := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||
assert.Equals(t, dbc, dbOld)
|
||||
|
||||
dbNew := new(dbChallenge)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbc.ID)
|
||||
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
|
||||
assert.Equals(t, dbNew.Type, dbc.Type)
|
||||
assert.Equals(t, dbNew.Token, dbc.Token)
|
||||
assert.Equals(t, dbNew.Value, dbc.Value)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
|
||||
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||
assert.Equals(t, dbNew.ValidatedAt, "foobar")
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error())
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.ch.ID, dbc.ID)
|
||||
assert.Equals(t, tc.ch.AccountID, dbc.AccountID)
|
||||
assert.Equals(t, tc.ch.Type, dbc.Type)
|
||||
assert.Equals(t, tc.ch.Token, dbc.Token)
|
||||
assert.Equals(t, tc.ch.Value, dbc.Value)
|
||||
assert.Equals(t, tc.ch.ValidatedAt, "foobar")
|
||||
assert.Equals(t, tc.ch.Status, acme.StatusValid)
|
||||
assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,387 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
nosqlDB "github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// externalAccountKeyMutex for read/write locking of EAK operations.
|
||||
var externalAccountKeyMutex sync.RWMutex
|
||||
|
||||
// referencesByProvisionerIndexMutex for locking referencesByProvisioner index operations.
|
||||
var referencesByProvisionerIndexMutex sync.Mutex
|
||||
|
||||
type dbExternalAccountKey struct {
|
||||
ID string `json:"id"`
|
||||
ProvisionerID string `json:"provisionerID"`
|
||||
Reference string `json:"reference"`
|
||||
AccountID string `json:"accountID,omitempty"`
|
||||
HmacKey []byte `json:"key"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
BoundAt time.Time `json:"boundAt"`
|
||||
}
|
||||
|
||||
type dbExternalAccountKeyReference struct {
|
||||
Reference string `json:"reference"`
|
||||
ExternalAccountKeyID string `json:"externalAccountKeyID"`
|
||||
}
|
||||
|
||||
// getDBExternalAccountKey retrieves and unmarshals dbExternalAccountKey.
|
||||
func (db *DB) getDBExternalAccountKey(_ context.Context, id string) (*dbExternalAccountKey, error) {
|
||||
data, err := db.db.Get(externalAccountKeyTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
return nil, acme.ErrNotFound
|
||||
}
|
||||
return nil, errors.Wrapf(err, "error loading external account key %s", id)
|
||||
}
|
||||
|
||||
dbeak := new(dbExternalAccountKey)
|
||||
if err = json.Unmarshal(data, dbeak); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling external account key %s into dbExternalAccountKey", id)
|
||||
}
|
||||
|
||||
return dbeak, nil
|
||||
}
|
||||
|
||||
// CreateExternalAccountKey creates a new External Account Binding key with a name
|
||||
func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
externalAccountKeyMutex.Lock()
|
||||
defer externalAccountKeyMutex.Unlock()
|
||||
|
||||
keyID, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
random := make([]byte, 32)
|
||||
_, err = rand.Read(random)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbeak := &dbExternalAccountKey{
|
||||
ID: keyID,
|
||||
ProvisionerID: provisionerID,
|
||||
Reference: reference,
|
||||
HmacKey: random,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
if err := db.save(ctx, keyID, dbeak, nil, "external_account_key", externalAccountKeyTable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.addEAKID(ctx, provisionerID, dbeak.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbeak.Reference != "" {
|
||||
dbExternalAccountKeyReference := &dbExternalAccountKeyReference{
|
||||
Reference: dbeak.Reference,
|
||||
ExternalAccountKeyID: dbeak.ID,
|
||||
}
|
||||
if err := db.save(ctx, referenceKey(provisionerID, dbeak.Reference), dbExternalAccountKeyReference, nil, "external_account_key_reference", externalAccountKeyIDsByReferenceTable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &acme.ExternalAccountKey{
|
||||
ID: dbeak.ID,
|
||||
ProvisionerID: dbeak.ProvisionerID,
|
||||
Reference: dbeak.Reference,
|
||||
AccountID: dbeak.AccountID,
|
||||
HmacKey: dbeak.HmacKey,
|
||||
CreatedAt: dbeak.CreatedAt,
|
||||
BoundAt: dbeak.BoundAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetExternalAccountKey retrieves an External Account Binding key by KeyID
|
||||
func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
externalAccountKeyMutex.RLock()
|
||||
defer externalAccountKeyMutex.RUnlock()
|
||||
|
||||
dbeak, err := db.getDBExternalAccountKey(ctx, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbeak.ProvisionerID != provisionerID {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created")
|
||||
}
|
||||
|
||||
return &acme.ExternalAccountKey{
|
||||
ID: dbeak.ID,
|
||||
ProvisionerID: dbeak.ProvisionerID,
|
||||
Reference: dbeak.Reference,
|
||||
AccountID: dbeak.AccountID,
|
||||
HmacKey: dbeak.HmacKey,
|
||||
CreatedAt: dbeak.CreatedAt,
|
||||
BoundAt: dbeak.BoundAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *DB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error {
|
||||
externalAccountKeyMutex.Lock()
|
||||
defer externalAccountKeyMutex.Unlock()
|
||||
|
||||
dbeak, err := db.getDBExternalAccountKey(ctx, keyID)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error loading ACME EAB Key with Key ID %s", keyID)
|
||||
}
|
||||
|
||||
if dbeak.ProvisionerID != provisionerID {
|
||||
return errors.New("provisioner does not match provisioner for which the EAB key was created")
|
||||
}
|
||||
|
||||
if dbeak.Reference != "" {
|
||||
if err := db.db.Del(externalAccountKeyIDsByReferenceTable, []byte(referenceKey(provisionerID, dbeak.Reference))); err != nil {
|
||||
return errors.Wrapf(err, "error deleting ACME EAB Key reference with Key ID %s and reference %s", keyID, dbeak.Reference)
|
||||
}
|
||||
}
|
||||
if err := db.db.Del(externalAccountKeyTable, []byte(keyID)); err != nil {
|
||||
return errors.Wrapf(err, "error deleting ACME EAB Key with Key ID %s", keyID)
|
||||
}
|
||||
if err := db.deleteEAKID(ctx, provisionerID, keyID); err != nil {
|
||||
return errors.Wrapf(err, "error removing ACME EAB Key ID %s", keyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExternalAccountKeys retrieves all External Account Binding keys for a provisioner
|
||||
func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*acme.ExternalAccountKey, string, error) {
|
||||
_, _ = cursor, limit // unused input
|
||||
|
||||
externalAccountKeyMutex.RLock()
|
||||
defer externalAccountKeyMutex.RUnlock()
|
||||
|
||||
// cursor and limit are ignored in open source, at least for now.
|
||||
|
||||
var eakIDs []string
|
||||
r, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID))
|
||||
if err != nil {
|
||||
if !nosqlDB.IsErrNotFound(err) {
|
||||
return nil, "", errors.Wrapf(err, "error loading ACME EAB Key IDs for provisioner %s", provisionerID)
|
||||
}
|
||||
// it may happen that no record is found; we'll continue with an empty slice
|
||||
} else {
|
||||
if err := json.Unmarshal(r, &eakIDs); err != nil {
|
||||
return nil, "", errors.Wrapf(err, "error unmarshaling ACME EAB Key IDs for provisioner %s", provisionerID)
|
||||
}
|
||||
}
|
||||
|
||||
keys := []*acme.ExternalAccountKey{}
|
||||
for _, eakID := range eakIDs {
|
||||
if eakID == "" {
|
||||
continue // shouldn't happen; just in case
|
||||
}
|
||||
eak, err := db.getDBExternalAccountKey(ctx, eakID)
|
||||
if err != nil {
|
||||
if !nosqlDB.IsErrNotFound(err) {
|
||||
return nil, "", errors.Wrapf(err, "error retrieving ACME EAB Key for provisioner %s and keyID %s", provisionerID, eakID)
|
||||
}
|
||||
}
|
||||
keys = append(keys, &acme.ExternalAccountKey{
|
||||
ID: eak.ID,
|
||||
HmacKey: eak.HmacKey,
|
||||
ProvisionerID: eak.ProvisionerID,
|
||||
Reference: eak.Reference,
|
||||
AccountID: eak.AccountID,
|
||||
CreatedAt: eak.CreatedAt,
|
||||
BoundAt: eak.BoundAt,
|
||||
})
|
||||
}
|
||||
|
||||
return keys, "", nil
|
||||
}
|
||||
|
||||
// GetExternalAccountKeyByReference retrieves an External Account Binding key with unique reference
|
||||
func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
externalAccountKeyMutex.RLock()
|
||||
defer externalAccountKeyMutex.RUnlock()
|
||||
|
||||
if reference == "" {
|
||||
//nolint:nilnil // legacy
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
k, err := db.db.Get(externalAccountKeyIDsByReferenceTable, []byte(referenceKey(provisionerID, reference)))
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
return nil, acme.ErrNotFound
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading ACME EAB key for reference %s", reference)
|
||||
}
|
||||
dbExternalAccountKeyReference := new(dbExternalAccountKeyReference)
|
||||
if err := json.Unmarshal(k, dbExternalAccountKeyReference); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling ACME EAB key for reference %s", reference)
|
||||
}
|
||||
|
||||
return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID)
|
||||
}
|
||||
|
||||
func (db *DB) GetExternalAccountKeyByAccountID(context.Context, string, string) (*acme.ExternalAccountKey, error) {
|
||||
//nolint:nilnil // legacy
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
externalAccountKeyMutex.Lock()
|
||||
defer externalAccountKeyMutex.Unlock()
|
||||
|
||||
old, err := db.getDBExternalAccountKey(ctx, eak.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if old.ProvisionerID != provisionerID {
|
||||
return errors.New("provisioner does not match provisioner for which the EAB key was created")
|
||||
}
|
||||
|
||||
if old.ProvisionerID != eak.ProvisionerID {
|
||||
return errors.New("cannot change provisioner for an existing ACME EAB Key")
|
||||
}
|
||||
|
||||
if old.Reference != eak.Reference {
|
||||
return errors.New("cannot change reference for an existing ACME EAB Key")
|
||||
}
|
||||
|
||||
nu := dbExternalAccountKey{
|
||||
ID: eak.ID,
|
||||
ProvisionerID: eak.ProvisionerID,
|
||||
Reference: eak.Reference,
|
||||
AccountID: eak.AccountID,
|
||||
HmacKey: eak.HmacKey,
|
||||
CreatedAt: eak.CreatedAt,
|
||||
BoundAt: eak.BoundAt,
|
||||
}
|
||||
|
||||
return db.save(ctx, nu.ID, nu, old, "external_account_key", externalAccountKeyTable)
|
||||
}
|
||||
|
||||
func (db *DB) addEAKID(ctx context.Context, provisionerID, eakID string) error {
|
||||
referencesByProvisionerIndexMutex.Lock()
|
||||
defer referencesByProvisionerIndexMutex.Unlock()
|
||||
|
||||
if eakID == "" {
|
||||
return errors.Errorf("can't add empty eakID for provisioner %s", provisionerID)
|
||||
}
|
||||
|
||||
var eakIDs []string
|
||||
b, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID))
|
||||
if err != nil {
|
||||
if !nosqlDB.IsErrNotFound(err) {
|
||||
return errors.Wrapf(err, "error loading eakIDs for provisioner %s", provisionerID)
|
||||
}
|
||||
// it may happen that no record is found; we'll continue with an empty slice
|
||||
} else {
|
||||
if err := json.Unmarshal(b, &eakIDs); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshaling eakIDs for provisioner %s", provisionerID)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range eakIDs {
|
||||
if id == eakID {
|
||||
// return an error when a duplicate ID is found
|
||||
return errors.Errorf("eakID %s already exists for provisioner %s", eakID, provisionerID)
|
||||
}
|
||||
}
|
||||
|
||||
var newEAKIDs []string
|
||||
newEAKIDs = append(newEAKIDs, eakIDs...)
|
||||
newEAKIDs = append(newEAKIDs, eakID)
|
||||
|
||||
var (
|
||||
_old interface{} = eakIDs
|
||||
_new interface{} = newEAKIDs
|
||||
)
|
||||
|
||||
// ensure that the DB gets the expected value when the slice is empty; otherwise
|
||||
// it'll return with an error that indicates that the DBs view of the data is
|
||||
// different from the last read (i.e. _old is different from what the DB has).
|
||||
if len(eakIDs) == 0 {
|
||||
_old = nil
|
||||
}
|
||||
|
||||
if err = db.save(ctx, provisionerID, _new, _old, "externalAccountKeyIDsByProvisionerID", externalAccountKeyIDsByProvisionerIDTable); err != nil {
|
||||
return errors.Wrapf(err, "error saving eakIDs index for provisioner %s", provisionerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) deleteEAKID(ctx context.Context, provisionerID, eakID string) error {
|
||||
referencesByProvisionerIndexMutex.Lock()
|
||||
defer referencesByProvisionerIndexMutex.Unlock()
|
||||
|
||||
var eakIDs []string
|
||||
b, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID))
|
||||
if err != nil {
|
||||
if !nosqlDB.IsErrNotFound(err) {
|
||||
return errors.Wrapf(err, "error loading eakIDs for provisioner %s", provisionerID)
|
||||
}
|
||||
// it may happen that no record is found; we'll continue with an empty slice
|
||||
} else {
|
||||
if err := json.Unmarshal(b, &eakIDs); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshaling eakIDs for provisioner %s", provisionerID)
|
||||
}
|
||||
}
|
||||
|
||||
newEAKIDs := removeElement(eakIDs, eakID)
|
||||
var (
|
||||
_old interface{} = eakIDs
|
||||
_new interface{} = newEAKIDs
|
||||
)
|
||||
|
||||
// ensure that the DB gets the expected value when the slice is empty; otherwise
|
||||
// it'll return with an error that indicates that the DBs view of the data is
|
||||
// different from the last read (i.e. _old is different from what the DB has).
|
||||
if len(eakIDs) == 0 {
|
||||
_old = nil
|
||||
}
|
||||
|
||||
if err = db.save(ctx, provisionerID, _new, _old, "externalAccountKeyIDsByProvisionerID", externalAccountKeyIDsByProvisionerIDTable); err != nil {
|
||||
return errors.Wrapf(err, "error saving eakIDs index for provisioner %s", provisionerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// referenceKey returns a unique key for a reference per provisioner
|
||||
func referenceKey(provisionerID, reference string) string {
|
||||
return provisionerID + "." + reference
|
||||
}
|
||||
|
||||
// sliceIndex finds the index of item in slice
|
||||
func sliceIndex(slice []string, item string) int {
|
||||
for i := range slice {
|
||||
if slice[i] == item {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// removeElement deletes the item if it exists in the
|
||||
// slice. It returns a new slice, keeping the old one intact.
|
||||
func removeElement(slice []string, item string) []string {
|
||||
newSlice := make([]string, 0)
|
||||
index := sliceIndex(slice, item)
|
||||
if index < 0 {
|
||||
newSlice = append(newSlice, slice...)
|
||||
return newSlice
|
||||
}
|
||||
|
||||
newSlice = append(newSlice, slice[:index]...)
|
||||
|
||||
return append(newSlice, slice[index+1:]...)
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,66 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
// dbNonce contains nonce metadata used in the ACME protocol.
|
||||
type dbNonce struct {
|
||||
ID string
|
||||
CreatedAt time.Time
|
||||
DeletedAt time.Time
|
||||
}
|
||||
|
||||
// CreateNonce creates, stores, and returns an ACME replay-nonce.
|
||||
// Implements the acme.DB interface.
|
||||
func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
|
||||
_id, err := randID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
|
||||
n := &dbNonce{
|
||||
ID: id,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
if err := db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return acme.Nonce(id), nil
|
||||
}
|
||||
|
||||
// DeleteNonce verifies that the nonce is valid (by checking if it exists),
|
||||
// and if so, consumes the nonce resource by deleting it from the database.
|
||||
func (db *DB) DeleteNonce(_ context.Context, nonce acme.Nonce) error {
|
||||
err := db.db.Update(&database.Tx{
|
||||
Operations: []*database.TxEntry{
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Get,
|
||||
},
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Delete,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
return acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", string(nonce))
|
||||
case err != nil:
|
||||
return errors.Wrapf(err, "error deleting nonce %s", string(nonce))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -1,168 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestDB_CreateNonce(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
_id *string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbn := new(dbNonce)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbn))
|
||||
assert.Equals(t, dbn.ID, string(key))
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme nonce: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var (
|
||||
id string
|
||||
idPtr = &id
|
||||
)
|
||||
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
*idPtr = string(key)
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
dbn := new(dbNonce)
|
||||
assert.FatalError(t, json.Unmarshal(nu, dbn))
|
||||
assert.Equals(t, dbn.ID, string(key))
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
_id: idPtr,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if n, err := d.CreateNonce(context.Background()); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, string(n), *tc._id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_DeleteNonce(t *testing.T) {
|
||||
|
||||
nonceID := "nonceID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return database.ErrNotFound
|
||||
},
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", nonceID),
|
||||
}
|
||||
},
|
||||
"fail/db.Update-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error deleting nonce nonceID: force"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,103 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
nosqlDB "github.com/smallstep/nosql"
|
||||
"go.step.sm/crypto/randutil"
|
||||
)
|
||||
|
||||
var (
|
||||
accountTable = []byte("acme_accounts")
|
||||
accountByKeyIDTable = []byte("acme_keyID_accountID_index")
|
||||
authzTable = []byte("acme_authzs")
|
||||
challengeTable = []byte("acme_challenges")
|
||||
nonceTable = []byte("nonces")
|
||||
orderTable = []byte("acme_orders")
|
||||
ordersByAccountIDTable = []byte("acme_account_orders_index")
|
||||
certTable = []byte("acme_certs")
|
||||
certBySerialTable = []byte("acme_serial_certs_index")
|
||||
externalAccountKeyTable = []byte("acme_external_account_keys")
|
||||
externalAccountKeyIDsByReferenceTable = []byte("acme_external_account_keyID_reference_index")
|
||||
externalAccountKeyIDsByProvisionerIDTable = []byte("acme_external_account_keyID_provisionerID_index")
|
||||
)
|
||||
|
||||
// DB is a struct that implements the AcmeDB interface.
|
||||
type DB struct {
|
||||
db nosqlDB.DB
|
||||
}
|
||||
|
||||
// New configures and returns a new ACME DB backend implemented using a nosql DB.
|
||||
func New(db nosqlDB.DB) (*DB, error) {
|
||||
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
|
||||
challengeTable, nonceTable, orderTable, ordersByAccountIDTable,
|
||||
certTable, certBySerialTable, externalAccountKeyTable,
|
||||
externalAccountKeyIDsByReferenceTable, externalAccountKeyIDsByProvisionerIDTable,
|
||||
}
|
||||
for _, b := range tables {
|
||||
if err := db.CreateTable(b); err != nil {
|
||||
return nil, errors.Wrapf(err, "error creating table %s",
|
||||
string(b))
|
||||
}
|
||||
}
|
||||
return &DB{db}, nil
|
||||
}
|
||||
|
||||
// save writes the new data to the database, overwriting the old data if it
|
||||
// existed.
|
||||
func (db *DB) save(_ context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||
var (
|
||||
err error
|
||||
newB []byte
|
||||
)
|
||||
if nu == nil {
|
||||
newB = nil
|
||||
} else {
|
||||
newB, err = json.Marshal(nu)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu)
|
||||
}
|
||||
}
|
||||
var oldB []byte
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
oldB, err = json.Marshal(old)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old)
|
||||
}
|
||||
}
|
||||
|
||||
_, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return errors.Wrapf(err, "error saving acme %s", typ)
|
||||
case !swapped:
|
||||
return errors.Errorf("error saving acme %s; changed since last read", typ)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var idLen = 32
|
||||
|
||||
func randID() (val string, err error) {
|
||||
val, err = randutil.Alphanumeric(idLen)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error generating random alphanumeric ID")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
|
||||
// Now returns the UTC time rounded to seconds.
|
||||
func (c *Clock) Now() time.Time {
|
||||
return time.Now().UTC().Truncate(time.Second)
|
||||
}
|
||||
|
||||
var clock = new(Clock)
|
|
@ -1,139 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"fail/db.CreateTable-error": {
|
||||
db: &db.MockNoSQLDB{
|
||||
MCreateTable: func(bucket []byte) error {
|
||||
assert.Equals(t, string(bucket), string(accountTable))
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.Errorf("error creating table %s: force", string(accountTable)),
|
||||
},
|
||||
"ok": {
|
||||
db: &db.MockNoSQLDB{
|
||||
MCreateTable: func(bucket []byte) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if _, err := New(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type errorThrower string
|
||||
|
||||
func (et errorThrower) MarshalJSON() ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
|
||||
func TestDB_save(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
nu interface{}
|
||||
old interface{}
|
||||
err error
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"fail/error-marshaling-new": {
|
||||
nu: errorThrower("foo"),
|
||||
err: errors.New("error marshaling acme type: challenge"),
|
||||
},
|
||||
"fail/error-marshaling-old": {
|
||||
nu: "new",
|
||||
old: errorThrower("foo"),
|
||||
err: errors.New("error marshaling acme type: challenge"),
|
||||
},
|
||||
"fail/db.CmpAndSwap-error": {
|
||||
nu: "new",
|
||||
old: "old",
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), "id")
|
||||
assert.Equals(t, string(old), "\"old\"")
|
||||
assert.Equals(t, string(nu), "\"new\"")
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme challenge: force"),
|
||||
},
|
||||
"fail/db.CmpAndSwap-false-marshaling-old": {
|
||||
nu: "new",
|
||||
old: "old",
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), "id")
|
||||
assert.Equals(t, string(old), "\"old\"")
|
||||
assert.Equals(t, string(nu), "\"new\"")
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: errors.New("error saving acme challenge; changed since last read"),
|
||||
},
|
||||
"ok": {
|
||||
nu: "new",
|
||||
old: "old",
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), "id")
|
||||
assert.Equals(t, string(old), "\"old\"")
|
||||
assert.Equals(t, string(nu), "\"new\"")
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
"ok/nils": {
|
||||
nu: nil,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, challengeTable)
|
||||
assert.Equals(t, string(key), "id")
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, nu, nil)
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := &DB{db: tc.db}
|
||||
if err := d.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,187 +0,0 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Mutex for locking ordersByAccount index operations.
|
||||
var ordersByAccountMux sync.Mutex
|
||||
|
||||
type dbOrder struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
ProvisionerID string `json:"provisionerID"`
|
||||
Identifiers []acme.Identifier `json:"identifiers"`
|
||||
AuthorizationIDs []string `json:"authorizationIDs"`
|
||||
Status acme.Status `json:"status"`
|
||||
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ExpiresAt time.Time `json:"expiresAt,omitempty"`
|
||||
CertificateID string `json:"certificate,omitempty"`
|
||||
Error *acme.Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (a *dbOrder) clone() *dbOrder {
|
||||
b := *a
|
||||
return &b
|
||||
}
|
||||
|
||||
// getDBOrder retrieves and unmarshals an ACME Order type from the database.
|
||||
func (db *DB) getDBOrder(_ context.Context, id string) (*dbOrder, error) {
|
||||
b, err := db.db.Get(orderTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading order %s", id)
|
||||
}
|
||||
o := new(dbOrder)
|
||||
if err := json.Unmarshal(b, &o); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling order %s into dbOrder", id)
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// GetOrder retrieves an ACME Order from the database.
|
||||
func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) {
|
||||
dbo, err := db.getDBOrder(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o := &acme.Order{
|
||||
ID: dbo.ID,
|
||||
AccountID: dbo.AccountID,
|
||||
ProvisionerID: dbo.ProvisionerID,
|
||||
CertificateID: dbo.CertificateID,
|
||||
Status: dbo.Status,
|
||||
ExpiresAt: dbo.ExpiresAt,
|
||||
Identifiers: dbo.Identifiers,
|
||||
NotBefore: dbo.NotBefore,
|
||||
NotAfter: dbo.NotAfter,
|
||||
AuthorizationIDs: dbo.AuthorizationIDs,
|
||||
Error: dbo.Error,
|
||||
}
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// CreateOrder creates ACME Order resources and saves them to the DB.
|
||||
func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error {
|
||||
var err error
|
||||
o.ID, err = randID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
dbo := &dbOrder{
|
||||
ID: o.ID,
|
||||
AccountID: o.AccountID,
|
||||
ProvisionerID: o.ProvisionerID,
|
||||
Status: o.Status,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: o.ExpiresAt,
|
||||
Identifiers: o.Identifiers,
|
||||
NotBefore: o.NotBefore,
|
||||
NotAfter: o.NotAfter,
|
||||
AuthorizationIDs: o.AuthorizationIDs,
|
||||
}
|
||||
if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.updateAddOrderIDs(ctx, o.AccountID, o.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateOrder saves an updated ACME Order to the database.
|
||||
func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error {
|
||||
old, err := db.getDBOrder(ctx, o.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
|
||||
nu.Status = o.Status
|
||||
nu.Error = o.Error
|
||||
nu.CertificateID = o.CertificateID
|
||||
return db.save(ctx, old.ID, nu, old, "order", orderTable)
|
||||
}
|
||||
|
||||
func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) {
|
||||
ordersByAccountMux.Lock()
|
||||
defer ordersByAccountMux.Unlock()
|
||||
|
||||
var oldOids []string
|
||||
b, err := db.db.Get(ordersByAccountIDTable, []byte(accID))
|
||||
if err != nil {
|
||||
if !nosql.IsErrNotFound(err) {
|
||||
return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID)
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(b, &oldOids); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove any order that is not in PENDING state and update the stored list
|
||||
// before returning.
|
||||
//
|
||||
// According to RFC 8555:
|
||||
// The server SHOULD include pending orders and SHOULD NOT include orders
|
||||
// that are invalid in the array of URLs.
|
||||
pendOids := []string{}
|
||||
for _, oid := range oldOids {
|
||||
o, err := db.GetOrder(ctx, oid)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID)
|
||||
}
|
||||
if err = o.UpdateStatus(ctx, db); err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error updating order %s for account %s", oid, accID)
|
||||
}
|
||||
if o.Status == acme.StatusPending {
|
||||
pendOids = append(pendOids, oid)
|
||||
}
|
||||
}
|
||||
pendOids = append(pendOids, addOids...)
|
||||
var (
|
||||
_old interface{} = oldOids
|
||||
_new interface{} = pendOids
|
||||
)
|
||||
switch {
|
||||
case len(oldOids) == 0 && len(pendOids) == 0:
|
||||
// If list has not changed from empty, then no need to write the DB.
|
||||
return []string{}, nil
|
||||
case len(oldOids) == 0:
|
||||
_old = nil
|
||||
case len(pendOids) == 0:
|
||||
_new = nil
|
||||
}
|
||||
if err = db.save(ctx, accID, _new, _old, "orderIDsByAccountID", ordersByAccountIDTable); err != nil {
|
||||
// Delete all orders that may have been previously stored if orderIDsByAccountID update fails.
|
||||
for _, oid := range addOids {
|
||||
// Ignore error from delete -- we tried our best.
|
||||
// TODO when we have logging w/ request ID tracking, logging this error.
|
||||
db.db.Del(orderTable, []byte(oid))
|
||||
}
|
||||
return nil, errors.Wrapf(err, "error saving orderIDs index for account %s", accID)
|
||||
}
|
||||
return pendOids, nil
|
||||
}
|
||||
|
||||
// GetOrdersByAccountID returns a list of order IDs owned by the account.
|
||||
func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
|
||||
return db.updateAddOrderIDs(ctx, accID)
|
||||
}
|
File diff suppressed because it is too large
Load diff
120
acme/directory.go
Normal file
120
acme/directory.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Directory represents an ACME directory for configuring clients.
|
||||
type Directory struct {
|
||||
NewNonce string `json:"newNonce,omitempty"`
|
||||
NewAccount string `json:"newAccount,omitempty"`
|
||||
NewOrder string `json:"newOrder,omitempty"`
|
||||
NewAuthz string `json:"newAuthz,omitempty"`
|
||||
RevokeCert string `json:"revokeCert,omitempty"`
|
||||
KeyChange string `json:"keyChange,omitempty"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging for the Directory type.
|
||||
func (d *Directory) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
type directory struct {
|
||||
prefix, dns string
|
||||
}
|
||||
|
||||
// newDirectory returns a new Directory type.
|
||||
func newDirectory(dns, prefix string) *directory {
|
||||
return &directory{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Link captures the link type.
|
||||
type Link int
|
||||
|
||||
const (
|
||||
// NewNonceLink new-nonce
|
||||
NewNonceLink Link = iota
|
||||
// NewAccountLink new-account
|
||||
NewAccountLink
|
||||
// AccountLink account
|
||||
AccountLink
|
||||
// OrderLink order
|
||||
OrderLink
|
||||
// NewOrderLink new-order
|
||||
NewOrderLink
|
||||
// OrdersByAccountLink list of orders owned by account
|
||||
OrdersByAccountLink
|
||||
// FinalizeLink finalize order
|
||||
FinalizeLink
|
||||
// NewAuthzLink authz
|
||||
NewAuthzLink
|
||||
// AuthzLink new-authz
|
||||
AuthzLink
|
||||
// ChallengeLink challenge
|
||||
ChallengeLink
|
||||
// CertificateLink certificate
|
||||
CertificateLink
|
||||
// DirectoryLink directory
|
||||
DirectoryLink
|
||||
// RevokeCertLink revoke certificate
|
||||
RevokeCertLink
|
||||
// KeyChangeLink key rollover
|
||||
KeyChangeLink
|
||||
)
|
||||
|
||||
func (l Link) String() string {
|
||||
switch l {
|
||||
case NewNonceLink:
|
||||
return "new-nonce"
|
||||
case NewAccountLink:
|
||||
return "new-account"
|
||||
case AccountLink:
|
||||
return "account"
|
||||
case NewOrderLink:
|
||||
return "new-order"
|
||||
case OrderLink:
|
||||
return "order"
|
||||
case NewAuthzLink:
|
||||
return "new-authz"
|
||||
case AuthzLink:
|
||||
return "authz"
|
||||
case ChallengeLink:
|
||||
return "challenge"
|
||||
case CertificateLink:
|
||||
return "certificate"
|
||||
case DirectoryLink:
|
||||
return "directory"
|
||||
case RevokeCertLink:
|
||||
return "revoke-cert"
|
||||
case KeyChangeLink:
|
||||
return "key-change"
|
||||
default:
|
||||
return "unexpected"
|
||||
}
|
||||
}
|
||||
|
||||
// getLink returns an absolute or partial path to the given resource.
|
||||
func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs ...string) string {
|
||||
var link string
|
||||
switch typ {
|
||||
case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
|
||||
link = fmt.Sprintf("/%s/%s", provisionerName, typ.String())
|
||||
case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0])
|
||||
case OrdersByAccountLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0])
|
||||
case FinalizeLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
|
||||
}
|
||||
if abs {
|
||||
return fmt.Sprintf("https://%s/%s%s", d.dns, d.prefix, link)
|
||||
}
|
||||
return link
|
||||
}
|
60
acme/directory_test.go
Normal file
60
acme/directory_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestDirectoryGetLink(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
dir := newDirectory(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
provID := URLSafeProvisionerName(prov)
|
||||
|
||||
assert.Equals(t, dir.getLink(NewNonceLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", provID))
|
||||
assert.Equals(t, dir.getLink(NewNonceLink, provID, false), fmt.Sprintf("/%s/new-nonce", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(NewAccountLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID))
|
||||
assert.Equals(t, dir.getLink(NewAccountLink, provID, false), fmt.Sprintf("/%s/new-account", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(AccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provID))
|
||||
assert.Equals(t, dir.getLink(AccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(NewOrderLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", provID))
|
||||
assert.Equals(t, dir.getLink(NewOrderLink, provID, false), fmt.Sprintf("/%s/new-order", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(OrderLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234", provID))
|
||||
assert.Equals(t, dir.getLink(OrderLink, provID, false, id), fmt.Sprintf("/%s/order/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234/orders", provID))
|
||||
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234/orders", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(FinalizeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234/finalize", provID))
|
||||
assert.Equals(t, dir.getLink(FinalizeLink, provID, false, id), fmt.Sprintf("/%s/order/1234/finalize", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(NewAuthzLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-authz", provID))
|
||||
assert.Equals(t, dir.getLink(NewAuthzLink, provID, false), fmt.Sprintf("/%s/new-authz", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(AuthzLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/authz/1234", provID))
|
||||
assert.Equals(t, dir.getLink(AuthzLink, provID, false, id), fmt.Sprintf("/%s/authz/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(DirectoryLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", provID))
|
||||
assert.Equals(t, dir.getLink(DirectoryLink, provID, false), fmt.Sprintf("/%s/directory", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(RevokeCertLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", provID))
|
||||
assert.Equals(t, dir.getLink(RevokeCertLink, provID, false), fmt.Sprintf("/%s/revoke-cert", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(KeyChangeLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", provID))
|
||||
assert.Equals(t, dir.getLink(KeyChangeLink, provID, false), fmt.Sprintf("/%s/key-change", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(ChallengeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/1234", provID))
|
||||
assert.Equals(t, dir.getLink(ChallengeLink, provID, false, id), fmt.Sprintf("/%s/challenge/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(CertificateLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/1234", provID))
|
||||
assert.Equals(t, dir.getLink(CertificateLink, provID, false, id), fmt.Sprintf("/%s/certificate/1234", provID))
|
||||
}
|
773
acme/errors.go
773
acme/errors.go
|
@ -1,381 +1,385 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
)
|
||||
|
||||
// ProblemType is the type of the ACME problem.
|
||||
type ProblemType int
|
||||
|
||||
const (
|
||||
// ErrorAccountDoesNotExistType request specified an account that does not exist
|
||||
ErrorAccountDoesNotExistType ProblemType = iota
|
||||
// ErrorAlreadyRevokedType request specified a certificate to be revoked that has already been revoked
|
||||
ErrorAlreadyRevokedType
|
||||
// ErrorBadAttestationStatementType WebAuthn attestation statement could not be verified
|
||||
ErrorBadAttestationStatementType
|
||||
// ErrorBadCSRType CSR is unacceptable (e.g., due to a short key)
|
||||
ErrorBadCSRType
|
||||
// ErrorBadNonceType client sent an unacceptable anti-replay nonce
|
||||
ErrorBadNonceType
|
||||
// ErrorBadPublicKeyType JWS was signed by a public key the server does not support
|
||||
ErrorBadPublicKeyType
|
||||
// ErrorBadRevocationReasonType revocation reason provided is not allowed by the server
|
||||
ErrorBadRevocationReasonType
|
||||
// ErrorBadSignatureAlgorithmType JWS was signed with an algorithm the server does not support
|
||||
ErrorBadSignatureAlgorithmType
|
||||
// ErrorCaaType Authority Authorization (CAA) records forbid the CA from issuing a certificate
|
||||
ErrorCaaType
|
||||
// ErrorCompoundType error conditions are indicated in the “subproblems” array.
|
||||
ErrorCompoundType
|
||||
// ErrorConnectionType server could not connect to validation target
|
||||
ErrorConnectionType
|
||||
// ErrorDNSType was a problem with a DNS query during identifier validation
|
||||
ErrorDNSType
|
||||
// ErrorExternalAccountRequiredType request must include a value for the “externalAccountBinding” field
|
||||
ErrorExternalAccountRequiredType
|
||||
// ErrorIncorrectResponseType received didn’t match the challenge’s requirements
|
||||
ErrorIncorrectResponseType
|
||||
// ErrorInvalidContactType URL for an account was invalid
|
||||
ErrorInvalidContactType
|
||||
// ErrorMalformedType request message was malformed
|
||||
ErrorMalformedType
|
||||
// ErrorOrderNotReadyType request attempted to finalize an order that is not ready to be finalized
|
||||
ErrorOrderNotReadyType
|
||||
// ErrorRateLimitedType request exceeds a rate limit
|
||||
ErrorRateLimitedType
|
||||
// ErrorRejectedIdentifierType server will not issue certificates for the identifier
|
||||
ErrorRejectedIdentifierType
|
||||
// ErrorServerInternalType server experienced an internal error
|
||||
ErrorServerInternalType
|
||||
// ErrorTLSType server received a TLS error during validation
|
||||
ErrorTLSType
|
||||
// ErrorUnauthorizedType client lacks sufficient authorization
|
||||
ErrorUnauthorizedType
|
||||
// ErrorUnsupportedContactType URL for an account used an unsupported protocol scheme
|
||||
ErrorUnsupportedContactType
|
||||
// ErrorUnsupportedIdentifierType identifier is of an unsupported type
|
||||
ErrorUnsupportedIdentifierType
|
||||
// ErrorUserActionRequiredType the “instance” URL and take actions specified there
|
||||
ErrorUserActionRequiredType
|
||||
// ErrorNotImplementedType operation is not implemented
|
||||
ErrorNotImplementedType
|
||||
)
|
||||
|
||||
// String returns the string representation of the acme problem type,
|
||||
// fulfilling the Stringer interface.
|
||||
func (ap ProblemType) String() string {
|
||||
switch ap {
|
||||
case ErrorAccountDoesNotExistType:
|
||||
return "accountDoesNotExist"
|
||||
case ErrorAlreadyRevokedType:
|
||||
return "alreadyRevoked"
|
||||
case ErrorBadAttestationStatementType:
|
||||
return "badAttestationStatement"
|
||||
case ErrorBadCSRType:
|
||||
return "badCSR"
|
||||
case ErrorBadNonceType:
|
||||
return "badNonce"
|
||||
case ErrorBadPublicKeyType:
|
||||
return "badPublicKey"
|
||||
case ErrorBadRevocationReasonType:
|
||||
return "badRevocationReason"
|
||||
case ErrorBadSignatureAlgorithmType:
|
||||
return "badSignatureAlgorithm"
|
||||
case ErrorCaaType:
|
||||
return "caa"
|
||||
case ErrorCompoundType:
|
||||
return "compound"
|
||||
case ErrorConnectionType:
|
||||
return "connection"
|
||||
case ErrorDNSType:
|
||||
return "dns"
|
||||
case ErrorExternalAccountRequiredType:
|
||||
return "externalAccountRequired"
|
||||
case ErrorInvalidContactType:
|
||||
return "incorrectResponse"
|
||||
case ErrorMalformedType:
|
||||
return "malformed"
|
||||
case ErrorOrderNotReadyType:
|
||||
return "orderNotReady"
|
||||
case ErrorRateLimitedType:
|
||||
return "rateLimited"
|
||||
case ErrorRejectedIdentifierType:
|
||||
return "rejectedIdentifier"
|
||||
case ErrorServerInternalType:
|
||||
return "serverInternal"
|
||||
case ErrorTLSType:
|
||||
return "tls"
|
||||
case ErrorUnauthorizedType:
|
||||
return "unauthorized"
|
||||
case ErrorUnsupportedContactType:
|
||||
return "unsupportedContact"
|
||||
case ErrorUnsupportedIdentifierType:
|
||||
return "unsupportedIdentifier"
|
||||
case ErrorUserActionRequiredType:
|
||||
return "userActionRequired"
|
||||
case ErrorNotImplementedType:
|
||||
return "notImplemented"
|
||||
default:
|
||||
return fmt.Sprintf("unsupported type ACME error type '%d'", int(ap))
|
||||
}
|
||||
}
|
||||
|
||||
type errorMetadata struct {
|
||||
details string
|
||||
status int
|
||||
typ string
|
||||
String string
|
||||
}
|
||||
|
||||
var (
|
||||
officialACMEPrefix = "urn:ietf:params:acme:error:"
|
||||
errorServerInternalMetadata = errorMetadata{
|
||||
typ: officialACMEPrefix + ErrorServerInternalType.String(),
|
||||
details: "The server experienced an internal error",
|
||||
status: 500,
|
||||
}
|
||||
errorMap = map[ProblemType]errorMetadata{
|
||||
ErrorAccountDoesNotExistType: {
|
||||
typ: officialACMEPrefix + ErrorAccountDoesNotExistType.String(),
|
||||
details: "Account does not exist",
|
||||
status: 400,
|
||||
},
|
||||
ErrorAlreadyRevokedType: {
|
||||
typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(),
|
||||
details: "Certificate already revoked",
|
||||
status: 400,
|
||||
},
|
||||
ErrorBadCSRType: {
|
||||
typ: officialACMEPrefix + ErrorBadCSRType.String(),
|
||||
details: "The CSR is unacceptable",
|
||||
status: 400,
|
||||
},
|
||||
ErrorBadNonceType: {
|
||||
typ: officialACMEPrefix + ErrorBadNonceType.String(),
|
||||
details: "Unacceptable anti-replay nonce",
|
||||
status: 400,
|
||||
},
|
||||
ErrorBadPublicKeyType: {
|
||||
typ: officialACMEPrefix + ErrorBadPublicKeyType.String(),
|
||||
details: "The jws was signed by a public key the server does not support",
|
||||
status: 400,
|
||||
},
|
||||
ErrorBadRevocationReasonType: {
|
||||
typ: officialACMEPrefix + ErrorBadRevocationReasonType.String(),
|
||||
details: "The revocation reason provided is not allowed by the server",
|
||||
status: 400,
|
||||
},
|
||||
ErrorBadSignatureAlgorithmType: {
|
||||
typ: officialACMEPrefix + ErrorBadSignatureAlgorithmType.String(),
|
||||
details: "The JWS was signed with an algorithm the server does not support",
|
||||
status: 400,
|
||||
},
|
||||
ErrorBadAttestationStatementType: {
|
||||
typ: officialACMEPrefix + ErrorBadAttestationStatementType.String(),
|
||||
details: "Attestation statement cannot be verified",
|
||||
status: 400,
|
||||
},
|
||||
ErrorCaaType: {
|
||||
typ: officialACMEPrefix + ErrorCaaType.String(),
|
||||
details: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
|
||||
status: 400,
|
||||
},
|
||||
ErrorCompoundType: {
|
||||
typ: officialACMEPrefix + ErrorCompoundType.String(),
|
||||
details: "Specific error conditions are indicated in the “subproblems” array",
|
||||
status: 400,
|
||||
},
|
||||
ErrorConnectionType: {
|
||||
typ: officialACMEPrefix + ErrorConnectionType.String(),
|
||||
details: "The server could not connect to validation target",
|
||||
status: 400,
|
||||
},
|
||||
ErrorDNSType: {
|
||||
typ: officialACMEPrefix + ErrorDNSType.String(),
|
||||
details: "There was a problem with a DNS query during identifier validation",
|
||||
status: 400,
|
||||
},
|
||||
ErrorExternalAccountRequiredType: {
|
||||
typ: officialACMEPrefix + ErrorExternalAccountRequiredType.String(),
|
||||
details: "The request must include a value for the \"externalAccountBinding\" field",
|
||||
status: 400,
|
||||
},
|
||||
ErrorIncorrectResponseType: {
|
||||
typ: officialACMEPrefix + ErrorIncorrectResponseType.String(),
|
||||
details: "Response received didn't match the challenge's requirements",
|
||||
status: 400,
|
||||
},
|
||||
ErrorInvalidContactType: {
|
||||
typ: officialACMEPrefix + ErrorInvalidContactType.String(),
|
||||
details: "A contact URL for an account was invalid",
|
||||
status: 400,
|
||||
},
|
||||
ErrorMalformedType: {
|
||||
typ: officialACMEPrefix + ErrorMalformedType.String(),
|
||||
details: "The request message was malformed",
|
||||
status: 400,
|
||||
},
|
||||
ErrorOrderNotReadyType: {
|
||||
typ: officialACMEPrefix + ErrorOrderNotReadyType.String(),
|
||||
details: "The request attempted to finalize an order that is not ready to be finalized",
|
||||
status: 400,
|
||||
},
|
||||
ErrorRateLimitedType: {
|
||||
typ: officialACMEPrefix + ErrorRateLimitedType.String(),
|
||||
details: "The request exceeds a rate limit",
|
||||
status: 400,
|
||||
},
|
||||
ErrorRejectedIdentifierType: {
|
||||
typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(),
|
||||
details: "The server will not issue certificates for the identifier",
|
||||
status: 400,
|
||||
},
|
||||
ErrorNotImplementedType: {
|
||||
typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(),
|
||||
details: "The requested operation is not implemented",
|
||||
status: 501,
|
||||
},
|
||||
ErrorTLSType: {
|
||||
typ: officialACMEPrefix + ErrorTLSType.String(),
|
||||
details: "The server received a TLS error during validation",
|
||||
status: 400,
|
||||
},
|
||||
ErrorUnauthorizedType: {
|
||||
typ: officialACMEPrefix + ErrorUnauthorizedType.String(),
|
||||
details: "The client lacks sufficient authorization",
|
||||
status: 401,
|
||||
},
|
||||
ErrorUnsupportedContactType: {
|
||||
typ: officialACMEPrefix + ErrorUnsupportedContactType.String(),
|
||||
details: "A contact URL for an account used an unsupported protocol scheme",
|
||||
status: 400,
|
||||
},
|
||||
ErrorUnsupportedIdentifierType: {
|
||||
typ: officialACMEPrefix + ErrorUnsupportedIdentifierType.String(),
|
||||
details: "An identifier is of an unsupported type",
|
||||
status: 400,
|
||||
},
|
||||
ErrorUserActionRequiredType: {
|
||||
typ: officialACMEPrefix + ErrorUserActionRequiredType.String(),
|
||||
details: "Visit the “instance” URL and take actions specified there",
|
||||
status: 400,
|
||||
},
|
||||
ErrorServerInternalType: errorServerInternalMetadata,
|
||||
}
|
||||
)
|
||||
|
||||
// Error represents an ACME Error
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Subproblems []Subproblem `json:"subproblems,omitempty"`
|
||||
Err error `json:"-"`
|
||||
Status int `json:"-"`
|
||||
}
|
||||
|
||||
// Subproblem represents an ACME subproblem. It's fairly
|
||||
// similar to an ACME error, but differs in that it can't
|
||||
// include subproblems itself, the error is reflected
|
||||
// in the Detail property and doesn't have a Status.
|
||||
type Subproblem struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
// The "identifier" field MUST NOT be present at the top level in ACME
|
||||
// problem documents. It can only be present in subproblems.
|
||||
// Subproblems need not all have the same type, and they do not need to
|
||||
// match the top level type.
|
||||
Identifier *Identifier `json:"identifier,omitempty"`
|
||||
}
|
||||
|
||||
// AddSubproblems adds the Subproblems to Error. It
|
||||
// returns the Error, allowing for fluent addition.
|
||||
func (e *Error) AddSubproblems(subproblems ...Subproblem) *Error {
|
||||
e.Subproblems = append(e.Subproblems, subproblems...)
|
||||
return e
|
||||
}
|
||||
|
||||
// NewError creates a new Error type.
|
||||
func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
|
||||
return newError(pt, errors.Errorf(msg, args...))
|
||||
}
|
||||
|
||||
// NewSubproblem creates a new Subproblem. The msg and args
|
||||
// are used to create a new error, which is set as the Detail, allowing
|
||||
// for more detailed error messages to be returned to the ACME client.
|
||||
func NewSubproblem(pt ProblemType, msg string, args ...interface{}) Subproblem {
|
||||
e := newError(pt, fmt.Errorf(msg, args...))
|
||||
s := Subproblem{
|
||||
Type: e.Type,
|
||||
Detail: e.Err.Error(),
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// NewSubproblemWithIdentifier creates a new Subproblem with a specific ACME
|
||||
// Identifier. It calls NewSubproblem and sets the Identifier.
|
||||
func NewSubproblemWithIdentifier(pt ProblemType, identifier Identifier, msg string, args ...interface{}) Subproblem {
|
||||
s := NewSubproblem(pt, msg, args...)
|
||||
s.Identifier = &identifier
|
||||
return s
|
||||
}
|
||||
|
||||
func newError(pt ProblemType, err error) *Error {
|
||||
meta, ok := errorMap[pt]
|
||||
if !ok {
|
||||
meta = errorServerInternalMetadata
|
||||
return &Error{
|
||||
Type: meta.typ,
|
||||
Detail: meta.details,
|
||||
Status: meta.status,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// AccountDoesNotExistErr returns a new acme error.
|
||||
func AccountDoesNotExistErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: meta.typ,
|
||||
Detail: meta.details,
|
||||
Status: meta.status,
|
||||
Type: accountDoesNotExistErr,
|
||||
Detail: "Account does not exist",
|
||||
Status: 404,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorISE creates a new ErrorServerInternalType Error.
|
||||
func NewErrorISE(msg string, args ...interface{}) *Error {
|
||||
return NewError(ErrorServerInternalType, msg, args...)
|
||||
}
|
||||
|
||||
// WrapError attempts to wrap the internal error.
|
||||
func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error {
|
||||
var e *Error
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.As(err, &e):
|
||||
if e.Err == nil {
|
||||
e.Err = errors.Errorf(msg+"; "+e.Detail, args...)
|
||||
} else {
|
||||
e.Err = errors.Wrapf(e.Err, msg, args...)
|
||||
}
|
||||
return e
|
||||
default:
|
||||
return newError(typ, errors.Wrapf(err, msg, args...))
|
||||
// AlreadyRevokedErr returns a new acme error.
|
||||
func AlreadyRevokedErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: alreadyRevokedErr,
|
||||
Detail: "Certificate already revoked",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// WrapErrorISE shortcut to wrap an internal server error type.
|
||||
func WrapErrorISE(err error, msg string, args ...interface{}) *Error {
|
||||
return WrapError(ErrorServerInternalType, err, msg, args...)
|
||||
// BadCSRErr returns a new acme error.
|
||||
func BadCSRErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badCSRErr,
|
||||
Detail: "The CSR is unacceptable",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// StatusCode returns the status code and implements the StatusCoder interface.
|
||||
func (e *Error) StatusCode() int {
|
||||
return e.Status
|
||||
// BadNonceErr returns a new acme error.
|
||||
func BadNonceErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badNonceErr,
|
||||
Detail: "Unacceptable anti-replay nonce",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// BadPublicKeyErr returns a new acme error.
|
||||
func BadPublicKeyErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badPublicKeyErr,
|
||||
Detail: "The jws was signed by a public key the server does not support",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// BadRevocationReasonErr returns a new acme error.
|
||||
func BadRevocationReasonErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badRevocationReasonErr,
|
||||
Detail: "The revocation reason provided is not allowed by the server",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// BadSignatureAlgorithmErr returns a new acme error.
|
||||
func BadSignatureAlgorithmErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badSignatureAlgorithmErr,
|
||||
Detail: "The JWS was signed with an algorithm the server does not support",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// CaaErr returns a new acme error.
|
||||
func CaaErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: caaErr,
|
||||
Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// CompoundErr returns a new acme error.
|
||||
func CompoundErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: compoundErr,
|
||||
Detail: "Specific error conditions are indicated in the “subproblems” array",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectionErr returns a new acme error.
|
||||
func ConnectionErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: connectionErr,
|
||||
Detail: "The server could not connect to validation target",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// DNSErr returns a new acme error.
|
||||
func DNSErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: dnsErr,
|
||||
Detail: "There was a problem with a DNS query during identifier validation",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ExternalAccountRequiredErr returns a new acme error.
|
||||
func ExternalAccountRequiredErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: externalAccountRequiredErr,
|
||||
Detail: "The request must include a value for the \"externalAccountBinding\" field",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// IncorrectResponseErr returns a new acme error.
|
||||
func IncorrectResponseErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: incorrectResponseErr,
|
||||
Detail: "Response received didn't match the challenge's requirements",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidContactErr returns a new acme error.
|
||||
func InvalidContactErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: invalidContactErr,
|
||||
Detail: "A contact URL for an account was invalid",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// MalformedErr returns a new acme error.
|
||||
func MalformedErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: malformedErr,
|
||||
Detail: "The request message was malformed",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// OrderNotReadyErr returns a new acme error.
|
||||
func OrderNotReadyErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: orderNotReadyErr,
|
||||
Detail: "The request attempted to finalize an order that is not ready to be finalized",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitedErr returns a new acme error.
|
||||
func RateLimitedErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: rateLimitedErr,
|
||||
Detail: "The request exceeds a rate limit",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// RejectedIdentifierErr returns a new acme error.
|
||||
func RejectedIdentifierErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: rejectedIdentifierErr,
|
||||
Detail: "The server will not issue certificates for the identifier",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ServerInternalErr returns a new acme error.
|
||||
func ServerInternalErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: serverInternalErr,
|
||||
Detail: "The server experienced an internal error",
|
||||
Status: 500,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// TLSErr returns a new acme error.
|
||||
func TLSErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: tlsErr,
|
||||
Detail: "The server received a TLS error during validation",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// UnauthorizedErr returns a new acme error.
|
||||
func UnauthorizedErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: unauthorizedErr,
|
||||
Detail: "The client lacks sufficient authorization",
|
||||
Status: 401,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// UnsupportedContactErr returns a new acme error.
|
||||
func UnsupportedContactErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: unsupportedContactErr,
|
||||
Detail: "A contact URL for an account used an unsupported protocol scheme",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// UnsupportedIdentifierErr returns a new acme error.
|
||||
func UnsupportedIdentifierErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: unsupportedIdentifierErr,
|
||||
Detail: "An identifier is of an unsupported type",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// UserActionRequiredErr returns a new acme error.
|
||||
func UserActionRequiredErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: userActionRequiredErr,
|
||||
Detail: "Visit the “instance” URL and take actions specified there",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ProbType is the type of the ACME problem.
|
||||
type ProbType int
|
||||
|
||||
const (
|
||||
// The request specified an account that does not exist
|
||||
accountDoesNotExistErr ProbType = iota
|
||||
// The request specified a certificate to be revoked that has already been revoked
|
||||
alreadyRevokedErr
|
||||
// The CSR is unacceptable (e.g., due to a short key)
|
||||
badCSRErr
|
||||
// The client sent an unacceptable anti-replay nonce
|
||||
badNonceErr
|
||||
// The JWS was signed by a public key the server does not support
|
||||
badPublicKeyErr
|
||||
// The revocation reason provided is not allowed by the server
|
||||
badRevocationReasonErr
|
||||
// The JWS was signed with an algorithm the server does not support
|
||||
badSignatureAlgorithmErr
|
||||
// Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate
|
||||
caaErr
|
||||
// Specific error conditions are indicated in the “subproblems” array.
|
||||
compoundErr
|
||||
// The server could not connect to validation target
|
||||
connectionErr
|
||||
// There was a problem with a DNS query during identifier validation
|
||||
dnsErr
|
||||
// The request must include a value for the “externalAccountBinding” field
|
||||
externalAccountRequiredErr
|
||||
// Response received didn’t match the challenge’s requirements
|
||||
incorrectResponseErr
|
||||
// A contact URL for an account was invalid
|
||||
invalidContactErr
|
||||
// The request message was malformed
|
||||
malformedErr
|
||||
// The request attempted to finalize an order that is not ready to be finalized
|
||||
orderNotReadyErr
|
||||
// The request exceeds a rate limit
|
||||
rateLimitedErr
|
||||
// The server will not issue certificates for the identifier
|
||||
rejectedIdentifierErr
|
||||
// The server experienced an internal error
|
||||
serverInternalErr
|
||||
// The server received a TLS error during validation
|
||||
tlsErr
|
||||
// The client lacks sufficient authorization
|
||||
unauthorizedErr
|
||||
// A contact URL for an account used an unsupported protocol scheme
|
||||
unsupportedContactErr
|
||||
// An identifier is of an unsupported type
|
||||
unsupportedIdentifierErr
|
||||
// Visit the “instance” URL and take actions specified there
|
||||
userActionRequiredErr
|
||||
)
|
||||
|
||||
// String returns the string representation of the acme problem type,
|
||||
// fulfilling the Stringer interface.
|
||||
func (ap ProbType) String() string {
|
||||
switch ap {
|
||||
case accountDoesNotExistErr:
|
||||
return "accountDoesNotExist"
|
||||
case alreadyRevokedErr:
|
||||
return "alreadyRevoked"
|
||||
case badCSRErr:
|
||||
return "badCSR"
|
||||
case badNonceErr:
|
||||
return "badNonce"
|
||||
case badPublicKeyErr:
|
||||
return "badPublicKey"
|
||||
case badRevocationReasonErr:
|
||||
return "badRevocationReason"
|
||||
case badSignatureAlgorithmErr:
|
||||
return "badSignatureAlgorithm"
|
||||
case caaErr:
|
||||
return "caa"
|
||||
case compoundErr:
|
||||
return "compound"
|
||||
case connectionErr:
|
||||
return "connection"
|
||||
case dnsErr:
|
||||
return "dns"
|
||||
case externalAccountRequiredErr:
|
||||
return "externalAccountRequired"
|
||||
case incorrectResponseErr:
|
||||
return "incorrectResponse"
|
||||
case invalidContactErr:
|
||||
return "invalidContact"
|
||||
case malformedErr:
|
||||
return "malformed"
|
||||
case orderNotReadyErr:
|
||||
return "orderNotReady"
|
||||
case rateLimitedErr:
|
||||
return "rateLimited"
|
||||
case rejectedIdentifierErr:
|
||||
return "rejectedIdentifier"
|
||||
case serverInternalErr:
|
||||
return "serverInternal"
|
||||
case tlsErr:
|
||||
return "tls"
|
||||
case unauthorizedErr:
|
||||
return "unauthorized"
|
||||
case unsupportedContactErr:
|
||||
return "unsupportedContact"
|
||||
case unsupportedIdentifierErr:
|
||||
return "unsupportedIdentifier"
|
||||
case userActionRequiredErr:
|
||||
return "userActionRequired"
|
||||
default:
|
||||
return "unsupported type"
|
||||
}
|
||||
}
|
||||
|
||||
// Error is an ACME error type complete with problem document.
|
||||
type Error struct {
|
||||
Type ProbType
|
||||
Detail string
|
||||
Err error
|
||||
Status int
|
||||
Sub []*Error
|
||||
Identifier *Identifier
|
||||
}
|
||||
|
||||
// Wrap attempts to wrap the internal error.
|
||||
func Wrap(err error, wrap string) *Error {
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case *Error:
|
||||
if e.Err == nil {
|
||||
e.Err = errors.New(wrap + "; " + e.Detail)
|
||||
} else {
|
||||
e.Err = errors.Wrap(e.Err, wrap)
|
||||
}
|
||||
return e
|
||||
default:
|
||||
return ServerInternalErr(errors.Wrap(err, wrap))
|
||||
}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
|
@ -394,17 +398,42 @@ func (e *Error) Cause() error {
|
|||
return e.Err
|
||||
}
|
||||
|
||||
// ToLog implements the EnableLogger interface.
|
||||
func (e *Error) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return nil, WrapErrorISE(err, "error marshaling acme.Error for logging")
|
||||
// ToACME returns an acme representation of the problem type.
|
||||
func (e *Error) ToACME() *AError {
|
||||
ae := &AError{
|
||||
Type: "urn:ietf:params:acme:error:" + e.Type.String(),
|
||||
Detail: e.Error(),
|
||||
Status: e.Status,
|
||||
}
|
||||
return string(b), nil
|
||||
if e.Identifier != nil {
|
||||
ae.Identifier = *e.Identifier
|
||||
}
|
||||
for _, p := range e.Sub {
|
||||
ae.Subproblems = append(ae.Subproblems, p.ToACME())
|
||||
}
|
||||
return ae
|
||||
}
|
||||
|
||||
// Render implements render.RenderableError for Error.
|
||||
func (e *Error) Render(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/problem+json")
|
||||
render.JSONStatus(w, e, e.StatusCode())
|
||||
// StatusCode returns the status code and implements the StatusCode interface.
|
||||
func (e *Error) StatusCode() int {
|
||||
return e.Status
|
||||
}
|
||||
|
||||
// AError is the error type as seen in acme request/responses.
|
||||
type AError struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Identifier interface{} `json:"identifier,omitempty"`
|
||||
Subproblems []interface{} `json:"subproblems,omitempty"`
|
||||
Status int `json:"-"`
|
||||
}
|
||||
|
||||
// Error allows AError to implement the error interface.
|
||||
func (ae *AError) Error() string {
|
||||
return ae.Detail
|
||||
}
|
||||
|
||||
// StatusCode returns the status code and implements the StatusCode interface.
|
||||
func (ae *AError) StatusCode() int {
|
||||
return ae.Status
|
||||
}
|
||||
|
|
263
acme/linker.go
263
acme/linker.go
|
@ -1,263 +0,0 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// LinkType captures the link type.
|
||||
type LinkType int
|
||||
|
||||
const (
|
||||
// NewNonceLinkType new-nonce
|
||||
NewNonceLinkType LinkType = iota
|
||||
// NewAccountLinkType new-account
|
||||
NewAccountLinkType
|
||||
// AccountLinkType account
|
||||
AccountLinkType
|
||||
// OrderLinkType order
|
||||
OrderLinkType
|
||||
// NewOrderLinkType new-order
|
||||
NewOrderLinkType
|
||||
// OrdersByAccountLinkType list of orders owned by account
|
||||
OrdersByAccountLinkType
|
||||
// FinalizeLinkType finalize order
|
||||
FinalizeLinkType
|
||||
// NewAuthzLinkType authz
|
||||
NewAuthzLinkType
|
||||
// AuthzLinkType new-authz
|
||||
AuthzLinkType
|
||||
// ChallengeLinkType challenge
|
||||
ChallengeLinkType
|
||||
// CertificateLinkType certificate
|
||||
CertificateLinkType
|
||||
// DirectoryLinkType directory
|
||||
DirectoryLinkType
|
||||
// RevokeCertLinkType revoke certificate
|
||||
RevokeCertLinkType
|
||||
// KeyChangeLinkType key rollover
|
||||
KeyChangeLinkType
|
||||
)
|
||||
|
||||
func (l LinkType) String() string {
|
||||
switch l {
|
||||
case NewNonceLinkType:
|
||||
return "new-nonce"
|
||||
case NewAccountLinkType:
|
||||
return "new-account"
|
||||
case AccountLinkType:
|
||||
return "account"
|
||||
case NewOrderLinkType:
|
||||
return "new-order"
|
||||
case OrderLinkType:
|
||||
return "order"
|
||||
case NewAuthzLinkType:
|
||||
return "new-authz"
|
||||
case AuthzLinkType:
|
||||
return "authz"
|
||||
case ChallengeLinkType:
|
||||
return "challenge"
|
||||
case CertificateLinkType:
|
||||
return "certificate"
|
||||
case DirectoryLinkType:
|
||||
return "directory"
|
||||
case RevokeCertLinkType:
|
||||
return "revoke-cert"
|
||||
case KeyChangeLinkType:
|
||||
return "key-change"
|
||||
default:
|
||||
return fmt.Sprintf("unexpected LinkType '%d'", int(l))
|
||||
}
|
||||
}
|
||||
|
||||
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
_, _, err := net.SplitHostPort(dns)
|
||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||
// these cases, then the input dns is not changed.
|
||||
lastIndex := strings.LastIndex(dns, ":")
|
||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||
if ip := net.ParseIP(hostPart); ip != nil {
|
||||
dns = "[" + hostPart + "]:" + portPart
|
||||
} else if ip := net.ParseIP(dns); ip != nil {
|
||||
dns = "[" + dns + "]"
|
||||
}
|
||||
}
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||
Middleware(http.Handler) http.Handler
|
||||
LinkOrder(ctx context.Context, o *Order)
|
||||
LinkAccount(ctx context.Context, o *Account)
|
||||
LinkChallenge(ctx context.Context, o *Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
type linkerKey struct{}
|
||||
|
||||
// NewLinkerContext adds the given linker to the context.
|
||||
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
|
||||
return context.WithValue(ctx, linkerKey{}, v)
|
||||
}
|
||||
|
||||
// LinkerFromContext returns the current linker from the given context.
|
||||
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
|
||||
v, ok = ctx.Value(linkerKey{}).(Linker)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current linker from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustLinkerFromContext(ctx context.Context) Linker {
|
||||
if v, ok := LinkerFromContext(ctx); !ok {
|
||||
panic("acme linker is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
type baseURLKey struct{}
|
||||
|
||||
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
|
||||
var u *url.URL
|
||||
if r.Host != "" {
|
||||
u = &url.URL{Scheme: "https", Host: r.Host}
|
||||
}
|
||||
return context.WithValue(ctx, baseURLKey{}, u)
|
||||
}
|
||||
|
||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
|
||||
return u
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
// Middleware gets the provisioner and current url from the request and sets
|
||||
// them in the context so we can use the linker to create ACME links.
|
||||
func (l *linker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add base url to the context.
|
||||
ctx := newBaseURLContext(r.Context(), r)
|
||||
|
||||
// Add provisioner to the context.
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
|
||||
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit.
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var name string
|
||||
if p, ok := ProvisionerFromContext(ctx); ok {
|
||||
name = p.GetName()
|
||||
}
|
||||
|
||||
var u url.URL
|
||||
if baseURL := baseURLFromContext(ctx); baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// LinkOrder sets the ACME links required by an ACME order.
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *Order) {
|
||||
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
|
||||
for i, azID := range o.AuthorizationIDs {
|
||||
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
|
||||
}
|
||||
o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, o.ID)
|
||||
if o.CertificateID != "" {
|
||||
o.CertificateURL = l.GetLink(ctx, CertificateLinkType, o.CertificateID)
|
||||
}
|
||||
}
|
||||
|
||||
// LinkAccount sets the ACME links required by an ACME account.
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
|
||||
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
|
||||
}
|
||||
|
||||
// LinkChallenge sets the ACME links required by an ACME challenge.
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
|
||||
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
|
||||
}
|
||||
|
||||
// LinkAuthorization sets the ACME links required by an ACME authorization.
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
|
||||
for _, ch := range az.Challenges {
|
||||
l.LinkChallenge(ctx, ch, az.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// LinkOrdersByAccountID converts each order ID to an ACME link.
|
||||
func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) {
|
||||
for i, id := range orders {
|
||||
orders[i] = l.GetLink(ctx, OrderLinkType, id)
|
||||
}
|
||||
}
|
|
@ -1,380 +0,0 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
func mockProvisioner(t *testing.T) Provisioner {
|
||||
t.Helper()
|
||||
var defaultDisableRenewal = false
|
||||
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-<test>provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestGetUnescapedPathSuffix(t *testing.T) {
|
||||
getPath := GetUnescapedPathSuffix
|
||||
|
||||
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
|
||||
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
|
||||
assert.Equals(t, getPath(NewAccountLinkType, "{provisionerID}"), "/{provisionerID}/new-account")
|
||||
assert.Equals(t, getPath(AccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}")
|
||||
assert.Equals(t, getPath(KeyChangeLinkType, "{provisionerID}"), "/{provisionerID}/key-change")
|
||||
assert.Equals(t, getPath(NewOrderLinkType, "{provisionerID}"), "/{provisionerID}/new-order")
|
||||
assert.Equals(t, getPath(OrderLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}")
|
||||
assert.Equals(t, getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}/orders")
|
||||
assert.Equals(t, getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}/finalize")
|
||||
assert.Equals(t, getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), "/{provisionerID}/authz/{authzID}")
|
||||
assert.Equals(t, getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), "/{provisionerID}/challenge/{authzID}/{chID}")
|
||||
assert.Equals(t, getPath(CertificateLinkType, "{provisionerID}", "{certID}"), "/{provisionerID}/certificate/{certID}")
|
||||
}
|
||||
|
||||
func TestLinker_DNS(t *testing.T) {
|
||||
prov := mockProvisioner(t)
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
type test struct {
|
||||
name string
|
||||
dns string
|
||||
prefix string
|
||||
expectedDirectoryLink string
|
||||
}
|
||||
tests := []test{
|
||||
{
|
||||
name: "domain",
|
||||
dns: "ca.smallstep.com",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", escProvName),
|
||||
},
|
||||
{
|
||||
name: "domain-port",
|
||||
dns: "ca.smallstep.com:8443",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://ca.smallstep.com:8443/acme/%s/directory", escProvName),
|
||||
},
|
||||
{
|
||||
name: "ipv4",
|
||||
dns: "127.0.0.1",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://127.0.0.1/acme/%s/directory", escProvName),
|
||||
},
|
||||
{
|
||||
name: "ipv4-port",
|
||||
dns: "127.0.0.1:8443",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://127.0.0.1:8443/acme/%s/directory", escProvName),
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
dns: "[::1]",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://[::1]/acme/%s/directory", escProvName),
|
||||
},
|
||||
{
|
||||
name: "ipv6-port",
|
||||
dns: "[::1]:8443",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://[::1]:8443/acme/%s/directory", escProvName),
|
||||
},
|
||||
{
|
||||
name: "ipv6-no-brackets",
|
||||
dns: "::1",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://[::1]/acme/%s/directory", escProvName),
|
||||
},
|
||||
{
|
||||
name: "ipv6-port-no-brackets",
|
||||
dns: "::1:8443",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://[::1]:8443/acme/%s/directory", escProvName),
|
||||
},
|
||||
{
|
||||
name: "ipv6-long-no-brackets",
|
||||
dns: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]/acme/%s/directory", escProvName),
|
||||
},
|
||||
{
|
||||
name: "ipv6-long-port-no-brackets",
|
||||
dns: "2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443",
|
||||
prefix: "acme",
|
||||
expectedDirectoryLink: fmt.Sprintf("https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8443/acme/%s/directory", escProvName),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
linker := NewLinker(tt.dns, tt.prefix)
|
||||
assert.Equals(t, tt.expectedDirectoryLink, linker.GetLink(ctx, DirectoryLinkType))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinker_GetLink(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
linker := NewLinker(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := mockProvisioner(t)
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
// No provisioner and no BaseURL from request
|
||||
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
|
||||
// Provisioner: yes, BaseURL: no
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
|
||||
// Provisioner: no, BaseURL: yes
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, NewAccountLinkType), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, AccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, NewOrderLinkType), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, OrderLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, OrdersByAccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, FinalizeLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, NewAuthzLinkType), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, AuthzLinkType, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, DirectoryLinkType), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, RevokeCertLinkType, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, KeyChangeLinkType), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, ChallengeLinkType, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, CertificateLinkType, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName))
|
||||
}
|
||||
|
||||
func TestLinker_LinkOrder(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
oid := "orderID"
|
||||
certID := "certID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
o *Order
|
||||
validate func(o *Order)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"no-authz-and-no-cert": {
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
},
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{})
|
||||
assert.Equals(t, o.CertificateURL, "")
|
||||
},
|
||||
},
|
||||
"one-authz-and-cert": {
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo"},
|
||||
},
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
})
|
||||
assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
|
||||
},
|
||||
},
|
||||
"many-authz": {
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo", "bar", "zap"},
|
||||
},
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "bar"),
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "zap"),
|
||||
})
|
||||
assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkOrder(ctx, tc.o)
|
||||
tc.validate(tc.o)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinker_LinkAccount(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
accID := "accountID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
a *Account
|
||||
validate func(o *Account)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
a: &Account{
|
||||
ID: accID,
|
||||
},
|
||||
validate: func(a *Account) {
|
||||
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkAccount(ctx, tc.a)
|
||||
tc.validate(tc.a)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinker_LinkChallenge(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
chID := "chID"
|
||||
azID := "azID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
ch *Challenge
|
||||
validate func(o *Challenge)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
ch: &Challenge{
|
||||
ID: chID,
|
||||
},
|
||||
validate: func(ch *Challenge) {
|
||||
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkChallenge(ctx, tc.ch, azID)
|
||||
tc.validate(tc.ch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinker_LinkAuthorization(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
chID0 := "chID-0"
|
||||
chID1 := "chID-1"
|
||||
chID2 := "chID-2"
|
||||
azID := "azID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
az *Authorization
|
||||
validate func(o *Authorization)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
az: &Authorization{
|
||||
ID: azID,
|
||||
Challenges: []*Challenge{
|
||||
{ID: chID0},
|
||||
{ID: chID1},
|
||||
{ID: chID2},
|
||||
},
|
||||
},
|
||||
validate: func(az *Authorization) {
|
||||
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
|
||||
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
|
||||
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkAuthorization(ctx, tc.az)
|
||||
tc.validate(tc.az)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
oids []string
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
oids: []string{"foo", "bar", "baz"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
l.LinkOrdersByAccountID(ctx, tc.oids)
|
||||
assert.Equals(t, tc.oids, []string{
|
||||
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "bar"),
|
||||
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "baz"),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,9 +1,73 @@
|
|||
package acme
|
||||
|
||||
// Nonce represents an ACME nonce type.
|
||||
type Nonce string
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
// String implements the ToString interface.
|
||||
func (n Nonce) String() string {
|
||||
return string(n)
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
// nonce contains nonce metadata used in the ACME protocol.
|
||||
type nonce struct {
|
||||
ID string
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
// newNonce creates, stores, and returns an ACME replay-nonce.
|
||||
func newNonce(db nosql.DB) (*nonce, error) {
|
||||
_id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
|
||||
n := &nonce{
|
||||
ID: id,
|
||||
Created: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce"))
|
||||
case !swapped:
|
||||
return nil, ServerInternalErr(errors.New("error storing nonce; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
// useNonce verifies that the nonce is valid (by checking if it exists),
|
||||
// and if so, consumes the nonce resource by deleting it from the database.
|
||||
func useNonce(db nosql.DB, nonce string) error {
|
||||
err := db.Update(&database.Tx{
|
||||
Operations: []*database.TxEntry{
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Get,
|
||||
},
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Delete,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
return BadNonceErr(nil)
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
163
acme/nonce_test.go
Normal file
163
acme/nonce_test.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestNewNonce(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing nonce: force")),
|
||||
}
|
||||
},
|
||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var _id string
|
||||
id := &_id
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
*id = string(key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if n, err := newNonce(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, n.ID, *tc.id)
|
||||
|
||||
assert.True(t, n.Created.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, n.Created.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseNonce(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/update-not-found": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return database.ErrNotFound
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
err: BadNonceErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/update-error": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := useNonce(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
630
acme/order.go
630
acme/order.go
|
@ -1,460 +1,342 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type IdentifierType string
|
||||
|
||||
const (
|
||||
// IP is the ACME ip identifier type
|
||||
IP IdentifierType = "ip"
|
||||
// DNS is the ACME dns identifier type
|
||||
DNS IdentifierType = "dns"
|
||||
// PermanentIdentifier is the ACME permanent-identifier identifier type
|
||||
// defined in https://datatracker.ietf.org/doc/html/draft-bweeks-acme-device-attest-00
|
||||
PermanentIdentifier IdentifierType = "permanent-identifier"
|
||||
)
|
||||
|
||||
// Identifier encodes the type that an order pertains to.
|
||||
type Identifier struct {
|
||||
Type IdentifierType `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
var defaultOrderExpiry = time.Hour * 24
|
||||
|
||||
// Order contains order metadata for the ACME protocol order type.
|
||||
type Order struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"-"`
|
||||
ProvisionerID string `json:"-"`
|
||||
Status Status `json:"status"`
|
||||
ExpiresAt time.Time `json:"expires"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore time.Time `json:"notBefore"`
|
||||
NotAfter time.Time `json:"notAfter"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
AuthorizationIDs []string `json:"-"`
|
||||
AuthorizationURLs []string `json:"authorizations"`
|
||||
FinalizeURL string `json:"finalize"`
|
||||
CertificateID string `json:"-"`
|
||||
CertificateURL string `json:"certificate,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Expires string `json:"expires,omitempty"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore string `json:"notBefore,omitempty"`
|
||||
NotAfter string `json:"notAfter,omitempty"`
|
||||
Error interface{} `json:"error,omitempty"`
|
||||
Authorizations []string `json:"authorizations"`
|
||||
Finalize string `json:"finalize"`
|
||||
Certificate string `json:"certificate,omitempty"`
|
||||
ID string `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (o *Order) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return nil, WrapErrorISE(err, "error marshaling order for logging")
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the ACME Order Status if necessary.
|
||||
// Changes to the order are saved using the database interface.
|
||||
func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
|
||||
now := clock.Now()
|
||||
// GetID returns the Order ID.
|
||||
func (o *Order) GetID() string {
|
||||
return o.ID
|
||||
}
|
||||
|
||||
// OrderOptions options with which to create a new Order.
|
||||
type OrderOptions struct {
|
||||
AccountID string `json:"accID"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore time.Time `json:"notBefore"`
|
||||
NotAfter time.Time `json:"notAfter"`
|
||||
}
|
||||
|
||||
type order struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Created time.Time `json:"created"`
|
||||
Expires time.Time `json:"expires,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
Authorizations []string `json:"authorizations"`
|
||||
Certificate string `json:"certificate,omitempty"`
|
||||
}
|
||||
|
||||
// newOrder returns a new Order type.
|
||||
func newOrder(db nosql.DB, ops OrderOptions) (*order, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authzs := make([]string, len(ops.Identifiers))
|
||||
for i, identifier := range ops.Identifiers {
|
||||
az, err := newAuthz(db, ops.AccountID, identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authzs[i] = az.getID()
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
o := &order{
|
||||
ID: id,
|
||||
AccountID: ops.AccountID,
|
||||
Created: now,
|
||||
Status: StatusPending,
|
||||
Expires: now.Add(defaultOrderExpiry),
|
||||
Identifiers: ops.Identifiers,
|
||||
NotBefore: ops.NotBefore,
|
||||
NotAfter: ops.NotAfter,
|
||||
Authorizations: authzs,
|
||||
}
|
||||
if err := o.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update the "order IDs by account ID" index //
|
||||
oids, err := getOrderIDsByAccount(db, ops.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newOids := append(oids, o.ID)
|
||||
if err = orderIDs(newOids).save(db, oids, o.AccountID); err != nil {
|
||||
db.Del(orderTable, []byte(o.ID))
|
||||
return nil, err
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
type orderIDs []string
|
||||
|
||||
func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
|
||||
var (
|
||||
err error
|
||||
oldb []byte
|
||||
)
|
||||
if len(old) == 0 {
|
||||
oldb = nil
|
||||
} else {
|
||||
oldb, err = json.Marshal(old)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice"))
|
||||
}
|
||||
}
|
||||
newb, err := json.Marshal(oids)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("error storing order IDs "+
|
||||
"for account %s; order IDs changed since last read", accID))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (o *order) save(db nosql.DB, old *order) error {
|
||||
var (
|
||||
err error
|
||||
oldB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
||||
}
|
||||
}
|
||||
|
||||
newB, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order"))
|
||||
}
|
||||
|
||||
_, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error storing order"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.New("error storing order; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// updateStatus updates order status if necessary.
|
||||
func (o *order) updateStatus(db nosql.DB) (*order, error) {
|
||||
_newOrder := *o
|
||||
newOrder := &_newOrder
|
||||
|
||||
now := time.Now().UTC()
|
||||
switch o.Status {
|
||||
case StatusInvalid:
|
||||
return nil
|
||||
return o, nil
|
||||
case StatusValid:
|
||||
return nil
|
||||
return o, nil
|
||||
case StatusReady:
|
||||
// Check expiry
|
||||
if now.After(o.ExpiresAt) {
|
||||
o.Status = StatusInvalid
|
||||
o.Error = NewError(ErrorMalformedType, "order has expired")
|
||||
// check expiry
|
||||
if now.After(o.Expires) {
|
||||
newOrder.Status = StatusInvalid
|
||||
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
||||
break
|
||||
}
|
||||
return nil
|
||||
return o, nil
|
||||
case StatusPending:
|
||||
// Check expiry
|
||||
if now.After(o.ExpiresAt) {
|
||||
o.Status = StatusInvalid
|
||||
o.Error = NewError(ErrorMalformedType, "order has expired")
|
||||
// check expiry
|
||||
if now.After(o.Expires) {
|
||||
newOrder.Status = StatusInvalid
|
||||
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
||||
break
|
||||
}
|
||||
|
||||
var count = map[Status]int{
|
||||
var count = map[string]int{
|
||||
StatusValid: 0,
|
||||
StatusInvalid: 0,
|
||||
StatusPending: 0,
|
||||
}
|
||||
for _, azID := range o.AuthorizationIDs {
|
||||
az, err := db.GetAuthorization(ctx, azID)
|
||||
for _, azID := range o.Authorizations {
|
||||
az, err := getAuthz(db, azID)
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error getting authorization ID %s", azID)
|
||||
return nil, err
|
||||
}
|
||||
if err = az.UpdateStatus(ctx, db); err != nil {
|
||||
return WrapErrorISE(err, "error updating authorization ID %s", azID)
|
||||
if az, err = az.updateStatus(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
st := az.Status
|
||||
st := az.getStatus()
|
||||
count[st]++
|
||||
}
|
||||
switch {
|
||||
case count[StatusInvalid] > 0:
|
||||
o.Status = StatusInvalid
|
||||
|
||||
// No change in the order status, so just return the order as is -
|
||||
// without writing any changes.
|
||||
newOrder.Status = StatusInvalid
|
||||
case count[StatusPending] > 0:
|
||||
return nil
|
||||
|
||||
case count[StatusValid] == len(o.AuthorizationIDs):
|
||||
o.Status = StatusReady
|
||||
|
||||
break
|
||||
case count[StatusValid] == len(o.Authorizations):
|
||||
newOrder.Status = StatusReady
|
||||
default:
|
||||
return NewErrorISE("unexpected authz status")
|
||||
return nil, ServerInternalErr(errors.New("unexpected authz status"))
|
||||
}
|
||||
default:
|
||||
return NewErrorISE("unrecognized order status: %s", o.Status)
|
||||
return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status))
|
||||
}
|
||||
if err := db.UpdateOrder(ctx, o); err != nil {
|
||||
return WrapErrorISE(err, "error updating order")
|
||||
|
||||
if err := newOrder.save(db, o); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil
|
||||
return newOrder, nil
|
||||
}
|
||||
|
||||
// getKeyFingerprint returns a fingerprint from the list of authorizations. This
|
||||
// fingerprint is used on the device-attest-01 flow to verify the attestation
|
||||
// certificate public key with the CSR public key.
|
||||
//
|
||||
// There's no point on reading all the authorizations as there will be only one
|
||||
// for a permanent identifier.
|
||||
func (o *Order) getAuthorizationFingerprint(ctx context.Context, db DB) (string, error) {
|
||||
for _, azID := range o.AuthorizationIDs {
|
||||
az, err := db.GetAuthorization(ctx, azID)
|
||||
if err != nil {
|
||||
return "", WrapErrorISE(err, "error getting authorization %q", azID)
|
||||
}
|
||||
// There's no point on reading all the authorizations as there will
|
||||
// be only one for a permanent identifier.
|
||||
if az.Fingerprint != "" {
|
||||
return az.Fingerprint, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Finalize signs a certificate if the necessary conditions for Order completion
|
||||
// finalize signs a certificate if the necessary conditions for Order completion
|
||||
// have been met.
|
||||
//
|
||||
// TODO(mariano): Here or in the challenge validation we should perform some
|
||||
// external validation using the identifier value and the attestation data. From
|
||||
// a validation service we can get the list of SANs to set in the final
|
||||
// certificate.
|
||||
func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth CertificateAuthority, p Provisioner) error {
|
||||
if err := o.UpdateStatus(ctx, db); err != nil {
|
||||
return err
|
||||
func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p provisioner.Interface) (*order, error) {
|
||||
var err error
|
||||
if o, err = o.updateStatus(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch o.Status {
|
||||
case StatusInvalid:
|
||||
return NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID)
|
||||
return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID))
|
||||
case StatusValid:
|
||||
return nil
|
||||
return o, nil
|
||||
case StatusPending:
|
||||
return NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID)
|
||||
return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID))
|
||||
case StatusReady:
|
||||
break
|
||||
default:
|
||||
return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID)
|
||||
return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID))
|
||||
}
|
||||
|
||||
// Get key fingerprint if any. And then compare it with the CSR fingerprint.
|
||||
//
|
||||
// In device-attest-01 challenges we should check that the keys in the CSR
|
||||
// and the attestation certificate are the same.
|
||||
fingerprint, err := o.getAuthorizationFingerprint(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
// Validate identifier names against CSR alternative names //
|
||||
csrNames := make(map[string]int)
|
||||
for _, n := range csr.DNSNames {
|
||||
csrNames[n] = 1
|
||||
}
|
||||
if fingerprint != "" {
|
||||
fp, err := keyutil.Fingerprint(csr.PublicKey)
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error calculating key fingerprint")
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(fingerprint), []byte(fp)) == 0 {
|
||||
return NewError(ErrorUnauthorizedType, "order %s csr does not match the attested key", o.ID)
|
||||
}
|
||||
orderNames := make(map[string]int)
|
||||
for _, n := range o.Identifiers {
|
||||
orderNames[n.Value] = 1
|
||||
}
|
||||
|
||||
// canonicalize the CSR to allow for comparison
|
||||
csr = canonicalize(csr)
|
||||
|
||||
// Template data
|
||||
data := x509util.NewTemplateData()
|
||||
data.SetCommonName(csr.Subject.CommonName)
|
||||
|
||||
// Custom sign options passed to authority.Sign
|
||||
var extraOptions []provisioner.SignOption
|
||||
|
||||
// TODO: support for multiple identifiers?
|
||||
var permanentIdentifier string
|
||||
for i := range o.Identifiers {
|
||||
if o.Identifiers[i].Type == PermanentIdentifier {
|
||||
permanentIdentifier = o.Identifiers[i].Value
|
||||
// the first (and only) Permanent Identifier that gets added to the certificate
|
||||
// should be equal to the Subject Common Name if it's set. If not equal, the CSR
|
||||
// is rejected, because the Common Name hasn't been challenged in that case. This
|
||||
// could result in unauthorized access if a relying system relies on the Common
|
||||
// Name in its authorization logic.
|
||||
if csr.Subject.CommonName != "" && csr.Subject.CommonName != permanentIdentifier {
|
||||
return NewError(ErrorBadCSRType, "CSR Subject Common Name does not match identifiers exactly: "+
|
||||
"CSR Subject Common Name = %s, Order Permanent Identifier = %s", csr.Subject.CommonName, permanentIdentifier)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var defaultTemplate string
|
||||
if permanentIdentifier != "" {
|
||||
defaultTemplate = x509util.DefaultAttestedLeafTemplate
|
||||
data.SetSubjectAlternativeNames(x509util.SubjectAlternativeName{
|
||||
Type: x509util.PermanentIdentifierType,
|
||||
Value: permanentIdentifier,
|
||||
})
|
||||
extraOptions = append(extraOptions, provisioner.AttestationData{
|
||||
PermanentIdentifier: permanentIdentifier,
|
||||
})
|
||||
} else {
|
||||
defaultTemplate = x509util.DefaultLeafTemplate
|
||||
sans, err := o.sans(csr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data.SetSubjectAlternativeNames(sans...)
|
||||
if !reflect.DeepEqual(csrNames, orderNames) {
|
||||
return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly"))
|
||||
}
|
||||
|
||||
// Get authorizations from the ACME provisioner.
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
signOps, err := p.AuthorizeSign(ctx, "")
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error retrieving authorization options from ACME provisioner")
|
||||
}
|
||||
// Unlike most of the provisioners, ACME's AuthorizeSign method doesn't
|
||||
// define the templates, and the template data used in WebHooks is not
|
||||
// available.
|
||||
for _, signOp := range signOps {
|
||||
if wc, ok := signOp.(*provisioner.WebhookController); ok {
|
||||
wc.TemplateData = data
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner"))
|
||||
}
|
||||
|
||||
templateOptions, err := provisioner.CustomTemplateOptions(p.GetOptions(), data, defaultTemplate)
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error creating template options from ACME provisioner")
|
||||
}
|
||||
|
||||
// Build extra signing options.
|
||||
signOps = append(signOps, templateOptions)
|
||||
signOps = append(signOps, extraOptions...)
|
||||
|
||||
// Sign a new certificate.
|
||||
certChain, err := auth.Sign(csr, provisioner.SignOptions{
|
||||
// Create and store a new certificate.
|
||||
certChain, err := auth.Sign(csr, provisioner.Options{
|
||||
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
|
||||
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
|
||||
}, signOps...)
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error signing certificate for order %s", o.ID)
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID))
|
||||
}
|
||||
|
||||
cert := &Certificate{
|
||||
cert, err := newCert(db, CertOptions{
|
||||
AccountID: o.AccountID,
|
||||
OrderID: o.ID,
|
||||
Leaf: certChain[0],
|
||||
Intermediates: certChain[1:],
|
||||
}
|
||||
if err := db.CreateCertificate(ctx, cert); err != nil {
|
||||
return WrapErrorISE(err, "error creating certificate for order %s", o.ID)
|
||||
}
|
||||
|
||||
o.CertificateID = cert.ID
|
||||
o.Status = StatusValid
|
||||
if err = db.UpdateOrder(ctx, o); err != nil {
|
||||
return WrapErrorISE(err, "error updating order %s", o.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativeName, error) {
|
||||
var sans []x509util.SubjectAlternativeName
|
||||
if len(csr.EmailAddresses) > 0 || len(csr.URIs) > 0 {
|
||||
return sans, NewError(ErrorBadCSRType, "Only DNS names and IP addresses are allowed")
|
||||
}
|
||||
|
||||
// order the DNS names and IP addresses, so that they can be compared against the canonicalized CSR
|
||||
orderNames := make([]string, numberOfIdentifierType(DNS, o.Identifiers))
|
||||
orderIPs := make([]net.IP, numberOfIdentifierType(IP, o.Identifiers))
|
||||
orderPIDs := make([]string, numberOfIdentifierType(PermanentIdentifier, o.Identifiers))
|
||||
indexDNS, indexIP, indexPID := 0, 0, 0
|
||||
for _, n := range o.Identifiers {
|
||||
switch n.Type {
|
||||
case DNS:
|
||||
orderNames[indexDNS] = n.Value
|
||||
indexDNS++
|
||||
case IP:
|
||||
orderIPs[indexIP] = net.ParseIP(n.Value) // NOTE: this assumes are all valid IPs at this time; or will result in nil entries
|
||||
indexIP++
|
||||
case PermanentIdentifier:
|
||||
orderPIDs[indexPID] = n.Value
|
||||
indexPID++
|
||||
default:
|
||||
return sans, NewErrorISE("unsupported identifier type in order: %s", n.Type)
|
||||
}
|
||||
}
|
||||
orderNames = uniqueSortedLowerNames(orderNames)
|
||||
orderIPs = uniqueSortedIPs(orderIPs)
|
||||
|
||||
totalNumberOfSANs := len(csr.DNSNames) + len(csr.IPAddresses)
|
||||
sans = make([]x509util.SubjectAlternativeName, totalNumberOfSANs)
|
||||
index := 0
|
||||
|
||||
// Validate identifier names against CSR alternative names.
|
||||
//
|
||||
// Note that with certificate templates we are not going to check for the
|
||||
// absence of other SANs as they will only be set if the template allows
|
||||
// them.
|
||||
if len(csr.DNSNames) != len(orderNames) {
|
||||
return sans, NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
|
||||
}
|
||||
|
||||
for i := range csr.DNSNames {
|
||||
if csr.DNSNames[i] != orderNames[i] {
|
||||
return sans, NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||
"CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
|
||||
}
|
||||
sans[index] = x509util.SubjectAlternativeName{
|
||||
Type: x509util.DNSType,
|
||||
Value: csr.DNSNames[i],
|
||||
}
|
||||
index++
|
||||
}
|
||||
|
||||
if len(csr.IPAddresses) != len(orderIPs) {
|
||||
return sans, NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+
|
||||
"CSR IPs = %v, Order IPs = %v", csr.IPAddresses, orderIPs)
|
||||
}
|
||||
|
||||
for i := range csr.IPAddresses {
|
||||
if !ipsAreEqual(csr.IPAddresses[i], orderIPs[i]) {
|
||||
return sans, NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+
|
||||
"CSR IPs = %v, Order IPs = %v", csr.IPAddresses, orderIPs)
|
||||
}
|
||||
sans[index] = x509util.SubjectAlternativeName{
|
||||
Type: x509util.IPType,
|
||||
Value: csr.IPAddresses[i].String(),
|
||||
}
|
||||
index++
|
||||
}
|
||||
|
||||
return sans, nil
|
||||
}
|
||||
|
||||
// numberOfIdentifierType returns the number of Identifiers that
|
||||
// are of type typ.
|
||||
func numberOfIdentifierType(typ IdentifierType, ids []Identifier) int {
|
||||
c := 0
|
||||
for _, id := range ids {
|
||||
if id.Type == typ {
|
||||
c++
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// canonicalize canonicalizes a CSR so that it can be compared against an Order
|
||||
// NOTE: this effectively changes the order of SANs in the CSR, which may be OK,
|
||||
// but may not be expected. It also adds a Subject Common Name to either the IP
|
||||
// addresses or DNS names slice, depending on whether it can be parsed as an IP
|
||||
// or not. This might result in an additional SAN in the final certificate.
|
||||
func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.CertificateRequest) {
|
||||
// for clarity only; we're operating on the same object by pointer
|
||||
canonicalized = csr
|
||||
|
||||
// RFC8555: The CSR MUST indicate the exact same set of requested
|
||||
// identifiers as the initial newOrder request. Identifiers of type "dns"
|
||||
// MUST appear either in the commonName portion of the requested subject
|
||||
// name or in an extensionRequest attribute [RFC2985] requesting a
|
||||
// subjectAltName extension, or both. Subject Common Names that can be
|
||||
// parsed as an IP are included as an IP address for the equality check.
|
||||
// If these were excluded, a certificate could contain an IP as the
|
||||
// common name without having been challenged.
|
||||
if csr.Subject.CommonName != "" {
|
||||
if ip := net.ParseIP(csr.Subject.CommonName); ip != nil {
|
||||
canonicalized.IPAddresses = append(canonicalized.IPAddresses, ip)
|
||||
} else {
|
||||
canonicalized.DNSNames = append(canonicalized.DNSNames, csr.Subject.CommonName)
|
||||
}
|
||||
}
|
||||
|
||||
canonicalized.DNSNames = uniqueSortedLowerNames(canonicalized.DNSNames)
|
||||
canonicalized.IPAddresses = uniqueSortedIPs(canonicalized.IPAddresses)
|
||||
|
||||
return canonicalized
|
||||
}
|
||||
|
||||
// ipsAreEqual compares IPs to be equal. Nil values (i.e. invalid IPs) are
|
||||
// not considered equal. IPv6 representations of IPv4 addresses are
|
||||
// considered equal to the IPv4 address in this implementation, which is
|
||||
// standard Go behavior. An example is "::ffff:192.168.42.42", which
|
||||
// is equal to "192.168.42.42". This is considered a known issue within
|
||||
// step and is tracked here too: https://github.com/golang/go/issues/37921.
|
||||
func ipsAreEqual(x, y net.IP) bool {
|
||||
if x == nil || y == nil {
|
||||
return false
|
||||
}
|
||||
return x.Equal(y)
|
||||
}
|
||||
|
||||
// uniqueSortedLowerNames returns the set of all unique names in the input after all
|
||||
// of them are lowercased. The returned names will be in their lowercased form
|
||||
// and sorted alphabetically.
|
||||
func uniqueSortedLowerNames(names []string) (unique []string) {
|
||||
nameMap := make(map[string]int, len(names))
|
||||
for _, name := range names {
|
||||
nameMap[strings.ToLower(name)] = 1
|
||||
}
|
||||
unique = make([]string, 0, len(nameMap))
|
||||
for name := range nameMap {
|
||||
unique = append(unique, name)
|
||||
}
|
||||
sort.Strings(unique)
|
||||
return
|
||||
}
|
||||
|
||||
// uniqueSortedIPs returns the set of all unique net.IPs in the input. They
|
||||
// are sorted by their bytes (octet) representation.
|
||||
func uniqueSortedIPs(ips []net.IP) (unique []net.IP) {
|
||||
type entry struct {
|
||||
ip net.IP
|
||||
}
|
||||
ipEntryMap := make(map[string]entry, len(ips))
|
||||
for _, ip := range ips {
|
||||
// reparsing the IP results in the IP being represented using 16 bytes
|
||||
// for both IPv4 as well as IPv6, even when the ips slice contains IPs that
|
||||
// are represented by 4 bytes. This ensures a fair comparison and thus ordering.
|
||||
ipEntryMap[ip.String()] = entry{ip: net.ParseIP(ip.String())}
|
||||
}
|
||||
unique = make([]net.IP, 0, len(ipEntryMap))
|
||||
for _, entry := range ipEntryMap {
|
||||
unique = append(unique, entry.ip)
|
||||
}
|
||||
sort.Slice(unique, func(i, j int) bool {
|
||||
return bytes.Compare(unique[i], unique[j]) < 0
|
||||
})
|
||||
return
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_newOrder := *o
|
||||
newOrder := &_newOrder
|
||||
newOrder.Certificate = cert.ID
|
||||
newOrder.Status = StatusValid
|
||||
if err := newOrder.save(db, o); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newOrder, nil
|
||||
}
|
||||
|
||||
// getOrder retrieves and unmarshals an ACME Order type from the database.
|
||||
func getOrder(db nosql.DB, id string) (*order, error) {
|
||||
b, err := db.Get(orderTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id))
|
||||
}
|
||||
var o order
|
||||
if err := json.Unmarshal(b, &o); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order"))
|
||||
}
|
||||
return &o, nil
|
||||
}
|
||||
|
||||
// toACME converts the internal Order type into the public acmeOrder type for
|
||||
// presentation in the ACME protocol.
|
||||
func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Order, error) {
|
||||
azs := make([]string, len(o.Authorizations))
|
||||
for i, aid := range o.Authorizations {
|
||||
azs[i] = dir.getLink(AuthzLink, URLSafeProvisionerName(p), true, aid)
|
||||
}
|
||||
ao := &Order{
|
||||
Status: o.Status,
|
||||
Expires: o.Expires.Format(time.RFC3339),
|
||||
Identifiers: o.Identifiers,
|
||||
NotBefore: o.NotBefore.Format(time.RFC3339),
|
||||
NotAfter: o.NotAfter.Format(time.RFC3339),
|
||||
Authorizations: azs,
|
||||
Finalize: dir.getLink(FinalizeLink, URLSafeProvisionerName(p), true, o.ID),
|
||||
ID: o.ID,
|
||||
}
|
||||
|
||||
if o.Certificate != "" {
|
||||
ao.Certificate = dir.getLink(CertificateLink, URLSafeProvisionerName(p), true, o.Certificate)
|
||||
}
|
||||
return ao, nil
|
||||
}
|
||||
|
|
2770
acme/order_test.go
2770
acme/order_test.go
File diff suppressed because it is too large
Load diff
|
@ -1,20 +0,0 @@
|
|||
package acme
|
||||
|
||||
// Status represents an ACME status.
|
||||
type Status string
|
||||
|
||||
var (
|
||||
// StatusValid -- valid
|
||||
StatusValid = Status("valid")
|
||||
// StatusInvalid -- invalid
|
||||
StatusInvalid = Status("invalid")
|
||||
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
|
||||
StatusPending = Status("pending")
|
||||
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
|
||||
StatusDeactivated = Status("deactivated")
|
||||
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
|
||||
StatusReady = Status("ready")
|
||||
//statusExpired = "expired"
|
||||
//statusActive = "active"
|
||||
//statusProcessing = "processing"
|
||||
)
|
422
api/api.go
422
api/api.go
|
@ -1,13 +1,11 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/dsa" //nolint:staticcheck // support legacy algorithms
|
||||
"crypto/dsa"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
|
@ -21,16 +19,10 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/certificates/api/log"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/cli/crypto/tlsutil"
|
||||
)
|
||||
|
||||
// Authority is the interface implemented by a CA authority.
|
||||
|
@ -38,27 +30,18 @@ type Authority interface {
|
|||
SSHAuthority
|
||||
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||
GetTLSOptions() *config.TLSOptions
|
||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||
GetTLSOptions() *tlsutil.TLSOptions
|
||||
Root(shasum string) (*x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
|
||||
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
Revoke(*authority.RevokeOptions) error
|
||||
GetEncryptedKey(kid string) (string, error)
|
||||
GetRoots() ([]*x509.Certificate, error)
|
||||
GetRoots() (federation []*x509.Certificate, err error)
|
||||
GetFederation() ([]*x509.Certificate, error)
|
||||
Version() authority.Version
|
||||
GetCertificateRevocationList() ([]byte, error)
|
||||
}
|
||||
|
||||
// mustAuthority will be replaced on unit tests.
|
||||
var mustAuthority = func(ctx context.Context) Authority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// TimeDuration is an alias of provisioner.TimeDuration
|
||||
|
@ -88,13 +71,6 @@ func NewCertificate(cr *x509.Certificate) Certificate {
|
|||
}
|
||||
}
|
||||
|
||||
// reset sets the inner x509.CertificateRequest to nil
|
||||
func (c *Certificate) reset() {
|
||||
if c != nil {
|
||||
c.Certificate = nil
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface. The certificate is
|
||||
// quoted string using the PEM encoding.
|
||||
func (c Certificate) MarshalJSON() ([]byte, error) {
|
||||
|
@ -115,13 +91,6 @@ func (c *Certificate) UnmarshalJSON(data []byte) error {
|
|||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrap(err, "error decoding certificate")
|
||||
}
|
||||
|
||||
// Make sure the inner x509.Certificate is nil
|
||||
if s == "null" || s == "" {
|
||||
c.reset()
|
||||
return nil
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(s))
|
||||
if block == nil {
|
||||
return errors.New("error decoding certificate")
|
||||
|
@ -148,13 +117,6 @@ func NewCertificateRequest(cr *x509.CertificateRequest) CertificateRequest {
|
|||
}
|
||||
}
|
||||
|
||||
// reset sets the inner x509.CertificateRequest to nil
|
||||
func (c *CertificateRequest) reset() {
|
||||
if c != nil {
|
||||
c.CertificateRequest = nil
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface. The certificate request
|
||||
// is a quoted string using the PEM encoding.
|
||||
func (c CertificateRequest) MarshalJSON() ([]byte, error) {
|
||||
|
@ -175,13 +137,6 @@ func (c *CertificateRequest) UnmarshalJSON(data []byte) error {
|
|||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrap(err, "error decoding csr")
|
||||
}
|
||||
|
||||
// Make sure the inner x509.CertificateRequest is nil
|
||||
if s == "null" || s == "" {
|
||||
c.reset()
|
||||
return nil
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(s))
|
||||
if block == nil {
|
||||
return errors.New("error decoding csr")
|
||||
|
@ -207,13 +162,6 @@ type RouterHandler interface {
|
|||
Route(r Router)
|
||||
}
|
||||
|
||||
// VersionResponse is the response object that returns the version of the
|
||||
// server.
|
||||
type VersionResponse struct {
|
||||
Version string `json:"version"`
|
||||
RequireClientAuthentication bool `json:"requireClientAuthentication,omitempty"`
|
||||
}
|
||||
|
||||
// HealthResponse is the response object that returns the health of the server.
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
|
@ -224,42 +172,19 @@ type RootResponse struct {
|
|||
RootPEM Certificate `json:"ca"`
|
||||
}
|
||||
|
||||
// SignRequest is the request body for a certificate signature request.
|
||||
type SignRequest struct {
|
||||
CsrPEM CertificateRequest `json:"csr"`
|
||||
OTT string `json:"ott"`
|
||||
NotAfter TimeDuration `json:"notAfter"`
|
||||
NotBefore TimeDuration `json:"notBefore"`
|
||||
}
|
||||
|
||||
// ProvisionersResponse is the response object that returns the list of
|
||||
// provisioners.
|
||||
type ProvisionersResponse struct {
|
||||
Provisioners provisioner.List
|
||||
NextCursor string
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler. It marshals the ProvisionersResponse
|
||||
// into a byte slice.
|
||||
//
|
||||
// Special treatment is given to the SCEP provisioner, as it contains a
|
||||
// challenge secret that MUST NOT be leaked in (public) HTTP responses. The
|
||||
// challenge value is thus redacted in HTTP responses.
|
||||
func (p ProvisionersResponse) MarshalJSON() ([]byte, error) {
|
||||
for _, item := range p.Provisioners {
|
||||
scepProv, ok := item.(*provisioner.SCEP)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
old := scepProv.ChallengePassword
|
||||
scepProv.ChallengePassword = "*** REDACTED ***"
|
||||
defer func(p string) { //nolint:gocritic // defer in loop required to restore initial state of provisioners
|
||||
scepProv.ChallengePassword = p
|
||||
}(old)
|
||||
}
|
||||
|
||||
var list = struct {
|
||||
Provisioners []provisioner.Interface `json:"provisioners"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}{
|
||||
Provisioners: []provisioner.Interface(p.Provisioners),
|
||||
NextCursor: p.NextCursor,
|
||||
}
|
||||
|
||||
return json.Marshal(list)
|
||||
Provisioners provisioner.List `json:"provisioners"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}
|
||||
|
||||
// ProvisionerKeyResponse is the response object that returns the encrypted key
|
||||
|
@ -268,6 +193,31 @@ type ProvisionerKeyResponse struct {
|
|||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
// Validate checks the fields of the SignRequest and returns nil if they are ok
|
||||
// or an error if something is wrong.
|
||||
func (s *SignRequest) Validate() error {
|
||||
if s.CsrPEM.CertificateRequest == nil {
|
||||
return BadRequest(errors.New("missing csr"))
|
||||
}
|
||||
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
|
||||
return BadRequest(errors.Wrap(err, "invalid csr"))
|
||||
}
|
||||
if s.OTT == "" {
|
||||
return BadRequest(errors.New("missing ott"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignResponse is the response object of the certificate signature request.
|
||||
type SignResponse struct {
|
||||
ServerPEM Certificate `json:"crt"`
|
||||
CaPEM Certificate `json:"ca"`
|
||||
CertChainPEM []Certificate `json:"certChain"`
|
||||
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
|
||||
TLS *tls.ConnectionState `json:"-"`
|
||||
}
|
||||
|
||||
// RootsResponse is the response object of the roots request.
|
||||
type RootsResponse struct {
|
||||
Certificates []Certificate `json:"crts"`
|
||||
|
@ -283,78 +233,54 @@ type caHandler struct {
|
|||
Authority Authority
|
||||
}
|
||||
|
||||
// Route configures the http request router.
|
||||
func (h *caHandler) Route(r Router) {
|
||||
Route(r)
|
||||
}
|
||||
|
||||
// New creates a new RouterHandler with the CA endpoints.
|
||||
//
|
||||
// Deprecated: Use api.Route(r Router)
|
||||
func New(Authority) RouterHandler {
|
||||
return &caHandler{}
|
||||
func New(authority Authority) RouterHandler {
|
||||
return &caHandler{
|
||||
Authority: authority,
|
||||
}
|
||||
}
|
||||
|
||||
func Route(r Router) {
|
||||
r.MethodFunc("GET", "/version", Version)
|
||||
r.MethodFunc("GET", "/health", Health)
|
||||
r.MethodFunc("GET", "/root/{sha}", Root)
|
||||
r.MethodFunc("POST", "/sign", Sign)
|
||||
r.MethodFunc("POST", "/renew", Renew)
|
||||
r.MethodFunc("POST", "/rekey", Rekey)
|
||||
r.MethodFunc("POST", "/revoke", Revoke)
|
||||
r.MethodFunc("GET", "/crl", CRL)
|
||||
r.MethodFunc("GET", "/provisioners", Provisioners)
|
||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
|
||||
r.MethodFunc("GET", "/roots", Roots)
|
||||
r.MethodFunc("GET", "/roots.pem", RootsPEM)
|
||||
r.MethodFunc("GET", "/federation", Federation)
|
||||
func (h *caHandler) Route(r Router) {
|
||||
r.MethodFunc("GET", "/health", h.Health)
|
||||
r.MethodFunc("GET", "/root/{sha}", h.Root)
|
||||
r.MethodFunc("POST", "/sign", h.Sign)
|
||||
r.MethodFunc("POST", "/renew", h.Renew)
|
||||
r.MethodFunc("POST", "/revoke", h.Revoke)
|
||||
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
||||
r.MethodFunc("GET", "/roots", h.Roots)
|
||||
r.MethodFunc("GET", "/federation", h.Federation)
|
||||
// SSH CA
|
||||
r.MethodFunc("POST", "/ssh/sign", SSHSign)
|
||||
r.MethodFunc("POST", "/ssh/renew", SSHRenew)
|
||||
r.MethodFunc("POST", "/ssh/revoke", SSHRevoke)
|
||||
r.MethodFunc("POST", "/ssh/rekey", SSHRekey)
|
||||
r.MethodFunc("GET", "/ssh/roots", SSHRoots)
|
||||
r.MethodFunc("GET", "/ssh/federation", SSHFederation)
|
||||
r.MethodFunc("POST", "/ssh/config", SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost)
|
||||
r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts)
|
||||
r.MethodFunc("POST", "/ssh/bastion", SSHBastion)
|
||||
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
|
||||
r.MethodFunc("GET", "/ssh/roots", h.SSHRoots)
|
||||
r.MethodFunc("GET", "/ssh/federation", h.SSHFederation)
|
||||
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
|
||||
|
||||
// For compatibility with old code:
|
||||
r.MethodFunc("POST", "/re-sign", Renew)
|
||||
r.MethodFunc("POST", "/sign-ssh", SSHSign)
|
||||
r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts)
|
||||
}
|
||||
|
||||
// Version is an HTTP handler that returns the version of the server.
|
||||
func Version(w http.ResponseWriter, r *http.Request) {
|
||||
v := mustAuthority(r.Context()).Version()
|
||||
render.JSON(w, VersionResponse{
|
||||
Version: v.Version,
|
||||
RequireClientAuthentication: v.RequireClientAuthentication,
|
||||
})
|
||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||
r.MethodFunc("POST", "/sign-ssh", h.SSHSign)
|
||||
}
|
||||
|
||||
// Health is an HTTP handler that returns the status of the server.
|
||||
func Health(w http.ResponseWriter, _ *http.Request) {
|
||||
render.JSON(w, HealthResponse{Status: "ok"})
|
||||
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||
JSON(w, HealthResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
||||
// certificate for the given SHA256.
|
||||
func Root(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
||||
sha := chi.URLParam(r, "sha")
|
||||
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
|
||||
sum := strings.ToLower(strings.Replace(sha, "-", "", -1))
|
||||
// Load root certificate with the
|
||||
cert, err := mustAuthority(r.Context()).Root(sum)
|
||||
cert, err := h.Authority.Root(sum)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
WriteError(w, NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI)))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &RootResponse{RootPEM: Certificate{cert}})
|
||||
JSON(w, &RootResponse{RootPEM: Certificate{cert}})
|
||||
}
|
||||
|
||||
func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
||||
|
@ -365,43 +291,115 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
|||
return certChainPEM
|
||||
}
|
||||
|
||||
// Sign is an HTTP handler that reads a certificate request and an
|
||||
// one-time-token (ott) from the body and creates a new certificate with the
|
||||
// information in the certificate request.
|
||||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
opts := provisioner.Options{
|
||||
NotBefore: body.NotBefore,
|
||||
NotAfter: body.NotAfter,
|
||||
}
|
||||
|
||||
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
||||
if err != nil {
|
||||
WriteError(w, Unauthorized(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 0 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
logCertificate(w, certChain[0])
|
||||
JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
WriteError(w, BadRequest(errors.New("missing peer certificate")))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
|
||||
if err != nil {
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 0 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
|
||||
logCertificate(w, certChain[0])
|
||||
JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Provisioners returns the list of provisioners configured in the authority.
|
||||
func Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := ParseCursor(r)
|
||||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := parseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
WriteError(w, BadRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
p, next, err := h.Authority.GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
WriteError(w, InternalServerError(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionersResponse{
|
||||
JSON(w, &ProvisionersResponse{
|
||||
Provisioners: p,
|
||||
NextCursor: next,
|
||||
})
|
||||
}
|
||||
|
||||
// ProvisionerKey returns the encrypted key of a provisioner by it's key id.
|
||||
func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
kid := chi.URLParam(r, "kid")
|
||||
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
|
||||
key, err := h.Authority.GetEncryptedKey(kid)
|
||||
if err != nil {
|
||||
render.Error(w, errs.NotFoundErr(err))
|
||||
WriteError(w, NotFound(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionerKeyResponse{key})
|
||||
JSON(w, &ProvisionerKeyResponse{key})
|
||||
}
|
||||
|
||||
// Roots returns all the root certificates for the CA.
|
||||
func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := h.Authority.GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -410,39 +408,16 @@ func Roots(w http.ResponseWriter, r *http.Request) {
|
|||
certs[i] = Certificate{roots[i]}
|
||||
}
|
||||
|
||||
render.JSONStatus(w, &RootsResponse{
|
||||
JSONStatus(w, &RootsResponse{
|
||||
Certificates: certs,
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// RootsPEM returns all the root certificates for the CA in PEM format.
|
||||
func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/x-pem-file")
|
||||
|
||||
for _, root := range roots {
|
||||
block := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: root.Raw,
|
||||
})
|
||||
|
||||
if _, err := w.Write(block); err != nil {
|
||||
log.Error(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Federation returns all the public certificates in the federation.
|
||||
func Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := mustAuthority(r.Context()).GetFederation()
|
||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := h.Authority.GetFederation()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -451,7 +426,7 @@ func Federation(w http.ResponseWriter, r *http.Request) {
|
|||
certs[i] = Certificate{federated[i]}
|
||||
}
|
||||
|
||||
render.JSONStatus(w, &FederationResponse{
|
||||
JSONStatus(w, &FederationResponse{
|
||||
Certificates: certs,
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
@ -472,11 +447,10 @@ func logOtt(w http.ResponseWriter, token string) {
|
|||
}
|
||||
}
|
||||
|
||||
// LogCertificate adds certificate fields to the log message.
|
||||
func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
||||
func logCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
m := map[string]interface{}{
|
||||
"serial": cert.SerialNumber.String(),
|
||||
"serial": cert.SerialNumber,
|
||||
"subject": cert.Subject.CommonName,
|
||||
"issuer": cert.Issuer.CommonName,
|
||||
"valid-from": cert.NotBefore.Format(time.RFC3339),
|
||||
|
@ -485,73 +459,33 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
|||
"certificate": base64.StdEncoding.EncodeToString(cert.Raw),
|
||||
}
|
||||
for _, ext := range cert.Extensions {
|
||||
if !ext.Id.Equal(oidStepProvisioner) {
|
||||
continue
|
||||
}
|
||||
val := &stepProvisioner{}
|
||||
rest, err := asn1.Unmarshal(ext.Value, val)
|
||||
if err != nil || len(rest) > 0 {
|
||||
if ext.Id.Equal(oidStepProvisioner) {
|
||||
val := &stepProvisioner{}
|
||||
rest, err := asn1.Unmarshal(ext.Value, val)
|
||||
if err != nil || len(rest) > 0 {
|
||||
break
|
||||
}
|
||||
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
|
||||
break
|
||||
}
|
||||
if len(val.CredentialID) > 0 {
|
||||
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
|
||||
} else {
|
||||
m["provisioner"] = string(val.Name)
|
||||
}
|
||||
break
|
||||
}
|
||||
rl.WithFields(m)
|
||||
}
|
||||
}
|
||||
|
||||
// LogSSHCertificate adds SSH certificate fields to the log message.
|
||||
func LogSSHCertificate(w http.ResponseWriter, cert *ssh.Certificate) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
mak := bytes.TrimSpace(ssh.MarshalAuthorizedKey(cert))
|
||||
var certificate string
|
||||
parts := strings.Split(string(mak), " ")
|
||||
if len(parts) > 1 {
|
||||
certificate = parts[1]
|
||||
}
|
||||
var userOrHost string
|
||||
if cert.CertType == ssh.HostCert {
|
||||
userOrHost = "host"
|
||||
} else {
|
||||
userOrHost = "user"
|
||||
}
|
||||
certificateType := fmt.Sprintf("%s %s certificate", parts[0], userOrHost) // e.g. ecdsa-sha2-nistp256-cert-v01@openssh.com user certificate
|
||||
m := map[string]interface{}{
|
||||
"serial": cert.Serial,
|
||||
"principals": cert.ValidPrincipals,
|
||||
"valid-from": time.Unix(int64(cert.ValidAfter), 0).Format(time.RFC3339),
|
||||
"valid-to": time.Unix(int64(cert.ValidBefore), 0).Format(time.RFC3339),
|
||||
"certificate": certificate,
|
||||
"certificate-type": certificateType,
|
||||
}
|
||||
fingerprint, err := sshutil.FormatFingerprint(mak, sshutil.DefaultFingerprint)
|
||||
if err == nil {
|
||||
fpParts := strings.Split(fingerprint, " ")
|
||||
if len(fpParts) > 3 {
|
||||
m["public-key"] = fmt.Sprintf("%s %s", fpParts[1], fpParts[len(fpParts)-1])
|
||||
}
|
||||
}
|
||||
rl.WithFields(m)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseCursor parses the cursor and limit from the request query params.
|
||||
func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
|
||||
func parseCursor(r *http.Request) (cursor string, limit int, err error) {
|
||||
q := r.URL.Query()
|
||||
cursor = q.Get("cursor")
|
||||
if v := q.Get("limit"); len(v) > 0 {
|
||||
limit, err = strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v)
|
||||
return "", 0, errors.Wrapf(err, "error converting %s to integer", v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: add support for Ed25519 once it's supported
|
||||
func fmtPublicKey(cert *x509.Certificate) string {
|
||||
var params string
|
||||
switch pk := cert.PublicKey.(type) {
|
||||
|
@ -559,8 +493,6 @@ func fmtPublicKey(cert *x509.Certificate) string {
|
|||
params = pk.Curve.Params().Name
|
||||
case *rsa.PublicKey:
|
||||
params = strconv.Itoa(pk.Size() * 8)
|
||||
case ed25519.PublicKey:
|
||||
return cert.PublicKeyAlgorithm.String()
|
||||
case *dsa.PublicKey:
|
||||
params = strconv.Itoa(pk.Q.BitLen() * 8)
|
||||
default:
|
||||
|
|
1052
api/api_test.go
1052
api/api_test.go
File diff suppressed because it is too large
Load diff
32
api/crl.go
32
api/crl.go
|
@ -1,32 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
)
|
||||
|
||||
// CRL is an HTTP handler that returns the current CRL in DER or PEM format
|
||||
func CRL(w http.ResponseWriter, r *http.Request) {
|
||||
crlBytes, err := mustAuthority(r.Context()).GetCertificateRevocationList()
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, formatAsPEM := r.URL.Query()["pem"]
|
||||
if formatAsPEM {
|
||||
w.Header().Add("Content-Type", "application/x-pem-file")
|
||||
w.Header().Add("Content-Disposition", "attachment; filename=\"crl.pem\"")
|
||||
|
||||
_ = pem.Encode(w, &pem.Block{
|
||||
Type: "X509 CRL",
|
||||
Bytes: crlBytes,
|
||||
})
|
||||
} else {
|
||||
w.Header().Add("Content-Type", "application/pkix-crl")
|
||||
w.Header().Add("Content-Disposition", "attachment; filename=\"crl.der\"")
|
||||
w.Write(crlBytes)
|
||||
}
|
||||
}
|
154
api/errors.go
Normal file
154
api/errors.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// StatusCoder interface is used by errors that returns the HTTP response code.
|
||||
type StatusCoder interface {
|
||||
StatusCode() int
|
||||
}
|
||||
|
||||
// StackTracer must be by those errors that return an stack trace.
|
||||
type StackTracer interface {
|
||||
StackTrace() errors.StackTrace
|
||||
}
|
||||
|
||||
// Error represents the CA API errors.
|
||||
type Error struct {
|
||||
Status int
|
||||
Err error
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error in JSON format.
|
||||
type ErrorResponse struct {
|
||||
Status int `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Cause implements the errors.Causer interface and returns the original error.
|
||||
func (e *Error) Cause() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// Error implements the error interface and returns the error string.
|
||||
func (e *Error) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
// StatusCode implements the StatusCoder interface and returns the HTTP response
|
||||
// code.
|
||||
func (e *Error) StatusCode() int {
|
||||
return e.Status
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaller interface for the Error struct.
|
||||
func (e *Error) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(&ErrorResponse{Status: e.Status, Message: http.StatusText(e.Status)})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
|
||||
func (e *Error) UnmarshalJSON(data []byte) error {
|
||||
var er ErrorResponse
|
||||
if err := json.Unmarshal(data, &er); err != nil {
|
||||
return err
|
||||
}
|
||||
e.Status = er.Status
|
||||
e.Err = fmt.Errorf(er.Message)
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewError returns a new Error. If the given error implements the StatusCoder
|
||||
// interface we will ignore the given status.
|
||||
func NewError(status int, err error) error {
|
||||
if sc, ok := err.(StatusCoder); ok {
|
||||
return &Error{Status: sc.StatusCode(), Err: err}
|
||||
}
|
||||
cause := errors.Cause(err)
|
||||
if sc, ok := cause.(StatusCoder); ok {
|
||||
return &Error{Status: sc.StatusCode(), Err: err}
|
||||
}
|
||||
return &Error{Status: status, Err: err}
|
||||
}
|
||||
|
||||
// InternalServerError returns a 500 error with the given error.
|
||||
func InternalServerError(err error) error {
|
||||
return NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
// NotImplemented returns a 500 error with the given error.
|
||||
func NotImplemented(err error) error {
|
||||
return NewError(http.StatusNotImplemented, err)
|
||||
}
|
||||
|
||||
// BadRequest returns an 400 error with the given error.
|
||||
func BadRequest(err error) error {
|
||||
return NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
// Unauthorized returns an 401 error with the given error.
|
||||
func Unauthorized(err error) error {
|
||||
return NewError(http.StatusUnauthorized, err)
|
||||
}
|
||||
|
||||
// Forbidden returns an 403 error with the given error.
|
||||
func Forbidden(err error) error {
|
||||
return NewError(http.StatusForbidden, err)
|
||||
}
|
||||
|
||||
// NotFound returns an 404 error with the given error.
|
||||
func NotFound(err error) error {
|
||||
return NewError(http.StatusNotFound, err)
|
||||
}
|
||||
|
||||
// WriteError writes to w a JSON representation of the given error.
|
||||
func WriteError(w http.ResponseWriter, err error) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
w.Header().Set("Content-Type", "application/problem+json")
|
||||
err = k.ToACME()
|
||||
default:
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
cause := errors.Cause(err)
|
||||
if sc, ok := err.(StatusCoder); ok {
|
||||
w.WriteHeader(sc.StatusCode())
|
||||
} else {
|
||||
if sc, ok := cause.(StatusCoder); ok {
|
||||
w.WriteHeader(sc.StatusCode())
|
||||
} else {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// Write errors in the response writer
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"error": err,
|
||||
})
|
||||
if os.Getenv("STEPDEBUG") == "1" {
|
||||
if e, ok := err.(StackTracer); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
} else {
|
||||
if e, ok := cause.(StackTracer); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||||
LogError(w, err)
|
||||
}
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
// Package log implements API-related logging helpers.
|
||||
package log
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// StackTracedError is the set of errors implementing the StackTrace function.
|
||||
//
|
||||
// Errors implementing this interface have their stack traces logged when passed
|
||||
// to the Error function of this package.
|
||||
type StackTracedError interface {
|
||||
error
|
||||
|
||||
StackTrace() errors.StackTrace
|
||||
}
|
||||
|
||||
type fieldCarrier interface {
|
||||
WithFields(map[string]any)
|
||||
Fields() map[string]any
|
||||
}
|
||||
|
||||
// Error adds to the response writer the given error if it implements
|
||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||
// using the log package.
|
||||
func Error(rw http.ResponseWriter, err error) {
|
||||
fc, ok := rw.(fieldCarrier)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
fc.WithFields(map[string]any{
|
||||
"error": err,
|
||||
})
|
||||
|
||||
if os.Getenv("STEPDEBUG") != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
var st StackTracedError
|
||||
if errors.As(err, &st) {
|
||||
fc.WithFields(map[string]any{
|
||||
"stack-trace": fmt.Sprintf("%+v", st.StackTrace()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// EnabledResponse log the response object if it implements the EnableLogger
|
||||
// interface.
|
||||
func EnabledResponse(rw http.ResponseWriter, v any) {
|
||||
type enableLogger interface {
|
||||
ToLog() (any, error)
|
||||
}
|
||||
|
||||
if el, ok := v.(enableLogger); ok {
|
||||
out, err := el.ToLog()
|
||||
if err != nil {
|
||||
Error(rw, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rl, ok := rw.(fieldCarrier); ok {
|
||||
rl.WithFields(map[string]any{
|
||||
"response": out,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
type stackTracedError struct{}
|
||||
|
||||
func (stackTracedError) Error() string {
|
||||
return "a stacktraced error"
|
||||
}
|
||||
|
||||
func (stackTracedError) StackTrace() pkgerrors.StackTrace {
|
||||
f := struct{}{}
|
||||
return pkgerrors.StackTrace{ // fake stacktrace
|
||||
pkgerrors.Frame(unsafe.Pointer(&f)),
|
||||
pkgerrors.Frame(unsafe.Pointer(&f)),
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
error
|
||||
rw http.ResponseWriter
|
||||
isFieldCarrier bool
|
||||
stepDebug bool
|
||||
expectStackTrace bool
|
||||
}{
|
||||
{"noLogger", nil, nil, false, false, false},
|
||||
{"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false},
|
||||
{"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
|
||||
{"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false},
|
||||
{"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
|
||||
{"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
|
||||
{"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.stepDebug {
|
||||
t.Setenv("STEPDEBUG", "1")
|
||||
} else {
|
||||
t.Setenv("STEPDEBUG", "0")
|
||||
}
|
||||
|
||||
Error(tt.rw, tt.error)
|
||||
|
||||
// return early if test case doesn't use logger
|
||||
if !tt.isFieldCarrier {
|
||||
return
|
||||
}
|
||||
|
||||
fields := tt.rw.(logging.ResponseLogger).Fields()
|
||||
|
||||
// expect the error field to be (not) set and to be the same error that was fed to Error
|
||||
if tt.error == nil {
|
||||
assert.Nil(t, fields["error"])
|
||||
} else {
|
||||
assert.Same(t, tt.error, fields["error"])
|
||||
}
|
||||
|
||||
// check if stack-trace is set when expected
|
||||
if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] not set`)
|
||||
} else if !tt.expectStackTrace && hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] was set`)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
// Package read implements request object readers.
|
||||
package read
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// JSON reads JSON from the request body and stores it in the value
|
||||
// pointed to by v.
|
||||
func JSON(r io.Reader, v interface{}) error {
|
||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||
return errs.BadRequestErr(err, "error decoding json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProtoJSON reads JSON from the request body and stores it in the value
|
||||
// pointed to by m.
|
||||
func ProtoJSON(r io.Reader, m proto.Message) error {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return errs.BadRequestErr(err, "error reading request body")
|
||||
}
|
||||
|
||||
switch err := protojson.Unmarshal(data, m); {
|
||||
case errors.Is(err, proto.Error):
|
||||
return badProtoJSONError(err.Error())
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// badProtoJSONError is an error type that is returned by ProtoJSON
|
||||
// when a proto message cannot be unmarshaled. Usually this is caused
|
||||
// by an error in the request body.
|
||||
type badProtoJSONError string
|
||||
|
||||
// Error implements error for badProtoJSONError
|
||||
func (e badProtoJSONError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Render implements render.RenderableError for badProtoJSONError
|
||||
func (e badProtoJSONError) Render(w http.ResponseWriter) {
|
||||
v := struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Type: "badRequest",
|
||||
Detail: "bad request",
|
||||
// trim the proto prefix for the message
|
||||
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
|
||||
}
|
||||
render.JSONStatus(w, v, http.StatusBadRequest)
|
||||
}
|
|
@ -1,165 +0,0 @@
|
|||
package read
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
type args struct {
|
||||
r io.Reader
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
|
||||
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := JSON(tt.args.r, &tt.args.v)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
var e *errs.Error
|
||||
if errors.As(err, &e) {
|
||||
if code := e.StatusCode(); code != 400 {
|
||||
t.Errorf("error.StatusCode() = %v, wants 400", code)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("error type = %T, wants *Error", err)
|
||||
}
|
||||
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
|
||||
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtoJSON(t *testing.T) {
|
||||
|
||||
p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import?
|
||||
|
||||
type args struct {
|
||||
r io.Reader
|
||||
m proto.Message
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "fail/io.ReadAll",
|
||||
args: args{
|
||||
r: iotest.ErrReader(errors.New("read error")),
|
||||
m: p,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail/proto",
|
||||
args: args{
|
||||
r: strings.NewReader(`{?}`),
|
||||
m: p,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
r: strings.NewReader(`{"x509":{}}`),
|
||||
m: p,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ProtoJSON(tt.args.r, tt.args.m)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
var (
|
||||
ee *errs.Error
|
||||
bpe badProtoJSONError
|
||||
)
|
||||
switch {
|
||||
case errors.As(err, &bpe):
|
||||
assert.Contains(t, err.Error(), "syntax error")
|
||||
case errors.As(err, &ee):
|
||||
assert.Equal(t, http.StatusBadRequest, ee.Status)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m))
|
||||
assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_badProtoJSONError_Render(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e badProtoJSONError
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "bad proto normal space",
|
||||
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
|
||||
expected: "syntax error (line 1:2): invalid value ?",
|
||||
},
|
||||
{
|
||||
name: "bad proto non breaking space",
|
||||
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
|
||||
expected: "syntax error (line 1:2): invalid value ?",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tt.e.Render(w)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Message string `json:"message"`
|
||||
}{}
|
||||
|
||||
assert.NoError(t, json.Unmarshal(data, &v))
|
||||
assert.Equal(t, "badRequest", v.Type)
|
||||
assert.Equal(t, "bad request", v.Detail)
|
||||
assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
66
api/rekey.go
66
api/rekey.go
|
@ -1,66 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// RekeyRequest is the request body for a certificate rekey request.
|
||||
type RekeyRequest struct {
|
||||
CsrPEM CertificateRequest `json:"csr"`
|
||||
}
|
||||
|
||||
// Validate checks the fields of the RekeyRequest and returns nil if they are ok
|
||||
// or an error if something is wrong.
|
||||
func (s *RekeyRequest) Validate() error {
|
||||
if s.CsrPEM.CertificateRequest == nil {
|
||||
return errs.BadRequest("missing csr")
|
||||
}
|
||||
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
|
||||
return errs.BadRequestErr(err, "invalid csr")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
||||
func Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
render.Error(w, errs.BadRequest("missing client certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
var body RekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 1 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
|
@ -1,135 +0,0 @@
|
|||
// Package render implements functionality related to response rendering.
|
||||
package render
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/smallstep/certificates/api/log"
|
||||
)
|
||||
|
||||
// JSON is shorthand for JSONStatus(w, v, http.StatusOK).
|
||||
func JSON(w http.ResponseWriter, v interface{}) {
|
||||
JSONStatus(w, v, http.StatusOK)
|
||||
}
|
||||
|
||||
// JSONStatus marshals v into w. It additionally sets the status code of
|
||||
// w to the given one.
|
||||
//
|
||||
// JSONStatus sets the Content-Type of w to application/json unless one is
|
||||
// specified.
|
||||
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||
setContentTypeUnlessPresent(w, "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
var errUnsupportedType *json.UnsupportedTypeError
|
||||
if errors.As(err, &errUnsupportedType) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var errUnsupportedValue *json.UnsupportedValueError
|
||||
if errors.As(err, &errUnsupportedValue) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var errMarshalError *json.MarshalerError
|
||||
if errors.As(err, &errMarshalError) {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
log.EnabledResponse(w, v)
|
||||
}
|
||||
|
||||
// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK).
|
||||
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
|
||||
ProtoJSONStatus(w, m, http.StatusOK)
|
||||
}
|
||||
|
||||
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
|
||||
// given status is written as the status code of the response.
|
||||
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
|
||||
b, err := protojson.Marshal(m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
setContentTypeUnlessPresent(w, "application/json")
|
||||
w.WriteHeader(status)
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
|
||||
func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) {
|
||||
const header = "Content-Type"
|
||||
|
||||
h := w.Header()
|
||||
if _, ok := h[header]; !ok {
|
||||
h.Set(header, contentType)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderableError is the set of errors that implement the basic Render method.
|
||||
//
|
||||
// Errors that implement this interface will use their own Render method when
|
||||
// being rendered into responses.
|
||||
type RenderableError interface {
|
||||
error
|
||||
|
||||
Render(http.ResponseWriter)
|
||||
}
|
||||
|
||||
// Error marshals the JSON representation of err to w. In case err implements
|
||||
// RenderableError its own Render method will be called instead.
|
||||
func Error(w http.ResponseWriter, err error) {
|
||||
log.Error(w, err)
|
||||
|
||||
var r RenderableError
|
||||
if errors.As(err, &r) {
|
||||
r.Render(w)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
JSONStatus(w, err, statusCodeFromError(err))
|
||||
}
|
||||
|
||||
// StatusCodedError is the set of errors that implement the basic StatusCode
|
||||
// function.
|
||||
//
|
||||
// Errors that implement this interface will use the code reported by StatusCode
|
||||
// as the HTTP response code when being rendered by this package.
|
||||
type StatusCodedError interface {
|
||||
error
|
||||
|
||||
StatusCode() int
|
||||
}
|
||||
|
||||
func statusCodeFromError(err error) (code int) {
|
||||
code = http.StatusInternalServerError
|
||||
|
||||
type causer interface {
|
||||
Cause() error
|
||||
}
|
||||
|
||||
for err != nil {
|
||||
var sc StatusCodedError
|
||||
if errors.As(err, &sc) {
|
||||
code = sc.StatusCode()
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
var c causer
|
||||
if !errors.As(err, &c) {
|
||||
break
|
||||
}
|
||||
err = c.Cause()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1,150 +0,0 @@
|
|||
package render
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := logging.NewResponseLogger(rec)
|
||||
|
||||
JSON(rw, map[string]interface{}{"foo": "bar"})
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Result().StatusCode)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String())
|
||||
|
||||
assert.Empty(t, rw.Fields())
|
||||
}
|
||||
|
||||
func TestJSONPanicsOnUnsupportedType(t *testing.T) {
|
||||
jsonPanicTest[json.UnsupportedTypeError](t, make(chan struct{}))
|
||||
}
|
||||
|
||||
func TestJSONPanicsOnUnsupportedValue(t *testing.T) {
|
||||
jsonPanicTest[json.UnsupportedValueError](t, math.NaN())
|
||||
}
|
||||
|
||||
func TestJSONPanicsOnMarshalerError(t *testing.T) {
|
||||
var v erroneousJSONMarshaler
|
||||
jsonPanicTest[json.MarshalerError](t, v)
|
||||
}
|
||||
|
||||
type erroneousJSONMarshaler struct{}
|
||||
|
||||
func (erroneousJSONMarshaler) MarshalJSON() ([]byte, error) {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
|
||||
func jsonPanicTest[T json.UnsupportedTypeError | json.UnsupportedValueError | json.MarshalerError](t *testing.T, v any) {
|
||||
t.Helper()
|
||||
|
||||
defer func() {
|
||||
var err error
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic")
|
||||
} else if e, ok := r.(error); !ok {
|
||||
t.Fatalf("did not panic with an error (%T)", r)
|
||||
} else {
|
||||
err = e
|
||||
}
|
||||
|
||||
var e *T
|
||||
assert.ErrorAs(t, err, &e)
|
||||
}()
|
||||
|
||||
JSON(httptest.NewRecorder(), v)
|
||||
}
|
||||
|
||||
type renderableError struct {
|
||||
Code int `json:"-"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (err renderableError) Error() string {
|
||||
return err.Message
|
||||
}
|
||||
|
||||
func (err renderableError) Render(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "something/custom")
|
||||
|
||||
JSONStatus(w, err, err.Code)
|
||||
}
|
||||
|
||||
type statusedError struct {
|
||||
Contents string
|
||||
}
|
||||
|
||||
func (err statusedError) Error() string { return err.Contents }
|
||||
|
||||
func (statusedError) StatusCode() int { return 432 }
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
cases := []struct {
|
||||
err error
|
||||
code int
|
||||
body string
|
||||
header string
|
||||
}{
|
||||
0: {
|
||||
err: renderableError{532, "some string"},
|
||||
code: 532,
|
||||
body: "{\"message\":\"some string\"}\n",
|
||||
header: "something/custom",
|
||||
},
|
||||
1: {
|
||||
err: statusedError{"123"},
|
||||
code: 432,
|
||||
body: "{\"Contents\":\"123\"}\n",
|
||||
header: "application/json",
|
||||
},
|
||||
}
|
||||
|
||||
for caseIndex := range cases {
|
||||
kase := cases[caseIndex]
|
||||
|
||||
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
Error(rec, kase.err)
|
||||
|
||||
assert.Equal(t, kase.code, rec.Result().StatusCode)
|
||||
assert.Equal(t, kase.body, rec.Body.String())
|
||||
assert.Equal(t, kase.header, rec.Header().Get("Content-Type"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type causedError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) }
|
||||
func (err causedError) Cause() error { return err.cause }
|
||||
|
||||
func TestStatusCodeFromError(t *testing.T) {
|
||||
cases := []struct {
|
||||
err error
|
||||
exp int
|
||||
}{
|
||||
0: {nil, http.StatusInternalServerError},
|
||||
1: {io.EOF, http.StatusInternalServerError},
|
||||
2: {statusedError{"123"}, 432},
|
||||
3: {causedError{statusedError{"432"}}, 432},
|
||||
}
|
||||
|
||||
for caseIndex, kase := range cases {
|
||||
assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex)
|
||||
}
|
||||
}
|
68
api/renew.go
68
api/renew.go
|
@ -1,68 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
bearerScheme = "Bearer"
|
||||
)
|
||||
|
||||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get the leaf certificate from the peer or the token.
|
||||
cert, token, err := getPeerCertificate(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// The token can be used by RAs to renew a certificate.
|
||||
if token != "" {
|
||||
ctx = authority.NewTokenContext(ctx, token)
|
||||
}
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
certChain, err := a.RenewContext(ctx, cert, nil)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 1 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
return r.TLS.PeerCertificates[0], "", nil
|
||||
}
|
||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||
ctx := r.Context()
|
||||
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||
return peer, parts[1], err
|
||||
}
|
||||
}
|
||||
return nil, "", errs.BadRequest("missing client certificate")
|
||||
}
|
|
@ -1,17 +1,12 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
)
|
||||
|
||||
// RevokeResponse is the response object that returns the health of the server.
|
||||
|
@ -32,18 +27,13 @@ type RevokeRequest struct {
|
|||
// or an error if something is wrong.
|
||||
func (r *RevokeRequest) Validate() (err error) {
|
||||
if r.Serial == "" {
|
||||
return errs.BadRequest("missing serial")
|
||||
return BadRequest(errors.New("missing serial"))
|
||||
}
|
||||
sn, ok := new(big.Int).SetString(r.Serial, 0)
|
||||
if !ok {
|
||||
return errs.BadRequest("'%s' is not a valid serial number - use a base 10 representation or a base 16 representation with '0x' prefix", r.Serial)
|
||||
}
|
||||
r.Serial = sn.String()
|
||||
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
|
||||
return errs.BadRequest("reasonCode out of bounds")
|
||||
return BadRequest(errors.New("reasonCode out of bounds"))
|
||||
}
|
||||
if !r.Passive {
|
||||
return errs.NotImplemented("non-passive revocation not implemented")
|
||||
return NotImplemented(errors.New("non-passive revocation not implemented"))
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -54,15 +44,15 @@ func (r *RevokeRequest) Validate() (err error) {
|
|||
// NOTE: currently only Passive revocation is supported.
|
||||
//
|
||||
// TODO: Add CRL and OCSP support.
|
||||
func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body RevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -73,45 +63,31 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
PassiveOnly: body.Passive,
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod)
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
if len(body.OTT) > 0 {
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
} else {
|
||||
// If no token is present, then the request must be made over mTLS and
|
||||
// the client certificate Serial Number must match the serial number
|
||||
// being revoked.
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
render.Error(w, errs.BadRequest("missing ott or client certificate"))
|
||||
WriteError(w, BadRequest(errors.New("missing ott or peer certificate")))
|
||||
return
|
||||
}
|
||||
opts.Crt = r.TLS.PeerCertificates[0]
|
||||
if opts.Crt.SerialNumber.String() != opts.Serial {
|
||||
render.Error(w, errs.BadRequest("serial number in client certificate different than body"))
|
||||
return
|
||||
}
|
||||
// TODO: should probably be checking if the certificate was revoked here.
|
||||
// Will need to thread that request down to the authority, so will need
|
||||
// to add API for that.
|
||||
LogCertificate(w, opts.Crt)
|
||||
logCertificate(w, opts.Crt)
|
||||
opts.MTLS = true
|
||||
}
|
||||
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||
if err := h.Authority.Revoke(opts); err != nil {
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
|
||||
logRevoke(w, opts)
|
||||
render.JSON(w, &RevokeResponse{Status: "ok"})
|
||||
JSON(w, &RevokeResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||
|
|
|
@ -2,58 +2,51 @@ package api
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
func TestRevokeRequestValidate(t *testing.T) {
|
||||
type test struct {
|
||||
rr *RevokeRequest
|
||||
err *errs.Error
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]test{
|
||||
"error/missing serial": {
|
||||
rr: &RevokeRequest{},
|
||||
err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest},
|
||||
},
|
||||
"error/bad sn": {
|
||||
rr: &RevokeRequest{Serial: "sn"},
|
||||
err: &errs.Error{Err: errors.New("'sn' is not a valid serial number - use a base 10 representation or a base 16 representation with '0x' prefix"), Status: http.StatusBadRequest},
|
||||
err: &Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest},
|
||||
},
|
||||
"error/bad reasonCode": {
|
||||
rr: &RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 15,
|
||||
Passive: true,
|
||||
},
|
||||
err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest},
|
||||
err: &Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest},
|
||||
},
|
||||
"error/non-passive not implemented": {
|
||||
rr: &RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 8,
|
||||
Passive: false,
|
||||
},
|
||||
err: &errs.Error{Err: errors.New("non-passive revocation not implemented"), Status: http.StatusNotImplemented},
|
||||
err: &Error{Err: errors.New("non-passive revocation not implemented"), Status: http.StatusNotImplemented},
|
||||
},
|
||||
"ok": {
|
||||
rr: &RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 9,
|
||||
Passive: true,
|
||||
},
|
||||
|
@ -62,12 +55,12 @@ func TestRevokeRequestValidate(t *testing.T) {
|
|||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.rr.Validate(); err != nil {
|
||||
var ee *errs.Error
|
||||
if errors.As(err, &ee) {
|
||||
assert.HasPrefix(t, ee.Error(), tc.err.Error())
|
||||
assert.Equals(t, ee.StatusCode(), tc.err.Status)
|
||||
} else {
|
||||
t.Errorf("unexpected error type: %T", err)
|
||||
switch v := err.(type) {
|
||||
case *Error:
|
||||
assert.HasPrefix(t, v.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.StatusCode(), tc.err.Status)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
|
@ -101,7 +94,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
},
|
||||
"200/ott": func(t *testing.T) test {
|
||||
input, err := json.Marshal(RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 4,
|
||||
Reason: "foo",
|
||||
OTT: "valid",
|
||||
|
@ -112,13 +105,10 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusOK,
|
||||
auth: &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
revoke: func(opts *authority.RevokeOptions) error {
|
||||
assert.True(t, opts.PassiveOnly)
|
||||
assert.False(t, opts.MTLS)
|
||||
assert.Equals(t, opts.Serial, "10")
|
||||
assert.Equals(t, opts.Serial, "sn")
|
||||
assert.Equals(t, opts.ReasonCode, 4)
|
||||
assert.Equals(t, opts.Reason, "foo")
|
||||
return nil
|
||||
|
@ -129,7 +119,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
},
|
||||
"400/no OTT and no peer certificate": func(t *testing.T) test {
|
||||
input, err := json.Marshal(RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 4,
|
||||
Passive: true,
|
||||
})
|
||||
|
@ -156,10 +146,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
statusCode: http.StatusOK,
|
||||
tls: cs,
|
||||
auth: &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, ri *authority.RevokeOptions) error {
|
||||
revoke: func(ri *authority.RevokeOptions) error {
|
||||
assert.True(t, ri.PassiveOnly)
|
||||
assert.True(t, ri.MTLS)
|
||||
assert.Equals(t, ri.Serial, "1404354960355712309")
|
||||
|
@ -180,7 +167,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
},
|
||||
"500/ott authority.Revoke": func(t *testing.T) test {
|
||||
input, err := json.Marshal(RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 4,
|
||||
Reason: "foo",
|
||||
OTT: "valid",
|
||||
|
@ -191,18 +178,15 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusInternalServerError,
|
||||
auth: &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
return errs.InternalServer("force")
|
||||
revoke: func(opts *authority.RevokeOptions) error {
|
||||
return InternalServerError(errors.New("force"))
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"403/ott authority.Revoke": func(t *testing.T) test {
|
||||
input, err := json.Marshal(RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 4,
|
||||
Reason: "foo",
|
||||
OTT: "valid",
|
||||
|
@ -213,10 +197,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusForbidden,
|
||||
auth: &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
revoke: func(opts *authority.RevokeOptions) error {
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -227,18 +208,18 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
for name, _tc := range tests {
|
||||
tc := _tc(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := New(tc.auth).(*caHandler)
|
||||
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
|
||||
if tc.tls != nil {
|
||||
req.TLS = tc.tls
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
Revoke(logging.NewResponseLogger(w), req)
|
||||
h.Revoke(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
|
99
api/sign.go
99
api/sign.go
|
@ -1,99 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// SignRequest is the request body for a certificate signature request.
|
||||
type SignRequest struct {
|
||||
CsrPEM CertificateRequest `json:"csr"`
|
||||
OTT string `json:"ott"`
|
||||
NotAfter TimeDuration `json:"notAfter,omitempty"`
|
||||
NotBefore TimeDuration `json:"notBefore,omitempty"`
|
||||
TemplateData json.RawMessage `json:"templateData,omitempty"`
|
||||
}
|
||||
|
||||
// Validate checks the fields of the SignRequest and returns nil if they are ok
|
||||
// or an error if something is wrong.
|
||||
func (s *SignRequest) Validate() error {
|
||||
if s.CsrPEM.CertificateRequest == nil {
|
||||
return errs.BadRequest("missing csr")
|
||||
}
|
||||
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
|
||||
return errs.BadRequestErr(err, "invalid csr")
|
||||
}
|
||||
if s.OTT == "" {
|
||||
return errs.BadRequest("missing ott")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignResponse is the response object of the certificate signature request.
|
||||
type SignResponse struct {
|
||||
ServerPEM Certificate `json:"crt"`
|
||||
CaPEM Certificate `json:"ca"`
|
||||
CertChainPEM []Certificate `json:"certChain"`
|
||||
TLSOptions *config.TLSOptions `json:"tlsOptions,omitempty"`
|
||||
TLS *tls.ConnectionState `json:"-"`
|
||||
}
|
||||
|
||||
// Sign is an HTTP handler that reads a certificate request and an
|
||||
// one-time-token (ott) from the body and creates a new certificate with the
|
||||
// information in the certificate request.
|
||||
func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
opts := provisioner.SignOptions{
|
||||
NotBefore: body.NotBefore,
|
||||
NotAfter: body.NotAfter,
|
||||
TemplateData: body.TemplateData,
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
var caPEM Certificate
|
||||
if len(certChainPEM) > 1 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
290
api/ssh.go
290
api/ssh.go
|
@ -2,77 +2,56 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SSHAuthority is the interface implemented by a SSH CA authority.
|
||||
type SSHAuthority interface {
|
||||
SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
GetSSHRoots(ctx context.Context) (*config.SSHKeys, error)
|
||||
GetSSHFederation(ctx context.Context) (*config.SSHKeys, error)
|
||||
GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
|
||||
CheckSSHHost(ctx context.Context, principal string, token string) (bool, error)
|
||||
GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
|
||||
GetSSHBastion(ctx context.Context, user string, hostname string) (*config.Bastion, error)
|
||||
SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
GetSSHRoots() (*authority.SSHKeys, error)
|
||||
GetSSHFederation() (*authority.SSHKeys, error)
|
||||
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
|
||||
CheckSSHHost(principal string) (bool, error)
|
||||
}
|
||||
|
||||
// SSHSignRequest is the request body of an SSH certificate request.
|
||||
type SSHSignRequest struct {
|
||||
PublicKey []byte `json:"publicKey"` // base64 encoded
|
||||
OTT string `json:"ott"`
|
||||
CertType string `json:"certType,omitempty"`
|
||||
KeyID string `json:"keyID,omitempty"`
|
||||
Principals []string `json:"principals,omitempty"`
|
||||
ValidAfter TimeDuration `json:"validAfter,omitempty"`
|
||||
ValidBefore TimeDuration `json:"validBefore,omitempty"`
|
||||
AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"`
|
||||
IdentityCSR CertificateRequest `json:"identityCSR,omitempty"`
|
||||
TemplateData json.RawMessage `json:"templateData,omitempty"`
|
||||
PublicKey []byte `json:"publicKey"` //base64 encoded
|
||||
OTT string `json:"ott"`
|
||||
CertType string `json:"certType,omitempty"`
|
||||
Principals []string `json:"principals,omitempty"`
|
||||
ValidAfter TimeDuration `json:"validAfter,omitempty"`
|
||||
ValidBefore TimeDuration `json:"validBefore,omitempty"`
|
||||
AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"`
|
||||
}
|
||||
|
||||
// Validate validates the SSHSignRequest.
|
||||
func (s *SSHSignRequest) Validate() error {
|
||||
switch {
|
||||
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
|
||||
return errs.BadRequest("invalid certType '%s'", s.CertType)
|
||||
return errors.Errorf("unknown certType %s", s.CertType)
|
||||
case len(s.PublicKey) == 0:
|
||||
return errs.BadRequest("missing or empty publicKey")
|
||||
case s.OTT == "":
|
||||
return errs.BadRequest("missing or empty ott")
|
||||
return errors.New("missing or empty publicKey")
|
||||
case len(s.OTT) == 0:
|
||||
return errors.New("missing or empty ott")
|
||||
default:
|
||||
// Validate identity signature if provided
|
||||
if s.IdentityCSR.CertificateRequest != nil {
|
||||
if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil {
|
||||
return errs.BadRequestErr(err, "invalid identityCSR")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SSHSignResponse is the response object that returns the SSH certificate.
|
||||
type SSHSignResponse struct {
|
||||
Certificate SSHCertificate `json:"crt"`
|
||||
AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"`
|
||||
IdentityCertificate []Certificate `json:"identityCrt,omitempty"`
|
||||
Certificate SSHCertificate `json:"crt"`
|
||||
AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"`
|
||||
}
|
||||
|
||||
// SSHRootsResponse represents the response object that returns the SSH user and
|
||||
|
@ -87,12 +66,6 @@ type SSHCertificate struct {
|
|||
*ssh.Certificate `json:"omitempty"`
|
||||
}
|
||||
|
||||
// SSHGetHostsResponse is the response object that returns the list of valid
|
||||
// hosts for SSH.
|
||||
type SSHGetHostsResponse struct {
|
||||
Hosts []config.Host `json:"hosts"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface. Returns a quoted,
|
||||
// base64 encoded, openssh wire format version of the certificate.
|
||||
func (c SSHCertificate) MarshalJSON() ([]byte, error) {
|
||||
|
@ -188,7 +161,7 @@ func (r *SSHConfigRequest) Validate() error {
|
|||
case provisioner.SSHUserCert, provisioner.SSHHostCert:
|
||||
return nil
|
||||
default:
|
||||
return errs.BadRequest("invalid type '%s'", r.Type)
|
||||
return errors.Errorf("unsupported type %s", r.Type)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -204,16 +177,15 @@ type SSHConfigResponse struct {
|
|||
type SSHCheckPrincipalRequest struct {
|
||||
Type string `json:"type"`
|
||||
Principal string `json:"principal"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
|
||||
// Validate checks the check principal request.
|
||||
func (r *SSHCheckPrincipalRequest) Validate() error {
|
||||
switch {
|
||||
case r.Type != provisioner.SSHHostCert:
|
||||
return errs.BadRequest("unsupported type '%s'", r.Type)
|
||||
return errors.Errorf("unsupported type %s", r.Type)
|
||||
case r.Principal == "":
|
||||
return errs.BadRequest("missing or empty principal")
|
||||
return errors.New("missing or empty principal")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
@ -225,47 +197,25 @@ type SSHCheckPrincipalResponse struct {
|
|||
Exists bool `json:"exists"`
|
||||
}
|
||||
|
||||
// SSHBastionRequest is the request body used to get the bastion for a given
|
||||
// host.
|
||||
type SSHBastionRequest struct {
|
||||
User string `json:"user"`
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
|
||||
// Validate checks the values of the SSHBastionRequest.
|
||||
func (r *SSHBastionRequest) Validate() error {
|
||||
if r.Hostname == "" {
|
||||
return errs.BadRequest("missing or empty hostname")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SSHBastionResponse is the response body used to return the bastion for a
|
||||
// given host.
|
||||
type SSHBastionResponse struct {
|
||||
Hostname string `json:"hostname"`
|
||||
Bastion *config.Bastion `json:"bastion,omitempty"`
|
||||
}
|
||||
|
||||
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHSignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
WriteError(w, BadRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey")))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -273,91 +223,59 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
if body.AddUserPublicKey != nil {
|
||||
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
opts := provisioner.SignSSHOptions{
|
||||
CertType: body.CertType,
|
||||
KeyID: body.KeyID,
|
||||
Principals: body.Principals,
|
||||
ValidBefore: body.ValidBefore,
|
||||
ValidAfter: body.ValidAfter,
|
||||
TemplateData: body.TemplateData,
|
||||
opts := provisioner.SSHOptions{
|
||||
CertType: body.CertType,
|
||||
Principals: body.Principals,
|
||||
ValidBefore: body.ValidBefore,
|
||||
ValidAfter: body.ValidAfter,
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
WriteError(w, Unauthorized(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
|
||||
var addUserCertificate *SSHCertificate
|
||||
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
||||
addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
|
||||
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
WriteError(w, Forbidden(err))
|
||||
return
|
||||
}
|
||||
addUserCertificate = &SSHCertificate{addUserCert}
|
||||
}
|
||||
|
||||
// Sign identity certificate if available.
|
||||
var identityCertificate []Certificate
|
||||
if cr := body.IdentityCSR.CertificateRequest; cr != nil {
|
||||
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Enforce the same duration as ssh certificate.
|
||||
signOpts = append(signOpts, &identityModifier{
|
||||
NotBefore: time.Unix(int64(cert.ValidAfter), 0),
|
||||
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
|
||||
})
|
||||
|
||||
certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||
return
|
||||
}
|
||||
identityCertificate = certChainToPEM(certChain)
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, cert)
|
||||
render.JSONStatus(w, &SSHSignResponse{
|
||||
Certificate: SSHCertificate{cert},
|
||||
AddUserCertificate: addUserCertificate,
|
||||
IdentityCertificate: identityCertificate,
|
||||
}, http.StatusCreated)
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
JSON(w, &SSHSignResponse{
|
||||
Certificate: SSHCertificate{cert},
|
||||
AddUserCertificate: addUserCertificate,
|
||||
})
|
||||
}
|
||||
|
||||
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
|
||||
// certificates.
|
||||
func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
|
||||
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
WriteError(w, InternalServerError(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
render.Error(w, errs.NotFound("no keys found"))
|
||||
WriteError(w, NotFound(errors.New("no keys found")))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -369,21 +287,20 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
|||
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
||||
}
|
||||
|
||||
render.JSON(w, resp)
|
||||
JSON(w, resp)
|
||||
}
|
||||
|
||||
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
||||
// for user and host certificates.
|
||||
func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
|
||||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHFederation()
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
WriteError(w, InternalServerError(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
render.Error(w, errs.NotFound("no keys found"))
|
||||
WriteError(w, NotFound(errors.New("no keys found")))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -395,117 +312,60 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
|||
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
||||
}
|
||||
|
||||
render.JSON(w, resp)
|
||||
JSON(w, resp)
|
||||
}
|
||||
|
||||
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
||||
// and servers.
|
||||
func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHConfigRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
WriteError(w, BadRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
|
||||
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
WriteError(w, InternalServerError(err))
|
||||
return
|
||||
}
|
||||
|
||||
var cfg SSHConfigResponse
|
||||
var config SSHConfigResponse
|
||||
switch body.Type {
|
||||
case provisioner.SSHUserCert:
|
||||
cfg.UserTemplates = ts
|
||||
config.UserTemplates = ts
|
||||
case provisioner.SSHHostCert:
|
||||
cfg.HostTemplates = ts
|
||||
config.HostTemplates = ts
|
||||
default:
|
||||
render.Error(w, errs.InternalServer("it should hot get here"))
|
||||
WriteError(w, InternalServerError(errors.New("it should hot get here")))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, cfg)
|
||||
JSON(w, config)
|
||||
}
|
||||
|
||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||
func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHCheckPrincipalRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
WriteError(w, BadRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
|
||||
exists, err := h.Authority.CheckSSHHost(body.Principal)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
WriteError(w, InternalServerError(err))
|
||||
return
|
||||
}
|
||||
render.JSON(w, &SSHCheckPrincipalResponse{
|
||||
JSON(w, &SSHCheckPrincipalResponse{
|
||||
Exists: exists,
|
||||
})
|
||||
}
|
||||
|
||||
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
||||
func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
var cert *x509.Certificate
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
cert = r.TLS.PeerCertificates[0]
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
render.JSON(w, &SSHGetHostsResponse{
|
||||
Hosts: hosts,
|
||||
})
|
||||
}
|
||||
|
||||
// SSHBastion provides returns the bastion configured if any.
|
||||
func SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHBastionRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &SSHBastionResponse{
|
||||
Hostname: body.Hostname,
|
||||
Bastion: bastion,
|
||||
})
|
||||
}
|
||||
|
||||
// identityModifier is a custom modifier used to force a fixed duration.
|
||||
type identityModifier struct {
|
||||
NotBefore time.Time
|
||||
NotAfter time.Time
|
||||
}
|
||||
|
||||
func (m *identityModifier) Enforce(cert *x509.Certificate) error {
|
||||
cert.NotBefore = m.NotBefore
|
||||
cert.NotAfter = m.NotAfter
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,97 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// SSHRekeyRequest is the request body of an SSH certificate request.
|
||||
type SSHRekeyRequest struct {
|
||||
OTT string `json:"ott"`
|
||||
PublicKey []byte `json:"publicKey"` //base64 encoded
|
||||
}
|
||||
|
||||
// Validate validates the SSHSignRekey.
|
||||
func (s *SSHRekeyRequest) Validate() error {
|
||||
switch {
|
||||
case s.OTT == "":
|
||||
return errs.BadRequest("missing or empty ott")
|
||||
case len(s.PublicKey) == 0:
|
||||
return errs.BadRequest("missing or empty public key")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SSHRekeyResponse is the response object that returns the SSH certificate.
|
||||
type SSHRekeyResponse struct {
|
||||
Certificate SSHCertificate `json:"crt"`
|
||||
IdentityCertificate []Certificate `json:"identityCrt,omitempty"`
|
||||
}
|
||||
|
||||
// SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
// Match identity cert with the SSH cert
|
||||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, newCert)
|
||||
render.JSONStatus(w, &SSHRekeyResponse{
|
||||
Certificate: SSHCertificate{newCert},
|
||||
IdentityCertificate: identity,
|
||||
}, http.StatusCreated)
|
||||
}
|
118
api/sshRenew.go
118
api/sshRenew.go
|
@ -1,118 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// SSHRenewRequest is the request body of an SSH certificate request.
|
||||
type SSHRenewRequest struct {
|
||||
OTT string `json:"ott"`
|
||||
}
|
||||
|
||||
// Validate validates the SSHSignRequest.
|
||||
func (s *SSHRenewRequest) Validate() error {
|
||||
switch {
|
||||
case s.OTT == "":
|
||||
return errs.BadRequest("missing or empty ott")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SSHRenewResponse is the response object that returns the SSH certificate.
|
||||
type SSHRenewResponse struct {
|
||||
Certificate SSHCertificate `json:"crt"`
|
||||
IdentityCertificate []Certificate `json:"identityCrt,omitempty"`
|
||||
}
|
||||
|
||||
// SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRenewRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
_, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
newCert, err := a.RenewSSH(ctx, oldCert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
// Match identity cert with the SSH cert
|
||||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, newCert)
|
||||
render.JSONStatus(w, &SSHSignResponse{
|
||||
Certificate: SSHCertificate{newCert},
|
||||
IdentityCertificate: identity,
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
|
||||
func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Clone the certificate as we can modify it.
|
||||
cert, err := x509.ParseCertificate(r.TLS.PeerCertificates[0].Raw)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing client certificate")
|
||||
}
|
||||
|
||||
// Enforce the cert to match another certificate, for example an ssh
|
||||
// certificate.
|
||||
if !notBefore.IsZero() {
|
||||
cert.NotBefore = notBefore
|
||||
}
|
||||
if !notAfter.IsZero() {
|
||||
cert.NotAfter = notAfter
|
||||
}
|
||||
|
||||
certChain, err := mustAuthority(r.Context()).Renew(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return certChainToPEM(certChain), nil
|
||||
}
|
103
api/sshRevoke.go
103
api/sshRevoke.go
|
@ -1,103 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// SSHRevokeResponse is the response object that returns the health of the server.
|
||||
type SSHRevokeResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// SSHRevokeRequest is the request body for a revocation request.
|
||||
type SSHRevokeRequest struct {
|
||||
Serial string `json:"serial"`
|
||||
OTT string `json:"ott"`
|
||||
ReasonCode int `json:"reasonCode"`
|
||||
Reason string `json:"reason"`
|
||||
Passive bool `json:"passive"`
|
||||
}
|
||||
|
||||
// Validate checks the fields of the RevokeRequest and returns nil if they are ok
|
||||
// or an error if something is wrong.
|
||||
func (r *SSHRevokeRequest) Validate() (err error) {
|
||||
if r.Serial == "" {
|
||||
return errs.BadRequest("missing serial")
|
||||
}
|
||||
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
|
||||
return errs.BadRequest("reasonCode out of bounds")
|
||||
}
|
||||
if !r.Passive {
|
||||
return errs.NotImplemented("non-passive revocation not implemented")
|
||||
}
|
||||
if r.OTT == "" {
|
||||
return errs.BadRequest("missing ott")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Revoke supports handful of different methods that revoke a Certificate.
|
||||
//
|
||||
// NOTE: currently only Passive revocation is supported.
|
||||
func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
opts := &authority.RevokeOptions{
|
||||
Serial: body.Serial,
|
||||
Reason: body.Reason,
|
||||
ReasonCode: body.ReasonCode,
|
||||
PassiveOnly: body.Passive,
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
logOtt(w, body.OTT)
|
||||
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
logSSHRevoke(w, opts)
|
||||
render.JSON(w, &SSHRevokeResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"serial": ri.Serial,
|
||||
"reasonCode": ri.ReasonCode,
|
||||
"reason": ri.Reason,
|
||||
"passiveOnly": ri.PassiveOnly,
|
||||
"mTLS": ri.MTLS,
|
||||
"ssh": true,
|
||||
})
|
||||
}
|
||||
}
|
275
api/ssh_test.go
275
api/ssh_test.go
|
@ -2,15 +2,13 @@ package api
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
|
@ -18,13 +16,12 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -200,10 +197,6 @@ func TestSSHCertificate_UnmarshalJSON(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSignSSHRequest_Validate(t *testing.T) {
|
||||
csr := parseCertificateRequest(csrPEM)
|
||||
badCSR := parseCertificateRequest(csrPEM)
|
||||
badCSR.SignatureAlgorithm = x509.SHA1WithRSA
|
||||
|
||||
type fields struct {
|
||||
PublicKey []byte
|
||||
OTT string
|
||||
|
@ -212,24 +205,19 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
|||
ValidAfter TimeDuration
|
||||
ValidBefore TimeDuration
|
||||
AddUserPublicKey []byte
|
||||
KeyID string
|
||||
IdentityCSR CertificateRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false},
|
||||
{"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false},
|
||||
{"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false},
|
||||
{"ok-keyID", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{}}, false},
|
||||
{"ok-identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: csr}}, false},
|
||||
{"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
|
||||
{"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
|
||||
{"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
|
||||
{"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
|
||||
{"identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: badCSR}}, true},
|
||||
{"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||
{"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||
{"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, false},
|
||||
{"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||
{"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||
{"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||
{"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -241,8 +229,6 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
|||
ValidAfter: tt.fields.ValidAfter,
|
||||
ValidBefore: tt.fields.ValidBefore,
|
||||
AddUserPublicKey: tt.fields.AddUserPublicKey,
|
||||
KeyID: tt.fields.KeyID,
|
||||
IdentityCSR: tt.fields.IdentityCSR,
|
||||
}
|
||||
if err := s.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
@ -251,7 +237,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHSign(t *testing.T) {
|
||||
func Test_caHandler_SSHSign(t *testing.T) {
|
||||
user, err := getSignedUserCertificate()
|
||||
assert.FatalError(t, err)
|
||||
host, err := getSignedHostCertificate()
|
||||
|
@ -276,70 +262,53 @@ func Test_SSHSign(t *testing.T) {
|
|||
AddUserPublicKey: user.Key.Marshal(),
|
||||
})
|
||||
assert.FatalError(t, err)
|
||||
userIdentityReq, err := json.Marshal(SSHSignRequest{
|
||||
PublicKey: user.Key.Marshal(),
|
||||
OTT: "ott",
|
||||
IdentityCSR: CertificateRequest{parseCertificateRequest(csrPEM)},
|
||||
})
|
||||
assert.FatalError(t, err)
|
||||
identityCerts := []*x509.Certificate{
|
||||
parseCertificate(certPEM),
|
||||
}
|
||||
identityCertsPEM := []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req []byte
|
||||
authErr error
|
||||
signCert *ssh.Certificate
|
||||
signErr error
|
||||
addUserCert *ssh.Certificate
|
||||
addUserErr error
|
||||
tlsSignCerts []*x509.Certificate
|
||||
tlsSignErr error
|
||||
body []byte
|
||||
statusCode int
|
||||
name string
|
||||
req []byte
|
||||
authErr error
|
||||
signCert *ssh.Certificate
|
||||
signErr error
|
||||
addUserCert *ssh.Certificate
|
||||
addUserErr error
|
||||
body []byte
|
||||
statusCode int
|
||||
}{
|
||||
{"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, userB64)), http.StatusCreated},
|
||||
{"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, hostB64)), http.StatusCreated},
|
||||
{"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q,"addUserCrt":%q}`, userB64, userB64)), http.StatusCreated},
|
||||
{"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":%q,"identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated},
|
||||
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":%q,"ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized},
|
||||
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden},
|
||||
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
|
||||
{"fail-user-identity", userIdentityReq, nil, user, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden},
|
||||
{"ok-user", userReq, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated},
|
||||
{"ok-host", hostReq, nil, host, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated},
|
||||
{"ok-user-add", userAddReq, nil, user, nil, user, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated},
|
||||
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusUnauthorized},
|
||||
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
|
||||
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
h := New(&mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return []provisioner.SignOption{}, tt.authErr
|
||||
},
|
||||
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
return tt.signCert, tt.signErr
|
||||
},
|
||||
signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
return tt.addUserCert, tt.addUserErr
|
||||
},
|
||||
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
return tt.tlsSignCerts, tt.tlsSignErr
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
SSHSign(logging.NewResponseLogger(w), req)
|
||||
h.SSHSign(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.SignSSH unexpected error = %v", err)
|
||||
|
@ -353,7 +322,7 @@ func Test_SSHSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHRoots(t *testing.T) {
|
||||
func Test_caHandler_SSHRoots(t *testing.T) {
|
||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
|
@ -369,31 +338,31 @@ func Test_SSHRoots(t *testing.T) {
|
|||
body []byte
|
||||
statusCode int
|
||||
}{
|
||||
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK},
|
||||
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
|
||||
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK},
|
||||
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK},
|
||||
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK},
|
||||
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
|
||||
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK},
|
||||
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK},
|
||||
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
|
||||
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
h := New(&mockAuthority{
|
||||
getSSHRoots: func() (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
SSHRoots(logging.NewResponseLogger(w), req)
|
||||
h.SSHRoots(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.SSHRoots unexpected error = %v", err)
|
||||
|
@ -407,7 +376,7 @@ func Test_SSHRoots(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHFederation(t *testing.T) {
|
||||
func Test_caHandler_SSHFederation(t *testing.T) {
|
||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
|
@ -423,31 +392,31 @@ func Test_SSHFederation(t *testing.T) {
|
|||
body []byte
|
||||
statusCode int
|
||||
}{
|
||||
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK},
|
||||
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
|
||||
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK},
|
||||
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK},
|
||||
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK},
|
||||
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
|
||||
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK},
|
||||
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK},
|
||||
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
|
||||
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
h := New(&mockAuthority{
|
||||
getSSHFederation: func() (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
SSHFederation(logging.NewResponseLogger(w), req)
|
||||
h.SSHFederation(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.SSHFederation unexpected error = %v", err)
|
||||
|
@ -461,9 +430,9 @@ func Test_SSHFederation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHConfig(t *testing.T) {
|
||||
func Test_caHandler_SSHConfig(t *testing.T) {
|
||||
userOutput := []templates.Output{
|
||||
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
|
||||
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/config/ssh/known_hosts")},
|
||||
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
|
||||
}
|
||||
hostOutput := []templates.Output{
|
||||
|
@ -492,22 +461,22 @@ func Test_SSHConfig(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||
h := New(&mockAuthority{
|
||||
getSSHConfig: func(typ string, data map[string]string) ([]templates.Output, error) {
|
||||
return tt.output, tt.err
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
SSHConfig(logging.NewResponseLogger(w), req)
|
||||
h.SSHConfig(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.SSHConfig unexpected error = %v", err)
|
||||
|
@ -521,7 +490,7 @@ func Test_SSHConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHCheckHost(t *testing.T) {
|
||||
func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req string
|
||||
|
@ -539,22 +508,22 @@ func Test_SSHCheckHost(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
|
||||
h := New(&mockAuthority{
|
||||
checkSSHHost: func(_ string) (bool, error) {
|
||||
return tt.exists, tt.err
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||
h.SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err)
|
||||
|
@ -568,112 +537,6 @@ func Test_SSHCheckHost(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHGetHosts(t *testing.T) {
|
||||
hosts := []authority.Host{
|
||||
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
|
||||
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
|
||||
}
|
||||
hostsJSON, err := json.Marshal(hosts)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hosts []authority.Host
|
||||
err error
|
||||
body []byte
|
||||
statusCode int
|
||||
}{
|
||||
{"ok", hosts, nil, []byte(fmt.Sprintf(`{"hosts":%s}`, hostsJSON)), http.StatusOK},
|
||||
{"empty (array)", []authority.Host{}, nil, []byte(`{"hosts":[]}`), http.StatusOK},
|
||||
{"empty (nil)", nil, nil, []byte(`{"hosts":null}`), http.StatusOK},
|
||||
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
|
||||
return tt.hosts, tt.err
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err)
|
||||
}
|
||||
if tt.statusCode < http.StatusBadRequest {
|
||||
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||
t.Errorf("caHandler.SSHGetHosts Body = %s, wants %s", body, tt.body)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SSHBastion(t *testing.T) {
|
||||
bastion := &authority.Bastion{
|
||||
Hostname: "bastion.local",
|
||||
}
|
||||
bastionPort := &authority.Bastion{
|
||||
Hostname: "bastion.local",
|
||||
Port: "2222",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
bastion *authority.Bastion
|
||||
bastionErr error
|
||||
req []byte
|
||||
body []byte
|
||||
statusCode int
|
||||
}{
|
||||
{"ok", bastion, nil, []byte(`{"hostname":"host.local"}`), []byte(`{"hostname":"host.local","bastion":{"hostname":"bastion.local"}}`), http.StatusOK},
|
||||
{"ok", bastionPort, nil, []byte(`{"hostname":"host.local","user":"user"}`), []byte(`{"hostname":"host.local","bastion":{"hostname":"bastion.local","port":"2222"}}`), http.StatusOK},
|
||||
{"empty", nil, nil, []byte(`{"hostname":"host.local"}`), []byte(`{"hostname":"host.local"}`), http.StatusOK},
|
||||
{"bad json", bastion, nil, []byte(`bad json`), nil, http.StatusBadRequest},
|
||||
{"bad request", bastion, nil, []byte(`{"hostname": ""}`), nil, http.StatusBadRequest},
|
||||
{"error", nil, fmt.Errorf("an error"), []byte(`{"hostname":"host.local"}`), nil, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||
return tt.bastion, tt.bastionErr
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
SSHBastion(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.SSHBastion unexpected error = %v", err)
|
||||
}
|
||||
if tt.statusCode < http.StatusBadRequest {
|
||||
if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
|
||||
t.Errorf("caHandler.SSHBastion Body = %s, wants %s", body, tt.body)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHPublicKey_MarshalJSON(t *testing.T) {
|
||||
key, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
assert.FatalError(t, err)
|
||||
|
|
74
api/utils.go
Normal file
74
api/utils.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// EnableLogger is an interface that enables response logging for an object.
|
||||
type EnableLogger interface {
|
||||
ToLog() (interface{}, error)
|
||||
}
|
||||
|
||||
// LogError adds to the response writer the given error if it implements
|
||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||
// using the log package.
|
||||
func LogError(rw http.ResponseWriter, err error) {
|
||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"error": err,
|
||||
})
|
||||
} else {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
// LogEnabledResponse log the response object if it implements the EnableLogger
|
||||
// interface.
|
||||
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
|
||||
if el, ok := v.(EnableLogger); ok {
|
||||
out, err := el.ToLog()
|
||||
if err != nil {
|
||||
LogError(rw, err)
|
||||
return
|
||||
}
|
||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"response": out,
|
||||
})
|
||||
} else {
|
||||
log.Println(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JSON writes the passed value into the http.ResponseWriter.
|
||||
func JSON(w http.ResponseWriter, v interface{}) {
|
||||
JSONStatus(w, v, http.StatusOK)
|
||||
}
|
||||
|
||||
// JSONStatus writes the given value into the http.ResponseWriter and the
|
||||
// given status is written as the status code of the response.
|
||||
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
LogError(w, err)
|
||||
return
|
||||
}
|
||||
LogEnabledResponse(w, v)
|
||||
}
|
||||
|
||||
// ReadJSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
func ReadJSON(r io.Reader, v interface{}) error {
|
||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||
return BadRequest(errors.Wrap(err, "error decoding json"))
|
||||
}
|
||||
return nil
|
||||
}
|
124
api/utils_test.go
Normal file
124
api/utils_test.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
func TestLogError(t *testing.T) {
|
||||
theError := errors.New("the error")
|
||||
type args struct {
|
||||
rw http.ResponseWriter
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
withFields bool
|
||||
}{
|
||||
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
|
||||
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
LogError(tt.args.rw, tt.args.err)
|
||||
if tt.withFields {
|
||||
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
|
||||
fields := rl.Fields()
|
||||
if !reflect.DeepEqual(fields["error"], theError) {
|
||||
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
|
||||
}
|
||||
} else {
|
||||
t.Error("ResponseWriter does not implement logging.ResponseLogger")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
type args struct {
|
||||
rw http.ResponseWriter
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
ok bool
|
||||
}{
|
||||
{"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true},
|
||||
{"fail", args{httptest.NewRecorder(), make(chan int)}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rw := logging.NewResponseLogger(tt.args.rw)
|
||||
JSON(rw, tt.args.v)
|
||||
|
||||
rr, ok := tt.args.rw.(*httptest.ResponseRecorder)
|
||||
if !ok {
|
||||
t.Error("ResponseWriter does not implement *httptest.ResponseRecorder")
|
||||
return
|
||||
}
|
||||
|
||||
fields := rw.Fields()
|
||||
if tt.ok {
|
||||
if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" {
|
||||
t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body)
|
||||
}
|
||||
if len(fields) != 0 {
|
||||
t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields)
|
||||
}
|
||||
} else {
|
||||
if body := rr.Body.String(); body != "" {
|
||||
t.Errorf("Unexpected body = %s, want empty string", body)
|
||||
}
|
||||
if len(fields) != 1 {
|
||||
t.Errorf("ResponseLogger fields = %v, wants 1 element", fields)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadJSON(t *testing.T) {
|
||||
type args struct {
|
||||
r io.Reader
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
|
||||
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ReadJSON(tt.args.r, &tt.args.v)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr {
|
||||
e, ok := err.(*Error)
|
||||
if ok {
|
||||
if code := e.StatusCode(); code != 400 {
|
||||
t.Errorf("error.StatusCode() = %v, wants 400", code)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("error type = %T, wants *Error", err)
|
||||
}
|
||||
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
|
||||
t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue