forked from TrueCloudLab/certificates
Compare commits
19 commits
master
...
panos/api/
Author | SHA1 | Date | |
---|---|---|---|
|
845e41967d | ||
|
c3cc60e211 | ||
|
2e729ebb26 | ||
|
a5c171e750 | ||
|
a715e57d04 | ||
|
2fd84227f0 | ||
|
e82b21c1cb | ||
|
4cdb38b2e8 | ||
|
23c81db95a | ||
|
d49c00b0d7 | ||
|
098c2e1134 | ||
|
6636e87fc7 | ||
|
9b6c1f608e | ||
|
3389e57c48 | ||
|
9aa480a09a | ||
|
833ea1e695 | ||
|
b79af0456c | ||
|
eae0211a3e | ||
|
a31feae6d4 |
373 changed files with 20314 additions and 40207 deletions
56
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
56
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
|
@ -1,56 +0,0 @@
|
|||
name: Bug Report
|
||||
description: File a bug report
|
||||
title: "[Bug]: "
|
||||
labels: ["bug", "needs triage"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report!
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Tell us how to reproduce this issue.
|
||||
placeholder: These are the steps!
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: your-env
|
||||
attributes:
|
||||
label: Your Environment
|
||||
value: |-
|
||||
* OS -
|
||||
* `step-ca` Version -
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: What did you expect to happen?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: What happens instead?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context about the problem here.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: contributing
|
||||
attributes:
|
||||
label: Contributing
|
||||
value: |
|
||||
Vote on this issue by adding a 👍 reaction.
|
||||
To contribute a fix for this issue, leave a comment (and link to your pull request, if you've opened one already).
|
||||
validations:
|
||||
required: false
|
27
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
27
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
|
@ -0,0 +1,27 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: bug, needs triage
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
### Subject of the issue
|
||||
Describe your issue here.
|
||||
|
||||
### Your environment
|
||||
* OS -
|
||||
* Version -
|
||||
|
||||
### Steps to reproduce
|
||||
Tell us how to reproduce this issue. Please provide a working demo, you can use [this template](https://plnkr.co/edit/XorWgI?p=preview) as a base.
|
||||
|
||||
### Expected behaviour
|
||||
Tell us what should happen
|
||||
|
||||
### Actual behaviour
|
||||
Tell us what happens instead
|
||||
|
||||
### Additional context
|
||||
Add any other context about the problem here.
|
12
.github/ISSUE_TEMPLATE/documentation-request.md
vendored
12
.github/ISSUE_TEMPLATE/documentation-request.md
vendored
|
@ -1,20 +1,12 @@
|
|||
---
|
||||
name: Documentation Request
|
||||
about: Request documentation for a feature
|
||||
title: '[Docs]:'
|
||||
labels: docs, needs triage
|
||||
title: ''
|
||||
labels: documentation, needs triage
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Hello!
|
||||
<!-- Please leave this section as-is, it's designed to help others in the community know how to interact with our GitHub issues. -->
|
||||
|
||||
- Vote on this issue by adding a 👍 reaction
|
||||
- If you want to document this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
|
||||
|
||||
## Affected area/feature
|
||||
|
||||
<!---
|
||||
Tell us which feature you'd like to see documented.
|
||||
- Where would you like that documentation to live (command line usage output, website, github markdown on the repo)?
|
||||
|
|
17
.github/ISSUE_TEMPLATE/enhancement.md
vendored
17
.github/ISSUE_TEMPLATE/enhancement.md
vendored
|
@ -1,24 +1,13 @@
|
|||
---
|
||||
name: Enhancement
|
||||
about: Suggest an enhancement to step-ca
|
||||
about: Suggest an enhancement to step certificates
|
||||
title: ''
|
||||
labels: enhancement, needs triage
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Hello!
|
||||
<!-- Please leave this section as-is,
|
||||
it's designed to help others in the community know how to interact with our GitHub issues. -->
|
||||
### What would you like to be added
|
||||
|
||||
- Vote on this issue by adding a 👍 reaction
|
||||
- If you want to implement this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
|
||||
|
||||
## Issue details
|
||||
|
||||
<!-- Enhancement requests are most helpful when they describe the problem you're having
|
||||
as well as articulating the potential solution you'd like to see built. -->
|
||||
|
||||
## Why is this needed?
|
||||
|
||||
<!-- Let us know why you think this enhancement would be good for the project or community. -->
|
||||
### Why this is needed
|
||||
|
|
20
.github/PULL_REQUEST_TEMPLATE
vendored
20
.github/PULL_REQUEST_TEMPLATE
vendored
|
@ -1,20 +1,4 @@
|
|||
<!---
|
||||
Please provide answers in the spaces below each prompt, where applicable.
|
||||
Not every PR requires responses for each prompt.
|
||||
Use your discretion.
|
||||
-->
|
||||
#### Name of feature:
|
||||
|
||||
#### Pain or issue this feature alleviates:
|
||||
|
||||
#### Why is this important to the project (if not answered above):
|
||||
|
||||
#### Is there documentation on how to use this feature? If so, where?
|
||||
|
||||
#### In what environments or workflows is this feature supported?
|
||||
|
||||
#### In what environments or workflows is this feature explicitly NOT supported (if any)?
|
||||
|
||||
#### Supporting links/other PRs/issues:
|
||||
### Description
|
||||
Please describe your pull request.
|
||||
|
||||
💔Thank you!
|
||||
|
|
11
.github/dependabot.yml
vendored
11
.github/dependabot.yml
vendored
|
@ -1,11 +0,0 @@
|
|||
# To get started with Dependabot version updates, you'll need to specify which
|
||||
# package ecosystems to update and where the package manifests are located.
|
||||
# Please see the documentation for all configuration options:
|
||||
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
|
||||
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "gomod" # See documentation for possible values
|
||||
directory: "/" # Location of package manifests
|
||||
schedule:
|
||||
interval: "weekly"
|
4
.github/labeler.yml
vendored
Normal file
4
.github/labeler.yml
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
needs triage:
|
||||
- '**' # index.php | src/main.php
|
||||
- '.*' # .gitignore
|
||||
- '.*/**' # .github/workflows/label.yml
|
27
.github/workflows/ci.yml
vendored
27
.github/workflows/ci.yml
vendored
|
@ -1,27 +0,0 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags-ignore:
|
||||
- 'v*'
|
||||
branches:
|
||||
- "master"
|
||||
pull_request:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: true
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ci:
|
||||
uses: smallstep/workflows/.github/workflows/goCI.yml@main
|
||||
with:
|
||||
only-latest-golang: false
|
||||
os-dependencies: 'libpcsclite-dev'
|
||||
run-codeql: true
|
||||
test-command: 'V=1 make test'
|
||||
secrets: inherit
|
9
.github/workflows/code-scan-cron.yml
vendored
9
.github/workflows/code-scan-cron.yml
vendored
|
@ -1,9 +0,0 @@
|
|||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
|
||||
jobs:
|
||||
code-scan:
|
||||
uses: smallstep/workflows/.github/workflows/code-scan.yml@main
|
||||
secrets:
|
||||
GITLEAKS_LICENSE_KEY: ${{ secrets.GITLEAKS_LICENSE_KEY }}
|
22
.github/workflows/dependabot-auto-merge.yml
vendored
22
.github/workflows/dependabot-auto-merge.yml
vendored
|
@ -1,22 +0,0 @@
|
|||
name: Dependabot auto-merge
|
||||
on: pull_request
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
dependabot:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.actor == 'dependabot[bot]' }}
|
||||
steps:
|
||||
- name: Dependabot metadata
|
||||
id: metadata
|
||||
uses: dependabot/fetch-metadata@v1.1.1
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
- name: Enable auto-merge for Dependabot PRs
|
||||
run: gh pr merge --auto --merge "$PR_URL"
|
||||
env:
|
||||
PR_URL: ${{github.event.pull_request.html_url}}
|
||||
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
|
12
.github/workflows/labeler.yml
vendored
Normal file
12
.github/workflows/labeler.yml
vendored
Normal file
|
@ -0,0 +1,12 @@
|
|||
name: Pull Request Labeler
|
||||
on:
|
||||
pull_request_target
|
||||
|
||||
jobs:
|
||||
label:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v3.0.2
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
200
.github/workflows/release.yml
vendored
200
.github/workflows/release.yml
vendored
|
@ -7,43 +7,81 @@ on:
|
|||
- 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10
|
||||
|
||||
jobs:
|
||||
ci:
|
||||
uses: smallstep/certificates/.github/workflows/ci.yml@master
|
||||
secrets: inherit
|
||||
test:
|
||||
name: Lint, Test, Build
|
||||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
matrix:
|
||||
go: [ '1.15', '1.16', '1.17' ]
|
||||
outputs:
|
||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
-
|
||||
name: Setup Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
-
|
||||
name: Install Deps
|
||||
id: install-deps
|
||||
run: sudo apt-get -y install libpcsclite-dev
|
||||
-
|
||||
name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||
version: 'v1.44.0'
|
||||
|
||||
# Optional: working directory, useful for monorepos
|
||||
# working-directory: somedir
|
||||
|
||||
# Optional: golangci-lint command line arguments.
|
||||
args: --timeout=30m
|
||||
|
||||
# Optional: show only new issues if it's a pull request. The default value is `false`.
|
||||
# only-new-issues: true
|
||||
|
||||
# Optional: if set to true then the action will use pre-installed Go.
|
||||
# skip-go-installation: true
|
||||
|
||||
# Optional: if set to true then the action don't cache or restore ~/go/pkg.
|
||||
# skip-pkg-cache: true
|
||||
|
||||
# Optional: if set to true then the action don't cache or restore ~/.cache/go-build.
|
||||
# skip-build-cache: true
|
||||
-
|
||||
name: Test, Build
|
||||
id: lint_test_build
|
||||
run: V=1 make ci
|
||||
|
||||
create_release:
|
||||
name: Create Release
|
||||
needs: ci
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DOCKER_IMAGE: smallstep/step-ca
|
||||
needs: test
|
||||
runs-on: ubuntu-20.04
|
||||
outputs:
|
||||
version: ${{ steps.extract-tag.outputs.VERSION }}
|
||||
debversion: ${{ steps.extract-tag.outputs.DEB_VERSION }}
|
||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||
docker_tags: ${{ env.DOCKER_TAGS }}
|
||||
docker_tags_hsm: ${{ env.DOCKER_TAGS_HSM }}
|
||||
steps:
|
||||
- name: Is Pre-release
|
||||
-
|
||||
name: Extract Tag Names
|
||||
id: extract-tag
|
||||
run: |
|
||||
DEB_VERSION=$(echo ${GITHUB_REF#refs/tags/v} | sed 's/-/./')
|
||||
echo "::set-output name=DEB_VERSION::${DEB_VERSION}"
|
||||
-
|
||||
name: Is Pre-release
|
||||
id: is_prerelease
|
||||
run: |
|
||||
set +e
|
||||
echo ${{ github.ref }} | grep "\-rc.*"
|
||||
OUT=$?
|
||||
if [ $OUT -eq 0 ]; then IS_PRERELEASE=true; else IS_PRERELEASE=false; fi
|
||||
echo "IS_PRERELEASE=${IS_PRERELEASE}" >> ${GITHUB_OUTPUT}
|
||||
- name: Extract Tag Names
|
||||
id: extract-tag
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
echo "VERSION=${VERSION}" >> ${GITHUB_OUTPUT}
|
||||
echo "DOCKER_TAGS=${{ env.DOCKER_IMAGE }}:${VERSION}" >> ${GITHUB_ENV}
|
||||
echo "DOCKER_TAGS_HSM=${{ env.DOCKER_IMAGE }}:${VERSION}-hsm" >> ${GITHUB_ENV}
|
||||
- name: Add Latest Tag
|
||||
if: steps.is_prerelease.outputs.IS_PRERELEASE == 'false'
|
||||
run: |
|
||||
echo "DOCKER_TAGS=${{ env.DOCKER_TAGS }},${{ env.DOCKER_IMAGE }}:latest" >> ${GITHUB_ENV}
|
||||
echo "DOCKER_TAGS_HSM=${{ env.DOCKER_TAGS_HSM }},${{ env.DOCKER_IMAGE }}:hsm" >> ${GITHUB_ENV}
|
||||
- name: Create Release
|
||||
echo "::set-output name=IS_PRERELEASE::${IS_PRERELEASE}"
|
||||
-
|
||||
name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
|
@ -55,37 +93,89 @@ jobs:
|
|||
prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||
|
||||
goreleaser:
|
||||
name: Upload Assets To Github w/ goreleaser
|
||||
runs-on: ubuntu-20.04
|
||||
needs: create_release
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
uses: smallstep/workflows/.github/workflows/goreleaser.yml@main
|
||||
secrets: inherit
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.17
|
||||
-
|
||||
name: APT Install
|
||||
id: aptInstall
|
||||
run: sudo apt-get -y install build-essential debhelper fakeroot
|
||||
-
|
||||
name: Build Debian package
|
||||
id: make_debian
|
||||
run: |
|
||||
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
|
||||
make debian
|
||||
# need to restore the git state otherwise goreleaser fails due to dirty state
|
||||
git restore debian/changelog
|
||||
git clean -fd
|
||||
-
|
||||
name: Install cosign
|
||||
uses: sigstore/cosign-installer@v1.1.0
|
||||
with:
|
||||
cosign-release: 'v1.1.0'
|
||||
-
|
||||
name: Write cosign key to disk
|
||||
id: write_key
|
||||
run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
|
||||
-
|
||||
name: Get Release Date
|
||||
id: release_date
|
||||
run: |
|
||||
RELEASE_DATE=$(date +"%y-%m-%d")
|
||||
echo "::set-output name=RELEASE_DATE::${RELEASE_DATE}"
|
||||
-
|
||||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0
|
||||
with:
|
||||
version: latest
|
||||
args: release --rm-dist
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.PAT }}
|
||||
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}
|
||||
DEB_VERSION: ${{ needs.create_release.outputs.debversion }}
|
||||
RELEASE_DATE: ${{ steps.release_date.outputs.RELEASE_DATE }}
|
||||
|
||||
build_upload_docker:
|
||||
name: Build & Upload Docker Images
|
||||
needs: create_release
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
uses: smallstep/workflows/.github/workflows/docker-buildx-push.yml@main
|
||||
with:
|
||||
platforms: linux/amd64,linux/386,linux/arm,linux/arm64
|
||||
tags: ${{ needs.create_release.outputs.docker_tags }}
|
||||
docker_image: smallstep/step-ca
|
||||
docker_file: docker/Dockerfile
|
||||
secrets: inherit
|
||||
|
||||
build_upload_docker_hsm:
|
||||
name: Build & Upload HSM Enabled Docker Images
|
||||
needs: create_release
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
uses: smallstep/workflows/.github/workflows/docker-buildx-push.yml@main
|
||||
with:
|
||||
platforms: linux/amd64,linux/386,linux/arm,linux/arm64
|
||||
tags: ${{ needs.create_release.outputs.docker_tags_hsm }}
|
||||
docker_image: smallstep/step-ca
|
||||
docker_file: docker/Dockerfile.hsm
|
||||
secrets: inherit
|
||||
runs-on: ubuntu-20.04
|
||||
needs: test
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
-
|
||||
name: Setup Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: '1.17'
|
||||
-
|
||||
name: Install cosign
|
||||
uses: sigstore/cosign-installer@v1.1.0
|
||||
with:
|
||||
cosign-release: 'v1.1.0'
|
||||
-
|
||||
name: Write cosign key to disk
|
||||
id: write_key
|
||||
run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
|
||||
-
|
||||
name: Build
|
||||
id: build
|
||||
run: |
|
||||
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
|
||||
make docker-artifacts
|
||||
env:
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }}
|
||||
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}
|
||||
|
|
66
.github/workflows/test.yml
vendored
Normal file
66
.github/workflows/test.yml
vendored
Normal file
|
@ -0,0 +1,66 @@
|
|||
name: Lint, Test, Build
|
||||
|
||||
on:
|
||||
push:
|
||||
tags-ignore:
|
||||
- 'v*'
|
||||
branches:
|
||||
- "**"
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
lintTestBuild:
|
||||
name: Lint, Test, Build
|
||||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
matrix:
|
||||
go: [ '1.16', '1.17' ]
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
-
|
||||
name: Setup Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
-
|
||||
name: Install Deps
|
||||
id: install-deps
|
||||
run: sudo apt-get -y install libpcsclite-dev
|
||||
-
|
||||
name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||
version: 'v1.44.0'
|
||||
|
||||
# Optional: working directory, useful for monorepos
|
||||
# working-directory: somedir
|
||||
|
||||
# Optional: golangci-lint command line arguments.
|
||||
args: --timeout=30m
|
||||
|
||||
# Optional: show only new issues if it's a pull request. The default value is `false`.
|
||||
# only-new-issues: true
|
||||
|
||||
# Optional: if set to true then the action will use pre-installed Go.
|
||||
# skip-go-installation: true
|
||||
|
||||
# Optional: if set to true then the action don't cache or restore ~/go/pkg.
|
||||
# skip-pkg-cache: true
|
||||
|
||||
# Optional: if set to true then the action don't cache or restore ~/.cache/go-build.
|
||||
# skip-build-cache: true
|
||||
-
|
||||
name: Test, Build
|
||||
id: lint_test_build
|
||||
run: V=1 make ci
|
||||
-
|
||||
name: Codecov
|
||||
if: matrix.go == '1.17'
|
||||
uses: codecov/codecov-action@v1.2.1
|
||||
with:
|
||||
file: ./coverage.out # optional
|
||||
name: codecov-umbrella # optional
|
||||
fail_ci_if_error: true # optional (default = false)
|
16
.github/workflows/triage.yml
vendored
16
.github/workflows/triage.yml
vendored
|
@ -1,16 +0,0 @@
|
|||
name: Add Issues and PRs to Triage
|
||||
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
- opened
|
||||
- reopened
|
||||
pull_request_target:
|
||||
types:
|
||||
- opened
|
||||
- reopened
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
uses: smallstep/workflows/.github/workflows/triage.yml@main
|
||||
secrets: inherit
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -6,10 +6,6 @@
|
|||
*.so
|
||||
*.dylib
|
||||
|
||||
# Go Workspaces
|
||||
go.work
|
||||
go.work.sum
|
||||
|
||||
# Test binary, build with `go test -c`
|
||||
*.test
|
||||
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:85
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:107
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:108
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:129
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:131
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:136
|
||||
deac15327f5605a1a963e50818760a95cee9d882:docs/kms.md:generic-api-key:138
|
||||
7c9ab9814fb676cb3c125c3dac4893271f1b7ae5:README.md:generic-api-key:282
|
||||
fb7140444ac8f1fa1245a80e49d17e206f7435f3:docs/provisioners.md:generic-api-key:110
|
||||
e4de7f07e82118b3f926716666b620db058fa9f7:docs/revocation.md:generic-api-key:73
|
||||
e4de7f07e82118b3f926716666b620db058fa9f7:docs/revocation.md:generic-api-key:113
|
||||
e4de7f07e82118b3f926716666b620db058fa9f7:docs/revocation.md:generic-api-key:151
|
||||
8b2de42e9cf6ce99f53a5049881e1d6077d5d66e:docs/docker.md:generic-api-key:152
|
||||
3939e855264117e81531df777a642ea953d325a7:autocert/init/ca/intermediate_ca_key:private-key:1
|
||||
e72f08703753facfa05f2d8c68f9f6a3745824b8:README.md:generic-api-key:244
|
||||
e70a5dae7de0b6ca40a0393c09c28872d4cfa071:autocert/README.md:generic-api-key:365
|
||||
e70a5dae7de0b6ca40a0393c09c28872d4cfa071:autocert/README.md:generic-api-key:366
|
||||
c284a2c0ab1c571a46443104be38c873ef0c7c6d:config.json:generic-api-key:10
|
75
.golangci.yml
Normal file
75
.golangci.yml
Normal file
|
@ -0,0 +1,75 @@
|
|||
linters-settings:
|
||||
govet:
|
||||
check-shadowing: true
|
||||
settings:
|
||||
printf:
|
||||
funcs:
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Infof
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf
|
||||
revive:
|
||||
min-confidence: 0
|
||||
gocyclo:
|
||||
min-complexity: 10
|
||||
maligned:
|
||||
suggest-new: true
|
||||
dupl:
|
||||
threshold: 100
|
||||
goconst:
|
||||
min-len: 2
|
||||
min-occurrences: 2
|
||||
depguard:
|
||||
list-type: blacklist
|
||||
packages:
|
||||
# logging is allowed only by logutils.Log, logrus
|
||||
# is allowed to use only in logutils package
|
||||
- github.com/sirupsen/logrus
|
||||
misspell:
|
||||
locale: US
|
||||
lll:
|
||||
line-length: 140
|
||||
goimports:
|
||||
local-prefixes: github.com/golangci/golangci-lint
|
||||
gocritic:
|
||||
enabled-tags:
|
||||
- performance
|
||||
- style
|
||||
- experimental
|
||||
- diagnostic
|
||||
disabled-checks:
|
||||
- commentFormatting
|
||||
- commentedOutCode
|
||||
- evalOrder
|
||||
- hugeParam
|
||||
- octalLiteral
|
||||
- rangeValCopy
|
||||
- tooManyResultsChecker
|
||||
- unnamedResult
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- deadcode
|
||||
- gocritic
|
||||
- gofmt
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- misspell
|
||||
- revive
|
||||
- staticcheck
|
||||
- unused
|
||||
|
||||
run:
|
||||
skip-dirs:
|
||||
- pkg
|
||||
|
||||
issues:
|
||||
exclude:
|
||||
- can't lint
|
||||
- declaration of "err" shadows declaration at line
|
||||
- should have a package comment, unless it's in another file for this package
|
||||
- error strings should not be capitalized or end with punctuation or a newline
|
||||
- Wrapf call needs 1 arg but has 2 args
|
||||
- cs.NegotiatedProtocolIsMutual is deprecated
|
206
.goreleaser.yml
206
.goreleaser.yml
|
@ -19,24 +19,62 @@ builds:
|
|||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_5
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
flags:
|
||||
- -trimpath
|
||||
main: ./cmd/step-ca/main.go
|
||||
binary: step-ca
|
||||
binary: bin/step-ca
|
||||
ldflags:
|
||||
- -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}}
|
||||
-
|
||||
id: step-cloudkms-init
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
targets:
|
||||
- darwin_amd64
|
||||
- darwin_arm64
|
||||
- freebsd_amd64
|
||||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
flags:
|
||||
- -trimpath
|
||||
main: ./cmd/step-cloudkms-init/main.go
|
||||
binary: bin/step-cloudkms-init
|
||||
ldflags:
|
||||
- -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}}
|
||||
-
|
||||
id: step-awskms-init
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
targets:
|
||||
- darwin_amd64
|
||||
- darwin_arm64
|
||||
- freebsd_amd64
|
||||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
flags:
|
||||
- -trimpath
|
||||
main: ./cmd/step-awskms-init/main.go
|
||||
binary: bin/step-awskms-init
|
||||
ldflags:
|
||||
- -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}}
|
||||
|
||||
archives:
|
||||
- &ARCHIVE
|
||||
-
|
||||
# Can be used to change the archive formats for specific GOOSs.
|
||||
# Most common use case is to archive as zip on Windows.
|
||||
# Default is empty.
|
||||
name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Version }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
|
||||
rlcp: true
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
format: zip
|
||||
|
@ -44,51 +82,9 @@ archives:
|
|||
files:
|
||||
- README.md
|
||||
- LICENSE
|
||||
allow_different_binary_count: true
|
||||
-
|
||||
<< : *ARCHIVE
|
||||
id: unversioned
|
||||
name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
|
||||
|
||||
|
||||
nfpms:
|
||||
# Configure nFPM for .deb and .rpm releases
|
||||
#
|
||||
# See https://nfpm.goreleaser.com/configuration/
|
||||
# and https://goreleaser.com/customization/nfpm/
|
||||
#
|
||||
# Useful tools for debugging .debs:
|
||||
# List file contents: dpkg -c dist/step_...deb
|
||||
# Package metadata: dpkg --info dist/step_....deb
|
||||
#
|
||||
- &NFPM
|
||||
builds:
|
||||
- step-ca
|
||||
package_name: step-ca
|
||||
file_name_template: "{{ .PackageName }}_{{ .Version }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
|
||||
vendor: Smallstep Labs
|
||||
homepage: https://github.com/smallstep/certificates
|
||||
maintainer: Smallstep <techadmin@smallstep.com>
|
||||
description: >
|
||||
step-ca is an online certificate authority for secure, automated certificate management.
|
||||
license: Apache 2.0
|
||||
section: utils
|
||||
formats:
|
||||
- deb
|
||||
- rpm
|
||||
priority: optional
|
||||
bindir: /usr/bin
|
||||
contents:
|
||||
- src: debian/copyright
|
||||
dst: /usr/share/doc/step-ca/copyright
|
||||
-
|
||||
<< : *NFPM
|
||||
id: unversioned
|
||||
file_name_template: "{{ .PackageName }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
|
||||
|
||||
source:
|
||||
enabled: true
|
||||
rlcp: true
|
||||
name_template: '{{ .ProjectName }}_{{ .Version }}'
|
||||
|
||||
checksum:
|
||||
|
@ -98,9 +94,8 @@ checksum:
|
|||
|
||||
signs:
|
||||
- cmd: cosign
|
||||
signature: "${artifact}.sig"
|
||||
certificate: "${artifact}.pem"
|
||||
args: ["sign-blob", "--oidc-issuer=https://token.actions.githubusercontent.com", "--output-certificate=${certificate}", "--output-signature=${signature}", "${artifact}"]
|
||||
stdin: '{{ .Env.COSIGN_PWD }}'
|
||||
args: ["sign-blob", "-key=/tmp/cosign.key", "-output=${signature}", "${artifact}"]
|
||||
artifacts: all
|
||||
|
||||
snapshot:
|
||||
|
@ -141,17 +136,17 @@ release:
|
|||
|
||||
#### Linux
|
||||
|
||||
- 📦 [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz)
|
||||
- 📦 [step-ca_{{ .Version }}_amd64.deb](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ .Version }}_amd64.deb)
|
||||
- 📦 [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz)
|
||||
- 📦 [step-ca_{{ .Env.DEB_VERSION }}_amd64.deb](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ .Env.DEB_VERSION }}_amd64.deb)
|
||||
|
||||
#### OSX Darwin
|
||||
|
||||
- 📦 [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz)
|
||||
- 📦 [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz)
|
||||
- 📦 [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz)
|
||||
- 📦 [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz)
|
||||
|
||||
#### Windows
|
||||
|
||||
- 📦 [step-ca_windows_{{ .Version }}_amd64.zip](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip)
|
||||
- 📦 [step-ca_windows_{{ .Version }}_arm64.zip](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip)
|
||||
|
||||
For more builds across platforms and architectures, see the `Assets` section below.
|
||||
And for packaged versions (Docker, k8s, Homebrew), see our [installation docs](https://smallstep.com/docs/step-ca/installation).
|
||||
|
@ -166,10 +161,8 @@ release:
|
|||
|
||||
```
|
||||
cosign verify-blob \
|
||||
--certificate ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig.pem \
|
||||
--signature ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig \
|
||||
--certificate-identity-regexp "https://github\.com/smallstep/certificates/.*" \
|
||||
--certificate-oidc-issuer https://token.actions.githubusercontent.com \
|
||||
-key https://raw.githubusercontent.com/smallstep/certificates/master/cosign.pub \
|
||||
-signature ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig
|
||||
~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz
|
||||
```
|
||||
|
||||
|
@ -199,40 +192,77 @@ release:
|
|||
# - glob: ./glob/**/to/**/file/**/*
|
||||
# - glob: ./glob/foo/to/bar/file/foobar/override_from_previous
|
||||
|
||||
scoops:
|
||||
-
|
||||
ids: [ default ]
|
||||
# Template for the url which is determined by the given Token (github or gitlab)
|
||||
# Default for github is "https://github.com/<repo_owner>/<repo_name>/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
|
||||
# Default for gitlab is "https://gitlab.com/<repo_owner>/<repo_name>/uploads/{{ .ArtifactUploadHash }}/{{ .ArtifactName }}"
|
||||
# Default for gitea is "https://gitea.com/<repo_owner>/<repo_name>/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
|
||||
url_template: "http://github.com/smallstep/certificates/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
|
||||
# Repository to push the app manifest to.
|
||||
bucket:
|
||||
owner: smallstep
|
||||
name: scoop-bucket
|
||||
scoop:
|
||||
# Template for the url which is determined by the given Token (github or gitlab)
|
||||
# Default for github is "https://github.com/<repo_owner>/<repo_name>/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
|
||||
# Default for gitlab is "https://gitlab.com/<repo_owner>/<repo_name>/uploads/{{ .ArtifactUploadHash }}/{{ .ArtifactName }}"
|
||||
# Default for gitea is "https://gitea.com/<repo_owner>/<repo_name>/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
|
||||
url_template: "http://github.com/smallstep/certificates/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
|
||||
|
||||
# Git author used to commit to the repository.
|
||||
# Defaults are shown.
|
||||
commit_author:
|
||||
name: goreleaserbot
|
||||
email: goreleaser@smallstep.com
|
||||
# Repository to push the app manifest to.
|
||||
bucket:
|
||||
owner: smallstep
|
||||
name: scoop-bucket
|
||||
|
||||
# The project name and current git tag are used in the format string.
|
||||
commit_msg_template: "Scoop update for {{ .ProjectName }} version {{ .Tag }}"
|
||||
# Git author used to commit to the repository.
|
||||
# Defaults are shown.
|
||||
commit_author:
|
||||
name: goreleaserbot
|
||||
email: goreleaser@smallstep.com
|
||||
|
||||
# Your app's homepage.
|
||||
# Default is empty.
|
||||
homepage: "https://smallstep.com/docs/step-ca"
|
||||
# The project name and current git tag are used in the format string.
|
||||
commit_msg_template: "Scoop update for {{ .ProjectName }} version {{ .Tag }}"
|
||||
|
||||
# Skip uploads for prerelease.
|
||||
skip_upload: auto
|
||||
# Your app's homepage.
|
||||
# Default is empty.
|
||||
homepage: "https://smallstep.com/docs/step-ca"
|
||||
|
||||
# Your app's description.
|
||||
# Default is empty.
|
||||
description: "A private certificate authority (X.509 & SSH) & ACME server for secure automated certificate management, so you can use TLS everywhere & SSO for SSH."
|
||||
# Skip uploads for prerelease.
|
||||
skip_upload: auto
|
||||
|
||||
# Your app's license
|
||||
# Default is empty.
|
||||
license: "Apache-2.0"
|
||||
# Your app's description.
|
||||
# Default is empty.
|
||||
description: "A private certificate authority (X.509 & SSH) & ACME server for secure automated certificate management, so you can use TLS everywhere & SSO for SSH."
|
||||
|
||||
# Your app's license
|
||||
# Default is empty.
|
||||
license: "Apache-2.0"
|
||||
|
||||
#dockers:
|
||||
# - dockerfile: docker/Dockerfile
|
||||
# goos: linux
|
||||
# goarch: amd64
|
||||
# use_buildx: true
|
||||
# image_templates:
|
||||
# - "smallstep/step-cli:latest"
|
||||
# - "smallstep/step-cli:{{ .Tag }}"
|
||||
# build_flag_templates:
|
||||
# - "--platform=linux/amd64"
|
||||
# - dockerfile: docker/Dockerfile
|
||||
# goos: linux
|
||||
# goarch: 386
|
||||
# use_buildx: true
|
||||
# image_templates:
|
||||
# - "smallstep/step-cli:latest"
|
||||
# - "smallstep/step-cli:{{ .Tag }}"
|
||||
# build_flag_templates:
|
||||
# - "--platform=linux/386"
|
||||
# - dockerfile: docker/Dockerfile
|
||||
# goos: linux
|
||||
# goarch: arm
|
||||
# goarm: 7
|
||||
# use_buildx: true
|
||||
# image_templates:
|
||||
# - "smallstep/step-cli:latest"
|
||||
# - "smallstep/step-cli:{{ .Tag }}"
|
||||
# build_flag_templates:
|
||||
# - "--platform=linux/arm/v7"
|
||||
# - dockerfile: docker/Dockerfile
|
||||
# goos: linux
|
||||
# goarch: arm64
|
||||
# use_buildx: true
|
||||
# image_templates:
|
||||
# - "smallstep/step-cli:latest"
|
||||
# - "smallstep/step-cli:{{ .Tag }}"
|
||||
# build_flag_templates:
|
||||
# - "--platform=linux/arm64/v8"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#!/usr/bin/env sh
|
||||
#!/usr/bin/env bash
|
||||
read -r firstline < .VERSION
|
||||
last_half="${firstline##*tag: }"
|
||||
if [[ ${last_half::1} == "v" ]]; then
|
||||
|
|
312
CHANGELOG.md
312
CHANGELOG.md
|
@ -1,407 +1,103 @@
|
|||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
||||
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## TEMPLATE -- do not alter or remove
|
||||
|
||||
---
|
||||
|
||||
## [x.y.z] - aaaa-bb-cc
|
||||
|
||||
## [Unreleased - 0.18.3] - DATE
|
||||
### Added
|
||||
|
||||
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
|
||||
### Changed
|
||||
|
||||
- Made SCEP CA URL paths dynamic
|
||||
### Deprecated
|
||||
|
||||
### Removed
|
||||
|
||||
### Fixed
|
||||
|
||||
### Security
|
||||
|
||||
---
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Fixed
|
||||
|
||||
- Improved authentication for ACME requests using kid and provisioner name
|
||||
(smallstep/certificates#1386).
|
||||
|
||||
|
||||
## [v0.24.2] - 2023-05-11
|
||||
|
||||
### Added
|
||||
|
||||
- Log SSH certificates (smallstep/certificates#1374)
|
||||
- CRL endpoints on the HTTP server (smallstep/certificates#1372)
|
||||
- Dynamic SCEP challenge validation using webhooks (smallstep/certificates#1366)
|
||||
- For Docker deployments, added DOCKER_STEPCA_INIT_PASSWORD_FILE. Useful for pointing to a Docker Secret in the container (smallstep/certificates#1384)
|
||||
|
||||
### Changed
|
||||
|
||||
- Depend on [smallstep/go-attestation](https://github.com/smallstep/go-attestation) instead of [google/go-attestation](https://github.com/google/go-attestation)
|
||||
- Render CRLs into http.ResponseWriter instead of memory (smallstep/certificates#1373)
|
||||
- Redaction of SCEP static challenge when listing provisioners (smallstep/certificates#1204)
|
||||
|
||||
### Fixed
|
||||
|
||||
- VaultCAS certificate lifetime (smallstep/certificates#1376)
|
||||
|
||||
## [v0.24.1] - 2023-04-14
|
||||
|
||||
### Fixed
|
||||
|
||||
- Docker image name for HSM support (smallstep/certificates#1348)
|
||||
|
||||
## [v0.24.0] - 2023-04-12
|
||||
|
||||
### Added
|
||||
|
||||
- Add ACME `device-attest-01` support with TPM 2.0
|
||||
(smallstep/certificates#1063).
|
||||
- Add support for new Azure SDK, sovereign clouds, and HSM keys on Azure KMS
|
||||
(smallstep/crypto#192, smallstep/crypto#197, smallstep/crypto#198,
|
||||
smallstep/certificates#1323, smallstep/certificates#1309).
|
||||
- Add support for ASN.1 functions on certificate templates
|
||||
(smallstep/crypto#208, smallstep/certificates#1345)
|
||||
- Add `DOCKER_STEPCA_INIT_ADDRESS` to configure the address to use in a docker
|
||||
container (smallstep/certificates#1262).
|
||||
- Make sure that the CSR used matches the attested key when using AME
|
||||
`device-attest-01` challenge (smallstep/certificates#1265).
|
||||
- Add support for compacting the Badger DB (smallstep/certificates#1298).
|
||||
- Build and release cleanups (smallstep/certificates#1322,
|
||||
smallstep/certificates#1329, smallstep/certificates#1340).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fix support for PKCS #7 RSA-OAEP decryption through
|
||||
[smallstep/pkcs7#4](https://github.com/smallstep/pkcs7/pull/4), as used in
|
||||
SCEP.
|
||||
- Fix RA installation using `scripts/install-step-ra.sh`
|
||||
(smallstep/certificates#1255).
|
||||
- Clarify error messages on policy errors (smallstep/certificates#1287,
|
||||
smallstep/certificates#1278).
|
||||
- Clarify error message on OIDC email validation (smallstep/certificates#1290).
|
||||
- Mark the IDP critical in the generated CRL data (smallstep/certificates#1293).
|
||||
- Disable database if CA is initialized with the `--no-db` flag
|
||||
(smallstep/certificates#1294).
|
||||
|
||||
## [v0.23.2] - 2023-02-02
|
||||
|
||||
### Added
|
||||
|
||||
- Added [`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) to
|
||||
docker images, and a new image, `smallstep/step-ca-hsm`, compiled with cgo
|
||||
(smallstep/certificates#1243).
|
||||
- Added [`scoop`](https://scoop.sh) packages back to the release
|
||||
(smallstep/certificates#1250).
|
||||
- Added optional flag `--pidfile` which allows passing a filename where step-ca
|
||||
will write its process id (smallstep/certificates#1251).
|
||||
- Added helpful message on CA startup when config can't be opened
|
||||
(smallstep/certificates#1252).
|
||||
- Improved validation and error messages on `device-attest-01` orders
|
||||
(smallstep/certificates#1235).
|
||||
|
||||
### Removed
|
||||
|
||||
- The deprecated CLI utils `step-awskms-init`, `step-cloudkms-init`,
|
||||
`step-pkcs11-init`, `step-yubikey-init` have been removed.
|
||||
[`step`](https://github.com/smallstep/cli) and
|
||||
[`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) should be
|
||||
used instead (smallstep/certificates#1240).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed remote management flags in docker images (smallstep/certificates#1228).
|
||||
|
||||
## [v0.23.1] - 2023-01-10
|
||||
|
||||
### Added
|
||||
|
||||
- Added configuration property `.crl.idpURL` to be able to set a custom Issuing
|
||||
Distribution Point in the CRL (smallstep/certificates#1178).
|
||||
- Added WithContext methods to the CA client (smallstep/certificates#1211).
|
||||
- Docker: Added environment variables for enabling Remote Management and ACME
|
||||
provisioner (smallstep/certificates#1201).
|
||||
- Docker: The entrypoint script now generates and displays an initial JWK
|
||||
provisioner password by default when the CA is being initialized
|
||||
(smallstep/certificates#1223).
|
||||
|
||||
### Changed
|
||||
|
||||
- Ignore SSH principals validation when using an OIDC provisioner. The
|
||||
provisioner will ignore the principals passed and set the defaults or the ones
|
||||
including using WebHooks or templates (smallstep/certificates#1206).
|
||||
|
||||
## [v0.23.0] - 2022-11-11
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for ACME device-attest-01 challenge on iOS, iPadOS, tvOS and
|
||||
YubiKey.
|
||||
- Ability to disable ACME challenges and attestation formats.
|
||||
- Added flags to change ACME challenge ports for testing purposes.
|
||||
- Added name constraints evaluation and enforcement when issuing or renewing
|
||||
X.509 certificates.
|
||||
- Added provisioner webhooks for augmenting template data and authorizing
|
||||
certificate requests before signing.
|
||||
- Added automatic migration of provisioners when enabling remote management.
|
||||
- Added experimental support for CRLs.
|
||||
- Add certificate renewal support on RA mode. The `step ca renew` command must
|
||||
use the flag `--mtls=false` to use the token renewal flow.
|
||||
- Added support for initializing remote management using `step ca init`.
|
||||
- Added support for renewing X.509 certificates on RAs.
|
||||
- Added support for using SCEP with keys in a KMS.
|
||||
- Added client support to set the dialer's local address with the environment variable
|
||||
`STEP_CLIENT_ADDR`.
|
||||
|
||||
### Changed
|
||||
|
||||
- Remove the email requirement for issuing SSH certificates with an OIDC
|
||||
provisioner.
|
||||
- Root files can contain more than one certificate.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed MySQL DSN parsing issues with an upgrade to
|
||||
[smallstep/nosql@v0.5.0](https://github.com/smallstep/nosql/releases/tag/v0.5.0).
|
||||
- Fixed renewal of certificates with missing subject attributes.
|
||||
- Fixed ACME support with [ejabberd](https://github.com/processone/ejabberd).
|
||||
|
||||
### Deprecated
|
||||
|
||||
- The CLIs `step-awskms-init`, `step-cloudkms-init`, `step-pkcs11-init`,
|
||||
`step-yubikey-init` are deprecated. Now you can use
|
||||
[`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) in
|
||||
combination with `step certificates create` to initialize your PKI.
|
||||
|
||||
## [0.22.1] - 2022-08-31
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed signature algorithm on EC (root) + RSA (intermediate) PKIs.
|
||||
|
||||
## [0.22.0] - 2022-08-26
|
||||
|
||||
### Added
|
||||
|
||||
- Added automatic configuration of Linked RAs.
|
||||
- Send provisioner configuration on Linked RAs.
|
||||
|
||||
### Changed
|
||||
|
||||
- Certificates signed by an issuer using an RSA key will be signed using the
|
||||
same algorithm used to sign the issuer certificate. The signature will no
|
||||
longer default to PKCS #1. For example, if the issuer certificate was signed
|
||||
using RSA-PSS with SHA-256, a new certificate will also be signed using
|
||||
RSA-PSS with SHA-256.
|
||||
- Support two latest versions of Go (1.18, 1.19).
|
||||
- Validate revocation serial number (either base 10 or prefixed with an
|
||||
appropriate base).
|
||||
- Sanitize TLS options.
|
||||
|
||||
## [0.20.0] - 2022-05-26
|
||||
|
||||
### Added
|
||||
|
||||
- Added Kubernetes auth method for Vault RAs.
|
||||
- Added support for reporting provisioners to linkedca.
|
||||
- Added support for certificate policies on authority level.
|
||||
- Added a Dockerfile with a step-ca build with HSM support.
|
||||
- A few new WithXX methods for instantiating authorities
|
||||
|
||||
### Changed
|
||||
|
||||
- Context usage in HTTP APIs.
|
||||
- Changed authentication for Vault RAs.
|
||||
- Error message returned to client when authenticating with expired certificate.
|
||||
- Strip padding from ACME CSRs.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- HTTP API handler types.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed SSH revocation.
|
||||
- CA client dial context for js/wasm target.
|
||||
- Incomplete `extraNames` support in templates.
|
||||
- SCEP GET request support.
|
||||
- Large SCEP request handling.
|
||||
|
||||
## [0.19.0] - 2022-04-19
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`.
|
||||
- Added support for `extraNames` in X.509 templates.
|
||||
- Added `armv5` builds.
|
||||
- Added RA support using a Vault instance as the CA.
|
||||
- Added `WithX509SignerFunc` authority option.
|
||||
- Added a new `/roots.pem` endpoint to download the CA roots in PEM format.
|
||||
- Added support for Azure `Managed Identity` tokens.
|
||||
- Added support for automatic configuration of linked RAs.
|
||||
- Added support for the `--context` flag. It's now possible to start the
|
||||
CA with `step-ca --context=abc` to use the configuration from context `abc`.
|
||||
When a context has been configured and no configuration file is provided
|
||||
on startup, the configuration for the current context is used.
|
||||
- Added startup info logging and option to skip it (`--quiet`).
|
||||
- Added support for renaming the CA (Common Name).
|
||||
|
||||
### Changed
|
||||
|
||||
- Made SCEP CA URL paths dynamic.
|
||||
- Support two latest versions of Go (1.17, 1.18).
|
||||
- Upgrade go.step.sm/crypto to v0.16.1.
|
||||
- Upgrade go.step.sm/linkedca to v0.15.0.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Go 1.16 support.
|
||||
|
||||
### Removed
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed admin credentials on RAs.
|
||||
- Fixed ACME HTTP-01 challenges for IPv6 identifiers.
|
||||
- Various improvements under the hood.
|
||||
|
||||
### Security
|
||||
|
||||
## [0.18.2] - 2022-03-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner.
|
||||
- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering
|
||||
out database drivers using Go tags. For example, using the Go flag
|
||||
`--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx
|
||||
driver for PostgreSQL.
|
||||
|
||||
### Changed
|
||||
|
||||
- IPv6 addresses are normalized as IP addresses instead of hostnames.
|
||||
- More descriptive JWK decryption error message.
|
||||
- Make the X5C leaf certificate available to the templates using `{{ .AuthorizationCrt }}`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- During provisioner add - validate provisioner configuration before storing to DB.
|
||||
|
||||
## [0.18.1] - 2022-02-03
|
||||
|
||||
### Added
|
||||
|
||||
- Support for ACME revocation.
|
||||
- Replace hash function with an RSA SSH CA to "rsa-sha2-256".
|
||||
- Support Nebula provisioners.
|
||||
- Example Ansible configurations.
|
||||
- Support PKCS#11 as a decrypter, as used by SCEP.
|
||||
|
||||
### Changed
|
||||
|
||||
- Automatically create database directory on `step ca init`.
|
||||
- Slightly improve errors reported when a template has invalid content.
|
||||
- Error reporting in logs and to clients.
|
||||
|
||||
### Fixed
|
||||
|
||||
- SCEP renewal using HTTPS on macOS.
|
||||
|
||||
## [0.18.0] - 2021-11-17
|
||||
|
||||
### Added
|
||||
|
||||
- Support for multiple certificate authority contexts.
|
||||
- Support for generating extractable keys and certificates on a pkcs#11 module.
|
||||
|
||||
### Changed
|
||||
|
||||
- Support two latest versions of Go (1.16, 1.17)
|
||||
|
||||
- Support two latest versions of golang (1.16, 1.17)
|
||||
### Deprecated
|
||||
|
||||
- go 1.15 support
|
||||
|
||||
## [0.17.6] - 2021-10-20
|
||||
|
||||
### Notes
|
||||
|
||||
- 0.17.5 failed in CI/CD
|
||||
|
||||
## [0.17.5] - 2021-10-20
|
||||
|
||||
### Added
|
||||
|
||||
- Support for Azure Key Vault as a KMS.
|
||||
- Adapt `pki` package to support key managers.
|
||||
- gocritic linter
|
||||
|
||||
### Fixed
|
||||
|
||||
- gocritic warnings
|
||||
|
||||
## [0.17.4] - 2021-09-28
|
||||
|
||||
### Fixed
|
||||
|
||||
- Support host-only or user-only SSH CA.
|
||||
|
||||
## [0.17.3] - 2021-09-24
|
||||
|
||||
### Added
|
||||
|
||||
- go 1.17 to github action test matrix
|
||||
- Support for CloudKMS RSA-PSS signers without using templates.
|
||||
- Add flags to support individual passwords for the intermediate and SSH keys.
|
||||
- Global support for group admins in the OIDC provisioner.
|
||||
|
||||
### Changed
|
||||
|
||||
- Using go 1.17 for binaries
|
||||
|
||||
### Fixed
|
||||
|
||||
- Upgrade go-jose.v2 to fix a bug in the JWK fingerprint of Ed25519 keys.
|
||||
|
||||
### Security
|
||||
|
||||
- Use cosign to sign and upload signatures for multi-arch Docker container.
|
||||
- Add debian checksum
|
||||
|
||||
## [0.17.2] - 2021-08-30
|
||||
|
||||
### Added
|
||||
|
||||
- Additional way to distinguish Azure IID and Azure OIDC tokens.
|
||||
|
||||
### Security
|
||||
|
||||
- Sign over all goreleaser github artifacts using cosign
|
||||
|
||||
## [0.17.1] - 2021-08-26
|
||||
|
||||
## [0.17.0] - 2021-08-25
|
||||
|
||||
### Added
|
||||
|
||||
- Add support for Linked CAs using protocol buffers and gRPC
|
||||
- `step-ca init` adds support for
|
||||
- configuring a StepCAS RA
|
||||
- configuring a Linked CA
|
||||
- congifuring a `step-ca` using Helm
|
||||
|
||||
### Changed
|
||||
|
||||
- Update badger driver to use v2 by default
|
||||
- Update TLS cipher suites to include 1.3
|
||||
|
||||
### Security
|
||||
|
||||
- Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512.
|
||||
|
|
150
Makefile
150
Makefile
|
@ -1,11 +1,21 @@
|
|||
PKG?=github.com/smallstep/certificates/cmd/step-ca
|
||||
BINNAME?=step-ca
|
||||
CLOUDKMS_BINNAME?=step-cloudkms-init
|
||||
CLOUDKMS_PKG?=github.com/smallstep/certificates/cmd/step-cloudkms-init
|
||||
AWSKMS_BINNAME?=step-awskms-init
|
||||
AWSKMS_PKG?=github.com/smallstep/certificates/cmd/step-awskms-init
|
||||
YUBIKEY_BINNAME?=step-yubikey-init
|
||||
YUBIKEY_PKG?=github.com/smallstep/certificates/cmd/step-yubikey-init
|
||||
PKCS11_BINNAME?=step-pkcs11-init
|
||||
PKCS11_PKG?=github.com/smallstep/certificates/cmd/step-pkcs11-init
|
||||
|
||||
# Set V to 1 for verbose output from the Makefile
|
||||
Q=$(if $V,,@)
|
||||
PREFIX?=
|
||||
SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*")
|
||||
GOOS_OVERRIDE ?=
|
||||
OUTPUT_ROOT=output/
|
||||
RELEASE=./.releases
|
||||
|
||||
all: lint test build
|
||||
|
||||
|
@ -18,11 +28,8 @@ ci: testcgo build
|
|||
#########################################
|
||||
|
||||
bootstra%:
|
||||
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $$(go env GOPATH)/bin latest
|
||||
$Q go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
$Q go install gotest.tools/gotestsum@latest
|
||||
$Q go install github.com/goreleaser/goreleaser@latest
|
||||
$Q go install github.com/sigstore/cosign/v2/cmd/cosign@latest
|
||||
# Using a released version of golangci-lint to take into account custom replacements in their go.mod
|
||||
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v1.42.0
|
||||
|
||||
.PHONY: bootstra%
|
||||
|
||||
|
@ -30,8 +37,17 @@ bootstra%:
|
|||
# Determine the type of `push` and `version`
|
||||
#################################################
|
||||
|
||||
# If TRAVIS_TAG is set then we know this ref has been tagged.
|
||||
ifdef TRAVIS_TAG
|
||||
VERSION ?= $(TRAVIS_TAG)
|
||||
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
|
||||
ifeq ($(NOT_RC),)
|
||||
PUSHTYPE := release-candidate
|
||||
else
|
||||
PUSHTYPE := release
|
||||
endif
|
||||
# GITHUB Actions
|
||||
ifdef GITHUB_REF
|
||||
else ifdef GITHUB_REF
|
||||
VERSION ?= $(shell echo $(GITHUB_REF) | sed 's/^refs\/tags\///')
|
||||
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
|
||||
ifeq ($(NOT_RC),)
|
||||
|
@ -44,50 +60,59 @@ VERSION ?= $(shell [ -d .git ] && git describe --tags --always --dirty="-dev")
|
|||
# If we are not in an active git dir then try reading the version from .VERSION.
|
||||
# .VERSION contains a slug populated by `git archive`.
|
||||
VERSION := $(or $(VERSION),$(shell ./.version.sh .VERSION))
|
||||
ifeq ($(TRAVIS_BRANCH),master)
|
||||
PUSHTYPE := master
|
||||
else
|
||||
PUSHTYPE := branch
|
||||
endif
|
||||
endif
|
||||
|
||||
VERSION := $(shell echo $(VERSION) | sed 's/^v//')
|
||||
DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/./g')
|
||||
|
||||
ifdef V
|
||||
$(info TRAVIS_TAG is $(TRAVIS_TAG))
|
||||
$(info GITHUB_REF is $(GITHUB_REF))
|
||||
$(info VERSION is $(VERSION))
|
||||
$(info DEB_VERSION is $(DEB_VERSION))
|
||||
$(info PUSHTYPE is $(PUSHTYPE))
|
||||
endif
|
||||
|
||||
include make/docker.mk
|
||||
|
||||
#########################################
|
||||
# Build
|
||||
#########################################
|
||||
|
||||
DATE := $(shell date -u '+%Y-%m-%d %H:%M UTC')
|
||||
LDFLAGS := -ldflags='-w -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"'
|
||||
|
||||
# Always explicitly enable or disable cgo,
|
||||
# so that go doesn't silently fall back on
|
||||
# non-cgo when gcc is not found.
|
||||
ifeq (,$(findstring CGO_ENABLED,$(GO_ENVS)))
|
||||
ifneq ($(origin GOFLAGS),undefined)
|
||||
# This section is for backward compatibility with
|
||||
#
|
||||
# $ make build GOFLAGS=""
|
||||
#
|
||||
# which is how we recommended building step-ca with cgo support
|
||||
# until June 2023.
|
||||
GO_ENVS := $(GO_ENVS) CGO_ENABLED=1
|
||||
else
|
||||
GO_ENVS := $(GO_ENVS) CGO_ENABLED=0
|
||||
endif
|
||||
endif
|
||||
GOFLAGS := CGO_ENABLED=0
|
||||
|
||||
download:
|
||||
$Q go mod download
|
||||
|
||||
build: $(PREFIX)bin/$(BINNAME)
|
||||
build: $(PREFIX)bin/$(BINNAME) $(PREFIX)bin/$(CLOUDKMS_BINNAME) $(PREFIX)bin/$(AWSKMS_BINNAME) $(PREFIX)bin/$(YUBIKEY_BINNAME) $(PREFIX)bin/$(PKCS11_BINNAME)
|
||||
@echo "Build Complete!"
|
||||
|
||||
$(PREFIX)bin/$(BINNAME): download $(call rwildcard,*.go)
|
||||
$Q mkdir -p $(@D)
|
||||
$Q $(GOOS_OVERRIDE) GOFLAGS="$(GOFLAGS)" $(GO_ENVS) go build -v -o $(PREFIX)bin/$(BINNAME) $(LDFLAGS) $(PKG)
|
||||
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(BINNAME) $(LDFLAGS) $(PKG)
|
||||
|
||||
$(PREFIX)bin/$(CLOUDKMS_BINNAME): download $(call rwildcard,*.go)
|
||||
$Q mkdir -p $(@D)
|
||||
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(CLOUDKMS_BINNAME) $(LDFLAGS) $(CLOUDKMS_PKG)
|
||||
|
||||
$(PREFIX)bin/$(AWSKMS_BINNAME): download $(call rwildcard,*.go)
|
||||
$Q mkdir -p $(@D)
|
||||
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(AWSKMS_BINNAME) $(LDFLAGS) $(AWSKMS_PKG)
|
||||
|
||||
$(PREFIX)bin/$(YUBIKEY_BINNAME): download $(call rwildcard,*.go)
|
||||
$Q mkdir -p $(@D)
|
||||
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(YUBIKEY_BINNAME) $(LDFLAGS) $(YUBIKEY_PKG)
|
||||
|
||||
$(PREFIX)bin/$(PKCS11_BINNAME): download $(call rwildcard,*.go)
|
||||
$Q mkdir -p $(@D)
|
||||
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(PKCS11_BINNAME) $(LDFLAGS) $(PKCS11_PKG)
|
||||
|
||||
# Target to force a build of step-ca without running tests
|
||||
simple: build
|
||||
|
@ -106,26 +131,18 @@ generate:
|
|||
#########################################
|
||||
# Test
|
||||
#########################################
|
||||
test: testdefault testtpmsimulator combinecoverage
|
||||
|
||||
testdefault:
|
||||
$Q $(GO_ENVS) gotestsum -- -coverprofile=defaultcoverage.out -short -covermode=atomic ./...
|
||||
|
||||
testtpmsimulator:
|
||||
$Q CGO_ENABLED=1 gotestsum -- -coverprofile=tpmsimulatorcoverage.out -short -covermode=atomic -tags tpmsimulator ./acme
|
||||
test:
|
||||
$Q $(GOFLAGS) go test -short -coverprofile=coverage.out ./...
|
||||
|
||||
testcgo:
|
||||
$Q gotestsum -- -coverprofile=coverage.out -short -covermode=atomic ./...
|
||||
$Q go test -short -coverprofile=coverage.out ./...
|
||||
|
||||
combinecoverage:
|
||||
cat defaultcoverage.out tpmsimulatorcoverage.out > coverage.out
|
||||
|
||||
.PHONY: test testdefault testtpmsimulator testcgo combinecoverage
|
||||
.PHONY: test testcgo
|
||||
|
||||
integrate: integration
|
||||
|
||||
integration: bin/$(BINNAME)
|
||||
$Q $(GO_ENVS) gotestsum -- -tags=integration ./integration/...
|
||||
$Q $(GOFLAGS) go test -tags=integration ./integration/...
|
||||
|
||||
.PHONY: integrate integration
|
||||
|
||||
|
@ -134,14 +151,15 @@ integration: bin/$(BINNAME)
|
|||
#########################################
|
||||
|
||||
fmt:
|
||||
$Q goimports -l -w $(SRC)
|
||||
$Q gofmt -l -w $(SRC)
|
||||
|
||||
lint: SHELL:=/bin/bash
|
||||
lint:
|
||||
$Q LOG_LEVEL=error golangci-lint run --config <(curl -s https://raw.githubusercontent.com/smallstep/workflows/master/.golangci.yml) --timeout=30m
|
||||
$Q govulncheck ./...
|
||||
$Q golangci-lint run --timeout=30m
|
||||
|
||||
.PHONY: fmt lint
|
||||
lintcgo:
|
||||
$Q LOG_LEVEL=error golangci-lint run --timeout=30m
|
||||
|
||||
.PHONY: fmt lint lintcgo
|
||||
|
||||
#########################################
|
||||
# Install
|
||||
|
@ -149,11 +167,15 @@ lint:
|
|||
|
||||
INSTALL_PREFIX?=/usr/
|
||||
|
||||
install: $(PREFIX)bin/$(BINNAME)
|
||||
install: $(PREFIX)bin/$(BINNAME) $(PREFIX)bin/$(CLOUDKMS_BINNAME) $(PREFIX)bin/$(AWSKMS_BINNAME)
|
||||
$Q install -D $(PREFIX)bin/$(BINNAME) $(DESTDIR)$(INSTALL_PREFIX)bin/$(BINNAME)
|
||||
$Q install -D $(PREFIX)bin/$(CLOUDKMS_BINNAME) $(DESTDIR)$(INSTALL_PREFIX)bin/$(CLOUDKMS_BINNAME)
|
||||
$Q install -D $(PREFIX)bin/$(AWSKMS_BINNAME) $(DESTDIR)$(INSTALL_PREFIX)bin/$(AWSKMS_BINNAME)
|
||||
|
||||
uninstall:
|
||||
$Q rm -f $(DESTDIR)$(INSTALL_PREFIX)/bin/$(BINNAME)
|
||||
$Q rm -f $(DESTDIR)$(INSTALL_PREFIX)/bin/$(CLOUDKMS_BINNAME)
|
||||
$Q rm -f $(DESTDIR)$(INSTALL_PREFIX)/bin/$(AWSKMS_BINNAME)
|
||||
|
||||
.PHONY: install uninstall
|
||||
|
||||
|
@ -165,6 +187,18 @@ clean:
|
|||
ifneq ($(BINNAME),"")
|
||||
$Q rm -f bin/$(BINNAME)
|
||||
endif
|
||||
ifneq ($(CLOUDKMS_BINNAME),"")
|
||||
$Q rm -f bin/$(CLOUDKMS_BINNAME)
|
||||
endif
|
||||
ifneq ($(AWSKMS_BINNAME),"")
|
||||
$Q rm -f bin/$(AWSKMS_BINNAME)
|
||||
endif
|
||||
ifneq ($(YUBIKEY_BINNAME),"")
|
||||
$Q rm -f bin/$(YUBIKEY_BINNAME)
|
||||
endif
|
||||
ifneq ($(PKCS11_BINNAME),"")
|
||||
$Q rm -f bin/$(PKCS11_BINNAME)
|
||||
endif
|
||||
|
||||
.PHONY: clean
|
||||
|
||||
|
@ -177,3 +211,31 @@ run:
|
|||
|
||||
.PHONY: run
|
||||
|
||||
#########################################
|
||||
# Debian
|
||||
#########################################
|
||||
|
||||
changelog:
|
||||
$Q echo "step-ca ($(DEB_VERSION)) unstable; urgency=medium" > debian/changelog
|
||||
$Q echo >> debian/changelog
|
||||
$Q echo " * See https://github.com/smallstep/certificates/releases" >> debian/changelog
|
||||
$Q echo >> debian/changelog
|
||||
$Q echo " -- Smallstep Labs, Inc. <techadmin@smallstep.com> $(shell date -uR)" >> debian/changelog
|
||||
|
||||
debian: changelog
|
||||
$Q mkdir -p $(RELEASE); \
|
||||
OUTPUT=../step-ca*.deb; \
|
||||
rm $$OUTPUT; \
|
||||
dpkg-buildpackage -b -rfakeroot -us -uc && cp $$OUTPUT $(RELEASE)/
|
||||
|
||||
distclean: clean
|
||||
|
||||
.PHONY: changelog debian distclean
|
||||
|
||||
#################################################
|
||||
# Targets for creating step artifacts
|
||||
#################################################
|
||||
|
||||
docker-artifacts: docker-$(PUSHTYPE)
|
||||
|
||||
.PHONY: docker-artifacts
|
||||
|
|
19
README.md
19
README.md
|
@ -35,7 +35,7 @@ To get up and running quickly, or as an alternative to running your own `step-ca
|
|||
|
||||
[](https://github.com/smallstep/certificates/releases/latest)
|
||||
[](https://goreportcard.com/report/github.com/smallstep/certificates)
|
||||
[](https://github.com/smallstep/certificates)
|
||||
[](https://travis-ci.com/smallstep/certificates)
|
||||
[](https://opensource.org/licenses/Apache-2.0)
|
||||
[](https://cla-assistant.io/smallstep/certificates)
|
||||
|
||||
|
@ -54,7 +54,7 @@ Setting up a *public key infrastructure* (PKI) is out of reach for many small te
|
|||
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation
|
||||
- Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
|
||||
- Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca)
|
||||
- [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
|
||||
- [Badger, BoltDB, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
|
||||
|
||||
### ⚙️ Many ways to automate
|
||||
|
||||
|
@ -68,7 +68,6 @@ You can issue certificates in exchange for:
|
|||
- [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure
|
||||
- [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
|
||||
- A trusted X.509 certificate (X5C provisioner)
|
||||
- A host certificate from your Nebula network
|
||||
- A SCEP challenge (SCEP provisioner)
|
||||
- An SSH host certificates needing renewal (the SSHPOP provisioner)
|
||||
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners)
|
||||
|
@ -119,12 +118,18 @@ See our installation docs [here](https://smallstep.com/docs/step-ca/installation
|
|||
|
||||
## Documentation
|
||||
|
||||
* [Official documentation](https://smallstep.com/docs/step-ca) is on smallstep.com
|
||||
* The `step` command reference is available via `step help`,
|
||||
[on smallstep.com](https://smallstep.com/docs/step-cli/reference/),
|
||||
or by running `step help --http=:8080` from the command line
|
||||
Documentation can be found in a handful of different places:
|
||||
|
||||
1. On the web at https://smallstep.com/docs/step-ca.
|
||||
|
||||
2. On the command line with `step help ca xxx` where `xxx` is the subcommand
|
||||
you are interested in. Ex: `step help ca provisioner list`.
|
||||
|
||||
3. In your browser, by running `step help --http=:8080 ca` from the command line
|
||||
and visiting http://localhost:8080.
|
||||
|
||||
4. The [docs](./docs/README.md) folder is being deprecated, but it still has some documentation and tutorials.
|
||||
|
||||
## Feedback?
|
||||
|
||||
* Tell us what you like and don't like about managing your PKI - we're eager to help solve problems in this space.
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
We appreciate any effort to discover and disclose security vulnerabilities responsibly.
|
||||
|
||||
If you would like to report a vulnerability in one of our projects, or have security concerns regarding Smallstep software, please email security@smallstep.com.
|
||||
|
||||
In order for us to best respond to your report, please include any of the following:
|
||||
* Steps to reproduce or proof-of-concept
|
||||
* Any relevant tools, including versions used
|
||||
* Tool output
|
|
@ -7,8 +7,6 @@ import (
|
|||
"time"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
// Account is a subset of the internal account type containing only those
|
||||
|
@ -20,16 +18,6 @@ type Account struct {
|
|||
Status Status `json:"status"`
|
||||
OrdersURL string `json:"orders"`
|
||||
ExternalAccountBinding interface{} `json:"externalAccountBinding,omitempty"`
|
||||
LocationPrefix string `json:"-"`
|
||||
ProvisionerName string `json:"-"`
|
||||
}
|
||||
|
||||
// GetLocation returns the URL location of the given account.
|
||||
func (a *Account) GetLocation() string {
|
||||
if a.LocationPrefix == "" {
|
||||
return ""
|
||||
}
|
||||
return a.LocationPrefix + a.ID
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
|
@ -43,7 +31,7 @@ func (a *Account) ToLog() (interface{}, error) {
|
|||
|
||||
// IsValid returns true if the Account is valid.
|
||||
func (a *Account) IsValid() bool {
|
||||
return a.Status == StatusValid
|
||||
return Status(a.Status) == StatusValid
|
||||
}
|
||||
|
||||
// KeyToID converts a JWK to a thumbprint.
|
||||
|
@ -55,64 +43,15 @@ func KeyToID(jwk *jose.JSONWebKey) (string, error) {
|
|||
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||
}
|
||||
|
||||
// PolicyNames contains ACME account level policy names
|
||||
type PolicyNames struct {
|
||||
DNSNames []string `json:"dns"`
|
||||
IPRanges []string `json:"ips"`
|
||||
}
|
||||
|
||||
// X509Policy contains ACME account level X.509 policy
|
||||
type X509Policy struct {
|
||||
Allowed PolicyNames `json:"allow"`
|
||||
Denied PolicyNames `json:"deny"`
|
||||
AllowWildcardNames bool `json:"allowWildcardNames"`
|
||||
}
|
||||
|
||||
// Policy is an ACME Account level policy
|
||||
type Policy struct {
|
||||
X509 X509Policy `json:"x509"`
|
||||
}
|
||||
|
||||
func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &policy.X509NameOptions{
|
||||
DNSDomains: p.X509.Allowed.DNSNames,
|
||||
IPRanges: p.X509.Allowed.IPRanges,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &policy.X509NameOptions{
|
||||
DNSDomains: p.X509.Denied.DNSNames,
|
||||
IPRanges: p.X509.Denied.IPRanges,
|
||||
}
|
||||
}
|
||||
|
||||
// AreWildcardNamesAllowed returns if wildcard names
|
||||
// like *.example.com are allowed to be signed.
|
||||
// Defaults to false.
|
||||
func (p *Policy) AreWildcardNamesAllowed() bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
return p.X509.AllowWildcardNames
|
||||
}
|
||||
|
||||
// ExternalAccountKey is an ACME External Account Binding key.
|
||||
type ExternalAccountKey struct {
|
||||
ID string `json:"id"`
|
||||
ProvisionerID string `json:"provisionerID"`
|
||||
Reference string `json:"reference"`
|
||||
AccountID string `json:"-"`
|
||||
HmacKey []byte `json:"-"`
|
||||
KeyBytes []byte `json:"-"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
BoundAt time.Time `json:"boundAt,omitempty"`
|
||||
Policy *Policy `json:"policy,omitempty"`
|
||||
}
|
||||
|
||||
// AlreadyBound returns whether this EAK is already bound to
|
||||
|
@ -129,6 +68,6 @@ func (eak *ExternalAccountKey) BindTo(account *Account) error {
|
|||
}
|
||||
eak.AccountID = account.ID
|
||||
eak.BoundAt = time.Now()
|
||||
eak.HmacKey = []byte{} // clearing the key bytes; can only be used once
|
||||
eak.KeyBytes = []byte{} // clearing the key bytes; can only be used once
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,9 +7,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
func TestKeyToID(t *testing.T) {
|
||||
|
@ -46,14 +45,14 @@ func TestKeyToID(t *testing.T) {
|
|||
tc := run(t)
|
||||
if id, err := KeyToID(tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
var k *Error
|
||||
if errors.As(err, &k) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
assert.Equals(t, k.Type, tc.err.Type)
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
assert.Equals(t, k.Status, tc.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
} else {
|
||||
default:
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
}
|
||||
|
@ -66,23 +65,6 @@ func TestKeyToID(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAccount_GetLocation(t *testing.T) {
|
||||
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
|
||||
type test struct {
|
||||
acc *Account
|
||||
exp string
|
||||
}
|
||||
tests := map[string]test{
|
||||
"empty": {acc: &Account{LocationPrefix: ""}, exp: ""},
|
||||
"not-empty": {acc: &Account{ID: "bar", LocationPrefix: locationPrefix}, exp: locationPrefix + "bar"},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert.Equals(t, tc.acc.GetLocation(), tc.exp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_IsValid(t *testing.T) {
|
||||
type test struct {
|
||||
acc *Account
|
||||
|
@ -113,7 +95,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
},
|
||||
acct: &Account{
|
||||
ID: "accountID",
|
||||
|
@ -126,7 +108,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
AccountID: "someAccountID",
|
||||
BoundAt: boundAt,
|
||||
},
|
||||
|
@ -148,15 +130,15 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
|
|||
}
|
||||
if wantErr {
|
||||
assert.NotNil(t, err)
|
||||
var ae *Error
|
||||
if assert.True(t, errors.As(err, &ae)) {
|
||||
assert.Equals(t, ae.Type, tt.err.Type)
|
||||
assert.Equals(t, ae.Detail, tt.err.Detail)
|
||||
assert.Equals(t, ae.Subproblems, tt.err.Subproblems)
|
||||
}
|
||||
assert.Type(t, &Error{}, err)
|
||||
ae, _ := err.(*Error)
|
||||
assert.Equals(t, ae.Type, tt.err.Type)
|
||||
assert.Equals(t, ae.Detail, tt.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tt.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tt.err.Subproblems)
|
||||
} else {
|
||||
assert.Equals(t, eak.AccountID, acct.ID)
|
||||
assert.Equals(t, eak.HmacKey, []byte{})
|
||||
assert.Equals(t, eak.KeyBytes, []byte{})
|
||||
assert.NotNil(t, eak.BoundAt)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
@ -68,18 +66,9 @@ func (u *UpdateAccountRequest) Validate() error {
|
|||
}
|
||||
}
|
||||
|
||||
// getAccountLocationPath returns the current account URL location.
|
||||
// Returned location will be of the form: https://<ca-url>/acme/<provisioner>/account/<accID>
|
||||
func getAccountLocationPath(ctx context.Context, linker acme.Linker, accID string) string {
|
||||
return linker.GetLink(ctx, acme.AccountLinkType, accID)
|
||||
}
|
||||
|
||||
// NewAccount is the handler resource for creating new ACME accounts.
|
||||
func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -105,8 +94,8 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
httpStatus := http.StatusCreated
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if !errors.As(err, &acmeErr) || acmeErr.Status != http.StatusBadRequest {
|
||||
acmeErr, ok := err.(*acme.Error)
|
||||
if !ok || acmeErr.Status != http.StatusBadRequest {
|
||||
// Something went wrong ...
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -125,30 +114,29 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
eak, err := validateExternalAccountBinding(ctx, &nar)
|
||||
eak, err := h.validateExternalAccountBinding(ctx, &nar)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
acc = &acme.Account{
|
||||
Key: jwk,
|
||||
Contact: nar.Contact,
|
||||
Status: acme.StatusValid,
|
||||
LocationPrefix: getAccountLocationPath(ctx, linker, ""),
|
||||
ProvisionerName: prov.GetName(),
|
||||
Key: jwk,
|
||||
Contact: nar.Contact,
|
||||
Status: acme.StatusValid,
|
||||
}
|
||||
if err := db.CreateAccount(ctx, acc); err != nil {
|
||||
if err := h.db.CreateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
||||
return
|
||||
}
|
||||
|
||||
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
|
||||
if err := eak.BindTo(acc); err != nil {
|
||||
err := eak.BindTo(acc)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
|
||||
return
|
||||
}
|
||||
|
@ -159,18 +147,15 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
httpStatus = http.StatusOK
|
||||
}
|
||||
|
||||
linker.LinkAccount(ctx, acc)
|
||||
h.linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID))
|
||||
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
|
||||
render.JSONStatus(w, acc, httpStatus)
|
||||
}
|
||||
|
||||
// GetOrUpdateAccount is the api for updating an ACME account.
|
||||
func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -202,16 +187,16 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
|||
acc.Contact = uar.Contact
|
||||
}
|
||||
|
||||
if err := db.UpdateAccount(ctx, acc); err != nil {
|
||||
if err := h.db.UpdateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
linker.LinkAccount(ctx, acc)
|
||||
h.linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
|
||||
render.JSON(w, acc)
|
||||
}
|
||||
|
||||
|
@ -225,11 +210,8 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
|||
}
|
||||
|
||||
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
|
||||
func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -240,14 +222,13 @@ func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
||||
return
|
||||
}
|
||||
|
||||
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
|
||||
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrdersByAccountID(ctx, orders)
|
||||
h.linker.LinkOrdersByAccountID(ctx, orders)
|
||||
|
||||
render.JSON(w, orders)
|
||||
logOrdersByAccount(w, orders)
|
||||
|
|
|
@ -3,7 +3,6 @@ package api
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -14,12 +13,10 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -32,27 +29,6 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
type fakeProvisioner struct{}
|
||||
|
||||
func (*fakeProvisioner) AuthorizeOrderIdentifier(context.Context, provisioner.ACMEIdentifier) error {
|
||||
return nil
|
||||
}
|
||||
func (*fakeProvisioner) AuthorizeSign(context.Context, string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (*fakeProvisioner) IsChallengeEnabled(context.Context, provisioner.ACMEChallenge) bool {
|
||||
return true
|
||||
}
|
||||
func (*fakeProvisioner) IsAttestationFormatEnabled(context.Context, provisioner.ACMEAttestationFormat) bool {
|
||||
return true
|
||||
}
|
||||
func (*fakeProvisioner) GetAttestationRoots() (*x509.CertPool, bool) { return nil, false }
|
||||
func (*fakeProvisioner) AuthorizeRevoke(context.Context, string) error { return nil }
|
||||
func (*fakeProvisioner) GetID() string { return "" }
|
||||
func (*fakeProvisioner) GetName() string { return "" }
|
||||
func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 }
|
||||
func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil }
|
||||
|
||||
func newProv() acme.Provisioner {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
|
@ -65,19 +41,6 @@ func newProv() acme.Provisioner {
|
|||
return p
|
||||
}
|
||||
|
||||
func newProvWithOptions(options *provisioner.Options) acme.Provisioner {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-<test>provisioner.com",
|
||||
Options: options,
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func newACMEProv(t *testing.T) *provisioner.ACME {
|
||||
p := newProv()
|
||||
a, ok := p.(*provisioner.ACME)
|
||||
|
@ -87,15 +50,6 @@ func newACMEProv(t *testing.T) *provisioner.ACME {
|
|||
return a
|
||||
}
|
||||
|
||||
func newACMEProvWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME {
|
||||
p := newProvWithOptions(options)
|
||||
a, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
t.Fatal("not a valid ACME provisioner")
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) {
|
||||
signer, err := jose.NewSigner(
|
||||
jose.SigningKey{
|
||||
|
@ -190,12 +144,11 @@ func TestNewAccountRequest_Validate(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.nar.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
var ae *acme.Error
|
||||
if assert.True(t, errors.As(err, &ae)) {
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
|
@ -262,12 +215,11 @@ func TestUpdateAccountRequest_Validate(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.uar.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
var ae *acme.Error
|
||||
if assert.True(t, errors.As(err, &ae)) {
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
|
@ -344,9 +296,10 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: accID}
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
|
||||
|
@ -362,11 +315,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetOrdersByAccountID(w, req)
|
||||
h.GetOrdersByAccountID(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -381,6 +334,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -409,7 +363,6 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -418,7 +371,6 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -427,7 +379,6 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to "+
|
||||
|
@ -442,7 +393,6 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||
|
@ -455,9 +405,8 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -469,10 +418,9 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jwk expected in request context"),
|
||||
|
@ -484,11 +432,10 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jwk expected in request context"),
|
||||
|
@ -507,9 +454,9 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"),
|
||||
|
@ -524,7 +471,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -554,11 +501,18 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
scepProvisioner := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"),
|
||||
|
@ -597,13 +551,14 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
eak := &acme.ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
return test{
|
||||
|
@ -644,7 +599,8 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||
|
@ -679,11 +635,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
acc: acc,
|
||||
statusCode: 200,
|
||||
|
@ -708,7 +664,8 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = false
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||
|
@ -762,7 +719,8 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -777,7 +735,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -801,11 +759,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
NewAccount(w, req)
|
||||
h.NewAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -820,6 +778,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -855,7 +814,6 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -864,7 +822,6 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -873,7 +830,6 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
"fail/no-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -883,7 +839,6 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -893,7 +848,6 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
|
||||
|
@ -908,7 +862,6 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||
|
@ -941,9 +894,10 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||
|
@ -960,11 +914,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
uar := &UpdateAccountRequest{}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
|
@ -975,9 +929,10 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||
|
@ -991,11 +946,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok/post-as-get": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
|
@ -1004,11 +959,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetOrUpdateAccount(w, req)
|
||||
h.GetOrUpdateAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1023,6 +978,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
|
|
@ -3,11 +3,9 @@ package api
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
// ExternalAccountBinding represents the ACME externalAccountBinding JWS
|
||||
|
@ -18,14 +16,13 @@ type ExternalAccountBinding struct {
|
|||
}
|
||||
|
||||
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
|
||||
func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
||||
func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
|
||||
}
|
||||
|
||||
if !acmeProv.RequireEAB {
|
||||
//nolint:nilnil // legacy
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@ -50,29 +47,19 @@ func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest)
|
|||
return nil, acmeErr
|
||||
}
|
||||
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||
externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||
if err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
if _, ok := err.(*acme.Error); ok {
|
||||
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
|
||||
}
|
||||
return nil, acme.WrapErrorISE(err, "error retrieving external account key")
|
||||
}
|
||||
|
||||
if externalAccountKey == nil {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key")
|
||||
}
|
||||
|
||||
if len(externalAccountKey.HmacKey) == 0 {
|
||||
return nil, acme.NewError(acme.ErrorServerInternalType, "external account binding key with id '%s' does not have secret bytes", keyID)
|
||||
}
|
||||
|
||||
if externalAccountKey.AlreadyBound() {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt)
|
||||
}
|
||||
|
||||
payload, err := eabJWS.Verify(externalAccountKey.HmacKey)
|
||||
payload, err := eabJWS.Verify(externalAccountKey.KeyBytes)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature")
|
||||
}
|
||||
|
@ -110,12 +97,12 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
|
|||
|
||||
// validateEABJWS verifies the contents of the External Account Binding JWS.
|
||||
// The protected header of the JWS MUST meet the following criteria:
|
||||
//
|
||||
// - The "alg" field MUST indicate a MAC-based algorithm
|
||||
// - The "kid" field MUST contain the key identifier provided by the CA
|
||||
// - The "nonce" field MUST NOT be present
|
||||
// - The "url" field MUST be set to the same value as the outer JWS
|
||||
// o The "alg" field MUST indicate a MAC-based algorithm
|
||||
// o The "kid" field MUST contain the key identifier provided by the CA
|
||||
// o The "nonce" field MUST NOT be present
|
||||
// o The "url" field MUST be set to the same value as the outer JWS
|
||||
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
|
||||
|
||||
if jws == nil {
|
||||
return "", acme.NewErrorISE("no JWS provided")
|
||||
}
|
||||
|
|
|
@ -9,11 +9,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
func Test_keysAreEqual(t *testing.T) {
|
||||
|
@ -99,7 +98,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
|
@ -143,7 +143,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
createdAt := time.Now()
|
||||
return test{
|
||||
|
@ -153,7 +154,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: createdAt,
|
||||
}, nil
|
||||
},
|
||||
|
@ -167,7 +168,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: createdAt,
|
||||
},
|
||||
err: nil,
|
||||
|
@ -188,10 +189,17 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
scepProvisioner := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
|
||||
|
@ -210,7 +218,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
|
@ -255,7 +264,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
|
@ -300,7 +310,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -347,7 +358,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -396,7 +408,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -413,112 +426,6 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
err: acme.NewErrorISE("error retrieving external account key"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetExternalAccountKey-nil": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
|
||||
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
|
||||
assert.FatalError(t, err)
|
||||
eab := &ExternalAccountBinding{}
|
||||
err = json.Unmarshal(rawEABJWS, &eab)
|
||||
assert.FatalError(t, err)
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
ExternalAccountBinding: eab,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
|
||||
so.WithHeader("url", url)
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||
Key: jwk.Key,
|
||||
}, so)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := signer.Sign(payloadBytes)
|
||||
assert.FatalError(t, err)
|
||||
raw, err := jws.CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
nar: &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
ExternalAccountBinding: eab,
|
||||
},
|
||||
eak: nil,
|
||||
err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetExternalAccountKey-no-keybytes": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
|
||||
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
|
||||
assert.FatalError(t, err)
|
||||
eab := &ExternalAccountBinding{}
|
||||
err = json.Unmarshal(rawEABJWS, &eab)
|
||||
assert.FatalError(t, err)
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
ExternalAccountBinding: eab,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
|
||||
so.WithHeader("url", url)
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||
Key: jwk.Key,
|
||||
}, so)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := signer.Sign(payloadBytes)
|
||||
assert.FatalError(t, err)
|
||||
raw, err := jws.CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
createdAt := time.Now()
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
return &acme.ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
CreatedAt: createdAt,
|
||||
AccountID: "some-account-id",
|
||||
HmacKey: []byte{},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
nar: &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
ExternalAccountBinding: eab,
|
||||
},
|
||||
eak: nil,
|
||||
err: acme.NewError(acme.ErrorServerInternalType, "external account binding key with id 'eakID' does not have secret bytes"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -551,7 +458,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -598,7 +506,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
createdAt := time.Now()
|
||||
boundAt := time.Now().Add(1 * time.Second)
|
||||
|
@ -611,7 +520,6 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
Reference: "testeak",
|
||||
CreatedAt: createdAt,
|
||||
AccountID: "some-account-id",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
BoundAt: boundAt,
|
||||
}, nil
|
||||
},
|
||||
|
@ -657,7 +565,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -666,7 +575,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
HmacKey: []byte{1, 2, 3, 4},
|
||||
KeyBytes: []byte{1, 2, 3, 4},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -714,7 +623,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -723,7 +633,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -768,7 +678,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -777,7 +688,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -823,7 +734,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, nil)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -832,7 +744,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -850,8 +762,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
|
||||
got, err := validateExternalAccountBinding(ctx, tc.nar)
|
||||
h := &Handler{
|
||||
db: tc.db,
|
||||
}
|
||||
got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar)
|
||||
wantErr := tc.err != nil
|
||||
gotErr := err != nil
|
||||
if wantErr != gotErr {
|
||||
|
@ -860,21 +774,20 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
if wantErr {
|
||||
assert.NotNil(t, err)
|
||||
assert.Type(t, &acme.Error{}, err)
|
||||
var ae *acme.Error
|
||||
if assert.True(t, errors.As(err, &ae)) {
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Status, tc.err.Status)
|
||||
assert.HasPrefix(t, ae.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
}
|
||||
ae, _ := err.(*acme.Error)
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Status, tc.err.Status)
|
||||
assert.HasPrefix(t, ae.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
} else {
|
||||
if got == nil {
|
||||
assert.Nil(t, tc.eak)
|
||||
} else {
|
||||
assert.NotNil(t, tc.eak)
|
||||
assert.Equals(t, got.ID, tc.eak.ID)
|
||||
assert.Equals(t, got.HmacKey, tc.eak.HmacKey)
|
||||
assert.Equals(t, got.KeyBytes, tc.eak.KeyBytes)
|
||||
assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID)
|
||||
assert.Equals(t, got.Reference, tc.eak.Reference)
|
||||
assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt)
|
||||
|
@ -1144,6 +1057,7 @@ func Test_validateEABJWS(t *testing.T) {
|
|||
assert.Equals(t, tc.err.Status, err.Status)
|
||||
assert.HasPrefix(t, err.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, tc.err.Identifier, err.Identifier)
|
||||
assert.Equals(t, tc.err.Subproblems, err.Subproblems)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
|
|
|
@ -2,10 +2,12 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -14,7 +16,6 @@ import (
|
|||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
|
@ -38,152 +39,111 @@ type payloadInfo struct {
|
|||
isEmptyJSON bool
|
||||
}
|
||||
|
||||
// Handler is the ACME API request handler.
|
||||
type Handler struct {
|
||||
db acme.DB
|
||||
backdate provisioner.Duration
|
||||
ca acme.CertificateAuthority
|
||||
linker Linker
|
||||
validateChallengeOptions *acme.ValidateChallengeOptions
|
||||
prerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
}
|
||||
|
||||
// HandlerOptions required to create a new ACME API request handler.
|
||||
type HandlerOptions struct {
|
||||
// DB storage backend that implements the acme.DB interface.
|
||||
//
|
||||
// Deprecated: use acme.NewContex(context.Context, acme.DB)
|
||||
DB acme.DB
|
||||
|
||||
// CA is the certificate authority interface.
|
||||
//
|
||||
// Deprecated: use authority.NewContext(context.Context, *authority.Authority)
|
||||
CA acme.CertificateAuthority
|
||||
|
||||
// Backdate is the duration that the CA will subtract from the current time
|
||||
// to set the NotBefore in the certificate.
|
||||
Backdate provisioner.Duration
|
||||
|
||||
// DB storage backend that impements the acme.DB interface.
|
||||
DB acme.DB
|
||||
// DNS the host used to generate accurate ACME links. By default the authority
|
||||
// will use the Host from the request, so this value will only be used if
|
||||
// request.Host is empty.
|
||||
DNS string
|
||||
|
||||
// Prefix is a URL path prefix under which the ACME api is served. This
|
||||
// prefix is required to generate accurate ACME links.
|
||||
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
|
||||
// "acme" is the prefix from which the ACME api is accessed.
|
||||
Prefix string
|
||||
|
||||
CA acme.CertificateAuthority
|
||||
// PrerequisitesChecker checks if all prerequisites for serving ACME are
|
||||
// met by the CA configuration.
|
||||
PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
}
|
||||
|
||||
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// handler is the ACME API request handler.
|
||||
type handler struct {
|
||||
opts *HandlerOptions
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface. For backward compatibility
|
||||
// this route adds will add a new middleware that will set the ACME components
|
||||
// on the context.
|
||||
//
|
||||
// Note: this method is deprecated in step-ca, other applications can still use
|
||||
// this to support ACME, but the recommendation is to use use
|
||||
// api.Route(api.Router) and acme.NewContext() instead.
|
||||
func (h *handler) Route(r api.Router) {
|
||||
client := acme.NewClient()
|
||||
linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix)
|
||||
route(r, func(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
|
||||
ctx = authority.NewContext(ctx, ca)
|
||||
}
|
||||
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// NewHandler returns a new ACME API handler.
|
||||
//
|
||||
// Note: this method is deprecated in step-ca, other applications can still use
|
||||
// this to support ACME, but the recommendation is to use use
|
||||
// api.Route(api.Router) and acme.NewContext() instead.
|
||||
func NewHandler(opts HandlerOptions) api.RouterHandler {
|
||||
return &handler{
|
||||
opts: &opts,
|
||||
func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
prerequisitesChecker := func(ctx context.Context) (bool, error) {
|
||||
// by default all prerequisites are met
|
||||
return true, nil
|
||||
}
|
||||
if ops.PrerequisitesChecker != nil {
|
||||
prerequisitesChecker = ops.PrerequisitesChecker
|
||||
}
|
||||
return &Handler{
|
||||
ca: ops.CA,
|
||||
db: ops.DB,
|
||||
backdate: ops.Backdate,
|
||||
linker: NewLinker(ops.DNS, ops.Prefix),
|
||||
validateChallengeOptions: &acme.ValidateChallengeOptions{
|
||||
HTTPGet: client.Get,
|
||||
LookupTxt: net.LookupTXT,
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(dialer, network, addr, config)
|
||||
},
|
||||
},
|
||||
prerequisitesChecker: prerequisitesChecker,
|
||||
}
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface. This method requires that
|
||||
// all the acme components, authority, db, client, linker, and prerequisite
|
||||
// checker to be present in the context.
|
||||
func Route(r api.Router) {
|
||||
route(r, nil)
|
||||
}
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
getPath := h.linker.GetUnescapedPathSuffix
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||
|
||||
func route(r api.Router, middleware func(next nextHTTP) nextHTTP) {
|
||||
commonMiddleware := func(next nextHTTP) nextHTTP {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
// Linker middleware gets the provisioner and current url from the
|
||||
// request and sets them in the context.
|
||||
linker := acme.MustLinkerFromContext(r.Context())
|
||||
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
|
||||
}
|
||||
if middleware != nil {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
return handler
|
||||
}
|
||||
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
||||
return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
|
||||
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))))
|
||||
}
|
||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
|
||||
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
|
||||
getPath := acme.GetUnescapedPathSuffix
|
||||
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
|
||||
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
|
||||
extractPayloadByJWK(NewAccount))
|
||||
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(GetOrUpdateAccount))
|
||||
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(NotImplemented))
|
||||
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
|
||||
extractPayloadByKid(NewOrder))
|
||||
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrder)))
|
||||
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
|
||||
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(FinalizeOrder))
|
||||
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetAuthorization)))
|
||||
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
|
||||
extractPayloadByKid(GetChallenge))
|
||||
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetCertificate)))
|
||||
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
|
||||
extractPayloadByKidOrJWK(RevokeCert))
|
||||
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
|
||||
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
|
||||
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented))
|
||||
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder))
|
||||
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
||||
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID)))
|
||||
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
||||
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
|
||||
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
|
||||
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
||||
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert))
|
||||
}
|
||||
|
||||
// GetNonce just sets the right header since a Nonce is added to each response
|
||||
// by middleware by default.
|
||||
func GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
|
@ -205,7 +165,7 @@ type Directory struct {
|
|||
NewOrder string `json:"newOrder"`
|
||||
RevokeCert string `json:"revokeCert"`
|
||||
KeyChange string `json:"keyChange"`
|
||||
Meta *Meta `json:"meta,omitempty"`
|
||||
Meta Meta `json:"meta"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging for the Directory type.
|
||||
|
@ -219,7 +179,7 @@ func (d *Directory) ToLog() (interface{}, error) {
|
|||
|
||||
// GetDirectory is the ACME resource for returning a directory configuration
|
||||
// for client configuration.
|
||||
func GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -227,68 +187,33 @@ func GetDirectory(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
render.JSON(w, &Directory{
|
||||
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
|
||||
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
|
||||
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
|
||||
RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
|
||||
KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
|
||||
Meta: createMetaObject(acmeProv),
|
||||
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
|
||||
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
|
||||
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
|
||||
RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType),
|
||||
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType),
|
||||
Meta: Meta{
|
||||
ExternalAccountRequired: acmeProv.RequireEAB,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// createMetaObject creates a Meta object if the ACME provisioner
|
||||
// has one or more properties that are written in the ACME directory output.
|
||||
// It returns nil if none of the properties are set.
|
||||
func createMetaObject(p *provisioner.ACME) *Meta {
|
||||
if shouldAddMetaObject(p) {
|
||||
return &Meta{
|
||||
TermsOfService: p.TermsOfService,
|
||||
Website: p.Website,
|
||||
CaaIdentities: p.CaaIdentities,
|
||||
ExternalAccountRequired: p.RequireEAB,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldAddMetaObject returns whether or not the ACME provisioner
|
||||
// has properties configured that must be added to the ACME directory object.
|
||||
func shouldAddMetaObject(p *provisioner.ACME) bool {
|
||||
switch {
|
||||
case p.TermsOfService != "":
|
||||
return true
|
||||
case p.Website != "":
|
||||
return true
|
||||
case len(p.CaaIdentities) > 0:
|
||||
return true
|
||||
case p.RequireEAB:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
||||
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
||||
func NotImplemented(w http.ResponseWriter, _ *http.Request) {
|
||||
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
||||
}
|
||||
|
||||
// GetAuthorization ACME api for retrieving an Authz.
|
||||
func GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
||||
return
|
||||
|
@ -298,43 +223,41 @@ func GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
|||
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
||||
return
|
||||
}
|
||||
if err = az.UpdateStatus(ctx, db); err != nil {
|
||||
if err = az.UpdateStatus(ctx, h.db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkAuthorization(ctx, az)
|
||||
h.linker.LinkAuthorization(ctx, az)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
|
||||
render.JSON(w, az)
|
||||
}
|
||||
|
||||
// GetChallenge ACME api for retrieving a Challenge.
|
||||
func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
// Just verify that the payload was set, since we're not strictly adhering
|
||||
// to ACME V2 spec for reasons specified below.
|
||||
_, err = payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// NOTE: We should be checking that the request is either a POST-as-GET, or
|
||||
// that for all challenges except for device-attest-01, the payload is an
|
||||
// empty JSON block ({}). However, older ACME clients still send a vestigial
|
||||
// body (rather than an empty JSON block) and strict enforcement would
|
||||
// render these clients broken.
|
||||
// NOTE: We should be checking ^^^ that the request is either a POST-as-GET, or
|
||||
// that the payload is an empty JSON block ({}). However, older ACME clients
|
||||
// still send a vestigial body (rather than an empty JSON block) and
|
||||
// strict enforcement would render these clients broken. For the time being
|
||||
// we'll just ignore the body.
|
||||
|
||||
azID := chi.URLParam(r, "authzID")
|
||||
ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
||||
return
|
||||
|
@ -350,31 +273,29 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if err = ch.Validate(ctx, db, jwk, payload.value); err != nil {
|
||||
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkChallenge(ctx, ch, azID)
|
||||
h.linker.LinkChallenge(ctx, ch, azID)
|
||||
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
|
||||
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
|
||||
render.JSON(w, ch)
|
||||
}
|
||||
|
||||
// GetCertificate ACME api for retrieving a Certificate.
|
||||
func GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
certID := chi.URLParam(r, "certID")
|
||||
cert, err := db.GetCertificate(ctx, certID)
|
||||
|
||||
cert, err := h.db.GetCertificate(ctx, certID)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
||||
return
|
||||
|
@ -394,6 +315,6 @@ func GetCertificate(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
api.LogCertificate(w, cert.Leaf)
|
||||
w.Header().Set("Content-Type", "application/pem-certificate-chain")
|
||||
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
|
||||
w.Write(certBytes)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package api
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
|
@ -18,38 +17,13 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
get func(url string) (*http.Response, error)
|
||||
lookupTxt func(name string) ([]string, error)
|
||||
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) }
|
||||
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
|
||||
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return m.tlsDial(network, addr, config)
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_GetNonce(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -64,10 +38,10 @@ func TestHandler_GetNonce(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// h := &Handler{}
|
||||
h := &Handler{}
|
||||
w := httptest.NewRecorder()
|
||||
req.Method = tt.name
|
||||
GetNonce(w, req)
|
||||
h.GetNonce(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -78,8 +52,7 @@ func TestHandler_GetNonce(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandler_GetDirectory(t *testing.T) {
|
||||
linker := acme.NewLinker("ca.smallstep.com", "acme")
|
||||
_ = linker
|
||||
linker := NewLinker("ca.smallstep.com", "acme")
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
|
@ -88,14 +61,23 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner is not in context"),
|
||||
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"),
|
||||
}
|
||||
},
|
||||
"fail/different-provisioner": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{})
|
||||
prov := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
|
@ -106,7 +88,8 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
|
@ -125,42 +108,15 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
|
||||
RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName),
|
||||
KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName),
|
||||
Meta: &Meta{
|
||||
ExternalAccountRequired: true,
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
dir: expDir,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/full-meta": func(t *testing.T) test {
|
||||
prov := newACMEProv(t)
|
||||
prov.TermsOfService = "https://terms.ca.local/"
|
||||
prov.Website = "https://ca.local/"
|
||||
prov.CaaIdentities = []string{"ca.local"}
|
||||
prov.RequireEAB = true
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
|
||||
RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName),
|
||||
KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName),
|
||||
Meta: &Meta{
|
||||
TermsOfService: "https://terms.ca.local/",
|
||||
Website: "https://ca.local/",
|
||||
CaaIdentities: []string{"ca.local"},
|
||||
Meta: Meta{
|
||||
ExternalAccountRequired: true,
|
||||
},
|
||||
}
|
||||
|
@ -174,11 +130,11 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
h := &Handler{linker: linker}
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetDirectory(w, req)
|
||||
h.GetDirectory(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -193,6 +149,7 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -262,7 +219,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
|
@ -328,9 +285,10 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||
|
@ -346,11 +304,11 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetAuthorization(w, req)
|
||||
h.GetAuthorization(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -365,6 +323,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -488,11 +447,11 @@ func TestHandler_GetCertificate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetCertificate(w, req)
|
||||
h.GetCertificate(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -507,11 +466,12 @@ func TestHandler_GetCertificate(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.HasPrefix(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes))
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain"})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain; charset=utf-8"})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -531,7 +491,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
|
||||
type test struct {
|
||||
db acme.DB
|
||||
vc acme.Client
|
||||
vco *acme.ValidateChallengeOptions
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
ch *acme.Challenge
|
||||
|
@ -540,7 +500,6 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -548,7 +507,6 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), accContextKey, nil),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -558,7 +516,6 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -566,11 +523,10 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -578,7 +534,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -597,7 +553,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -616,7 +572,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/no-jwk": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -635,7 +591,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/nil-jwk": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||
|
@ -655,7 +611,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/validate-challenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
|
@ -683,8 +639,8 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
return acme.NewErrorISE("force")
|
||||
},
|
||||
},
|
||||
vc: &mockClient{
|
||||
get: func(string) (*http.Response, error) {
|
||||
vco: &acme.ValidateChallengeOptions{
|
||||
HTTPGet: func(string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -695,13 +651,14 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
_pub := _jwk.Public()
|
||||
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -733,8 +690,8 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
URL: u,
|
||||
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
||||
},
|
||||
vc: &mockClient{
|
||||
get: func(string) (*http.Response, error) {
|
||||
vco: &acme.ValidateChallengeOptions{
|
||||
HTTPGet: func(string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -746,11 +703,11 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetChallenge(w, req)
|
||||
h.GetChallenge(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -765,6 +722,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -778,89 +736,3 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createMetaObject(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
p *provisioner.ACME
|
||||
want *Meta
|
||||
}{
|
||||
{
|
||||
name: "no-meta",
|
||||
p: &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "acme",
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "terms-of-service",
|
||||
p: &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "acme",
|
||||
TermsOfService: "https://terms.ca.local",
|
||||
},
|
||||
want: &Meta{
|
||||
TermsOfService: "https://terms.ca.local",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "website",
|
||||
p: &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "acme",
|
||||
Website: "https://ca.local",
|
||||
},
|
||||
want: &Meta{
|
||||
Website: "https://ca.local",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "caa",
|
||||
p: &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "acme",
|
||||
CaaIdentities: []string{"ca.local", "ca.remote"},
|
||||
},
|
||||
want: &Meta{
|
||||
CaaIdentities: []string{"ca.local", "ca.remote"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "require-eab",
|
||||
p: &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "acme",
|
||||
RequireEAB: true,
|
||||
},
|
||||
want: &Meta{
|
||||
ExternalAccountRequired: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full-meta",
|
||||
p: &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "acme",
|
||||
TermsOfService: "https://terms.ca.local",
|
||||
Website: "https://ca.local",
|
||||
CaaIdentities: []string{"ca.local", "ca.remote"},
|
||||
RequireEAB: true,
|
||||
},
|
||||
want: &Meta{
|
||||
TermsOfService: "https://terms.ca.local",
|
||||
Website: "https://ca.local",
|
||||
CaaIdentities: []string{"ca.local", "ca.remote"},
|
||||
ExternalAccountRequired: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := createMetaObject(tt.p)
|
||||
if !cmp.Equal(tt.want, got) {
|
||||
t.Errorf("createMetaObject() diff =\n%s", cmp.Diff(tt.want, got))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,19 +1,100 @@
|
|||
package acme
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
)
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
_, _, err := net.SplitHostPort(dns)
|
||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||
// these cases, then the input dns is not changed.
|
||||
lastIndex := strings.LastIndex(dns, ":")
|
||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||
if ip := net.ParseIP(hostPart); ip != nil {
|
||||
dns = "[" + hostPart + "]:" + portPart
|
||||
} else if ip := net.ParseIP(dns); ip != nil {
|
||||
dns = "[" + dns + "]"
|
||||
}
|
||||
}
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
|
||||
|
||||
LinkOrder(ctx context.Context, o *acme.Order)
|
||||
LinkAccount(ctx context.Context, o *acme.Account)
|
||||
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *acme.Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var (
|
||||
provName string
|
||||
baseURL = baseURLFromContext(ctx)
|
||||
u = url.URL{}
|
||||
)
|
||||
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
|
||||
provName = p.GetName()
|
||||
}
|
||||
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
|
||||
if baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
|
||||
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
|
||||
|
||||
// If no Scheme is set, then default to https.
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
|
||||
// If no Host is set, then use the default (first DNS attr in the ca.json).
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
u.Path = l.prefix + u.Path
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// LinkType captures the link type.
|
||||
type LinkType int
|
||||
|
||||
|
@ -79,155 +160,8 @@ func (l LinkType) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
_, _, err := net.SplitHostPort(dns)
|
||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||
// these cases, then the input dns is not changed.
|
||||
lastIndex := strings.LastIndex(dns, ":")
|
||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||
if ip := net.ParseIP(hostPart); ip != nil {
|
||||
dns = "[" + hostPart + "]:" + portPart
|
||||
} else if ip := net.ParseIP(dns); ip != nil {
|
||||
dns = "[" + dns + "]"
|
||||
}
|
||||
}
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||
Middleware(http.Handler) http.Handler
|
||||
LinkOrder(ctx context.Context, o *Order)
|
||||
LinkAccount(ctx context.Context, o *Account)
|
||||
LinkChallenge(ctx context.Context, o *Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
type linkerKey struct{}
|
||||
|
||||
// NewLinkerContext adds the given linker to the context.
|
||||
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
|
||||
return context.WithValue(ctx, linkerKey{}, v)
|
||||
}
|
||||
|
||||
// LinkerFromContext returns the current linker from the given context.
|
||||
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
|
||||
v, ok = ctx.Value(linkerKey{}).(Linker)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current linker from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustLinkerFromContext(ctx context.Context) Linker {
|
||||
if v, ok := LinkerFromContext(ctx); !ok {
|
||||
panic("acme linker is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
type baseURLKey struct{}
|
||||
|
||||
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
|
||||
var u *url.URL
|
||||
if r.Host != "" {
|
||||
u = &url.URL{Scheme: "https", Host: r.Host}
|
||||
}
|
||||
return context.WithValue(ctx, baseURLKey{}, u)
|
||||
}
|
||||
|
||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
|
||||
return u
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
// Middleware gets the provisioner and current url from the request and sets
|
||||
// them in the context so we can use the linker to create ACME links.
|
||||
func (l *linker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add base url to the context.
|
||||
ctx := newBaseURLContext(r.Context(), r)
|
||||
|
||||
// Add provisioner to the context.
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
|
||||
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit.
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var name string
|
||||
if p, ok := ProvisionerFromContext(ctx); ok {
|
||||
name = p.GetName()
|
||||
}
|
||||
|
||||
var u url.URL
|
||||
if baseURL := baseURLFromContext(ctx); baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// LinkOrder sets the ACME links required by an ACME order.
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *Order) {
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
||||
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
|
||||
for i, azID := range o.AuthorizationIDs {
|
||||
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
|
||||
|
@ -239,17 +173,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *Order) {
|
|||
}
|
||||
|
||||
// LinkAccount sets the ACME links required by an ACME account.
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
|
||||
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
|
||||
}
|
||||
|
||||
// LinkChallenge sets the ACME links required by an ACME challenge.
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
|
||||
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
|
||||
}
|
||||
|
||||
// LinkAuthorization sets the ACME links required by an ACME authorization.
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
|
||||
for _, ch := range az.Challenges {
|
||||
l.LinkChallenge(ctx, ch, az.ID)
|
||||
}
|
|
@ -1,38 +1,21 @@
|
|||
package acme
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
)
|
||||
|
||||
func mockProvisioner(t *testing.T) Provisioner {
|
||||
t.Helper()
|
||||
var defaultDisableRenewal = false
|
||||
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
linker := NewLinker(dns, prefix)
|
||||
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-<test>provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestGetUnescapedPathSuffix(t *testing.T) {
|
||||
getPath := GetUnescapedPathSuffix
|
||||
getPath := linker.GetUnescapedPathSuffix
|
||||
|
||||
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
|
||||
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
|
||||
|
@ -49,9 +32,9 @@ func TestGetUnescapedPathSuffix(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLinker_DNS(t *testing.T) {
|
||||
prov := mockProvisioner(t)
|
||||
prov := newProv()
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
type test struct {
|
||||
name string
|
||||
dns string
|
||||
|
@ -134,19 +117,19 @@ func TestLinker_GetLink(t *testing.T) {
|
|||
linker := NewLinker(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := mockProvisioner(t)
|
||||
prov := newProv()
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
|
||||
// No provisioner and no BaseURL from request
|
||||
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
|
||||
// Provisioner: yes, BaseURL: no
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
|
||||
// Provisioner: no, BaseURL: yes
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
|
@ -180,37 +163,37 @@ func TestLinker_GetLink(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkOrder(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
oid := "orderID"
|
||||
certID := "certID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
o *Order
|
||||
validate func(o *Order)
|
||||
o *acme.Order
|
||||
validate func(o *acme.Order)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"no-authz-and-no-cert": {
|
||||
o: &Order{
|
||||
o: &acme.Order{
|
||||
ID: oid,
|
||||
},
|
||||
validate: func(o *Order) {
|
||||
validate: func(o *acme.Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{})
|
||||
assert.Equals(t, o.CertificateURL, "")
|
||||
},
|
||||
},
|
||||
"one-authz-and-cert": {
|
||||
o: &Order{
|
||||
o: &acme.Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo"},
|
||||
},
|
||||
validate: func(o *Order) {
|
||||
validate: func(o *acme.Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
|
@ -219,12 +202,12 @@ func TestLinker_LinkOrder(t *testing.T) {
|
|||
},
|
||||
},
|
||||
"many-authz": {
|
||||
o: &Order{
|
||||
o: &acme.Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo", "bar", "zap"},
|
||||
},
|
||||
validate: func(o *Order) {
|
||||
validate: func(o *acme.Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
|
@ -245,24 +228,24 @@ func TestLinker_LinkOrder(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkAccount(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
accID := "accountID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
a *Account
|
||||
validate func(o *Account)
|
||||
a *acme.Account
|
||||
validate func(o *acme.Account)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
a: &Account{
|
||||
a: &acme.Account{
|
||||
ID: accID,
|
||||
},
|
||||
validate: func(a *Account) {
|
||||
validate: func(a *acme.Account) {
|
||||
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
|
||||
},
|
||||
},
|
||||
|
@ -277,25 +260,25 @@ func TestLinker_LinkAccount(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkChallenge(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
chID := "chID"
|
||||
azID := "azID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
ch *Challenge
|
||||
validate func(o *Challenge)
|
||||
ch *acme.Challenge
|
||||
validate func(o *acme.Challenge)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
ch: &Challenge{
|
||||
ch: &acme.Challenge{
|
||||
ID: chID,
|
||||
},
|
||||
validate: func(ch *Challenge) {
|
||||
validate: func(ch *acme.Challenge) {
|
||||
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
|
||||
},
|
||||
},
|
||||
|
@ -310,10 +293,10 @@ func TestLinker_LinkChallenge(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkAuthorization(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
chID0 := "chID-0"
|
||||
chID1 := "chID-1"
|
||||
|
@ -322,20 +305,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
|||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
az *Authorization
|
||||
validate func(o *Authorization)
|
||||
az *acme.Authorization
|
||||
validate func(o *acme.Authorization)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
az: &Authorization{
|
||||
az: &acme.Authorization{
|
||||
ID: azID,
|
||||
Challenges: []*Challenge{
|
||||
Challenges: []*acme.Challenge{
|
||||
{ID: chID0},
|
||||
{ID: chID1},
|
||||
{ID: chID2},
|
||||
},
|
||||
},
|
||||
validate: func(az *Authorization) {
|
||||
validate: func(az *acme.Authorization) {
|
||||
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
|
||||
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
|
||||
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
|
||||
|
@ -352,10 +335,10 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := mockProvisioner(t)
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
|
@ -7,9 +7,9 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
|
||||
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
@ -30,11 +31,39 @@ func logNonce(w http.ResponseWriter, nonce string) {
|
|||
}
|
||||
}
|
||||
|
||||
// addNonce is a middleware that adds a nonce to the response header.
|
||||
func addNonce(next nextHTTP) nextHTTP {
|
||||
// baseURLFromRequest determines the base URL which should be used for
|
||||
// constructing link URLs in e.g. the ACME directory result by taking the
|
||||
// request Host into consideration.
|
||||
//
|
||||
// If the Request.Host is an empty string, we return an empty string, to
|
||||
// indicate that the configured URL values should be used instead. If this
|
||||
// function returns a non-empty result, then this should be used in
|
||||
// constructing ACME link URLs.
|
||||
func baseURLFromRequest(r *http.Request) *url.URL {
|
||||
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
|
||||
// for an implementation that allows HTTP requests using the x-forwarded-proto
|
||||
// header.
|
||||
|
||||
if r.Host == "" {
|
||||
return nil
|
||||
}
|
||||
return &url.URL{Scheme: "https", Host: r.Host}
|
||||
}
|
||||
|
||||
// baseURLFromRequest is a middleware that extracts and caches the baseURL
|
||||
// from the request.
|
||||
// E.g. https://ca.smallstep.com/
|
||||
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
db := acme.MustDatabaseFromContext(r.Context())
|
||||
nonce, err := db.CreateNonce(r.Context())
|
||||
ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// addNonce is a middleware that adds a nonce to the response header.
|
||||
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
nonce, err := h.db.CreateNonce(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -48,31 +77,25 @@ func addNonce(next nextHTTP) nextHTTP {
|
|||
|
||||
// addDirLink is a middleware that adds a 'Link' response reader with the
|
||||
// directory index url.
|
||||
func addDirLink(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index"))
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// verifyContentType is a middleware that verifies that content type is
|
||||
// application/jose+json.
|
||||
func verifyContentType(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var expected []string
|
||||
p, err := provisionerFromContext(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
|
||||
}
|
||||
|
||||
var expected []string
|
||||
u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
|
||||
if strings.Contains(r.URL.String(), u.EscapedPath()) {
|
||||
// GET /certificate requests allow a greater range of content types.
|
||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||
|
@ -94,7 +117,7 @@ func verifyContentType(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
|
||||
func parseJWS(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
|
@ -119,19 +142,17 @@ func parseJWS(next nextHTTP) nextHTTP {
|
|||
// The JWS Unprotected Header [RFC7515] MUST NOT be used
|
||||
// The JWS Payload MUST NOT be detached
|
||||
// The JWS Protected Header MUST include the following fields:
|
||||
// - “alg” (Algorithm).
|
||||
// This field MUST NOT contain “none” or a Message Authentication Code
|
||||
// (MAC) algorithm (e.g. one in which the algorithm registry description
|
||||
// mentions MAC/HMAC).
|
||||
// - “nonce” (defined in Section 6.5)
|
||||
// - “url” (defined in Section 6.4)
|
||||
// - Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func validateJWS(next nextHTTP) nextHTTP {
|
||||
// * “alg” (Algorithm)
|
||||
// * This field MUST NOT contain “none” or a Message Authentication Code
|
||||
// (MAC) algorithm (e.g. one in which the algorithm registry description
|
||||
// mentions MAC/HMAC).
|
||||
// * “nonce” (defined in Section 6.5)
|
||||
// * “url” (defined in Section 6.4)
|
||||
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
jws, err := jwsFromContext(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -181,7 +202,7 @@ func validateJWS(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// Check the validity/freshness of the Nonce.
|
||||
if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
@ -214,12 +235,10 @@ func validateJWS(next nextHTTP) nextHTTP {
|
|||
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
|
||||
// in the context. Make sure to parse and validate the JWS before running this
|
||||
// middleware.
|
||||
func extractJWK(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
jws, err := jwsFromContext(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -245,7 +264,7 @@ func extractJWK(next nextHTTP) nextHTTP {
|
|||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
|
||||
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
|
||||
acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||
switch {
|
||||
case errors.Is(err, acme.ErrNotFound):
|
||||
// For NewAccount and Revoke requests ...
|
||||
|
@ -264,52 +283,75 @@ func extractJWK(next nextHTTP) nextHTTP {
|
|||
}
|
||||
}
|
||||
|
||||
// checkPrerequisites checks if all prerequisites for serving ACME
|
||||
// are met by the CA configuration.
|
||||
func checkPrerequisites(next nextHTTP) nextHTTP {
|
||||
// lookupProvisioner loads the provisioner associated with the request.
|
||||
// Responds 404 if the provisioner does not exist.
|
||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
// If the function is not set assume that all prerequisites are met.
|
||||
checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
|
||||
if ok {
|
||||
ok, err := checkFunc(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
p, err := h.ca.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// checkPrerequisites checks if all prerequisites for serving ACME
|
||||
// are met by the CA configuration.
|
||||
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ok, err := h.prerequisitesChecker(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||
// kid parameter of the signed payload.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func lookupJWK(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
|
||||
kid := jws.Signatures[0].Protected.KeyID
|
||||
if kid == "" {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"))
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"kid does not have required prefix; expected %s, but got %s",
|
||||
kidPrefix, kid))
|
||||
return
|
||||
}
|
||||
|
||||
accID := path.Base(kid)
|
||||
acc, err := db.GetAccount(ctx, accID)
|
||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||
acc, err := h.db.GetAccount(ctx, accID)
|
||||
switch {
|
||||
case acme.IsErrNotFound(err):
|
||||
case nosql.IsErrNotFound(err):
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||
return
|
||||
case err != nil:
|
||||
|
@ -320,45 +362,6 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
|||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||
return
|
||||
}
|
||||
|
||||
if storedLocation := acc.GetLocation(); storedLocation != "" {
|
||||
if kid != storedLocation {
|
||||
// ACME accounts should have a stored location equivalent to the
|
||||
// kid in the ACME request.
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"kid does not match stored account location; expected %s, but got %s",
|
||||
storedLocation, kid))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify that the provisioner with which the account was created
|
||||
// matches the provisioner in the request URL.
|
||||
reqProv := acme.MustProvisionerFromContext(ctx)
|
||||
reqProvName := reqProv.GetName()
|
||||
accProvName := acc.ProvisionerName
|
||||
if reqProvName != accProvName {
|
||||
// Provisioner in the URL must match the provisioner with
|
||||
// which the account was created.
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s",
|
||||
accProvName, reqProvName))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// This code will only execute for old ACME accounts that do
|
||||
// not have a cached location. The following validation was
|
||||
// the original implementation of the `kid` check which has
|
||||
// since been deprecated. However, the code will remain to
|
||||
// ensure consistent behavior for old ACME accounts.
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"kid does not have required prefix; expected %s, but got %s",
|
||||
kidPrefix, kid))
|
||||
return
|
||||
}
|
||||
}
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
|
||||
next(w, r.WithContext(ctx))
|
||||
|
@ -369,7 +372,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
|||
|
||||
// extractOrLookupJWK forwards handling to either extractJWK or
|
||||
// lookupJWK based on the presence of a JWK or a KID, respectively.
|
||||
func extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
|
@ -382,13 +385,13 @@ func extractOrLookupJWK(next nextHTTP) nextHTTP {
|
|||
// and it can be used to check if a JWK exists. This flow is used when the ACME client
|
||||
// signed the payload with a certificate private key.
|
||||
if canExtractJWKFrom(jws) {
|
||||
extractJWK(next)(w, r)
|
||||
h.extractJWK(next)(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
|
||||
// signed the payload with an account private key.
|
||||
lookupJWK(next)(w, r)
|
||||
h.lookupJWK(next)(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -405,7 +408,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
|
|||
|
||||
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
|
@ -437,7 +440,7 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
|
||||
func isPostAsGet(next nextHTTP) nextHTTP {
|
||||
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, err := payloadFromContext(r.Context())
|
||||
if err != nil {
|
||||
|
@ -459,12 +462,16 @@ type ContextKey string
|
|||
const (
|
||||
// accContextKey account key
|
||||
accContextKey = ContextKey("acc")
|
||||
// baseURLContextKey baseURL key
|
||||
baseURLContextKey = ContextKey("baseURL")
|
||||
// jwsContextKey jws key
|
||||
jwsContextKey = ContextKey("jws")
|
||||
// jwkContextKey jwk key
|
||||
jwkContextKey = ContextKey("jwk")
|
||||
// payloadContextKey payload key
|
||||
payloadContextKey = ContextKey("payload")
|
||||
// provisionerContextKey provisioner key
|
||||
provisionerContextKey = ContextKey("provisioner")
|
||||
)
|
||||
|
||||
// accountFromContext searches the context for an ACME account. Returns the
|
||||
|
@ -477,6 +484,15 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) {
|
|||
return val, nil
|
||||
}
|
||||
|
||||
// baseURLFromContext returns the baseURL if one is stored in the context.
|
||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
|
||||
if !ok || val == nil {
|
||||
return nil
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
|
||||
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
|
||||
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
|
||||
|
@ -498,26 +514,29 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
|
|||
// provisionerFromContext searches the context for a provisioner. Returns the
|
||||
// provisioner or an error.
|
||||
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
|
||||
p, ok := acme.ProvisionerFromContext(ctx)
|
||||
if !ok || p == nil {
|
||||
val := ctx.Value(provisionerContextKey)
|
||||
if val == nil {
|
||||
return nil, acme.NewErrorISE("provisioner expected in request context")
|
||||
}
|
||||
return p, nil
|
||||
pval, ok := val.(acme.Provisioner)
|
||||
if !ok || pval == nil {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
return pval, nil
|
||||
}
|
||||
|
||||
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
|
||||
// pointer to an ACME provisioner or an error.
|
||||
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
|
||||
p, err := provisionerFromContext(ctx)
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ap, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
acmeProv, ok := prov.(*provisioner.ACME)
|
||||
if !ok || acmeProv == nil {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
|
||||
return ap, nil
|
||||
return acmeProv, nil
|
||||
}
|
||||
|
||||
// payloadFromContext searches the context for a payload. Returns the payload
|
||||
|
|
|
@ -17,28 +17,93 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql/database"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
)
|
||||
|
||||
var testBody = []byte("foo")
|
||||
|
||||
func testNext(w http.ResponseWriter, _ *http.Request) {
|
||||
func testNext(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(testBody)
|
||||
}
|
||||
|
||||
func newBaseContext(ctx context.Context, args ...interface{}) context.Context {
|
||||
for _, a := range args {
|
||||
switch v := a.(type) {
|
||||
case acme.DB:
|
||||
ctx = acme.NewDatabaseContext(ctx, v)
|
||||
case acme.Linker:
|
||||
ctx = acme.NewLinkerContext(ctx, v)
|
||||
case acme.PrerequisitesChecker:
|
||||
ctx = acme.NewPrerequisitesCheckerContext(ctx, v)
|
||||
func Test_baseURLFromRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
targetURL string
|
||||
expectedResult *url.URL
|
||||
requestPreparer func(*http.Request)
|
||||
}{
|
||||
{
|
||||
"HTTPS host pass-through failed.",
|
||||
"https://my.dummy.host",
|
||||
&url.URL{Scheme: "https", Host: "my.dummy.host"},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Port pass-through failed",
|
||||
"https://host.with.port:8080",
|
||||
&url.URL{Scheme: "https", Host: "host.with.port:8080"},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Explicit host from Request.Host was not used.",
|
||||
"https://some.target.host:8080",
|
||||
&url.URL{Scheme: "https", Host: "proxied.host"},
|
||||
func(r *http.Request) {
|
||||
r.Host = "proxied.host"
|
||||
},
|
||||
},
|
||||
{
|
||||
"Missing Request.Host value did not result in empty string result.",
|
||||
"https://some.host",
|
||||
nil,
|
||||
func(r *http.Request) {
|
||||
r.Host = ""
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
request := httptest.NewRequest("GET", tc.targetURL, nil)
|
||||
if tc.requestPreparer != nil {
|
||||
tc.requestPreparer(request)
|
||||
}
|
||||
result := baseURLFromRequest(request)
|
||||
if result == nil || tc.expectedResult == nil {
|
||||
assert.Equals(t, result, tc.expectedResult)
|
||||
} else if result.String() != tc.expectedResult.String() {
|
||||
t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_baseURLFromRequest(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req.Host = "test.ca.smallstep.com:8080"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
next := func(w http.ResponseWriter, r *http.Request) {
|
||||
bu := baseURLFromContext(r.Context())
|
||||
if assert.NotNil(t, bu) {
|
||||
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
|
||||
assert.Equals(t, bu.Scheme, "https")
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
|
||||
h.baseURLFromRequest(next)(w, req)
|
||||
|
||||
req = httptest.NewRequest("GET", "/foo", nil)
|
||||
req.Host = ""
|
||||
|
||||
next = func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, baseURLFromContext(r.Context()), nil)
|
||||
}
|
||||
|
||||
h.baseURLFromRequest(next)(w, req)
|
||||
}
|
||||
|
||||
func TestHandler_addNonce(t *testing.T) {
|
||||
|
@ -74,10 +139,10 @@ func TestHandler_addNonce(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := newBaseContext(context.Background(), tc.db)
|
||||
req := httptest.NewRequest("GET", u, nil).WithContext(ctx)
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
w := httptest.NewRecorder()
|
||||
addNonce(testNext)(w, req)
|
||||
h.addNonce(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -92,6 +157,7 @@ func TestHandler_addNonce(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -109,15 +175,17 @@ func TestHandler_addDirLink(t *testing.T) {
|
|||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
type test struct {
|
||||
link string
|
||||
linker Linker
|
||||
statusCode int
|
||||
ctx context.Context
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"ok": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
|
||||
statusCode: 200,
|
||||
|
@ -127,10 +195,11 @@ func TestHandler_addDirLink(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: tc.linker}
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
addDirLink(testNext)(w, req)
|
||||
h.addDirLink(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -145,6 +214,7 @@ func TestHandler_addDirLink(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -161,6 +231,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
||||
type test struct {
|
||||
h Handler
|
||||
ctx context.Context
|
||||
contentType string
|
||||
err *acme.Error
|
||||
|
@ -170,6 +241,9 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/provisioner-not-set": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
url: u,
|
||||
ctx: context.Background(),
|
||||
contentType: "foo",
|
||||
|
@ -179,8 +253,11 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
},
|
||||
"fail/general-bad-content-type": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
url: u,
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
contentType: "foo",
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
|
||||
|
@ -188,7 +265,10 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
},
|
||||
"fail/certificate-bad-content-type": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
contentType: "foo",
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
|
||||
|
@ -196,28 +276,40 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
contentType: "application/jose+json",
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/certificate/pkix-cert": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
contentType: "application/pkix-cert",
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/certificate/jose+json": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
contentType: "application/jose+json",
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/certificate/pkcs7-mime": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
contentType: "application/pkcs7-mime",
|
||||
statusCode: 200,
|
||||
}
|
||||
|
@ -234,7 +326,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
req = req.WithContext(tc.ctx)
|
||||
req.Header.Add("Content-Type", tc.contentType)
|
||||
w := httptest.NewRecorder()
|
||||
verifyContentType(testNext)(w, req)
|
||||
tc.h.verifyContentType(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -249,6 +341,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -297,11 +390,11 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// h := &Handler{}
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
isPostAsGet(testNext)(w, req)
|
||||
h.isPostAsGet(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -316,6 +409,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -327,7 +421,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
|||
|
||||
type errReader int
|
||||
|
||||
func (errReader) Read([]byte) (int, error) {
|
||||
func (errReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("force")
|
||||
}
|
||||
func (errReader) Close() error {
|
||||
|
@ -387,10 +481,10 @@ func TestHandler_parseJWS(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// h := &Handler{}
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest("GET", u, tc.body)
|
||||
w := httptest.NewRecorder()
|
||||
parseJWS(tc.next)(w, req)
|
||||
h.parseJWS(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -405,6 +499,7 @@ func TestHandler_parseJWS(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -512,6 +607,9 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok/empty-algorithm-in-jwk": func(t *testing.T) test {
|
||||
_pub := *pub
|
||||
clone := &_pub
|
||||
clone.Algorithm = ""
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, pub)
|
||||
return test{
|
||||
|
@ -581,11 +679,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// h := &Handler{}
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -600,6 +698,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -634,7 +733,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
linker acme.Linker
|
||||
linker Linker
|
||||
db acme.DB
|
||||
ctx context.Context
|
||||
next func(http.ResponseWriter, *http.Request)
|
||||
|
@ -644,19 +743,15 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
}
|
||||
},
|
||||
"fail/nil-jws": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -670,25 +765,50 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
_jws, err := _signer.Sign([]byte("baz"))
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"),
|
||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
|
||||
}
|
||||
},
|
||||
"fail/bad-kid-prefix": func(t *testing.T) test {
|
||||
_so := new(jose.SignerOptions)
|
||||
_so.WithHeader("kid", "foo")
|
||||
_signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||
Key: jwk.Key,
|
||||
}, _so)
|
||||
assert.FatalError(t, err)
|
||||
_jws, err := _signer.Sign([]byte("baz"))
|
||||
assert.FatalError(t, err)
|
||||
_raw, err := _jws.CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
_parsed, err := jose.ParseJWS(_raw)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _parsed)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
|
||||
}
|
||||
},
|
||||
"fail/account-not-found": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||
assert.Equals(t, accID, accID)
|
||||
return nil, acme.ErrNotFound
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
|
@ -697,10 +817,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/GetAccount-error": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
|
@ -714,10 +835,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-valid": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "deactivated"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
|
@ -729,82 +851,13 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"),
|
||||
}
|
||||
},
|
||||
"fail/account-with-location-prefix/bad-kid": func(t *testing.T) test {
|
||||
acc := &acme.Account{LocationPrefix: "foobar", Status: "valid"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
return acc, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: http.StatusUnauthorized,
|
||||
err: acme.NewError(acme.ErrorUnauthorizedType, "kid does not match stored account location; expected foobar, but %q", prefix+accID),
|
||||
}
|
||||
},
|
||||
"fail/account-with-location-prefix/bad-provisioner": func(t *testing.T) test {
|
||||
acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: "other"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
return acc, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
_acc, err := accountFromContext(r.Context())
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, _acc, acc)
|
||||
_jwk, err := jwkFromContext(r.Context())
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, _jwk, jwk)
|
||||
w.Write(testBody)
|
||||
},
|
||||
statusCode: http.StatusUnauthorized,
|
||||
err: acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account provisioner does not match requested provisioner; account provisioner = %s, reqested provisioner = %s",
|
||||
prov.GetName(), "other"),
|
||||
}
|
||||
},
|
||||
"ok/account-with-location-prefix": func(t *testing.T) test {
|
||||
acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: prov.GetName()}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
return acc, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
_acc, err := accountFromContext(r.Context())
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, _acc, acc)
|
||||
_jwk, err := jwkFromContext(r.Context())
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, _jwk, jwk)
|
||||
w.Write(testBody)
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
},
|
||||
"ok/account-without-location-prefix": func(t *testing.T) test {
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "valid", Key: jwk}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
|
@ -828,11 +881,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
|
||||
h := &Handler{db: tc.db, linker: tc.linker}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
lookupJWK(tc.next)(w, req)
|
||||
h.lookupJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -847,6 +900,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -891,17 +945,15 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
}
|
||||
},
|
||||
"fail/nil-jws": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -917,10 +969,9 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
|
||||
|
@ -936,17 +987,16 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
|
||||
}
|
||||
},
|
||||
"fail/GetAccountByKey-error": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -962,7 +1012,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-valid": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "deactivated"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -978,7 +1028,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "valid"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1001,7 +1051,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok/no-account": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1027,11 +1077,11 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := newBaseContext(tc.ctx, tc.db)
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
extractJWK(tc.next)(w, req)
|
||||
h.extractJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1046,6 +1096,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -1067,7 +1118,6 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -1075,7 +1125,6 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
"fail/nil-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, nil),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -1083,7 +1132,6 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
"fail/no-signature": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
|
||||
|
@ -1097,7 +1145,6 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
|
||||
|
@ -1110,7 +1157,6 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
|
||||
|
@ -1123,7 +1169,6 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
|
||||
|
@ -1136,7 +1181,6 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
|
||||
|
@ -1171,8 +1215,6 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/rsa-key-too-small": func(t *testing.T) test {
|
||||
revert := keyutil.Insecure()
|
||||
defer revert()
|
||||
jwk, err := jose.GenerateJWK("RSA", "", "", "sig", "", 1024)
|
||||
assert.FatalError(t, err)
|
||||
pub := jwk.Public()
|
||||
|
@ -1402,11 +1444,11 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := newBaseContext(tc.ctx, tc.db)
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
validateJWS(tc.next)(w, req)
|
||||
h.validateJWS(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1421,6 +1463,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -1499,7 +1542,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
u := "https://ca.smallstep.com/acme/account"
|
||||
type test struct {
|
||||
db acme.DB
|
||||
linker acme.Linker
|
||||
linker Linker
|
||||
statusCode int
|
||||
ctx context.Context
|
||||
err *acme.Error
|
||||
|
@ -1527,7 +1570,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
||||
assert.Equals(t, kid, pub.KeyID)
|
||||
|
@ -1563,10 +1606,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
linker: NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
|
@ -1584,11 +1628,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
|
||||
h := &Handler{db: tc.db, linker: tc.linker}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
extractOrLookupJWK(tc.next)(w, req)
|
||||
h.extractOrLookupJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1603,6 +1647,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -1619,7 +1664,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||
baseURL, provName)
|
||||
type test struct {
|
||||
linker acme.Linker
|
||||
linker Linker
|
||||
ctx context.Context
|
||||
prerequisitesChecker func(context.Context) (bool, error)
|
||||
next func(http.ResponseWriter, *http.Request)
|
||||
|
@ -1628,9 +1673,10 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/error": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1641,9 +1687,10 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/prerequisites-nok": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1654,9 +1701,10 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1669,11 +1717,11 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker)
|
||||
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
checkPrerequisites(tc.next)(w, req)
|
||||
h.checkPrerequisites(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1687,6 +1735,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
|
|
@ -13,12 +13,9 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
|
||||
"go.step.sm/crypto/randutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// NewOrderRequest represents the body for a NewOrder request.
|
||||
|
@ -34,26 +31,12 @@ func (n *NewOrderRequest) Validate() error {
|
|||
return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty")
|
||||
}
|
||||
for _, id := range n.Identifiers {
|
||||
switch id.Type {
|
||||
case acme.IP:
|
||||
if net.ParseIP(id.Value) == nil {
|
||||
return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value)
|
||||
}
|
||||
case acme.DNS:
|
||||
value, _ := trimIfWildcard(id.Value)
|
||||
if _, err := x509util.SanitizeName(value); err != nil {
|
||||
return acme.NewError(acme.ErrorMalformedType, "invalid DNS name: %s", id.Value)
|
||||
}
|
||||
case acme.PermanentIdentifier:
|
||||
if id.Value == "" {
|
||||
return acme.NewError(acme.ErrorMalformedType, "permanent identifier cannot be empty")
|
||||
}
|
||||
default:
|
||||
if !(id.Type == acme.DNS || id.Type == acme.IP) {
|
||||
return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type)
|
||||
}
|
||||
|
||||
// TODO(hs): add some validations for DNS domains?
|
||||
// TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1
|
||||
if id.Type == acme.IP && net.ParseIP(id.Value) == nil {
|
||||
return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -67,13 +50,7 @@ type FinalizeRequest struct {
|
|||
// Validate validates a finalize request body.
|
||||
func (f *FinalizeRequest) Validate() error {
|
||||
var err error
|
||||
// RFC 8555 isn't 100% conclusive about using raw base64-url encoding for the
|
||||
// CSR specifically, instead of "normal" base64-url encoding (incl. padding).
|
||||
// By trimming the padding from CSRs submitted by ACME clients that use
|
||||
// base64-url encoding instead of raw base64-url encoding, these are also
|
||||
// supported. This was reported in https://github.com/smallstep/certificates/issues/939
|
||||
// to be the case for a Synology DSM NAS system.
|
||||
csrBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(f.CSR, "="))
|
||||
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
|
||||
if err != nil {
|
||||
return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr")
|
||||
}
|
||||
|
@ -91,12 +68,8 @@ var defaultOrderExpiry = time.Hour * 24
|
|||
var defaultOrderBackdate = time.Minute
|
||||
|
||||
// NewOrder ACME api for creating a new order.
|
||||
func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ca := mustAuthority(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -112,7 +85,6 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var nor NewOrderRequest
|
||||
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
|
@ -125,48 +97,6 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// TODO(hs): gather all errors, so that we can build one response with ACME subproblems
|
||||
// include the nor.Validate() error here too, like in the example in the ACME RFC?
|
||||
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var eak *acme.ExternalAccountKey
|
||||
if acmeProv.RequireEAB {
|
||||
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
acmePolicy, err := newACMEPolicyEngine(eak)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine"))
|
||||
return
|
||||
}
|
||||
|
||||
for _, identifier := range nor.Identifiers {
|
||||
// evaluate the ACME account level policy
|
||||
if err = isIdentifierAllowed(acmePolicy, identifier); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
// evaluate the provisioner level policy
|
||||
orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
|
||||
if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
// evaluate the authority level policy
|
||||
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
// New order.
|
||||
o := &acme.Order{
|
||||
|
@ -187,7 +117,7 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
ExpiresAt: o.ExpiresAt,
|
||||
Status: acme.StatusPending,
|
||||
}
|
||||
if err := newAuthorization(ctx, az); err != nil {
|
||||
if err := h.newAuthorization(ctx, az); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
@ -206,44 +136,24 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
|
||||
}
|
||||
|
||||
if err := db.CreateOrder(ctx, o); err != nil {
|
||||
if err := h.db.CreateOrder(ctx, o); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrder(ctx, o)
|
||||
h.linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
render.JSONStatus(w, o, http.StatusCreated)
|
||||
}
|
||||
|
||||
func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error {
|
||||
if acmePolicy == nil {
|
||||
return nil
|
||||
}
|
||||
return acmePolicy.AreSANsAllowed([]string{identifier.Value})
|
||||
}
|
||||
|
||||
func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error) {
|
||||
if eak == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return policy.NewX509PolicyEngine(eak.Policy)
|
||||
}
|
||||
|
||||
func trimIfWildcard(value string) (string, bool) {
|
||||
if strings.HasPrefix(value, "*.") {
|
||||
return strings.TrimPrefix(value, "*."), true
|
||||
}
|
||||
return value, false
|
||||
}
|
||||
|
||||
func newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
value, isWildcard := trimIfWildcard(az.Identifier.Value)
|
||||
az.Wildcard = isWildcard
|
||||
az.Identifier = acme.Identifier{
|
||||
Value: value,
|
||||
Type: az.Identifier.Type,
|
||||
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
if strings.HasPrefix(az.Identifier.Value, "*.") {
|
||||
az.Wildcard = true
|
||||
az.Identifier = acme.Identifier{
|
||||
Value: strings.TrimPrefix(az.Identifier.Value, "*."),
|
||||
Type: az.Identifier.Type,
|
||||
}
|
||||
}
|
||||
|
||||
chTypes := challengeTypes(az)
|
||||
|
@ -253,15 +163,8 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
|||
if err != nil {
|
||||
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
|
||||
}
|
||||
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
prov := acme.MustProvisionerFromContext(ctx)
|
||||
az.Challenges = make([]*acme.Challenge, 0, len(chTypes))
|
||||
for _, typ := range chTypes {
|
||||
if !prov.IsChallengeEnabled(ctx, provisioner.ACMEChallenge(typ)) {
|
||||
continue
|
||||
}
|
||||
|
||||
az.Challenges = make([]*acme.Challenge, len(chTypes))
|
||||
for i, typ := range chTypes {
|
||||
ch := &acme.Challenge{
|
||||
AccountID: az.AccountID,
|
||||
Value: az.Identifier.Value,
|
||||
|
@ -269,23 +172,20 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
|||
Token: az.Token,
|
||||
Status: acme.StatusPending,
|
||||
}
|
||||
if err := db.CreateChallenge(ctx, ch); err != nil {
|
||||
if err := h.db.CreateChallenge(ctx, ch); err != nil {
|
||||
return acme.WrapErrorISE(err, "error creating challenge")
|
||||
}
|
||||
az.Challenges = append(az.Challenges, ch)
|
||||
az.Challenges[i] = ch
|
||||
}
|
||||
if err = db.CreateAuthorization(ctx, az); err != nil {
|
||||
if err = h.db.CreateAuthorization(ctx, az); err != nil {
|
||||
return acme.WrapErrorISE(err, "error creating authorization")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOrder ACME api for retrieving an order.
|
||||
func GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -296,8 +196,7 @@ func GetOrder(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
|
@ -312,23 +211,20 @@ func GetOrder(w http.ResponseWriter, r *http.Request) {
|
|||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
return
|
||||
}
|
||||
if err = o.UpdateStatus(ctx, db); err != nil {
|
||||
if err = o.UpdateStatus(ctx, h.db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrder(ctx, o)
|
||||
h.linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
}
|
||||
|
||||
// FinalizeOrder attempts to finalize an order and create a certificate.
|
||||
func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -355,7 +251,7 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
|
@ -370,16 +266,14 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
|||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
return
|
||||
}
|
||||
|
||||
ca := mustAuthority(ctx)
|
||||
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
|
||||
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrder(ctx, o)
|
||||
h.linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
}
|
||||
|
||||
|
@ -397,8 +291,6 @@ func challengeTypes(az *acme.Authorization) []acme.ChallengeType {
|
|||
if !az.Wildcard {
|
||||
chTypes = append(chTypes, []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}...)
|
||||
}
|
||||
case acme.PermanentIdentifier:
|
||||
chTypes = []acme.ChallengeType{acme.DEVICEATTEST01}
|
||||
default:
|
||||
chTypes = []acme.ChallengeType{}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -26,11 +26,9 @@ type revokePayload struct {
|
|||
}
|
||||
|
||||
// RevokeCert attempts to revoke a certificate.
|
||||
func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -71,7 +69,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
serial := certToBeRevoked.SerialNumber.String()
|
||||
dbCert, err := db.GetCertificateBySerial(ctx, serial)
|
||||
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||
return
|
||||
|
@ -89,7 +87,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||
if acmeErr != nil {
|
||||
render.Error(w, acmeErr)
|
||||
return
|
||||
|
@ -105,8 +103,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
ca := mustAuthority(ctx)
|
||||
hasBeenRevokedBefore, err := ca.IsRevoked(serial)
|
||||
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
||||
return
|
||||
|
@ -133,14 +130,14 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
||||
err = ca.Revoke(ctx, options)
|
||||
err = h.ca.Revoke(ctx, options)
|
||||
if err != nil {
|
||||
render.Error(w, wrapRevokeErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
logRevoke(w, options)
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index"))
|
||||
w.Write(nil)
|
||||
}
|
||||
|
||||
|
@ -151,7 +148,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
|
||||
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
|
||||
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
|
||||
func isAccountAuthorized(_ context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||
if !account.IsValid() {
|
||||
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
|
||||
}
|
||||
|
|
|
@ -24,16 +24,14 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
)
|
||||
|
||||
// v is a utility function to return the pointer to an integer
|
||||
|
@ -258,7 +256,7 @@ func jwkEncode(pub crypto.PublicKey) (string, error) {
|
|||
// jwsFinal constructs the final JWS object.
|
||||
// Implementation taken from github.com/mholt/acmez, which seems to be based on
|
||||
// https://github.com/golang/crypto/blob/master/acme/jws.go.
|
||||
func jwsFinal(_ crypto.Hash, sig []byte, phead, payload string) ([]byte, error) {
|
||||
func jwsFinal(sha crypto.Hash, sig []byte, phead, payload string) ([]byte, error) {
|
||||
enc := struct {
|
||||
Protected string `json:"protected"`
|
||||
Payload string `json:"payload"`
|
||||
|
@ -276,22 +274,14 @@ func jwsFinal(_ crypto.Hash, sig []byte, phead, payload string) ([]byte, error)
|
|||
}
|
||||
|
||||
type mockCA struct {
|
||||
MockIsRevoked func(sn string) (bool, error)
|
||||
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
|
||||
MockAreSANsallowed func(ctx context.Context, sans []string) error
|
||||
MockIsRevoked func(sn string) (bool, error)
|
||||
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
|
||||
}
|
||||
|
||||
func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error {
|
||||
if m.MockAreSANsallowed != nil {
|
||||
return m.MockAreSANsallowed(ctx, sans)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCA) IsRevoked(sn string) (bool, error) {
|
||||
if m.MockIsRevoked != nil {
|
||||
return m.MockIsRevoked(sn)
|
||||
|
@ -521,7 +511,6 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/no-jws": func(t *testing.T) test {
|
||||
ctx := context.Background()
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -530,7 +519,6 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/nil-jws": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -539,7 +527,6 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -547,9 +534,8 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = acme.NewProvisionerContext(ctx, nil)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -557,9 +543,8 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -567,10 +552,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -579,10 +563,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/unmarshal-payload": func(t *testing.T) test {
|
||||
malformedPayload := []byte(`{"payload":malformed?}`)
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("error unmarshaling payload"),
|
||||
|
@ -594,11 +577,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: &acme.Error{
|
||||
|
@ -614,11 +596,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
emptyPayloadBytes, err := json.Marshal(emptyPayload)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: &acme.Error{
|
||||
|
@ -629,7 +610,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/db.GetCertificateBySerial": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
db := &acme.MockDB{
|
||||
|
@ -647,7 +628,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/different-certificate-contents": func(t *testing.T) test {
|
||||
aDifferentCert, _, err := generateCertKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
db := &acme.MockDB{
|
||||
|
@ -666,7 +647,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
db := &acme.MockDB{
|
||||
|
@ -685,7 +666,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
|
@ -706,10 +687,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-valid": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -735,10 +717,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-authorized": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -788,9 +771,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -814,10 +798,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/certificate-revoked-check-fails": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -847,7 +832,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/certificate-already-revoked": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -885,7 +870,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -923,7 +908,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
}
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -955,7 +940,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/ca.Revoke": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -987,7 +972,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/ca.Revoke-already-revoked": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -1018,10 +1003,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"ok/using-account-key": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -1045,9 +1031,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(jwsBytes))
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -1070,12 +1057,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
for name, setup := range tests {
|
||||
tc := setup(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
mockMustAuthority(t, tc.ca)
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
|
||||
req := httptest.NewRequest("POST", revokeURL, nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
RevokeCert(w, req)
|
||||
h.RevokeCert(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1090,6 +1076,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
|
@ -1211,8 +1198,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
|
|||
for name, setup := range tests {
|
||||
tc := setup(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// h := &Handler{db: tc.db}
|
||||
acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
|
||||
h := &Handler{db: tc.db}
|
||||
acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
|
||||
|
||||
expectError := tc.err != nil
|
||||
gotError := acmeErr != nil
|
||||
|
@ -1229,6 +1216,7 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
|
|||
assert.Equals(t, acmeErr.Type, tc.err.Type)
|
||||
assert.Equals(t, acmeErr.Status, tc.err.Status)
|
||||
assert.Equals(t, acmeErr.Detail, tc.err.Detail)
|
||||
assert.Equals(t, acmeErr.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, acmeErr.Subproblems, tc.err.Subproblems)
|
||||
|
||||
})
|
||||
|
@ -1321,6 +1309,7 @@ func Test_wrapUnauthorizedError(t *testing.T) {
|
|||
assert.Equals(t, acmeErr.Type, tc.want.Type)
|
||||
assert.Equals(t, acmeErr.Status, tc.want.Status)
|
||||
assert.Equals(t, acmeErr.Detail, tc.want.Detail)
|
||||
assert.Equals(t, acmeErr.Identifier, tc.want.Identifier)
|
||||
assert.Equals(t, acmeErr.Subproblems, tc.want.Subproblems)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -8,16 +8,15 @@ import (
|
|||
|
||||
// Authorization representst an ACME Authorization.
|
||||
type Authorization struct {
|
||||
ID string `json:"-"`
|
||||
AccountID string `json:"-"`
|
||||
Token string `json:"-"`
|
||||
Fingerprint string `json:"-"`
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status Status `json:"status"`
|
||||
Challenges []*Challenge `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
ExpiresAt time.Time `json:"expires"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
ID string `json:"-"`
|
||||
AccountID string `json:"-"`
|
||||
Token string `json:"-"`
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status Status `json:"status"`
|
||||
Challenges []*Challenge `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
ExpiresAt time.Time `json:"expires"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
|
|
|
@ -130,14 +130,14 @@ func TestAuthorization_UpdateStatus(t *testing.T) {
|
|||
tc := run(t)
|
||||
if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
var k *Error
|
||||
if errors.As(err, &k) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
assert.Equals(t, k.Type, tc.err.Type)
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
assert.Equals(t, k.Status, tc.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
} else {
|
||||
default:
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,14 +3,9 @@ package acme
|
|||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
|
@ -19,23 +14,13 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/google/go-tpm/tpm2"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/smallstep/go-attestation/attest"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
type ChallengeType string
|
||||
|
@ -47,20 +32,6 @@ const (
|
|||
DNS01 ChallengeType = "dns-01"
|
||||
// TLSALPN01 is the tls-alpn-01 ACME challenge type
|
||||
TLSALPN01 ChallengeType = "tls-alpn-01"
|
||||
// DEVICEATTEST01 is the device-attest-01 ACME challenge type
|
||||
DEVICEATTEST01 ChallengeType = "device-attest-01"
|
||||
)
|
||||
|
||||
var (
|
||||
// InsecurePortHTTP01 is the port used to verify http-01 challenges. If not set it
|
||||
// defaults to 80.
|
||||
InsecurePortHTTP01 int
|
||||
|
||||
// InsecurePortTLSALPN01 is the port used to verify tls-alpn-01 challenges. If not
|
||||
// set it defaults to 443.
|
||||
//
|
||||
// This variable can be used for testing purposes.
|
||||
InsecurePortTLSALPN01 int
|
||||
)
|
||||
|
||||
// Challenge represents an ACME response Challenge type.
|
||||
|
@ -86,39 +57,31 @@ func (ch *Challenge) ToLog() (interface{}, error) {
|
|||
return string(b), nil
|
||||
}
|
||||
|
||||
// Validate attempts to validate the Challenge. Stores changes to the Challenge
|
||||
// type using the DB interface. If the Challenge is validated, the 'status' and
|
||||
// 'validated' attributes are updated.
|
||||
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, payload []byte) error {
|
||||
// Validate attempts to validate the challenge. Stores changes to the Challenge
|
||||
// type using the DB interface.
|
||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||
// updated.
|
||||
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
// If already valid or invalid then return without performing validation.
|
||||
if ch.Status != StatusPending {
|
||||
return nil
|
||||
}
|
||||
switch ch.Type {
|
||||
case HTTP01:
|
||||
return http01Validate(ctx, ch, db, jwk)
|
||||
return http01Validate(ctx, ch, db, jwk, vo)
|
||||
case DNS01:
|
||||
return dns01Validate(ctx, ch, db, jwk)
|
||||
return dns01Validate(ctx, ch, db, jwk, vo)
|
||||
case TLSALPN01:
|
||||
return tlsalpn01Validate(ctx, ch, db, jwk)
|
||||
case DEVICEATTEST01:
|
||||
return deviceAttest01Validate(ctx, ch, db, jwk, payload)
|
||||
return tlsalpn01Validate(ctx, ch, db, jwk, vo)
|
||||
default:
|
||||
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||
|
||||
// Append insecure port if set.
|
||||
// Only used for testing purposes.
|
||||
if InsecurePortHTTP01 != 0 {
|
||||
u.Host += ":" + strconv.Itoa(InsecurePortHTTP01)
|
||||
}
|
||||
|
||||
vc := MustClientFromContext(ctx)
|
||||
resp, err := vc.Get(u.String())
|
||||
resp, err := vo.HTTPGet(u.String())
|
||||
if err != nil {
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||
"error doing http GET for url %s", u))
|
||||
|
@ -156,17 +119,6 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
|
|||
return nil
|
||||
}
|
||||
|
||||
// http01ChallengeHost checks if a Challenge value is an IPv6 address
|
||||
// and adds square brackets if that's the case, so that it can be used
|
||||
// as a hostname. Returns the original Challenge value as the host to
|
||||
// use in other cases.
|
||||
func http01ChallengeHost(value string) string {
|
||||
if ip := net.ParseIP(value); ip != nil && ip.To4() == nil {
|
||||
value = "[" + value + "]"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func tlsAlert(err error) uint8 {
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
|
@ -178,7 +130,7 @@ func tlsAlert(err error) uint8 {
|
|||
return 0
|
||||
}
|
||||
|
||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
config := &tls.Config{
|
||||
NextProtos: []string{"acme-tls/1"},
|
||||
// https://tools.ietf.org/html/rfc8737#section-4
|
||||
|
@ -186,20 +138,12 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
// [RFC5246] or higher when connecting to clients for validation.
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: serverName(ch),
|
||||
InsecureSkipVerify: true, //nolint:gosec // we expect a self-signed challenge certificate
|
||||
InsecureSkipVerify: true, // we expect a self-signed challenge certificate
|
||||
}
|
||||
|
||||
var hostPort string
|
||||
hostPort := net.JoinHostPort(ch.Value, "443")
|
||||
|
||||
// Allow to change TLS port for testing purposes.
|
||||
if port := InsecurePortTLSALPN01; port == 0 {
|
||||
hostPort = net.JoinHostPort(ch.Value, "443")
|
||||
} else {
|
||||
hostPort = net.JoinHostPort(ch.Value, strconv.Itoa(port))
|
||||
}
|
||||
|
||||
vc := MustClientFromContext(ctx)
|
||||
conn, err := vc.TLSDial("tcp", hostPort, config)
|
||||
conn, err := vo.TLSDial("tcp", hostPort, config)
|
||||
if err != nil {
|
||||
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
|
||||
// client and server protocols. When this happens the connection is
|
||||
|
@ -298,15 +242,14 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
|
||||
}
|
||||
|
||||
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
// Normalize domain for wildcard DNS names
|
||||
// This is done to avoid making TXT lookups for domains like
|
||||
// _acme-challenge.*.example.com
|
||||
// Instead perform txt lookup for _acme-challenge.example.com
|
||||
domain := strings.TrimPrefix(ch.Value, "*.")
|
||||
|
||||
vc := MustClientFromContext(ctx)
|
||||
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
|
||||
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
|
||||
if err != nil {
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
|
||||
"error looking up TXT records for domain %s", domain))
|
||||
|
@ -341,706 +284,6 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK
|
|||
return nil
|
||||
}
|
||||
|
||||
type payloadType struct {
|
||||
AttObj string `json:"attObj"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type attestationObject struct {
|
||||
Format string `json:"fmt"`
|
||||
AttStatement map[string]interface{} `json:"attStmt,omitempty"`
|
||||
}
|
||||
|
||||
// TODO(bweeks): move attestation verification to a shared package.
|
||||
func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error {
|
||||
// Load authorization to store the key fingerprint.
|
||||
az, err := db.GetAuthorization(ctx, ch.AuthorizationID)
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error loading authorization")
|
||||
}
|
||||
|
||||
// Parse payload.
|
||||
var p payloadType
|
||||
if err := json.Unmarshal(payload, &p); err != nil {
|
||||
return WrapErrorISE(err, "error unmarshalling JSON")
|
||||
}
|
||||
if p.Error != "" {
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||
"payload contained error: %v", p.Error))
|
||||
}
|
||||
|
||||
attObj, err := base64.RawURLEncoding.DecodeString(p.AttObj)
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error base64 decoding attObj")
|
||||
}
|
||||
|
||||
att := attestationObject{}
|
||||
if err := cbor.Unmarshal(attObj, &att); err != nil {
|
||||
return WrapErrorISE(err, "error unmarshalling CBOR")
|
||||
}
|
||||
|
||||
prov := MustProvisionerFromContext(ctx)
|
||||
if !prov.IsAttestationFormatEnabled(ctx, provisioner.ACMEAttestationFormat(att.Format)) {
|
||||
return storeError(ctx, db, ch, true,
|
||||
NewError(ErrorBadAttestationStatementType, "attestation format %q is not enabled", att.Format))
|
||||
}
|
||||
|
||||
switch att.Format {
|
||||
case "apple":
|
||||
data, err := doAppleAttestationFormat(ctx, prov, ch, &att)
|
||||
if err != nil {
|
||||
var acmeError *Error
|
||||
if errors.As(err, &acmeError) {
|
||||
if acmeError.Status == 500 {
|
||||
return acmeError
|
||||
}
|
||||
return storeError(ctx, db, ch, true, acmeError)
|
||||
}
|
||||
return WrapErrorISE(err, "error validating attestation")
|
||||
}
|
||||
// Validate nonce with SHA-256 of the token.
|
||||
if len(data.Nonce) != 0 {
|
||||
sum := sha256.Sum256([]byte(ch.Token))
|
||||
if subtle.ConstantTimeCompare(data.Nonce, sum[:]) != 1 {
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "challenge token does not match"))
|
||||
}
|
||||
}
|
||||
|
||||
// Validate Apple's ClientIdentifier (Identifier.Value) with device
|
||||
// identifiers.
|
||||
//
|
||||
// Note: We might want to use an external service for this.
|
||||
if data.UDID != ch.Value && data.SerialNumber != ch.Value {
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "permanent identifier does not match"))
|
||||
}
|
||||
|
||||
// Update attestation key fingerprint to compare against the CSR
|
||||
az.Fingerprint = data.Fingerprint
|
||||
case "step":
|
||||
data, err := doStepAttestationFormat(ctx, prov, ch, jwk, &att)
|
||||
if err != nil {
|
||||
var acmeError *Error
|
||||
if errors.As(err, &acmeError) {
|
||||
if acmeError.Status == 500 {
|
||||
return acmeError
|
||||
}
|
||||
return storeError(ctx, db, ch, true, acmeError)
|
||||
}
|
||||
return WrapErrorISE(err, "error validating attestation")
|
||||
}
|
||||
|
||||
// Validate the YubiKey serial number from the attestation
|
||||
// certificate with the challenged Order value.
|
||||
//
|
||||
// Note: We might want to use an external service for this.
|
||||
if data.SerialNumber != ch.Value {
|
||||
subproblem := NewSubproblemWithIdentifier(
|
||||
ErrorMalformedType,
|
||||
Identifier{Type: "permanent-identifier", Value: ch.Value},
|
||||
"challenge identifier %q doesn't match the attested hardware identifier %q", ch.Value, data.SerialNumber,
|
||||
)
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "permanent identifier does not match").AddSubproblems(subproblem))
|
||||
}
|
||||
|
||||
// Update attestation key fingerprint to compare against the CSR
|
||||
az.Fingerprint = data.Fingerprint
|
||||
|
||||
case "tpm":
|
||||
data, err := doTPMAttestationFormat(ctx, prov, ch, jwk, &att)
|
||||
if err != nil {
|
||||
// TODO(hs): we should provide more details in the error reported to the client;
|
||||
// "Attestation statement cannot be verified" is VERY generic. Also holds true for the other formats.
|
||||
var acmeError *Error
|
||||
if errors.As(err, &acmeError) {
|
||||
if acmeError.Status == 500 {
|
||||
return acmeError
|
||||
}
|
||||
return storeError(ctx, db, ch, true, acmeError)
|
||||
}
|
||||
return WrapErrorISE(err, "error validating attestation")
|
||||
}
|
||||
|
||||
// TODO(hs): currently this will allow a request for which no PermanentIdentifiers have been
|
||||
// extracted from the AK certificate. This is currently the case for AK certs from the CLI, as we
|
||||
// haven't implemented a way for AK certs requested by the CLI to always contain the requested
|
||||
// PermanentIdentifier. Omitting the check below doesn't allow just any request, as the Order can
|
||||
// still fail if the challenge value isn't equal to the CSR subject.
|
||||
if len(data.PermanentIdentifiers) > 0 && !slices.Contains(data.PermanentIdentifiers, ch.Value) { // TODO(hs): add support for HardwareModuleName
|
||||
subproblem := NewSubproblemWithIdentifier(
|
||||
ErrorMalformedType,
|
||||
Identifier{Type: "permanent-identifier", Value: ch.Value},
|
||||
"challenge identifier %q doesn't match any of the attested hardware identifiers %q", ch.Value, data.PermanentIdentifiers,
|
||||
)
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "permanent identifier does not match").AddSubproblems(subproblem))
|
||||
}
|
||||
|
||||
// Update attestation key fingerprint to compare against the CSR
|
||||
az.Fingerprint = data.Fingerprint
|
||||
default:
|
||||
return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "unexpected attestation object format"))
|
||||
}
|
||||
|
||||
// Update and store the challenge.
|
||||
ch.Status = StatusValid
|
||||
ch.Error = nil
|
||||
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
|
||||
|
||||
// Store the fingerprint in the authorization.
|
||||
//
|
||||
// TODO: add method to update authorization and challenge atomically.
|
||||
if az.Fingerprint != "" {
|
||||
if err := db.UpdateAuthorization(ctx, az); err != nil {
|
||||
return WrapErrorISE(err, "error updating authorization")
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.UpdateChallenge(ctx, ch); err != nil {
|
||||
return WrapErrorISE(err, "error updating challenge")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
oidSubjectAlternativeName = asn1.ObjectIdentifier{2, 5, 29, 17}
|
||||
)
|
||||
|
||||
type tpmAttestationData struct {
|
||||
Certificate *x509.Certificate
|
||||
VerifiedChains [][]*x509.Certificate
|
||||
PermanentIdentifiers []string
|
||||
Fingerprint string
|
||||
}
|
||||
|
||||
// coseAlgorithmIdentifier models a COSEAlgorithmIdentifier.
|
||||
// Also see https://www.w3.org/TR/webauthn-2/#sctn-alg-identifier.
|
||||
type coseAlgorithmIdentifier int32
|
||||
|
||||
const (
|
||||
coseAlgES256 coseAlgorithmIdentifier = -7
|
||||
coseAlgRS256 coseAlgorithmIdentifier = -257
|
||||
)
|
||||
|
||||
func doTPMAttestationFormat(_ context.Context, prov Provisioner, ch *Challenge, jwk *jose.JSONWebKey, att *attestationObject) (*tpmAttestationData, error) {
|
||||
ver, ok := att.AttStatement["ver"].(string)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "ver not present")
|
||||
}
|
||||
if ver != "2.0" {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "version %q is not supported", ver)
|
||||
}
|
||||
|
||||
x5c, ok := att.AttStatement["x5c"].([]interface{})
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c not present")
|
||||
}
|
||||
if len(x5c) == 0 {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c is empty")
|
||||
}
|
||||
|
||||
akCertBytes, ok := x5c[0].([]byte)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c is malformed")
|
||||
}
|
||||
akCert, err := x509.ParseCertificate(akCertBytes)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is malformed")
|
||||
}
|
||||
|
||||
intermediates := x509.NewCertPool()
|
||||
for _, v := range x5c[1:] {
|
||||
intCertBytes, vok := v.([]byte)
|
||||
if !vok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c is malformed")
|
||||
}
|
||||
intCert, err := x509.ParseCertificate(intCertBytes)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is malformed")
|
||||
}
|
||||
intermediates.AddCert(intCert)
|
||||
}
|
||||
|
||||
// TODO(hs): this can be removed when permanent-identifier/hardware-module-name are handled correctly in
|
||||
// the stdlib in https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/x509/parser.go;drc=b5b2cf519fe332891c165077f3723ee74932a647;l=362,
|
||||
// but I doubt that will happen.
|
||||
if len(akCert.UnhandledCriticalExtensions) > 0 {
|
||||
unhandledCriticalExtensions := akCert.UnhandledCriticalExtensions[:0]
|
||||
for _, extOID := range akCert.UnhandledCriticalExtensions {
|
||||
if !extOID.Equal(oidSubjectAlternativeName) {
|
||||
// critical extensions other than the Subject Alternative Name remain unhandled
|
||||
unhandledCriticalExtensions = append(unhandledCriticalExtensions, extOID)
|
||||
}
|
||||
}
|
||||
akCert.UnhandledCriticalExtensions = unhandledCriticalExtensions
|
||||
}
|
||||
|
||||
roots, ok := prov.GetAttestationRoots()
|
||||
if !ok {
|
||||
return nil, NewErrorISE("no root CA bundle available to verify the attestation certificate")
|
||||
}
|
||||
|
||||
// verify that the AK certificate was signed by a trusted root,
|
||||
// chained to by the intermediates provided by the client. As part
|
||||
// of building the verified certificate chain, the signature over the
|
||||
// AK certificate is checked to be a valid signature of one of the
|
||||
// provided intermediates. Signatures over the intermediates are in
|
||||
// turn also verified to be valid signatures from one of the trusted
|
||||
// roots.
|
||||
verifiedChains, err := akCert.Verify(x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
Intermediates: intermediates,
|
||||
CurrentTime: time.Now().Truncate(time.Second),
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is not valid")
|
||||
}
|
||||
|
||||
// validate additional AK certificate requirements
|
||||
if err := validateAKCertificate(akCert); err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "AK certificate is not valid")
|
||||
}
|
||||
|
||||
// TODO(hs): implement revocation check; Verify() doesn't perform CRL check nor OCSP lookup.
|
||||
|
||||
sans, err := x509util.ParseSubjectAlternativeNames(akCert)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed parsing AK certificate Subject Alternative Names")
|
||||
}
|
||||
|
||||
permanentIdentifiers := make([]string, len(sans.PermanentIdentifiers))
|
||||
for i, pi := range sans.PermanentIdentifiers {
|
||||
permanentIdentifiers[i] = pi.Identifier
|
||||
}
|
||||
|
||||
// extract and validate pubArea, sig, certInfo and alg properties from the request body
|
||||
pubArea, ok := att.AttStatement["pubArea"].([]byte)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "invalid pubArea in attestation statement")
|
||||
}
|
||||
if len(pubArea) == 0 {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "pubArea is empty")
|
||||
}
|
||||
|
||||
sig, ok := att.AttStatement["sig"].([]byte)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "invalid sig in attestation statement")
|
||||
}
|
||||
if len(sig) == 0 {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "sig is empty")
|
||||
}
|
||||
|
||||
certInfo, ok := att.AttStatement["certInfo"].([]byte)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "invalid certInfo in attestation statement")
|
||||
}
|
||||
if len(certInfo) == 0 {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "certInfo is empty")
|
||||
}
|
||||
|
||||
alg, ok := att.AttStatement["alg"].(int64)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "invalid alg in attestation statement")
|
||||
}
|
||||
|
||||
// only RS256 and ES256 are allowed
|
||||
coseAlg := coseAlgorithmIdentifier(alg)
|
||||
if coseAlg != coseAlgRS256 && coseAlg != coseAlgES256 {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "invalid alg %d in attestation statement", alg)
|
||||
}
|
||||
|
||||
// set the hash algorithm to use to SHA256
|
||||
hash := crypto.SHA256
|
||||
|
||||
// recreate the generated key certification parameter values and verify
|
||||
// the attested key using the public key of the AK.
|
||||
certificationParameters := &attest.CertificationParameters{
|
||||
Public: pubArea, // the public key that was attested
|
||||
CreateAttestation: certInfo, // the attested properties of the key
|
||||
CreateSignature: sig, // signature over the attested properties
|
||||
}
|
||||
verifyOpts := attest.VerifyOpts{
|
||||
Public: akCert.PublicKey, // public key of the AK that attested the key
|
||||
Hash: hash,
|
||||
}
|
||||
if err = certificationParameters.Verify(verifyOpts); err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "invalid certification parameters")
|
||||
}
|
||||
|
||||
// decode the "certInfo" data. This won't fail, as it's also done as part of Verify().
|
||||
tpmCertInfo, err := tpm2.DecodeAttestationData(certInfo)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed decoding attestation data")
|
||||
}
|
||||
|
||||
keyAuth, err := KeyAuthorization(ch.Token, jwk)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed creating key auth digest")
|
||||
}
|
||||
hashedKeyAuth := sha256.Sum256([]byte(keyAuth))
|
||||
|
||||
// verify the WebAuthn object contains the expect key authorization digest, which is carried
|
||||
// within the encoded `certInfo` property of the attestation statement.
|
||||
if subtle.ConstantTimeCompare(hashedKeyAuth[:], []byte(tpmCertInfo.ExtraData)) == 0 {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "key authorization does not match")
|
||||
}
|
||||
|
||||
// decode the (attested) public key and determine its fingerprint. This won't fail, as it's also done as part of Verify().
|
||||
pub, err := tpm2.DecodePublic(pubArea)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed decoding pubArea")
|
||||
}
|
||||
|
||||
publicKey, err := pub.Key()
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed getting public key")
|
||||
}
|
||||
|
||||
data := &tpmAttestationData{
|
||||
Certificate: akCert,
|
||||
VerifiedChains: verifiedChains,
|
||||
PermanentIdentifiers: permanentIdentifiers,
|
||||
}
|
||||
|
||||
if data.Fingerprint, err = keyutil.Fingerprint(publicKey); err != nil {
|
||||
return nil, WrapErrorISE(err, "error calculating key fingerprint")
|
||||
}
|
||||
|
||||
// TODO(hs): pass more attestation data, so that that can be used/recorded too?
|
||||
return data, nil
|
||||
}
|
||||
|
||||
var (
|
||||
oidExtensionExtendedKeyUsage = asn1.ObjectIdentifier{2, 5, 29, 37}
|
||||
oidTCGKpAIKCertificate = asn1.ObjectIdentifier{2, 23, 133, 8, 3}
|
||||
)
|
||||
|
||||
// validateAKCertifiate validates the X.509 AK certificate to be
|
||||
// in accordance with the required properties. The requirements come from:
|
||||
// https://www.w3.org/TR/webauthn-2/#sctn-tpm-cert-requirements.
|
||||
//
|
||||
// - Version MUST be set to 3.
|
||||
// - Subject field MUST be set to empty.
|
||||
// - The Subject Alternative Name extension MUST be set as defined
|
||||
// in [TPMv2-EK-Profile] section 3.2.9.
|
||||
// - The Extended Key Usage extension MUST contain the OID 2.23.133.8.3
|
||||
// ("joint-iso-itu-t(2) internationalorganizations(23) 133 tcg-kp(8) tcg-kp-AIKCertificate(3)").
|
||||
// - The Basic Constraints extension MUST have the CA component set to false.
|
||||
// - An Authority Information Access (AIA) extension with entry id-ad-ocsp
|
||||
// and a CRL Distribution Point extension [RFC5280] are both OPTIONAL as
|
||||
// the status of many attestation certificates is available through metadata
|
||||
// services. See, for example, the FIDO Metadata Service.
|
||||
func validateAKCertificate(c *x509.Certificate) error {
|
||||
if c.Version != 3 {
|
||||
return fmt.Errorf("AK certificate has invalid version %d; only version 3 is allowed", c.Version)
|
||||
}
|
||||
if c.Subject.String() != "" {
|
||||
return fmt.Errorf("AK certificate subject must be empty; got %q", c.Subject)
|
||||
}
|
||||
if c.IsCA {
|
||||
return errors.New("AK certificate must not be a CA")
|
||||
}
|
||||
if err := validateAKCertificateExtendedKeyUsage(c); err != nil {
|
||||
return err
|
||||
}
|
||||
return validateAKCertificateSubjectAlternativeNames(c)
|
||||
}
|
||||
|
||||
// validateAKCertificateSubjectAlternativeNames checks if the AK certificate
|
||||
// has TPM hardware details set.
|
||||
func validateAKCertificateSubjectAlternativeNames(c *x509.Certificate) error {
|
||||
sans, err := x509util.ParseSubjectAlternativeNames(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed parsing AK certificate Subject Alternative Names: %w", err)
|
||||
}
|
||||
|
||||
details := sans.TPMHardwareDetails
|
||||
manufacturer, model, version := details.Manufacturer, details.Model, details.Version
|
||||
|
||||
switch {
|
||||
case manufacturer == "":
|
||||
return errors.New("missing TPM manufacturer")
|
||||
case model == "":
|
||||
return errors.New("missing TPM model")
|
||||
case version == "":
|
||||
return errors.New("missing TPM version")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateAKCertificateExtendedKeyUsage checks if the AK certificate
|
||||
// has the "tcg-kp-AIKCertificate" Extended Key Usage set.
|
||||
func validateAKCertificateExtendedKeyUsage(c *x509.Certificate) error {
|
||||
var (
|
||||
valid = false
|
||||
ekus []asn1.ObjectIdentifier
|
||||
)
|
||||
for _, ext := range c.Extensions {
|
||||
if ext.Id.Equal(oidExtensionExtendedKeyUsage) {
|
||||
if _, err := asn1.Unmarshal(ext.Value, &ekus); err != nil || !ekus[0].Equal(oidTCGKpAIKCertificate) {
|
||||
return errors.New("AK certificate is missing Extended Key Usage value tcg-kp-AIKCertificate (2.23.133.8.3)")
|
||||
}
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
return errors.New("AK certificate is missing Extended Key Usage extension")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apple Enterprise Attestation Root CA from
|
||||
// https://www.apple.com/certificateauthority/private/
|
||||
const appleEnterpriseAttestationRootCA = `-----BEGIN CERTIFICATE-----
|
||||
MIICJDCCAamgAwIBAgIUQsDCuyxyfFxeq/bxpm8frF15hzcwCgYIKoZIzj0EAwMw
|
||||
UTEtMCsGA1UEAwwkQXBwbGUgRW50ZXJwcmlzZSBBdHRlc3RhdGlvbiBSb290IENB
|
||||
MRMwEQYDVQQKDApBcHBsZSBJbmMuMQswCQYDVQQGEwJVUzAeFw0yMjAyMTYxOTAx
|
||||
MjRaFw00NzAyMjAwMDAwMDBaMFExLTArBgNVBAMMJEFwcGxlIEVudGVycHJpc2Ug
|
||||
QXR0ZXN0YXRpb24gUm9vdCBDQTETMBEGA1UECgwKQXBwbGUgSW5jLjELMAkGA1UE
|
||||
BhMCVVMwdjAQBgcqhkjOPQIBBgUrgQQAIgNiAAT6Jigq+Ps9Q4CoT8t8q+UnOe2p
|
||||
oT9nRaUfGhBTbgvqSGXPjVkbYlIWYO+1zPk2Sz9hQ5ozzmLrPmTBgEWRcHjA2/y7
|
||||
7GEicps9wn2tj+G89l3INNDKETdxSPPIZpPj8VmjQjBAMA8GA1UdEwEB/wQFMAMB
|
||||
Af8wHQYDVR0OBBYEFPNqTQGd8muBpV5du+UIbVbi+d66MA4GA1UdDwEB/wQEAwIB
|
||||
BjAKBggqhkjOPQQDAwNpADBmAjEA1xpWmTLSpr1VH4f8Ypk8f3jMUKYz4QPG8mL5
|
||||
8m9sX/b2+eXpTv2pH4RZgJjucnbcAjEA4ZSB6S45FlPuS/u4pTnzoz632rA+xW/T
|
||||
ZwFEh9bhKjJ+5VQ9/Do1os0u3LEkgN/r
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
var (
|
||||
oidAppleSerialNumber = asn1.ObjectIdentifier{1, 2, 840, 113635, 100, 8, 9, 1}
|
||||
oidAppleUniqueDeviceIdentifier = asn1.ObjectIdentifier{1, 2, 840, 113635, 100, 8, 9, 2}
|
||||
oidAppleSecureEnclaveProcessorOSVersion = asn1.ObjectIdentifier{1, 2, 840, 113635, 100, 8, 10, 2}
|
||||
oidAppleNonce = asn1.ObjectIdentifier{1, 2, 840, 113635, 100, 8, 11, 1}
|
||||
)
|
||||
|
||||
type appleAttestationData struct {
|
||||
Nonce []byte
|
||||
SerialNumber string
|
||||
UDID string
|
||||
SEPVersion string
|
||||
Certificate *x509.Certificate
|
||||
Fingerprint string
|
||||
}
|
||||
|
||||
func doAppleAttestationFormat(_ context.Context, prov Provisioner, _ *Challenge, att *attestationObject) (*appleAttestationData, error) {
|
||||
// Use configured or default attestation roots if none is configured.
|
||||
roots, ok := prov.GetAttestationRoots()
|
||||
if !ok {
|
||||
root, err := pemutil.ParseCertificate([]byte(appleEnterpriseAttestationRootCA))
|
||||
if err != nil {
|
||||
return nil, WrapErrorISE(err, "error parsing apple enterprise ca")
|
||||
}
|
||||
roots = x509.NewCertPool()
|
||||
roots.AddCert(root)
|
||||
}
|
||||
|
||||
x5c, ok := att.AttStatement["x5c"].([]interface{})
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c not present")
|
||||
}
|
||||
if len(x5c) == 0 {
|
||||
return nil, NewError(ErrorRejectedIdentifierType, "x5c is empty")
|
||||
}
|
||||
|
||||
der, ok := x5c[0].([]byte)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c is malformed")
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is malformed")
|
||||
}
|
||||
|
||||
intermediates := x509.NewCertPool()
|
||||
for _, v := range x5c[1:] {
|
||||
der, ok = v.([]byte)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c is malformed")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is malformed")
|
||||
}
|
||||
intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
if _, err := leaf.Verify(x509.VerifyOptions{
|
||||
Intermediates: intermediates,
|
||||
Roots: roots,
|
||||
CurrentTime: time.Now().Truncate(time.Second),
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
|
||||
}); err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is not valid")
|
||||
}
|
||||
|
||||
data := &appleAttestationData{
|
||||
Certificate: leaf,
|
||||
}
|
||||
if data.Fingerprint, err = keyutil.Fingerprint(leaf.PublicKey); err != nil {
|
||||
return nil, WrapErrorISE(err, "error calculating key fingerprint")
|
||||
}
|
||||
for _, ext := range leaf.Extensions {
|
||||
switch {
|
||||
case ext.Id.Equal(oidAppleSerialNumber):
|
||||
data.SerialNumber = string(ext.Value)
|
||||
case ext.Id.Equal(oidAppleUniqueDeviceIdentifier):
|
||||
data.UDID = string(ext.Value)
|
||||
case ext.Id.Equal(oidAppleSecureEnclaveProcessorOSVersion):
|
||||
data.SEPVersion = string(ext.Value)
|
||||
case ext.Id.Equal(oidAppleNonce):
|
||||
data.Nonce = ext.Value
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Yubico PIV Root CA Serial 263751
|
||||
// https://developers.yubico.com/PIV/Introduction/piv-attestation-ca.pem
|
||||
const yubicoPIVRootCA = `-----BEGIN CERTIFICATE-----
|
||||
MIIDFzCCAf+gAwIBAgIDBAZHMA0GCSqGSIb3DQEBCwUAMCsxKTAnBgNVBAMMIFl1
|
||||
YmljbyBQSVYgUm9vdCBDQSBTZXJpYWwgMjYzNzUxMCAXDTE2MDMxNDAwMDAwMFoY
|
||||
DzIwNTIwNDE3MDAwMDAwWjArMSkwJwYDVQQDDCBZdWJpY28gUElWIFJvb3QgQ0Eg
|
||||
U2VyaWFsIDI2Mzc1MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMN2
|
||||
cMTNR6YCdcTFRxuPy31PabRn5m6pJ+nSE0HRWpoaM8fc8wHC+Tmb98jmNvhWNE2E
|
||||
ilU85uYKfEFP9d6Q2GmytqBnxZsAa3KqZiCCx2LwQ4iYEOb1llgotVr/whEpdVOq
|
||||
joU0P5e1j1y7OfwOvky/+AXIN/9Xp0VFlYRk2tQ9GcdYKDmqU+db9iKwpAzid4oH
|
||||
BVLIhmD3pvkWaRA2H3DA9t7H/HNq5v3OiO1jyLZeKqZoMbPObrxqDg+9fOdShzgf
|
||||
wCqgT3XVmTeiwvBSTctyi9mHQfYd2DwkaqxRnLbNVyK9zl+DzjSGp9IhVPiVtGet
|
||||
X02dxhQnGS7K6BO0Qe8CAwEAAaNCMEAwHQYDVR0OBBYEFMpfyvLEojGc6SJf8ez0
|
||||
1d8Cv4O/MA8GA1UdEwQIMAYBAf8CAQEwDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3
|
||||
DQEBCwUAA4IBAQBc7Ih8Bc1fkC+FyN1fhjWioBCMr3vjneh7MLbA6kSoyWF70N3s
|
||||
XhbXvT4eRh0hvxqvMZNjPU/VlRn6gLVtoEikDLrYFXN6Hh6Wmyy1GTnspnOvMvz2
|
||||
lLKuym9KYdYLDgnj3BeAvzIhVzzYSeU77/Cupofj093OuAswW0jYvXsGTyix6B3d
|
||||
bW5yWvyS9zNXaqGaUmP3U9/b6DlHdDogMLu3VLpBB9bm5bjaKWWJYgWltCVgUbFq
|
||||
Fqyi4+JE014cSgR57Jcu3dZiehB6UtAPgad9L5cNvua/IWRmm+ANy3O2LH++Pyl8
|
||||
SREzU8onbBsjMg9QDiSf5oJLKvd/Ren+zGY7
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
// Serial number of the YubiKey, encoded as an integer.
|
||||
// https://developers.yubico.com/PIV/Introduction/PIV_attestation.html
|
||||
var oidYubicoSerialNumber = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 41482, 3, 7}
|
||||
|
||||
type stepAttestationData struct {
|
||||
Certificate *x509.Certificate
|
||||
SerialNumber string
|
||||
Fingerprint string
|
||||
}
|
||||
|
||||
func doStepAttestationFormat(_ context.Context, prov Provisioner, ch *Challenge, jwk *jose.JSONWebKey, att *attestationObject) (*stepAttestationData, error) {
|
||||
// Use configured or default attestation roots if none is configured.
|
||||
roots, ok := prov.GetAttestationRoots()
|
||||
if !ok {
|
||||
root, err := pemutil.ParseCertificate([]byte(yubicoPIVRootCA))
|
||||
if err != nil {
|
||||
return nil, WrapErrorISE(err, "error parsing root ca")
|
||||
}
|
||||
roots = x509.NewCertPool()
|
||||
roots.AddCert(root)
|
||||
}
|
||||
|
||||
// Extract x5c and verify certificate
|
||||
x5c, ok := att.AttStatement["x5c"].([]interface{})
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c not present")
|
||||
}
|
||||
if len(x5c) == 0 {
|
||||
return nil, NewError(ErrorRejectedIdentifierType, "x5c is empty")
|
||||
}
|
||||
der, ok := x5c[0].([]byte)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c is malformed")
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is malformed")
|
||||
}
|
||||
intermediates := x509.NewCertPool()
|
||||
for _, v := range x5c[1:] {
|
||||
der, ok = v.([]byte)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "x5c is malformed")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is malformed")
|
||||
}
|
||||
intermediates.AddCert(cert)
|
||||
}
|
||||
if _, err := leaf.Verify(x509.VerifyOptions{
|
||||
Intermediates: intermediates,
|
||||
Roots: roots,
|
||||
CurrentTime: time.Now().Truncate(time.Second),
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
|
||||
}); err != nil {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is not valid")
|
||||
}
|
||||
|
||||
// Verify proof of possession of private key validating the key
|
||||
// authorization. Per recommendation at
|
||||
// https://w3c.github.io/webauthn/#sctn-signature-attestation-types the
|
||||
// signature is CBOR-encoded.
|
||||
var sig []byte
|
||||
csig, ok := att.AttStatement["sig"].([]byte)
|
||||
if !ok {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "sig not present")
|
||||
}
|
||||
if err := cbor.Unmarshal(csig, &sig); err != nil {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "sig is malformed")
|
||||
}
|
||||
keyAuth, err := KeyAuthorization(ch.Token, jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch pub := leaf.PublicKey.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
if pub.Curve != elliptic.P256() {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "unsupported elliptic curve %s", pub.Curve)
|
||||
}
|
||||
sum := sha256.Sum256([]byte(keyAuth))
|
||||
if !ecdsa.VerifyASN1(pub, sum[:], sig) {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "failed to validate signature")
|
||||
}
|
||||
case *rsa.PublicKey:
|
||||
sum := sha256.Sum256([]byte(keyAuth))
|
||||
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sig); err != nil {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "failed to validate signature")
|
||||
}
|
||||
case ed25519.PublicKey:
|
||||
if !ed25519.Verify(pub, []byte(keyAuth), sig) {
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "failed to validate signature")
|
||||
}
|
||||
default:
|
||||
return nil, NewError(ErrorBadAttestationStatementType, "unsupported public key type %T", pub)
|
||||
}
|
||||
|
||||
// Parse attestation data:
|
||||
// TODO(mariano): add support for other extensions.
|
||||
data := &stepAttestationData{
|
||||
Certificate: leaf,
|
||||
}
|
||||
if data.Fingerprint, err = keyutil.Fingerprint(leaf.PublicKey); err != nil {
|
||||
return nil, WrapErrorISE(err, "error calculating key fingerprint")
|
||||
}
|
||||
for _, ext := range leaf.Extensions {
|
||||
if !ext.Id.Equal(oidYubicoSerialNumber) {
|
||||
continue
|
||||
}
|
||||
var serialNumber int
|
||||
rest, err := asn1.Unmarshal(ext.Value, &serialNumber)
|
||||
if err != nil || len(rest) > 0 {
|
||||
return nil, WrapError(ErrorBadAttestationStatementType, err, "error parsing serial number")
|
||||
}
|
||||
data.SerialNumber = strconv.Itoa(serialNumber)
|
||||
break
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// serverName determines the SNI HostName to set based on an acme.Challenge
|
||||
// for TLS-ALPN-01 challenges RFC8738 states that, if HostName is an IP, it
|
||||
// should be the ARPA address https://datatracker.ietf.org/doc/html/rfc8738#section-6.
|
||||
|
@ -1088,10 +331,10 @@ func uitoa(val uint) string {
|
|||
var buf [20]byte // big enough for 64bit value base 10
|
||||
i := len(buf) - 1
|
||||
for val >= 10 {
|
||||
v := val / 10
|
||||
buf[i] = byte('0' + val - v*10)
|
||||
q := val / 10
|
||||
buf[i] = byte('0' + val - q*10)
|
||||
i--
|
||||
val = v
|
||||
val = q
|
||||
}
|
||||
// val < 10
|
||||
buf[i] = byte('0' + val)
|
||||
|
@ -1122,3 +365,14 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type httpGetter func(string) (*http.Response, error)
|
||||
type lookupTxt func(string) ([]string, error)
|
||||
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
|
||||
// ValidateChallengeOptions are ACME challenge validator functions.
|
||||
type ValidateChallengeOptions struct {
|
||||
HTTPGet httpGetter
|
||||
LookupTxt lookupTxt
|
||||
TLSDial tlsDialer
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,860 +0,0 @@
|
|||
//go:build tpmsimulator
|
||||
// +build tpmsimulator
|
||||
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/smallstep/go-attestation/attest"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/minica"
|
||||
"go.step.sm/crypto/tpm"
|
||||
"go.step.sm/crypto/tpm/simulator"
|
||||
tpmstorage "go.step.sm/crypto/tpm/storage"
|
||||
"go.step.sm/crypto/x509util"
|
||||
)
|
||||
|
||||
func newSimulatedTPM(t *testing.T) *tpm.TPM {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
tpm, err := tpm.New(withSimulator(t), tpm.WithStore(tpmstorage.NewDirstore(tmpDir))) // TODO: provide in-memory storage implementation instead
|
||||
require.NoError(t, err)
|
||||
return tpm
|
||||
}
|
||||
|
||||
func withSimulator(t *testing.T) tpm.NewTPMOption {
|
||||
t.Helper()
|
||||
var sim simulator.Simulator
|
||||
t.Cleanup(func() {
|
||||
if sim == nil {
|
||||
return
|
||||
}
|
||||
err := sim.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
sim, err := simulator.New()
|
||||
require.NoError(t, err)
|
||||
err = sim.Open()
|
||||
require.NoError(t, err)
|
||||
return tpm.WithSimulator(sim)
|
||||
}
|
||||
|
||||
func generateKeyID(t *testing.T, pub crypto.PublicKey) []byte {
|
||||
t.Helper()
|
||||
b, err := x509.MarshalPKIXPublicKey(pub)
|
||||
require.NoError(t, err)
|
||||
hash := sha256.Sum256(b)
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
func mustAttestTPM(t *testing.T, keyAuthorization string, permanentIdentifiers []string) ([]byte, crypto.Signer, *x509.Certificate) {
|
||||
t.Helper()
|
||||
aca, err := minica.New(
|
||||
minica.WithName("TPM Testing"),
|
||||
minica.WithGetSignerFunc(
|
||||
func() (crypto.Signer, error) {
|
||||
return keyutil.GenerateSigner("RSA", "", 2048)
|
||||
},
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// prepare simulated TPM and create an AK
|
||||
stpm := newSimulatedTPM(t)
|
||||
eks, err := stpm.GetEKs(context.Background())
|
||||
require.NoError(t, err)
|
||||
ak, err := stpm.CreateAK(context.Background(), "first-ak")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ak)
|
||||
|
||||
// extract the AK public key // TODO(hs): replace this when there's a simpler method to get the AK public key (e.g. ak.Public())
|
||||
ap, err := ak.AttestationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
akp, err := attest.ParseAKPublic(attest.TPMVersion20, ap.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create template and sign certificate for the AK public key
|
||||
keyID := generateKeyID(t, eks[0].Public())
|
||||
template := &x509.Certificate{
|
||||
PublicKey: akp.Public,
|
||||
IsCA: false,
|
||||
UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate},
|
||||
}
|
||||
sans := []x509util.SubjectAlternativeName{}
|
||||
uris := []*url.URL{{Scheme: "urn", Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID)}}
|
||||
for _, pi := range permanentIdentifiers {
|
||||
sans = append(sans, x509util.SubjectAlternativeName{
|
||||
Type: x509util.PermanentIdentifierType,
|
||||
Value: pi,
|
||||
})
|
||||
}
|
||||
asn1Value := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55"))
|
||||
sans = append(sans, x509util.SubjectAlternativeName{
|
||||
Type: x509util.DirectoryNameType,
|
||||
ASN1Value: asn1Value,
|
||||
})
|
||||
ext, err := createSubjectAltNameExtension(nil, nil, nil, uris, sans, true)
|
||||
require.NoError(t, err)
|
||||
ext.Set(template)
|
||||
akCert, err := aca.Sign(template)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, akCert)
|
||||
|
||||
// create a new key attested by the AK, while including
|
||||
// the key authorization bytes as qualifying data.
|
||||
keyAuthSum := sha256.Sum256([]byte(keyAuthorization))
|
||||
config := tpm.AttestKeyConfig{
|
||||
Algorithm: "RSA",
|
||||
Size: 2048,
|
||||
QualifyingData: keyAuthSum[:],
|
||||
}
|
||||
key, err := stpm.AttestKey(context.Background(), "first-ak", "first-key", config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
require.Equal(t, "first-key", key.Name())
|
||||
require.NotEqual(t, 0, len(key.Data()))
|
||||
require.Equal(t, "first-ak", key.AttestedBy())
|
||||
require.True(t, key.WasAttested())
|
||||
require.True(t, key.WasAttestedBy(ak))
|
||||
|
||||
signer, err := key.Signer(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// prepare the attestation object with the AK certificate chain,
|
||||
// the attested key, its metadata and the signature signed by the
|
||||
// AK.
|
||||
params, err := key.CertificationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
attObj, err := cbor.Marshal(struct {
|
||||
Format string `json:"fmt"`
|
||||
AttStatement map[string]interface{} `json:"attStmt,omitempty"`
|
||||
}{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// marshal the ACME payload
|
||||
payload, err := json.Marshal(struct {
|
||||
AttObj string `json:"attObj"`
|
||||
}{
|
||||
AttObj: base64.RawURLEncoding.EncodeToString(attObj),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return payload, signer, aca.Root
|
||||
}
|
||||
|
||||
func Test_deviceAttest01ValidateWithTPMSimulator(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
ch *Challenge
|
||||
db DB
|
||||
jwk *jose.JSONWebKey
|
||||
payload []byte
|
||||
}
|
||||
type test struct {
|
||||
args args
|
||||
wantErr *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok/doTPMAttestationFormat-storeError": func(t *testing.T) test {
|
||||
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
|
||||
payload, _, root := mustAttestTPM(t, keyAuth, nil) // TODO: value(s) for AK cert?
|
||||
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
|
||||
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
|
||||
|
||||
// parse payload, set invalid "ver", remarshal
|
||||
var p payloadType
|
||||
err := json.Unmarshal(payload, &p)
|
||||
require.NoError(t, err)
|
||||
attObj, err := base64.RawURLEncoding.DecodeString(p.AttObj)
|
||||
require.NoError(t, err)
|
||||
att := attestationObject{}
|
||||
err = cbor.Unmarshal(attObj, &att)
|
||||
require.NoError(t, err)
|
||||
att.AttStatement["ver"] = "bogus"
|
||||
attObj, err = cbor.Marshal(struct {
|
||||
Format string `json:"fmt"`
|
||||
AttStatement map[string]interface{} `json:"attStmt,omitempty"`
|
||||
}{
|
||||
Format: "tpm",
|
||||
AttStatement: att.AttStatement,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
payload, err = json.Marshal(struct {
|
||||
AttObj string `json:"attObj"`
|
||||
}{
|
||||
AttObj: base64.RawURLEncoding.EncodeToString(attObj),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
jwk: jwk,
|
||||
ch: &Challenge{
|
||||
ID: "chID",
|
||||
AuthorizationID: "azID",
|
||||
Token: "token",
|
||||
Type: "device-attest-01",
|
||||
Status: StatusPending,
|
||||
Value: "device.id.12345678",
|
||||
},
|
||||
payload: payload,
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
assert.Equal(t, "azID", id)
|
||||
return &Authorization{ID: "azID"}, nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
assert.Equal(t, "chID", updch.ID)
|
||||
assert.Equal(t, "token", updch.Token)
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "device.id.12345678", updch.Value)
|
||||
|
||||
err := NewError(ErrorBadAttestationStatementType, `version "bogus" is not supported`)
|
||||
|
||||
assert.EqualError(t, updch.Error.Err, err.Err.Error())
|
||||
assert.Equal(t, err.Type, updch.Error.Type)
|
||||
assert.Equal(t, err.Detail, updch.Error.Detail)
|
||||
assert.Equal(t, err.Status, updch.Error.Status)
|
||||
assert.Equal(t, err.Subproblems, updch.Error.Subproblems)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
}
|
||||
},
|
||||
"ok with invalid PermanentIdentifier SAN": func(t *testing.T) test {
|
||||
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
|
||||
payload, _, root := mustAttestTPM(t, keyAuth, []string{"device.id.12345678"}) // TODO: value(s) for AK cert?
|
||||
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
|
||||
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
|
||||
return test{
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
jwk: jwk,
|
||||
ch: &Challenge{
|
||||
ID: "chID",
|
||||
AuthorizationID: "azID",
|
||||
Token: "token",
|
||||
Type: "device-attest-01",
|
||||
Status: StatusPending,
|
||||
Value: "device.id.99999999",
|
||||
},
|
||||
payload: payload,
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
assert.Equal(t, "azID", id)
|
||||
return &Authorization{ID: "azID"}, nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
assert.Equal(t, "chID", updch.ID)
|
||||
assert.Equal(t, "token", updch.Token)
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "device.id.99999999", updch.Value)
|
||||
|
||||
err := NewError(ErrorRejectedIdentifierType, `permanent identifier does not match`).
|
||||
AddSubproblems(NewSubproblemWithIdentifier(
|
||||
ErrorMalformedType,
|
||||
Identifier{Type: "permanent-identifier", Value: "device.id.99999999"},
|
||||
`challenge identifier "device.id.99999999" doesn't match any of the attested hardware identifiers ["device.id.12345678"]`,
|
||||
))
|
||||
|
||||
assert.EqualError(t, updch.Error.Err, err.Err.Error())
|
||||
assert.Equal(t, err.Type, updch.Error.Type)
|
||||
assert.Equal(t, err.Detail, updch.Error.Detail)
|
||||
assert.Equal(t, err.Status, updch.Error.Status)
|
||||
assert.Equal(t, err.Subproblems, updch.Error.Subproblems)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
|
||||
payload, signer, root := mustAttestTPM(t, keyAuth, nil) // TODO: value(s) for AK cert?
|
||||
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
|
||||
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
|
||||
return test{
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
jwk: jwk,
|
||||
ch: &Challenge{
|
||||
ID: "chID",
|
||||
AuthorizationID: "azID",
|
||||
Token: "token",
|
||||
Type: "device-attest-01",
|
||||
Status: StatusPending,
|
||||
Value: "device.id.12345678",
|
||||
},
|
||||
payload: payload,
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
assert.Equal(t, "azID", id)
|
||||
return &Authorization{ID: "azID"}, nil
|
||||
},
|
||||
MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error {
|
||||
fingerprint, err := keyutil.Fingerprint(signer.Public())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "azID", az.ID)
|
||||
assert.Equal(t, fingerprint, az.Fingerprint)
|
||||
return nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
assert.Equal(t, "chID", updch.ID)
|
||||
assert.Equal(t, "token", updch.Token)
|
||||
assert.Equal(t, StatusValid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "device.id.12345678", updch.Value)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
}
|
||||
},
|
||||
"ok with PermanentIdentifier SAN": func(t *testing.T) test {
|
||||
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
|
||||
payload, signer, root := mustAttestTPM(t, keyAuth, []string{"device.id.12345678"}) // TODO: value(s) for AK cert?
|
||||
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
|
||||
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
|
||||
return test{
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
jwk: jwk,
|
||||
ch: &Challenge{
|
||||
ID: "chID",
|
||||
AuthorizationID: "azID",
|
||||
Token: "token",
|
||||
Type: "device-attest-01",
|
||||
Status: StatusPending,
|
||||
Value: "device.id.12345678",
|
||||
},
|
||||
payload: payload,
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
assert.Equal(t, "azID", id)
|
||||
return &Authorization{ID: "azID"}, nil
|
||||
},
|
||||
MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error {
|
||||
fingerprint, err := keyutil.Fingerprint(signer.Public())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "azID", az.ID)
|
||||
assert.Equal(t, fingerprint, az.Fingerprint)
|
||||
return nil
|
||||
},
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
assert.Equal(t, "chID", updch.ID)
|
||||
assert.Equal(t, "token", updch.Token)
|
||||
assert.Equal(t, StatusValid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "device.id.12345678", updch.Value)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
if err := deviceAttest01Validate(tc.args.ctx, tc.args.ch, tc.args.db, tc.args.jwk, tc.args.payload); err != nil {
|
||||
assert.Error(t, tc.wantErr)
|
||||
assert.EqualError(t, err, tc.wantErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.Nil(t, tc.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newBadAttestationStatementError(msg string) *Error {
|
||||
return &Error{
|
||||
Type: "urn:ietf:params:acme:error:badAttestationStatement",
|
||||
Status: 400,
|
||||
Err: errors.New(msg),
|
||||
}
|
||||
}
|
||||
|
||||
func newInternalServerError(msg string) *Error {
|
||||
return &Error{
|
||||
Type: "urn:ietf:params:acme:error:serverInternal",
|
||||
Status: 500,
|
||||
Err: errors.New(msg),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
oidPermanentIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}
|
||||
oidHardwareModuleNameIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 4}
|
||||
)
|
||||
|
||||
func Test_doTPMAttestationFormat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
aca, err := minica.New(
|
||||
minica.WithName("TPM Testing"),
|
||||
minica.WithGetSignerFunc(
|
||||
func() (crypto.Signer, error) {
|
||||
return keyutil.GenerateSigner("RSA", "", 2048)
|
||||
},
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
acaRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: aca.Root.Raw})
|
||||
|
||||
// prepare simulated TPM and create an AK
|
||||
stpm := newSimulatedTPM(t)
|
||||
eks, err := stpm.GetEKs(context.Background())
|
||||
require.NoError(t, err)
|
||||
ak, err := stpm.CreateAK(context.Background(), "first-ak")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ak)
|
||||
|
||||
// extract the AK public key // TODO(hs): replace this when there's a simpler method to get the AK public key (e.g. ak.Public())
|
||||
ap, err := ak.AttestationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
akp, err := attest.ParseAKPublic(attest.TPMVersion20, ap.Public)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create template and sign certificate for the AK public key
|
||||
keyID := generateKeyID(t, eks[0].Public())
|
||||
template := &x509.Certificate{
|
||||
PublicKey: akp.Public,
|
||||
IsCA: false,
|
||||
UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate},
|
||||
}
|
||||
sans := []x509util.SubjectAlternativeName{}
|
||||
uris := []*url.URL{{Scheme: "urn", Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID)}}
|
||||
asn1Value := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55"))
|
||||
sans = append(sans, x509util.SubjectAlternativeName{
|
||||
Type: x509util.DirectoryNameType,
|
||||
ASN1Value: asn1Value,
|
||||
})
|
||||
ext, err := createSubjectAltNameExtension(nil, nil, nil, uris, sans, true)
|
||||
require.NoError(t, err)
|
||||
ext.Set(template)
|
||||
akCert, err := aca.Sign(template)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, akCert)
|
||||
|
||||
invalidTemplate := &x509.Certificate{
|
||||
PublicKey: akp.Public,
|
||||
IsCA: false,
|
||||
UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate},
|
||||
}
|
||||
invalidAKCert, err := aca.Sign(invalidTemplate)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, invalidAKCert)
|
||||
|
||||
// generate a JWK and the key authorization value
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
require.NoError(t, err)
|
||||
keyAuthorization, err := KeyAuthorization("token", jwk)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a new key attested by the AK, while including
|
||||
// the key authorization bytes as qualifying data.
|
||||
keyAuthSum := sha256.Sum256([]byte(keyAuthorization))
|
||||
config := tpm.AttestKeyConfig{
|
||||
Algorithm: "RSA",
|
||||
Size: 2048,
|
||||
QualifyingData: keyAuthSum[:],
|
||||
}
|
||||
key, err := stpm.AttestKey(context.Background(), "first-ak", "first-key", config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
params, err := key.CertificationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
signer, err := key.Signer(context.Background())
|
||||
require.NoError(t, err)
|
||||
fingerprint, err := keyutil.Fingerprint(signer.Public())
|
||||
require.NoError(t, err)
|
||||
|
||||
// attest another key and get its certification parameters
|
||||
anotherKey, err := stpm.AttestKey(context.Background(), "first-ak", "another-key", config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
anotherKeyParams, err := anotherKey.CertificationParameters(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
prov Provisioner
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
att *attestationObject
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tpmAttestationData
|
||||
expErr *Error
|
||||
}{
|
||||
{"ok", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, nil},
|
||||
{"fail ver not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("ver not present")},
|
||||
{"fail ver type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": []interface{}{},
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("ver not present")},
|
||||
{"fail bogus ver", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "bogus",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError(`version "bogus" is not supported`)},
|
||||
{"fail x5c not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c not present")},
|
||||
{"fail x5c type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": [][]byte{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c not present")},
|
||||
{"fail x5c empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is empty")},
|
||||
{"fail leaf type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{"leaf", aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is malformed")},
|
||||
{"fail leaf parse", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw[:100], aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is malformed: x509: malformed certificate")},
|
||||
{"fail intermediate type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, "intermediate"},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is malformed")},
|
||||
{"fail intermediate parse", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw[:100]},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is malformed: x509: malformed certificate")},
|
||||
{"fail roots", args{ctx, mustAttestationProvisioner(t, nil), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newInternalServerError("no root CA bundle available to verify the attestation certificate")},
|
||||
{"fail verify", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "step",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("x5c is not valid: x509: certificate signed by unknown authority")},
|
||||
{"fail validateAKCertificate", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{invalidAKCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("AK certificate is not valid: missing TPM manufacturer")},
|
||||
{"fail pubArea not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid pubArea in attestation statement")},
|
||||
{"fail pubArea type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": []interface{}{},
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid pubArea in attestation statement")},
|
||||
{"fail pubArea empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": []byte{},
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("pubArea is empty")},
|
||||
{"fail sig not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid sig in attestation statement")},
|
||||
{"fail sig type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": []interface{}{},
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid sig in attestation statement")},
|
||||
{"fail sig empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": []byte{},
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("sig is empty")},
|
||||
{"fail certInfo not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid certInfo in attestation statement")},
|
||||
{"fail certInfo type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": []interface{}{},
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid certInfo in attestation statement")},
|
||||
{"fail certInfo empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": []byte{},
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("certInfo is empty")},
|
||||
{"fail alg not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid alg in attestation statement")},
|
||||
{"fail alg type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(0), // invalid alg
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid alg 0 in attestation statement")},
|
||||
{"fail attestation verification", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": anotherKeyParams.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("invalid certification parameters: certification refers to a different key")},
|
||||
{"fail keyAuthorization", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, &jose.JSONWebKey{Key: []byte("not an asymmetric key")}, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), // RS256
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newInternalServerError("failed creating key auth digest: error generating JWK thumbprint: square/go-jose: unknown key type '[]uint8'")},
|
||||
{"fail different keyAuthorization", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "aDifferentToken"}, jwk, &attestationObject{
|
||||
Format: "tpm",
|
||||
AttStatement: map[string]interface{}{
|
||||
"ver": "2.0",
|
||||
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
|
||||
"alg": int64(-257), //
|
||||
"sig": params.CreateSignature,
|
||||
"certInfo": params.CreateAttestation,
|
||||
"pubArea": params.Public,
|
||||
},
|
||||
}}, nil, newBadAttestationStatementError("key authorization does not match")},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := doTPMAttestationFormat(tt.args.ctx, tt.args.prov, tt.args.ch, tt.args.jwk, tt.args.att)
|
||||
if tt.expErr != nil {
|
||||
var ae *Error
|
||||
if assert.True(t, errors.As(err, &ae)) {
|
||||
assert.EqualError(t, err, tt.expErr.Error())
|
||||
assert.Equal(t, ae.StatusCode(), tt.expErr.StatusCode())
|
||||
assert.Equal(t, ae.Type, tt.expErr.Type)
|
||||
}
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
if assert.NotNil(t, got) {
|
||||
assert.Equal(t, akCert, got.Certificate)
|
||||
assert.Equal(t, [][]*x509.Certificate{
|
||||
{
|
||||
akCert, aca.Intermediate, aca.Root,
|
||||
},
|
||||
}, got.VerifiedChains)
|
||||
assert.Equal(t, fingerprint, got.Fingerprint)
|
||||
assert.Empty(t, got.PermanentIdentifiers) // currently expected to be always empty
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client is the interface used to verify ACME challenges.
|
||||
type Client interface {
|
||||
// Get issues an HTTP GET to the specified URL.
|
||||
Get(url string) (*http.Response, error)
|
||||
|
||||
// LookupTXT returns the DNS TXT records for the given domain name.
|
||||
LookupTxt(name string) ([]string, error)
|
||||
|
||||
// TLSDial connects to the given network address using net.Dialer and then
|
||||
// initiates a TLS handshake, returning the resulting TLS connection.
|
||||
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
type clientKey struct{}
|
||||
|
||||
// NewClientContext adds the given client to the context.
|
||||
func NewClientContext(ctx context.Context, c Client) context.Context {
|
||||
return context.WithValue(ctx, clientKey{}, c)
|
||||
}
|
||||
|
||||
// ClientFromContext returns the current client from the given context.
|
||||
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
|
||||
c, ok = ctx.Value(clientKey{}).(Client)
|
||||
return
|
||||
}
|
||||
|
||||
// MustClientFromContext returns the current client from the given context. It will
|
||||
// return a new instance of the client if it does not exist.
|
||||
func MustClientFromContext(ctx context.Context) Client {
|
||||
c, ok := ClientFromContext(ctx)
|
||||
if !ok {
|
||||
return NewClient()
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type client struct {
|
||||
http *http.Client
|
||||
dialer *net.Dialer
|
||||
}
|
||||
|
||||
// NewClient returns an implementation of Client for verifying ACME challenges.
|
||||
func NewClient() Client {
|
||||
return &client{
|
||||
http: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
//nolint:gosec // used on tls-alpn-01 challenge
|
||||
InsecureSkipVerify: true, // lgtm[go/disabled-certificate-check]
|
||||
},
|
||||
},
|
||||
},
|
||||
dialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Get(url string) (*http.Response, error) {
|
||||
return c.http.Get(url)
|
||||
}
|
||||
|
||||
func (c *client) LookupTxt(name string) ([]string, error) {
|
||||
return net.LookupTXT(name)
|
||||
}
|
||||
|
||||
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(c.dialer, network, addr, config)
|
||||
}
|
132
acme/common.go
132
acme/common.go
|
@ -9,6 +9,14 @@ import (
|
|||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
IsRevoked(sn string) (bool, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
|
||||
|
@ -19,104 +27,27 @@ func (c *Clock) Now() time.Time {
|
|||
|
||||
var clock Clock
|
||||
|
||||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
AreSANsAllowed(ctx context.Context, sans []string) error
|
||||
IsRevoked(sn string) (bool, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// NewContext adds the given acme components to the context.
|
||||
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
|
||||
ctx = NewDatabaseContext(ctx, db)
|
||||
ctx = NewClientContext(ctx, client)
|
||||
ctx = NewLinkerContext(ctx, linker)
|
||||
// Prerequisite checker is optional.
|
||||
if fn != nil {
|
||||
ctx = NewPrerequisitesCheckerContext(ctx, fn)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// PrerequisitesChecker is a function that checks if all prerequisites for
|
||||
// serving ACME are met by the CA configuration.
|
||||
type PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
|
||||
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
|
||||
// always true.
|
||||
func DefaultPrerequisitesChecker(context.Context) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type prerequisitesKey struct{}
|
||||
|
||||
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
|
||||
// context.
|
||||
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
|
||||
return context.WithValue(ctx, prerequisitesKey{}, fn)
|
||||
}
|
||||
|
||||
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
|
||||
// context.
|
||||
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
|
||||
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
|
||||
return fn, ok && fn != nil
|
||||
}
|
||||
|
||||
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
||||
// only those methods required by the ACME api/authority.
|
||||
type Provisioner interface {
|
||||
AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error
|
||||
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
|
||||
AuthorizeRevoke(ctx context.Context, token string) error
|
||||
IsChallengeEnabled(ctx context.Context, challenge provisioner.ACMEChallenge) bool
|
||||
IsAttestationFormatEnabled(ctx context.Context, format provisioner.ACMEAttestationFormat) bool
|
||||
GetAttestationRoots() (*x509.CertPool, bool)
|
||||
GetID() string
|
||||
GetName() string
|
||||
DefaultTLSCertDuration() time.Duration
|
||||
GetOptions() *provisioner.Options
|
||||
}
|
||||
|
||||
type provisionerKey struct{}
|
||||
|
||||
// NewProvisionerContext adds the given provisioner to the context.
|
||||
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
|
||||
return context.WithValue(ctx, provisionerKey{}, v)
|
||||
}
|
||||
|
||||
// ProvisionerFromContext returns the current provisioner from the given context.
|
||||
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
|
||||
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current provisioner from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustProvisionerFromContext(ctx context.Context) Provisioner {
|
||||
if v, ok := ProvisionerFromContext(ctx); !ok {
|
||||
panic("acme provisioner is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// MockProvisioner for testing
|
||||
type MockProvisioner struct {
|
||||
Mret1 interface{}
|
||||
Merr error
|
||||
MgetID func() string
|
||||
MgetName func() string
|
||||
MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error
|
||||
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
MauthorizeRevoke func(ctx context.Context, token string) error
|
||||
MisChallengeEnabled func(ctx context.Context, challenge provisioner.ACMEChallenge) bool
|
||||
MisAttFormatEnabled func(ctx context.Context, format provisioner.ACMEAttestationFormat) bool
|
||||
MgetAttestationRoots func() (*x509.CertPool, bool)
|
||||
MdefaultTLSCertDuration func() time.Duration
|
||||
MgetOptions func() *provisioner.Options
|
||||
Mret1 interface{}
|
||||
Merr error
|
||||
MgetID func() string
|
||||
MgetName func() string
|
||||
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
MauthorizeRevoke func(ctx context.Context, token string) error
|
||||
MdefaultTLSCertDuration func() time.Duration
|
||||
MgetOptions func() *provisioner.Options
|
||||
}
|
||||
|
||||
// GetName mock
|
||||
|
@ -127,14 +58,6 @@ func (m *MockProvisioner) GetName() string {
|
|||
return m.Mret1.(string)
|
||||
}
|
||||
|
||||
// AuthorizeOrderIdentifiers mock
|
||||
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
|
||||
if m.MauthorizeOrderIdentifier != nil {
|
||||
return m.MauthorizeOrderIdentifier(ctx, identifier)
|
||||
}
|
||||
return m.Merr
|
||||
}
|
||||
|
||||
// AuthorizeSign mock
|
||||
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
if m.MauthorizeSign != nil {
|
||||
|
@ -151,29 +74,6 @@ func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, token string) err
|
|||
return m.Merr
|
||||
}
|
||||
|
||||
// IsChallengeEnabled mock
|
||||
func (m *MockProvisioner) IsChallengeEnabled(ctx context.Context, challenge provisioner.ACMEChallenge) bool {
|
||||
if m.MisChallengeEnabled != nil {
|
||||
return m.MisChallengeEnabled(ctx, challenge)
|
||||
}
|
||||
return m.Merr == nil
|
||||
}
|
||||
|
||||
// IsAttestationFormatEnabled mock
|
||||
func (m *MockProvisioner) IsAttestationFormatEnabled(ctx context.Context, format provisioner.ACMEAttestationFormat) bool {
|
||||
if m.MisAttFormatEnabled != nil {
|
||||
return m.MisAttFormatEnabled(ctx, format)
|
||||
}
|
||||
return m.Merr == nil
|
||||
}
|
||||
|
||||
func (m *MockProvisioner) GetAttestationRoots() (*x509.CertPool, bool) {
|
||||
if m.MgetAttestationRoots != nil {
|
||||
return m.MgetAttestationRoots()
|
||||
}
|
||||
return m.Mret1.(*x509.CertPool), m.Mret1 != nil
|
||||
}
|
||||
|
||||
// DefaultTLSCertDuration mock
|
||||
func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration {
|
||||
if m.MdefaultTLSCertDuration != nil {
|
||||
|
|
41
acme/db.go
41
acme/db.go
|
@ -12,12 +12,6 @@ import (
|
|||
// account.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// IsErrNotFound returns true if the error is a "not found" error. Returns false
|
||||
// otherwise.
|
||||
func IsErrNotFound(err error) bool {
|
||||
return errors.Is(err, ErrNotFound)
|
||||
}
|
||||
|
||||
// DB is the DB interface expected by the step-ca ACME API.
|
||||
type DB interface {
|
||||
CreateAccount(ctx context.Context, acc *Account) error
|
||||
|
@ -29,7 +23,6 @@ type DB interface {
|
|||
GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
|
||||
GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
|
||||
GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
|
||||
GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
|
||||
DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error
|
||||
UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
|
||||
|
||||
|
@ -55,29 +48,6 @@ type DB interface {
|
|||
UpdateOrder(ctx context.Context, o *Order) error
|
||||
}
|
||||
|
||||
type dbKey struct{}
|
||||
|
||||
// NewDatabaseContext adds the given acme database to the context.
|
||||
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
|
||||
return context.WithValue(ctx, dbKey{}, db)
|
||||
}
|
||||
|
||||
// DatabaseFromContext returns the current acme database from the given context.
|
||||
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
|
||||
db, ok = ctx.Value(dbKey{}).(DB)
|
||||
return
|
||||
}
|
||||
|
||||
// MustDatabaseFromContext returns the current database from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustDatabaseFromContext(ctx context.Context) DB {
|
||||
if db, ok := DatabaseFromContext(ctx); !ok {
|
||||
panic("acme database is not in the context")
|
||||
} else {
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
// a mock in tests.
|
||||
type MockDB struct {
|
||||
|
@ -90,7 +60,6 @@ type MockDB struct {
|
|||
MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
|
||||
MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
|
||||
MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
|
||||
MockGetExternalAccountKeyByAccountID func(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
|
||||
MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error
|
||||
MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
|
||||
|
||||
|
@ -199,16 +168,6 @@ func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provision
|
|||
return m.MockRet1.(*ExternalAccountKey), m.MockError
|
||||
}
|
||||
|
||||
// GetExternalAccountKeyByAccountID mock
|
||||
func (m *MockDB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) {
|
||||
if m.MockGetExternalAccountKeyByAccountID != nil {
|
||||
return m.MockGetExternalAccountKeyByAccountID(ctx, provisionerID, accountID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*ExternalAccountKey), m.MockError
|
||||
}
|
||||
|
||||
// DeleteExternalAccountKey mock
|
||||
func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error {
|
||||
if m.MockDeleteExternalAccountKey != nil {
|
||||
|
|
|
@ -13,14 +13,12 @@ import (
|
|||
|
||||
// dbAccount represents an ACME account.
|
||||
type dbAccount struct {
|
||||
ID string `json:"id"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status acme.Status `json:"status"`
|
||||
LocationPrefix string `json:"locationPrefix"`
|
||||
ProvisionerName string `json:"provisionerName"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
DeactivatedAt time.Time `json:"deactivatedAt"`
|
||||
ID string `json:"id"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status acme.Status `json:"status"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
DeactivatedAt time.Time `json:"deactivatedAt"`
|
||||
}
|
||||
|
||||
func (dba *dbAccount) clone() *dbAccount {
|
||||
|
@ -28,7 +26,7 @@ func (dba *dbAccount) clone() *dbAccount {
|
|||
return &nu
|
||||
}
|
||||
|
||||
func (db *DB) getAccountIDByKeyID(_ context.Context, kid string) (string, error) {
|
||||
func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) {
|
||||
id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
|
||||
if err != nil {
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
|
@ -40,7 +38,7 @@ func (db *DB) getAccountIDByKeyID(_ context.Context, kid string) (string, error)
|
|||
}
|
||||
|
||||
// getDBAccount retrieves and unmarshals dbAccount.
|
||||
func (db *DB) getDBAccount(_ context.Context, id string) (*dbAccount, error) {
|
||||
func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) {
|
||||
data, err := db.db.Get(accountTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
|
@ -64,12 +62,10 @@ func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error)
|
|||
}
|
||||
|
||||
return &acme.Account{
|
||||
Status: dbacc.Status,
|
||||
Contact: dbacc.Contact,
|
||||
Key: dbacc.Key,
|
||||
ID: dbacc.ID,
|
||||
LocationPrefix: dbacc.LocationPrefix,
|
||||
ProvisionerName: dbacc.ProvisionerName,
|
||||
Status: dbacc.Status,
|
||||
Contact: dbacc.Contact,
|
||||
Key: dbacc.Key,
|
||||
ID: dbacc.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -91,13 +87,11 @@ func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
|
|||
}
|
||||
|
||||
dba := &dbAccount{
|
||||
ID: acc.ID,
|
||||
Key: acc.Key,
|
||||
Contact: acc.Contact,
|
||||
Status: acc.Status,
|
||||
CreatedAt: clock.Now(),
|
||||
LocationPrefix: acc.LocationPrefix,
|
||||
ProvisionerName: acc.ProvisionerName,
|
||||
ID: acc.ID,
|
||||
Key: acc.Key,
|
||||
Contact: acc.Contact,
|
||||
Status: acc.Status,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
kid, err := acme.KeyToID(dba.Key)
|
||||
|
|
|
@ -95,16 +95,16 @@ func TestDB_getDBAccount(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if dbacc, err := d.getDBAccount(context.Background(), accID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -174,16 +174,16 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if retAccID, err := d.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -197,8 +197,6 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
|||
|
||||
func TestDB_GetAccount(t *testing.T) {
|
||||
accID := "accID"
|
||||
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
|
||||
provisionerName := "foo"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
|
@ -224,14 +222,12 @@ func TestDB_GetAccount(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
ProvisionerName: provisionerName,
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -252,16 +248,16 @@ func TestDB_GetAccount(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if acc, err := d.GetAccount(context.Background(), accID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -270,8 +266,6 @@ func TestDB_GetAccount(t *testing.T) {
|
|||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||
assert.Equals(t, acc.LocationPrefix, tc.dbacc.LocationPrefix)
|
||||
assert.Equals(t, acc.ProvisionerName, tc.dbacc.ProvisionerName)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||
}
|
||||
})
|
||||
|
@ -360,16 +354,16 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if acc, err := d.GetAccountByKeyID(context.Background(), kid); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -385,7 +379,6 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDB_CreateAccount(t *testing.T) {
|
||||
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
acc *acme.Account
|
||||
|
@ -397,10 +390,9 @@ func TestDB_CreateAccount(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
|
@ -421,10 +413,9 @@ func TestDB_CreateAccount(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
|
@ -445,10 +436,9 @@ func TestDB_CreateAccount(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
|
@ -466,8 +456,6 @@ func TestDB_CreateAccount(t *testing.T) {
|
|||
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||
assert.Equals(t, dbacc.ID, string(key))
|
||||
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||
assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix)
|
||||
assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName)
|
||||
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||
|
@ -491,10 +479,9 @@ func TestDB_CreateAccount(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
LocationPrefix: locationPrefix,
|
||||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
|
@ -513,8 +500,6 @@ func TestDB_CreateAccount(t *testing.T) {
|
|||
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||
assert.Equals(t, dbacc.ID, string(key))
|
||||
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||
assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix)
|
||||
assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName)
|
||||
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||
|
@ -554,14 +539,12 @@ func TestDB_UpdateAccount(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
dbacc := &dbAccount{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
LocationPrefix: "foo",
|
||||
ProvisionerName: "alpha",
|
||||
Key: jwk,
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
CreatedAt: now,
|
||||
DeactivatedAt: now,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
b, err := json.Marshal(dbacc)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -661,12 +644,10 @@ func TestDB_UpdateAccount(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
Contact: []string{"baz", "zap"},
|
||||
LocationPrefix: "bar",
|
||||
ProvisionerName: "beta",
|
||||
Key: jwk,
|
||||
ID: accID,
|
||||
Status: acme.StatusDeactivated,
|
||||
Contact: []string{"foo", "bar"},
|
||||
Key: jwk,
|
||||
}
|
||||
return test{
|
||||
acc: acc,
|
||||
|
@ -685,10 +666,7 @@ func TestDB_UpdateAccount(t *testing.T) {
|
|||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||
assert.Equals(t, dbNew.Status, acc.Status)
|
||||
assert.Equals(t, dbNew.Contact, acc.Contact)
|
||||
// LocationPrefix should not change.
|
||||
assert.Equals(t, dbNew.LocationPrefix, dbacc.LocationPrefix)
|
||||
assert.Equals(t, dbNew.ProvisionerName, dbacc.ProvisionerName)
|
||||
assert.Equals(t, dbNew.Contact, dbacc.Contact)
|
||||
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||
|
@ -708,7 +686,12 @@ func TestDB_UpdateAccount(t *testing.T) {
|
|||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, dbacc.ID)
|
||||
assert.Equals(t, tc.acc.Status, dbacc.Status)
|
||||
assert.Equals(t, tc.acc.Contact, dbacc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@ type dbAuthz struct {
|
|||
Identifier acme.Identifier `json:"identifier"`
|
||||
Status acme.Status `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Fingerprint string `json:"fingerprint,omitempty"`
|
||||
ChallengeIDs []string `json:"challengeIDs"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
|
@ -32,7 +31,7 @@ func (ba *dbAuthz) clone() *dbAuthz {
|
|||
|
||||
// getDBAuthz retrieves and unmarshals a database representation of the
|
||||
// ACME Authorization type.
|
||||
func (db *DB) getDBAuthz(_ context.Context, id string) (*dbAuthz, error) {
|
||||
func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) {
|
||||
data, err := db.db.Get(authzTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id)
|
||||
|
@ -62,16 +61,15 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat
|
|||
}
|
||||
}
|
||||
return &acme.Authorization{
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: chs,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Token: dbaz.Token,
|
||||
Fingerprint: dbaz.Fingerprint,
|
||||
Error: dbaz.Error,
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: chs,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Token: dbaz.Token,
|
||||
Error: dbaz.Error,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -99,7 +97,6 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e
|
|||
Identifier: az.Identifier,
|
||||
ChallengeIDs: chIDs,
|
||||
Token: az.Token,
|
||||
Fingerprint: az.Fingerprint,
|
||||
Wildcard: az.Wildcard,
|
||||
}
|
||||
|
||||
|
@ -114,14 +111,14 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) e
|
|||
}
|
||||
|
||||
nu := old.clone()
|
||||
|
||||
nu.Status = az.Status
|
||||
nu.Fingerprint = az.Fingerprint
|
||||
nu.Error = az.Error
|
||||
return db.save(ctx, old.ID, nu, old, "authz", authzTable)
|
||||
}
|
||||
|
||||
// GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database.
|
||||
func (db *DB) GetAuthorizationsByAccountID(_ context.Context, accountID string) ([]*acme.Authorization, error) {
|
||||
func (db *DB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*acme.Authorization, error) {
|
||||
entries, err := db.db.List(authzTable)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error listing authz")
|
||||
|
@ -139,16 +136,15 @@ func (db *DB) GetAuthorizationsByAccountID(_ context.Context, accountID string)
|
|||
continue
|
||||
}
|
||||
authzs = append(authzs, &acme.Authorization{
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: nil, // challenges not required for current use case
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Token: dbaz.Token,
|
||||
Fingerprint: dbaz.Fingerprint,
|
||||
Error: dbaz.Error,
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: nil, // challenges not required for current use case
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Token: dbaz.Token,
|
||||
Error: dbaz.Error,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ func TestDB_getDBAuthz(t *testing.T) {
|
|||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("The server experienced an internal error"),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
|
@ -101,16 +101,16 @@ func TestDB_getDBAuthz(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if dbaz, err := d.getDBAuthz(context.Background(), azID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -254,7 +254,7 @@ func TestDB_GetAuthorization(t *testing.T) {
|
|||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
Error: acme.NewErrorISE("The server experienced an internal error"),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
|
@ -295,16 +295,16 @@ func TestDB_GetAuthorization(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if az, err := d.GetAuthorization(context.Background(), azID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -473,7 +473,6 @@ func TestDB_UpdateAuthorization(t *testing.T) {
|
|||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
Fingerprint: "fingerprint",
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -533,7 +532,7 @@ func TestDB_UpdateAuthorization(t *testing.T) {
|
|||
assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
|
||||
assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error())
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -550,11 +549,10 @@ func TestDB_UpdateAuthorization(t *testing.T) {
|
|||
{ID: "foo"},
|
||||
{ID: "bar"},
|
||||
},
|
||||
Token: dbaz.Token,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Fingerprint: "fingerprint",
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
Token: dbaz.Token,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
}
|
||||
return test{
|
||||
az: updAz,
|
||||
|
@ -584,8 +582,7 @@ func TestDB_UpdateAuthorization(t *testing.T) {
|
|||
assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
|
||||
assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
|
||||
assert.Equals(t, dbNew.Fingerprint, dbaz.Fingerprint)
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error())
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
|
@ -748,16 +745,16 @@ func TestDB_GetAuthorizationsByAccountID(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if azs, err := d.GetAuthorizationsByAccountID(context.Background(), accountID); err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err
|
|||
|
||||
// GetCertificate retrieves and unmarshals an ACME certificate type from the
|
||||
// datastore.
|
||||
func (db *DB) GetCertificate(_ context.Context, id string) (*acme.Certificate, error) {
|
||||
func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) {
|
||||
b, err := db.db.Get(certTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id)
|
||||
|
@ -138,4 +138,5 @@ func parseBundle(b []byte) ([]*x509.Certificate, error) {
|
|||
return nil, errors.New("error decoding PEM: unexpected data")
|
||||
}
|
||||
return bundle, nil
|
||||
|
||||
}
|
||||
|
|
|
@ -250,16 +250,16 @@ func TestDB_GetCertificate(t *testing.T) {
|
|||
d := DB{db: tc.db}
|
||||
cert, err := d.GetCertificate(context.Background(), certID)
|
||||
if err != nil {
|
||||
var acmeErr *acme.Error
|
||||
if errors.As(err, &acmeErr) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, acmeErr.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, acmeErr.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -444,16 +444,16 @@ func TestDB_GetCertificateBySerial(t *testing.T) {
|
|||
d := DB{db: tc.db}
|
||||
cert, err := d.GetCertificateBySerial(context.Background(), serial)
|
||||
if err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
|
|
@ -6,10 +6,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/nosql"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type dbChallenge struct {
|
||||
|
@ -21,7 +19,7 @@ type dbChallenge struct {
|
|||
Value string `json:"value"`
|
||||
ValidatedAt string `json:"validatedAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Error *acme.Error `json:"error"` // TODO(hs): a bit dangerous; should become db-specific type
|
||||
Error *acme.Error `json:"error"`
|
||||
}
|
||||
|
||||
func (dbc *dbChallenge) clone() *dbChallenge {
|
||||
|
@ -29,7 +27,7 @@ func (dbc *dbChallenge) clone() *dbChallenge {
|
|||
return &u
|
||||
}
|
||||
|
||||
func (db *DB) getDBChallenge(_ context.Context, id string) (*dbChallenge, error) {
|
||||
func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) {
|
||||
data, err := db.db.Get(challengeTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id)
|
||||
|
@ -69,7 +67,6 @@ func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error {
|
|||
// GetChallenge retrieves and unmarshals an ACME challenge type from the database.
|
||||
// Implements the acme.DB GetChallenge interface.
|
||||
func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) {
|
||||
_ = authzID // unused input
|
||||
dbch, err := db.getDBChallenge(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -72,7 +72,7 @@ func TestDB_getDBChallenge(t *testing.T) {
|
|||
Value: "test.ca.smallstep.com",
|
||||
CreatedAt: clock.Now(),
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewErrorISE("The server experienced an internal error"),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
}
|
||||
b, err := json.Marshal(dbc)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -94,16 +94,16 @@ func TestDB_getDBChallenge(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if ch, err := d.getDBChallenge(context.Background(), chID); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -264,7 +264,7 @@ func TestDB_GetChallenge(t *testing.T) {
|
|||
Value: "test.ca.smallstep.com",
|
||||
CreatedAt: clock.Now(),
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewErrorISE("The server experienced an internal error"),
|
||||
Error: acme.NewErrorISE("force"),
|
||||
}
|
||||
b, err := json.Marshal(dbc)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -286,16 +286,16 @@ func TestDB_GetChallenge(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if ch, err := d.GetChallenge(context.Background(), chID, azID); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -354,7 +354,7 @@ func TestDB_UpdateChallenge(t *testing.T) {
|
|||
ID: chID,
|
||||
Status: acme.StatusValid,
|
||||
ValidatedAt: "foobar",
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"),
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||
}
|
||||
return test{
|
||||
ch: updCh,
|
||||
|
@ -428,7 +428,7 @@ func TestDB_UpdateChallenge(t *testing.T) {
|
|||
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
|
||||
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||
assert.Equals(t, dbNew.ValidatedAt, "foobar")
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error())
|
||||
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
nosqlDB "github.com/smallstep/nosql"
|
||||
)
|
||||
|
@ -24,7 +23,7 @@ type dbExternalAccountKey struct {
|
|||
ProvisionerID string `json:"provisionerID"`
|
||||
Reference string `json:"reference"`
|
||||
AccountID string `json:"accountID,omitempty"`
|
||||
HmacKey []byte `json:"key"`
|
||||
KeyBytes []byte `json:"key"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
BoundAt time.Time `json:"boundAt"`
|
||||
}
|
||||
|
@ -35,7 +34,7 @@ type dbExternalAccountKeyReference struct {
|
|||
}
|
||||
|
||||
// getDBExternalAccountKey retrieves and unmarshals dbExternalAccountKey.
|
||||
func (db *DB) getDBExternalAccountKey(_ context.Context, id string) (*dbExternalAccountKey, error) {
|
||||
func (db *DB) getDBExternalAccountKey(ctx context.Context, id string) (*dbExternalAccountKey, error) {
|
||||
data, err := db.db.Get(externalAccountKeyTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosqlDB.IsErrNotFound(err) {
|
||||
|
@ -54,6 +53,7 @@ func (db *DB) getDBExternalAccountKey(_ context.Context, id string) (*dbExternal
|
|||
|
||||
// CreateExternalAccountKey creates a new External Account Binding key with a name
|
||||
func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
|
||||
externalAccountKeyMutex.Lock()
|
||||
defer externalAccountKeyMutex.Unlock()
|
||||
|
||||
|
@ -72,7 +72,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
|
|||
ID: keyID,
|
||||
ProvisionerID: provisionerID,
|
||||
Reference: reference,
|
||||
HmacKey: random,
|
||||
KeyBytes: random,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
|
|||
ProvisionerID: dbeak.ProvisionerID,
|
||||
Reference: dbeak.Reference,
|
||||
AccountID: dbeak.AccountID,
|
||||
HmacKey: dbeak.HmacKey,
|
||||
KeyBytes: dbeak.KeyBytes,
|
||||
CreatedAt: dbeak.CreatedAt,
|
||||
BoundAt: dbeak.BoundAt,
|
||||
}, nil
|
||||
|
@ -124,7 +124,7 @@ func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID st
|
|||
ProvisionerID: dbeak.ProvisionerID,
|
||||
Reference: dbeak.Reference,
|
||||
AccountID: dbeak.AccountID,
|
||||
HmacKey: dbeak.HmacKey,
|
||||
KeyBytes: dbeak.KeyBytes,
|
||||
CreatedAt: dbeak.CreatedAt,
|
||||
BoundAt: dbeak.BoundAt,
|
||||
}, nil
|
||||
|
@ -160,8 +160,6 @@ func (db *DB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID
|
|||
|
||||
// GetExternalAccountKeys retrieves all External Account Binding keys for a provisioner
|
||||
func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*acme.ExternalAccountKey, string, error) {
|
||||
_, _ = cursor, limit // unused input
|
||||
|
||||
externalAccountKeyMutex.RLock()
|
||||
defer externalAccountKeyMutex.RUnlock()
|
||||
|
||||
|
@ -193,7 +191,7 @@ func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor
|
|||
}
|
||||
keys = append(keys, &acme.ExternalAccountKey{
|
||||
ID: eak.ID,
|
||||
HmacKey: eak.HmacKey,
|
||||
KeyBytes: eak.KeyBytes,
|
||||
ProvisionerID: eak.ProvisionerID,
|
||||
Reference: eak.Reference,
|
||||
AccountID: eak.AccountID,
|
||||
|
@ -211,7 +209,6 @@ func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerI
|
|||
defer externalAccountKeyMutex.RUnlock()
|
||||
|
||||
if reference == "" {
|
||||
//nolint:nilnil // legacy
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@ -229,11 +226,6 @@ func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerI
|
|||
return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID)
|
||||
}
|
||||
|
||||
func (db *DB) GetExternalAccountKeyByAccountID(context.Context, string, string) (*acme.ExternalAccountKey, error) {
|
||||
//nolint:nilnil // legacy
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
externalAccountKeyMutex.Lock()
|
||||
defer externalAccountKeyMutex.Unlock()
|
||||
|
@ -260,7 +252,7 @@ func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string
|
|||
ProvisionerID: eak.ProvisionerID,
|
||||
Reference: eak.Reference,
|
||||
AccountID: eak.AccountID,
|
||||
HmacKey: eak.HmacKey,
|
||||
KeyBytes: eak.KeyBytes,
|
||||
CreatedAt: eak.CreatedAt,
|
||||
BoundAt: eak.BoundAt,
|
||||
}
|
||||
|
@ -374,6 +366,7 @@ func sliceIndex(slice []string, item string) int {
|
|||
// removeElement deletes the item if it exists in the
|
||||
// slice. It returns a new slice, keeping the old one intact.
|
||||
func removeElement(slice []string, item string) []string {
|
||||
|
||||
newSlice := make([]string, 0)
|
||||
index := sliceIndex(slice, item)
|
||||
if index < 0 {
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
certdb "github.com/smallstep/certificates/db"
|
||||
|
@ -33,7 +32,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -93,23 +92,23 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if dbeak, err := d.getDBExternalAccountKey(context.Background(), keyID); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbeak.ID, tc.dbeak.ID)
|
||||
assert.Equals(t, dbeak.HmacKey, tc.dbeak.HmacKey)
|
||||
assert.Equals(t, dbeak.KeyBytes, tc.dbeak.KeyBytes)
|
||||
assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID)
|
||||
assert.Equals(t, dbeak.Reference, tc.dbeak.Reference)
|
||||
assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt)
|
||||
|
@ -137,7 +136,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -155,7 +154,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
@ -180,7 +179,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: "aDifferentProvID",
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -198,7 +197,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"),
|
||||
|
@ -210,23 +209,23 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if eak, err := d.GetExternalAccountKey(context.Background(), provID, keyID); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, eak.ID, tc.eak.ID)
|
||||
assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
|
||||
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes)
|
||||
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
|
||||
assert.Equals(t, eak.Reference, tc.eak.Reference)
|
||||
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
|
||||
|
@ -256,7 +255,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -289,7 +288,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
err: nil,
|
||||
|
@ -374,16 +373,16 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if eak, err := d.GetExternalAccountKeyByReference(context.Background(), provID, tc.ref); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -393,7 +392,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
|
|||
assert.Equals(t, eak.AccountID, tc.eak.AccountID)
|
||||
assert.Equals(t, eak.BoundAt, tc.eak.BoundAt)
|
||||
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
|
||||
assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
|
||||
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes)
|
||||
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
|
||||
assert.Equals(t, eak.Reference, tc.eak.Reference)
|
||||
}
|
||||
|
@ -421,7 +420,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b1, err := json.Marshal(dbeak1)
|
||||
|
@ -431,7 +430,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b2, err := json.Marshal(dbeak2)
|
||||
|
@ -441,7 +440,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: "aDifferentProvID",
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b3, err := json.Marshal(dbeak3)
|
||||
|
@ -514,7 +513,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
|
@ -522,7 +521,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
},
|
||||
|
@ -580,16 +579,16 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
cursor, limit := "", 0
|
||||
if eaks, nextCursor, err := d.GetExternalAccountKeys(context.Background(), provID, cursor, limit); err != nil {
|
||||
assert.Equals(t, "", nextCursor)
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
|
@ -599,7 +598,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
assert.Equals(t, "", nextCursor)
|
||||
for i, eak := range eaks {
|
||||
assert.Equals(t, eak.ID, tc.eaks[i].ID)
|
||||
assert.Equals(t, eak.HmacKey, tc.eaks[i].HmacKey)
|
||||
assert.Equals(t, eak.KeyBytes, tc.eaks[i].KeyBytes)
|
||||
assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID)
|
||||
assert.Equals(t, eak.Reference, tc.eaks[i].Reference)
|
||||
assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt)
|
||||
|
@ -628,7 +627,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -672,7 +671,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
return errors.New("force default")
|
||||
}
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||
MCmpAndSwap: func(bucket, key, old, new []byte) ([]byte, bool, error) {
|
||||
fmt.Println(string(bucket))
|
||||
switch string(bucket) {
|
||||
case string(externalAccountKeyIDsByReferenceTable):
|
||||
|
@ -708,7 +707,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: "aDifferentProvID",
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -731,7 +730,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -781,7 +780,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -831,7 +830,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -882,16 +881,16 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.DeleteExternalAccountKey(context.Background(), provID, keyID); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -954,7 +953,7 @@ func TestDB_CreateExternalAccountKey(t *testing.T) {
|
|||
assert.Equals(t, string(key), dbeak.ID)
|
||||
assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID)
|
||||
assert.Equals(t, eak.Reference, dbeak.Reference)
|
||||
assert.Equals(t, 32, len(dbeak.HmacKey))
|
||||
assert.Equals(t, 32, len(dbeak.KeyBytes))
|
||||
assert.False(t, dbeak.CreatedAt.IsZero())
|
||||
assert.Equals(t, dbeak.AccountID, eak.AccountID)
|
||||
assert.True(t, dbeak.BoundAt.IsZero())
|
||||
|
@ -1079,7 +1078,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -1097,7 +1096,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
return test{
|
||||
|
@ -1121,7 +1120,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
assert.Equals(t, dbNew.AccountID, dbeak.AccountID)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt)
|
||||
assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt)
|
||||
assert.Equals(t, dbNew.HmacKey, dbeak.HmacKey)
|
||||
assert.Equals(t, dbNew.KeyBytes, dbeak.KeyBytes)
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
|
@ -1149,7 +1148,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: "aDifferentProvID",
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(newDBEAK)
|
||||
|
@ -1175,7 +1174,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(newDBEAK)
|
||||
|
@ -1201,7 +1200,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(newDBEAK)
|
||||
|
@ -1238,7 +1237,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
assert.Equals(t, dbeak.AccountID, tc.eak.AccountID)
|
||||
assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt)
|
||||
assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt)
|
||||
assert.Equals(t, dbeak.HmacKey, tc.eak.HmacKey)
|
||||
assert.Equals(t, dbeak.KeyBytes, tc.eak.KeyBytes)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
|
|||
|
||||
// DeleteNonce verifies that the nonce is valid (by checking if it exists),
|
||||
// and if so, consumes the nonce resource by deleting it from the database.
|
||||
func (db *DB) DeleteNonce(_ context.Context, nonce acme.Nonce) error {
|
||||
func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error {
|
||||
err := db.db.Update(&database.Tx{
|
||||
Operations: []*database.TxEntry{
|
||||
{
|
||||
|
|
|
@ -146,16 +146,16 @@ func TestDB_DeleteNonce(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if err := d.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ func New(db nosqlDB.DB) (*DB, error) {
|
|||
|
||||
// save writes the new data to the database, overwriting the old data if it
|
||||
// existed.
|
||||
func (db *DB) save(_ context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||
func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||
var (
|
||||
err error
|
||||
newB []byte
|
||||
|
|
|
@ -35,7 +35,7 @@ func (a *dbOrder) clone() *dbOrder {
|
|||
}
|
||||
|
||||
// getDBOrder retrieves and unmarshals an ACME Order type from the database.
|
||||
func (db *DB) getDBOrder(_ context.Context, id string) (*dbOrder, error) {
|
||||
func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) {
|
||||
b, err := db.db.Get(orderTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id)
|
||||
|
|
|
@ -80,7 +80,7 @@ func TestDB_getDBOrder(t *testing.T) {
|
|||
{Type: "dns", Value: "example.foo.com"},
|
||||
},
|
||||
AuthorizationIDs: []string{"foo", "bar"},
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"),
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "force"),
|
||||
}
|
||||
b, err := json.Marshal(dbo)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -102,16 +102,16 @@ func TestDB_getDBOrder(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if dbo, err := d.getDBOrder(context.Background(), orderID); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -185,7 +185,7 @@ func TestDB_GetOrder(t *testing.T) {
|
|||
{Type: "dns", Value: "example.foo.com"},
|
||||
},
|
||||
AuthorizationIDs: []string{"foo", "bar"},
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"),
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "force"),
|
||||
}
|
||||
b, err := json.Marshal(dbo)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -206,16 +206,16 @@ func TestDB_GetOrder(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if o, err := d.GetOrder(context.Background(), orderID); err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -284,7 +284,7 @@ func TestDB_UpdateOrder(t *testing.T) {
|
|||
ID: orderID,
|
||||
Status: acme.StatusValid,
|
||||
CertificateID: "certID",
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"),
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "force"),
|
||||
}
|
||||
return test{
|
||||
o: o,
|
||||
|
@ -324,7 +324,7 @@ func TestDB_UpdateOrder(t *testing.T) {
|
|||
ID: orderID,
|
||||
Status: acme.StatusValid,
|
||||
CertificateID: "certID",
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"),
|
||||
Error: acme.NewError(acme.ErrorMalformedType, "force"),
|
||||
}
|
||||
return test{
|
||||
o: o,
|
||||
|
@ -372,7 +372,7 @@ func TestDB_UpdateOrder(t *testing.T) {
|
|||
assert.Equals(t, tc.o.ID, dbo.ID)
|
||||
assert.Equals(t, tc.o.CertificateID, "certID")
|
||||
assert.Equals(t, tc.o.Status, acme.StatusValid)
|
||||
assert.Equals(t, tc.o.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error())
|
||||
assert.Equals(t, tc.o.Error.Error(), acme.NewError(acme.ErrorMalformedType, "force").Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -659,7 +659,7 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
|||
assert.Equals(t, newdbo.ID, "foo")
|
||||
assert.Equals(t, newdbo.Status, acme.StatusInvalid)
|
||||
assert.Equals(t, newdbo.ExpiresAt, expiry)
|
||||
assert.Equals(t, newdbo.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error())
|
||||
assert.Equals(t, newdbo.Error.Error(), acme.NewError(acme.ErrorMalformedType, "order has expired").Error())
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1003,16 +1003,16 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
var ae *acme.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, ae.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
|
106
acme/errors.go
106
acme/errors.go
|
@ -3,10 +3,13 @@ package acme
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// ProblemType is the type of the ACME problem.
|
||||
|
@ -17,8 +20,6 @@ const (
|
|||
ErrorAccountDoesNotExistType ProblemType = iota
|
||||
// ErrorAlreadyRevokedType request specified a certificate to be revoked that has already been revoked
|
||||
ErrorAlreadyRevokedType
|
||||
// ErrorBadAttestationStatementType WebAuthn attestation statement could not be verified
|
||||
ErrorBadAttestationStatementType
|
||||
// ErrorBadCSRType CSR is unacceptable (e.g., due to a short key)
|
||||
ErrorBadCSRType
|
||||
// ErrorBadNonceType client sent an unacceptable anti-replay nonce
|
||||
|
@ -75,8 +76,6 @@ func (ap ProblemType) String() string {
|
|||
return "accountDoesNotExist"
|
||||
case ErrorAlreadyRevokedType:
|
||||
return "alreadyRevoked"
|
||||
case ErrorBadAttestationStatementType:
|
||||
return "badAttestationStatement"
|
||||
case ErrorBadCSRType:
|
||||
return "badCSR"
|
||||
case ErrorBadNonceType:
|
||||
|
@ -176,11 +175,6 @@ var (
|
|||
details: "The JWS was signed with an algorithm the server does not support",
|
||||
status: 400,
|
||||
},
|
||||
ErrorBadAttestationStatementType: {
|
||||
typ: officialACMEPrefix + ErrorBadAttestationStatementType.String(),
|
||||
details: "Attestation statement cannot be verified",
|
||||
status: 400,
|
||||
},
|
||||
ErrorCaaType: {
|
||||
typ: officialACMEPrefix + ErrorCaaType.String(),
|
||||
details: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
|
||||
|
@ -270,34 +264,14 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
// Error represents an ACME Error
|
||||
// Error represents an ACME
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Subproblems []Subproblem `json:"subproblems,omitempty"`
|
||||
Err error `json:"-"`
|
||||
Status int `json:"-"`
|
||||
}
|
||||
|
||||
// Subproblem represents an ACME subproblem. It's fairly
|
||||
// similar to an ACME error, but differs in that it can't
|
||||
// include subproblems itself, the error is reflected
|
||||
// in the Detail property and doesn't have a Status.
|
||||
type Subproblem struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
// The "identifier" field MUST NOT be present at the top level in ACME
|
||||
// problem documents. It can only be present in subproblems.
|
||||
// Subproblems need not all have the same type, and they do not need to
|
||||
// match the top level type.
|
||||
Identifier *Identifier `json:"identifier,omitempty"`
|
||||
}
|
||||
|
||||
// AddSubproblems adds the Subproblems to Error. It
|
||||
// returns the Error, allowing for fluent addition.
|
||||
func (e *Error) AddSubproblems(subproblems ...Subproblem) *Error {
|
||||
e.Subproblems = append(e.Subproblems, subproblems...)
|
||||
return e
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Subproblems []interface{} `json:"subproblems,omitempty"`
|
||||
Identifier interface{} `json:"identifier,omitempty"`
|
||||
Err error `json:"-"`
|
||||
Status int `json:"-"`
|
||||
}
|
||||
|
||||
// NewError creates a new Error type.
|
||||
|
@ -305,26 +279,6 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
|
|||
return newError(pt, errors.Errorf(msg, args...))
|
||||
}
|
||||
|
||||
// NewSubproblem creates a new Subproblem. The msg and args
|
||||
// are used to create a new error, which is set as the Detail, allowing
|
||||
// for more detailed error messages to be returned to the ACME client.
|
||||
func NewSubproblem(pt ProblemType, msg string, args ...interface{}) Subproblem {
|
||||
e := newError(pt, fmt.Errorf(msg, args...))
|
||||
s := Subproblem{
|
||||
Type: e.Type,
|
||||
Detail: e.Err.Error(),
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// NewSubproblemWithIdentifier creates a new Subproblem with a specific ACME
|
||||
// Identifier. It calls NewSubproblem and sets the Identifier.
|
||||
func NewSubproblemWithIdentifier(pt ProblemType, identifier Identifier, msg string, args ...interface{}) Subproblem {
|
||||
s := NewSubproblem(pt, msg, args...)
|
||||
s.Identifier = &identifier
|
||||
return s
|
||||
}
|
||||
|
||||
func newError(pt ProblemType, err error) *Error {
|
||||
meta, ok := errorMap[pt]
|
||||
if !ok {
|
||||
|
@ -352,11 +306,10 @@ func NewErrorISE(msg string, args ...interface{}) *Error {
|
|||
|
||||
// WrapError attempts to wrap the internal error.
|
||||
func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error {
|
||||
var e *Error
|
||||
switch {
|
||||
case err == nil:
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case errors.As(err, &e):
|
||||
case *Error:
|
||||
if e.Err == nil {
|
||||
e.Err = errors.Errorf(msg+"; "+e.Detail, args...)
|
||||
} else {
|
||||
|
@ -378,12 +331,9 @@ func (e *Error) StatusCode() int {
|
|||
return e.Status
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
// Error allows AError to implement the error interface.
|
||||
func (e *Error) Error() string {
|
||||
if e.Err == nil {
|
||||
return e.Detail
|
||||
}
|
||||
return e.Err.Error()
|
||||
return e.Detail
|
||||
}
|
||||
|
||||
// Cause returns the internal error and implements the Causer interface.
|
||||
|
@ -403,8 +353,26 @@ func (e *Error) ToLog() (interface{}, error) {
|
|||
return string(b), nil
|
||||
}
|
||||
|
||||
// Render implements render.RenderableError for Error.
|
||||
func (e *Error) Render(w http.ResponseWriter) {
|
||||
// WriteError writes to w a JSON representation of the given error.
|
||||
func WriteError(w http.ResponseWriter, err *Error) {
|
||||
w.Header().Set("Content-Type", "application/problem+json")
|
||||
render.JSONStatus(w, e, e.StatusCode())
|
||||
w.WriteHeader(err.StatusCode())
|
||||
|
||||
// Write errors in the response writer
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"error": err.Err,
|
||||
})
|
||||
if os.Getenv("STEPDEBUG") == "1" {
|
||||
if e, ok := err.Err.(errs.StackTracer); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
|
122
acme/order.go
122
acme/order.go
|
@ -3,7 +3,6 @@ package acme
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"net"
|
||||
|
@ -12,7 +11,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
)
|
||||
|
||||
|
@ -23,9 +21,6 @@ const (
|
|||
IP IdentifierType = "ip"
|
||||
// DNS is the ACME dns identifier type
|
||||
DNS IdentifierType = "dns"
|
||||
// PermanentIdentifier is the ACME permanent-identifier identifier type
|
||||
// defined in https://datatracker.ietf.org/doc/html/draft-bweeks-acme-device-attest-00
|
||||
PermanentIdentifier IdentifierType = "permanent-identifier"
|
||||
)
|
||||
|
||||
// Identifier encodes the type that an order pertains to.
|
||||
|
@ -127,34 +122,8 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// getKeyFingerprint returns a fingerprint from the list of authorizations. This
|
||||
// fingerprint is used on the device-attest-01 flow to verify the attestation
|
||||
// certificate public key with the CSR public key.
|
||||
//
|
||||
// There's no point on reading all the authorizations as there will be only one
|
||||
// for a permanent identifier.
|
||||
func (o *Order) getAuthorizationFingerprint(ctx context.Context, db DB) (string, error) {
|
||||
for _, azID := range o.AuthorizationIDs {
|
||||
az, err := db.GetAuthorization(ctx, azID)
|
||||
if err != nil {
|
||||
return "", WrapErrorISE(err, "error getting authorization %q", azID)
|
||||
}
|
||||
// There's no point on reading all the authorizations as there will
|
||||
// be only one for a permanent identifier.
|
||||
if az.Fingerprint != "" {
|
||||
return az.Fingerprint, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Finalize signs a certificate if the necessary conditions for Order completion
|
||||
// have been met.
|
||||
//
|
||||
// TODO(mariano): Here or in the challenge validation we should perform some
|
||||
// external validation using the identifier value and the attestation data. From
|
||||
// a validation service we can get the list of SANs to set in the final
|
||||
// certificate.
|
||||
func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth CertificateAuthority, p Provisioner) error {
|
||||
if err := o.UpdateStatus(ctx, db); err != nil {
|
||||
return err
|
||||
|
@ -173,69 +142,13 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
|
|||
return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID)
|
||||
}
|
||||
|
||||
// Get key fingerprint if any. And then compare it with the CSR fingerprint.
|
||||
//
|
||||
// In device-attest-01 challenges we should check that the keys in the CSR
|
||||
// and the attestation certificate are the same.
|
||||
fingerprint, err := o.getAuthorizationFingerprint(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fingerprint != "" {
|
||||
fp, err := keyutil.Fingerprint(csr.PublicKey)
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error calculating key fingerprint")
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(fingerprint), []byte(fp)) == 0 {
|
||||
return NewError(ErrorUnauthorizedType, "order %s csr does not match the attested key", o.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// canonicalize the CSR to allow for comparison
|
||||
csr = canonicalize(csr)
|
||||
|
||||
// Template data
|
||||
data := x509util.NewTemplateData()
|
||||
data.SetCommonName(csr.Subject.CommonName)
|
||||
|
||||
// Custom sign options passed to authority.Sign
|
||||
var extraOptions []provisioner.SignOption
|
||||
|
||||
// TODO: support for multiple identifiers?
|
||||
var permanentIdentifier string
|
||||
for i := range o.Identifiers {
|
||||
if o.Identifiers[i].Type == PermanentIdentifier {
|
||||
permanentIdentifier = o.Identifiers[i].Value
|
||||
// the first (and only) Permanent Identifier that gets added to the certificate
|
||||
// should be equal to the Subject Common Name if it's set. If not equal, the CSR
|
||||
// is rejected, because the Common Name hasn't been challenged in that case. This
|
||||
// could result in unauthorized access if a relying system relies on the Common
|
||||
// Name in its authorization logic.
|
||||
if csr.Subject.CommonName != "" && csr.Subject.CommonName != permanentIdentifier {
|
||||
return NewError(ErrorBadCSRType, "CSR Subject Common Name does not match identifiers exactly: "+
|
||||
"CSR Subject Common Name = %s, Order Permanent Identifier = %s", csr.Subject.CommonName, permanentIdentifier)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var defaultTemplate string
|
||||
if permanentIdentifier != "" {
|
||||
defaultTemplate = x509util.DefaultAttestedLeafTemplate
|
||||
data.SetSubjectAlternativeNames(x509util.SubjectAlternativeName{
|
||||
Type: x509util.PermanentIdentifierType,
|
||||
Value: permanentIdentifier,
|
||||
})
|
||||
extraOptions = append(extraOptions, provisioner.AttestationData{
|
||||
PermanentIdentifier: permanentIdentifier,
|
||||
})
|
||||
} else {
|
||||
defaultTemplate = x509util.DefaultLeafTemplate
|
||||
sans, err := o.sans(csr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data.SetSubjectAlternativeNames(sans...)
|
||||
// retrieve the requested SANs for the Order
|
||||
sans, err := o.sans(csr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get authorizations from the ACME provisioner.
|
||||
|
@ -244,23 +157,17 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
|
|||
if err != nil {
|
||||
return WrapErrorISE(err, "error retrieving authorization options from ACME provisioner")
|
||||
}
|
||||
// Unlike most of the provisioners, ACME's AuthorizeSign method doesn't
|
||||
// define the templates, and the template data used in WebHooks is not
|
||||
// available.
|
||||
for _, signOp := range signOps {
|
||||
if wc, ok := signOp.(*provisioner.WebhookController); ok {
|
||||
wc.TemplateData = data
|
||||
}
|
||||
}
|
||||
|
||||
templateOptions, err := provisioner.CustomTemplateOptions(p.GetOptions(), data, defaultTemplate)
|
||||
// Template data
|
||||
data := x509util.NewTemplateData()
|
||||
data.SetCommonName(csr.Subject.CommonName)
|
||||
data.Set(x509util.SANsKey, sans)
|
||||
|
||||
templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data)
|
||||
if err != nil {
|
||||
return WrapErrorISE(err, "error creating template options from ACME provisioner")
|
||||
}
|
||||
|
||||
// Build extra signing options.
|
||||
signOps = append(signOps, templateOptions)
|
||||
signOps = append(signOps, extraOptions...)
|
||||
|
||||
// Sign a new certificate.
|
||||
certChain, err := auth.Sign(csr, provisioner.SignOptions{
|
||||
|
@ -290,7 +197,9 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
|
|||
}
|
||||
|
||||
func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativeName, error) {
|
||||
|
||||
var sans []x509util.SubjectAlternativeName
|
||||
|
||||
if len(csr.EmailAddresses) > 0 || len(csr.URIs) > 0 {
|
||||
return sans, NewError(ErrorBadCSRType, "Only DNS names and IP addresses are allowed")
|
||||
}
|
||||
|
@ -298,8 +207,7 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ
|
|||
// order the DNS names and IP addresses, so that they can be compared against the canonicalized CSR
|
||||
orderNames := make([]string, numberOfIdentifierType(DNS, o.Identifiers))
|
||||
orderIPs := make([]net.IP, numberOfIdentifierType(IP, o.Identifiers))
|
||||
orderPIDs := make([]string, numberOfIdentifierType(PermanentIdentifier, o.Identifiers))
|
||||
indexDNS, indexIP, indexPID := 0, 0, 0
|
||||
indexDNS, indexIP := 0, 0
|
||||
for _, n := range o.Identifiers {
|
||||
switch n.Type {
|
||||
case DNS:
|
||||
|
@ -308,9 +216,6 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ
|
|||
case IP:
|
||||
orderIPs[indexIP] = net.ParseIP(n.Value) // NOTE: this assumes are all valid IPs at this time; or will result in nil entries
|
||||
indexIP++
|
||||
case PermanentIdentifier:
|
||||
orderPIDs[indexPID] = n.Value
|
||||
indexPID++
|
||||
default:
|
||||
return sans, NewErrorISE("unsupported identifier type in order: %s", n.Type)
|
||||
}
|
||||
|
@ -382,6 +287,7 @@ func numberOfIdentifierType(typ IdentifierType, ids []Identifier) int {
|
|||
// addresses or DNS names slice, depending on whether it can be parsed as an IP
|
||||
// or not. This might result in an additional SAN in the final certificate.
|
||||
func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.CertificateRequest) {
|
||||
|
||||
// for clarity only; we're operating on the same object by pointer
|
||||
canonicalized = csr
|
||||
|
||||
|
|
|
@ -2,12 +2,9 @@ package acme
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
@ -19,7 +16,6 @@ import (
|
|||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
)
|
||||
|
||||
|
@ -251,14 +247,14 @@ func TestOrder_UpdateStatus(t *testing.T) {
|
|||
tc := run(t)
|
||||
if err := tc.o.UpdateStatus(context.Background(), tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
var k *Error
|
||||
if errors.As(err, &k) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
assert.Equals(t, k.Type, tc.err.Type)
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
assert.Equals(t, k.Status, tc.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
} else {
|
||||
default:
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
}
|
||||
|
@ -272,7 +268,6 @@ func TestOrder_UpdateStatus(t *testing.T) {
|
|||
|
||||
type mockSignAuth struct {
|
||||
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
areSANsAllowed func(ctx context.Context, sans []string) error
|
||||
loadProvisionerByName func(string) (provisioner.Interface, error)
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
|
@ -287,13 +282,6 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S
|
|||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error {
|
||||
if m.areSANsAllowed != nil {
|
||||
return m.areSANsAllowed(ctx, sans)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByName != nil {
|
||||
return m.loadProvisionerByName(name)
|
||||
|
@ -301,7 +289,7 @@ func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface
|
|||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *mockSignAuth) IsRevoked(string) (bool, error) {
|
||||
func (m *mockSignAuth) IsRevoked(sn string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
@ -310,14 +298,6 @@ func (m *mockSignAuth) Revoke(context.Context, *authority.RevokeOptions) error {
|
|||
}
|
||||
|
||||
func TestOrder_Finalize(t *testing.T) {
|
||||
mustSigner := func(kty, crv string, size int) crypto.Signer {
|
||||
s, err := keyutil.GenerateSigner(kty, crv, size)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type test struct {
|
||||
o *Order
|
||||
err *Error
|
||||
|
@ -398,72 +378,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
err: NewErrorISE("unrecognized order status: %s", o.Status),
|
||||
}
|
||||
},
|
||||
"fail/non-matching-permanent-identifier-common-name": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
ID: "oID",
|
||||
AccountID: "accID",
|
||||
Status: StatusReady,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
AuthorizationIDs: []string{"a", "b"},
|
||||
Identifiers: []Identifier{
|
||||
{Type: "permanent-identifier", Value: "a-permanent-identifier"},
|
||||
},
|
||||
}
|
||||
|
||||
signer := mustSigner("EC", "P-256", 0)
|
||||
fingerprint, err := keyutil.Fingerprint(signer.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
csr := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "a-different-identifier",
|
||||
},
|
||||
PublicKey: signer.Public(),
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
|
||||
Value: []byte("a-permanent-identifier"),
|
||||
},
|
||||
},
|
||||
}
|
||||
return test{
|
||||
o: o,
|
||||
csr: csr,
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
switch id {
|
||||
case "a":
|
||||
return &Authorization{
|
||||
ID: id,
|
||||
Status: StatusValid,
|
||||
}, nil
|
||||
case "b":
|
||||
return &Authorization{
|
||||
ID: id,
|
||||
Fingerprint: fingerprint,
|
||||
Status: StatusValid,
|
||||
}, nil
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected authorization %s", id))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
MockUpdateOrder: func(ctx context.Context, o *Order) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
err: &Error{
|
||||
Type: "urn:ietf:params:acme:error:badCSR",
|
||||
Detail: "The CSR is unacceptable",
|
||||
Status: 400,
|
||||
Err: fmt.Errorf("CSR Subject Common Name does not match identifiers exactly: "+
|
||||
"CSR Subject Common Name = %s, Order Permanent Identifier = %s", csr.Subject.CommonName, "a-permanent-identifier"),
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/error-provisioner-auth": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
|
@ -493,11 +407,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
},
|
||||
},
|
||||
err: NewErrorISE("error retrieving authorization options from ACME provisioner: force"),
|
||||
}
|
||||
},
|
||||
|
@ -537,11 +446,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
}
|
||||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
},
|
||||
},
|
||||
err: NewErrorISE("error creating template options from ACME provisioner: error unmarshaling template data: invalid character 'o' in literal false (expecting 'a')"),
|
||||
}
|
||||
},
|
||||
|
@ -583,11 +487,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
},
|
||||
},
|
||||
err: NewErrorISE("error signing certificate for order oID: force"),
|
||||
}
|
||||
},
|
||||
|
@ -634,9 +533,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
},
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
assert.Equals(t, cert.OrderID, o.ID)
|
||||
|
@ -691,9 +587,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
},
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
cert.ID = "certID"
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
|
@ -716,297 +609,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
err: NewErrorISE("error updating order oID: force"),
|
||||
}
|
||||
},
|
||||
"fail/csr-fingerprint": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
ID: "oID",
|
||||
AccountID: "accID",
|
||||
Status: StatusReady,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
AuthorizationIDs: []string{"a", "b"},
|
||||
Identifiers: []Identifier{
|
||||
{Type: "permanent-identifier", Value: "a-permanent-identifier"},
|
||||
},
|
||||
}
|
||||
|
||||
signer := mustSigner("EC", "P-256", 0)
|
||||
|
||||
csr := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "a-permanent-identifier",
|
||||
},
|
||||
PublicKey: signer.Public(),
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
|
||||
Value: []byte("a-permanent-identifier"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
leaf := &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "a-permanent-identifier"},
|
||||
PublicKey: signer.Public(),
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
|
||||
Value: []byte("a-permanent-identifier"),
|
||||
},
|
||||
},
|
||||
}
|
||||
inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}}
|
||||
root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}}
|
||||
|
||||
return test{
|
||||
o: o,
|
||||
csr: csr,
|
||||
prov: &MockProvisioner{
|
||||
MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
assert.Equals(t, token, "")
|
||||
return nil, nil
|
||||
},
|
||||
MgetOptions: func() *provisioner.Options {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ca: &mockSignAuth{
|
||||
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
assert.Equals(t, _csr, csr)
|
||||
return []*x509.Certificate{leaf, inter, root}, nil
|
||||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{
|
||||
ID: id,
|
||||
Fingerprint: "other-fingerprint",
|
||||
Status: StatusValid,
|
||||
}, nil
|
||||
},
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
cert.ID = "certID"
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
assert.Equals(t, cert.OrderID, o.ID)
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||
return nil
|
||||
},
|
||||
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
|
||||
assert.Equals(t, updo.CertificateID, "certID")
|
||||
assert.Equals(t, updo.Status, StatusValid)
|
||||
assert.Equals(t, updo.ID, o.ID)
|
||||
assert.Equals(t, updo.AccountID, o.AccountID)
|
||||
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
|
||||
assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
|
||||
assert.Equals(t, updo.Identifiers, o.Identifiers)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
err: NewError(ErrorUnauthorizedType, "order oID csr does not match the attested key"),
|
||||
}
|
||||
},
|
||||
"ok/permanent-identifier": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
ID: "oID",
|
||||
AccountID: "accID",
|
||||
Status: StatusReady,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
AuthorizationIDs: []string{"a", "b"},
|
||||
Identifiers: []Identifier{
|
||||
{Type: "permanent-identifier", Value: "a-permanent-identifier"},
|
||||
},
|
||||
}
|
||||
|
||||
signer := mustSigner("EC", "P-256", 0)
|
||||
fingerprint, err := keyutil.Fingerprint(signer.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
csr := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "a-permanent-identifier",
|
||||
},
|
||||
PublicKey: signer.Public(),
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
|
||||
Value: []byte("a-permanent-identifier"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
leaf := &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "a-permanent-identifier"},
|
||||
PublicKey: signer.Public(),
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
|
||||
Value: []byte("a-permanent-identifier"),
|
||||
},
|
||||
},
|
||||
}
|
||||
inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}}
|
||||
root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}}
|
||||
|
||||
return test{
|
||||
o: o,
|
||||
csr: csr,
|
||||
prov: &MockProvisioner{
|
||||
MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
assert.Equals(t, token, "")
|
||||
return nil, nil
|
||||
},
|
||||
MgetOptions: func() *provisioner.Options {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ca: &mockSignAuth{
|
||||
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
assert.Equals(t, _csr, csr)
|
||||
return []*x509.Certificate{leaf, inter, root}, nil
|
||||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
switch id {
|
||||
case "a":
|
||||
return &Authorization{
|
||||
ID: id,
|
||||
Status: StatusValid,
|
||||
}, nil
|
||||
case "b":
|
||||
return &Authorization{
|
||||
ID: id,
|
||||
Fingerprint: fingerprint,
|
||||
Status: StatusValid,
|
||||
}, nil
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected authorization %s", id))
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
},
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
cert.ID = "certID"
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
assert.Equals(t, cert.OrderID, o.ID)
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||
return nil
|
||||
},
|
||||
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
|
||||
assert.Equals(t, updo.CertificateID, "certID")
|
||||
assert.Equals(t, updo.Status, StatusValid)
|
||||
assert.Equals(t, updo.ID, o.ID)
|
||||
assert.Equals(t, updo.AccountID, o.AccountID)
|
||||
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
|
||||
assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
|
||||
assert.Equals(t, updo.Identifiers, o.Identifiers)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/permanent-identifier-only": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
ID: "oID",
|
||||
AccountID: "accID",
|
||||
Status: StatusReady,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
AuthorizationIDs: []string{"a", "b"},
|
||||
Identifiers: []Identifier{
|
||||
{Type: "dns", Value: "foo.internal"},
|
||||
{Type: "permanent-identifier", Value: "a-permanent-identifier"},
|
||||
},
|
||||
}
|
||||
|
||||
signer := mustSigner("EC", "P-256", 0)
|
||||
fingerprint, err := keyutil.Fingerprint(signer.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
csr := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "a-permanent-identifier",
|
||||
},
|
||||
DNSNames: []string{"foo.internal"},
|
||||
PublicKey: signer.Public(),
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
|
||||
Value: []byte("a-permanent-identifier"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
leaf := &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "a-permanent-identifier"},
|
||||
PublicKey: signer.Public(),
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
|
||||
Value: []byte("a-permanent-identifier"),
|
||||
},
|
||||
},
|
||||
}
|
||||
inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}}
|
||||
root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}}
|
||||
|
||||
return test{
|
||||
o: o,
|
||||
csr: csr,
|
||||
prov: &MockProvisioner{
|
||||
MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
assert.Equals(t, token, "")
|
||||
return nil, nil
|
||||
},
|
||||
MgetOptions: func() *provisioner.Options {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
// TODO(hs): we should work on making the mocks more realistic. Ideally, we should get rid of
|
||||
// the mock entirely, relying on an instances of provisioner, authority and DB (possibly hardest), so
|
||||
// that behavior of the tests is what an actual CA would do. We could gradually phase them out by
|
||||
// using the mocking functions as a wrapper for actual test helpers generated per test case or per
|
||||
// function that's tested.
|
||||
ca: &mockSignAuth{
|
||||
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
assert.Equals(t, _csr, csr)
|
||||
return []*x509.Certificate{leaf, inter, root}, nil
|
||||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{
|
||||
ID: id,
|
||||
Fingerprint: fingerprint,
|
||||
Status: StatusValid,
|
||||
}, nil
|
||||
},
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
cert.ID = "certID"
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
assert.Equals(t, cert.OrderID, o.ID)
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||
return nil
|
||||
},
|
||||
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
|
||||
assert.Equals(t, updo.CertificateID, "certID")
|
||||
assert.Equals(t, updo.Status, StatusValid)
|
||||
assert.Equals(t, updo.ID, o.ID)
|
||||
assert.Equals(t, updo.AccountID, o.AccountID)
|
||||
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
|
||||
assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
|
||||
assert.Equals(t, updo.Identifiers, o.Identifiers)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/new-cert-dns": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
o := &Order{
|
||||
|
@ -1050,9 +652,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
},
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
cert.ID = "certID"
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
|
@ -1114,9 +713,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
},
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
cert.ID = "certID"
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
|
@ -1181,9 +777,6 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
},
|
||||
},
|
||||
db: &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
},
|
||||
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
|
||||
cert.ID = "certID"
|
||||
assert.Equals(t, cert.AccountID, o.AccountID)
|
||||
|
@ -1211,14 +804,14 @@ func TestOrder_Finalize(t *testing.T) {
|
|||
tc := run(t)
|
||||
if err := tc.o.Finalize(context.Background(), tc.db, tc.csr, tc.ca, tc.prov); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
var k *Error
|
||||
if errors.As(err, &k) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
assert.Equals(t, k.Type, tc.err.Type)
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
assert.Equals(t, k.Status, tc.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||
} else {
|
||||
default:
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
}
|
||||
|
@ -1873,14 +1466,14 @@ func TestOrder_sans(t *testing.T) {
|
|||
t.Errorf("Order.sans() = %v, want error; got none", got)
|
||||
return
|
||||
}
|
||||
var k *Error
|
||||
if errors.As(err, &k) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
assert.Equals(t, k.Type, tt.err.Type)
|
||||
assert.Equals(t, k.Detail, tt.err.Detail)
|
||||
assert.Equals(t, k.Status, tt.err.Status)
|
||||
assert.Equals(t, k.Err.Error(), tt.err.Err.Error())
|
||||
assert.Equals(t, k.Detail, tt.err.Detail)
|
||||
} else {
|
||||
default:
|
||||
assert.FatalError(t, errors.New("unexpected error type"))
|
||||
}
|
||||
return
|
||||
|
@ -1891,55 +1484,3 @@ func TestOrder_sans(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrder_getAuthorizationFingerprint(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
type fields struct {
|
||||
AuthorizationIDs []string
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
db DB
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{[]string{"az1", "az2"}}, args{ctx, &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
},
|
||||
}}, "", false},
|
||||
{"ok fingerprint", fields{[]string{"az1", "az2"}}, args{ctx, &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
if id == "az1" {
|
||||
return &Authorization{ID: id, Status: StatusValid}, nil
|
||||
}
|
||||
return &Authorization{ID: id, Fingerprint: "fingerprint", Status: StatusValid}, nil
|
||||
},
|
||||
}}, "fingerprint", false},
|
||||
{"fail", fields{[]string{"az1", "az2"}}, args{ctx, &MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}}, "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := &Order{
|
||||
AuthorizationIDs: tt.fields.AuthorizationIDs,
|
||||
}
|
||||
got, err := o.getAuthorizationFingerprint(tt.args.ctx, tt.args.db)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Order.getAuthorizationFingerprint() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Order.getAuthorizationFingerprint() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
204
api/api.go
204
api/api.go
|
@ -1,10 +1,9 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/dsa" //nolint:staticcheck // support legacy algorithms
|
||||
"crypto/dsa" //nolint
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
|
@ -21,10 +20,7 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/certificates/api/log"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
|
@ -38,12 +34,12 @@ type Authority interface {
|
|||
SSHAuthority
|
||||
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||
GetTLSOptions() *config.TLSOptions
|
||||
Root(shasum string) (*x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
|
||||
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
|
@ -53,12 +49,6 @@ type Authority interface {
|
|||
GetRoots() ([]*x509.Certificate, error)
|
||||
GetFederation() ([]*x509.Certificate, error)
|
||||
Version() authority.Version
|
||||
GetCertificateRevocationList() ([]byte, error)
|
||||
}
|
||||
|
||||
// mustAuthority will be replaced on unit tests.
|
||||
var mustAuthority = func(ctx context.Context) Authority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// TimeDuration is an alias of provisioner.TimeDuration
|
||||
|
@ -227,39 +217,8 @@ type RootResponse struct {
|
|||
// ProvisionersResponse is the response object that returns the list of
|
||||
// provisioners.
|
||||
type ProvisionersResponse struct {
|
||||
Provisioners provisioner.List
|
||||
NextCursor string
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler. It marshals the ProvisionersResponse
|
||||
// into a byte slice.
|
||||
//
|
||||
// Special treatment is given to the SCEP provisioner, as it contains a
|
||||
// challenge secret that MUST NOT be leaked in (public) HTTP responses. The
|
||||
// challenge value is thus redacted in HTTP responses.
|
||||
func (p ProvisionersResponse) MarshalJSON() ([]byte, error) {
|
||||
for _, item := range p.Provisioners {
|
||||
scepProv, ok := item.(*provisioner.SCEP)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
old := scepProv.ChallengePassword
|
||||
scepProv.ChallengePassword = "*** REDACTED ***"
|
||||
defer func(p string) { //nolint:gocritic // defer in loop required to restore initial state of provisioners
|
||||
scepProv.ChallengePassword = p
|
||||
}(old)
|
||||
}
|
||||
|
||||
var list = struct {
|
||||
Provisioners []provisioner.Interface `json:"provisioners"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}{
|
||||
Provisioners: []provisioner.Interface(p.Provisioners),
|
||||
NextCursor: p.NextCursor,
|
||||
}
|
||||
|
||||
return json.Marshal(list)
|
||||
Provisioners provisioner.List `json:"provisioners"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}
|
||||
|
||||
// ProvisionerKeyResponse is the response object that returns the encrypted key
|
||||
|
@ -283,54 +242,47 @@ type caHandler struct {
|
|||
Authority Authority
|
||||
}
|
||||
|
||||
// Route configures the http request router.
|
||||
func (h *caHandler) Route(r Router) {
|
||||
Route(r)
|
||||
}
|
||||
|
||||
// New creates a new RouterHandler with the CA endpoints.
|
||||
//
|
||||
// Deprecated: Use api.Route(r Router)
|
||||
func New(Authority) RouterHandler {
|
||||
return &caHandler{}
|
||||
func New(auth Authority) RouterHandler {
|
||||
return &caHandler{
|
||||
Authority: auth,
|
||||
}
|
||||
}
|
||||
|
||||
func Route(r Router) {
|
||||
r.MethodFunc("GET", "/version", Version)
|
||||
r.MethodFunc("GET", "/health", Health)
|
||||
r.MethodFunc("GET", "/root/{sha}", Root)
|
||||
r.MethodFunc("POST", "/sign", Sign)
|
||||
r.MethodFunc("POST", "/renew", Renew)
|
||||
r.MethodFunc("POST", "/rekey", Rekey)
|
||||
r.MethodFunc("POST", "/revoke", Revoke)
|
||||
r.MethodFunc("GET", "/crl", CRL)
|
||||
r.MethodFunc("GET", "/provisioners", Provisioners)
|
||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
|
||||
r.MethodFunc("GET", "/roots", Roots)
|
||||
r.MethodFunc("GET", "/roots.pem", RootsPEM)
|
||||
r.MethodFunc("GET", "/federation", Federation)
|
||||
func (h *caHandler) Route(r Router) {
|
||||
r.MethodFunc("GET", "/version", h.Version)
|
||||
r.MethodFunc("GET", "/health", h.Health)
|
||||
r.MethodFunc("GET", "/root/{sha}", h.Root)
|
||||
r.MethodFunc("POST", "/sign", h.Sign)
|
||||
r.MethodFunc("POST", "/renew", h.Renew)
|
||||
r.MethodFunc("POST", "/rekey", h.Rekey)
|
||||
r.MethodFunc("POST", "/revoke", h.Revoke)
|
||||
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
||||
r.MethodFunc("GET", "/roots", h.Roots)
|
||||
r.MethodFunc("GET", "/federation", h.Federation)
|
||||
// SSH CA
|
||||
r.MethodFunc("POST", "/ssh/sign", SSHSign)
|
||||
r.MethodFunc("POST", "/ssh/renew", SSHRenew)
|
||||
r.MethodFunc("POST", "/ssh/revoke", SSHRevoke)
|
||||
r.MethodFunc("POST", "/ssh/rekey", SSHRekey)
|
||||
r.MethodFunc("GET", "/ssh/roots", SSHRoots)
|
||||
r.MethodFunc("GET", "/ssh/federation", SSHFederation)
|
||||
r.MethodFunc("POST", "/ssh/config", SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost)
|
||||
r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts)
|
||||
r.MethodFunc("POST", "/ssh/bastion", SSHBastion)
|
||||
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
|
||||
r.MethodFunc("POST", "/ssh/renew", h.SSHRenew)
|
||||
r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke)
|
||||
r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey)
|
||||
r.MethodFunc("GET", "/ssh/roots", h.SSHRoots)
|
||||
r.MethodFunc("GET", "/ssh/federation", h.SSHFederation)
|
||||
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
|
||||
r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts)
|
||||
r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion)
|
||||
|
||||
// For compatibility with old code:
|
||||
r.MethodFunc("POST", "/re-sign", Renew)
|
||||
r.MethodFunc("POST", "/sign-ssh", SSHSign)
|
||||
r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts)
|
||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||
r.MethodFunc("POST", "/sign-ssh", h.SSHSign)
|
||||
r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts)
|
||||
}
|
||||
|
||||
// Version is an HTTP handler that returns the version of the server.
|
||||
func Version(w http.ResponseWriter, r *http.Request) {
|
||||
v := mustAuthority(r.Context()).Version()
|
||||
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
||||
v := h.Authority.Version()
|
||||
render.JSON(w, VersionResponse{
|
||||
Version: v.Version,
|
||||
RequireClientAuthentication: v.RequireClientAuthentication,
|
||||
|
@ -338,17 +290,17 @@ func Version(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Health is an HTTP handler that returns the status of the server.
|
||||
func Health(w http.ResponseWriter, _ *http.Request) {
|
||||
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSON(w, HealthResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
||||
// certificate for the given SHA256.
|
||||
func Root(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
||||
sha := chi.URLParam(r, "sha")
|
||||
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
|
||||
// Load root certificate with the
|
||||
cert, err := mustAuthority(r.Context()).Root(sum)
|
||||
cert, err := h.Authority.Root(sum)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
return
|
||||
|
@ -366,19 +318,18 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
|||
}
|
||||
|
||||
// Provisioners returns the list of provisioners configured in the authority.
|
||||
func Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
p, next, err := h.Authority.GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionersResponse{
|
||||
Provisioners: p,
|
||||
NextCursor: next,
|
||||
|
@ -386,20 +337,19 @@ func Provisioners(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// ProvisionerKey returns the encrypted key of a provisioner by it's key id.
|
||||
func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
kid := chi.URLParam(r, "kid")
|
||||
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
|
||||
key, err := h.Authority.GetEncryptedKey(kid)
|
||||
if err != nil {
|
||||
render.Error(w, errs.NotFoundErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionerKeyResponse{key})
|
||||
}
|
||||
|
||||
// Roots returns all the root certificates for the CA.
|
||||
func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := h.Authority.GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
|
||||
return
|
||||
|
@ -415,32 +365,9 @@ func Roots(w http.ResponseWriter, r *http.Request) {
|
|||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// RootsPEM returns all the root certificates for the CA in PEM format.
|
||||
func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/x-pem-file")
|
||||
|
||||
for _, root := range roots {
|
||||
block := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: root.Raw,
|
||||
})
|
||||
|
||||
if _, err := w.Write(block); err != nil {
|
||||
log.Error(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Federation returns all the public certificates in the federation.
|
||||
func Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := mustAuthority(r.Context()).GetFederation()
|
||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := h.Authority.GetFederation()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||
return
|
||||
|
@ -472,7 +399,7 @@ func logOtt(w http.ResponseWriter, token string) {
|
|||
}
|
||||
}
|
||||
|
||||
// LogCertificate adds certificate fields to the log message.
|
||||
// LogCertificate add certificate fields to the log message.
|
||||
func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
m := map[string]interface{}{
|
||||
|
@ -504,41 +431,6 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
|||
}
|
||||
}
|
||||
|
||||
// LogSSHCertificate adds SSH certificate fields to the log message.
|
||||
func LogSSHCertificate(w http.ResponseWriter, cert *ssh.Certificate) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
mak := bytes.TrimSpace(ssh.MarshalAuthorizedKey(cert))
|
||||
var certificate string
|
||||
parts := strings.Split(string(mak), " ")
|
||||
if len(parts) > 1 {
|
||||
certificate = parts[1]
|
||||
}
|
||||
var userOrHost string
|
||||
if cert.CertType == ssh.HostCert {
|
||||
userOrHost = "host"
|
||||
} else {
|
||||
userOrHost = "user"
|
||||
}
|
||||
certificateType := fmt.Sprintf("%s %s certificate", parts[0], userOrHost) // e.g. ecdsa-sha2-nistp256-cert-v01@openssh.com user certificate
|
||||
m := map[string]interface{}{
|
||||
"serial": cert.Serial,
|
||||
"principals": cert.ValidPrincipals,
|
||||
"valid-from": time.Unix(int64(cert.ValidAfter), 0).Format(time.RFC3339),
|
||||
"valid-to": time.Unix(int64(cert.ValidBefore), 0).Format(time.RFC3339),
|
||||
"certificate": certificate,
|
||||
"certificate-type": certificateType,
|
||||
}
|
||||
fingerprint, err := sshutil.FormatFingerprint(mak, sshutil.DefaultFingerprint)
|
||||
if err == nil {
|
||||
fpParts := strings.Split(fingerprint, " ")
|
||||
if len(fpParts) > 3 {
|
||||
m["public-key"] = fmt.Sprintf("%s %s", fpParts[1], fpParts[len(fpParts)-1])
|
||||
}
|
||||
}
|
||||
rl.WithFields(m)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseCursor parses the cursor and limit from the request query params.
|
||||
func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
|
||||
q := r.URL.Query()
|
||||
|
|
314
api/api_test.go
314
api/api_test.go
|
@ -4,7 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/dsa" //nolint:staticcheck // support legacy algorithms
|
||||
"crypto/dsa" //nolint
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
|
@ -28,15 +28,12 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
sassert "github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"golang.org/x/crypto/ssh"
|
||||
squarejose "gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
@ -174,28 +171,16 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
|
|||
return csr
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a Authority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) Authority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
type mockAuthority struct {
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
||||
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||
getTLSOptions func() *authority.TLSOptions
|
||||
root func(shasum string) (*x509.Certificate, error)
|
||||
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||
loadProvisionerByName func(name string) (provisioner.Interface, error)
|
||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
|
@ -203,7 +188,6 @@ type mockAuthority struct {
|
|||
getEncryptedKey func(kid string) (string, error)
|
||||
getRoots func() ([]*x509.Certificate, error)
|
||||
getFederation func() ([]*x509.Certificate, error)
|
||||
getCRL func() ([]byte, error)
|
||||
signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
|
@ -217,18 +201,14 @@ type mockAuthority struct {
|
|||
version func() authority.Version
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) {
|
||||
if m.getCRL != nil {
|
||||
return m.getCRL()
|
||||
}
|
||||
|
||||
return m.ret1.([]byte), m.err
|
||||
}
|
||||
|
||||
// TODO: remove once Authorize is deprecated.
|
||||
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
if m.authorize != nil {
|
||||
return m.authorize(ctx, ott)
|
||||
return m.AuthorizeSign(ott)
|
||||
}
|
||||
|
||||
func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||
if m.authorizeSign != nil {
|
||||
return m.authorizeSign(ott)
|
||||
}
|
||||
return m.ret1.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
@ -268,13 +248,6 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro
|
|||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
if m.renewContext != nil {
|
||||
return m.renewContext(ctx, oldcert, pk)
|
||||
}
|
||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
if m.rekey != nil {
|
||||
return m.rekey(oldcert, pk)
|
||||
|
@ -792,45 +765,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) (
|
|||
return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
||||
func Test_CRLGeneration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
statusCode int
|
||||
expected []byte
|
||||
}{
|
||||
{"empty", nil, http.StatusOK, nil},
|
||||
}
|
||||
|
||||
chiCtx := chi.NewRouteContext()
|
||||
req := httptest.NewRequest("GET", "http://example.com/crl", nil)
|
||||
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{ret1: tt.expected, err: tt.err})
|
||||
w := httptest.NewRecorder()
|
||||
CRL(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.CRL StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.Root unexpected error = %v", err)
|
||||
}
|
||||
if tt.statusCode == 200 {
|
||||
if !bytes.Equal(bytes.TrimSpace(body), tt.expected) {
|
||||
t.Errorf("caHandler.Root CRL = %s, wants %s", body, tt.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Route(t *testing.T) {
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
|
@ -855,10 +789,11 @@ func Test_caHandler_Route(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_Health(t *testing.T) {
|
||||
func Test_caHandler_Health(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
Health(w, req)
|
||||
h := New(&mockAuthority{}).(*caHandler)
|
||||
h.Health(w, req)
|
||||
|
||||
res := w.Result()
|
||||
if res.StatusCode != 200 {
|
||||
|
@ -876,7 +811,7 @@ func Test_Health(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_Root(t *testing.T) {
|
||||
func Test_caHandler_Root(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
root *x509.Certificate
|
||||
|
@ -897,9 +832,9 @@ func Test_Root(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err})
|
||||
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
|
||||
w := httptest.NewRecorder()
|
||||
Root(w, req)
|
||||
h.Root(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -920,7 +855,7 @@ func Test_Root(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_Sign(t *testing.T) {
|
||||
func Test_caHandler_Sign(t *testing.T) {
|
||||
csr := parseCertificateRequest(csrPEM)
|
||||
valid, err := json.Marshal(SignRequest{
|
||||
CsrPEM: CertificateRequest{csr},
|
||||
|
@ -961,18 +896,18 @@ func Test_Sign(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
h := New(&mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return tt.certAttrOpts, tt.autherr
|
||||
},
|
||||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
|
||||
w := httptest.NewRecorder()
|
||||
Sign(logging.NewResponseLogger(w), req)
|
||||
h.Sign(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -993,7 +928,7 @@ func Test_Sign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_Renew(t *testing.T) {
|
||||
func Test_caHandler_Renew(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1083,7 +1018,7 @@ func Test_Renew(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
h := New(&mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
|
||||
|
@ -1104,12 +1039,12 @@ func Test_Renew(t *testing.T) {
|
|||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
||||
req.TLS = tt.tls
|
||||
req.Header = tt.header
|
||||
w := httptest.NewRecorder()
|
||||
Renew(logging.NewResponseLogger(w), req)
|
||||
h.Renew(logging.NewResponseLogger(w), req)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
@ -1138,7 +1073,7 @@ func Test_Renew(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_Rekey(t *testing.T) {
|
||||
func Test_caHandler_Rekey(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1169,16 +1104,16 @@ func Test_Rekey(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
h := New(&mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
Rekey(logging.NewResponseLogger(w), req)
|
||||
h.Rekey(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -1199,7 +1134,7 @@ func Test_Rekey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_Provisioners(t *testing.T) {
|
||||
func Test_caHandler_Provisioners(t *testing.T) {
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
}
|
||||
|
@ -1265,8 +1200,10 @@ func Test_Provisioners(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tt.fields.Authority)
|
||||
Provisioners(tt.args.w, tt.args.r)
|
||||
h := &caHandler{
|
||||
Authority: tt.fields.Authority,
|
||||
}
|
||||
h.Provisioners(tt.args.w, tt.args.r)
|
||||
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
res := rec.Result()
|
||||
|
@ -1301,7 +1238,7 @@ func Test_Provisioners(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_ProvisionerKey(t *testing.T) {
|
||||
func Test_caHandler_ProvisionerKey(t *testing.T) {
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
}
|
||||
|
@ -1333,8 +1270,10 @@ func Test_ProvisionerKey(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tt.fields.Authority)
|
||||
ProvisionerKey(tt.args.w, tt.args.r)
|
||||
h := &caHandler{
|
||||
Authority: tt.fields.Authority,
|
||||
}
|
||||
h.ProvisionerKey(tt.args.w, tt.args.r)
|
||||
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
res := rec.Result()
|
||||
|
@ -1359,7 +1298,7 @@ func Test_ProvisionerKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_Roots(t *testing.T) {
|
||||
func Test_caHandler_Roots(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1380,11 +1319,11 @@ func Test_Roots(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
Roots(w, req)
|
||||
h.Roots(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -1405,47 +1344,7 @@ func Test_Roots(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_RootsPEM(t *testing.T) {
|
||||
parsedRoot := parseCertificate(rootPEM)
|
||||
tests := []struct {
|
||||
name string
|
||||
roots []*x509.Certificate
|
||||
err error
|
||||
statusCode int
|
||||
expect string
|
||||
}{
|
||||
{"one root", []*x509.Certificate{parsedRoot}, nil, http.StatusOK, rootPEM},
|
||||
{"two roots", []*x509.Certificate{parsedRoot, parsedRoot}, nil, http.StatusOK, rootPEM + "\n" + rootPEM},
|
||||
{"fail", nil, errors.New("an error"), http.StatusInternalServerError, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err})
|
||||
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
|
||||
w := httptest.NewRecorder()
|
||||
RootsPEM(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.RootsPEM StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.RootsPEM unexpected error = %v", err)
|
||||
}
|
||||
if tt.statusCode < http.StatusBadRequest {
|
||||
if !bytes.Equal(bytes.TrimSpace(body), []byte(tt.expect)) {
|
||||
t.Errorf("caHandler.RootsPEM Body = %s, wants %s", body, tt.expect)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Federation(t *testing.T) {
|
||||
func Test_caHandler_Federation(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1466,11 +1365,11 @@ func Test_Federation(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
Federation(w, req)
|
||||
h.Federation(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -1496,7 +1395,7 @@ func Test_fmtPublicKey(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1522,7 +1421,7 @@ func Test_fmtPublicKey(t *testing.T) {
|
|||
want string
|
||||
}{
|
||||
{"p256", args{p256.Public(), p256, nil}, "ECDSA P-256"},
|
||||
{"rsa2048", args{rsa2048.Public(), rsa2048, nil}, "RSA 2048"},
|
||||
{"rsa1024", args{rsa1024.Public(), rsa1024, nil}, "RSA 1024"},
|
||||
{"ed25519", args{edPub, edPriv, nil}, "Ed25519"},
|
||||
{"dsa2048", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.DSA, PublicKey: &dsa2048.PublicKey}}, "DSA 2048"},
|
||||
{"unknown", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.ECDSA, PublicKey: []byte("12345678")}}, "ECDSA unknown"},
|
||||
|
@ -1567,122 +1466,3 @@ func mustCertificate(t *testing.T, pub, priv interface{}) *x509.Certificate {
|
|||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
func TestProvisionersResponse_MarshalJSON(t *testing.T) {
|
||||
|
||||
k := map[string]any{
|
||||
"use": "sig",
|
||||
"kty": "EC",
|
||||
"kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc",
|
||||
"crv": "P-256",
|
||||
"alg": "ES256",
|
||||
"x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8",
|
||||
"y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y",
|
||||
}
|
||||
key := squarejose.JSONWebKey{}
|
||||
b, err := json.Marshal(k)
|
||||
assert.FatalError(t, err)
|
||||
err = json.Unmarshal(b, &key)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
r := ProvisionersResponse{
|
||||
Provisioners: provisioner.List{
|
||||
&provisioner.SCEP{
|
||||
Name: "scep",
|
||||
Type: "scep",
|
||||
ChallengePassword: "not-so-secret",
|
||||
MinimumPublicKeyLength: 2048,
|
||||
EncryptionAlgorithmIdentifier: 2,
|
||||
},
|
||||
&provisioner.JWK{
|
||||
EncryptedKey: "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg",
|
||||
Key: &key,
|
||||
Name: "step-cli",
|
||||
Type: "JWK",
|
||||
},
|
||||
},
|
||||
NextCursor: "next",
|
||||
}
|
||||
|
||||
expected := map[string]any{
|
||||
"provisioners": []map[string]any{
|
||||
{
|
||||
"type": "scep",
|
||||
"name": "scep",
|
||||
"challenge": "*** REDACTED ***",
|
||||
"minimumPublicKeyLength": 2048,
|
||||
"encryptionAlgorithmIdentifier": 2,
|
||||
},
|
||||
{
|
||||
"type": "JWK",
|
||||
"name": "step-cli",
|
||||
"key": map[string]any{
|
||||
"use": "sig",
|
||||
"kty": "EC",
|
||||
"kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc",
|
||||
"crv": "P-256",
|
||||
"alg": "ES256",
|
||||
"x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8",
|
||||
"y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y",
|
||||
},
|
||||
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg",
|
||||
},
|
||||
},
|
||||
"nextCursor": "next",
|
||||
}
|
||||
|
||||
expBytes, err := json.Marshal(expected)
|
||||
sassert.NoError(t, err)
|
||||
|
||||
br, err := r.MarshalJSON()
|
||||
sassert.NoError(t, err)
|
||||
sassert.JSONEq(t, string(expBytes), string(br))
|
||||
|
||||
keyCopy := key
|
||||
expList := provisioner.List{
|
||||
&provisioner.SCEP{
|
||||
Name: "scep",
|
||||
Type: "scep",
|
||||
ChallengePassword: "not-so-secret",
|
||||
MinimumPublicKeyLength: 2048,
|
||||
EncryptionAlgorithmIdentifier: 2,
|
||||
},
|
||||
&provisioner.JWK{
|
||||
EncryptedKey: "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg",
|
||||
Key: &keyCopy,
|
||||
Name: "step-cli",
|
||||
Type: "JWK",
|
||||
},
|
||||
}
|
||||
|
||||
// MarshalJSON must not affect the struct properties itself
|
||||
sassert.Equal(t, expList, r.Provisioners)
|
||||
}
|
||||
|
||||
const (
|
||||
fixtureECDSACertificate = `ecdsa-sha2-nistp256-cert-v01@openssh.com AAAAKGVjZHNhLXNoYTItbmlzdHAyNTYtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLnkvSk4odlo3b1R+RDw+LmorL3RkN354IilCIVFVen4AAAAIbmlzdHAyNTYAAABBBHjKHss8WM2ffMYlavisoLXR0I6UEIU+cidV1ogEH1U6+/SYaFPrlzQo0tGLM5CNkMbhInbyasQsrHzn8F1Rt7nHg5/tcSf9qwAAAAEAAAAGaGVybWFuAAAACgAAAAZoZXJtYW4AAAAAY8kvJwAAAABjyhBjAAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAGgAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAAhuaXN0cDI1NgAAAEEE/ayqpPrZZF5uA1UlDt4FreTf15agztQIzpxnWq/XoxAHzagRSkFGkdgFpjgsfiRpP8URHH3BZScqc0ZDCTxhoQAAAGQAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAEkAAAAhAJuP1wCVwoyrKrEtHGfFXrVbRHySDjvXtS1tVTdHyqymAAAAIBa/CSSzfZb4D2NLP+eEmOOMJwSjYOiNM8fiOoAaqglI herman`
|
||||
)
|
||||
|
||||
func TestLogSSHCertificate(t *testing.T) {
|
||||
|
||||
out, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fixtureECDSACertificate))
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, ok := out.(*ssh.Certificate)
|
||||
require.True(t, ok)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
rl := logging.NewResponseLogger(w)
|
||||
LogSSHCertificate(rl, cert)
|
||||
|
||||
sassert.Equal(t, 200, w.Result().StatusCode)
|
||||
|
||||
fields := rl.Fields()
|
||||
sassert.Equal(t, uint64(14376510277651266987), fields["serial"])
|
||||
sassert.Equal(t, []string{"herman"}, fields["principals"])
|
||||
sassert.Equal(t, "ecdsa-sha2-nistp256-cert-v01@openssh.com user certificate", fields["certificate-type"])
|
||||
sassert.Equal(t, time.Unix(1674129191, 0).Format(time.RFC3339), fields["valid-from"])
|
||||
sassert.Equal(t, time.Unix(1674186851, 0).Format(time.RFC3339), fields["valid-to"])
|
||||
sassert.Equal(t, "AAAAKGVjZHNhLXNoYTItbmlzdHAyNTYtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLnkvSk4odlo3b1R+RDw+LmorL3RkN354IilCIVFVen4AAAAIbmlzdHAyNTYAAABBBHjKHss8WM2ffMYlavisoLXR0I6UEIU+cidV1ogEH1U6+/SYaFPrlzQo0tGLM5CNkMbhInbyasQsrHzn8F1Rt7nHg5/tcSf9qwAAAAEAAAAGaGVybWFuAAAACgAAAAZoZXJtYW4AAAAAY8kvJwAAAABjyhBjAAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAGgAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAAhuaXN0cDI1NgAAAEEE/ayqpPrZZF5uA1UlDt4FreTf15agztQIzpxnWq/XoxAHzagRSkFGkdgFpjgsfiRpP8URHH3BZScqc0ZDCTxhoQAAAGQAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAEkAAAAhAJuP1wCVwoyrKrEtHGfFXrVbRHySDjvXtS1tVTdHyqymAAAAIBa/CSSzfZb4D2NLP+eEmOOMJwSjYOiNM8fiOoAaqglI", fields["certificate"])
|
||||
sassert.Equal(t, "SHA256:RvkDPGwl/G9d7LUFm1kmWhvOD9I/moPq4yxcb0STwr0 (ECDSA-CERT)", fields["public-key"])
|
||||
}
|
||||
|
|
32
api/crl.go
32
api/crl.go
|
@ -1,32 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
)
|
||||
|
||||
// CRL is an HTTP handler that returns the current CRL in DER or PEM format
|
||||
func CRL(w http.ResponseWriter, r *http.Request) {
|
||||
crlBytes, err := mustAuthority(r.Context()).GetCertificateRevocationList()
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, formatAsPEM := r.URL.Query()["pem"]
|
||||
if formatAsPEM {
|
||||
w.Header().Add("Content-Type", "application/x-pem-file")
|
||||
w.Header().Add("Content-Disposition", "attachment; filename=\"crl.pem\"")
|
||||
|
||||
_ = pem.Encode(w, &pem.Block{
|
||||
Type: "X509 CRL",
|
||||
Bytes: crlBytes,
|
||||
})
|
||||
} else {
|
||||
w.Header().Add("Content-Type", "application/pkix-crl")
|
||||
w.Header().Add("Content-Disposition", "attachment; filename=\"crl.der\"")
|
||||
w.Write(crlBytes)
|
||||
}
|
||||
}
|
|
@ -2,58 +2,30 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// StackTracedError is the set of errors implementing the StackTrace function.
|
||||
//
|
||||
// Errors implementing this interface have their stack traces logged when passed
|
||||
// to the Error function of this package.
|
||||
type StackTracedError interface {
|
||||
error
|
||||
|
||||
StackTrace() errors.StackTrace
|
||||
}
|
||||
|
||||
type fieldCarrier interface {
|
||||
WithFields(map[string]any)
|
||||
Fields() map[string]any
|
||||
}
|
||||
|
||||
// Error adds to the response writer the given error if it implements
|
||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||
// using the log package.
|
||||
func Error(rw http.ResponseWriter, err error) {
|
||||
fc, ok := rw.(fieldCarrier)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
fc.WithFields(map[string]any{
|
||||
"error": err,
|
||||
})
|
||||
|
||||
if os.Getenv("STEPDEBUG") != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
var st StackTracedError
|
||||
if errors.As(err, &st) {
|
||||
fc.WithFields(map[string]any{
|
||||
"stack-trace": fmt.Sprintf("%+v", st.StackTrace()),
|
||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"error": err,
|
||||
})
|
||||
} else {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
// EnabledResponse log the response object if it implements the EnableLogger
|
||||
// interface.
|
||||
func EnabledResponse(rw http.ResponseWriter, v any) {
|
||||
func EnabledResponse(rw http.ResponseWriter, v interface{}) {
|
||||
type enableLogger interface {
|
||||
ToLog() (any, error)
|
||||
ToLog() (interface{}, error)
|
||||
}
|
||||
|
||||
if el, ok := v.(enableLogger); ok {
|
||||
|
@ -64,10 +36,12 @@ func EnabledResponse(rw http.ResponseWriter, v any) {
|
|||
return
|
||||
}
|
||||
|
||||
if rl, ok := rw.(fieldCarrier); ok {
|
||||
rl.WithFields(map[string]any{
|
||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"response": out,
|
||||
})
|
||||
} else {
|
||||
log.Println(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,78 +1,43 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
type stackTracedError struct{}
|
||||
|
||||
func (stackTracedError) Error() string {
|
||||
return "a stacktraced error"
|
||||
}
|
||||
|
||||
func (stackTracedError) StackTrace() pkgerrors.StackTrace {
|
||||
f := struct{}{}
|
||||
return pkgerrors.StackTrace{ // fake stacktrace
|
||||
pkgerrors.Frame(unsafe.Pointer(&f)),
|
||||
pkgerrors.Frame(unsafe.Pointer(&f)),
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
theError := errors.New("the error")
|
||||
|
||||
type args struct {
|
||||
rw http.ResponseWriter
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
error
|
||||
rw http.ResponseWriter
|
||||
isFieldCarrier bool
|
||||
stepDebug bool
|
||||
expectStackTrace bool
|
||||
name string
|
||||
args args
|
||||
withFields bool
|
||||
}{
|
||||
{"noLogger", nil, nil, false, false, false},
|
||||
{"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false},
|
||||
{"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
|
||||
{"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false},
|
||||
{"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
|
||||
{"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
|
||||
{"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
|
||||
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
|
||||
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.stepDebug {
|
||||
t.Setenv("STEPDEBUG", "1")
|
||||
} else {
|
||||
t.Setenv("STEPDEBUG", "0")
|
||||
}
|
||||
|
||||
Error(tt.rw, tt.error)
|
||||
|
||||
// return early if test case doesn't use logger
|
||||
if !tt.isFieldCarrier {
|
||||
return
|
||||
}
|
||||
|
||||
fields := tt.rw.(logging.ResponseLogger).Fields()
|
||||
|
||||
// expect the error field to be (not) set and to be the same error that was fed to Error
|
||||
if tt.error == nil {
|
||||
assert.Nil(t, fields["error"])
|
||||
} else {
|
||||
assert.Same(t, tt.error, fields["error"])
|
||||
}
|
||||
|
||||
// check if stack-trace is set when expected
|
||||
if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] not set`)
|
||||
} else if !tt.expectStackTrace && hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] was set`)
|
||||
Error(tt.args.rw, tt.args.err)
|
||||
if tt.withFields {
|
||||
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
|
||||
fields := rl.Fields()
|
||||
if !reflect.DeepEqual(fields["error"], theError) {
|
||||
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
|
||||
}
|
||||
} else {
|
||||
t.Error("ResponseWriter does not implement logging.ResponseLogger")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
114
api/read/read.go
114
api/read/read.go
|
@ -2,65 +2,91 @@
|
|||
package read
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
|
||||
"github.com/smallstep/certificates/internal/buffer"
|
||||
)
|
||||
|
||||
// JSON reads JSON from the request body and stores it in the value
|
||||
// pointed to by v.
|
||||
func JSON(r io.Reader, v interface{}) error {
|
||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||
return errs.BadRequestErr(err, "error decoding json")
|
||||
// JSON unmarshals from the given request's JSON body into v. In case of an
|
||||
// error a HTTP Bad Request error will be written to w.
|
||||
func JSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
|
||||
b := read(w, r)
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
defer buffer.Put(b)
|
||||
|
||||
if err := json.NewDecoder(b).Decode(v); err != nil {
|
||||
err = fmt.Errorf("error decoding json: %w", err)
|
||||
|
||||
render.BadRequest(w, err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AdminJSON is obsolete; it's here for backwards compatibility.
|
||||
//
|
||||
// Please don't use.
|
||||
func AdminJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
|
||||
b := read(w, r)
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
defer buffer.Put(b)
|
||||
|
||||
if err := json.NewDecoder(b).Decode(v); err != nil {
|
||||
e := admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")
|
||||
admin.WriteError(w, e)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ProtoJSON reads JSON from the request body and stores it in the value
|
||||
// pointed to by m.
|
||||
func ProtoJSON(r io.Reader, m proto.Message) error {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return errs.BadRequestErr(err, "error reading request body")
|
||||
// pointed by v.
|
||||
func ProtoJSON(w http.ResponseWriter, r *http.Request, m proto.Message) bool {
|
||||
b := read(w, r)
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
defer buffer.Put(b)
|
||||
|
||||
if err := protojson.Unmarshal(b.Bytes(), m); err != nil {
|
||||
err = fmt.Errorf("error decoding proto json: %w", err)
|
||||
|
||||
render.BadRequest(w, err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
switch err := protojson.Unmarshal(data, m); {
|
||||
case errors.Is(err, proto.Error):
|
||||
return badProtoJSONError(err.Error())
|
||||
default:
|
||||
return err
|
||||
return true
|
||||
}
|
||||
|
||||
func read(w http.ResponseWriter, r *http.Request) *bytes.Buffer {
|
||||
b := buffer.Get()
|
||||
if _, err := b.ReadFrom(r.Body); err != nil {
|
||||
buffer.Put(b)
|
||||
|
||||
err = fmt.Errorf("error reading request body: %w", err)
|
||||
|
||||
render.BadRequest(w, err)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// badProtoJSONError is an error type that is returned by ProtoJSON
|
||||
// when a proto message cannot be unmarshaled. Usually this is caused
|
||||
// by an error in the request body.
|
||||
type badProtoJSONError string
|
||||
|
||||
// Error implements error for badProtoJSONError
|
||||
func (e badProtoJSONError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Render implements render.RenderableError for badProtoJSONError
|
||||
func (e badProtoJSONError) Render(w http.ResponseWriter) {
|
||||
v := struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Type: "badRequest",
|
||||
Detail: "bad request",
|
||||
// trim the proto prefix for the message
|
||||
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
|
||||
}
|
||||
render.JSONStatus(w, v, http.StatusBadRequest)
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -1,165 +1,57 @@
|
|||
package read
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
type args struct {
|
||||
r io.Reader
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
cases := []struct {
|
||||
src io.Reader
|
||||
exp interface{}
|
||||
ok bool
|
||||
code int
|
||||
}{
|
||||
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
|
||||
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
|
||||
0: {
|
||||
src: strings.NewReader(`{"foo":"bar"}`),
|
||||
exp: map[string]interface{}{"foo": "bar"},
|
||||
ok: true,
|
||||
code: http.StatusOK,
|
||||
},
|
||||
1: {
|
||||
src: strings.NewReader(`{"foo"}`),
|
||||
code: http.StatusBadRequest,
|
||||
},
|
||||
2: {
|
||||
src: io.MultiReader(
|
||||
strings.NewReader(`{`),
|
||||
iotest.ErrReader(assert.AnError),
|
||||
strings.NewReader(`"foo":"bar"}`),
|
||||
),
|
||||
code: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := JSON(tt.args.r, &tt.args.v)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
var e *errs.Error
|
||||
if errors.As(err, &e) {
|
||||
if code := e.StatusCode(); code != 400 {
|
||||
t.Errorf("error.StatusCode() = %v, wants 400", code)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("error type = %T, wants *Error", err)
|
||||
}
|
||||
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
|
||||
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtoJSON(t *testing.T) {
|
||||
|
||||
p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import?
|
||||
|
||||
type args struct {
|
||||
r io.Reader
|
||||
m proto.Message
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "fail/io.ReadAll",
|
||||
args: args{
|
||||
r: iotest.ErrReader(errors.New("read error")),
|
||||
m: p,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail/proto",
|
||||
args: args{
|
||||
r: strings.NewReader(`{?}`),
|
||||
m: p,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
r: strings.NewReader(`{"x509":{}}`),
|
||||
m: p,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ProtoJSON(tt.args.r, tt.args.m)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
var (
|
||||
ee *errs.Error
|
||||
bpe badProtoJSONError
|
||||
)
|
||||
switch {
|
||||
case errors.As(err, &bpe):
|
||||
assert.Contains(t, err.Error(), "syntax error")
|
||||
case errors.As(err, &ee):
|
||||
assert.Equal(t, http.StatusBadRequest, ee.Status)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m))
|
||||
assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_badProtoJSONError_Render(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e badProtoJSONError
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "bad proto normal space",
|
||||
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
|
||||
expected: "syntax error (line 1:2): invalid value ?",
|
||||
},
|
||||
{
|
||||
name: "bad proto non breaking space",
|
||||
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
|
||||
expected: "syntax error (line 1:2): invalid value ?",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tt.e.Render(w)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Message string `json:"message"`
|
||||
}{}
|
||||
|
||||
assert.NoError(t, json.Unmarshal(data, &v))
|
||||
assert.Equal(t, "badRequest", v.Type)
|
||||
assert.Equal(t, "bad request", v.Detail)
|
||||
assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message)
|
||||
|
||||
for caseIndex := range cases {
|
||||
kase := cases[caseIndex]
|
||||
|
||||
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", kase.src)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
var body interface{}
|
||||
got := JSON(rec, req, &body)
|
||||
|
||||
assert.Equal(t, kase.ok, got)
|
||||
assert.Equal(t, kase.code, rec.Result().StatusCode)
|
||||
assert.Equal(t, kase.exp, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
10
api/rekey.go
10
api/rekey.go
|
@ -27,15 +27,14 @@ func (s *RekeyRequest) Validate() error {
|
|||
}
|
||||
|
||||
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
||||
func Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
render.Error(w, errs.BadRequest("missing client certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
var body RekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if !read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -44,8 +43,7 @@ func Rekey(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||
return
|
||||
|
@ -61,6 +59,6 @@ func Rekey(w http.ResponseWriter, r *http.Request) {
|
|||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
|
|
@ -3,50 +3,42 @@ package render
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/log"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/scep"
|
||||
)
|
||||
|
||||
// JSON is shorthand for JSONStatus(w, v, http.StatusOK).
|
||||
// JSON writes the passed value into the http.ResponseWriter.
|
||||
func JSON(w http.ResponseWriter, v interface{}) {
|
||||
JSONStatus(w, v, http.StatusOK)
|
||||
}
|
||||
|
||||
// JSONStatus marshals v into w. It additionally sets the status code of
|
||||
// w to the given one.
|
||||
//
|
||||
// JSONStatus sets the Content-Type of w to application/json unless one is
|
||||
// specified.
|
||||
// JSONStatus writes the given value into the http.ResponseWriter and the
|
||||
// given status is written as the status code of the response.
|
||||
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||
setContentTypeUnlessPresent(w, "application/json")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
var errUnsupportedType *json.UnsupportedTypeError
|
||||
if errors.As(err, &errUnsupportedType) {
|
||||
panic(err)
|
||||
}
|
||||
log.Error(w, err)
|
||||
|
||||
var errUnsupportedValue *json.UnsupportedValueError
|
||||
if errors.As(err, &errUnsupportedValue) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var errMarshalError *json.MarshalerError
|
||||
if errors.As(err, &errMarshalError) {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.EnabledResponse(w, v)
|
||||
}
|
||||
|
||||
// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK).
|
||||
// ProtoJSON writes the passed value into the http.ResponseWriter.
|
||||
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
|
||||
ProtoJSONStatus(w, m, http.StatusOK)
|
||||
}
|
||||
|
@ -54,82 +46,103 @@ func ProtoJSON(w http.ResponseWriter, m proto.Message) {
|
|||
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
|
||||
// given status is written as the status code of the response.
|
||||
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
b, err := protojson.Marshal(m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
setContentTypeUnlessPresent(w, "application/json")
|
||||
w.WriteHeader(status)
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
|
||||
func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) {
|
||||
const header = "Content-Type"
|
||||
|
||||
h := w.Header()
|
||||
if _, ok := h[header]; !ok {
|
||||
h.Set(header, contentType)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderableError is the set of errors that implement the basic Render method.
|
||||
//
|
||||
// Errors that implement this interface will use their own Render method when
|
||||
// being rendered into responses.
|
||||
type RenderableError interface {
|
||||
error
|
||||
|
||||
Render(http.ResponseWriter)
|
||||
}
|
||||
|
||||
// Error marshals the JSON representation of err to w. In case err implements
|
||||
// RenderableError its own Render method will be called instead.
|
||||
func Error(w http.ResponseWriter, err error) {
|
||||
log.Error(w, err)
|
||||
|
||||
var r RenderableError
|
||||
if errors.As(err, &r) {
|
||||
r.Render(w)
|
||||
log.Error(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
JSONStatus(w, err, statusCodeFromError(err))
|
||||
if _, err := w.Write(b); err != nil {
|
||||
log.Error(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// log.EnabledResponse(w, v)
|
||||
}
|
||||
|
||||
// StatusCodedError is the set of errors that implement the basic StatusCode
|
||||
// function.
|
||||
// Error encodes the JSON representation of err to w.
|
||||
func Error(w http.ResponseWriter, err error) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
acme.WriteError(w, k)
|
||||
return
|
||||
case *admin.Error:
|
||||
admin.WriteError(w, k)
|
||||
return
|
||||
case *scep.Error:
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
default:
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
cause := errors.Cause(err)
|
||||
if sc, ok := err.(errs.StatusCoder); ok {
|
||||
w.WriteHeader(sc.StatusCode())
|
||||
} else {
|
||||
if sc, ok := cause.(errs.StatusCoder); ok {
|
||||
w.WriteHeader(sc.StatusCode())
|
||||
} else {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// Write errors in the response writer
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"error": err,
|
||||
})
|
||||
if os.Getenv("STEPDEBUG") == "1" {
|
||||
if e, ok := err.(errs.StackTracer); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
} else if e, ok := cause.(errs.StackTracer); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||||
log.Error(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
// BadRequest renders the JSON representation of err into w and sets its
|
||||
// status code to http.StatusBadRequest.
|
||||
//
|
||||
// Errors that implement this interface will use the code reported by StatusCode
|
||||
// as the HTTP response code when being rendered by this package.
|
||||
type StatusCodedError interface {
|
||||
error
|
||||
|
||||
StatusCode() int
|
||||
// In case err is nil, a default error message will be used in its place.
|
||||
func BadRequest(w http.ResponseWriter, err error) {
|
||||
codedError(w, http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
func statusCodeFromError(err error) (code int) {
|
||||
code = http.StatusInternalServerError
|
||||
|
||||
type causer interface {
|
||||
Cause() error
|
||||
func codedError(w http.ResponseWriter, code int, err error) {
|
||||
if err == nil {
|
||||
err = errors.New(http.StatusText(code))
|
||||
}
|
||||
|
||||
for err != nil {
|
||||
var sc StatusCodedError
|
||||
if errors.As(err, &sc) {
|
||||
code = sc.StatusCode()
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
var c causer
|
||||
if !errors.As(err, &c) {
|
||||
break
|
||||
}
|
||||
err = c.Cause()
|
||||
var wrapper = struct {
|
||||
Status int `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Status: code,
|
||||
Message: err.Error(),
|
||||
}
|
||||
|
||||
return
|
||||
data, err := json.Marshal(wrapper)
|
||||
if err != nil {
|
||||
log.Error(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
w.Write(data)
|
||||
}
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
package render
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
|
@ -16,98 +12,67 @@ import (
|
|||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := logging.NewResponseLogger(rec)
|
||||
|
||||
JSON(rw, map[string]interface{}{"foo": "bar"})
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Result().StatusCode)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String())
|
||||
|
||||
assert.Empty(t, rw.Fields())
|
||||
}
|
||||
|
||||
func TestJSONPanicsOnUnsupportedType(t *testing.T) {
|
||||
jsonPanicTest[json.UnsupportedTypeError](t, make(chan struct{}))
|
||||
}
|
||||
|
||||
func TestJSONPanicsOnUnsupportedValue(t *testing.T) {
|
||||
jsonPanicTest[json.UnsupportedValueError](t, math.NaN())
|
||||
}
|
||||
|
||||
func TestJSONPanicsOnMarshalerError(t *testing.T) {
|
||||
var v erroneousJSONMarshaler
|
||||
jsonPanicTest[json.MarshalerError](t, v)
|
||||
}
|
||||
|
||||
type erroneousJSONMarshaler struct{}
|
||||
|
||||
func (erroneousJSONMarshaler) MarshalJSON() ([]byte, error) {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
|
||||
func jsonPanicTest[T json.UnsupportedTypeError | json.UnsupportedValueError | json.MarshalerError](t *testing.T, v any) {
|
||||
t.Helper()
|
||||
|
||||
defer func() {
|
||||
var err error
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic")
|
||||
} else if e, ok := r.(error); !ok {
|
||||
t.Fatalf("did not panic with an error (%T)", r)
|
||||
} else {
|
||||
err = e
|
||||
}
|
||||
|
||||
var e *T
|
||||
assert.ErrorAs(t, err, &e)
|
||||
}()
|
||||
|
||||
JSON(httptest.NewRecorder(), v)
|
||||
}
|
||||
|
||||
type renderableError struct {
|
||||
Code int `json:"-"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (err renderableError) Error() string {
|
||||
return err.Message
|
||||
}
|
||||
|
||||
func (err renderableError) Render(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "something/custom")
|
||||
|
||||
JSONStatus(w, err, err.Code)
|
||||
}
|
||||
|
||||
type statusedError struct {
|
||||
Contents string
|
||||
}
|
||||
|
||||
func (err statusedError) Error() string { return err.Contents }
|
||||
|
||||
func (statusedError) StatusCode() int { return 432 }
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
cases := []struct {
|
||||
err error
|
||||
code int
|
||||
body string
|
||||
header string
|
||||
type args struct {
|
||||
rw http.ResponseWriter
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
ok bool
|
||||
}{
|
||||
{"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true},
|
||||
{"fail", args{httptest.NewRecorder(), make(chan int)}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rw := logging.NewResponseLogger(tt.args.rw)
|
||||
JSON(rw, tt.args.v)
|
||||
|
||||
rr, ok := tt.args.rw.(*httptest.ResponseRecorder)
|
||||
if !ok {
|
||||
t.Error("ResponseWriter does not implement *httptest.ResponseRecorder")
|
||||
return
|
||||
}
|
||||
|
||||
fields := rw.Fields()
|
||||
if tt.ok {
|
||||
if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" {
|
||||
t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body)
|
||||
}
|
||||
if len(fields) != 0 {
|
||||
t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields)
|
||||
}
|
||||
} else {
|
||||
if body := rr.Body.String(); body != "" {
|
||||
t.Errorf("Unexpected body = %s, want empty string", body)
|
||||
}
|
||||
if len(fields) != 1 {
|
||||
t.Errorf("ResponseLogger fields = %v, wants 1 element", fields)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
fn func(http.ResponseWriter, error) // helper
|
||||
err error // error being render
|
||||
code int // expected status code
|
||||
body string // expected body
|
||||
}{
|
||||
// --- BadRequest
|
||||
0: {
|
||||
err: renderableError{532, "some string"},
|
||||
code: 532,
|
||||
body: "{\"message\":\"some string\"}\n",
|
||||
header: "something/custom",
|
||||
fn: BadRequest,
|
||||
err: assert.AnError,
|
||||
code: http.StatusBadRequest,
|
||||
body: `{"status":400,"message":"assert.AnError general error for testing"}`,
|
||||
},
|
||||
1: {
|
||||
err: statusedError{"123"},
|
||||
code: 432,
|
||||
body: "{\"Contents\":\"123\"}\n",
|
||||
header: "application/json",
|
||||
fn: BadRequest,
|
||||
code: http.StatusBadRequest,
|
||||
body: `{"status":400,"message":"Bad Request"}`,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -116,35 +81,13 @@ func TestError(t *testing.T) {
|
|||
|
||||
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
kase.fn(rec, kase.err)
|
||||
|
||||
Error(rec, kase.err)
|
||||
ret := rec.Result()
|
||||
|
||||
assert.Equal(t, kase.code, rec.Result().StatusCode)
|
||||
assert.Equal(t, "application/json", ret.Header.Get("Content-Type"))
|
||||
assert.Equal(t, kase.code, ret.StatusCode)
|
||||
assert.Equal(t, kase.body, rec.Body.String())
|
||||
assert.Equal(t, kase.header, rec.Header().Get("Content-Type"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type causedError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) }
|
||||
func (err causedError) Cause() error { return err.cause }
|
||||
|
||||
func TestStatusCodeFromError(t *testing.T) {
|
||||
cases := []struct {
|
||||
err error
|
||||
exp int
|
||||
}{
|
||||
0: {nil, http.StatusInternalServerError},
|
||||
1: {io.EOF, http.StatusInternalServerError},
|
||||
2: {statusedError{"123"}, 432},
|
||||
3: {causedError{statusedError{"432"}}, 432},
|
||||
}
|
||||
|
||||
for caseIndex, kase := range cases {
|
||||
assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex)
|
||||
}
|
||||
}
|
||||
|
|
28
api/renew.go
28
api/renew.go
|
@ -6,7 +6,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
|
@ -17,23 +16,14 @@ const (
|
|||
|
||||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get the leaf certificate from the peer or the token.
|
||||
cert, token, err := getPeerCertificate(r)
|
||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||
cert, err := h.getPeerCertificate(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// The token can be used by RAs to renew a certificate.
|
||||
if token != "" {
|
||||
ctx = authority.NewTokenContext(ctx, token)
|
||||
}
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
certChain, err := a.RenewContext(ctx, cert, nil)
|
||||
certChain, err := h.Authority.Renew(cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
|
@ -49,20 +39,18 @@ func Renew(w http.ResponseWriter, r *http.Request) {
|
|||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) {
|
||||
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
return r.TLS.PeerCertificates[0], "", nil
|
||||
return r.TLS.PeerCertificates[0], nil
|
||||
}
|
||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||
ctx := r.Context()
|
||||
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||
return peer, parts[1], err
|
||||
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
|
||||
}
|
||||
}
|
||||
return nil, "", errs.BadRequest("missing client certificate")
|
||||
return nil, errs.BadRequest("missing client certificate")
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
@ -34,11 +34,6 @@ func (r *RevokeRequest) Validate() (err error) {
|
|||
if r.Serial == "" {
|
||||
return errs.BadRequest("missing serial")
|
||||
}
|
||||
sn, ok := new(big.Int).SetString(r.Serial, 0)
|
||||
if !ok {
|
||||
return errs.BadRequest("'%s' is not a valid serial number - use a base 10 representation or a base 16 representation with '0x' prefix", r.Serial)
|
||||
}
|
||||
r.Serial = sn.String()
|
||||
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
|
||||
return errs.BadRequest("reasonCode out of bounds")
|
||||
}
|
||||
|
@ -54,10 +49,9 @@ func (r *RevokeRequest) Validate() (err error) {
|
|||
// NOTE: currently only Passive revocation is supported.
|
||||
//
|
||||
// TODO: Add CRL and OCSP support.
|
||||
func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body RevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if !read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -73,14 +67,12 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
PassiveOnly: body.Passive,
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod)
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
if len(body.OTT) > 0 {
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
@ -105,7 +97,7 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
opts.MTLS = true
|
||||
}
|
||||
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -31,13 +31,9 @@ func TestRevokeRequestValidate(t *testing.T) {
|
|||
rr: &RevokeRequest{},
|
||||
err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest},
|
||||
},
|
||||
"error/bad sn": {
|
||||
rr: &RevokeRequest{Serial: "sn"},
|
||||
err: &errs.Error{Err: errors.New("'sn' is not a valid serial number - use a base 10 representation or a base 16 representation with '0x' prefix"), Status: http.StatusBadRequest},
|
||||
},
|
||||
"error/bad reasonCode": {
|
||||
rr: &RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 15,
|
||||
Passive: true,
|
||||
},
|
||||
|
@ -45,7 +41,7 @@ func TestRevokeRequestValidate(t *testing.T) {
|
|||
},
|
||||
"error/non-passive not implemented": {
|
||||
rr: &RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 8,
|
||||
Passive: false,
|
||||
},
|
||||
|
@ -53,7 +49,7 @@ func TestRevokeRequestValidate(t *testing.T) {
|
|||
},
|
||||
"ok": {
|
||||
rr: &RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 9,
|
||||
Passive: true,
|
||||
},
|
||||
|
@ -62,12 +58,12 @@ func TestRevokeRequestValidate(t *testing.T) {
|
|||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.rr.Validate(); err != nil {
|
||||
var ee *errs.Error
|
||||
if errors.As(err, &ee) {
|
||||
assert.HasPrefix(t, ee.Error(), tc.err.Error())
|
||||
assert.Equals(t, ee.StatusCode(), tc.err.Status)
|
||||
} else {
|
||||
t.Errorf("unexpected error type: %T", err)
|
||||
switch v := err.(type) {
|
||||
case *errs.Error:
|
||||
assert.HasPrefix(t, v.Error(), tc.err.Error())
|
||||
assert.Equals(t, v.StatusCode(), tc.err.Status)
|
||||
default:
|
||||
t.Errorf("unexpected error type: %T", v)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
|
@ -101,7 +97,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
},
|
||||
"200/ott": func(t *testing.T) test {
|
||||
input, err := json.Marshal(RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 4,
|
||||
Reason: "foo",
|
||||
OTT: "valid",
|
||||
|
@ -112,13 +108,13 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusOK,
|
||||
auth: &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
assert.True(t, opts.PassiveOnly)
|
||||
assert.False(t, opts.MTLS)
|
||||
assert.Equals(t, opts.Serial, "10")
|
||||
assert.Equals(t, opts.Serial, "sn")
|
||||
assert.Equals(t, opts.ReasonCode, 4)
|
||||
assert.Equals(t, opts.Reason, "foo")
|
||||
return nil
|
||||
|
@ -129,7 +125,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
},
|
||||
"400/no OTT and no peer certificate": func(t *testing.T) test {
|
||||
input, err := json.Marshal(RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 4,
|
||||
Passive: true,
|
||||
})
|
||||
|
@ -156,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
statusCode: http.StatusOK,
|
||||
tls: cs,
|
||||
auth: &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, ri *authority.RevokeOptions) error {
|
||||
|
@ -180,7 +176,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
},
|
||||
"500/ott authority.Revoke": func(t *testing.T) test {
|
||||
input, err := json.Marshal(RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 4,
|
||||
Reason: "foo",
|
||||
OTT: "valid",
|
||||
|
@ -191,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusInternalServerError,
|
||||
auth: &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
|
@ -202,7 +198,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
},
|
||||
"403/ott authority.Revoke": func(t *testing.T) test {
|
||||
input, err := json.Marshal(RevokeRequest{
|
||||
Serial: "10",
|
||||
Serial: "sn",
|
||||
ReasonCode: 4,
|
||||
Reason: "foo",
|
||||
OTT: "valid",
|
||||
|
@ -213,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusForbidden,
|
||||
auth: &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
|
@ -227,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
for name, _tc := range tests {
|
||||
tc := _tc(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := New(tc.auth).(*caHandler)
|
||||
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
|
||||
if tc.tls != nil {
|
||||
req.TLS = tc.tls
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
Revoke(logging.NewResponseLogger(w), req)
|
||||
h.Revoke(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
16
api/sign.go
16
api/sign.go
|
@ -49,10 +49,9 @@ type SignResponse struct {
|
|||
// Sign is an HTTP handler that reads a certificate request and an
|
||||
// one-time-token (ott) from the body and creates a new certificate with the
|
||||
// information in the certificate request.
|
||||
func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if !read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -68,17 +67,13 @@ func Sign(w http.ResponseWriter, r *http.Request) {
|
|||
TemplateData: body.TemplateData,
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
|
||||
return
|
||||
|
@ -88,12 +83,11 @@ func Sign(w http.ResponseWriter, r *http.Request) {
|
|||
if len(certChainPEM) > 1 {
|
||||
caPEM = certChainPEM[1]
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
|
58
api/ssh.go
58
api/ssh.go
|
@ -250,10 +250,9 @@ type SSHBastionResponse struct {
|
|||
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHSignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if !read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -288,16 +287,13 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
|
@ -305,7 +301,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var addUserCertificate *SSHCertificate
|
||||
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
||||
addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
|
@ -318,7 +314,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
if cr := body.IdentityCSR.CertificateRequest; cr != nil {
|
||||
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
|
@ -330,7 +326,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
|
||||
})
|
||||
|
||||
certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||
return
|
||||
|
@ -338,7 +334,6 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
identityCertificate = certChainToPEM(certChain)
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, cert)
|
||||
render.JSONStatus(w, &SSHSignResponse{
|
||||
Certificate: SSHCertificate{cert},
|
||||
AddUserCertificate: addUserCertificate,
|
||||
|
@ -348,9 +343,8 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
|
||||
// certificates.
|
||||
func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
|
||||
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHRoots(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -374,9 +368,8 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
||||
// for user and host certificates.
|
||||
func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
|
||||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHFederation(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -400,10 +393,9 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
||||
// and servers.
|
||||
func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHConfigRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if !read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
|
@ -411,8 +403,7 @@ func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
|
||||
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -433,10 +424,9 @@ func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||
func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHCheckPrincipalRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if !read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
|
@ -444,8 +434,7 @@ func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
|
||||
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -456,14 +445,13 @@ func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
||||
func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
var cert *x509.Certificate
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
cert = r.TLS.PeerCertificates[0]
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
|
||||
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -474,10 +462,9 @@ func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// SSHBastion provides returns the bastion configured if any.
|
||||
func SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHBastionRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if !read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
|
@ -485,8 +472,7 @@ func SSHBastion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
|
||||
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
|
|
@ -39,10 +39,9 @@ type SSHRekeyResponse struct {
|
|||
// SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if !read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -59,10 +58,7 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
|
@ -73,7 +69,7 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||
return
|
||||
|
@ -83,13 +79,12 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, newCert)
|
||||
render.JSONStatus(w, &SSHRekeyResponse{
|
||||
Certificate: SSHCertificate{newCert},
|
||||
IdentityCertificate: identity,
|
||||
|
|
|
@ -37,10 +37,9 @@ type SSHRenewResponse struct {
|
|||
// SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRenewRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -51,10 +50,7 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
_, err := a.Authorize(ctx, body.OTT)
|
||||
_, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
|
@ -65,7 +61,7 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
newCert, err := a.RenewSSH(ctx, oldCert)
|
||||
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||
return
|
||||
|
@ -75,13 +71,12 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, newCert)
|
||||
render.JSONStatus(w, &SSHSignResponse{
|
||||
Certificate: SSHCertificate{newCert},
|
||||
IdentityCertificate: identity,
|
||||
|
@ -89,7 +84,7 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
|
||||
func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||
func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -109,7 +104,7 @@ func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([
|
|||
cert.NotAfter = notAfter
|
||||
}
|
||||
|
||||
certChain, err := mustAuthority(r.Context()).Renew(cert)
|
||||
certChain, err := h.Authority.Renew(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -48,10 +48,9 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
|||
// Revoke supports handful of different methods that revoke a Certificate.
|
||||
//
|
||||
// NOTE: currently only Passive revocation is supported.
|
||||
func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
if !read.JSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -68,19 +67,16 @@ func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
logOtt(w, body.OTT)
|
||||
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHSign(t *testing.T) {
|
||||
func Test_caHandler_SSHSign(t *testing.T) {
|
||||
user, err := getSignedUserCertificate()
|
||||
assert.FatalError(t, err)
|
||||
host, err := getSignedHostCertificate()
|
||||
|
@ -315,8 +315,8 @@ func Test_SSHSign(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
h := New(&mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return []provisioner.SignOption{}, tt.authErr
|
||||
},
|
||||
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
|
@ -328,11 +328,11 @@ func Test_SSHSign(t *testing.T) {
|
|||
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
return tt.tlsSignCerts, tt.tlsSignErr
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
SSHSign(logging.NewResponseLogger(w), req)
|
||||
h.SSHSign(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -353,7 +353,7 @@ func Test_SSHSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHRoots(t *testing.T) {
|
||||
func Test_caHandler_SSHRoots(t *testing.T) {
|
||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
|
@ -378,15 +378,15 @@ func Test_SSHRoots(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
h := New(&mockAuthority{
|
||||
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
SSHRoots(logging.NewResponseLogger(w), req)
|
||||
h.SSHRoots(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -407,7 +407,7 @@ func Test_SSHRoots(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHFederation(t *testing.T) {
|
||||
func Test_caHandler_SSHFederation(t *testing.T) {
|
||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
|
@ -432,15 +432,15 @@ func Test_SSHFederation(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
h := New(&mockAuthority{
|
||||
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
SSHFederation(logging.NewResponseLogger(w), req)
|
||||
h.SSHFederation(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -461,7 +461,7 @@ func Test_SSHFederation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHConfig(t *testing.T) {
|
||||
func Test_caHandler_SSHConfig(t *testing.T) {
|
||||
userOutput := []templates.Output{
|
||||
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
|
||||
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
|
||||
|
@ -492,15 +492,15 @@ func Test_SSHConfig(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
h := New(&mockAuthority{
|
||||
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||
return tt.output, tt.err
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
SSHConfig(logging.NewResponseLogger(w), req)
|
||||
h.SSHConfig(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -521,7 +521,7 @@ func Test_SSHConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHCheckHost(t *testing.T) {
|
||||
func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req string
|
||||
|
@ -539,15 +539,15 @@ func Test_SSHCheckHost(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
h := New(&mockAuthority{
|
||||
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
|
||||
return tt.exists, tt.err
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||
h.SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -568,7 +568,7 @@ func Test_SSHCheckHost(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHGetHosts(t *testing.T) {
|
||||
func Test_caHandler_SSHGetHosts(t *testing.T) {
|
||||
hosts := []authority.Host{
|
||||
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
|
||||
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
|
||||
|
@ -590,15 +590,15 @@ func Test_SSHGetHosts(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
h := New(&mockAuthority{
|
||||
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
|
||||
return tt.hosts, tt.err
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||
h.SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -619,7 +619,7 @@ func Test_SSHGetHosts(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_SSHBastion(t *testing.T) {
|
||||
func Test_caHandler_SSHBastion(t *testing.T) {
|
||||
bastion := &authority.Bastion{
|
||||
Hostname: "bastion.local",
|
||||
}
|
||||
|
@ -645,15 +645,15 @@ func Test_SSHBastion(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
h := New(&mockAuthority{
|
||||
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||
return tt.bastion, tt.bastionErr
|
||||
},
|
||||
})
|
||||
}).(*caHandler)
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
SSHBastion(logging.NewResponseLogger(w), req)
|
||||
h.SSHBastion(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
|
|
@ -1,15 +1,22 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
const (
|
||||
// provisionerContextKey provisioner key
|
||||
provisionerContextKey = ContextKey("provisioner")
|
||||
)
|
||||
|
||||
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests
|
||||
|
@ -33,120 +40,78 @@ type GetExternalAccountKeysResponse struct {
|
|||
|
||||
// requireEABEnabled is a middleware that ensures ACME EAB is enabled
|
||||
// before serving requests that act on ACME EAB credentials.
|
||||
func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
|
||||
acmeProvisioner := prov.GetDetails().GetACME()
|
||||
if acmeProvisioner == nil {
|
||||
render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName()))
|
||||
provName := chi.URLParam(r, "provisionerName")
|
||||
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !acmeProvisioner.RequireEab {
|
||||
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName()))
|
||||
if !eabEnabled {
|
||||
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
next(w, r)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// ACMEAdminResponder is responsible for writing ACME admin responses
|
||||
type ACMEAdminResponder interface {
|
||||
// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME
|
||||
// provisioner is set to true and thus has EAB enabled.
|
||||
func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil {
|
||||
return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName)
|
||||
}
|
||||
|
||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID())
|
||||
}
|
||||
|
||||
details := prov.GetDetails()
|
||||
if details == nil {
|
||||
return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID())
|
||||
}
|
||||
|
||||
acmeProvisioner := details.GetACME()
|
||||
if acmeProvisioner == nil {
|
||||
return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID())
|
||||
}
|
||||
|
||||
return acmeProvisioner.GetRequireEab(), prov, nil
|
||||
}
|
||||
|
||||
type acmeAdminResponderInterface interface {
|
||||
GetExternalAccountKeys(w http.ResponseWriter, r *http.Request)
|
||||
CreateExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
||||
DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// acmeAdminResponder implements ACMEAdminResponder.
|
||||
type acmeAdminResponder struct{}
|
||||
// ACMEAdminResponder is responsible for writing ACME admin responses
|
||||
type ACMEAdminResponder struct{}
|
||||
|
||||
// NewACMEAdminResponder returns a new ACMEAdminResponder
|
||||
func NewACMEAdminResponder() ACMEAdminResponder {
|
||||
return &acmeAdminResponder{}
|
||||
func NewACMEAdminResponder() *ACMEAdminResponder {
|
||||
return &ACMEAdminResponder{}
|
||||
}
|
||||
|
||||
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
|
||||
func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, _ *http.Request) {
|
||||
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
|
||||
func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, _ *http.Request) {
|
||||
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
|
||||
func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, _ *http.Request) {
|
||||
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
eak := &linkedca.EABKey{
|
||||
Id: k.ID,
|
||||
HmacKey: k.HmacKey,
|
||||
Provisioner: k.ProvisionerID,
|
||||
Reference: k.Reference,
|
||||
Account: k.AccountID,
|
||||
CreatedAt: timestamppb.New(k.CreatedAt),
|
||||
BoundAt: timestamppb.New(k.BoundAt),
|
||||
}
|
||||
|
||||
if k.Policy != nil {
|
||||
eak.Policy = &linkedca.Policy{
|
||||
X509: &linkedca.X509Policy{
|
||||
Allow: &linkedca.X509Names{},
|
||||
Deny: &linkedca.X509Names{},
|
||||
},
|
||||
}
|
||||
eak.Policy.X509.Allow.Dns = k.Policy.X509.Allowed.DNSNames
|
||||
eak.Policy.X509.Allow.Ips = k.Policy.X509.Allowed.IPRanges
|
||||
eak.Policy.X509.Deny.Dns = k.Policy.X509.Denied.DNSNames
|
||||
eak.Policy.X509.Deny.Ips = k.Policy.X509.Denied.IPRanges
|
||||
eak.Policy.X509.AllowWildcardNames = k.Policy.X509.AllowWildcardNames
|
||||
}
|
||||
|
||||
return eak
|
||||
}
|
||||
|
||||
func linkedEAKToCertificates(k *linkedca.EABKey) *acme.ExternalAccountKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
eak := &acme.ExternalAccountKey{
|
||||
ID: k.Id,
|
||||
ProvisionerID: k.Provisioner,
|
||||
Reference: k.Reference,
|
||||
AccountID: k.Account,
|
||||
HmacKey: k.HmacKey,
|
||||
CreatedAt: k.CreatedAt.AsTime(),
|
||||
BoundAt: k.BoundAt.AsTime(),
|
||||
}
|
||||
|
||||
if policy := k.GetPolicy(); policy != nil {
|
||||
eak.Policy = &acme.Policy{}
|
||||
if x509 := policy.GetX509(); x509 != nil {
|
||||
eak.Policy.X509 = acme.X509Policy{}
|
||||
if allow := x509.GetAllow(); allow != nil {
|
||||
eak.Policy.X509.Allowed = acme.PolicyNames{}
|
||||
eak.Policy.X509.Allowed.DNSNames = allow.Dns
|
||||
eak.Policy.X509.Allowed.IPRanges = allow.Ips
|
||||
}
|
||||
if deny := x509.GetDeny(); deny != nil {
|
||||
eak.Policy.X509.Denied = acme.PolicyNames{}
|
||||
eak.Policy.X509.Denied.DNSNames = deny.Dns
|
||||
eak.Policy.X509.Denied.IPRanges = deny.Ips
|
||||
}
|
||||
eak.Policy.X509.AllowWildcardNames = x509.AllowWildcardNames
|
||||
}
|
||||
}
|
||||
|
||||
return eak
|
||||
}
|
||||
|
|
|
@ -4,24 +4,20 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
||||
|
@ -33,90 +29,109 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
|||
return protojson.Unmarshal(data, m)
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a adminAuthority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) adminAuthority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_requireEABEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
next http.HandlerFunc
|
||||
adminDB admin.DB
|
||||
auth adminAuthority
|
||||
next nextHTTP
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/prov.GetDetails": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
"fail/h.provisionerHasEABEnabled": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
|
||||
err.Message = "error getting ACME details for provisioner 'provName'"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"fail/prov.GetDetails.GetACME": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
|
||||
err.Message = "error getting ACME details for provisioner 'provName'"
|
||||
err := admin.NewErrorISE("error loading provisioner provName: force")
|
||||
err.Message = "error loading provisioner provName: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"ok/eab-disabled": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: false,
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName")
|
||||
err.Message = "ACME EAB not enabled for provisioner 'provName'"
|
||||
err.Message = "ACME EAB not enabled for provisioner provName"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
err: err,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"ok/eab-enabled": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: true,
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
|
@ -128,9 +143,16 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
acmeDB: nil,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
requireEABEnabled(tc.next)(w, req)
|
||||
h.requireEABEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -154,6 +176,216 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHandler_provisionerHasEABEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
adminDB admin.DB
|
||||
auth adminAuthority
|
||||
provisionerName string
|
||||
want bool
|
||||
err *admin.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetProvisioner": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"fail/prov.GetDetails": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: nil,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"fail/details.GetACME": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: nil,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"ok/eab-disabled": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "eab-disabled", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "eab-disabled",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
adminDB: db,
|
||||
auth: auth,
|
||||
provisionerName: "eab-disabled",
|
||||
want: false,
|
||||
}
|
||||
},
|
||||
"ok/eab-enabled": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "eab-enabled", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "eab-enabled",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
adminDB: db,
|
||||
auth: auth,
|
||||
provisionerName: "eab-enabled",
|
||||
want: true,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
acmeDB: nil,
|
||||
}
|
||||
got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName)
|
||||
if (err != nil) != (tc.err != nil) {
|
||||
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
|
||||
return
|
||||
}
|
||||
if tc.err != nil {
|
||||
assert.Type(t, &linkedca.Provisioner{}, prov)
|
||||
assert.Type(t, &admin.Error{}, err)
|
||||
adminError, _ := err.(*admin.Error)
|
||||
assert.Equals(t, tc.err.Type, adminError.Type)
|
||||
assert.Equals(t, tc.err.Status, adminError.Status)
|
||||
assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode())
|
||||
assert.Equals(t, tc.err.Message, adminError.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminError.Detail)
|
||||
return
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Reference string
|
||||
|
@ -353,206 +585,3 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_eakToLinked(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
k *acme.ExternalAccountKey
|
||||
want *linkedca.EABKey
|
||||
}{
|
||||
{
|
||||
name: "no-key",
|
||||
k: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no-policy",
|
||||
k: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: nil,
|
||||
},
|
||||
want: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with-policy",
|
||||
k: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: &acme.Policy{
|
||||
X509: acme.X509Policy{
|
||||
Allowed: acme.PolicyNames{
|
||||
DNSNames: []string{"*.local"},
|
||||
IPRanges: []string{"10.0.0.0/24"},
|
||||
},
|
||||
Denied: acme.PolicyNames{
|
||||
DNSNames: []string{"badhost.local"},
|
||||
IPRanges: []string{"10.0.0.30"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: &linkedca.Policy{
|
||||
X509: &linkedca.X509Policy{
|
||||
Allow: &linkedca.X509Names{
|
||||
Dns: []string{"*.local"},
|
||||
Ips: []string{"10.0.0.0/24"},
|
||||
},
|
||||
Deny: &linkedca.X509Names{
|
||||
Dns: []string{"badhost.local"},
|
||||
Ips: []string{"10.0.0.30"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := eakToLinked(tt.k); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("eakToLinked() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_linkedEAKToCertificates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
k *linkedca.EABKey
|
||||
want *acme.ExternalAccountKey
|
||||
}{
|
||||
{
|
||||
name: "no-key",
|
||||
k: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no-policy",
|
||||
k: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: nil,
|
||||
},
|
||||
want: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no-x509-policy",
|
||||
k: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: &linkedca.Policy{},
|
||||
},
|
||||
want: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: &acme.Policy{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with-x509-policy",
|
||||
k: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: &linkedca.Policy{
|
||||
X509: &linkedca.X509Policy{
|
||||
Allow: &linkedca.X509Names{
|
||||
Dns: []string{"*.local"},
|
||||
Ips: []string{"10.0.0.0/24"},
|
||||
},
|
||||
Deny: &linkedca.X509Names{
|
||||
Dns: []string{"badhost.local"},
|
||||
Ips: []string{"10.0.0.30"},
|
||||
},
|
||||
AllowWildcardNames: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: &acme.Policy{
|
||||
X509: acme.X509Policy{
|
||||
Allowed: acme.PolicyNames{
|
||||
DNSNames: []string{"*.local"},
|
||||
IPRanges: []string{"10.0.0.0/24"},
|
||||
},
|
||||
Denied: acme.PolicyNames{
|
||||
DNSNames: []string{"badhost.local"},
|
||||
IPRanges: []string{"10.0.0.30"},
|
||||
},
|
||||
AllowWildcardNames: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := linkedEAKToCertificates(tt.k); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("linkedEAKToCertificates() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,10 +29,6 @@ type adminAuthority interface {
|
|||
LoadProvisionerByID(id string) (provisioner.Interface, error)
|
||||
UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error
|
||||
RemoveProvisioner(ctx context.Context, id string) error
|
||||
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
|
||||
CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
RemoveAuthorityPolicy(ctx context.Context) error
|
||||
}
|
||||
|
||||
// CreateAdminRequest represents the body for a CreateAdmin request.
|
||||
|
@ -85,10 +81,10 @@ type DeleteResponse struct {
|
|||
}
|
||||
|
||||
// GetAdmin returns the requested admin, or an error.
|
||||
func GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
adm, ok := mustAuthority(r.Context()).LoadAdminByID(id)
|
||||
adm, ok := h.auth.LoadAdminByID(id)
|
||||
if !ok {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found", id))
|
||||
|
@ -98,7 +94,7 @@ func GetAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// GetAdmins returns a segment of admins associated with the authority.
|
||||
func GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
|
@ -106,7 +102,7 @@ func GetAdmins(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit)
|
||||
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
||||
return
|
||||
|
@ -118,10 +114,9 @@ func GetAdmins(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// CreateAdmin creates a new admin.
|
||||
func CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body CreateAdminRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
if !read.AdminJSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -130,8 +125,7 @@ func CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(r.Context())
|
||||
p, err := auth.LoadProvisionerByName(body.Provisioner)
|
||||
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
||||
return
|
||||
|
@ -142,7 +136,7 @@ func CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
Type: body.Type,
|
||||
}
|
||||
// Store to authority collection.
|
||||
if err := auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
|
||||
return
|
||||
}
|
||||
|
@ -151,10 +145,10 @@ func CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// DeleteAdmin deletes admin.
|
||||
func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil {
|
||||
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
||||
return
|
||||
}
|
||||
|
@ -163,10 +157,9 @@ func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// UpdateAdmin updates an existing admin.
|
||||
func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body UpdateAdminRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
if !read.AdminJSON(w, r, &body) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -176,8 +169,8 @@ func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
auth := mustAuthority(r.Context())
|
||||
adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||
|
||||
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
||||
return
|
||||
|
|
|
@ -14,13 +14,11 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type mockAdminAuthority struct {
|
||||
|
@ -39,11 +37,6 @@ type mockAdminAuthority struct {
|
|||
MockLoadProvisionerByID func(id string) (provisioner.Interface, error)
|
||||
MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error
|
||||
MockRemoveProvisioner func(ctx context.Context, id string) error
|
||||
|
||||
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
|
||||
MockCreateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
MockUpdateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
MockRemoveAuthorityPolicy func(ctx context.Context) error
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) IsAdminAPIEnabled() bool {
|
||||
|
@ -137,34 +130,6 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e
|
|||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
if m.MockGetAuthorityPolicy != nil {
|
||||
return m.MockGetAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Policy), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
if m.MockCreateAuthorityPolicy != nil {
|
||||
return m.MockCreateAuthorityPolicy(ctx, adm, policy)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Policy), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
if m.MockUpdateAuthorityPolicy != nil {
|
||||
return m.MockUpdateAuthorityPolicy(ctx, adm, policy)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Policy), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error {
|
||||
if m.MockRemoveAuthorityPolicy != nil {
|
||||
return m.MockRemoveAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func TestCreateAdminRequest_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Subject string
|
||||
|
@ -229,13 +194,11 @@ func TestCreateAdminRequest_Validate(t *testing.T) {
|
|||
|
||||
if err != nil {
|
||||
assert.Type(t, &admin.Error{}, err)
|
||||
var adminErr *admin.Error
|
||||
if assert.True(t, errors.As(err, &adminErr)) {
|
||||
assert.Equals(t, tt.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tt.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, tt.err.Status, adminErr.Status)
|
||||
assert.Equals(t, tt.err.Message, adminErr.Message)
|
||||
}
|
||||
adminErr, _ := err.(*admin.Error)
|
||||
assert.Equals(t, tt.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tt.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, tt.err.Status, adminErr.Status)
|
||||
assert.Equals(t, tt.err.Message, adminErr.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -280,13 +243,11 @@ func TestUpdateAdminRequest_Validate(t *testing.T) {
|
|||
|
||||
if err != nil {
|
||||
assert.Type(t, &admin.Error{}, err)
|
||||
var ae *admin.Error
|
||||
if assert.True(t, errors.As(err, &ae)) {
|
||||
assert.Equals(t, tt.err.Type, ae.Type)
|
||||
assert.Equals(t, tt.err.Detail, ae.Detail)
|
||||
assert.Equals(t, tt.err.Status, ae.Status)
|
||||
assert.Equals(t, tt.err.Message, ae.Message)
|
||||
}
|
||||
adminErr, _ := err.(*admin.Error)
|
||||
assert.Equals(t, tt.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tt.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, tt.err.Status, adminErr.Status)
|
||||
assert.Equals(t, tt.err.Message, adminErr.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -356,11 +317,14 @@ func TestHandler_GetAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetAdmin(w, req)
|
||||
h.GetAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -492,10 +456,13 @@ func TestHandler_GetAdmins(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetAdmins(w, req)
|
||||
h.GetAdmins(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -673,11 +640,13 @@ func TestHandler_CreateAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
CreateAdmin(w, req)
|
||||
h.CreateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -763,11 +732,13 @@ func TestHandler_DeleteAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
DeleteAdmin(w, req)
|
||||
h.DeleteAdmin(w, req)
|
||||
res := w.Result()
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
|
@ -906,11 +877,13 @@ func TestHandler_UpdateAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
UpdateAdmin(w, req)
|
||||
h.UpdateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
|
@ -1,133 +1,56 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
var mustAuthority = func(ctx context.Context) adminAuthority {
|
||||
return authority.MustFromContext(ctx)
|
||||
// Handler is the Admin API request handler.
|
||||
type Handler struct {
|
||||
adminDB admin.DB
|
||||
auth adminAuthority
|
||||
acmeDB acme.DB
|
||||
acmeResponder acmeAdminResponderInterface
|
||||
}
|
||||
|
||||
type router struct {
|
||||
acmeResponder ACMEAdminResponder
|
||||
policyResponder PolicyAdminResponder
|
||||
webhookResponder WebhookAdminResponder
|
||||
}
|
||||
|
||||
type RouterOption func(*router)
|
||||
|
||||
func WithACMEResponder(acmeResponder ACMEAdminResponder) RouterOption {
|
||||
return func(r *router) {
|
||||
r.acmeResponder = acmeResponder
|
||||
}
|
||||
}
|
||||
|
||||
func WithPolicyResponder(policyResponder PolicyAdminResponder) RouterOption {
|
||||
return func(r *router) {
|
||||
r.policyResponder = policyResponder
|
||||
}
|
||||
}
|
||||
|
||||
func WithWebhookResponder(webhookResponder WebhookAdminResponder) RouterOption {
|
||||
return func(r *router) {
|
||||
r.webhookResponder = webhookResponder
|
||||
// NewHandler returns a new Authority Config Handler.
|
||||
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler {
|
||||
return &Handler{
|
||||
auth: auth,
|
||||
adminDB: adminDB,
|
||||
acmeDB: acmeDB,
|
||||
acmeResponder: acmeResponder,
|
||||
}
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func Route(r api.Router, options ...RouterOption) {
|
||||
router := &router{}
|
||||
for _, fn := range options {
|
||||
fn(router)
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
authnz := func(next nextHTTP) nextHTTP {
|
||||
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next))
|
||||
}
|
||||
|
||||
authnz := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return extractAuthorizeTokenAdmin(requireAPIEnabled(next))
|
||||
}
|
||||
|
||||
enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return checkAction(next, true)
|
||||
}
|
||||
|
||||
disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return checkAction(next, false)
|
||||
}
|
||||
|
||||
acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(loadProvisionerByName(requireEABEnabled(next)))
|
||||
}
|
||||
|
||||
authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(enabledInStandalone(next))
|
||||
}
|
||||
|
||||
provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(disabledInStandalone(loadProvisionerByName(next)))
|
||||
}
|
||||
|
||||
acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next)))))
|
||||
}
|
||||
|
||||
webhookMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(loadProvisionerByName(next))
|
||||
requireEABEnabled := func(next nextHTTP) nextHTTP {
|
||||
return h.requireEABEnabled(next)
|
||||
}
|
||||
|
||||
// Provisioners
|
||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners))
|
||||
r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner))
|
||||
r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner))
|
||||
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners))
|
||||
r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner))
|
||||
r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner))
|
||||
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner))
|
||||
|
||||
// Admins
|
||||
r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin))
|
||||
r.MethodFunc("GET", "/admins", authnz(GetAdmins))
|
||||
r.MethodFunc("POST", "/admins", authnz(CreateAdmin))
|
||||
r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin))
|
||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin))
|
||||
r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin))
|
||||
r.MethodFunc("GET", "/admins", authnz(h.GetAdmins))
|
||||
r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin))
|
||||
r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin))
|
||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin))
|
||||
|
||||
// ACME responder
|
||||
if router.acmeResponder != nil {
|
||||
// ACME External Account Binding Keys
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(router.acmeResponder.GetExternalAccountKeys))
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(router.acmeResponder.GetExternalAccountKeys))
|
||||
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(router.acmeResponder.CreateExternalAccountKey))
|
||||
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(router.acmeResponder.DeleteExternalAccountKey))
|
||||
}
|
||||
|
||||
// Policy responder
|
||||
if router.policyResponder != nil {
|
||||
// Policy - Authority
|
||||
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(router.policyResponder.GetAuthorityPolicy))
|
||||
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(router.policyResponder.CreateAuthorityPolicy))
|
||||
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(router.policyResponder.UpdateAuthorityPolicy))
|
||||
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(router.policyResponder.DeleteAuthorityPolicy))
|
||||
|
||||
// Policy - Provisioner
|
||||
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(router.policyResponder.GetProvisionerPolicy))
|
||||
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(router.policyResponder.CreateProvisionerPolicy))
|
||||
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(router.policyResponder.UpdateProvisionerPolicy))
|
||||
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(router.policyResponder.DeleteProvisionerPolicy))
|
||||
|
||||
// Policy - ACME Account
|
||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(router.policyResponder.GetACMEAccountPolicy))
|
||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(router.policyResponder.GetACMEAccountPolicy))
|
||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(router.policyResponder.CreateACMEAccountPolicy))
|
||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(router.policyResponder.CreateACMEAccountPolicy))
|
||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(router.policyResponder.UpdateACMEAccountPolicy))
|
||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(router.policyResponder.UpdateACMEAccountPolicy))
|
||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(router.policyResponder.DeleteACMEAccountPolicy))
|
||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(router.policyResponder.DeleteACMEAccountPolicy))
|
||||
}
|
||||
|
||||
if router.webhookResponder != nil {
|
||||
r.MethodFunc("POST", "/provisioners/{provisionerName}/webhooks", webhookMiddleware(router.webhookResponder.CreateProvisionerWebhook))
|
||||
r.MethodFunc("PUT", "/provisioners/{provisionerName}/webhooks/{webhookName}", webhookMiddleware(router.webhookResponder.UpdateProvisionerWebhook))
|
||||
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/webhooks/{webhookName}", webhookMiddleware(router.webhookResponder.DeleteProvisionerWebhook))
|
||||
}
|
||||
// ACME External Account Binding Keys
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
|
||||
r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey)))
|
||||
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey)))
|
||||
}
|
||||
|
|
|
@ -1,26 +1,22 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
||||
// requireAPIEnabled is a middleware that ensures the Administration API
|
||||
// is enabled before servicing requests.
|
||||
func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
|
||||
if !h.auth.IsAdminAPIEnabled() {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"administration API not enabled"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
|
@ -28,7 +24,7 @@ func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
|||
}
|
||||
|
||||
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
|
||||
func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||||
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
tok := r.Header.Get("Authorization")
|
||||
if tok == "" {
|
||||
|
@ -37,111 +33,22 @@ func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok)
|
||||
adm, err := h.auth.AuthorizeAdminToken(r, tok)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx = linkedca.NewContextWithAdmin(ctx, adm)
|
||||
ctx := context.WithValue(r.Context(), adminContextKey, adm)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// loadProvisionerByName is a middleware that searches for a provisioner
|
||||
// by name and stores it in the context.
|
||||
func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
// ContextKey is the key type for storing and searching for ACME request
|
||||
// essentials in the context of a request.
|
||||
type ContextKey string
|
||||
|
||||
ctx := r.Context()
|
||||
auth := mustAuthority(ctx)
|
||||
adminDB := admin.MustFromContext(ctx)
|
||||
name := chi.URLParam(r, "provisionerName")
|
||||
|
||||
// TODO(hs): distinguish 404 vs. 500
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// checkAction checks if an action is supported in standalone or not
|
||||
func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// actions allowed in standalone mode are always supported
|
||||
if supportedInStandalone {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// when an action is not supported in standalone mode and when
|
||||
// using a nosql.DB backend, actions are not supported
|
||||
if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"operation not supported in standalone mode"))
|
||||
return
|
||||
}
|
||||
|
||||
// continue to next http handler
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// loadExternalAccountKey is a middleware that searches for an ACME
|
||||
// External Account Key by reference or keyID and stores it in the context.
|
||||
func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
reference := chi.URLParam(r, "reference")
|
||||
keyID := chi.URLParam(r, "keyID")
|
||||
|
||||
var (
|
||||
eak *acme.ExternalAccountKey
|
||||
err error
|
||||
)
|
||||
|
||||
if keyID != "" {
|
||||
eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
|
||||
} else {
|
||||
eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, acme.ErrNotFound) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
|
||||
return
|
||||
}
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key"))
|
||||
return
|
||||
}
|
||||
|
||||
if eak == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
|
||||
return
|
||||
}
|
||||
|
||||
linkedEAK := eakToLinked(eak)
|
||||
|
||||
ctx = linkedca.NewContextWithExternalAccountKey(ctx, linkedEAK)
|
||||
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
const (
|
||||
// adminContextKey account key
|
||||
adminContextKey = ContextKey("admin")
|
||||
)
|
||||
|
|
|
@ -4,32 +4,25 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func TestHandler_requireAPIEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
auth adminAuthority
|
||||
next http.HandlerFunc
|
||||
next nextHTTP
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
|
@ -71,11 +64,13 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
requireAPIEnabled(tc.next)(w, req)
|
||||
h.requireAPIEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -107,7 +102,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
ctx context.Context
|
||||
auth adminAuthority
|
||||
req *http.Request
|
||||
next http.HandlerFunc
|
||||
next nextHTTP
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
|
@ -157,7 +152,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
req.Header["Authorization"] = []string{"token"}
|
||||
createdAt := time.Now()
|
||||
var deletedAt time.Time
|
||||
adm := &linkedca.Admin{
|
||||
admin := &linkedca.Admin{
|
||||
Id: "adminID",
|
||||
AuthorityId: "authorityID",
|
||||
Subject: "admin",
|
||||
|
@ -169,15 +164,20 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
auth := &mockAdminAuthority{
|
||||
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||
assert.Equals(t, "token", token)
|
||||
return adm, nil
|
||||
return admin, nil
|
||||
},
|
||||
}
|
||||
next := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
adm := linkedca.MustAdminFromContext(ctx) // verifying that the context now has a linkedca.Admin
|
||||
a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin
|
||||
adm, ok := a.(*linkedca.Admin)
|
||||
if !ok {
|
||||
t.Errorf("expected *linkedca.Admin; got %T", a)
|
||||
return
|
||||
}
|
||||
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
|
||||
if !cmp.Equal(adm, adm, opts...) {
|
||||
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...))
|
||||
if !cmp.Equal(admin, adm, opts...) {
|
||||
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...))
|
||||
}
|
||||
w.Write(nil) // mock response with status 200
|
||||
}
|
||||
|
@ -194,459 +194,13 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
err := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_loadProvisionerByName(t *testing.T) {
|
||||
type test struct {
|
||||
adminDB admin.DB
|
||||
auth adminAuthority
|
||||
ctx context.Context
|
||||
next http.HandlerFunc
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
err := admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName")
|
||||
err.Message = "error loading provisioner provName: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: err,
|
||||
}
|
||||
},
|
||||
"fail/db.GetProvisioner": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving provisioner provName")
|
||||
err.Message = "error retrieving provisioner provName: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
statusCode: 500,
|
||||
err: err,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
statusCode: 200,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
prov := linkedca.MustProvisionerFromContext(r.Context())
|
||||
assert.NotNil(t, prov)
|
||||
assert.Equals(t, "provID", prov.GetId())
|
||||
assert.Equals(t, "provName", prov.GetName())
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
loadProvisionerByName(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
err := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_checkAction(t *testing.T) {
|
||||
type test struct {
|
||||
adminDB admin.DB
|
||||
next http.HandlerFunc
|
||||
supportedInStandalone bool
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"standalone-nosql-supported": func(t *testing.T) test {
|
||||
return test{
|
||||
supportedInStandalone: true,
|
||||
adminDB: &nosql.DB{},
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"standalone-nosql-not-supported": func(t *testing.T) test {
|
||||
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode")
|
||||
err.Message = "operation not supported in standalone mode"
|
||||
return test{
|
||||
supportedInStandalone: false,
|
||||
adminDB: &nosql.DB{},
|
||||
statusCode: 501,
|
||||
err: err,
|
||||
}
|
||||
},
|
||||
"standalone-no-nosql-not-supported": func(t *testing.T) test {
|
||||
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported")
|
||||
err.Message = "operation not supported"
|
||||
return test{
|
||||
supportedInStandalone: false,
|
||||
adminDB: &admin.MockDB{},
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
statusCode: 200,
|
||||
err: err,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := admin.NewContext(context.Background(), tc.adminDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
checkAction(tc.next, tc.supportedInStandalone)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
err := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_loadExternalAccountKey(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
acmeDB acme.DB
|
||||
next http.HandlerFunc
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/keyID-not-found-error": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("keyID", "key")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "key", keyID)
|
||||
return nil, acme.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
},
|
||||
"fail/keyID-error": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("keyID", "key")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
|
||||
err.Message = "error retrieving ACME External Account Key: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "key", keyID)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"fail/reference-not-found-error": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("reference", "ref")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "ref", reference)
|
||||
return nil, acme.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
},
|
||||
"fail/reference-error": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("reference", "ref")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
|
||||
err.Message = "error retrieving ACME External Account Key: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "ref", reference)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"fail/no-key": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("reference", "ref")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "ref", reference)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
},
|
||||
"ok/keyID": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("keyID", "eakID")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
createdAt := time.Now().Add(-1 * time.Hour)
|
||||
var boundAt time.Time
|
||||
eak := &acme.ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
CreatedAt: createdAt,
|
||||
BoundAt: boundAt,
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "eakID", keyID)
|
||||
return eak, nil
|
||||
},
|
||||
},
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
|
||||
assert.NotNil(t, eak)
|
||||
exp := &linkedca.EABKey{
|
||||
Id: "eakID",
|
||||
Provisioner: "provID",
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
BoundAt: timestamppb.New(boundAt),
|
||||
}
|
||||
assert.Equals(t, exp, contextEAK)
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
err: nil,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/reference": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("reference", "ref")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
createdAt := time.Now().Add(-1 * time.Hour)
|
||||
var boundAt time.Time
|
||||
eak := &acme.ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
CreatedAt: createdAt,
|
||||
BoundAt: boundAt,
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "ref", reference)
|
||||
return eak, nil
|
||||
},
|
||||
},
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
|
||||
assert.NotNil(t, eak)
|
||||
exp := &linkedca.EABKey{
|
||||
Id: "eakID",
|
||||
Provisioner: "provID",
|
||||
Reference: "ref",
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
BoundAt: timestamppb.New(boundAt),
|
||||
}
|
||||
assert.Equals(t, exp, contextEAK)
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
err: nil,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
loadExternalAccountKey(tc.next)(w, req)
|
||||
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
|
@ -1,499 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
// PolicyAdminResponder is the interface responsible for writing ACME admin
|
||||
// responses.
|
||||
type PolicyAdminResponder interface {
|
||||
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
GetProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// policyAdminResponder implements PolicyAdminResponder.
|
||||
type policyAdminResponder struct{}
|
||||
|
||||
// NewACMEAdminResponder returns a new PolicyAdminResponder.
|
||||
func NewPolicyAdminResponder() PolicyAdminResponder {
|
||||
return &policyAdminResponder{}
|
||||
}
|
||||
|
||||
// GetAuthorityPolicy handles the GET /admin/authority/policy request
|
||||
func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(r.Context())
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, authorityPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
|
||||
func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
|
||||
render.Error(w, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
adm := linkedca.MustAdminFromContext(ctx)
|
||||
|
||||
var createdPolicy *linkedca.Policy
|
||||
if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, createdPolicy, http.StatusCreated)
|
||||
}
|
||||
|
||||
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
|
||||
func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
adm := linkedca.MustAdminFromContext(ctx)
|
||||
|
||||
var updatedPolicy *linkedca.Policy
|
||||
if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, updatedPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
|
||||
func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.RemoveAuthorityPolicy(ctx); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
|
||||
func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, provisionerPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
|
||||
func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
|
||||
render.Error(w, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
prov.Policy = newPolicy
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
|
||||
}
|
||||
|
||||
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
|
||||
func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
prov.Policy = newPolicy
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
|
||||
func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
if prov.Policy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
// remove the policy
|
||||
prov.Policy = nil
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, eakPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
|
||||
render.Error(w, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
eak.Policy = newPolicy
|
||||
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
eak.Policy = newPolicy
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
// remove the policy
|
||||
eak.Policy = nil
|
||||
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
// blockLinkedCA blocks all API operations on linked deployments
|
||||
func blockLinkedCA(ctx context.Context) error {
|
||||
// temporary blocking linked deployments
|
||||
adminDB := admin.MustFromContext(ctx)
|
||||
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() {
|
||||
return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isBadRequest checks if an error should result in a bad request error
|
||||
// returned to the client.
|
||||
func isBadRequest(err error) bool {
|
||||
var pe *authority.PolicyError
|
||||
isPolicyError := errors.As(err, &pe)
|
||||
return isPolicyError && (pe.Typ == authority.AdminLockOut || pe.Typ == authority.EvaluationFailure || pe.Typ == authority.ConfigurationFailure)
|
||||
}
|
||||
|
||||
func validatePolicy(p *linkedca.Policy) error {
|
||||
// convert the policy; return early if nil
|
||||
options := policy.LinkedToCertificates(p)
|
||||
if options == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Initialize a temporary x509 allow/deny policy engine
|
||||
if _, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize a temporary SSH allow/deny policy engine for host certificates
|
||||
if _, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize a temporary SSH allow/deny policy engine for user certificates
|
||||
if _, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,13 +1,10 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
|
@ -26,31 +23,29 @@ type GetProvisionersResponse struct {
|
|||
}
|
||||
|
||||
// GetProvisioner returns the requested provisioner, or an error.
|
||||
func GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
|
||||
ctx := r.Context()
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
auth := mustAuthority(ctx)
|
||||
db := admin.MustFromContext(ctx)
|
||||
|
||||
if len(id) > 0 {
|
||||
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
prov, err := db.GetProvisioner(ctx, p.GetID())
|
||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -59,7 +54,7 @@ func GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// GetProvisioners returns the given segment of provisioners associated with the authority.
|
||||
func GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
|
@ -67,7 +62,7 @@ func GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
p, next, err := h.auth.GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -79,10 +74,9 @@ func GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// CreateProvisioner creates a new prov.
|
||||
func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var prov = new(linkedca.Provisioner)
|
||||
if err := read.ProtoJSON(r.Body, prov); err != nil {
|
||||
render.Error(w, err)
|
||||
if !read.ProtoJSON(w, r, prov) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -92,13 +86,7 @@ func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// validate the templates and template data
|
||||
if err := validateTemplates(prov.X509Template, prov.SshTemplate); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil {
|
||||
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
||||
return
|
||||
}
|
||||
|
@ -106,29 +94,27 @@ func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// DeleteProvisioner deletes a provisioner.
|
||||
func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
auth := mustAuthority(r.Context())
|
||||
|
||||
if len(id) > 0 {
|
||||
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
||||
return
|
||||
}
|
||||
|
@ -137,27 +123,22 @@ func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// UpdateProvisioner updates an existing prov.
|
||||
func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var nu = new(linkedca.Provisioner)
|
||||
if err := read.ProtoJSON(r.Body, nu); err != nil {
|
||||
render.Error(w, err)
|
||||
if !read.ProtoJSON(w, r, nu) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
name := chi.URLParam(r, "name")
|
||||
auth := mustAuthority(ctx)
|
||||
db := admin.MustFromContext(ctx)
|
||||
|
||||
p, err := auth.LoadProvisionerByName(name)
|
||||
_old, err := h.auth.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
||||
return
|
||||
}
|
||||
|
||||
old, err := db.GetProvisioner(r.Context(), p.GetID())
|
||||
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -188,47 +169,9 @@ func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// validate the templates and template data
|
||||
if err := validateTemplates(nu.X509Template, nu.SshTemplate); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
render.ProtoJSON(w, nu)
|
||||
}
|
||||
|
||||
// validateTemplates validates the X.509 and SSH templates and template data if set.
|
||||
func validateTemplates(x509, ssh *linkedca.Template) error {
|
||||
if x509 != nil {
|
||||
if len(x509.Template) > 0 {
|
||||
if err := x509util.ValidateTemplate(x509.Template); err != nil {
|
||||
return fmt.Errorf("invalid X.509 template: %w", err)
|
||||
}
|
||||
}
|
||||
if len(x509.Data) > 0 {
|
||||
if err := x509util.ValidateTemplateData(x509.Data); err != nil {
|
||||
return fmt.Errorf("invalid X.509 template data: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ssh != nil {
|
||||
if len(ssh.Template) > 0 {
|
||||
if err := sshutil.ValidateTemplate(ssh.Template); err != nil {
|
||||
return fmt.Errorf("invalid SSH template: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ssh.Data) > 0 {
|
||||
if err := sshutil.ValidateTemplateData(ssh.Data); err != nil {
|
||||
return fmt.Errorf("invalid SSH template data: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -8,21 +8,18 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func TestHandler_GetProvisioner(t *testing.T) {
|
||||
|
@ -50,7 +47,6 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||
ctx: ctx,
|
||||
req: req,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
|
@ -75,7 +71,6 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||
ctx: ctx,
|
||||
req: req,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
|
@ -158,11 +153,13 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := tc.req.WithContext(ctx)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetProvisioner(w, req)
|
||||
h.GetProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -280,10 +277,12 @@ func TestHandler_GetProvisioners(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
GetProvisioners(w, req)
|
||||
h.GetProvisioners(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -336,12 +335,12 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
|||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: "badRequest",
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "proto: syntax error (line 1:2): invalid value !",
|
||||
statusCode: 500,
|
||||
err: &admin.Error{ // TODO(hs): this probably needs a better error
|
||||
Type: "",
|
||||
Status: 500,
|
||||
Detail: "",
|
||||
Message: "",
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -349,29 +348,6 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
|||
// "fail/authority.ValidateClaims": func(t *testing.T) test {
|
||||
// return test{}
|
||||
// },
|
||||
"fail/validateTemplates": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Type: linkedca.Provisioner_OIDC,
|
||||
Name: "provName",
|
||||
X509Template: &linkedca.Template{
|
||||
Template: []byte(`{ {{missingFunction "foo"}} }`),
|
||||
},
|
||||
}
|
||||
body, err := protojson.Marshal(prov)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: "badRequest",
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "invalid template: invalid X.509 template: error parsing template: template: template:1: function \"missingFunction\" not defined",
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/auth.StoreProvisioner": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
|
@ -426,11 +402,13 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
CreateProvisioner(w, req)
|
||||
h.CreateProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -445,15 +423,9 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
|||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||
|
||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
if strings.HasPrefix(tc.err.Message, "proto:") {
|
||||
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
|
||||
} else {
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -590,10 +562,12 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
DeleteProvisioner(w, req)
|
||||
h.DeleteProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -642,13 +616,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: "badRequest",
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "proto: syntax error (line 1:2): invalid value !",
|
||||
statusCode: 500,
|
||||
err: &admin.Error{ // TODO(hs): this probably needs a better error
|
||||
Type: "",
|
||||
Status: 500,
|
||||
Detail: "",
|
||||
Message: "",
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -672,7 +645,6 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: auth,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
|
@ -959,61 +931,6 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
},
|
||||
// TODO(hs): ValidateClaims can't be mocked atm
|
||||
//"fail/ValidateClaims": func(t *testing.T) test { return test{} },
|
||||
"fail/validateTemplates": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("name", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
createdAt := time.Now()
|
||||
var deletedAt time.Time
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Type: linkedca.Provisioner_OIDC,
|
||||
Name: "provName",
|
||||
AuthorityId: "authorityID",
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
DeletedAt: timestamppb.New(deletedAt),
|
||||
X509Template: &linkedca.Template{
|
||||
Template: []byte("{ {{ missingFunction }} }"),
|
||||
},
|
||||
}
|
||||
body, err := protojson.Marshal(prov)
|
||||
assert.FatalError(t, err)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.OIDC{
|
||||
ID: "provID",
|
||||
Name: "provName",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Type: linkedca.Provisioner_OIDC,
|
||||
AuthorityId: "authorityID",
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
DeletedAt: timestamppb.New(deletedAt),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: "badRequest",
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "invalid template: invalid X.509 template: error parsing template: template: template:1: function \"missingFunction\" not defined",
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/auth.UpdateProvisioner": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("name", "provName")
|
||||
|
@ -1135,12 +1052,14 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
UpdateProvisioner(w, req)
|
||||
h.UpdateProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -1155,15 +1074,9 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||
|
||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
if strings.HasPrefix(tc.err.Message, "proto:") {
|
||||
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
|
||||
} else {
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1185,87 +1098,3 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validateTemplates(t *testing.T) {
|
||||
type args struct {
|
||||
x509 *linkedca.Template
|
||||
ssh *linkedca.Template
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ok/x509",
|
||||
args: args{
|
||||
x509: &linkedca.Template{
|
||||
Template: []byte(`{"x": 1}`),
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ok/ssh",
|
||||
args: args{
|
||||
ssh: &linkedca.Template{
|
||||
Template: []byte(`{"x": 1}`),
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "fail/x509-template-missing-quote",
|
||||
args: args{
|
||||
x509: &linkedca.Template{
|
||||
Template: []byte(`{ {{printf "%q" "quoted}} }`),
|
||||
},
|
||||
},
|
||||
err: errors.New("invalid X.509 template: error parsing template: template: template:1: unterminated quoted string"),
|
||||
},
|
||||
{
|
||||
name: "fail/x509-template-data",
|
||||
args: args{
|
||||
x509: &linkedca.Template{
|
||||
Data: []byte(`{!?}`),
|
||||
},
|
||||
},
|
||||
err: errors.New("invalid X.509 template data: error validating json template data"),
|
||||
},
|
||||
{
|
||||
name: "fail/ssh-template-unknown-function",
|
||||
args: args{
|
||||
ssh: &linkedca.Template{
|
||||
Template: []byte(`{ {{unknownFunction "foo"}} }`),
|
||||
},
|
||||
},
|
||||
err: errors.New("invalid SSH template: error parsing template: template: template:1: function \"unknownFunction\" not defined"),
|
||||
},
|
||||
{
|
||||
name: "fail/ssh-template-data",
|
||||
args: args{
|
||||
ssh: &linkedca.Template{
|
||||
Data: []byte(`{!?}`),
|
||||
},
|
||||
},
|
||||
err: errors.New("invalid SSH template data: error validating json template data"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateTemplates(tt.args.x509, tt.args.ssh)
|
||||
if tt.err != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equals(t, tt.err.Error(), err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,235 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"go.step.sm/crypto/randutil"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// WebhookAdminResponder is the interface responsible for writing webhook admin
|
||||
// responses.
|
||||
type WebhookAdminResponder interface {
|
||||
CreateProvisionerWebhook(w http.ResponseWriter, r *http.Request)
|
||||
UpdateProvisionerWebhook(w http.ResponseWriter, r *http.Request)
|
||||
DeleteProvisionerWebhook(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// webhoookAdminResponder implements WebhookAdminResponder
|
||||
type webhookAdminResponder struct{}
|
||||
|
||||
// NewWebhookAdminResponder returns a new WebhookAdminResponder
|
||||
func NewWebhookAdminResponder() WebhookAdminResponder {
|
||||
return &webhookAdminResponder{}
|
||||
}
|
||||
|
||||
func validateWebhook(webhook *linkedca.Webhook) error {
|
||||
if webhook == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// name
|
||||
if webhook.Name == "" {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "webhook name is required")
|
||||
}
|
||||
|
||||
// url
|
||||
parsedURL, err := url.Parse(webhook.Url)
|
||||
if err != nil {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid")
|
||||
}
|
||||
if parsedURL.Host == "" {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid")
|
||||
}
|
||||
if parsedURL.Scheme != "https" {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "webhook url must use https")
|
||||
}
|
||||
if parsedURL.User != nil {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "webhook url may not contain username or password")
|
||||
}
|
||||
|
||||
// kind
|
||||
switch webhook.Kind {
|
||||
case linkedca.Webhook_ENRICHING, linkedca.Webhook_AUTHORIZING, linkedca.Webhook_SCEPCHALLENGE:
|
||||
default:
|
||||
return admin.NewError(admin.ErrorBadRequestType, "webhook kind %q is invalid", webhook.Kind)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
|
||||
var newWebhook = new(linkedca.Webhook)
|
||||
if err := read.ProtoJSON(r.Body, newWebhook); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateWebhook(newWebhook); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if newWebhook.Secret != "" {
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "webhook secret must not be set")
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if newWebhook.Id != "" {
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "webhook ID must not be set")
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
id, err := randutil.UUIDv4()
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error generating webhook id"))
|
||||
return
|
||||
}
|
||||
newWebhook.Id = id
|
||||
|
||||
// verify the name is unique
|
||||
for _, wh := range prov.Webhooks {
|
||||
if wh.Name == newWebhook.Name {
|
||||
err := admin.NewError(admin.ErrorConflictType, "provisioner %q already has a webhook with the name %q", prov.Name, newWebhook.Name)
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
secret, err := randutil.Bytes(64)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error generating webhook secret"))
|
||||
return
|
||||
}
|
||||
newWebhook.Secret = base64.StdEncoding.EncodeToString(secret)
|
||||
|
||||
prov.Webhooks = append(prov.Webhooks, newWebhook)
|
||||
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, newWebhook, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (war *webhookAdminResponder) DeleteProvisionerWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
|
||||
webhookName := chi.URLParam(r, "webhookName")
|
||||
|
||||
found := false
|
||||
for i, wh := range prov.Webhooks {
|
||||
if wh.Name == webhookName {
|
||||
prov.Webhooks = append(prov.Webhooks[0:i], prov.Webhooks[i+1:]...)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error deleting provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
|
||||
var newWebhook = new(linkedca.Webhook)
|
||||
if err := read.ProtoJSON(r.Body, newWebhook); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateWebhook(newWebhook); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
found := false
|
||||
for i, wh := range prov.Webhooks {
|
||||
if wh.Name != newWebhook.Name {
|
||||
continue
|
||||
}
|
||||
if newWebhook.Secret != "" && newWebhook.Secret != wh.Secret {
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "webhook secret cannot be updated")
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
newWebhook.Secret = wh.Secret
|
||||
if newWebhook.Id != "" && newWebhook.Id != wh.Id {
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "webhook ID cannot be updated")
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
newWebhook.Id = wh.Id
|
||||
prov.Webhooks[i] = newWebhook
|
||||
found = true
|
||||
break
|
||||
}
|
||||
if !found {
|
||||
msg := fmt.Sprintf("provisioner %q has no webhook with the name %q", prov.Name, newWebhook.Name)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, msg)
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
// Return a copy without the signing secret. Include the client-supplied
|
||||
// auth secrets since those may have been updated in this request and we
|
||||
// should show in the response that they changed
|
||||
whResponse := &linkedca.Webhook{
|
||||
Id: newWebhook.Id,
|
||||
Name: newWebhook.Name,
|
||||
Url: newWebhook.Url,
|
||||
Kind: newWebhook.Kind,
|
||||
CertType: newWebhook.CertType,
|
||||
Auth: newWebhook.Auth,
|
||||
DisableTlsClientAuth: newWebhook.DisableTlsClientAuth,
|
||||
}
|
||||
render.ProtoJSONStatus(w, whResponse, http.StatusCreated)
|
||||
}
|
|
@ -1,688 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
// ignore secret and id since those are set by the server
|
||||
func assertEqualWebhook(t *testing.T, a, b *linkedca.Webhook) {
|
||||
assert.Equal(t, a.Name, b.Name)
|
||||
assert.Equal(t, a.Url, b.Url)
|
||||
assert.Equal(t, a.Kind, b.Kind)
|
||||
assert.Equal(t, a.CertType, b.CertType)
|
||||
assert.Equal(t, a.DisableTlsClientAuth, b.DisableTlsClientAuth)
|
||||
|
||||
assert.Equal(t, a.GetAuth(), b.GetAuth())
|
||||
}
|
||||
|
||||
func TestWebhookAdminResponder_CreateProvisionerWebhook(t *testing.T) {
|
||||
type test struct {
|
||||
auth adminAuthority
|
||||
body []byte
|
||||
ctx context.Context
|
||||
err *admin.Error
|
||||
response *linkedca.Webhook
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/existing-webhook": func(t *testing.T) test {
|
||||
webhook := &linkedca.Webhook{
|
||||
Name: "already-exists",
|
||||
Url: "https://example.com",
|
||||
}
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{webhook},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
err := admin.NewError(admin.ErrorConflictType, `provisioner "provName" already has a webhook with the name "already-exists"`)
|
||||
err.Message = `provisioner "provName" already has a webhook with the name "already-exists"`
|
||||
body := []byte(`
|
||||
{
|
||||
"name": "already-exists",
|
||||
"url": "https://example.com",
|
||||
"kind": "ENRICHING"
|
||||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: err,
|
||||
statusCode: 409,
|
||||
}
|
||||
},
|
||||
"fail/read.ProtoJSON": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?")
|
||||
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
|
||||
body := []byte("{?}")
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/missing-name": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook name is required")
|
||||
adminErr.Message = "webhook name is required"
|
||||
body := []byte(`{"url": "https://example.com", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/missing-url": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid")
|
||||
adminErr.Message = "webhook url is invalid"
|
||||
body := []byte(`{"name": "metadata", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/relative-url": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid")
|
||||
adminErr.Message = "webhook url is invalid"
|
||||
body := []byte(`{"name": "metadata", "url": "example.com/path", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/http-url": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url must use https")
|
||||
adminErr.Message = "webhook url must use https"
|
||||
body := []byte(`{"name": "metadata", "url": "http://example.com", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/basic-auth-in-url": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url may not contain username or password")
|
||||
adminErr.Message = "webhook url may not contain username or password"
|
||||
body := []byte(`
|
||||
{
|
||||
"name": "metadata",
|
||||
"url": "https://user:pass@example.com",
|
||||
"kind": "ENRICHING"
|
||||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/secret-in-request": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook secret must not be set")
|
||||
adminErr.Message = "webhook secret must not be set"
|
||||
body := []byte(`
|
||||
{
|
||||
"name": "metadata",
|
||||
"url": "https://example.com",
|
||||
"kind": "ENRICHING",
|
||||
"secret": "secret"
|
||||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/unsupported-webhook-kind": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, `(line 5:13): invalid value for enum type: "UNSUPPORTED"`)
|
||||
adminErr.Message = `(line 5:13): invalid value for enum type: "UNSUPPORTED"`
|
||||
body := []byte(`
|
||||
{
|
||||
"name": "metadata",
|
||||
"url": "https://example.com",
|
||||
"kind": "UNSUPPORTED",
|
||||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/auth.UpdateProvisioner-error": func(t *testing.T) test {
|
||||
adm := &linkedca.Admin{
|
||||
Subject: "step",
|
||||
}
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithAdmin(context.Background(), adm)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
adminErr := admin.NewError(admin.ErrorServerInternalType, "error creating provisioner webhook: force")
|
||||
adminErr.Message = "error creating provisioner webhook: force"
|
||||
body := []byte(`{"name": "metadata", "url": "https://example.com", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return &authority.PolicyError{
|
||||
Typ: authority.StoreFailure,
|
||||
Err: errors.New("force"),
|
||||
}
|
||||
},
|
||||
},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
body := []byte(`{"name": "metadata", "url": "https://example.com", "kind": "ENRICHING", "certType": "X509"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
assert.Equal(t, linkedca.Webhook_X509, nu.Webhooks[0].CertType)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
body: body,
|
||||
response: &linkedca.Webhook{
|
||||
Name: "metadata",
|
||||
Url: "https://example.com",
|
||||
Kind: linkedca.Webhook_ENRICHING,
|
||||
CertType: linkedca.Webhook_X509,
|
||||
},
|
||||
statusCode: 201,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, &admin.MockDB{})
|
||||
war := NewWebhookAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
war.CreateProvisionerWebhook(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equal(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ae := testAdminError{}
|
||||
assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
|
||||
assert.Equal(t, tc.err.Type, ae.Type)
|
||||
assert.Equal(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equal(t, tc.err.Detail, ae.Detail)
|
||||
assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
// when the error message starts with "proto", we expect it to have
|
||||
// a syntax error (in the tests). If the message doesn't start with "proto",
|
||||
// we expect a full string match.
|
||||
if strings.HasPrefix(tc.err.Message, "proto:") {
|
||||
assert.True(t, strings.Contains(ae.Message, "syntax error"))
|
||||
} else {
|
||||
assert.Equal(t, tc.err.Message, ae.Message)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp := &linkedca.Webhook{}
|
||||
body, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, protojson.Unmarshal(body, resp))
|
||||
|
||||
assertEqualWebhook(t, tc.response, resp)
|
||||
assert.NotEmpty(t, resp.Secret)
|
||||
assert.NotEmpty(t, resp.Id)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookAdminResponder_DeleteProvisionerWebhook(t *testing.T) {
|
||||
type test struct {
|
||||
auth adminAuthority
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
provisionerWebhooks []*linkedca.Webhook
|
||||
webhookName string
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/auth.UpdateProvisioner-error": func(t *testing.T) test {
|
||||
adminErr := admin.NewError(admin.ErrorServerInternalType, "error deleting provisioner webhook: force")
|
||||
adminErr.Message = "error deleting provisioner webhook: force"
|
||||
return test{
|
||||
err: adminErr,
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return &authority.PolicyError{
|
||||
Typ: authority.StoreFailure,
|
||||
Err: errors.New("force"),
|
||||
}
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
webhookName: "my-webhook",
|
||||
provisionerWebhooks: []*linkedca.Webhook{
|
||||
{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
statusCode: 200,
|
||||
webhookName: "no-exists",
|
||||
provisionerWebhooks: nil,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
statusCode: 200,
|
||||
webhookName: "exists",
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
assert.Equal(t, nu.Webhooks, []*linkedca.Webhook{
|
||||
{Name: "my-2nd-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING},
|
||||
})
|
||||
return nil
|
||||
},
|
||||
},
|
||||
provisionerWebhooks: []*linkedca.Webhook{
|
||||
{Name: "exists", Url: "https.example.com", Kind: linkedca.Webhook_ENRICHING},
|
||||
{Name: "my-2nd-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("webhookName", tc.webhookName)
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: tc.provisionerWebhooks,
|
||||
}
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
ctx = admin.NewContext(ctx, &admin.MockDB{})
|
||||
req := httptest.NewRequest("DELETE", "/foo", nil).WithContext(ctx)
|
||||
|
||||
war := NewWebhookAdminResponder()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
war.DeleteProvisionerWebhook(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equal(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ae := testAdminError{}
|
||||
assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
|
||||
assert.Equal(t, tc.err.Type, ae.Type)
|
||||
assert.Equal(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equal(t, tc.err.Detail, ae.Detail)
|
||||
assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
// when the error message starts with "proto", we expect it to have
|
||||
// a syntax error (in the tests). If the message doesn't start with "proto",
|
||||
// we expect a full string match.
|
||||
if strings.HasPrefix(tc.err.Message, "proto:") {
|
||||
assert.True(t, strings.Contains(ae.Message, "syntax error"))
|
||||
} else {
|
||||
assert.Equal(t, tc.err.Message, ae.Message)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
res.Body.Close()
|
||||
response := DeleteResponse{}
|
||||
assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response))
|
||||
assert.Equal(t, "ok", response.Status)
|
||||
assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookAdminResponder_UpdateProvisionerWebhook(t *testing.T) {
|
||||
type test struct {
|
||||
auth adminAuthority
|
||||
adminDB admin.DB
|
||||
body []byte
|
||||
ctx context.Context
|
||||
err *admin.Error
|
||||
response *linkedca.Webhook
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "exists", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, `provisioner "provName" has no webhook with the name "no-exists"`)
|
||||
err.Message = `provisioner "provName" has no webhook with the name "no-exists"`
|
||||
body := []byte(`
|
||||
{
|
||||
"name": "no-exists",
|
||||
"url": "https://example.com",
|
||||
"kind": "ENRICHING"
|
||||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
},
|
||||
"fail/read.ProtoJSON": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?")
|
||||
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
|
||||
body := []byte("{?}")
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/missing-name": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook name is required")
|
||||
adminErr.Message = "webhook name is required"
|
||||
body := []byte(`{"url": "https://example.com", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/missing-url": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid")
|
||||
adminErr.Message = "webhook url is invalid"
|
||||
body := []byte(`{"name": "metadata", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/relative-url": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid")
|
||||
adminErr.Message = "webhook url is invalid"
|
||||
body := []byte(`{"name": "metadata", "url": "example.com/path", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/http-url": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url must use https")
|
||||
adminErr.Message = "webhook url must use https"
|
||||
body := []byte(`{"name": "metadata", "url": "http://example.com", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/basic-auth-in-url": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url may not contain username or password")
|
||||
adminErr.Message = "webhook url may not contain username or password"
|
||||
body := []byte(`
|
||||
{
|
||||
"name": "my-webhook",
|
||||
"url": "https://user:pass@example.com",
|
||||
"kind": "ENRICHING"
|
||||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/different-secret-in-request": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING, Secret: "c2VjcmV0"}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook secret cannot be updated")
|
||||
adminErr.Message = "webhook secret cannot be updated"
|
||||
body := []byte(`
|
||||
{
|
||||
"name": "my-webhook",
|
||||
"url": "https://example.com",
|
||||
"kind": "ENRICHING",
|
||||
"secret": "secret"
|
||||
}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"fail/auth.UpdateProvisioner-error": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating provisioner webhook: force")
|
||||
adminErr.Message = "error updating provisioner webhook: force"
|
||||
body := []byte(`{"name": "my-webhook", "url": "https://example.com", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return &authority.PolicyError{
|
||||
Typ: authority.StoreFailure,
|
||||
Err: errors.New("force"),
|
||||
}
|
||||
},
|
||||
},
|
||||
body: body,
|
||||
err: adminErr,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Name: "provName",
|
||||
Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
body := []byte(`{"name": "my-webhook", "url": "https://example.com", "kind": "ENRICHING"}`)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: &mockAdminAuthority{
|
||||
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
body: body,
|
||||
response: &linkedca.Webhook{
|
||||
Name: "my-webhook",
|
||||
Url: "https://example.com",
|
||||
Kind: linkedca.Webhook_ENRICHING,
|
||||
},
|
||||
statusCode: 201,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
war := NewWebhookAdminResponder()
|
||||
|
||||
req := httptest.NewRequest("PUT", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
war.UpdateProvisionerWebhook(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equal(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ae := testAdminError{}
|
||||
assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
|
||||
assert.Equal(t, tc.err.Type, ae.Type)
|
||||
assert.Equal(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equal(t, tc.err.Detail, ae.Detail)
|
||||
assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
// when the error message starts with "proto", we expect it to have
|
||||
// a syntax error (in the tests). If the message doesn't start with "proto",
|
||||
// we expect a full string match.
|
||||
if strings.HasPrefix(tc.err.Message, "proto:") {
|
||||
assert.True(t, strings.Contains(ae.Message, "syntax error"))
|
||||
} else {
|
||||
assert.Equal(t, tc.err.Message, ae.Message)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp := &linkedca.Webhook{}
|
||||
body, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, protojson.Unmarshal(body, resp))
|
||||
|
||||
assertEqualWebhook(t, tc.response, resp)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -69,34 +69,6 @@ type DB interface {
|
|||
GetAdmins(ctx context.Context) ([]*linkedca.Admin, error)
|
||||
UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error
|
||||
DeleteAdmin(ctx context.Context, id string) error
|
||||
|
||||
CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
|
||||
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
|
||||
UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
|
||||
DeleteAuthorityPolicy(ctx context.Context) error
|
||||
}
|
||||
|
||||
type dbKey struct{}
|
||||
|
||||
// NewContext adds the given admin database to the context.
|
||||
func NewContext(ctx context.Context, db DB) context.Context {
|
||||
return context.WithValue(ctx, dbKey{}, db)
|
||||
}
|
||||
|
||||
// FromContext returns the current admin database from the given context.
|
||||
func FromContext(ctx context.Context) (db DB, ok bool) {
|
||||
db, ok = ctx.Value(dbKey{}).(DB)
|
||||
return
|
||||
}
|
||||
|
||||
// MustFromContext returns the current admin database from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustFromContext(ctx context.Context) DB {
|
||||
if db, ok := FromContext(ctx); !ok {
|
||||
panic("admin database is not in the context")
|
||||
} else {
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
|
@ -114,11 +86,6 @@ type MockDB struct {
|
|||
MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error
|
||||
MockDeleteAdmin func(ctx context.Context, id string) error
|
||||
|
||||
MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
|
||||
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
|
||||
MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
|
||||
MockDeleteAuthorityPolicy func(ctx context.Context) error
|
||||
|
||||
MockError error
|
||||
MockRet1 interface{}
|
||||
}
|
||||
|
@ -212,35 +179,3 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error {
|
|||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateAuthorityPolicy mock
|
||||
func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
if m.MockCreateAuthorityPolicy != nil {
|
||||
return m.MockCreateAuthorityPolicy(ctx, policy)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAuthorityPolicy mock
|
||||
func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
if m.MockGetAuthorityPolicy != nil {
|
||||
return m.MockGetAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Policy), m.MockError
|
||||
}
|
||||
|
||||
// UpdateAuthorityPolicy mock
|
||||
func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
if m.MockUpdateAuthorityPolicy != nil {
|
||||
return m.MockUpdateAuthorityPolicy(ctx, policy)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// DeleteAuthorityPolicy mock
|
||||
func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error {
|
||||
if m.MockDeleteAuthorityPolicy != nil {
|
||||
return m.MockDeleteAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func (dba *dbAdmin) clone() *dbAdmin {
|
|||
return &u
|
||||
}
|
||||
|
||||
func (db *DB) getDBAdminBytes(_ context.Context, id string) ([]byte, error) {
|
||||
func (db *DB) getDBAdminBytes(ctx context.Context, id string) ([]byte, error) {
|
||||
data, err := db.db.Get(adminsTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)
|
||||
|
@ -102,7 +102,7 @@ func (db *DB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error)
|
|||
// GetAdmins retrieves and unmarshals all active (not deleted) admins
|
||||
// from the database.
|
||||
// TODO should we be paginating?
|
||||
func (db *DB) GetAdmins(context.Context) ([]*linkedca.Admin, error) {
|
||||
func (db *DB) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||
dbEntries, err := db.db.List(adminsTable)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error loading admins")
|
||||
|
@ -111,14 +111,16 @@ func (db *DB) GetAdmins(context.Context) ([]*linkedca.Admin, error) {
|
|||
for _, entry := range dbEntries {
|
||||
adm, err := db.unmarshalAdmin(entry.Value, string(entry.Key))
|
||||
if err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
if ae.IsType(admin.ErrorDeletedType) || ae.IsType(admin.ErrorAuthorityMismatchType) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if k.IsType(admin.ErrorDeletedType) || k.IsType(admin.ErrorAuthorityMismatchType) {
|
||||
continue
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if adm.AuthorityId != db.authorityID {
|
||||
continue
|
||||
|
|
|
@ -68,16 +68,16 @@ func TestDB_getDBAdminBytes(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if b, err := d.getDBAdminBytes(context.Background(), adminID); err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if assert.NotNil(t, tc.adminErr) {
|
||||
assert.Equals(t, ae.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -192,16 +192,16 @@ func TestDB_getDBAdmin(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||
if dba, err := d.getDBAdmin(context.Background(), adminID); err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if assert.NotNil(t, tc.adminErr) {
|
||||
assert.Equals(t, ae.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -280,16 +280,16 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||
if dba, err := d.unmarshalDBAdmin(tc.in, adminID); err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if assert.NotNil(t, tc.adminErr) {
|
||||
assert.Equals(t, ae.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -355,16 +355,16 @@ func TestDB_unmarshalAdmin(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||
if adm, err := d.unmarshalAdmin(tc.in, adminID); err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if assert.NotNil(t, tc.adminErr) {
|
||||
assert.Equals(t, ae.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -509,16 +509,16 @@ func TestDB_GetAdmin(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||
if adm, err := d.GetAdmin(context.Background(), adminID); err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if assert.NotNil(t, tc.adminErr) {
|
||||
assert.Equals(t, ae.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -661,16 +661,16 @@ func TestDB_DeleteAdmin(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||
if err := d.DeleteAdmin(context.Background(), adminID); err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if assert.NotNil(t, tc.adminErr) {
|
||||
assert.Equals(t, ae.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -812,16 +812,16 @@ func TestDB_UpdateAdmin(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||
if err := d.UpdateAdmin(context.Background(), tc.adm); err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if assert.NotNil(t, tc.adminErr) {
|
||||
assert.Equals(t, ae.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -910,16 +910,16 @@ func TestDB_CreateAdmin(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||
if err := d.CreateAdmin(context.Background(), tc.adm); err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if assert.NotNil(t, tc.adminErr) {
|
||||
assert.Equals(t, ae.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
@ -1086,16 +1086,16 @@ func TestDB_GetAdmins(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||
if admins, err := d.GetAdmins(context.Background()); err != nil {
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if assert.NotNil(t, tc.adminErr) {
|
||||
assert.Equals(t, ae.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, ae.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, ae.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Type, tc.adminErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.adminErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.adminErr.Detail)
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue