forked from TrueCloudLab/certificates
Merge branch 'master' into hs/scep
This commit is contained in:
commit
0487686f69
62 changed files with 10768 additions and 11360 deletions
|
@ -1,5 +1,3 @@
|
||||||
README.md
|
|
||||||
.gitignore
|
|
||||||
bin
|
bin
|
||||||
coverage.txt
|
coverage.txt
|
||||||
*.test
|
*.test
|
||||||
|
|
37
.github/workflows/release.yml
vendored
37
.github/workflows/release.yml
vendored
|
@ -10,6 +10,9 @@ jobs:
|
||||||
test:
|
test:
|
||||||
name: Lint, Test, Build
|
name: Lint, Test, Build
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
go: [ '1.15', '1.16' ]
|
||||||
outputs:
|
outputs:
|
||||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||||
steps:
|
steps:
|
||||||
|
@ -20,15 +23,39 @@ jobs:
|
||||||
name: Setup Go
|
name: Setup Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: '1.15.8'
|
go-version: ${{ matrix.go }}
|
||||||
-
|
-
|
||||||
name: Install Deps
|
name: Install Deps
|
||||||
id: install-deps
|
id: install-deps
|
||||||
run: sudo apt-get -y install libpcsclite-dev
|
run: sudo apt-get -y install libpcsclite-dev
|
||||||
-
|
-
|
||||||
name: Lint, Test, Build
|
name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v2
|
||||||
|
with:
|
||||||
|
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||||
|
version: 'latest'
|
||||||
|
|
||||||
|
# Optional: working directory, useful for monorepos
|
||||||
|
# working-directory: somedir
|
||||||
|
|
||||||
|
# Optional: golangci-lint command line arguments.
|
||||||
|
args: --timeout=30m
|
||||||
|
|
||||||
|
# Optional: show only new issues if it's a pull request. The default value is `false`.
|
||||||
|
# only-new-issues: true
|
||||||
|
|
||||||
|
# Optional: if set to true then the action will use pre-installed Go.
|
||||||
|
# skip-go-installation: true
|
||||||
|
|
||||||
|
# Optional: if set to true then the action don't cache or restore ~/go/pkg.
|
||||||
|
# skip-pkg-cache: true
|
||||||
|
|
||||||
|
# Optional: if set to true then the action don't cache or restore ~/.cache/go-build.
|
||||||
|
# skip-build-cache: true
|
||||||
|
-
|
||||||
|
name: Test, Build
|
||||||
id: lint_test_build
|
id: lint_test_build
|
||||||
run: V=1 make -j1 bootstrap ci
|
run: V=1 make ci
|
||||||
|
|
||||||
create_release:
|
create_release:
|
||||||
name: Create Release
|
name: Create Release
|
||||||
|
@ -96,7 +123,7 @@ jobs:
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: '1.15.8'
|
go-version: '1.16'
|
||||||
-
|
-
|
||||||
name: APT Install
|
name: APT Install
|
||||||
id: aptInstall
|
id: aptInstall
|
||||||
|
@ -126,7 +153,7 @@ jobs:
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: '1.15.8'
|
go-version: '1.16'
|
||||||
- name: Build
|
- name: Build
|
||||||
id: build
|
id: build
|
||||||
run: |
|
run: |
|
||||||
|
|
51
.github/workflows/test.yml
vendored
51
.github/workflows/test.yml
vendored
|
@ -11,24 +11,55 @@ on:
|
||||||
jobs:
|
jobs:
|
||||||
lintTestBuild:
|
lintTestBuild:
|
||||||
name: Lint, Test, Build
|
name: Lint, Test, Build
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-20.04
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
go: [ '1.15', '1.16' ]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
-
|
||||||
|
name: Checkout
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
- name: Setup Go
|
-
|
||||||
|
name: Setup Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: '1.15.6'
|
go-version: ${{ matrix.go }}
|
||||||
- name: Install Deps
|
-
|
||||||
|
name: Install Deps
|
||||||
id: install-deps
|
id: install-deps
|
||||||
run: sudo apt-get -y install libpcsclite-dev
|
run: sudo apt-get -y install libpcsclite-dev
|
||||||
- name: Lint, Test, Build
|
-
|
||||||
id: lintTestBuild
|
name: golangci-lint
|
||||||
run: V=1 make -j1 bootstrap ci
|
uses: golangci/golangci-lint-action@v2
|
||||||
- name: Codecov
|
with:
|
||||||
|
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||||
|
version: 'latest'
|
||||||
|
|
||||||
|
# Optional: working directory, useful for monorepos
|
||||||
|
# working-directory: somedir
|
||||||
|
|
||||||
|
# Optional: golangci-lint command line arguments.
|
||||||
|
args: --timeout=30m
|
||||||
|
|
||||||
|
# Optional: show only new issues if it's a pull request. The default value is `false`.
|
||||||
|
# only-new-issues: true
|
||||||
|
|
||||||
|
# Optional: if set to true then the action will use pre-installed Go.
|
||||||
|
# skip-go-installation: true
|
||||||
|
|
||||||
|
# Optional: if set to true then the action don't cache or restore ~/go/pkg.
|
||||||
|
# skip-pkg-cache: true
|
||||||
|
|
||||||
|
# Optional: if set to true then the action don't cache or restore ~/.cache/go-build.
|
||||||
|
# skip-build-cache: true
|
||||||
|
-
|
||||||
|
name: Test, Build
|
||||||
|
id: lint_test_build
|
||||||
|
run: V=1 make ci
|
||||||
|
-
|
||||||
|
name: Codecov
|
||||||
uses: codecov/codecov-action@v1.2.1
|
uses: codecov/codecov-action@v1.2.1
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
|
|
||||||
file: ./coverage.out # optional
|
file: ./coverage.out # optional
|
||||||
name: codecov-umbrella # optional
|
name: codecov-umbrella # optional
|
||||||
fail_ci_if_error: true # optional (default = false)
|
fail_ci_if_error: true # optional (default = false)
|
||||||
|
|
8
Makefile
8
Makefile
|
@ -18,7 +18,7 @@ OUTPUT_ROOT=output/
|
||||||
|
|
||||||
all: lint test build
|
all: lint test build
|
||||||
|
|
||||||
ci: lintcgo testcgo build
|
ci: testcgo build
|
||||||
|
|
||||||
.PHONY: all ci
|
.PHONY: all ci
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ ci: lintcgo testcgo build
|
||||||
|
|
||||||
bootstra%:
|
bootstra%:
|
||||||
# Using a released version of golangci-lint to take into account custom replacements in their go.mod
|
# Using a released version of golangci-lint to take into account custom replacements in their go.mod
|
||||||
$Q GO111MODULE=on go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.24.0
|
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.39.0
|
||||||
|
|
||||||
.PHONY: bootstra%
|
.PHONY: bootstra%
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ bootstra%:
|
||||||
|
|
||||||
# If TRAVIS_TAG is set then we know this ref has been tagged.
|
# If TRAVIS_TAG is set then we know this ref has been tagged.
|
||||||
ifdef TRAVIS_TAG
|
ifdef TRAVIS_TAG
|
||||||
VERSION := $(TRAVIS_TAG)
|
VERSION ?= $(TRAVIS_TAG)
|
||||||
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
|
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
|
||||||
ifeq ($(NOT_RC),)
|
ifeq ($(NOT_RC),)
|
||||||
PUSHTYPE := release-candidate
|
PUSHTYPE := release-candidate
|
||||||
|
@ -47,7 +47,7 @@ PUSHTYPE := release
|
||||||
endif
|
endif
|
||||||
# GITHUB Actions
|
# GITHUB Actions
|
||||||
else ifdef GITHUB_REF
|
else ifdef GITHUB_REF
|
||||||
VERSION := $(shell echo $(GITHUB_REF) | sed 's/^refs\/tags\///')
|
VERSION ?= $(shell echo $(GITHUB_REF) | sed 's/^refs\/tags\///')
|
||||||
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
|
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
|
||||||
ifeq ($(NOT_RC),)
|
ifeq ($(NOT_RC),)
|
||||||
PUSHTYPE := release-candidate
|
PUSHTYPE := release-candidate
|
||||||
|
|
268
README.md
268
README.md
|
@ -22,8 +22,7 @@ Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [
|
||||||
|
|
||||||
[Website](https://smallstep.com/certificates) |
|
[Website](https://smallstep.com/certificates) |
|
||||||
[Documentation](https://smallstep.com/docs) |
|
[Documentation](https://smallstep.com/docs) |
|
||||||
[Installation Guide](#installation-guide) |
|
[Installation](https://smallstep.com/docs/step-ca/installation) |
|
||||||
[Quickstart](#quickstart) |
|
|
||||||
[Getting Started](https://smallstep.com/docs/step-ca/getting-started) |
|
[Getting Started](https://smallstep.com/docs/step-ca/getting-started) |
|
||||||
[Contributor's Guide](./docs/CONTRIBUTING.md)
|
[Contributor's Guide](./docs/CONTRIBUTING.md)
|
||||||
|
|
||||||
|
@ -103,270 +102,9 @@ ACME is the protocol used by Let's Encrypt to automate the issuance of HTTPS cer
|
||||||
- [Install root certificates](https://smallstep.com/docs/step-cli/reference/certificate/install/) on your machine and browsers, so your CA is trusted
|
- [Install root certificates](https://smallstep.com/docs/step-cli/reference/certificate/install/) on your machine and browsers, so your CA is trusted
|
||||||
- [Inspect](https://smallstep.com/docs/step-cli/reference/certificate/inspect/) and [lint](https://smallstep.com/docs/step-cli/reference/certificate/lint/) certificates
|
- [Inspect](https://smallstep.com/docs/step-cli/reference/certificate/inspect/) and [lint](https://smallstep.com/docs/step-cli/reference/certificate/lint/) certificates
|
||||||
|
|
||||||
## Installation Guide
|
## Installation
|
||||||
|
|
||||||
These instructions will install an OS specific version of the `step-ca` binary on
|
See our installation docs [here](https://smallstep.com/docs/step-ca/installation).
|
||||||
your local machine.
|
|
||||||
|
|
||||||
Want to build from source? See [our contributor's guide](./docs/CONTRIBUTING.md)
|
|
||||||
|
|
||||||
### Mac OS
|
|
||||||
|
|
||||||
Install `step` and `step-ca` together, via [Homebrew](https://brew.sh/):
|
|
||||||
|
|
||||||
```
|
|
||||||
$ brew install step
|
|
||||||
```
|
|
||||||
|
|
||||||
### Linux
|
|
||||||
|
|
||||||
> **Note:** The [`step` CLI tool](https://github.com/smallstep/cli) is the easiest way to initialize, configure, and control `step-ca`. While `step` is not technically required to run `step-ca`, it is very much recommended.
|
|
||||||
|
|
||||||
#### Debian
|
|
||||||
|
|
||||||
1. Install `step`.
|
|
||||||
|
|
||||||
Download the Debian package from the
|
|
||||||
[latest `step` release](https://github.com/smallstep/cli/releases/latest):
|
|
||||||
|
|
||||||
```
|
|
||||||
$ wget https://github.com/smallstep/cli/releases/download/vX.Y.Z/step-cli_X.Y.Z_amd64.deb
|
|
||||||
```
|
|
||||||
|
|
||||||
Install the Debian package:
|
|
||||||
|
|
||||||
```
|
|
||||||
$ sudo dpkg -i step-cli_X.Y.Z_amd64.deb
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Install `step-ca`.
|
|
||||||
|
|
||||||
Download the Debian package from the [latest `step-ca` release](https://github.com/smallstep/certificates/releases/latest):
|
|
||||||
|
|
||||||
```
|
|
||||||
$ wget https://github.com/smallstep/certificates/releases/download/vX.Y.Z/step-ca_X.Y.Z_amd64.deb
|
|
||||||
```
|
|
||||||
|
|
||||||
Install the Debian package:
|
|
||||||
|
|
||||||
```
|
|
||||||
$ sudo dpkg -i step-ca_X.Y.Z_amd64.deb
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Arch Linux
|
|
||||||
|
|
||||||
We are using the [Arch User Repository](https://aur.archlinux.org) to distribute
|
|
||||||
`step` binaries for Arch Linux.
|
|
||||||
|
|
||||||
* The `step` binary tarball can be found [here](https://aur.archlinux.org/packages/step-cli-bin/).
|
|
||||||
* The `step-ca` binary tarball can be found [here](https://aur.archlinux.org/packages/step-ca-bin/).
|
|
||||||
|
|
||||||
You can use [pacman](https://www.archlinux.org/pacman/) to install the packages.
|
|
||||||
|
|
||||||
#### RHEL/CentOS
|
|
||||||
|
|
||||||
1. Install `step`.
|
|
||||||
|
|
||||||
Download the Linux tarball from the
|
|
||||||
[latest `step` release](https://github.com/smallstep/cli/releases/latest):
|
|
||||||
|
|
||||||
```
|
|
||||||
$ wget -O step-cli.tar.gz https://github.com/smallstep/cli/releases/download/vX.Y.Z/step_linux_X.Y.Z_amd64.tar.gz
|
|
||||||
```
|
|
||||||
|
|
||||||
Install `step` by unzipping and copying the executable over to `/usr/bin`:
|
|
||||||
|
|
||||||
```
|
|
||||||
$ tar -xf step-cli.tar.gz
|
|
||||||
$ sudo cp step_X.Y.Z/bin/step /usr/bin
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Install `step-ca`.
|
|
||||||
|
|
||||||
Download the Linux package from the [latest `step-ca` release](https://github.com/smallstep/certificates/releases/latest):
|
|
||||||
|
|
||||||
```
|
|
||||||
$ wget -O step-ca.tar.gz https://github.com/smallstep/certificates/releases/download/vX.Y.Z/step-ca_linux_X.Y.Z_amd64.tar.gz
|
|
||||||
```
|
|
||||||
|
|
||||||
Install `step-ca` by unzipping and copying the executable over to `/usr/bin`:
|
|
||||||
|
|
||||||
```
|
|
||||||
$ tar -xf step-ca.tar.gz
|
|
||||||
$ sudo cp step-ca_X.Y.Z/bin/step-ca /usr/bin
|
|
||||||
```
|
|
||||||
|
|
||||||
See the [`systemctl` setup section](https://smallstep.com/docs/step-ca/certificate-authority-server-production#running-step-ca-as-a-daemon) for a
|
|
||||||
guide on configuring `step-ca` as a daemon.
|
|
||||||
|
|
||||||
### Kubernetes
|
|
||||||
|
|
||||||
We publish [helm charts](https://hub.helm.sh/charts/smallstep/step-certificates) for easy installation on kubernetes:
|
|
||||||
|
|
||||||
```
|
|
||||||
helm install step-certificates
|
|
||||||
```
|
|
||||||
|
|
||||||
> <a href="https://github.com/smallstep/autocert"><img width="25%" src="https://raw.githubusercontent.com/smallstep/autocert/master/autocert-logo.png"></a>
|
|
||||||
>
|
|
||||||
> If you're using Kubernetes, make sure you [check out
|
|
||||||
> autocert](https://github.com/smallstep/autocert): a kubernetes add-on that builds on `step
|
|
||||||
> certificates` to automatically inject TLS/HTTPS certificates into your containers.
|
|
||||||
|
|
||||||
### Docker
|
|
||||||
|
|
||||||
See our [Docker getting started guide](https://smallstep.com/docs/tutorials/docker-tls-certificate-authority)
|
|
||||||
|
|
||||||
### Test
|
|
||||||
|
|
||||||
<pre><code><b>$ step version</b>
|
|
||||||
Smallstep CLI/0.10.0 (darwin/amd64)
|
|
||||||
Release Date: 2019-04-30 19:01 UTC
|
|
||||||
|
|
||||||
<b>$ step-ca version</b>
|
|
||||||
Smallstep CA/0.10.0 (darwin/amd64)
|
|
||||||
Release Date: 2019-04-30 19:02 UTC</code></pre>
|
|
||||||
|
|
||||||
## Quickstart
|
|
||||||
|
|
||||||
In the following guide we'll run a simple `hello` server that requires clients
|
|
||||||
to connect over an authorized and encrypted channel using HTTPS. `step-ca`
|
|
||||||
will issue certificates to our server, allowing it to authenticate and encrypt
|
|
||||||
communication.
|
|
||||||
|
|
||||||
![Animated terminal showing step certificates in practice](https://github.com/smallstep/certificates/raw/master/docs/images/step-ca-2-legged.gif)
|
|
||||||
|
|
||||||
Let's get started!
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
* [`step`](#installation-guide)
|
|
||||||
* [golang](https://golang.org/doc/install)
|
|
||||||
|
|
||||||
### Let's get started!
|
|
||||||
|
|
||||||
#### 1. Run `step ca init` to create your CA's keys & certificates and configure `step-ca`:
|
|
||||||
|
|
||||||
<pre><code><b>$ step ca init</b>
|
|
||||||
✔ What would you like to name your new PKI? (e.g. Smallstep): <b>Example Inc.</b>
|
|
||||||
✔ What DNS names or IP addresses would you like to add to your new CA? (e.g. ca.smallstep.com[,1.1.1.1,etc.]): <b>localhost</b>
|
|
||||||
✔ What address will your new CA listen at? (e.g. :443): <b>127.0.0.1:8080</b>
|
|
||||||
✔ What would you like to name the first provisioner for your new CA? (e.g. you@smallstep.com): <b>bob@example.com</b>
|
|
||||||
✔ What do you want your password to be? [leave empty and we'll generate one]: <b>abc123</b>
|
|
||||||
|
|
||||||
Generating root certificate...
|
|
||||||
all done!
|
|
||||||
|
|
||||||
Generating intermediate certificate...
|
|
||||||
all done!
|
|
||||||
|
|
||||||
✔ Root certificate: /Users/bob/src/github.com/smallstep/step/.step/certs/root_ca.crt
|
|
||||||
✔ Root private key: /Users/bob/src/github.com/smallstep/step/.step/secrets/root_ca_key
|
|
||||||
✔ Root fingerprint: 702a094e239c9eec6f0dcd0a5f65e595bf7ed6614012825c5fe3d1ae1b2fd6ee
|
|
||||||
✔ Intermediate certificate: /Users/bob/src/github.com/smallstep/step/.step/certs/intermediate_ca.crt
|
|
||||||
✔ Intermediate private key: /Users/bob/src/github.com/smallstep/step/.step/secrets/intermediate_ca_key
|
|
||||||
✔ Default configuration: /Users/bob/src/github.com/smallstep/step/.step/config/defaults.json
|
|
||||||
✔ Certificate Authority configuration: /Users/bob/src/github.com/smallstep/step/.step/config/ca.json
|
|
||||||
|
|
||||||
Your PKI is ready to go. To generate certificates for individual services see 'step help ca'.</code></pre>
|
|
||||||
|
|
||||||
This command will:
|
|
||||||
|
|
||||||
- Generate [password protected](https://github.com/smallstep/certificates/blob/master/docs/GETTING_STARTED.md#passwords) private keys for your CA to sign certificates
|
|
||||||
- Generate a root and [intermediate signing certificate](https://security.stackexchange.com/questions/128779/why-is-it-more-secure-to-use-intermediate-ca-certificates) for your CA
|
|
||||||
- Create a JSON configuration file for `step-ca` (see [configuration docs](https://smallstep.com/docs/step-ca/configuration) for details)
|
|
||||||
|
|
||||||
You can find these artifacts in `$STEPPATH` (or `~/.step` by default).
|
|
||||||
|
|
||||||
#### 2. Start `step-ca`:
|
|
||||||
|
|
||||||
You'll be prompted for your password from the previous step, to decrypt the CA's private signing key:
|
|
||||||
|
|
||||||
<pre><code><b>$ step-ca $(step path)/config/ca.json</b>
|
|
||||||
Please enter the password to decrypt /Users/bob/src/github.com/smallstep/step/.step/secrets/intermediate_ca_key: <b>abc123</b>
|
|
||||||
2019/02/18 13:28:58 Serving HTTPS on 127.0.0.1:8080 ...</code></pre>
|
|
||||||
|
|
||||||
#### 3. Copy our `hello world` golang server.
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cat > srv.go <<EOF
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"log"
|
|
||||||
)
|
|
||||||
|
|
||||||
func HiHandler(w http.ResponseWriter, req *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
w.Write([]byte("Hello, world!\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
http.HandleFunc("/hi", HiHandler)
|
|
||||||
err := http.ListenAndServeTLS(":8443", "srv.crt", "srv.key", nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
EOF
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 4. Get an identity for your server from the Step CA.
|
|
||||||
|
|
||||||
<pre><code><b>$ step ca certificate localhost srv.crt srv.key</b>
|
|
||||||
✔ Key ID: rQxROEr7Kx9TNjSQBTETtsu3GKmuW9zm02dMXZ8GUEk (bob@example.com)
|
|
||||||
✔ Please enter the password to decrypt the provisioner key: abc123
|
|
||||||
✔ CA: https://localhost:8080/1.0/sign
|
|
||||||
✔ Certificate: srv.crt
|
|
||||||
✔ Private Key: srv.key
|
|
||||||
|
|
||||||
<b>$ step certificate inspect --bundle srv.crt</b>
|
|
||||||
Certificate:
|
|
||||||
Data:
|
|
||||||
Version: 3 (0x2)
|
|
||||||
Serial Number: 140439335711218707689123407681832384336 (0x69a7a1d7f6f22f68059d2d9088307750)
|
|
||||||
Signature Algorithm: ECDSA-SHA256
|
|
||||||
Issuer: CN=Example Inc. Intermediate CA
|
|
||||||
Validity
|
|
||||||
Not Before: Feb 18 21:32:35 2019 UTC
|
|
||||||
Not After : Feb 19 21:32:35 2019 UTC
|
|
||||||
Subject: CN=localhost
|
|
||||||
...
|
|
||||||
Certificate:
|
|
||||||
Data:
|
|
||||||
Version: 3 (0x2)
|
|
||||||
Serial Number: 207035091234452090159026162349261226844 (0x9bc18217bd560cf07db23178ed90835c)
|
|
||||||
Signature Algorithm: ECDSA-SHA256
|
|
||||||
Issuer: CN=Example Inc. Root CA
|
|
||||||
Validity
|
|
||||||
Not Before: Feb 18 21:27:21 2019 UTC
|
|
||||||
Not After : Feb 15 21:27:21 2029 UTC
|
|
||||||
Subject: CN=Example Inc. Intermediate CA
|
|
||||||
...</code></pre>
|
|
||||||
|
|
||||||
Note that `step` and `step-ca` handle details like [certificate bundling](https://smallstep.com/blog/everything-pki.html#intermediates-chains-and-bundling) for you.
|
|
||||||
|
|
||||||
#### 5. Run the simple server.
|
|
||||||
|
|
||||||
<pre><code><b>$ go run srv.go &</b></code></pre>
|
|
||||||
|
|
||||||
#### 6. Get the root certificate from the Step CA.
|
|
||||||
|
|
||||||
In a new Terminal window:
|
|
||||||
|
|
||||||
<pre><code><b>$ step ca root root.crt</b>
|
|
||||||
The root certificate has been saved in root.crt.</code></pre>
|
|
||||||
|
|
||||||
#### 7. Make an authenticated, encrypted curl request to your server using HTTP over TLS.
|
|
||||||
|
|
||||||
<pre><code><b>$ curl --cacert root.crt https://localhost:8443/hi</b>
|
|
||||||
Hello, world!</code></pre>
|
|
||||||
|
|
||||||
*All Done!*
|
|
||||||
|
|
||||||
Check out the [Getting Started](./docs/GETTING_STARTED.md) guide for more examples
|
|
||||||
and best practices on running Step CA in production.
|
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
|
|
179
acme/account.go
179
acme/account.go
|
@ -1,197 +1,42 @@
|
||||||
package acme
|
package acme
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"crypto"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account is a subset of the internal account type containing only those
|
// Account is a subset of the internal account type containing only those
|
||||||
// attributes required for responses in the ACME protocol.
|
// attributes required for responses in the ACME protocol.
|
||||||
type Account struct {
|
type Account struct {
|
||||||
Contact []string `json:"contact,omitempty"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Orders string `json:"orders"`
|
|
||||||
ID string `json:"-"`
|
ID string `json:"-"`
|
||||||
Key *jose.JSONWebKey `json:"-"`
|
Key *jose.JSONWebKey `json:"-"`
|
||||||
|
Contact []string `json:"contact,omitempty"`
|
||||||
|
Status Status `json:"status"`
|
||||||
|
OrdersURL string `json:"orders"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToLog enables response logging.
|
// ToLog enables response logging.
|
||||||
func (a *Account) ToLog() (interface{}, error) {
|
func (a *Account) ToLog() (interface{}, error) {
|
||||||
b, err := json.Marshal(a)
|
b, err := json.Marshal(a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging"))
|
return nil, WrapErrorISE(err, "error marshaling account for logging")
|
||||||
}
|
}
|
||||||
return string(b), nil
|
return string(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the account ID.
|
|
||||||
func (a *Account) GetID() string {
|
|
||||||
return a.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetKey returns the JWK associated with the account.
|
|
||||||
func (a *Account) GetKey() *jose.JSONWebKey {
|
|
||||||
return a.Key
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValid returns true if the Account is valid.
|
// IsValid returns true if the Account is valid.
|
||||||
func (a *Account) IsValid() bool {
|
func (a *Account) IsValid() bool {
|
||||||
return a.Status == StatusValid
|
return Status(a.Status) == StatusValid
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountOptions are the options needed to create a new ACME account.
|
// KeyToID converts a JWK to a thumbprint.
|
||||||
type AccountOptions struct {
|
func KeyToID(jwk *jose.JSONWebKey) (string, error) {
|
||||||
Key *jose.JSONWebKey
|
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||||
Contact []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// account represents an ACME account.
|
|
||||||
type account struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Created time.Time `json:"created"`
|
|
||||||
Deactivated time.Time `json:"deactivated"`
|
|
||||||
Key *jose.JSONWebKey `json:"key"`
|
|
||||||
Contact []string `json:"contact,omitempty"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// newAccount returns a new acme account type.
|
|
||||||
func newAccount(db nosql.DB, ops AccountOptions) (*account, error) {
|
|
||||||
id, err := randID()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", WrapErrorISE(err, "error generating jwk thumbprint")
|
||||||
}
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||||
a := &account{
|
|
||||||
ID: id,
|
|
||||||
Key: ops.Key,
|
|
||||||
Contact: ops.Contact,
|
|
||||||
Status: "valid",
|
|
||||||
Created: clock.Now(),
|
|
||||||
}
|
|
||||||
return a, a.saveNew(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
// toACME converts the internal Account type into the public acmeAccount
|
|
||||||
// type for presentation in the ACME protocol.
|
|
||||||
func (a *account) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Account, error) {
|
|
||||||
return &Account{
|
|
||||||
Status: a.Status,
|
|
||||||
Contact: a.Contact,
|
|
||||||
Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID),
|
|
||||||
Key: a.Key,
|
|
||||||
ID: a.ID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// save writes the Account to the DB.
|
|
||||||
// If the account is new then the necessary indices will be created.
|
|
||||||
// Else, the account in the DB will be updated.
|
|
||||||
func (a *account) saveNew(db nosql.DB) error {
|
|
||||||
kid, err := keyToID(a.Key)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
kidB := []byte(kid)
|
|
||||||
|
|
||||||
// Set the jwkID -> acme account ID index
|
|
||||||
_, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID))
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index"))
|
|
||||||
case !swapped:
|
|
||||||
return ServerInternalErr(errors.Errorf("key-id to account-id index already exists"))
|
|
||||||
default:
|
|
||||||
if err = a.save(db, nil); err != nil {
|
|
||||||
db.Del(accountByKeyIDTable, kidB)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *account) save(db nosql.DB, old *account) error {
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
oldB []byte
|
|
||||||
)
|
|
||||||
if old == nil {
|
|
||||||
oldB = nil
|
|
||||||
} else {
|
|
||||||
if oldB, err = json.Marshal(old); err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
b, err := json.Marshal(*a)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "error marshaling new account object")
|
|
||||||
}
|
|
||||||
// Set the Account
|
|
||||||
_, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b)
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error storing account"))
|
|
||||||
case !swapped:
|
|
||||||
return ServerInternalErr(errors.New("error storing account; " +
|
|
||||||
"value has changed since last read"))
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// update updates the acme account object stored in the database if,
|
|
||||||
// and only if, the account has not changed since the last read.
|
|
||||||
func (a *account) update(db nosql.DB, contact []string) (*account, error) {
|
|
||||||
b := *a
|
|
||||||
b.Contact = contact
|
|
||||||
if err := b.save(db, a); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &b, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// deactivate deactivates the acme account.
|
|
||||||
func (a *account) deactivate(db nosql.DB) (*account, error) {
|
|
||||||
b := *a
|
|
||||||
b.Status = StatusDeactivated
|
|
||||||
b.Deactivated = clock.Now()
|
|
||||||
if err := b.save(db, a); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &b, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAccountByID retrieves the account with the given ID.
|
|
||||||
func getAccountByID(db nosql.DB, id string) (*account, error) {
|
|
||||||
ab, err := db.Get(accountTable, []byte(id))
|
|
||||||
if err != nil {
|
|
||||||
if nosql.IsErrNotFound(err) {
|
|
||||||
return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id))
|
|
||||||
}
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id))
|
|
||||||
}
|
|
||||||
|
|
||||||
a := new(account)
|
|
||||||
if err = json.Unmarshal(ab, a); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account"))
|
|
||||||
}
|
|
||||||
return a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAccountByKeyID retrieves Id associated with the given Kid.
|
|
||||||
func getAccountByKeyID(db nosql.DB, kid string) (*account, error) {
|
|
||||||
id, err := db.Get(accountByKeyIDTable, []byte(kid))
|
|
||||||
if err != nil {
|
|
||||||
if nosql.IsErrNotFound(err) {
|
|
||||||
return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid))
|
|
||||||
}
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index"))
|
|
||||||
}
|
|
||||||
return getAccountByID(db, string(id))
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,770 +1,81 @@
|
||||||
package acme
|
package acme
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"crypto"
|
||||||
"encoding/json"
|
"encoding/base64"
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
"github.com/smallstep/nosql/database"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
func TestKeyToID(t *testing.T) {
|
||||||
defaultDisableRenewal = false
|
|
||||||
globalProvisionerClaims = provisioner.Claims{
|
|
||||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
|
||||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
|
||||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
|
||||||
DisableRenewal: &defaultDisableRenewal,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func newProv() Provisioner {
|
|
||||||
// Initialize provisioners
|
|
||||||
p := &provisioner.ACME{
|
|
||||||
Type: "ACME",
|
|
||||||
Name: "test@acme-provisioner.com",
|
|
||||||
}
|
|
||||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
|
||||||
fmt.Printf("%v", err)
|
|
||||||
}
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAcc() (*account, error) {
|
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mockdb := &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return newAccount(mockdb, AccountOptions{
|
|
||||||
Key: jwk, Contact: []string{"foo", "bar"},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAccountByID(t *testing.T) {
|
|
||||||
type test struct {
|
type test struct {
|
||||||
id string
|
jwk *jose.JSONWebKey
|
||||||
db nosql.DB
|
exp string
|
||||||
acc *account
|
|
||||||
err *Error
|
err *Error
|
||||||
}
|
}
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/not-found": func(t *testing.T) test {
|
"fail/error-generating-thumbprint": func(t *testing.T) test {
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
id: acc.ID,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
return nil, database.ErrNotFound
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/db-error": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
id: acc.ID,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/unmarshal-error": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
id: acc.ID,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.ID))
|
|
||||||
return nil, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
id: acc.ID,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.ID))
|
|
||||||
return b, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if acc, err := getAccountByID(tc.db, tc.id); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
|
||||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
|
||||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
|
||||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
|
||||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
|
||||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAccountByKeyID(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
kid string
|
|
||||||
db nosql.DB
|
|
||||||
acc *account
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/kid-not-found": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
kid: "foo",
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
return nil, database.ErrNotFound
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: MalformedErr(errors.Errorf("account with key id foo not found: not found")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/db-error": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
kid: "foo",
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error loading key-account index: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/getAccount-error": func(t *testing.T) test {
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
kid: "foo",
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
if count == 0 {
|
|
||||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
|
||||||
assert.Equals(t, key, []byte("foo"))
|
|
||||||
count++
|
|
||||||
return []byte("bar"), nil
|
|
||||||
}
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error loading account bar: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
kid: acc.Key.KeyID,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
var ret []byte
|
|
||||||
switch count {
|
|
||||||
case 0:
|
|
||||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.Key.KeyID))
|
|
||||||
ret = []byte(acc.ID)
|
|
||||||
case 1:
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.ID))
|
|
||||||
ret = b
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return ret, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
acc: acc,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
|
||||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
|
||||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
|
||||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
|
||||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
|
||||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountToACME(t *testing.T) {
|
|
||||||
dir := newDirectory("ca.smallstep.com", "acme")
|
|
||||||
prov := newProv()
|
|
||||||
provName := url.PathEscape(prov.GetName())
|
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
|
||||||
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
|
|
||||||
|
|
||||||
type test struct {
|
|
||||||
acc *account
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{acc: acc}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
tc := run(t)
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
acmeAccount, err := tc.acc.toACME(ctx, nil, dir)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, acmeAccount.ID, tc.acc.ID)
|
|
||||||
assert.Equals(t, acmeAccount.Status, tc.acc.Status)
|
|
||||||
assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
|
|
||||||
assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
|
|
||||||
assert.Equals(t, acmeAccount.Orders,
|
|
||||||
fmt.Sprintf("%s/acme/%s/account/%s/orders", baseURL.String(), provName, tc.acc.ID))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountSave(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
acc, old *account
|
|
||||||
db nosql.DB
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
old: nil,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
old: nil,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return []byte("foo"), false, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error storing account; value has changed since last read")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/old-nil": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
old: nil,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
assert.Equals(t, b, newval)
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, []byte(acc.ID), key)
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/old-not-nil": func(t *testing.T) test {
|
|
||||||
oldAcc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
oldb, err := json.Marshal(oldAcc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
old: oldAcc,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, old, oldb)
|
|
||||||
assert.Equals(t, newval, b)
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, []byte(acc.ID), key)
|
|
||||||
return []byte("foo"), true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if err := tc.acc.save(tc.db, tc.old); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert.Nil(t, tc.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountSaveNew(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
acc *account
|
|
||||||
db nosql.DB
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/keyToID-error": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
acc.Key.Key = "foo"
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/swap-error": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
kid, err := keyToID(acc.Key)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
|
||||||
assert.Equals(t, key, []byte(kid))
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
assert.Equals(t, newval, []byte(acc.ID))
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/swap-false": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
kid, err := keyToID(acc.Key)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
|
||||||
assert.Equals(t, key, []byte(kid))
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
assert.Equals(t, newval, []byte(acc.ID))
|
|
||||||
return nil, false, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("key-id to account-id index already exists")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/save-error": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
kid, err := keyToID(acc.Key)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 0 {
|
|
||||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
|
||||||
assert.Equals(t, key, []byte(kid))
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
assert.Equals(t, newval, []byte(acc.ID))
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
}
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.ID))
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
assert.Equals(t, newval, b)
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
MDel: func(bucket, key []byte) error {
|
|
||||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
|
||||||
assert.Equals(t, key, []byte(kid))
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
kid, err := keyToID(acc.Key)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 0 {
|
|
||||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
|
||||||
assert.Equals(t, key, []byte(kid))
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
assert.Equals(t, newval, []byte(acc.ID))
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
}
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.ID))
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
assert.Equals(t, newval, b)
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if err := tc.acc.saveNew(tc.db); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert.Nil(t, tc.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountUpdate(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
acc *account
|
|
||||||
contact []string
|
|
||||||
db nosql.DB
|
|
||||||
res []byte
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
contact := []string{"foo", "bar"}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/save-error": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
oldb, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
_acc := *acc
|
|
||||||
clone := &_acc
|
|
||||||
clone.Contact = contact
|
|
||||||
b, err := json.Marshal(clone)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
contact: contact,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.ID))
|
|
||||||
assert.Equals(t, old, oldb)
|
|
||||||
assert.Equals(t, newval, b)
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
oldb, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
_acc := *acc
|
|
||||||
clone := &_acc
|
|
||||||
clone.Contact = contact
|
|
||||||
b, err := json.Marshal(clone)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
contact: contact,
|
|
||||||
res: b,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.ID))
|
|
||||||
assert.Equals(t, old, oldb)
|
|
||||||
assert.Equals(t, newval, b)
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
acc, err := tc.acc.update(tc.db, tc.contact)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
b, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
assert.Equals(t, b, tc.res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountDeactivate(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
acc *account
|
|
||||||
db nosql.DB
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/save-error": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
oldb, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.ID))
|
|
||||||
assert.Equals(t, old, oldb)
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
acc, err := newAcc()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
oldb, err := json.Marshal(acc)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
return test{
|
|
||||||
acc: acc,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, key, []byte(acc.ID))
|
|
||||||
assert.Equals(t, old, oldb)
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
acc, err := tc.acc.deactivate(tc.db)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, acc.ID, tc.acc.ID)
|
|
||||||
assert.Equals(t, acc.Contact, tc.acc.Contact)
|
|
||||||
assert.Equals(t, acc.Status, StatusDeactivated)
|
|
||||||
assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID)
|
|
||||||
assert.Equals(t, acc.Created, tc.acc.Created)
|
|
||||||
|
|
||||||
assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute)))
|
|
||||||
assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewAccount(t *testing.T) {
|
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
kid, err := keyToID(jwk)
|
jwk.Key = "foo"
|
||||||
assert.FatalError(t, err)
|
|
||||||
ops := AccountOptions{
|
|
||||||
Key: jwk,
|
|
||||||
Contact: []string{"foo", "bar"},
|
|
||||||
}
|
|
||||||
type test struct {
|
|
||||||
ops AccountOptions
|
|
||||||
db nosql.DB
|
|
||||||
err *Error
|
|
||||||
id *string
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/store-error": func(t *testing.T) test {
|
|
||||||
return test{
|
return test{
|
||||||
ops: ops,
|
jwk: jwk,
|
||||||
db: &db.MockNoSQLDB{
|
err: NewErrorISE("error generating jwk thumbprint: square/go-jose: unknown key type 'string'"),
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
var _id string
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
id := &_id
|
assert.FatalError(t, err)
|
||||||
count := 0
|
|
||||||
|
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
return test{
|
return test{
|
||||||
ops: ops,
|
jwk: jwk,
|
||||||
db: &db.MockNoSQLDB{
|
exp: base64.RawURLEncoding.EncodeToString(kid),
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
switch count {
|
|
||||||
case 0:
|
|
||||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
|
||||||
assert.Equals(t, key, []byte(kid))
|
|
||||||
case 1:
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
*id = string(key)
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
id: id,
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
acc, err := newAccount(tc.db, tc.ops)
|
tc := run(t)
|
||||||
if err != nil {
|
if id, err := KeyToID(tc.jwk); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
ae, ok := err.(*Error)
|
switch k := err.(type) {
|
||||||
assert.True(t, ok)
|
case *Error:
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
assert.Equals(t, k.Type, tc.err.Type)
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
assert.Equals(t, k.Status, tc.err.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.New("unexpected error type"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if assert.Nil(t, tc.err) {
|
if assert.Nil(t, tc.err) {
|
||||||
assert.Equals(t, acc.ID, *tc.id)
|
assert.Equals(t, id, tc.exp)
|
||||||
assert.Equals(t, acc.Status, StatusValid)
|
|
||||||
assert.Equals(t, acc.Contact, ops.Contact)
|
|
||||||
assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID)
|
|
||||||
|
|
||||||
assert.True(t, acc.Deactivated.IsZero())
|
|
||||||
|
|
||||||
assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute)))
|
|
||||||
assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_IsValid(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
acc *Account
|
||||||
|
exp bool
|
||||||
|
}
|
||||||
|
tests := map[string]test{
|
||||||
|
"valid": {acc: &Account{Status: StatusValid}, exp: true},
|
||||||
|
"invalid": {acc: &Account{Status: StatusInvalid}, exp: false},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert.Equals(t, tc.acc.IsValid(), tc.exp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
|
@ -21,7 +20,7 @@ type NewAccountRequest struct {
|
||||||
func validateContacts(cs []string) error {
|
func validateContacts(cs []string) error {
|
||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
if len(c) == 0 {
|
if len(c) == 0 {
|
||||||
return acme.MalformedErr(errors.New("contact cannot be empty string"))
|
return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -30,7 +29,7 @@ func validateContacts(cs []string) error {
|
||||||
// Validate validates a new-account request body.
|
// Validate validates a new-account request body.
|
||||||
func (n *NewAccountRequest) Validate() error {
|
func (n *NewAccountRequest) Validate() error {
|
||||||
if n.OnlyReturnExisting && len(n.Contact) > 0 {
|
if n.OnlyReturnExisting && len(n.Contact) > 0 {
|
||||||
return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone"))
|
return acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone")
|
||||||
}
|
}
|
||||||
return validateContacts(n.Contact)
|
return validateContacts(n.Contact)
|
||||||
}
|
}
|
||||||
|
@ -38,21 +37,15 @@ func (n *NewAccountRequest) Validate() error {
|
||||||
// UpdateAccountRequest represents an update-account request.
|
// UpdateAccountRequest represents an update-account request.
|
||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Contact []string `json:"contact"`
|
Contact []string `json:"contact"`
|
||||||
Status string `json:"status"`
|
Status acme.Status `json:"status"`
|
||||||
}
|
|
||||||
|
|
||||||
// IsDeactivateRequest returns true if the update request is a deactivation
|
|
||||||
// request, false otherwise.
|
|
||||||
func (u *UpdateAccountRequest) IsDeactivateRequest() bool {
|
|
||||||
return u.Status == acme.StatusDeactivated
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates a update-account request body.
|
// Validate validates a update-account request body.
|
||||||
func (u *UpdateAccountRequest) Validate() error {
|
func (u *UpdateAccountRequest) Validate() error {
|
||||||
switch {
|
switch {
|
||||||
case len(u.Status) > 0 && len(u.Contact) > 0:
|
case len(u.Status) > 0 && len(u.Contact) > 0:
|
||||||
return acme.MalformedErr(errors.New("incompatible input; contact and " +
|
return acme.NewError(acme.ErrorMalformedType, "incompatible input; contact and "+
|
||||||
"status updates are mutually exclusive"))
|
"status updates are mutually exclusive")
|
||||||
case len(u.Contact) > 0:
|
case len(u.Contact) > 0:
|
||||||
if err := validateContacts(u.Contact); err != nil {
|
if err := validateContacts(u.Contact); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -60,8 +53,8 @@ func (u *UpdateAccountRequest) Validate() error {
|
||||||
return nil
|
return nil
|
||||||
case len(u.Status) > 0:
|
case len(u.Status) > 0:
|
||||||
if u.Status != acme.StatusDeactivated {
|
if u.Status != acme.StatusDeactivated {
|
||||||
return acme.MalformedErr(errors.Errorf("cannot update account "+
|
return acme.NewError(acme.ErrorMalformedType, "cannot update account "+
|
||||||
"status to %s, only deactivated", u.Status))
|
"status to %s, only deactivated", u.Status)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
|
@ -73,15 +66,16 @@ func (u *UpdateAccountRequest) Validate() error {
|
||||||
|
|
||||||
// NewAccount is the handler resource for creating new ACME accounts.
|
// NewAccount is the handler resource for creating new ACME accounts.
|
||||||
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
payload, err := payloadFromContext(r.Context())
|
ctx := r.Context()
|
||||||
|
payload, err := payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var nar NewAccountRequest
|
var nar NewAccountRequest
|
||||||
if err := json.Unmarshal(payload.value, &nar); err != nil {
|
if err := json.Unmarshal(payload.value, &nar); err != nil {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||||
"failed to unmarshal new-account request payload")))
|
"failed to unmarshal new-account request payload"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := nar.Validate(); err != nil {
|
if err := nar.Validate(); err != nil {
|
||||||
|
@ -90,7 +84,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
httpStatus := http.StatusCreated
|
httpStatus := http.StatusCreated
|
||||||
acc, err := acme.AccountFromContext(r.Context())
|
acc, err := accountFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
acmeErr, ok := err.(*acme.Error)
|
acmeErr, ok := err.(*acme.Error)
|
||||||
if !ok || acmeErr.Status != http.StatusBadRequest {
|
if !ok || acmeErr.Status != http.StatusBadRequest {
|
||||||
|
@ -101,20 +95,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// Account does not exist //
|
// Account does not exist //
|
||||||
if nar.OnlyReturnExisting {
|
if nar.OnlyReturnExisting {
|
||||||
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
|
||||||
|
"account does not exist"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwk, err := acme.JwkFromContext(r.Context())
|
jwk, err := jwkFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if acc, err = h.Auth.NewAccount(r.Context(), acme.AccountOptions{
|
acc = &acme.Account{
|
||||||
Key: jwk,
|
Key: jwk,
|
||||||
Contact: nar.Contact,
|
Contact: nar.Contact,
|
||||||
}); err != nil {
|
Status: acme.StatusValid,
|
||||||
api.WriteError(w, err)
|
}
|
||||||
|
if err := h.db.CreateAccount(ctx, acc); err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error creating account"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -122,19 +119,21 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
httpStatus = http.StatusOK
|
httpStatus = http.StatusOK
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink,
|
h.linker.LinkAccount(ctx, acc)
|
||||||
true, acc.GetID()))
|
|
||||||
|
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
|
||||||
api.JSONStatus(w, acc, httpStatus)
|
api.JSONStatus(w, acc, httpStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUpdateAccount is the api for updating an ACME account.
|
// GetOrUpdateAccount is the api for updating an ACME account.
|
||||||
func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
acc, err := acme.AccountFromContext(r.Context())
|
ctx := r.Context()
|
||||||
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payload, err := payloadFromContext(r.Context())
|
payload, err := payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
|
@ -145,29 +144,31 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
if !payload.isPostAsGet {
|
if !payload.isPostAsGet {
|
||||||
var uar UpdateAccountRequest
|
var uar UpdateAccountRequest
|
||||||
if err := json.Unmarshal(payload.value, &uar); err != nil {
|
if err := json.Unmarshal(payload.value, &uar); err != nil {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload")))
|
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||||
|
"failed to unmarshal new-account request payload"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := uar.Validate(); err != nil {
|
if err := uar.Validate(); err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var err error
|
if len(uar.Status) > 0 || len(uar.Contact) > 0 {
|
||||||
// If neither the status nor the contacts are being updated then ignore
|
if len(uar.Status) > 0 {
|
||||||
// the updates and return 200. This conforms with the behavior detailed
|
acc.Status = uar.Status
|
||||||
// in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2).
|
|
||||||
if uar.IsDeactivateRequest() {
|
|
||||||
acc, err = h.Auth.DeactivateAccount(r.Context(), acc.GetID())
|
|
||||||
} else if len(uar.Contact) > 0 {
|
} else if len(uar.Contact) > 0 {
|
||||||
acc, err = h.Auth.UpdateAccount(r.Context(), acc.GetID(), uar.Contact)
|
acc.Contact = uar.Contact
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
api.WriteError(w, err)
|
if err := h.db.UpdateAccount(ctx, acc); err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error updating account"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink,
|
}
|
||||||
true, acc.GetID()))
|
|
||||||
|
h.linker.LinkAccount(ctx, acc)
|
||||||
|
|
||||||
|
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
|
||||||
api.JSON(w, acc)
|
api.JSON(w, acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,23 +181,27 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
|
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
|
||||||
func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||||
acc, err := acme.AccountFromContext(r.Context())
|
ctx := r.Context()
|
||||||
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
accID := chi.URLParam(r, "accID")
|
accID := chi.URLParam(r, "accID")
|
||||||
if acc.ID != accID {
|
if acc.ID != accID {
|
||||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param")))
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID())
|
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.linker.LinkOrdersByAccountID(ctx, orders)
|
||||||
|
|
||||||
api.JSON(w, orders)
|
api.JSON(w, orders)
|
||||||
logOrdersByAccount(w, orders)
|
logOrdersByAccount(w, orders)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
@ -29,11 +28,11 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func newProv() provisioner.Interface {
|
func newProv() acme.Provisioner {
|
||||||
// Initialize provisioners
|
// Initialize provisioners
|
||||||
p := &provisioner.ACME{
|
p := &provisioner.ACME{
|
||||||
Type: "ACME",
|
Type: "ACME",
|
||||||
Name: "test@acme-provisioner.com",
|
Name: "test@acme-<test>provisioner.com",
|
||||||
}
|
}
|
||||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||||
fmt.Printf("%v", err)
|
fmt.Printf("%v", err)
|
||||||
|
@ -41,7 +40,7 @@ func newProv() provisioner.Interface {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewAccountRequestValidate(t *testing.T) {
|
func TestNewAccountRequest_Validate(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
nar *NewAccountRequest
|
nar *NewAccountRequest
|
||||||
err *acme.Error
|
err *acme.Error
|
||||||
|
@ -53,7 +52,7 @@ func TestNewAccountRequestValidate(t *testing.T) {
|
||||||
OnlyReturnExisting: true,
|
OnlyReturnExisting: true,
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
},
|
},
|
||||||
err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")),
|
err: acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/bad-contact": func(t *testing.T) test {
|
"fail/bad-contact": func(t *testing.T) test {
|
||||||
|
@ -61,7 +60,7 @@ func TestNewAccountRequestValidate(t *testing.T) {
|
||||||
nar: &NewAccountRequest{
|
nar: &NewAccountRequest{
|
||||||
Contact: []string{"foo", ""},
|
Contact: []string{"foo", ""},
|
||||||
},
|
},
|
||||||
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
|
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
|
@ -97,7 +96,7 @@ func TestNewAccountRequestValidate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateAccountRequestValidate(t *testing.T) {
|
func TestUpdateAccountRequest_Validate(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
uar *UpdateAccountRequest
|
uar *UpdateAccountRequest
|
||||||
err *acme.Error
|
err *acme.Error
|
||||||
|
@ -109,8 +108,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) {
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
Status: "foo",
|
Status: "foo",
|
||||||
},
|
},
|
||||||
err: acme.MalformedErr(errors.Errorf("incompatible input; " +
|
err: acme.NewError(acme.ErrorMalformedType, "incompatible input; "+
|
||||||
"contact and status updates are mutually exclusive")),
|
"contact and status updates are mutually exclusive"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/bad-contact": func(t *testing.T) test {
|
"fail/bad-contact": func(t *testing.T) test {
|
||||||
|
@ -118,7 +117,7 @@ func TestUpdateAccountRequestValidate(t *testing.T) {
|
||||||
uar: &UpdateAccountRequest{
|
uar: &UpdateAccountRequest{
|
||||||
Contact: []string{"foo", ""},
|
Contact: []string{"foo", ""},
|
||||||
},
|
},
|
||||||
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
|
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/bad-status": func(t *testing.T) test {
|
"fail/bad-status": func(t *testing.T) test {
|
||||||
|
@ -126,8 +125,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) {
|
||||||
uar: &UpdateAccountRequest{
|
uar: &UpdateAccountRequest{
|
||||||
Status: "foo",
|
Status: "foo",
|
||||||
},
|
},
|
||||||
err: acme.MalformedErr(errors.Errorf("cannot update account " +
|
err: acme.NewError(acme.ErrorMalformedType, "cannot update account "+
|
||||||
"status to foo, only deactivated")),
|
"status to foo, only deactivated"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/contact": func(t *testing.T) test {
|
"ok/contact": func(t *testing.T) test {
|
||||||
|
@ -168,81 +167,81 @@ func TestUpdateAccountRequestValidate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerGetOrdersByAccount(t *testing.T) {
|
func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
||||||
oids := []string{
|
|
||||||
"https://ca.smallstep.com/acme/order/foo",
|
|
||||||
"https://ca.smallstep.com/acme/order/bar",
|
|
||||||
}
|
|
||||||
accID := "account-id"
|
accID := "account-id"
|
||||||
prov := newProv()
|
|
||||||
|
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("accID", accID)
|
chiCtx.URLParams.Add("accID", accID)
|
||||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID)
|
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
|
||||||
|
|
||||||
|
oids := []string{"foo", "bar"}
|
||||||
|
oidURLs := []string{
|
||||||
|
fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName),
|
||||||
|
fmt.Sprintf("%s/acme/%s/order/bar", baseURL.String(), provName),
|
||||||
|
}
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
auth acme.Interface
|
db acme.DB
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
statusCode int
|
statusCode int
|
||||||
problem *acme.Error
|
err *acme.Error
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{},
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
|
ctx: context.Background(),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, nil)
|
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{},
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: context.WithValue(context.Background(), accContextKey, nil),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "foo"}
|
acc := &acme.Account{ID: "foo"}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{},
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 401,
|
statusCode: 401,
|
||||||
problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")),
|
err: acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/getOrdersByAccount-error": func(t *testing.T) test {
|
"fail/db.GetOrdersByAccountID-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: accID}
|
acc := &acme.Account{ID: accID}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
err: acme.ServerInternalErr(errors.New("force")),
|
MockError: acme.NewErrorISE("force"),
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("force")),
|
err: acme.NewErrorISE("force"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: accID}
|
acc := &acme.Account{ID: accID}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
getOrdersByAccount: func(ctx context.Context, id string) ([]string, error) {
|
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
|
||||||
p, err := acme.ProvisionerFromContext(ctx)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
assert.Equals(t, p, prov)
|
|
||||||
assert.Equals(t, id, acc.ID)
|
assert.Equals(t, id, acc.ID)
|
||||||
return oids, nil
|
return oids, nil
|
||||||
},
|
},
|
||||||
|
@ -255,11 +254,11 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := New(tc.auth).(*Handler)
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetOrdersByAccount(w, req)
|
h.GetOrdersByAccountID(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -268,18 +267,17 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
var ae acme.AError
|
var ae acme.Error
|
||||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
prob := tc.problem.ToACME()
|
|
||||||
|
|
||||||
assert.Equals(t, ae.Type, prob.Type)
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
assert.Equals(t, ae.Detail, prob.Detail)
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
} else {
|
} else {
|
||||||
expB, err := json.Marshal(oids)
|
expB, err := json.Marshal(oidURLs)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
@ -288,47 +286,41 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerNewAccount(t *testing.T) {
|
func TestHandler_NewAccount(t *testing.T) {
|
||||||
accID := "accountID"
|
|
||||||
acc := acme.Account{
|
|
||||||
ID: accID,
|
|
||||||
Status: "valid",
|
|
||||||
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
|
|
||||||
}
|
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
provName := url.PathEscape(prov.GetName())
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
auth acme.Interface
|
db acme.DB
|
||||||
|
acc *acme.Account
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
statusCode int
|
statusCode int
|
||||||
problem *acme.Error
|
err *acme.Error
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-payload": func(t *testing.T) test {
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
|
ctx: context.Background(),
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-payload": func(t *testing.T) test {
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), payloadContextKey, nil)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
|
err: acme.NewError(acme.ErrorMalformedType, "failed to "+
|
||||||
|
"unmarshal new-account request payload: unexpected end of JSON input"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||||
|
@ -337,12 +329,11 @@ func TestHandlerNewAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
|
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/no-existing-account": func(t *testing.T) test {
|
"fail/no-existing-account": func(t *testing.T) test {
|
||||||
|
@ -351,12 +342,11 @@ func TestHandlerNewAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/no-jwk": func(t *testing.T) test {
|
"fail/no-jwk": func(t *testing.T) test {
|
||||||
|
@ -365,12 +355,11 @@ func TestHandlerNewAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
|
err: acme.NewErrorISE("jwk expected in request context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-jwk": func(t *testing.T) test {
|
"fail/nil-jwk": func(t *testing.T) test {
|
||||||
|
@ -379,16 +368,15 @@ func TestHandlerNewAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||||
ctx = context.WithValue(ctx, acme.JwkContextKey, nil)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
|
err: acme.NewErrorISE("jwk expected in request context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/NewAccount-error": func(t *testing.T) test {
|
"fail/db.CreateAccount-error": func(t *testing.T) test {
|
||||||
nar := &NewAccountRequest{
|
nar := &NewAccountRequest{
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
}
|
}
|
||||||
|
@ -396,23 +384,19 @@ func TestHandlerNewAccount(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
|
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) {
|
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||||
p, err := acme.ProvisionerFromContext(ctx)
|
assert.Equals(t, acc.Contact, nar.Contact)
|
||||||
assert.FatalError(t, err)
|
assert.Equals(t, acc.Key, jwk)
|
||||||
assert.Equals(t, p, prov)
|
return acme.NewErrorISE("force")
|
||||||
assert.Equals(t, ops.Contact, nar.Contact)
|
|
||||||
assert.Equals(t, ops.Key, jwk)
|
|
||||||
return nil, acme.ServerInternalErr(errors.New("force"))
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("force")),
|
err: acme.NewErrorISE("force"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/new-account": func(t *testing.T) test {
|
"ok/new-account": func(t *testing.T) test {
|
||||||
|
@ -423,28 +407,25 @@ func TestHandlerNewAccount(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
|
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) {
|
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||||
p, err := acme.ProvisionerFromContext(ctx)
|
acc.ID = "accountID"
|
||||||
assert.FatalError(t, err)
|
assert.Equals(t, acc.Contact, nar.Contact)
|
||||||
assert.Equals(t, p, prov)
|
assert.Equals(t, acc.Key, jwk)
|
||||||
assert.Equals(t, ops.Contact, nar.Contact)
|
return nil
|
||||||
assert.Equals(t, ops.Key, jwk)
|
|
||||||
return &acc, nil
|
|
||||||
},
|
},
|
||||||
getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
|
|
||||||
assert.Equals(t, typ, acme.AccountLink)
|
|
||||||
assert.True(t, abs)
|
|
||||||
assert.True(t, abs)
|
|
||||||
assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx))
|
|
||||||
return fmt.Sprintf("%s/acme/%s/account/%s",
|
|
||||||
baseURL.String(), provName, accID)
|
|
||||||
},
|
},
|
||||||
|
acc: &acme.Account{
|
||||||
|
ID: "accountID",
|
||||||
|
Key: jwk,
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName),
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 201,
|
statusCode: 201,
|
||||||
|
@ -456,22 +437,21 @@ func TestHandlerNewAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(nar)
|
b, err := json.Marshal(nar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
assert.FatalError(t, err)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
acc := &acme.Account{
|
||||||
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
|
ID: "accountID",
|
||||||
|
Key: jwk,
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
|
||||||
getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
|
|
||||||
assert.Equals(t, typ, acme.AccountLink)
|
|
||||||
assert.True(t, abs)
|
|
||||||
assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx))
|
|
||||||
assert.Equals(t, ins, []string{accID})
|
|
||||||
return fmt.Sprintf("%s/acme/%s/account/%s",
|
|
||||||
baseURL.String(), provName, accID)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
acc: acc,
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -479,7 +459,7 @@ func TestHandlerNewAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := New(tc.auth).(*Handler)
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -492,90 +472,85 @@ func TestHandlerNewAccount(t *testing.T) {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
var ae acme.AError
|
var ae acme.Error
|
||||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
prob := tc.problem.ToACME()
|
|
||||||
|
|
||||||
assert.Equals(t, ae.Type, prob.Type)
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
assert.Equals(t, ae.Detail, prob.Detail)
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
} else {
|
} else {
|
||||||
expB, err := json.Marshal(acc)
|
expB, err := json.Marshal(tc.acc)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Location"],
|
assert.Equals(t, res.Header["Location"],
|
||||||
[]string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
|
[]string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
|
||||||
provName, accID)})
|
escProvName, "accountID")})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerGetUpdateAccount(t *testing.T) {
|
func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
||||||
accID := "accountID"
|
accID := "accountID"
|
||||||
acc := acme.Account{
|
acc := acme.Account{
|
||||||
ID: accID,
|
ID: accID,
|
||||||
Status: "valid",
|
Status: "valid",
|
||||||
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
|
OrdersURL: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
|
||||||
}
|
}
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
provName := url.PathEscape(prov.GetName())
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
auth acme.Interface
|
db acme.DB
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
statusCode int
|
statusCode int
|
||||||
problem *acme.Error
|
err *acme.Error
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
|
ctx: context.Background(),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, nil)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/no-payload": func(t *testing.T) test {
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-payload": func(t *testing.T) test {
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
|
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||||
|
@ -584,62 +559,33 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(uar)
|
b, err := json.Marshal(uar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
|
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/Deactivate-error": func(t *testing.T) test {
|
"fail/db.UpdateAccount-error": func(t *testing.T) test {
|
||||||
uar := &UpdateAccountRequest{
|
uar := &UpdateAccountRequest{
|
||||||
Status: "deactivated",
|
Status: "deactivated",
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(uar)
|
b, err := json.Marshal(uar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||||
p, err := acme.ProvisionerFromContext(ctx)
|
assert.Equals(t, upd.Status, acme.StatusDeactivated)
|
||||||
assert.FatalError(t, err)
|
assert.Equals(t, upd.ID, acc.ID)
|
||||||
assert.Equals(t, p, prov)
|
return acme.NewErrorISE("force")
|
||||||
assert.Equals(t, id, accID)
|
|
||||||
return nil, acme.ServerInternalErr(errors.New("force"))
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("force")),
|
err: acme.NewErrorISE("force"),
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/UpdateAccount-error": func(t *testing.T) test {
|
|
||||||
uar := &UpdateAccountRequest{
|
|
||||||
Contact: []string{"foo", "bar"},
|
|
||||||
}
|
|
||||||
b, err := json.Marshal(uar)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
|
||||||
return test{
|
|
||||||
auth: &mockAcmeAuthority{
|
|
||||||
updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) {
|
|
||||||
p, err := acme.ProvisionerFromContext(ctx)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
assert.Equals(t, p, prov)
|
|
||||||
assert.Equals(t, id, accID)
|
|
||||||
assert.Equals(t, contacts, uar.Contact)
|
|
||||||
return nil, acme.ServerInternalErr(errors.New("force"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ctx: ctx,
|
|
||||||
statusCode: 500,
|
|
||||||
problem: acme.ServerInternalErr(errors.New("force")),
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/deactivate": func(t *testing.T) test {
|
"ok/deactivate": func(t *testing.T) test {
|
||||||
|
@ -648,26 +594,16 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(uar)
|
b, err := json.Marshal(uar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||||
p, err := acme.ProvisionerFromContext(ctx)
|
assert.Equals(t, upd.Status, acme.StatusDeactivated)
|
||||||
assert.FatalError(t, err)
|
assert.Equals(t, upd.ID, acc.ID)
|
||||||
assert.Equals(t, p, prov)
|
return nil
|
||||||
assert.Equals(t, id, accID)
|
|
||||||
return &acc, nil
|
|
||||||
},
|
|
||||||
getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
|
|
||||||
assert.Equals(t, typ, acme.AccountLink)
|
|
||||||
assert.True(t, abs)
|
|
||||||
assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
|
|
||||||
assert.Equals(t, ins, []string{accID})
|
|
||||||
return fmt.Sprintf("%s/acme/%s/account/%s",
|
|
||||||
baseURL.String(), provName, accID)
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -678,21 +614,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||||
uar := &UpdateAccountRequest{}
|
uar := &UpdateAccountRequest{}
|
||||||
b, err := json.Marshal(uar)
|
b, err := json.Marshal(uar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
|
||||||
getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
|
|
||||||
assert.Equals(t, typ, acme.AccountLink)
|
|
||||||
assert.True(t, abs)
|
|
||||||
assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
|
|
||||||
assert.Equals(t, ins, []string{accID})
|
|
||||||
return fmt.Sprintf("%s/acme/%s/account/%s",
|
|
||||||
baseURL.String(), provName, accID)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
|
@ -703,27 +629,16 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(uar)
|
b, err := json.Marshal(uar)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) {
|
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||||
p, err := acme.ProvisionerFromContext(ctx)
|
assert.Equals(t, upd.Contact, uar.Contact)
|
||||||
assert.FatalError(t, err)
|
assert.Equals(t, upd.ID, acc.ID)
|
||||||
assert.Equals(t, p, prov)
|
return nil
|
||||||
assert.Equals(t, id, accID)
|
|
||||||
assert.Equals(t, contacts, uar.Contact)
|
|
||||||
return &acc, nil
|
|
||||||
},
|
|
||||||
getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
|
|
||||||
assert.Equals(t, typ, acme.AccountLink)
|
|
||||||
assert.True(t, abs)
|
|
||||||
assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
|
|
||||||
assert.Equals(t, ins, []string{accID})
|
|
||||||
return fmt.Sprintf("%s/acme/%s/account/%s",
|
|
||||||
baseURL.String(), provName, accID)
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -731,21 +646,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/post-as-get": func(t *testing.T) test {
|
"ok/post-as-get": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||||
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
|
||||||
getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
|
|
||||||
assert.Equals(t, typ, acme.AccountLink)
|
|
||||||
assert.True(t, abs)
|
|
||||||
assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
|
|
||||||
assert.Equals(t, ins, []string{accID})
|
|
||||||
return fmt.Sprintf("%s/acme/%s/account/%s",
|
|
||||||
baseURL, provName, accID)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
|
@ -754,11 +659,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := New(tc.auth).(*Handler)
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetUpdateAccount(w, req)
|
h.GetOrUpdateAccount(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -767,15 +672,14 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
var ae acme.AError
|
var ae acme.Error
|
||||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
prob := tc.problem.ToACME()
|
|
||||||
|
|
||||||
assert.Equals(t, ae.Type, prob.Type)
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
assert.Equals(t, ae.Detail, prob.Detail)
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
} else {
|
} else {
|
||||||
expB, err := json.Marshal(acc)
|
expB, err := json.Marshal(acc)
|
||||||
|
@ -783,7 +687,7 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Location"],
|
assert.Equals(t, res.Header["Location"],
|
||||||
[]string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
|
[]string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
|
||||||
provName, accID)})
|
escProvName, accID)})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,56 +1,98 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
func link(url, typ string) string {
|
func link(url, typ string) string {
|
||||||
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
|
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clock that returns time in UTC rounded to seconds.
|
||||||
|
type Clock struct{}
|
||||||
|
|
||||||
|
// Now returns the UTC time rounded to seconds.
|
||||||
|
func (c *Clock) Now() time.Time {
|
||||||
|
return time.Now().UTC().Truncate(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
var clock Clock
|
||||||
|
|
||||||
type payloadInfo struct {
|
type payloadInfo struct {
|
||||||
value []byte
|
value []byte
|
||||||
isPostAsGet bool
|
isPostAsGet bool
|
||||||
isEmptyJSON bool
|
isEmptyJSON bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// payloadFromContext searches the context for a payload. Returns the payload
|
// Handler is the ACME API request handler.
|
||||||
// or an error.
|
|
||||||
func payloadFromContext(ctx context.Context) (*payloadInfo, error) {
|
|
||||||
val, ok := ctx.Value(acme.PayloadContextKey).(*payloadInfo)
|
|
||||||
if !ok || val == nil {
|
|
||||||
return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context"))
|
|
||||||
}
|
|
||||||
return val, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// New returns a new ACME API router.
|
|
||||||
func New(acmeAuth acme.Interface) api.RouterHandler {
|
|
||||||
return &Handler{acmeAuth}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler is the ACME request handler.
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
Auth acme.Interface
|
db acme.DB
|
||||||
|
backdate provisioner.Duration
|
||||||
|
ca acme.CertificateAuthority
|
||||||
|
linker Linker
|
||||||
|
validateChallengeOptions *acme.ValidateChallengeOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerOptions required to create a new ACME API request handler.
|
||||||
|
type HandlerOptions struct {
|
||||||
|
Backdate provisioner.Duration
|
||||||
|
// DB storage backend that impements the acme.DB interface.
|
||||||
|
DB acme.DB
|
||||||
|
// DNS the host used to generate accurate ACME links. By default the authority
|
||||||
|
// will use the Host from the request, so this value will only be used if
|
||||||
|
// request.Host is empty.
|
||||||
|
DNS string
|
||||||
|
// Prefix is a URL path prefix under which the ACME api is served. This
|
||||||
|
// prefix is required to generate accurate ACME links.
|
||||||
|
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
|
||||||
|
// "acme" is the prefix from which the ACME api is accessed.
|
||||||
|
Prefix string
|
||||||
|
CA acme.CertificateAuthority
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler returns a new ACME API handler.
|
||||||
|
func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
return &Handler{
|
||||||
|
ca: ops.CA,
|
||||||
|
db: ops.DB,
|
||||||
|
backdate: ops.Backdate,
|
||||||
|
linker: NewLinker(ops.DNS, ops.Prefix),
|
||||||
|
validateChallengeOptions: &acme.ValidateChallengeOptions{
|
||||||
|
HTTPGet: client.Get,
|
||||||
|
LookupTxt: net.LookupTXT,
|
||||||
|
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||||
|
return tls.DialWithDialer(dialer, network, addr, config)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route traffic and implement the Router interface.
|
// Route traffic and implement the Router interface.
|
||||||
func (h *Handler) Route(r api.Router) {
|
func (h *Handler) Route(r api.Router) {
|
||||||
getLink := h.Auth.GetLinkExplicit
|
getPath := h.linker.GetUnescapedPathSuffix
|
||||||
// Standard ACME API
|
// Standard ACME API
|
||||||
r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce))))
|
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
|
||||||
r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce))))
|
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
|
||||||
r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
|
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
|
||||||
r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
|
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
|
||||||
|
|
||||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||||
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))
|
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))
|
||||||
|
@ -59,16 +101,16 @@ func (h *Handler) Route(r api.Router) {
|
||||||
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))))
|
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))))
|
||||||
}
|
}
|
||||||
|
|
||||||
r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount))
|
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
|
||||||
r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount))
|
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
|
||||||
r.MethodFunc("POST", getLink(acme.KeyChangeLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented))
|
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented))
|
||||||
r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder))
|
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder))
|
||||||
r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
||||||
r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount)))
|
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID)))
|
||||||
r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
||||||
r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz)))
|
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
|
||||||
r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, nil, "{chID}"), extractPayloadByKid(h.GetChallenge))
|
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
|
||||||
r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNonce just sets the right header since a Nonce is added to each response
|
// GetNonce just sets the right header since a Nonce is added to each response
|
||||||
|
@ -81,101 +123,153 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Directory represents an ACME directory for configuring clients.
|
||||||
|
type Directory struct {
|
||||||
|
NewNonce string `json:"newNonce"`
|
||||||
|
NewAccount string `json:"newAccount"`
|
||||||
|
NewOrder string `json:"newOrder"`
|
||||||
|
RevokeCert string `json:"revokeCert"`
|
||||||
|
KeyChange string `json:"keyChange"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToLog enables response logging for the Directory type.
|
||||||
|
func (d *Directory) ToLog() (interface{}, error) {
|
||||||
|
b, err := json.Marshal(d)
|
||||||
|
if err != nil {
|
||||||
|
return nil, acme.WrapErrorISE(err, "error marshaling directory for logging")
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetDirectory is the ACME resource for returning a directory configuration
|
// GetDirectory is the ACME resource for returning a directory configuration
|
||||||
// for client configuration.
|
// for client configuration.
|
||||||
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||||
dir, err := h.Auth.GetDirectory(r.Context())
|
ctx := r.Context()
|
||||||
if err != nil {
|
api.JSON(w, &Directory{
|
||||||
api.WriteError(w, err)
|
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
|
||||||
return
|
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
|
||||||
}
|
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
|
||||||
api.JSON(w, dir)
|
RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType),
|
||||||
|
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
||||||
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
||||||
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
|
||||||
api.WriteError(w, acme.NotImplemented(nil).ToACME())
|
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthz ACME api for retrieving an Authz.
|
// GetAuthorization ACME api for retrieving an Authz.
|
||||||
func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||||
acc, err := acme.AccountFromContext(r.Context())
|
ctx := r.Context()
|
||||||
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authz, err := h.Auth.GetAuthz(r.Context(), acc.GetID(), chi.URLParam(r, "authzID"))
|
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if acc.ID != az.AccountID {
|
||||||
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = az.UpdateStatus(ctx, h.db); err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AuthzLink, true, authz.GetID()))
|
h.linker.LinkAuthorization(ctx, az)
|
||||||
api.JSON(w, authz)
|
|
||||||
|
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
|
||||||
|
api.JSON(w, az)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetChallenge ACME api for retrieving a Challenge.
|
// GetChallenge ACME api for retrieving a Challenge.
|
||||||
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
acc, err := acme.AccountFromContext(r.Context())
|
ctx := r.Context()
|
||||||
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Just verify that the payload was set, since we're not strictly adhering
|
// Just verify that the payload was set, since we're not strictly adhering
|
||||||
// to ACME V2 spec for reasons specified below.
|
// to ACME V2 spec for reasons specified below.
|
||||||
_, err = payloadFromContext(r.Context())
|
_, err = payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: We should be checking that the request is either a POST-as-GET, or
|
// NOTE: We should be checking ^^^ that the request is either a POST-as-GET, or
|
||||||
// that the payload is an empty JSON block ({}). However, older ACME clients
|
// that the payload is an empty JSON block ({}). However, older ACME clients
|
||||||
// still send a vestigial body (rather than an empty JSON block) and
|
// still send a vestigial body (rather than an empty JSON block) and
|
||||||
// strict enforcement would render these clients broken. For the time being
|
// strict enforcement would render these clients broken. For the time being
|
||||||
// we'll just ignore the body.
|
// we'll just ignore the body.
|
||||||
var (
|
|
||||||
ch *acme.Challenge
|
azID := chi.URLParam(r, "authzID")
|
||||||
chID = chi.URLParam(r, "chID")
|
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||||
)
|
if err != nil {
|
||||||
ch, err = h.Auth.ValidateChallenge(r.Context(), acc.GetID(), chID, acc.GetKey())
|
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch.AuthorizationID = azID
|
||||||
|
if acc.ID != ch.AccountID {
|
||||||
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"account '%s' does not own challenge '%s'", acc.ID, ch.ID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jwk, err := jwkFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), acme.AuthzLink, true, ch.GetAuthzID()), "up"))
|
h.linker.LinkChallenge(ctx, ch, azID)
|
||||||
w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.ChallengeLink, true, ch.GetID()))
|
|
||||||
|
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
|
||||||
|
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
|
||||||
api.JSON(w, ch)
|
api.JSON(w, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificate ACME api for retrieving a Certificate.
|
// GetCertificate ACME api for retrieving a Certificate.
|
||||||
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||||
acc, err := acme.AccountFromContext(r.Context())
|
ctx := r.Context()
|
||||||
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
certID := chi.URLParam(r, "certID")
|
certID := chi.URLParam(r, "certID")
|
||||||
certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID)
|
|
||||||
|
cert, err := h.db.GetCertificate(ctx, certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cert.AccountID != acc.ID {
|
||||||
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"account '%s' does not own certificate '%s'", acc.ID, certID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
block, _ := pem.Decode(certBytes)
|
var certBytes []byte
|
||||||
if block == nil {
|
for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) {
|
||||||
api.WriteError(w, acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes")))
|
certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{
|
||||||
return
|
Type: "CERTIFICATE",
|
||||||
}
|
Bytes: c.Raw,
|
||||||
cert, err := x509.ParseCertificate(block.Bytes)
|
})...)
|
||||||
if err != nil {
|
|
||||||
api.WriteError(w, acme.Wrap(err, "failed to parse generated leaf certificate"))
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
api.LogCertificate(w, cert)
|
api.LogCertificate(w, cert.Leaf)
|
||||||
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
|
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
|
||||||
w.Write(certBytes)
|
w.Write(certBytes)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -17,206 +18,11 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockAcmeAuthority struct {
|
func TestHandler_GetNonce(t *testing.T) {
|
||||||
getLink func(ctx context.Context, link acme.Link, absPath bool, ins ...string) string
|
|
||||||
getLinkExplicit func(acme.Link, string, bool, *url.URL, ...string) string
|
|
||||||
|
|
||||||
deactivateAccount func(ctx context.Context, accID string) (*acme.Account, error)
|
|
||||||
getAccount func(ctx context.Context, accID string) (*acme.Account, error)
|
|
||||||
getAccountByKey func(ctx context.Context, key *jose.JSONWebKey) (*acme.Account, error)
|
|
||||||
newAccount func(ctx context.Context, ao acme.AccountOptions) (*acme.Account, error)
|
|
||||||
updateAccount func(context.Context, string, []string) (*acme.Account, error)
|
|
||||||
|
|
||||||
getChallenge func(ctx context.Context, accID string, chID string) (*acme.Challenge, error)
|
|
||||||
validateChallenge func(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*acme.Challenge, error)
|
|
||||||
getAuthz func(ctx context.Context, accID string, authzID string) (*acme.Authz, error)
|
|
||||||
getDirectory func(ctx context.Context) (*acme.Directory, error)
|
|
||||||
getCertificate func(string, string) ([]byte, error)
|
|
||||||
|
|
||||||
finalizeOrder func(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*acme.Order, error)
|
|
||||||
getOrder func(ctx context.Context, accID string, orderID string) (*acme.Order, error)
|
|
||||||
getOrdersByAccount func(ctx context.Context, accID string) ([]string, error)
|
|
||||||
newOrder func(ctx context.Context, oo acme.OrderOptions) (*acme.Order, error)
|
|
||||||
|
|
||||||
loadProvisionerByID func(string) (provisioner.Interface, error)
|
|
||||||
newNonce func() (string, error)
|
|
||||||
useNonce func(string) error
|
|
||||||
ret1 interface{}
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) DeactivateAccount(ctx context.Context, id string) (*acme.Account, error) {
|
|
||||||
if m.deactivateAccount != nil {
|
|
||||||
return m.deactivateAccount(ctx, id)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Account), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) FinalizeOrder(ctx context.Context, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) {
|
|
||||||
if m.finalizeOrder != nil {
|
|
||||||
return m.finalizeOrder(ctx, accID, id, csr)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Order), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
|
|
||||||
if m.getAccount != nil {
|
|
||||||
return m.getAccount(ctx, id)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Account), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
|
|
||||||
if m.getAccountByKey != nil {
|
|
||||||
return m.getAccountByKey(ctx, jwk)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Account), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetAuthz(ctx context.Context, accID, id string) (*acme.Authz, error) {
|
|
||||||
if m.getAuthz != nil {
|
|
||||||
return m.getAuthz(ctx, accID, id)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Authz), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetCertificate(accID string, id string) ([]byte, error) {
|
|
||||||
if m.getCertificate != nil {
|
|
||||||
return m.getCertificate(accID, id)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.([]byte), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetChallenge(ctx context.Context, accID, id string) (*acme.Challenge, error) {
|
|
||||||
if m.getChallenge != nil {
|
|
||||||
return m.getChallenge(ctx, accID, id)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Challenge), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetDirectory(ctx context.Context) (*acme.Directory, error) {
|
|
||||||
if m.getDirectory != nil {
|
|
||||||
return m.getDirectory(ctx)
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Directory), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetLink(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
|
|
||||||
if m.getLink != nil {
|
|
||||||
return m.getLink(ctx, typ, abs, ins...)
|
|
||||||
}
|
|
||||||
return m.ret1.(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetLinkExplicit(typ acme.Link, provID string, abs bool, baseURL *url.URL, ins ...string) string {
|
|
||||||
if m.getLinkExplicit != nil {
|
|
||||||
return m.getLinkExplicit(typ, provID, abs, baseURL, ins...)
|
|
||||||
}
|
|
||||||
return m.ret1.(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetOrder(ctx context.Context, accID, id string) (*acme.Order, error) {
|
|
||||||
if m.getOrder != nil {
|
|
||||||
return m.getOrder(ctx, accID, id)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Order), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) {
|
|
||||||
if m.getOrdersByAccount != nil {
|
|
||||||
return m.getOrdersByAccount(ctx, id)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.([]string), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
|
|
||||||
if m.loadProvisionerByID != nil {
|
|
||||||
return m.loadProvisionerByID(provID)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(provisioner.Interface), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) NewAccount(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) {
|
|
||||||
if m.newAccount != nil {
|
|
||||||
return m.newAccount(ctx, ops)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Account), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) NewNonce() (string, error) {
|
|
||||||
if m.newNonce != nil {
|
|
||||||
return m.newNonce()
|
|
||||||
} else if m.err != nil {
|
|
||||||
return "", m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(string), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) NewOrder(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) {
|
|
||||||
if m.newOrder != nil {
|
|
||||||
return m.newOrder(ctx, ops)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Order), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) UpdateAccount(ctx context.Context, id string, contact []string) (*acme.Account, error) {
|
|
||||||
if m.updateAccount != nil {
|
|
||||||
return m.updateAccount(ctx, id, contact)
|
|
||||||
} else if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.(*acme.Account), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) UseNonce(nonce string) error {
|
|
||||||
if m.useNonce != nil {
|
|
||||||
return m.useNonce(nonce)
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAcmeAuthority) ValidateChallenge(ctx context.Context, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
|
|
||||||
switch {
|
|
||||||
case m.validateChallenge != nil:
|
|
||||||
return m.validateChallenge(ctx, accID, id, jwk)
|
|
||||||
case m.err != nil:
|
|
||||||
return nil, m.err
|
|
||||||
default:
|
|
||||||
return m.ret1.(*acme.Challenge), m.err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerGetNonce(t *testing.T) {
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
statusCode int
|
statusCode int
|
||||||
|
@ -230,7 +36,7 @@ func TestHandlerGetNonce(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(nil).(*Handler)
|
h := &Handler{}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
req.Method = tt.name
|
req.Method = tt.name
|
||||||
h.GetNonce(w, req)
|
h.GetNonce(w, req)
|
||||||
|
@ -243,21 +49,16 @@ func TestHandlerGetNonce(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerGetDirectory(t *testing.T) {
|
func TestHandler_GetDirectory(t *testing.T) {
|
||||||
auth, err := acme.New(nil, acme.AuthorityOptions{
|
linker := NewLinker("ca.smallstep.com", "acme")
|
||||||
DB: new(db.MockNoSQLDB),
|
|
||||||
DNS: "ca.smallstep.com",
|
|
||||||
Prefix: "acme",
|
|
||||||
})
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
|
||||||
expDir := acme.Directory{
|
expDir := Directory{
|
||||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||||
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
|
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
|
||||||
|
@ -267,7 +68,7 @@ func TestHandlerGetDirectory(t *testing.T) {
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
statusCode int
|
statusCode int
|
||||||
problem *acme.Error
|
err *acme.Error
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
|
@ -279,7 +80,7 @@ func TestHandlerGetDirectory(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := New(auth).(*Handler)
|
h := &Handler{linker: linker}
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -292,18 +93,17 @@ func TestHandlerGetDirectory(t *testing.T) {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
var ae acme.AError
|
var ae acme.Error
|
||||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
prob := tc.problem.ToACME()
|
|
||||||
|
|
||||||
assert.Equals(t, ae.Type, prob.Type)
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
assert.Equals(t, ae.Detail, prob.Detail)
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
} else {
|
} else {
|
||||||
var dir acme.Directory
|
var dir Directory
|
||||||
json.Unmarshal(bytes.TrimSpace(body), &dir)
|
json.Unmarshal(bytes.TrimSpace(body), &dir)
|
||||||
assert.Equals(t, dir, expDir)
|
assert.Equals(t, dir, expDir)
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
@ -312,16 +112,17 @@ func TestHandlerGetDirectory(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerGetAuthz(t *testing.T) {
|
func TestHandler_GetAuthorization(t *testing.T) {
|
||||||
expiry := time.Now().UTC().Add(6 * time.Hour)
|
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||||
az := acme.Authz{
|
az := acme.Authorization{
|
||||||
ID: "authzID",
|
ID: "authzID",
|
||||||
|
AccountID: "accID",
|
||||||
Identifier: acme.Identifier{
|
Identifier: acme.Identifier{
|
||||||
Type: "dns",
|
Type: "dns",
|
||||||
Value: "example.com",
|
Value: "example.com",
|
||||||
},
|
},
|
||||||
Status: "pending",
|
Status: "pending",
|
||||||
Expires: expiry.Format(time.RFC3339),
|
ExpiresAt: expiry,
|
||||||
Wildcard: false,
|
Wildcard: false,
|
||||||
Challenges: []*acme.Challenge{
|
Challenges: []*acme.Challenge{
|
||||||
{
|
{
|
||||||
|
@ -330,7 +131,6 @@ func TestHandlerGetAuthz(t *testing.T) {
|
||||||
Token: "tok2",
|
Token: "tok2",
|
||||||
URL: "https://ca.smallstep.com/acme/challenge/chHTTPID",
|
URL: "https://ca.smallstep.com/acme/challenge/chHTTPID",
|
||||||
ID: "chHTTP01ID",
|
ID: "chHTTP01ID",
|
||||||
AuthzID: "authzID",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Type: "dns-01",
|
Type: "dns-01",
|
||||||
|
@ -338,7 +138,6 @@ func TestHandlerGetAuthz(t *testing.T) {
|
||||||
Token: "tok2",
|
Token: "tok2",
|
||||||
URL: "https://ca.smallstep.com/acme/challenge/chDNSID",
|
URL: "https://ca.smallstep.com/acme/challenge/chDNSID",
|
||||||
ID: "chDNSID",
|
ID: "chDNSID",
|
||||||
AuthzID: "authzID",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -349,71 +148,101 @@ func TestHandlerGetAuthz(t *testing.T) {
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("authzID", az.ID)
|
chiCtx.URLParams.Add("authzID", az.ID)
|
||||||
url := fmt.Sprintf("%s/acme/%s/challenge/%s",
|
url := fmt.Sprintf("%s/acme/%s/authz/%s",
|
||||||
baseURL.String(), provName, az.ID)
|
baseURL.String(), provName, az.ID)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
auth acme.Interface
|
db acme.DB
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
statusCode int
|
statusCode int
|
||||||
problem *acme.Error
|
err *acme.Error
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{},
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
|
ctx: context.Background(),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, nil)
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{},
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/getAuthz-error": func(t *testing.T) test {
|
"fail/db.GetAuthorization-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
err: acme.ServerInternalErr(errors.New("force")),
|
MockError: acme.NewErrorISE("force"),
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("force")),
|
err: acme.NewErrorISE("force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||||
|
assert.Equals(t, id, az.ID)
|
||||||
|
return &acme.Authorization{
|
||||||
|
AccountID: "foo",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 401,
|
||||||
|
err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.UpdateAuthorization-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||||
|
assert.Equals(t, id, az.ID)
|
||||||
|
return &acme.Authorization{
|
||||||
|
AccountID: "accID",
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
MockUpdateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
|
||||||
|
assert.Equals(t, az.Status, acme.StatusInvalid)
|
||||||
|
return acme.NewErrorISE("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
err: acme.NewErrorISE("force"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
getAuthz: func(ctx context.Context, accID, id string) (*acme.Authz, error) {
|
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||||
p, err := acme.ProvisionerFromContext(ctx)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
assert.Equals(t, p, prov)
|
|
||||||
assert.Equals(t, accID, acc.ID)
|
|
||||||
assert.Equals(t, id, az.ID)
|
assert.Equals(t, id, az.ID)
|
||||||
return &az, nil
|
return &az, nil
|
||||||
},
|
},
|
||||||
getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
|
|
||||||
assert.Equals(t, typ, acme.AuthzLink)
|
|
||||||
assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
|
|
||||||
assert.True(t, abs)
|
|
||||||
assert.Equals(t, in, []string{az.ID})
|
|
||||||
return url
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
|
@ -423,11 +252,11 @@ func TestHandlerGetAuthz(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := New(tc.auth).(*Handler)
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetAuthz(w, req)
|
h.GetAuthorization(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
@ -436,15 +265,14 @@ func TestHandlerGetAuthz(t *testing.T) {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
var ae acme.AError
|
var ae acme.Error
|
||||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
prob := tc.problem.ToACME()
|
|
||||||
|
|
||||||
assert.Equals(t, ae.Type, prob.Type)
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
assert.Equals(t, ae.Detail, prob.Detail)
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
} else {
|
} else {
|
||||||
//var gotAz acme.Authz
|
//var gotAz acme.Authz
|
||||||
|
@ -459,7 +287,7 @@ func TestHandlerGetAuthz(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerGetCertificate(t *testing.T) {
|
func TestHandler_GetCertificate(t *testing.T) {
|
||||||
leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt")
|
leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt")
|
inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt")
|
||||||
|
@ -490,89 +318,73 @@ func TestHandlerGetCertificate(t *testing.T) {
|
||||||
baseURL.String(), provName, certID)
|
baseURL.String(), provName, certID)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
auth acme.Interface
|
db acme.DB
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
statusCode int
|
statusCode int
|
||||||
problem *acme.Error
|
err *acme.Error
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{},
|
db: &acme.MockDB{},
|
||||||
ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
|
ctx: context.Background(),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.AccContextKey, nil)
|
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{},
|
db: &acme.MockDB{},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/getCertificate-error": func(t *testing.T) test {
|
"fail/db.GetCertificate-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
err: acme.ServerInternalErr(errors.New("force")),
|
MockError: acme.NewErrorISE("force"),
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("force")),
|
err: acme.NewErrorISE("force"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/decode-leaf-for-loggger": func(t *testing.T) test {
|
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
getCertificate: func(accID, id string) ([]byte, error) {
|
MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) {
|
||||||
assert.Equals(t, accID, acc.ID)
|
|
||||||
assert.Equals(t, id, certID)
|
assert.Equals(t, id, certID)
|
||||||
return []byte("foo"), nil
|
return &acme.Certificate{AccountID: "foo"}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 401,
|
||||||
problem: acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes")),
|
err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"),
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/parse-x509-leaf-for-logger": func(t *testing.T) test {
|
|
||||||
acc := &acme.Account{ID: "accID"}
|
|
||||||
ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
|
||||||
return test{
|
|
||||||
auth: &mockAcmeAuthority{
|
|
||||||
getCertificate: func(accID, id string) ([]byte, error) {
|
|
||||||
assert.Equals(t, accID, acc.ID)
|
|
||||||
assert.Equals(t, id, certID)
|
|
||||||
return pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "CERTIFICATE REQUEST",
|
|
||||||
Bytes: []byte("foo"),
|
|
||||||
}), nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ctx: ctx,
|
|
||||||
statusCode: 500,
|
|
||||||
problem: acme.ServerInternalErr(errors.New("failed to parse generated leaf certificate")),
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
getCertificate: func(accID, id string) ([]byte, error) {
|
MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) {
|
||||||
assert.Equals(t, accID, acc.ID)
|
|
||||||
assert.Equals(t, id, certID)
|
assert.Equals(t, id, certID)
|
||||||
return certBytes, nil
|
return &acme.Certificate{
|
||||||
|
AccountID: "accID",
|
||||||
|
OrderID: "ordID",
|
||||||
|
Leaf: leaf,
|
||||||
|
Intermediates: []*x509.Certificate{inter, root},
|
||||||
|
ID: id,
|
||||||
|
}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -583,7 +395,7 @@ func TestHandlerGetCertificate(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := New(tc.auth).(*Handler)
|
h := &Handler{db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -596,15 +408,14 @@ func TestHandlerGetCertificate(t *testing.T) {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
var ae acme.AError
|
var ae acme.Error
|
||||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
prob := tc.problem.ToACME()
|
|
||||||
|
|
||||||
assert.Equals(t, ae.Type, prob.Type)
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
assert.HasPrefix(t, ae.Detail, prob.Detail)
|
assert.HasPrefix(t, ae.Detail, tc.err.Detail)
|
||||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
} else {
|
} else {
|
||||||
assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes))
|
assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes))
|
||||||
|
@ -614,152 +425,233 @@ func TestHandlerGetCertificate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ch() acme.Challenge {
|
func TestHandler_GetChallenge(t *testing.T) {
|
||||||
return acme.Challenge{
|
|
||||||
Type: "http-01",
|
|
||||||
Status: "pending",
|
|
||||||
Token: "tok2",
|
|
||||||
URL: "https://ca.smallstep.com/acme/challenge/chID",
|
|
||||||
ID: "chID",
|
|
||||||
AuthzID: "authzID",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerGetChallenge(t *testing.T) {
|
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("chID", "chID")
|
chiCtx.URLParams.Add("chID", "chID")
|
||||||
|
chiCtx.URLParams.Add("authzID", "authzID")
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
url := fmt.Sprintf("%s/acme/challenge/%s", baseURL, "chID")
|
|
||||||
|
url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
|
||||||
|
baseURL.String(), provName, "authzID", "chID")
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
auth acme.Interface
|
db acme.DB
|
||||||
|
vco *acme.ValidateChallengeOptions
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
statusCode int
|
statusCode int
|
||||||
ch acme.Challenge
|
ch *acme.Challenge
|
||||||
problem *acme.Error
|
err *acme.Error
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/no-account": func(t *testing.T) test {
|
"fail/no-account": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
|
ctx: context.Background(),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-account": func(t *testing.T) test {
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, nil)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: context.WithValue(context.Background(), accContextKey, nil),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
problem: acme.AccountDoesNotExistErr(nil),
|
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/no-payload": func(t *testing.T) test {
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/nil-payload": func(t *testing.T) test {
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
err: acme.NewErrorISE("payload expected in request context"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
||||||
|
assert.Equals(t, chID, "chID")
|
||||||
|
assert.Equals(t, azID, "authzID")
|
||||||
|
return nil, acme.NewErrorISE("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
err: acme.NewErrorISE("force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
||||||
|
assert.Equals(t, chID, "chID")
|
||||||
|
assert.Equals(t, azID, "authzID")
|
||||||
|
return &acme.Challenge{AccountID: "foo"}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 401,
|
||||||
|
err: acme.NewError(acme.ErrorUnauthorizedType, "accout id mismatch"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-jwk": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
||||||
|
assert.Equals(t, chID, "chID")
|
||||||
|
assert.Equals(t, azID, "authzID")
|
||||||
|
return &acme.Challenge{AccountID: "accID"}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
err: acme.NewErrorISE("missing jwk"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-jwk": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
||||||
|
assert.Equals(t, chID, "chID")
|
||||||
|
assert.Equals(t, azID, "authzID")
|
||||||
|
return &acme.Challenge{AccountID: "accID"}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
err: acme.NewErrorISE("nil jwk"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/validate-challenge-error": func(t *testing.T) test {
|
"fail/validate-challenge-error": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
|
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_pub := _jwk.Public()
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
err: acme.UnauthorizedErr(nil),
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
||||||
|
assert.Equals(t, chID, "chID")
|
||||||
|
assert.Equals(t, azID, "authzID")
|
||||||
|
return &acme.Challenge{
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Type: "http-01",
|
||||||
|
AccountID: "accID",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||||
|
assert.Equals(t, ch.Status, acme.StatusPending)
|
||||||
|
assert.Equals(t, ch.Type, "http-01")
|
||||||
|
assert.Equals(t, ch.AccountID, "accID")
|
||||||
|
assert.Equals(t, ch.AuthorizationID, "authzID")
|
||||||
|
assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String())
|
||||||
|
return acme.NewErrorISE("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
vco: &acme.ValidateChallengeOptions{
|
||||||
|
HTTPGet: func(string) (*http.Response, error) {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 401,
|
statusCode: 500,
|
||||||
problem: acme.UnauthorizedErr(nil),
|
err: acme.NewErrorISE("force"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/get-challenge-error": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{ID: "accID"}
|
acc := &acme.Account{ID: "accID"}
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true})
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
|
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_pub := _jwk.Public()
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
return test{
|
return test{
|
||||||
auth: &mockAcmeAuthority{
|
db: &acme.MockDB{
|
||||||
err: acme.UnauthorizedErr(nil),
|
MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
|
||||||
|
assert.Equals(t, chID, "chID")
|
||||||
|
assert.Equals(t, azID, "authzID")
|
||||||
|
return &acme.Challenge{
|
||||||
|
ID: "chID",
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Type: "http-01",
|
||||||
|
AccountID: "accID",
|
||||||
|
}, nil
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||||
statusCode: 401,
|
assert.Equals(t, ch.Status, acme.StatusPending)
|
||||||
problem: acme.UnauthorizedErr(nil),
|
assert.Equals(t, ch.Type, "http-01")
|
||||||
}
|
assert.Equals(t, ch.AccountID, "accID")
|
||||||
|
assert.Equals(t, ch.AuthorizationID, "authzID")
|
||||||
|
assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String())
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
"ok/validate-challenge": func(t *testing.T) test {
|
|
||||||
key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
acc := &acme.Account{ID: "accID", Key: key}
|
|
||||||
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
|
||||||
ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true})
|
|
||||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
|
||||||
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
|
|
||||||
ch := ch()
|
|
||||||
ch.Status = "valid"
|
|
||||||
ch.Validated = time.Now().UTC().Format(time.RFC3339)
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
auth: &mockAcmeAuthority{
|
|
||||||
validateChallenge: func(ctx context.Context, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
|
|
||||||
p, err := acme.ProvisionerFromContext(ctx)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
assert.Equals(t, p, prov)
|
|
||||||
assert.Equals(t, accID, acc.ID)
|
|
||||||
assert.Equals(t, id, ch.ID)
|
|
||||||
assert.Equals(t, jwk.KeyID, key.KeyID)
|
|
||||||
return &ch, nil
|
|
||||||
},
|
},
|
||||||
getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
|
ch: &acme.Challenge{
|
||||||
var ret string
|
ID: "chID",
|
||||||
switch count {
|
Status: acme.StatusPending,
|
||||||
case 0:
|
AuthorizationID: "authzID",
|
||||||
assert.Equals(t, typ, acme.AuthzLink)
|
Type: "http-01",
|
||||||
assert.True(t, abs)
|
AccountID: "accID",
|
||||||
assert.Equals(t, in, []string{ch.AuthzID})
|
URL: url,
|
||||||
ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID)
|
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
||||||
case 1:
|
},
|
||||||
assert.Equals(t, typ, acme.ChallengeLink)
|
vco: &acme.ValidateChallengeOptions{
|
||||||
assert.True(t, abs)
|
HTTPGet: func(string) (*http.Response, error) {
|
||||||
assert.Equals(t, in, []string{ch.ID})
|
return nil, errors.New("force")
|
||||||
ret = url
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return ret
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
ch: ch,
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := New(tc.auth).(*Handler)
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -772,21 +664,20 @@ func TestHandlerGetChallenge(t *testing.T) {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
var ae acme.AError
|
var ae acme.Error
|
||||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
prob := tc.problem.ToACME()
|
|
||||||
|
|
||||||
assert.Equals(t, ae.Type, prob.Type)
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
assert.Equals(t, ae.Detail, prob.Detail)
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
} else {
|
} else {
|
||||||
expB, err := json.Marshal(tc.ch)
|
expB, err := json.Marshal(tc.ch)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, tc.ch.AuthzID)})
|
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")})
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{url})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
|
|
181
acme/api/linker.go
Normal file
181
acme/api/linker.go
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewLinker returns a new Directory type.
|
||||||
|
func NewLinker(dns, prefix string) Linker {
|
||||||
|
return &linker{prefix: prefix, dns: dns}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Linker interface for generating links for ACME resources.
|
||||||
|
type Linker interface {
|
||||||
|
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||||
|
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
|
||||||
|
|
||||||
|
LinkOrder(ctx context.Context, o *acme.Order)
|
||||||
|
LinkAccount(ctx context.Context, o *acme.Account)
|
||||||
|
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
|
||||||
|
LinkAuthorization(ctx context.Context, o *acme.Authorization)
|
||||||
|
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// linker generates ACME links.
|
||||||
|
type linker struct {
|
||||||
|
prefix string
|
||||||
|
dns string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||||
|
switch typ {
|
||||||
|
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||||
|
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||||
|
case ChallengeLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||||
|
case OrdersByAccountLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||||
|
case FinalizeLinkType:
|
||||||
|
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLink is a helper for GetLinkExplicit
|
||||||
|
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||||
|
var (
|
||||||
|
provName string
|
||||||
|
baseURL = baseURLFromContext(ctx)
|
||||||
|
u = url.URL{}
|
||||||
|
)
|
||||||
|
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
|
||||||
|
provName = p.GetName()
|
||||||
|
}
|
||||||
|
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
|
||||||
|
if baseURL != nil {
|
||||||
|
u = *baseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
|
||||||
|
|
||||||
|
// If no Scheme is set, then default to https.
|
||||||
|
if u.Scheme == "" {
|
||||||
|
u.Scheme = "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no Host is set, then use the default (first DNS attr in the ca.json).
|
||||||
|
if u.Host == "" {
|
||||||
|
u.Host = l.dns
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Path = l.prefix + u.Path
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkType captures the link type.
|
||||||
|
type LinkType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// NewNonceLinkType new-nonce
|
||||||
|
NewNonceLinkType LinkType = iota
|
||||||
|
// NewAccountLinkType new-account
|
||||||
|
NewAccountLinkType
|
||||||
|
// AccountLinkType account
|
||||||
|
AccountLinkType
|
||||||
|
// OrderLinkType order
|
||||||
|
OrderLinkType
|
||||||
|
// NewOrderLinkType new-order
|
||||||
|
NewOrderLinkType
|
||||||
|
// OrdersByAccountLinkType list of orders owned by account
|
||||||
|
OrdersByAccountLinkType
|
||||||
|
// FinalizeLinkType finalize order
|
||||||
|
FinalizeLinkType
|
||||||
|
// NewAuthzLinkType authz
|
||||||
|
NewAuthzLinkType
|
||||||
|
// AuthzLinkType new-authz
|
||||||
|
AuthzLinkType
|
||||||
|
// ChallengeLinkType challenge
|
||||||
|
ChallengeLinkType
|
||||||
|
// CertificateLinkType certificate
|
||||||
|
CertificateLinkType
|
||||||
|
// DirectoryLinkType directory
|
||||||
|
DirectoryLinkType
|
||||||
|
// RevokeCertLinkType revoke certificate
|
||||||
|
RevokeCertLinkType
|
||||||
|
// KeyChangeLinkType key rollover
|
||||||
|
KeyChangeLinkType
|
||||||
|
)
|
||||||
|
|
||||||
|
func (l LinkType) String() string {
|
||||||
|
switch l {
|
||||||
|
case NewNonceLinkType:
|
||||||
|
return "new-nonce"
|
||||||
|
case NewAccountLinkType:
|
||||||
|
return "new-account"
|
||||||
|
case AccountLinkType:
|
||||||
|
return "account"
|
||||||
|
case NewOrderLinkType:
|
||||||
|
return "new-order"
|
||||||
|
case OrderLinkType:
|
||||||
|
return "order"
|
||||||
|
case NewAuthzLinkType:
|
||||||
|
return "new-authz"
|
||||||
|
case AuthzLinkType:
|
||||||
|
return "authz"
|
||||||
|
case ChallengeLinkType:
|
||||||
|
return "challenge"
|
||||||
|
case CertificateLinkType:
|
||||||
|
return "certificate"
|
||||||
|
case DirectoryLinkType:
|
||||||
|
return "directory"
|
||||||
|
case RevokeCertLinkType:
|
||||||
|
return "revoke-cert"
|
||||||
|
case KeyChangeLinkType:
|
||||||
|
return "key-change"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unexpected LinkType '%d'", int(l))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkOrder sets the ACME links required by an ACME order.
|
||||||
|
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
||||||
|
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
|
||||||
|
for i, azID := range o.AuthorizationIDs {
|
||||||
|
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
|
||||||
|
}
|
||||||
|
o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, o.ID)
|
||||||
|
if o.CertificateID != "" {
|
||||||
|
o.CertificateURL = l.GetLink(ctx, CertificateLinkType, o.CertificateID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkAccount sets the ACME links required by an ACME account.
|
||||||
|
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
|
||||||
|
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkChallenge sets the ACME links required by an ACME challenge.
|
||||||
|
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
|
||||||
|
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkAuthorization sets the ACME links required by an ACME authorization.
|
||||||
|
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
|
||||||
|
for _, ch := range az.Challenges {
|
||||||
|
l.LinkChallenge(ctx, ch, az.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkOrdersByAccountID converts each order ID to an ACME link.
|
||||||
|
func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) {
|
||||||
|
for i, id := range orders {
|
||||||
|
orders[i] = l.GetLink(ctx, OrderLinkType, id)
|
||||||
|
}
|
||||||
|
}
|
283
acme/api/linker_test.go
Normal file
283
acme/api/linker_test.go
Normal file
|
@ -0,0 +1,283 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
||||||
|
dns := "ca.smallstep.com"
|
||||||
|
prefix := "acme"
|
||||||
|
linker := NewLinker(dns, prefix)
|
||||||
|
|
||||||
|
getPath := linker.GetUnescapedPathSuffix
|
||||||
|
|
||||||
|
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
|
||||||
|
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
|
||||||
|
assert.Equals(t, getPath(NewAccountLinkType, "{provisionerID}"), "/{provisionerID}/new-account")
|
||||||
|
assert.Equals(t, getPath(AccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}")
|
||||||
|
assert.Equals(t, getPath(KeyChangeLinkType, "{provisionerID}"), "/{provisionerID}/key-change")
|
||||||
|
assert.Equals(t, getPath(NewOrderLinkType, "{provisionerID}"), "/{provisionerID}/new-order")
|
||||||
|
assert.Equals(t, getPath(OrderLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}")
|
||||||
|
assert.Equals(t, getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}/orders")
|
||||||
|
assert.Equals(t, getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}/finalize")
|
||||||
|
assert.Equals(t, getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), "/{provisionerID}/authz/{authzID}")
|
||||||
|
assert.Equals(t, getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), "/{provisionerID}/challenge/{authzID}/{chID}")
|
||||||
|
assert.Equals(t, getPath(CertificateLinkType, "{provisionerID}", "{certID}"), "/{provisionerID}/certificate/{certID}")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinker_GetLink(t *testing.T) {
|
||||||
|
dns := "ca.smallstep.com"
|
||||||
|
prefix := "acme"
|
||||||
|
linker := NewLinker(dns, prefix)
|
||||||
|
id := "1234"
|
||||||
|
|
||||||
|
prov := newProv()
|
||||||
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
|
||||||
|
// No provisioner and no BaseURL from request
|
||||||
|
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
|
||||||
|
// Provisioner: yes, BaseURL: no
|
||||||
|
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||||
|
|
||||||
|
// Provisioner: no, BaseURL: yes
|
||||||
|
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, NewAccountLinkType), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, AccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, NewOrderLinkType), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, OrderLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, OrdersByAccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, FinalizeLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, NewAuthzLinkType), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, AuthzLinkType, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, DirectoryLinkType), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, RevokeCertLinkType, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, KeyChangeLinkType), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, ChallengeLinkType, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id))
|
||||||
|
|
||||||
|
assert.Equals(t, linker.GetLink(ctx, CertificateLinkType, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinker_LinkOrder(t *testing.T) {
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||||
|
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||||
|
|
||||||
|
oid := "orderID"
|
||||||
|
certID := "certID"
|
||||||
|
linkerPrefix := "acme"
|
||||||
|
l := NewLinker("dns", linkerPrefix)
|
||||||
|
type test struct {
|
||||||
|
o *acme.Order
|
||||||
|
validate func(o *acme.Order)
|
||||||
|
}
|
||||||
|
var tests = map[string]test{
|
||||||
|
"no-authz-and-no-cert": {
|
||||||
|
o: &acme.Order{
|
||||||
|
ID: oid,
|
||||||
|
},
|
||||||
|
validate: func(o *acme.Order) {
|
||||||
|
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||||
|
assert.Equals(t, o.AuthorizationURLs, []string{})
|
||||||
|
assert.Equals(t, o.CertificateURL, "")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"one-authz-and-cert": {
|
||||||
|
o: &acme.Order{
|
||||||
|
ID: oid,
|
||||||
|
CertificateID: certID,
|
||||||
|
AuthorizationIDs: []string{"foo"},
|
||||||
|
},
|
||||||
|
validate: func(o *acme.Order) {
|
||||||
|
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||||
|
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||||
|
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||||
|
})
|
||||||
|
assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"many-authz": {
|
||||||
|
o: &acme.Order{
|
||||||
|
ID: oid,
|
||||||
|
CertificateID: certID,
|
||||||
|
AuthorizationIDs: []string{"foo", "bar", "zap"},
|
||||||
|
},
|
||||||
|
validate: func(o *acme.Order) {
|
||||||
|
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||||
|
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||||
|
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||||
|
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "bar"),
|
||||||
|
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "zap"),
|
||||||
|
})
|
||||||
|
assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
l.LinkOrder(ctx, tc.o)
|
||||||
|
tc.validate(tc.o)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinker_LinkAccount(t *testing.T) {
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||||
|
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||||
|
|
||||||
|
accID := "accountID"
|
||||||
|
linkerPrefix := "acme"
|
||||||
|
l := NewLinker("dns", linkerPrefix)
|
||||||
|
type test struct {
|
||||||
|
a *acme.Account
|
||||||
|
validate func(o *acme.Account)
|
||||||
|
}
|
||||||
|
var tests = map[string]test{
|
||||||
|
"ok": {
|
||||||
|
a: &acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
},
|
||||||
|
validate: func(a *acme.Account) {
|
||||||
|
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
l.LinkAccount(ctx, tc.a)
|
||||||
|
tc.validate(tc.a)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinker_LinkChallenge(t *testing.T) {
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||||
|
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||||
|
|
||||||
|
chID := "chID"
|
||||||
|
azID := "azID"
|
||||||
|
linkerPrefix := "acme"
|
||||||
|
l := NewLinker("dns", linkerPrefix)
|
||||||
|
type test struct {
|
||||||
|
ch *acme.Challenge
|
||||||
|
validate func(o *acme.Challenge)
|
||||||
|
}
|
||||||
|
var tests = map[string]test{
|
||||||
|
"ok": {
|
||||||
|
ch: &acme.Challenge{
|
||||||
|
ID: chID,
|
||||||
|
},
|
||||||
|
validate: func(ch *acme.Challenge) {
|
||||||
|
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
l.LinkChallenge(ctx, tc.ch, azID)
|
||||||
|
tc.validate(tc.ch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinker_LinkAuthorization(t *testing.T) {
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||||
|
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||||
|
|
||||||
|
chID0 := "chID-0"
|
||||||
|
chID1 := "chID-1"
|
||||||
|
chID2 := "chID-2"
|
||||||
|
azID := "azID"
|
||||||
|
linkerPrefix := "acme"
|
||||||
|
l := NewLinker("dns", linkerPrefix)
|
||||||
|
type test struct {
|
||||||
|
az *acme.Authorization
|
||||||
|
validate func(o *acme.Authorization)
|
||||||
|
}
|
||||||
|
var tests = map[string]test{
|
||||||
|
"ok": {
|
||||||
|
az: &acme.Authorization{
|
||||||
|
ID: azID,
|
||||||
|
Challenges: []*acme.Challenge{
|
||||||
|
{ID: chID0},
|
||||||
|
{ID: chID1},
|
||||||
|
{ID: chID2},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: func(az *acme.Authorization) {
|
||||||
|
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
|
||||||
|
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
|
||||||
|
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
l.LinkAuthorization(ctx, tc.az)
|
||||||
|
tc.validate(tc.az)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||||
|
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||||
|
|
||||||
|
linkerPrefix := "acme"
|
||||||
|
l := NewLinker("dns", linkerPrefix)
|
||||||
|
type test struct {
|
||||||
|
oids []string
|
||||||
|
}
|
||||||
|
var tests = map[string]test{
|
||||||
|
"ok": {
|
||||||
|
oids: []string{"foo", "bar", "baz"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
l.LinkOrdersByAccountID(ctx, tc.oids)
|
||||||
|
assert.Equals(t, tc.oids, []string{
|
||||||
|
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||||
|
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "bar"),
|
||||||
|
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "baz"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,13 +3,13 @@ package api
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
@ -54,7 +54,7 @@ func baseURLFromRequest(r *http.Request) *url.URL {
|
||||||
// E.g. https://ca.smallstep.com/
|
// E.g. https://ca.smallstep.com/
|
||||||
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
|
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := context.WithValue(r.Context(), acme.BaseURLContextKey, baseURLFromRequest(r))
|
ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,14 +62,14 @@ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
|
||||||
// addNonce is a middleware that adds a nonce to the response header.
|
// addNonce is a middleware that adds a nonce to the response header.
|
||||||
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
nonce, err := h.Auth.NewNonce()
|
nonce, err := h.db.CreateNonce(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Replay-Nonce", nonce)
|
w.Header().Set("Replay-Nonce", string(nonce))
|
||||||
w.Header().Set("Cache-Control", "no-store")
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
logNonce(w, nonce)
|
logNonce(w, string(nonce))
|
||||||
next(w, r)
|
next(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -78,8 +78,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||||
// directory index url.
|
// directory index url.
|
||||||
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Add("Link", link(h.Auth.GetLink(r.Context(),
|
w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index"))
|
||||||
acme.DirectoryLink, true), "index"))
|
|
||||||
next(w, r)
|
next(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,23 +87,31 @@ func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||||
// application/jose+json.
|
// application/jose+json.
|
||||||
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ct := r.Header.Get("Content-Type")
|
|
||||||
var expected []string
|
var expected []string
|
||||||
if strings.Contains(r.URL.Path, h.Auth.GetLink(r.Context(), acme.CertificateLink, false, "")) {
|
p, err := provisionerFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
|
||||||
|
if strings.Contains(r.URL.String(), u.EscapedPath()) {
|
||||||
// GET /certificate requests allow a greater range of content types.
|
// GET /certificate requests allow a greater range of content types.
|
||||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||||
} else {
|
} else {
|
||||||
// By default every request should have content-type applictaion/jose+json.
|
// By default every request should have content-type applictaion/jose+json.
|
||||||
expected = []string{"application/jose+json"}
|
expected = []string{"application/jose+json"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ct := r.Header.Get("Content-Type")
|
||||||
for _, e := range expected {
|
for _, e := range expected {
|
||||||
if ct == e {
|
if ct == e {
|
||||||
next(w, r)
|
next(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf(
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
"expected content-type to be in %s, but got %s", expected, ct)))
|
"expected content-type to be in %s, but got %s", expected, ct))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,15 +120,15 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body")))
|
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jws, err := jose.ParseJWS(string(body))
|
jws, err := jose.ParseJWS(string(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body")))
|
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(r.Context(), acme.JwsContextKey, jws)
|
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -143,17 +150,18 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||||
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||||
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
jws, err := acme.JwsFromContext(r.Context())
|
ctx := r.Context()
|
||||||
|
jws, err := jwsFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(jws.Signatures) == 0 {
|
if len(jws.Signatures) == 0 {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(jws.Signatures) > 1 {
|
if len(jws.Signatures) > 1 {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,35 +172,36 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
len(uh.Algorithm) > 0 ||
|
len(uh.Algorithm) > 0 ||
|
||||||
len(uh.Nonce) > 0 ||
|
len(uh.Nonce) > 0 ||
|
||||||
len(uh.ExtraHeaders) > 0 {
|
len(uh.ExtraHeaders) > 0 {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hdr := sig.Protected
|
hdr := sig.Protected
|
||||||
switch hdr.Algorithm {
|
switch hdr.Algorithm {
|
||||||
case jose.RS256, jose.RS384, jose.RS512:
|
case jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512:
|
||||||
if hdr.JSONWebKey != nil {
|
if hdr.JSONWebKey != nil {
|
||||||
switch k := hdr.JSONWebKey.Key.(type) {
|
switch k := hdr.JSONWebKey.Key.(type) {
|
||||||
case *rsa.PublicKey:
|
case *rsa.PublicKey:
|
||||||
if k.Size() < keyutil.MinRSAKeyBytes {
|
if k.Size() < keyutil.MinRSAKeyBytes {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
"keys must be at least %d bits (%d bytes) in size",
|
"rsa keys must be at least %d bits (%d bytes) in size",
|
||||||
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)))
|
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
|
"jws key type and algorithm do not match"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
||||||
// we good
|
// we good
|
||||||
default:
|
default:
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm)))
|
api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the validity/freshness of the Nonce.
|
// Check the validity/freshness of the Nonce.
|
||||||
if err := h.Auth.UseNonce(hdr.Nonce); err != nil {
|
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -200,21 +209,22 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
// Check that the JWS url matches the requested url.
|
// Check that the JWS url matches the requested url.
|
||||||
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
||||||
if jwsURL != reqURL.String() {
|
if jwsURL != reqURL.String() {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
|
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
|
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next(w, r)
|
next(w, r)
|
||||||
|
@ -227,24 +237,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := acme.JwsFromContext(r.Context())
|
jws, err := jwsFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwk := jws.Signatures[0].Protected.JSONWebKey
|
jwk := jws.Signatures[0].Protected.JSONWebKey
|
||||||
if jwk == nil {
|
if jwk == nil {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !jwk.Valid() {
|
if !jwk.Valid() {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
|
|
||||||
acc, err := h.Auth.GetAccountByKey(ctx, jwk)
|
// Overwrite KeyID with the JWK thumbprint.
|
||||||
|
jwk.KeyID, err = acme.KeyToID(jwk)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the JWK in the context.
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
|
|
||||||
|
// Get Account or continue to generate a new one.
|
||||||
|
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||||
switch {
|
switch {
|
||||||
case nosql.IsErrNotFound(err):
|
case errors.Is(err, acme.ErrNotFound):
|
||||||
// For NewAccount requests ...
|
// For NewAccount requests ...
|
||||||
break
|
break
|
||||||
case err != nil:
|
case err != nil:
|
||||||
|
@ -252,10 +273,10 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
if !acc.IsValid() {
|
if !acc.IsValid() {
|
||||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
}
|
}
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
|
@ -270,20 +291,20 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||||
name := chi.URLParam(r, "provisionerID")
|
name := chi.URLParam(r, "provisionerID")
|
||||||
provID, err := url.PathUnescape(name)
|
provID, err := url.PathUnescape(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name)))
|
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner id '%s'", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p, err := h.Auth.LoadProvisionerByID("acme/" + provID)
|
p, err := h.ca.LoadProvisionerByID("acme/" + provID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acmeProv, ok := p.(*provisioner.ACME)
|
acmeProv, ok := p.(*provisioner.ACME)
|
||||||
if !ok {
|
if !ok {
|
||||||
api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME")))
|
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, acme.ProvisionerContextKey, acme.Provisioner(acmeProv))
|
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -294,36 +315,37 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||||
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := acme.JwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
kidPrefix := h.Auth.GetLink(ctx, acme.AccountLink, true, "")
|
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
|
||||||
kid := jws.Signatures[0].Protected.KeyID
|
kid := jws.Signatures[0].Protected.KeyID
|
||||||
if !strings.HasPrefix(kid, kidPrefix) {
|
if !strings.HasPrefix(kid, kidPrefix) {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
"required prefix; expected %s, but got %s", kidPrefix, kid)))
|
"kid does not have required prefix; expected %s, but got %s",
|
||||||
|
kidPrefix, kid))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||||
acc, err := h.Auth.GetAccount(r.Context(), accID)
|
acc, err := h.db.GetAccount(ctx, accID)
|
||||||
switch {
|
switch {
|
||||||
case nosql.IsErrNotFound(err):
|
case nosql.IsErrNotFound(err):
|
||||||
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||||
return
|
return
|
||||||
case err != nil:
|
case err != nil:
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
if !acc.IsValid() {
|
if !acc.IsValid() {
|
||||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, acme.JwkContextKey, acc.Key)
|
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -334,26 +356,27 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||||
// Make sure to parse and validate the JWS before running this middleware.
|
// Make sure to parse and validate the JWS before running this middleware.
|
||||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
jws, err := acme.JwsFromContext(r.Context())
|
ctx := r.Context()
|
||||||
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwk, err := acme.JwkFromContext(r.Context())
|
jwk, err := jwkFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payload, err := jws.Verify(jwk)
|
payload, err := jws.Verify(jwk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws")))
|
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(r.Context(), acme.PayloadContextKey, &payloadInfo{
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
|
||||||
value: payload,
|
value: payload,
|
||||||
isPostAsGet: string(payload) == "",
|
isPostAsGet: string(payload) == "",
|
||||||
isEmptyJSON: string(payload) == "{}",
|
isEmptyJSON: string(payload) == "{}",
|
||||||
|
@ -371,9 +394,89 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !payload.isPostAsGet {
|
if !payload.isPostAsGet {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET")))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next(w, r)
|
next(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContextKey is the key type for storing and searching for ACME request
|
||||||
|
// essentials in the context of a request.
|
||||||
|
type ContextKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// accContextKey account key
|
||||||
|
accContextKey = ContextKey("acc")
|
||||||
|
// baseURLContextKey baseURL key
|
||||||
|
baseURLContextKey = ContextKey("baseURL")
|
||||||
|
// jwsContextKey jws key
|
||||||
|
jwsContextKey = ContextKey("jws")
|
||||||
|
// jwkContextKey jwk key
|
||||||
|
jwkContextKey = ContextKey("jwk")
|
||||||
|
// payloadContextKey payload key
|
||||||
|
payloadContextKey = ContextKey("payload")
|
||||||
|
// provisionerContextKey provisioner key
|
||||||
|
provisionerContextKey = ContextKey("provisioner")
|
||||||
|
)
|
||||||
|
|
||||||
|
// accountFromContext searches the context for an ACME account. Returns the
|
||||||
|
// account or an error.
|
||||||
|
func accountFromContext(ctx context.Context) (*acme.Account, error) {
|
||||||
|
val, ok := ctx.Value(accContextKey).(*acme.Account)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil, acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context")
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// baseURLFromContext returns the baseURL if one is stored in the context.
|
||||||
|
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||||
|
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
|
||||||
|
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
|
||||||
|
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil, acme.NewErrorISE("jwk expected in request context")
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwsFromContext searches the context for a JWS. Returns the JWS or an error.
|
||||||
|
func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
|
||||||
|
val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil, acme.NewErrorISE("jws expected in request context")
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// provisionerFromContext searches the context for a provisioner. Returns the
|
||||||
|
// provisioner or an error.
|
||||||
|
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
|
||||||
|
val := ctx.Value(provisionerContextKey)
|
||||||
|
if val == nil {
|
||||||
|
return nil, acme.NewErrorISE("provisioner expected in request context")
|
||||||
|
}
|
||||||
|
pval, ok := val.(acme.Provisioner)
|
||||||
|
if !ok || pval == nil {
|
||||||
|
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||||
|
}
|
||||||
|
return pval, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// payloadFromContext searches the context for a payload. Returns the payload
|
||||||
|
// or an error.
|
||||||
|
func payloadFromContext(ctx context.Context) (*payloadInfo, error) {
|
||||||
|
val, ok := ctx.Value(payloadContextKey).(*payloadInfo)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil, acme.NewErrorISE("payload expected in request context")
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,16 +1,18 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
|
"go.step.sm/crypto/randutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewOrderRequest represents the body for a NewOrder request.
|
// NewOrderRequest represents the body for a NewOrder request.
|
||||||
|
@ -23,11 +25,11 @@ type NewOrderRequest struct {
|
||||||
// Validate validates a new-order request body.
|
// Validate validates a new-order request body.
|
||||||
func (n *NewOrderRequest) Validate() error {
|
func (n *NewOrderRequest) Validate() error {
|
||||||
if len(n.Identifiers) == 0 {
|
if len(n.Identifiers) == 0 {
|
||||||
return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty"))
|
return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty")
|
||||||
}
|
}
|
||||||
for _, id := range n.Identifiers {
|
for _, id := range n.Identifiers {
|
||||||
if id.Type != "dns" {
|
if id.Type != "dns" {
|
||||||
return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type))
|
return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -44,22 +46,30 @@ func (f *FinalizeRequest) Validate() error {
|
||||||
var err error
|
var err error
|
||||||
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
|
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr"))
|
return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr")
|
||||||
}
|
}
|
||||||
f.csr, err = x509.ParseCertificateRequest(csrBytes)
|
f.csr, err = x509.ParseCertificateRequest(csrBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return acme.MalformedErr(errors.Wrap(err, "unable to parse csr"))
|
return acme.WrapError(acme.ErrorMalformedType, err, "unable to parse csr")
|
||||||
}
|
}
|
||||||
if err = f.csr.CheckSignature(); err != nil {
|
if err = f.csr.CheckSignature(); err != nil {
|
||||||
return acme.MalformedErr(errors.Wrap(err, "csr failed signature check"))
|
return acme.WrapError(acme.ErrorMalformedType, err, "csr failed signature check")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultOrderExpiry = time.Hour * 24
|
||||||
|
var defaultOrderBackdate = time.Minute
|
||||||
|
|
||||||
// NewOrder ACME api for creating a new order.
|
// NewOrder ACME api for creating a new order.
|
||||||
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := acme.AccountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
prov, err := provisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
|
@ -71,8 +81,8 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
var nor NewOrderRequest
|
var nor NewOrderRequest
|
||||||
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||||
"failed to unmarshal new-order request payload")))
|
"failed to unmarshal new-order request payload"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := nor.Validate(); err != nil {
|
if err := nor.Validate(); err != nil {
|
||||||
|
@ -80,44 +90,146 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
o, err := h.Auth.NewOrder(ctx, acme.OrderOptions{
|
now := clock.Now()
|
||||||
AccountID: acc.GetID(),
|
// New order.
|
||||||
|
o := &acme.Order{
|
||||||
|
AccountID: acc.ID,
|
||||||
|
ProvisionerID: prov.GetID(),
|
||||||
|
Status: acme.StatusPending,
|
||||||
Identifiers: nor.Identifiers,
|
Identifiers: nor.Identifiers,
|
||||||
|
ExpiresAt: now.Add(defaultOrderExpiry),
|
||||||
|
AuthorizationIDs: make([]string, len(nor.Identifiers)),
|
||||||
NotBefore: nor.NotBefore,
|
NotBefore: nor.NotBefore,
|
||||||
NotAfter: nor.NotAfter,
|
NotAfter: nor.NotAfter,
|
||||||
})
|
}
|
||||||
if err != nil {
|
|
||||||
|
for i, identifier := range o.Identifiers {
|
||||||
|
az := &acme.Authorization{
|
||||||
|
AccountID: acc.ID,
|
||||||
|
Identifier: identifier,
|
||||||
|
ExpiresAt: o.ExpiresAt,
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
}
|
||||||
|
if err := h.newAuthorization(ctx, az); err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
o.AuthorizationIDs[i] = az.ID
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID()))
|
if o.NotBefore.IsZero() {
|
||||||
|
o.NotBefore = now
|
||||||
|
}
|
||||||
|
if o.NotAfter.IsZero() {
|
||||||
|
o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration())
|
||||||
|
}
|
||||||
|
// If request NotBefore was empty then backdate the order.NotBefore (now)
|
||||||
|
// to avoid timing issues.
|
||||||
|
if nor.NotBefore.IsZero() {
|
||||||
|
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.db.CreateOrder(ctx, o); err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error creating order"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.linker.LinkOrder(ctx, o)
|
||||||
|
|
||||||
|
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||||
api.JSONStatus(w, o, http.StatusCreated)
|
api.JSONStatus(w, o, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||||
|
if strings.HasPrefix(az.Identifier.Value, "*.") {
|
||||||
|
az.Wildcard = true
|
||||||
|
az.Identifier = acme.Identifier{
|
||||||
|
Value: strings.TrimPrefix(az.Identifier.Value, "*."),
|
||||||
|
Type: az.Identifier.Type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
chTypes = []string{"dns-01"}
|
||||||
|
)
|
||||||
|
// HTTP and TLS challenges can only be used for identifiers without wildcards.
|
||||||
|
if !az.Wildcard {
|
||||||
|
chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
az.Token, err = randutil.Alphanumeric(32)
|
||||||
|
if err != nil {
|
||||||
|
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
|
||||||
|
}
|
||||||
|
az.Challenges = make([]*acme.Challenge, len(chTypes))
|
||||||
|
for i, typ := range chTypes {
|
||||||
|
ch := &acme.Challenge{
|
||||||
|
AccountID: az.AccountID,
|
||||||
|
Value: az.Identifier.Value,
|
||||||
|
Type: typ,
|
||||||
|
Token: az.Token,
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
}
|
||||||
|
if err := h.db.CreateChallenge(ctx, ch); err != nil {
|
||||||
|
return acme.WrapErrorISE(err, "error creating challenge")
|
||||||
|
}
|
||||||
|
az.Challenges[i] = ch
|
||||||
|
}
|
||||||
|
if err = h.db.CreateAuthorization(ctx, az); err != nil {
|
||||||
|
return acme.WrapErrorISE(err, "error creating authorization")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetOrder ACME api for retrieving an order.
|
// GetOrder ACME api for retrieving an order.
|
||||||
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := acme.AccountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
oid := chi.URLParam(r, "ordID")
|
prov, err := provisionerFromContext(ctx)
|
||||||
o, err := h.Auth.GetOrder(ctx, acc.GetID(), oid)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if acc.ID != o.AccountID {
|
||||||
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if prov.GetID() != o.ProvisionerID {
|
||||||
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = o.UpdateStatus(ctx, h.db); err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error updating order status"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID()))
|
h.linker.LinkOrder(ctx, o)
|
||||||
|
|
||||||
|
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||||
api.JSON(w, o)
|
api.JSON(w, o)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||||
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := acme.AccountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
prov, err := provisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
|
@ -129,7 +241,8 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
var fr FinalizeRequest
|
var fr FinalizeRequest
|
||||||
if err := json.Unmarshal(payload.value, &fr); err != nil {
|
if err := json.Unmarshal(payload.value, &fr); err != nil {
|
||||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload")))
|
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||||
|
"failed to unmarshal finalize-order request payload"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := fr.Validate(); err != nil {
|
if err := fr.Validate(); err != nil {
|
||||||
|
@ -137,13 +250,28 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
oid := chi.URLParam(r, "ordID")
|
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||||
o, err := h.Auth.FinalizeOrder(ctx, acc.GetID(), oid, fr.csr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if acc.ID != o.AccountID {
|
||||||
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if prov.GetID() != o.ProvisionerID {
|
||||||
|
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.ID))
|
h.linker.LinkOrder(ctx, o)
|
||||||
|
|
||||||
|
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||||
api.JSON(w, o)
|
api.JSON(w, o)
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,342 +0,0 @@
|
||||||
package acme
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/base64"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
database "github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
"go.step.sm/crypto/jose"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Interface is the acme authority interface.
|
|
||||||
type Interface interface {
|
|
||||||
GetDirectory(ctx context.Context) (*Directory, error)
|
|
||||||
NewNonce() (string, error)
|
|
||||||
UseNonce(string) error
|
|
||||||
|
|
||||||
DeactivateAccount(ctx context.Context, accID string) (*Account, error)
|
|
||||||
GetAccount(ctx context.Context, accID string) (*Account, error)
|
|
||||||
GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error)
|
|
||||||
NewAccount(ctx context.Context, ao AccountOptions) (*Account, error)
|
|
||||||
UpdateAccount(context.Context, string, []string) (*Account, error)
|
|
||||||
|
|
||||||
GetAuthz(ctx context.Context, accID string, authzID string) (*Authz, error)
|
|
||||||
ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error)
|
|
||||||
|
|
||||||
FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error)
|
|
||||||
GetOrder(ctx context.Context, accID string, orderID string) (*Order, error)
|
|
||||||
GetOrdersByAccount(ctx context.Context, accID string) ([]string, error)
|
|
||||||
NewOrder(ctx context.Context, oo OrderOptions) (*Order, error)
|
|
||||||
|
|
||||||
GetCertificate(string, string) ([]byte, error)
|
|
||||||
|
|
||||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
|
||||||
GetLink(ctx context.Context, linkType Link, absoluteLink bool, inputs ...string) string
|
|
||||||
GetLinkExplicit(linkType Link, provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authority is the layer that handles all ACME interactions.
|
|
||||||
type Authority struct {
|
|
||||||
backdate provisioner.Duration
|
|
||||||
db nosql.DB
|
|
||||||
dir *directory
|
|
||||||
signAuth SignAuthority
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthorityOptions required to create a new ACME Authority.
|
|
||||||
type AuthorityOptions struct {
|
|
||||||
Backdate provisioner.Duration
|
|
||||||
// DB is the database used by nosql.
|
|
||||||
DB nosql.DB
|
|
||||||
// DNS the host used to generate accurate ACME links. By default the authority
|
|
||||||
// will use the Host from the request, so this value will only be used if
|
|
||||||
// request.Host is empty.
|
|
||||||
DNS string
|
|
||||||
// Prefix is a URL path prefix under which the ACME api is served. This
|
|
||||||
// prefix is required to generate accurate ACME links.
|
|
||||||
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
|
|
||||||
// "acme" is the prefix from which the ACME api is accessed.
|
|
||||||
Prefix string
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
accountTable = []byte("acme_accounts")
|
|
||||||
accountByKeyIDTable = []byte("acme_keyID_accountID_index")
|
|
||||||
authzTable = []byte("acme_authzs")
|
|
||||||
challengeTable = []byte("acme_challenges")
|
|
||||||
nonceTable = []byte("nonces")
|
|
||||||
orderTable = []byte("acme_orders")
|
|
||||||
ordersByAccountIDTable = []byte("acme_account_orders_index")
|
|
||||||
certTable = []byte("acme_certs")
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewAuthority returns a new Authority that implements the ACME interface.
|
|
||||||
//
|
|
||||||
// Deprecated: NewAuthority exists for hitorical compatibility and should not
|
|
||||||
// be used. Use acme.New() instead.
|
|
||||||
func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) {
|
|
||||||
return New(signAuth, AuthorityOptions{
|
|
||||||
DB: db,
|
|
||||||
DNS: dns,
|
|
||||||
Prefix: prefix,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// New returns a new Autohrity that implements the ACME interface.
|
|
||||||
func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) {
|
|
||||||
if _, ok := ops.DB.(*database.SimpleDB); !ok {
|
|
||||||
// If it's not a SimpleDB then go ahead and bootstrap the DB with the
|
|
||||||
// necessary ACME tables. SimpleDB should ONLY be used for testing.
|
|
||||||
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
|
|
||||||
challengeTable, nonceTable, orderTable, ordersByAccountIDTable,
|
|
||||||
certTable}
|
|
||||||
for _, b := range tables {
|
|
||||||
if err := ops.DB.CreateTable(b); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "error creating table %s",
|
|
||||||
string(b))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &Authority{
|
|
||||||
backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLink returns the requested link from the directory.
|
|
||||||
func (a *Authority) GetLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
|
|
||||||
return a.dir.getLink(ctx, typ, abs, inputs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLinkExplicit returns the requested link from the directory.
|
|
||||||
func (a *Authority) GetLinkExplicit(typ Link, provName string, abs bool, baseURL *url.URL, inputs ...string) string {
|
|
||||||
return a.dir.getLinkExplicit(typ, provName, abs, baseURL, inputs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDirectory returns the ACME directory object.
|
|
||||||
func (a *Authority) GetDirectory(ctx context.Context) (*Directory, error) {
|
|
||||||
return &Directory{
|
|
||||||
NewNonce: a.dir.getLink(ctx, NewNonceLink, true),
|
|
||||||
NewAccount: a.dir.getLink(ctx, NewAccountLink, true),
|
|
||||||
NewOrder: a.dir.getLink(ctx, NewOrderLink, true),
|
|
||||||
RevokeCert: a.dir.getLink(ctx, RevokeCertLink, true),
|
|
||||||
KeyChange: a.dir.getLink(ctx, KeyChangeLink, true),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadProvisionerByID calls out to the SignAuthority interface to load a
|
|
||||||
// provisioner by ID.
|
|
||||||
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
|
||||||
return a.signAuth.LoadProvisionerByID(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewNonce generates, stores, and returns a new ACME nonce.
|
|
||||||
func (a *Authority) NewNonce() (string, error) {
|
|
||||||
n, err := newNonce(a.db)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return n.ID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UseNonce consumes the given nonce if it is valid, returns error otherwise.
|
|
||||||
func (a *Authority) UseNonce(nonce string) error {
|
|
||||||
return useNonce(a.db, nonce)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAccount creates, stores, and returns a new ACME account.
|
|
||||||
func (a *Authority) NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) {
|
|
||||||
acc, err := newAccount(a.db, ao)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return acc.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateAccount updates an ACME account.
|
|
||||||
func (a *Authority) UpdateAccount(ctx context.Context, id string, contact []string) (*Account, error) {
|
|
||||||
acc, err := getAccountByID(a.db, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, ServerInternalErr(err)
|
|
||||||
}
|
|
||||||
if acc, err = acc.update(a.db, contact); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return acc.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount returns an ACME account.
|
|
||||||
func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) {
|
|
||||||
acc, err := getAccountByID(a.db, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return acc.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeactivateAccount deactivates an ACME account.
|
|
||||||
func (a *Authority) DeactivateAccount(ctx context.Context, id string) (*Account, error) {
|
|
||||||
acc, err := getAccountByID(a.db, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if acc, err = acc.deactivate(a.db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return acc.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
func keyToID(jwk *jose.JSONWebKey) (string, error) {
|
|
||||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
|
||||||
if err != nil {
|
|
||||||
return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint"))
|
|
||||||
}
|
|
||||||
return base64.RawURLEncoding.EncodeToString(kid), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByKey returns the ACME associated with the jwk id.
|
|
||||||
func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) {
|
|
||||||
kid, err := keyToID(jwk)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
acc, err := getAccountByKeyID(a.db, kid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return acc.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOrder returns an ACME order.
|
|
||||||
func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) {
|
|
||||||
o, err := getOrder(a.db, orderID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if accID != o.AccountID {
|
|
||||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
|
||||||
}
|
|
||||||
if o, err = o.updateStatus(a.db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return o.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOrdersByAccount returns the list of order urls owned by the account.
|
|
||||||
func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) {
|
|
||||||
ordersByAccountMux.Lock()
|
|
||||||
defer ordersByAccountMux.Unlock()
|
|
||||||
|
|
||||||
var oiba = orderIDsByAccount{}
|
|
||||||
oids, err := oiba.unsafeGetOrderIDsByAccount(a.db, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var ret = []string{}
|
|
||||||
for _, oid := range oids {
|
|
||||||
ret = append(ret, a.dir.getLink(ctx, OrderLink, true, oid))
|
|
||||||
}
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewOrder generates, stores, and returns a new ACME order.
|
|
||||||
func (a *Authority) NewOrder(ctx context.Context, ops OrderOptions) (*Order, error) {
|
|
||||||
prov, err := ProvisionerFromContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ops.backdate = a.backdate.Duration
|
|
||||||
ops.defaultDuration = prov.DefaultTLSCertDuration()
|
|
||||||
order, err := newOrder(a.db, ops)
|
|
||||||
if err != nil {
|
|
||||||
return nil, Wrap(err, "error creating order")
|
|
||||||
}
|
|
||||||
return order.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FinalizeOrder attempts to finalize an order and generate a new certificate.
|
|
||||||
func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
|
|
||||||
prov, err := ProvisionerFromContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
o, err := getOrder(a.db, orderID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if accID != o.AccountID {
|
|
||||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
|
||||||
}
|
|
||||||
o, err = o.finalize(a.db, csr, a.signAuth, prov)
|
|
||||||
if err != nil {
|
|
||||||
return nil, Wrap(err, "error finalizing order")
|
|
||||||
}
|
|
||||||
return o.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAuthz retrieves and attempts to update the status on an ACME authz
|
|
||||||
// before returning.
|
|
||||||
func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authz, error) {
|
|
||||||
az, err := getAuthz(a.db, authzID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if accID != az.getAccountID() {
|
|
||||||
return nil, UnauthorizedErr(errors.New("account does not own authz"))
|
|
||||||
}
|
|
||||||
az, err = az.updateStatus(a.db)
|
|
||||||
if err != nil {
|
|
||||||
return nil, Wrap(err, "error updating authz status")
|
|
||||||
}
|
|
||||||
return az.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateChallenge attempts to validate the challenge.
|
|
||||||
func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) {
|
|
||||||
ch, err := getChallenge(a.db, chID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if accID != ch.getAccountID() {
|
|
||||||
return nil, UnauthorizedErr(errors.New("account does not own challenge"))
|
|
||||||
}
|
|
||||||
client := http.Client{
|
|
||||||
Timeout: time.Duration(30 * time.Second),
|
|
||||||
}
|
|
||||||
dialer := &net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
}
|
|
||||||
ch, err = ch.validate(a.db, jwk, validateOptions{
|
|
||||||
httpGet: client.Get,
|
|
||||||
lookupTxt: net.LookupTXT,
|
|
||||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
|
||||||
return tls.DialWithDialer(dialer, network, addr, config)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, Wrap(err, "error attempting challenge validation")
|
|
||||||
}
|
|
||||||
return ch.toACME(ctx, a.db, a.dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCertificate retrieves the Certificate by ID.
|
|
||||||
func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) {
|
|
||||||
cert, err := getCert(a.db, certID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if accID != cert.AccountID {
|
|
||||||
return nil, UnauthorizedErr(errors.New("account does not own certificate"))
|
|
||||||
}
|
|
||||||
return cert.toACME(a.db, a.dir)
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load diff
69
acme/authorization.go
Normal file
69
acme/authorization.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Authorization representst an ACME Authorization.
|
||||||
|
type Authorization struct {
|
||||||
|
ID string `json:"-"`
|
||||||
|
AccountID string `json:"-"`
|
||||||
|
Token string `json:"-"`
|
||||||
|
Identifier Identifier `json:"identifier"`
|
||||||
|
Status Status `json:"status"`
|
||||||
|
Challenges []*Challenge `json:"challenges"`
|
||||||
|
Wildcard bool `json:"wildcard"`
|
||||||
|
ExpiresAt time.Time `json:"expires"`
|
||||||
|
Error *Error `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToLog enables response logging.
|
||||||
|
func (az *Authorization) ToLog() (interface{}, error) {
|
||||||
|
b, err := json.Marshal(az)
|
||||||
|
if err != nil {
|
||||||
|
return nil, WrapErrorISE(err, "error marshaling authz for logging")
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStatus updates the ACME Authorization Status if necessary.
|
||||||
|
// Changes to the Authorization are saved using the database interface.
|
||||||
|
func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
|
||||||
|
now := clock.Now()
|
||||||
|
|
||||||
|
switch az.Status {
|
||||||
|
case StatusInvalid:
|
||||||
|
return nil
|
||||||
|
case StatusValid:
|
||||||
|
return nil
|
||||||
|
case StatusPending:
|
||||||
|
// check expiry
|
||||||
|
if now.After(az.ExpiresAt) {
|
||||||
|
az.Status = StatusInvalid
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var isValid = false
|
||||||
|
for _, ch := range az.Challenges {
|
||||||
|
if ch.Status == StatusValid {
|
||||||
|
isValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
az.Status = StatusValid
|
||||||
|
az.Error = nil
|
||||||
|
default:
|
||||||
|
return NewErrorISE("unrecognized authorization status: %s", az.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.UpdateAuthorization(ctx, az); err != nil {
|
||||||
|
return WrapErrorISE(err, "error updating authorization")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
150
acme/authorization_test.go
Normal file
150
acme/authorization_test.go
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthorization_UpdateStatus(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
az *Authorization
|
||||||
|
err *Error
|
||||||
|
db DB
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"ok/already-invalid": func(t *testing.T) test {
|
||||||
|
az := &Authorization{
|
||||||
|
Status: StatusInvalid,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/already-valid": func(t *testing.T) test {
|
||||||
|
az := &Authorization{
|
||||||
|
Status: StatusInvalid,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/error-unexpected-status": func(t *testing.T) test {
|
||||||
|
az := &Authorization{
|
||||||
|
Status: "foo",
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
err: NewErrorISE("unrecognized authorization status: %s", az.Status),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/expired": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
az := &Authorization{
|
||||||
|
ID: "azID",
|
||||||
|
AccountID: "accID",
|
||||||
|
Status: StatusPending,
|
||||||
|
ExpiresAt: now.Add(-5 * time.Minute),
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
db: &MockDB{
|
||||||
|
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
|
||||||
|
assert.Equals(t, updaz.ID, az.ID)
|
||||||
|
assert.Equals(t, updaz.AccountID, az.AccountID)
|
||||||
|
assert.Equals(t, updaz.Status, StatusInvalid)
|
||||||
|
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.UpdateAuthorization-error": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
az := &Authorization{
|
||||||
|
ID: "azID",
|
||||||
|
AccountID: "accID",
|
||||||
|
Status: StatusPending,
|
||||||
|
ExpiresAt: now.Add(-5 * time.Minute),
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
db: &MockDB{
|
||||||
|
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
|
||||||
|
assert.Equals(t, updaz.ID, az.ID)
|
||||||
|
assert.Equals(t, updaz.AccountID, az.AccountID)
|
||||||
|
assert.Equals(t, updaz.Status, StatusInvalid)
|
||||||
|
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
|
||||||
|
return errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: NewErrorISE("error updating authorization: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/no-valid-challenges": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
az := &Authorization{
|
||||||
|
ID: "azID",
|
||||||
|
AccountID: "accID",
|
||||||
|
Status: StatusPending,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
Challenges: []*Challenge{
|
||||||
|
{Status: StatusPending}, {Status: StatusPending}, {Status: StatusPending},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/valid": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
az := &Authorization{
|
||||||
|
ID: "azID",
|
||||||
|
AccountID: "accID",
|
||||||
|
Status: StatusPending,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
Challenges: []*Challenge{
|
||||||
|
{Status: StatusPending}, {Status: StatusPending}, {Status: StatusValid},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
db: &MockDB{
|
||||||
|
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
|
||||||
|
assert.Equals(t, updaz.ID, az.ID)
|
||||||
|
assert.Equals(t, updaz.AccountID, az.AccountID)
|
||||||
|
assert.Equals(t, updaz.Status, StatusValid)
|
||||||
|
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
|
||||||
|
assert.Equals(t, updaz.Error, nil)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *Error:
|
||||||
|
assert.Equals(t, k.Type, tc.err.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.err.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.err.Detail)
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.New("unexpected error type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
347
acme/authz.go
347
acme/authz.go
|
@ -1,347 +0,0 @@
|
||||||
package acme
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
)
|
|
||||||
|
|
||||||
var defaultExpiryDuration = time.Hour * 24
|
|
||||||
|
|
||||||
// Authz is a subset of the Authz type containing only those attributes
|
|
||||||
// required for responses in the ACME protocol.
|
|
||||||
type Authz struct {
|
|
||||||
Identifier Identifier `json:"identifier"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Expires string `json:"expires"`
|
|
||||||
Challenges []*Challenge `json:"challenges"`
|
|
||||||
Wildcard bool `json:"wildcard"`
|
|
||||||
ID string `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToLog enables response logging.
|
|
||||||
func (a *Authz) ToLog() (interface{}, error) {
|
|
||||||
b, err := json.Marshal(a)
|
|
||||||
if err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging"))
|
|
||||||
}
|
|
||||||
return string(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetID returns the Authz ID.
|
|
||||||
func (a *Authz) GetID() string {
|
|
||||||
return a.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
// authz is the interface that the various authz types must implement.
|
|
||||||
type authz interface {
|
|
||||||
save(nosql.DB, authz) error
|
|
||||||
clone() *baseAuthz
|
|
||||||
getID() string
|
|
||||||
getAccountID() string
|
|
||||||
getType() string
|
|
||||||
getIdentifier() Identifier
|
|
||||||
getStatus() string
|
|
||||||
getExpiry() time.Time
|
|
||||||
getWildcard() bool
|
|
||||||
getChallenges() []string
|
|
||||||
getCreated() time.Time
|
|
||||||
updateStatus(db nosql.DB) (authz, error)
|
|
||||||
toACME(context.Context, nosql.DB, *directory) (*Authz, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// baseAuthz is the base authz type that others build from.
|
|
||||||
type baseAuthz struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
AccountID string `json:"accountID"`
|
|
||||||
Identifier Identifier `json:"identifier"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Expires time.Time `json:"expires"`
|
|
||||||
Challenges []string `json:"challenges"`
|
|
||||||
Wildcard bool `json:"wildcard"`
|
|
||||||
Created time.Time `json:"created"`
|
|
||||||
Error *Error `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) {
|
|
||||||
id, err := randID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
now := clock.Now()
|
|
||||||
ba := &baseAuthz{
|
|
||||||
ID: id,
|
|
||||||
AccountID: accID,
|
|
||||||
Status: StatusPending,
|
|
||||||
Created: now,
|
|
||||||
Expires: now.Add(defaultExpiryDuration),
|
|
||||||
Identifier: identifier,
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(identifier.Value, "*.") {
|
|
||||||
ba.Wildcard = true
|
|
||||||
ba.Identifier = Identifier{
|
|
||||||
Value: strings.TrimPrefix(identifier.Value, "*."),
|
|
||||||
Type: identifier.Type,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ba, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getID returns the ID of the authz.
|
|
||||||
func (ba *baseAuthz) getID() string {
|
|
||||||
return ba.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAccountID returns the Account ID that created the authz.
|
|
||||||
func (ba *baseAuthz) getAccountID() string {
|
|
||||||
return ba.AccountID
|
|
||||||
}
|
|
||||||
|
|
||||||
// getType returns the type of the authz.
|
|
||||||
func (ba *baseAuthz) getType() string {
|
|
||||||
return ba.Identifier.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
// getIdentifier returns the identifier for the authz.
|
|
||||||
func (ba *baseAuthz) getIdentifier() Identifier {
|
|
||||||
return ba.Identifier
|
|
||||||
}
|
|
||||||
|
|
||||||
// getStatus returns the status of the authz.
|
|
||||||
func (ba *baseAuthz) getStatus() string {
|
|
||||||
return ba.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
// getWildcard returns true if the authz identifier has a '*', false otherwise.
|
|
||||||
func (ba *baseAuthz) getWildcard() bool {
|
|
||||||
return ba.Wildcard
|
|
||||||
}
|
|
||||||
|
|
||||||
// getChallenges returns the authz challenge IDs.
|
|
||||||
func (ba *baseAuthz) getChallenges() []string {
|
|
||||||
return ba.Challenges
|
|
||||||
}
|
|
||||||
|
|
||||||
// getExpiry returns the expiration time of the authz.
|
|
||||||
func (ba *baseAuthz) getExpiry() time.Time {
|
|
||||||
return ba.Expires
|
|
||||||
}
|
|
||||||
|
|
||||||
// getCreated returns the created time of the authz.
|
|
||||||
func (ba *baseAuthz) getCreated() time.Time {
|
|
||||||
return ba.Created
|
|
||||||
}
|
|
||||||
|
|
||||||
// toACME converts the internal Authz type into the public acmeAuthz type for
|
|
||||||
// presentation in the ACME protocol.
|
|
||||||
func (ba *baseAuthz) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Authz, error) {
|
|
||||||
var chs = make([]*Challenge, len(ba.Challenges))
|
|
||||||
for i, chID := range ba.Challenges {
|
|
||||||
ch, err := getChallenge(db, chID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
chs[i], err = ch.toACME(ctx, db, dir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &Authz{
|
|
||||||
Identifier: ba.Identifier,
|
|
||||||
Status: ba.getStatus(),
|
|
||||||
Challenges: chs,
|
|
||||||
Wildcard: ba.getWildcard(),
|
|
||||||
Expires: ba.Expires.Format(time.RFC3339),
|
|
||||||
ID: ba.ID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ba *baseAuthz) save(db nosql.DB, old authz) error {
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
oldB, newB []byte
|
|
||||||
)
|
|
||||||
if old == nil {
|
|
||||||
oldB = nil
|
|
||||||
} else {
|
|
||||||
if oldB, err = json.Marshal(old); err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old authz"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if newB, err = json.Marshal(ba); err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new authz"))
|
|
||||||
}
|
|
||||||
_, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB)
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
return ServerInternalErr(errors.Wrapf(err, "error storing authz"))
|
|
||||||
case !swapped:
|
|
||||||
return ServerInternalErr(errors.Errorf("error storing authz; " +
|
|
||||||
"value has changed since last read"))
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ba *baseAuthz) clone() *baseAuthz {
|
|
||||||
u := *ba
|
|
||||||
return &u
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ba *baseAuthz) parent() authz {
|
|
||||||
return &dnsAuthz{ba}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateStatus attempts to update the status on a baseAuthz and stores the
|
|
||||||
// updating object if necessary.
|
|
||||||
func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) {
|
|
||||||
newAuthz := ba.clone()
|
|
||||||
|
|
||||||
now := time.Now().UTC()
|
|
||||||
switch ba.Status {
|
|
||||||
case StatusInvalid:
|
|
||||||
return ba.parent(), nil
|
|
||||||
case StatusValid:
|
|
||||||
return ba.parent(), nil
|
|
||||||
case StatusPending:
|
|
||||||
// check expiry
|
|
||||||
if now.After(ba.Expires) {
|
|
||||||
newAuthz.Status = StatusInvalid
|
|
||||||
newAuthz.Error = MalformedErr(errors.New("authz has expired"))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
var isValid = false
|
|
||||||
for _, chID := range ba.Challenges {
|
|
||||||
ch, err := getChallenge(db, chID)
|
|
||||||
if err != nil {
|
|
||||||
return ba, err
|
|
||||||
}
|
|
||||||
if ch.getStatus() == StatusValid {
|
|
||||||
isValid = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isValid {
|
|
||||||
return ba.parent(), nil
|
|
||||||
}
|
|
||||||
newAuthz.Status = StatusValid
|
|
||||||
newAuthz.Error = nil
|
|
||||||
default:
|
|
||||||
return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := newAuthz.save(db, ba); err != nil {
|
|
||||||
return ba, err
|
|
||||||
}
|
|
||||||
return newAuthz.parent(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// unmarshalAuthz unmarshals an authz type into the correct sub-type.
|
|
||||||
func unmarshalAuthz(data []byte) (authz, error) {
|
|
||||||
var getType struct {
|
|
||||||
Identifier Identifier `json:"identifier"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(data, &getType); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type"))
|
|
||||||
}
|
|
||||||
|
|
||||||
switch getType.Identifier.Type {
|
|
||||||
case "dns":
|
|
||||||
var ba baseAuthz
|
|
||||||
if err := json.Unmarshal(data, &ba); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz"))
|
|
||||||
}
|
|
||||||
return &dnsAuthz{&ba}, nil
|
|
||||||
default:
|
|
||||||
return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s",
|
|
||||||
getType.Identifier.Type))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dnsAuthz represents a dns acme authorization.
|
|
||||||
type dnsAuthz struct {
|
|
||||||
*baseAuthz
|
|
||||||
}
|
|
||||||
|
|
||||||
// newAuthz returns a new acme authorization object based on the identifier
|
|
||||||
// type.
|
|
||||||
func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) {
|
|
||||||
switch identifier.Type {
|
|
||||||
case "dns":
|
|
||||||
a, err = newDNSAuthz(db, accID, identifier)
|
|
||||||
default:
|
|
||||||
err = MalformedErr(errors.Errorf("unexpected authz type %s",
|
|
||||||
identifier.Type))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDNSAuthz returns a new dns acme authorization object.
|
|
||||||
func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) {
|
|
||||||
ba, err := newBaseAuthz(accID, identifier)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ba.Challenges = []string{}
|
|
||||||
if !ba.Wildcard {
|
|
||||||
// http and alpn challenges are only permitted if the DNS is not a wildcard dns.
|
|
||||||
ch1, err := newHTTP01Challenge(db, ChallengeOptions{
|
|
||||||
AccountID: accID,
|
|
||||||
AuthzID: ba.ID,
|
|
||||||
Identifier: ba.Identifier})
|
|
||||||
if err != nil {
|
|
||||||
return nil, Wrap(err, "error creating http challenge")
|
|
||||||
}
|
|
||||||
ba.Challenges = append(ba.Challenges, ch1.getID())
|
|
||||||
|
|
||||||
ch2, err := newTLSALPN01Challenge(db, ChallengeOptions{
|
|
||||||
AccountID: accID,
|
|
||||||
AuthzID: ba.ID,
|
|
||||||
Identifier: ba.Identifier,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, Wrap(err, "error creating alpn challenge")
|
|
||||||
}
|
|
||||||
ba.Challenges = append(ba.Challenges, ch2.getID())
|
|
||||||
}
|
|
||||||
ch3, err := newDNS01Challenge(db, ChallengeOptions{
|
|
||||||
AccountID: accID,
|
|
||||||
AuthzID: ba.ID,
|
|
||||||
Identifier: identifier})
|
|
||||||
if err != nil {
|
|
||||||
return nil, Wrap(err, "error creating dns challenge")
|
|
||||||
}
|
|
||||||
ba.Challenges = append(ba.Challenges, ch3.getID())
|
|
||||||
|
|
||||||
da := &dnsAuthz{ba}
|
|
||||||
if err := da.save(db, nil); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return da, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAuthz retrieves and unmarshals an ACME authz type from the database.
|
|
||||||
func getAuthz(db nosql.DB, id string) (authz, error) {
|
|
||||||
b, err := db.Get(authzTable, []byte(id))
|
|
||||||
if nosql.IsErrNotFound(err) {
|
|
||||||
return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id))
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id))
|
|
||||||
}
|
|
||||||
az, err := unmarshalAuthz(b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return az, nil
|
|
||||||
}
|
|
|
@ -1,836 +0,0 @@
|
||||||
package acme
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
"github.com/smallstep/nosql/database"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newAz() (authz, error) {
|
|
||||||
mockdb := &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return []byte("foo"), true, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return newAuthz(mockdb, "1234", Identifier{
|
|
||||||
Type: "dns", Value: "acme.example.com",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAuthz(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
id string
|
|
||||||
db nosql.DB
|
|
||||||
az authz
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/not-found": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
id: az.getID(),
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
return nil, database.ErrNotFound
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/db-error": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
id: az.getID(),
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/unmarshal-error": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_az, ok := az.(*dnsAuthz)
|
|
||||||
assert.Fatal(t, ok)
|
|
||||||
_az.baseAuthz.Identifier.Type = "foo"
|
|
||||||
b, err := json.Marshal(az)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
id: az.getID(),
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, bucket, authzTable)
|
|
||||||
assert.Equals(t, key, []byte(az.getID()))
|
|
||||||
return b, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(az)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
id: az.getID(),
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, bucket, authzTable)
|
|
||||||
assert.Equals(t, key, []byte(az.getID()))
|
|
||||||
return b, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if az, err := getAuthz(tc.db, tc.id); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, tc.az.getID(), az.getID())
|
|
||||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
|
||||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
|
||||||
assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
|
|
||||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
|
||||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
|
||||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthzClone(t *testing.T) {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
clone := az.clone()
|
|
||||||
|
|
||||||
assert.Equals(t, clone.getID(), az.getID())
|
|
||||||
assert.Equals(t, clone.getAccountID(), az.getAccountID())
|
|
||||||
assert.Equals(t, clone.getStatus(), az.getStatus())
|
|
||||||
assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
|
|
||||||
assert.Equals(t, clone.getExpiry(), az.getExpiry())
|
|
||||||
assert.Equals(t, clone.getCreated(), az.getCreated())
|
|
||||||
assert.Equals(t, clone.getChallenges(), az.getChallenges())
|
|
||||||
|
|
||||||
clone.Status = StatusValid
|
|
||||||
|
|
||||||
assert.NotEquals(t, clone.getStatus(), az.getStatus())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewAuthz(t *testing.T) {
|
|
||||||
iden := Identifier{
|
|
||||||
Type: "dns", Value: "acme.example.com",
|
|
||||||
}
|
|
||||||
accID := "1234"
|
|
||||||
type test struct {
|
|
||||||
iden Identifier
|
|
||||||
db nosql.DB
|
|
||||||
err *Error
|
|
||||||
resChs *([]string)
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/unexpected-type": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
iden: Identifier{Type: "foo", Value: "acme.example.com"},
|
|
||||||
err: MalformedErr(errors.New("unexpected authz type foo")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/new-http-chall-error": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
iden: iden,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/new-tls-alpn-chall-error": func(t *testing.T) test {
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
iden: iden,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 1 {
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error creating alpn challenge: error saving acme challenge: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/new-dns-chall-error": func(t *testing.T) test {
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
iden: iden,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 2 {
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/save-authz-error": func(t *testing.T) test {
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
iden: iden,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 3 {
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
chs := &([]string{})
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
iden: iden,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 3 {
|
|
||||||
assert.Equals(t, bucket, authzTable)
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
|
|
||||||
az, err := unmarshalAuthz(newval)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
assert.Equals(t, az.getID(), string(key))
|
|
||||||
assert.Equals(t, az.getAccountID(), accID)
|
|
||||||
assert.Equals(t, az.getStatus(), StatusPending)
|
|
||||||
assert.Equals(t, az.getIdentifier(), iden)
|
|
||||||
assert.Equals(t, az.getWildcard(), false)
|
|
||||||
|
|
||||||
*chs = az.getChallenges()
|
|
||||||
|
|
||||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
|
||||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
|
||||||
|
|
||||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
|
||||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
|
||||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resChs: chs,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/wildcard": func(t *testing.T) test {
|
|
||||||
chs := &([]string{})
|
|
||||||
count := 0
|
|
||||||
_iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
|
|
||||||
return test{
|
|
||||||
iden: _iden,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 1 {
|
|
||||||
assert.Equals(t, bucket, authzTable)
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
|
|
||||||
az, err := unmarshalAuthz(newval)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
assert.Equals(t, az.getID(), string(key))
|
|
||||||
assert.Equals(t, az.getAccountID(), accID)
|
|
||||||
assert.Equals(t, az.getStatus(), StatusPending)
|
|
||||||
assert.Equals(t, az.getIdentifier(), iden)
|
|
||||||
assert.Equals(t, az.getWildcard(), true)
|
|
||||||
|
|
||||||
*chs = az.getChallenges()
|
|
||||||
// Verify that we only have 1 challenge instead of 2.
|
|
||||||
assert.True(t, len(*chs) == 1)
|
|
||||||
|
|
||||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
|
||||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
|
||||||
|
|
||||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
|
||||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
|
||||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resChs: chs,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
tc := run(t)
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
az, err := newAuthz(tc.db, accID, tc.iden)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, az.getAccountID(), accID)
|
|
||||||
assert.Equals(t, az.getType(), "dns")
|
|
||||||
assert.Equals(t, az.getStatus(), StatusPending)
|
|
||||||
|
|
||||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
|
||||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
|
||||||
|
|
||||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
|
||||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
|
||||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
|
||||||
|
|
||||||
assert.Equals(t, az.getChallenges(), *(tc.resChs))
|
|
||||||
|
|
||||||
if strings.HasPrefix(tc.iden.Value, "*.") {
|
|
||||||
assert.True(t, az.getWildcard())
|
|
||||||
assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
|
|
||||||
} else {
|
|
||||||
assert.False(t, az.getWildcard())
|
|
||||||
assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, az.getID() != "")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthzToACME(t *testing.T) {
|
|
||||||
dir := newDirectory("ca.smallstep.com", "acme")
|
|
||||||
|
|
||||||
var (
|
|
||||||
ch1, ch2 challenge
|
|
||||||
ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
mockdb := &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 0 {
|
|
||||||
*ch1Bytes = newval
|
|
||||||
ch1, err = unmarshalChallenge(newval)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
} else if count == 1 {
|
|
||||||
*ch2Bytes = newval
|
|
||||||
ch2, err = unmarshalChallenge(newval)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return []byte("foo"), true, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
iden := Identifier{
|
|
||||||
Type: "dns", Value: "acme.example.com",
|
|
||||||
}
|
|
||||||
az, err := newAuthz(mockdb, "1234", iden)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
prov := newProv()
|
|
||||||
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
|
|
||||||
|
|
||||||
type test struct {
|
|
||||||
db nosql.DB
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/getChallenge1-error": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/getChallenge2-error": func(t *testing.T) test {
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
if count == 1 {
|
|
||||||
return nil, errors.New("force")
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return *ch1Bytes, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
count := 0
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
if count == 0 {
|
|
||||||
count++
|
|
||||||
return *ch1Bytes, nil
|
|
||||||
}
|
|
||||||
return *ch2Bytes, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
tc := run(t)
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
acmeAz, err := az.toACME(ctx, tc.db, dir)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, acmeAz.ID, az.getID())
|
|
||||||
assert.Equals(t, acmeAz.Identifier, iden)
|
|
||||||
assert.Equals(t, acmeAz.Status, StatusPending)
|
|
||||||
|
|
||||||
acmeCh1, err := ch1.toACME(ctx, nil, dir)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
acmeCh2, err := ch2.toACME(ctx, nil, dir)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
|
|
||||||
assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
|
|
||||||
|
|
||||||
expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
assert.Equals(t, expiry.String(), az.getExpiry().String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthzSave(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
az, old authz
|
|
||||||
db nosql.DB
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
old: nil,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
old: nil,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return []byte("foo"), false, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/old-nil": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(az)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
old: nil,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
assert.Equals(t, b, newval)
|
|
||||||
assert.Equals(t, bucket, authzTable)
|
|
||||||
assert.Equals(t, []byte(az.getID()), key)
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/old-not-nil": func(t *testing.T) test {
|
|
||||||
oldAz, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
oldb, err := json.Marshal(oldAz)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(az)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
old: oldAz,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, old, oldb)
|
|
||||||
assert.Equals(t, b, newval)
|
|
||||||
assert.Equals(t, bucket, authzTable)
|
|
||||||
assert.Equals(t, []byte(az.getID()), key)
|
|
||||||
return []byte("foo"), true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if err := tc.az.save(tc.db, tc.old); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert.Nil(t, tc.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthzUnmarshal(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
az authz
|
|
||||||
azb []byte
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/nil": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
azb: nil,
|
|
||||||
err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/unexpected-type": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_az, ok := az.(*dnsAuthz)
|
|
||||||
assert.Fatal(t, ok)
|
|
||||||
_az.baseAuthz.Identifier.Type = "foo"
|
|
||||||
b, err := json.Marshal(az)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
azb: b,
|
|
||||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/dns": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(az)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
azb: b,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if az, err := unmarshalAuthz(tc.azb); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, tc.az.getID(), az.getID())
|
|
||||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
|
||||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
|
||||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
|
||||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
|
||||||
assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
|
|
||||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthzUpdateStatus(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
az, res authz
|
|
||||||
err *Error
|
|
||||||
db nosql.DB
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/already-invalid": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_az, ok := az.(*dnsAuthz)
|
|
||||||
assert.Fatal(t, ok)
|
|
||||||
_az.baseAuthz.Status = StatusInvalid
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
res: az,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/already-valid": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_az, ok := az.(*dnsAuthz)
|
|
||||||
assert.Fatal(t, ok)
|
|
||||||
_az.baseAuthz.Status = StatusValid
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
res: az,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/unexpected-status": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_az, ok := az.(*dnsAuthz)
|
|
||||||
assert.Fatal(t, ok)
|
|
||||||
_az.baseAuthz.Status = StatusReady
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
res: az,
|
|
||||||
err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/save-error": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_az, ok := az.(*dnsAuthz)
|
|
||||||
assert.Fatal(t, ok)
|
|
||||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
res: az,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/expired": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_az, ok := az.(*dnsAuthz)
|
|
||||||
assert.Fatal(t, ok)
|
|
||||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
|
||||||
|
|
||||||
clone := az.clone()
|
|
||||||
clone.Error = MalformedErr(errors.New("authz has expired"))
|
|
||||||
clone.Status = StatusInvalid
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
res: clone.parent(),
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/get-challenge-error": func(t *testing.T) test {
|
|
||||||
az, err := newAz()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
res: az,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/valid": func(t *testing.T) test {
|
|
||||||
var (
|
|
||||||
ch3 challenge
|
|
||||||
ch2Bytes = &([]byte{})
|
|
||||||
ch1Bytes = &([]byte{})
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
mockdb := &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 0 {
|
|
||||||
*ch1Bytes = newval
|
|
||||||
} else if count == 1 {
|
|
||||||
*ch2Bytes = newval
|
|
||||||
} else if count == 2 {
|
|
||||||
ch3, err = unmarshalChallenge(newval)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
iden := Identifier{
|
|
||||||
Type: "dns", Value: "acme.example.com",
|
|
||||||
}
|
|
||||||
az, err := newAuthz(mockdb, "1234", iden)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_az, ok := az.(*dnsAuthz)
|
|
||||||
assert.Fatal(t, ok)
|
|
||||||
_az.baseAuthz.Error = MalformedErr(nil)
|
|
||||||
|
|
||||||
_ch, ok := ch3.(*dns01Challenge)
|
|
||||||
assert.Fatal(t, ok)
|
|
||||||
_ch.baseChallenge.Status = StatusValid
|
|
||||||
chb, err := json.Marshal(ch3)
|
|
||||||
|
|
||||||
clone := az.clone()
|
|
||||||
clone.Status = StatusValid
|
|
||||||
clone.Error = nil
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
res: clone.parent(),
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
if count == 0 {
|
|
||||||
count++
|
|
||||||
return *ch1Bytes, nil
|
|
||||||
}
|
|
||||||
if count == 1 {
|
|
||||||
count++
|
|
||||||
return *ch2Bytes, nil
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return chb, nil
|
|
||||||
},
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/still-pending": func(t *testing.T) test {
|
|
||||||
var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
mockdb := &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if count == 0 {
|
|
||||||
*ch1Bytes = newval
|
|
||||||
} else if count == 1 {
|
|
||||||
*ch2Bytes = newval
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
iden := Identifier{
|
|
||||||
Type: "dns", Value: "acme.example.com",
|
|
||||||
}
|
|
||||||
az, err := newAuthz(mockdb, "1234", iden)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
return test{
|
|
||||||
az: az,
|
|
||||||
res: az,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
if count == 0 {
|
|
||||||
count++
|
|
||||||
return *ch1Bytes, nil
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
return *ch2Bytes, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
az, err := tc.az.updateStatus(tc.db)
|
|
||||||
if err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
expB, err := json.Marshal(tc.res)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(az)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
assert.Equals(t, expB, b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -2,88 +2,13 @@ package acme
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
|
||||||
"encoding/pem"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type certificate struct {
|
// Certificate options with which to create and store a cert object.
|
||||||
ID string `json:"id"`
|
type Certificate struct {
|
||||||
Created time.Time `json:"created"`
|
ID string
|
||||||
AccountID string `json:"accountID"`
|
|
||||||
OrderID string `json:"orderID"`
|
|
||||||
Leaf []byte `json:"leaf"`
|
|
||||||
Intermediates []byte `json:"intermediates"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CertOptions options with which to create and store a cert object.
|
|
||||||
type CertOptions struct {
|
|
||||||
AccountID string
|
AccountID string
|
||||||
OrderID string
|
OrderID string
|
||||||
Leaf *x509.Certificate
|
Leaf *x509.Certificate
|
||||||
Intermediates []*x509.Certificate
|
Intermediates []*x509.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCert(db nosql.DB, ops CertOptions) (*certificate, error) {
|
|
||||||
id, err := randID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
leaf := pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "CERTIFICATE",
|
|
||||||
Bytes: ops.Leaf.Raw,
|
|
||||||
})
|
|
||||||
var intermediates []byte
|
|
||||||
for _, cert := range ops.Intermediates {
|
|
||||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "CERTIFICATE",
|
|
||||||
Bytes: cert.Raw,
|
|
||||||
})...)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert := &certificate{
|
|
||||||
ID: id,
|
|
||||||
AccountID: ops.AccountID,
|
|
||||||
OrderID: ops.OrderID,
|
|
||||||
Leaf: leaf,
|
|
||||||
Intermediates: intermediates,
|
|
||||||
Created: time.Now().UTC(),
|
|
||||||
}
|
|
||||||
certB, err := json.Marshal(cert)
|
|
||||||
if err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate"))
|
|
||||||
}
|
|
||||||
|
|
||||||
_, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB)
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate"))
|
|
||||||
case !swapped:
|
|
||||||
return nil, ServerInternalErr(errors.New("error storing certificate; " +
|
|
||||||
"value has changed since last read"))
|
|
||||||
default:
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) {
|
|
||||||
return append(c.Leaf, c.Intermediates...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCert(db nosql.DB, id string) (*certificate, error) {
|
|
||||||
b, err := db.Get(certTable, []byte(id))
|
|
||||||
if nosql.IsErrNotFound(err) {
|
|
||||||
return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id))
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate"))
|
|
||||||
}
|
|
||||||
var cert certificate
|
|
||||||
if err := json.Unmarshal(b, &cert); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate"))
|
|
||||||
}
|
|
||||||
return &cert, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,253 +0,0 @@
|
||||||
package acme
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/json"
|
|
||||||
"encoding/pem"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
"github.com/smallstep/nosql/database"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func defaultCertOps() (*CertOptions, error) {
|
|
||||||
crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &CertOptions{
|
|
||||||
AccountID: "accID",
|
|
||||||
OrderID: "ordID",
|
|
||||||
Leaf: crt,
|
|
||||||
Intermediates: []*x509.Certificate{inter, root},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newcert() (*certificate, error) {
|
|
||||||
ops, err := defaultCertOps()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mockdb := &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return newCert(mockdb, *ops)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewCert(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
db nosql.DB
|
|
||||||
ops CertOptions
|
|
||||||
err *Error
|
|
||||||
id *string
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
|
||||||
ops, err := defaultCertOps()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
ops: *ops,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, certTable)
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.Errorf("error storing certificate: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
|
||||||
ops, err := defaultCertOps()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
ops: *ops,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, certTable)
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
return nil, false, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
ops, err := defaultCertOps()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
var _id string
|
|
||||||
id := &_id
|
|
||||||
return test{
|
|
||||||
ops: *ops,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, certTable)
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
*id = string(key)
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
id: id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if cert, err := newCert(tc.db, tc.ops); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, cert.ID, *tc.id)
|
|
||||||
assert.Equals(t, cert.AccountID, tc.ops.AccountID)
|
|
||||||
assert.Equals(t, cert.OrderID, tc.ops.OrderID)
|
|
||||||
|
|
||||||
leaf := pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "CERTIFICATE",
|
|
||||||
Bytes: tc.ops.Leaf.Raw,
|
|
||||||
})
|
|
||||||
var intermediates []byte
|
|
||||||
for _, cert := range tc.ops.Intermediates {
|
|
||||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "CERTIFICATE",
|
|
||||||
Bytes: cert.Raw,
|
|
||||||
})...)
|
|
||||||
}
|
|
||||||
assert.Equals(t, cert.Leaf, leaf)
|
|
||||||
assert.Equals(t, cert.Intermediates, intermediates)
|
|
||||||
|
|
||||||
assert.True(t, cert.Created.Before(time.Now().Add(time.Minute)))
|
|
||||||
assert.True(t, cert.Created.After(time.Now().Add(-time.Minute)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetCert(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
id string
|
|
||||||
db nosql.DB
|
|
||||||
cert *certificate
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/not-found": func(t *testing.T) test {
|
|
||||||
cert, err := newcert()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
cert: cert,
|
|
||||||
id: cert.ID,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, bucket, certTable)
|
|
||||||
assert.Equals(t, key, []byte(cert.ID))
|
|
||||||
return nil, database.ErrNotFound
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/db-error": func(t *testing.T) test {
|
|
||||||
cert, err := newcert()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
cert: cert,
|
|
||||||
id: cert.ID,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, bucket, certTable)
|
|
||||||
assert.Equals(t, key, []byte(cert.ID))
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error loading certificate: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/unmarshal-error": func(t *testing.T) test {
|
|
||||||
cert, err := newcert()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
cert: cert,
|
|
||||||
id: cert.ID,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, bucket, certTable)
|
|
||||||
assert.Equals(t, key, []byte(cert.ID))
|
|
||||||
return nil, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
cert, err := newcert()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
b, err := json.Marshal(cert)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
cert: cert,
|
|
||||||
id: cert.ID,
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, bucket, certTable)
|
|
||||||
assert.Equals(t, key, []byte(cert.ID))
|
|
||||||
return b, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if cert, err := getCert(tc.db, tc.id); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, tc.cert.ID, cert.ID)
|
|
||||||
assert.Equals(t, tc.cert.AccountID, cert.AccountID)
|
|
||||||
assert.Equals(t, tc.cert.OrderID, cert.OrderID)
|
|
||||||
assert.Equals(t, tc.cert.Created, cert.Created)
|
|
||||||
assert.Equals(t, tc.cert.Leaf, cert.Leaf)
|
|
||||||
assert.Equals(t, tc.cert.Intermediates, cert.Intermediates)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateToACME(t *testing.T) {
|
|
||||||
cert, err := newcert()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
acmeCert, err := cert.toACME(nil, nil)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert)
|
|
||||||
}
|
|
|
@ -14,394 +14,115 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Challenge is a subset of the challenge type containing only those attributes
|
// Challenge represents an ACME response Challenge type.
|
||||||
// required for responses in the ACME protocol.
|
|
||||||
type Challenge struct {
|
type Challenge struct {
|
||||||
Type string `json:"type"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
Validated string `json:"validated,omitempty"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
Error *AError `json:"error,omitempty"`
|
|
||||||
ID string `json:"-"`
|
ID string `json:"-"`
|
||||||
AuthzID string `json:"-"`
|
AccountID string `json:"-"`
|
||||||
|
AuthorizationID string `json:"-"`
|
||||||
|
Value string `json:"-"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Status Status `json:"status"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
ValidatedAt string `json:"validated,omitempty"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Error *Error `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToLog enables response logging.
|
// ToLog enables response logging.
|
||||||
func (c *Challenge) ToLog() (interface{}, error) {
|
func (ch *Challenge) ToLog() (interface{}, error) {
|
||||||
b, err := json.Marshal(c)
|
b, err := json.Marshal(ch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging"))
|
return nil, WrapErrorISE(err, "error marshaling challenge for logging")
|
||||||
}
|
}
|
||||||
return string(b), nil
|
return string(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the Challenge ID.
|
// Validate attempts to validate the challenge. Stores changes to the Challenge
|
||||||
func (c *Challenge) GetID() string {
|
// type using the DB interface.
|
||||||
return c.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAuthzID returns the parent Authz ID that owns the Challenge.
|
|
||||||
func (c *Challenge) GetAuthzID() string {
|
|
||||||
return c.AuthzID
|
|
||||||
}
|
|
||||||
|
|
||||||
type httpGetter func(string) (*http.Response, error)
|
|
||||||
type lookupTxt func(string) ([]string, error)
|
|
||||||
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
|
||||||
|
|
||||||
type validateOptions struct {
|
|
||||||
httpGet httpGetter
|
|
||||||
lookupTxt lookupTxt
|
|
||||||
tlsDial tlsDialer
|
|
||||||
}
|
|
||||||
|
|
||||||
// challenge is the interface ACME challenege types must implement.
|
|
||||||
type challenge interface {
|
|
||||||
save(db nosql.DB, swap challenge) error
|
|
||||||
validate(nosql.DB, *jose.JSONWebKey, validateOptions) (challenge, error)
|
|
||||||
getType() string
|
|
||||||
getError() *AError
|
|
||||||
getValue() string
|
|
||||||
getStatus() string
|
|
||||||
getID() string
|
|
||||||
getAuthzID() string
|
|
||||||
getToken() string
|
|
||||||
clone() *baseChallenge
|
|
||||||
getAccountID() string
|
|
||||||
getValidated() time.Time
|
|
||||||
getCreated() time.Time
|
|
||||||
toACME(context.Context, nosql.DB, *directory) (*Challenge, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChallengeOptions is the type used to created a new Challenge.
|
|
||||||
type ChallengeOptions struct {
|
|
||||||
AccountID string
|
|
||||||
AuthzID string
|
|
||||||
Identifier Identifier
|
|
||||||
}
|
|
||||||
|
|
||||||
// baseChallenge is the base Challenge type that others build from.
|
|
||||||
type baseChallenge struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
AccountID string `json:"accountID"`
|
|
||||||
AuthzID string `json:"authzID"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
Value string `json:"value"`
|
|
||||||
Validated time.Time `json:"validated"`
|
|
||||||
Created time.Time `json:"created"`
|
|
||||||
Error *AError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBaseChallenge(accountID, authzID string) (*baseChallenge, error) {
|
|
||||||
id, err := randID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, Wrap(err, "error generating random id for ACME challenge")
|
|
||||||
}
|
|
||||||
token, err := randID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, Wrap(err, "error generating token for ACME challenge")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &baseChallenge{
|
|
||||||
ID: id,
|
|
||||||
AccountID: accountID,
|
|
||||||
AuthzID: authzID,
|
|
||||||
Status: StatusPending,
|
|
||||||
Token: token,
|
|
||||||
Created: clock.Now(),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getID returns the id of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getID() string {
|
|
||||||
return bc.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAuthzID returns the authz ID of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getAuthzID() string {
|
|
||||||
return bc.AuthzID
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAccountID returns the account id of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getAccountID() string {
|
|
||||||
return bc.AccountID
|
|
||||||
}
|
|
||||||
|
|
||||||
// getType returns the type of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getType() string {
|
|
||||||
return bc.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
// getValue returns the type of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getValue() string {
|
|
||||||
return bc.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
// getStatus returns the status of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getStatus() string {
|
|
||||||
return bc.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
// getToken returns the token of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getToken() string {
|
|
||||||
return bc.Token
|
|
||||||
}
|
|
||||||
|
|
||||||
// getValidated returns the validated time of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getValidated() time.Time {
|
|
||||||
return bc.Validated
|
|
||||||
}
|
|
||||||
|
|
||||||
// getCreated returns the created time of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getCreated() time.Time {
|
|
||||||
return bc.Created
|
|
||||||
}
|
|
||||||
|
|
||||||
// getCreated returns the created time of the baseChallenge.
|
|
||||||
func (bc *baseChallenge) getError() *AError {
|
|
||||||
return bc.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
// toACME converts the internal Challenge type into the public acmeChallenge
|
|
||||||
// type for presentation in the ACME protocol.
|
|
||||||
func (bc *baseChallenge) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Challenge, error) {
|
|
||||||
ac := &Challenge{
|
|
||||||
Type: bc.getType(),
|
|
||||||
Status: bc.getStatus(),
|
|
||||||
Token: bc.getToken(),
|
|
||||||
URL: dir.getLink(ctx, ChallengeLink, true, bc.getID()),
|
|
||||||
ID: bc.getID(),
|
|
||||||
AuthzID: bc.getAuthzID(),
|
|
||||||
}
|
|
||||||
if !bc.Validated.IsZero() {
|
|
||||||
ac.Validated = bc.Validated.Format(time.RFC3339)
|
|
||||||
}
|
|
||||||
if bc.Error != nil {
|
|
||||||
ac.Error = bc.Error
|
|
||||||
}
|
|
||||||
return ac, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// save writes the challenge to disk. For new challenges 'old' should be nil,
|
|
||||||
// otherwise 'old' should be a pointer to the acme challenge as it was at the
|
|
||||||
// start of the request. This method will fail if the value currently found
|
|
||||||
// in the bucket/row does not match the value of 'old'.
|
|
||||||
func (bc *baseChallenge) save(db nosql.DB, old challenge) error {
|
|
||||||
newB, err := json.Marshal(bc)
|
|
||||||
if err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err,
|
|
||||||
"error marshaling new acme challenge"))
|
|
||||||
}
|
|
||||||
var oldB []byte
|
|
||||||
if old == nil {
|
|
||||||
oldB = nil
|
|
||||||
} else {
|
|
||||||
oldB, err = json.Marshal(old)
|
|
||||||
if err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err,
|
|
||||||
"error marshaling old acme challenge"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB)
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error saving acme challenge"))
|
|
||||||
case !swapped:
|
|
||||||
return ServerInternalErr(errors.New("error saving acme challenge; " +
|
|
||||||
"acme challenge has changed since last read"))
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bc *baseChallenge) clone() *baseChallenge {
|
|
||||||
u := *bc
|
|
||||||
return &u
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bc *baseChallenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
|
||||||
return nil, ServerInternalErr(errors.New("unimplemented"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bc *baseChallenge) storeError(db nosql.DB, err *Error) error {
|
|
||||||
clone := bc.clone()
|
|
||||||
clone.Error = err.ToACME()
|
|
||||||
if err := clone.save(db, bc); err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "failure saving error to acme challenge"))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// unmarshalChallenge unmarshals a challenge type into the correct sub-type.
|
|
||||||
func unmarshalChallenge(data []byte) (challenge, error) {
|
|
||||||
var getType struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(data, &getType); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type"))
|
|
||||||
}
|
|
||||||
|
|
||||||
switch getType.Type {
|
|
||||||
case "dns-01":
|
|
||||||
var bc baseChallenge
|
|
||||||
if err := json.Unmarshal(data, &bc); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
|
|
||||||
"challenge type into dns01Challenge"))
|
|
||||||
}
|
|
||||||
return &dns01Challenge{&bc}, nil
|
|
||||||
case "http-01":
|
|
||||||
var bc baseChallenge
|
|
||||||
if err := json.Unmarshal(data, &bc); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
|
|
||||||
"challenge type into http01Challenge"))
|
|
||||||
}
|
|
||||||
return &http01Challenge{&bc}, nil
|
|
||||||
case "tls-alpn-01":
|
|
||||||
var bc baseChallenge
|
|
||||||
if err := json.Unmarshal(data, &bc); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
|
|
||||||
"challenge type into tlsALPN01Challenge"))
|
|
||||||
}
|
|
||||||
return &tlsALPN01Challenge{&bc}, nil
|
|
||||||
default:
|
|
||||||
return nil, ServerInternalErr(errors.Errorf("unexpected challenge type %s", getType.Type))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// http01Challenge represents an http-01 acme challenge.
|
|
||||||
type http01Challenge struct {
|
|
||||||
*baseChallenge
|
|
||||||
}
|
|
||||||
|
|
||||||
// newHTTP01Challenge returns a new acme http-01 challenge.
|
|
||||||
func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
|
|
||||||
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
bc.Type = "http-01"
|
|
||||||
bc.Value = ops.Identifier.Value
|
|
||||||
|
|
||||||
hc := &http01Challenge{bc}
|
|
||||||
if err := hc.save(db, nil); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return hc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate attempts to validate the challenge. If the challenge has been
|
|
||||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||||
// updated.
|
// updated.
|
||||||
func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||||
// If already valid or invalid then return without performing validation.
|
// If already valid or invalid then return without performing validation.
|
||||||
if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid {
|
if ch.Status != StatusPending {
|
||||||
return hc, nil
|
return nil
|
||||||
|
}
|
||||||
|
switch ch.Type {
|
||||||
|
case "http-01":
|
||||||
|
return http01Validate(ctx, ch, db, jwk, vo)
|
||||||
|
case "dns-01":
|
||||||
|
return dns01Validate(ctx, ch, db, jwk, vo)
|
||||||
|
case "tls-alpn-01":
|
||||||
|
return tlsalpn01Validate(ctx, ch, db, jwk, vo)
|
||||||
|
default:
|
||||||
|
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token)
|
|
||||||
|
|
||||||
resp, err := vo.httpGet(url)
|
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||||
|
url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||||
|
|
||||||
|
resp, err := vo.HTTPGet(url.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err,
|
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||||
"error doing http GET for url %s", url))); err != nil {
|
"error doing http GET for url %s", url))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return hc, nil
|
|
||||||
}
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
if err = hc.storeError(db,
|
|
||||||
ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d",
|
|
||||||
url, resp.StatusCode))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return hc, nil
|
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return storeError(ctx, db, ch, false, NewError(ErrorConnectionType,
|
||||||
|
"error doing http GET for url %s with status code %d", url, resp.StatusCode))
|
||||||
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+
|
return WrapErrorISE(err, "error reading "+
|
||||||
"response body for url %s", url))
|
"response body for url %s", url)
|
||||||
}
|
}
|
||||||
keyAuth := strings.Trim(string(body), "\r\n")
|
keyAuth := strings.TrimSpace(string(body))
|
||||||
|
|
||||||
expected, err := KeyAuthorization(hc.Token, jwk)
|
expected, err := KeyAuthorization(ch.Token, jwk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
if keyAuth != expected {
|
if keyAuth != expected {
|
||||||
if err = hc.storeError(db,
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+
|
"keyAuthorization does not match; expected %s, but got %s", expected, keyAuth))
|
||||||
"expected %s, but got %s", expected, keyAuth))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return hc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update and store the challenge.
|
// Update and store the challenge.
|
||||||
upd := &http01Challenge{hc.baseChallenge.clone()}
|
ch.Status = StatusValid
|
||||||
upd.Status = StatusValid
|
ch.Error = nil
|
||||||
upd.Error = nil
|
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
|
||||||
upd.Validated = clock.Now()
|
|
||||||
|
|
||||||
if err := upd.save(db, hc); err != nil {
|
if err = db.UpdateChallenge(ctx, ch); err != nil {
|
||||||
return nil, err
|
return WrapErrorISE(err, "error updating challenge")
|
||||||
}
|
}
|
||||||
return upd, nil
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
type tlsALPN01Challenge struct {
|
|
||||||
*baseChallenge
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTLSALPN01Challenge returns a new acme tls-alpn-01 challenge.
|
|
||||||
func newTLSALPN01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
|
|
||||||
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
bc.Type = "tls-alpn-01"
|
|
||||||
bc.Value = ops.Identifier.Value
|
|
||||||
|
|
||||||
hc := &tlsALPN01Challenge{bc}
|
|
||||||
if err := hc.save(db, nil); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return hc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
|
||||||
// If already valid or invalid then return without performing validation.
|
|
||||||
if tc.getStatus() == StatusValid || tc.getStatus() == StatusInvalid {
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||||
config := &tls.Config{
|
config := &tls.Config{
|
||||||
NextProtos: []string{"acme-tls/1"},
|
NextProtos: []string{"acme-tls/1"},
|
||||||
ServerName: tc.Value,
|
// https://tools.ietf.org/html/rfc8737#section-4
|
||||||
|
// ACME servers that implement "acme-tls/1" MUST only negotiate TLS 1.2
|
||||||
|
// [RFC5246] or higher when connecting to clients for validation.
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: ch.Value,
|
||||||
InsecureSkipVerify: true, // we expect a self-signed challenge certificate
|
InsecureSkipVerify: true, // we expect a self-signed challenge certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
hostPort := net.JoinHostPort(tc.Value, "443")
|
hostPort := net.JoinHostPort(ch.Value, "443")
|
||||||
|
|
||||||
conn, err := vo.tlsDial("tcp", hostPort, config)
|
conn, err := vo.TLSDial("tcp", hostPort, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err = tc.storeError(db,
|
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||||
ConnectionErr(errors.Wrapf(err, "error doing TLS dial for %s", hostPort))); err != nil {
|
"error doing TLS dial for %s", hostPort))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
@ -409,86 +130,62 @@ func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo val
|
||||||
certs := cs.PeerCertificates
|
certs := cs.PeerCertificates
|
||||||
|
|
||||||
if len(certs) == 0 {
|
if len(certs) == 0 {
|
||||||
if err = tc.storeError(db,
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("%s challenge for %s resulted in no certificates",
|
"%s challenge for %s resulted in no certificates", ch.Type, ch.Value))
|
||||||
tc.Type, tc.Value))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" {
|
if cs.NegotiatedProtocol != "acme-tls/1" {
|
||||||
if err = tc.storeError(db,
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("cannot negotiate ALPN acme-tls/1 protocol for "+
|
"cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge"))
|
||||||
"tls-alpn-01 challenge"))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
leafCert := certs[0]
|
leafCert := certs[0]
|
||||||
|
|
||||||
if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], tc.Value) {
|
if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) {
|
||||||
if err = tc.storeError(db,
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
|
"incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value))
|
||||||
"leaf certificate must contain a single DNS name, %v", tc.Value))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31}
|
idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31}
|
||||||
idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
|
idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
|
||||||
foundIDPeAcmeIdentifierV1Obsolete := false
|
foundIDPeAcmeIdentifierV1Obsolete := false
|
||||||
|
|
||||||
keyAuth, err := KeyAuthorization(tc.Token, jwk)
|
keyAuth, err := KeyAuthorization(ch.Token, jwk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
hashedKeyAuth := sha256.Sum256([]byte(keyAuth))
|
hashedKeyAuth := sha256.Sum256([]byte(keyAuth))
|
||||||
|
|
||||||
for _, ext := range leafCert.Extensions {
|
for _, ext := range leafCert.Extensions {
|
||||||
if idPeAcmeIdentifier.Equal(ext.Id) {
|
if idPeAcmeIdentifier.Equal(ext.Id) {
|
||||||
if !ext.Critical {
|
if !ext.Critical {
|
||||||
if err = tc.storeError(db,
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
|
"incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical"))
|
||||||
"acmeValidationV1 extension not critical"))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var extValue []byte
|
var extValue []byte
|
||||||
rest, err := asn1.Unmarshal(ext.Value, &extValue)
|
rest, err := asn1.Unmarshal(ext.Value, &extValue)
|
||||||
|
|
||||||
if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) {
|
if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) {
|
||||||
if err = tc.storeError(db,
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
|
"incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value"))
|
||||||
"malformed acmeValidationV1 extension value"))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 {
|
if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 {
|
||||||
if err = tc.storeError(db,
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
|
"incorrect certificate for tls-alpn-01 challenge: "+
|
||||||
"expected acmeValidationV1 extension value %s for this challenge but got %s",
|
"expected acmeValidationV1 extension value %s for this challenge but got %s",
|
||||||
hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))); err != nil {
|
hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
upd := &tlsALPN01Challenge{tc.baseChallenge.clone()}
|
ch.Status = StatusValid
|
||||||
upd.Status = StatusValid
|
ch.Error = nil
|
||||||
upd.Error = nil
|
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
|
||||||
upd.Validated = clock.Now()
|
|
||||||
|
|
||||||
if err := upd.save(db, tc); err != nil {
|
if err = db.UpdateChallenge(ctx, ch); err != nil {
|
||||||
return nil, err
|
return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge")
|
||||||
}
|
}
|
||||||
return upd, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) {
|
if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) {
|
||||||
|
@ -497,82 +194,30 @@ func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo val
|
||||||
}
|
}
|
||||||
|
|
||||||
if foundIDPeAcmeIdentifierV1Obsolete {
|
if foundIDPeAcmeIdentifierV1Obsolete {
|
||||||
if err = tc.storeError(db,
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
|
"incorrect certificate for tls-alpn-01 challenge: obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))
|
||||||
"obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = tc.storeError(db,
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
|
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
|
||||||
"missing acmeValidationV1 extension"))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// dns01Challenge represents an dns-01 acme challenge.
|
|
||||||
type dns01Challenge struct {
|
|
||||||
*baseChallenge
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDNS01Challenge returns a new acme dns-01 challenge.
|
|
||||||
func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
|
|
||||||
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
bc.Type = "dns-01"
|
|
||||||
bc.Value = ops.Identifier.Value
|
|
||||||
|
|
||||||
dc := &dns01Challenge{bc}
|
|
||||||
if err := dc.save(db, nil); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return dc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyAuthorization creates the ACME key authorization value from a token
|
|
||||||
// and a jwk.
|
|
||||||
func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
|
|
||||||
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
|
|
||||||
if err != nil {
|
|
||||||
return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint"))
|
|
||||||
}
|
|
||||||
encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
|
|
||||||
return fmt.Sprintf("%s.%s", token, encPrint), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validate attempts to validate the challenge. If the challenge has been
|
|
||||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
|
||||||
// updated.
|
|
||||||
func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
|
||||||
// If already valid or invalid then return without performing validation.
|
|
||||||
if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid {
|
|
||||||
return dc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||||
// Normalize domain for wildcard DNS names
|
// Normalize domain for wildcard DNS names
|
||||||
// This is done to avoid making TXT lookups for domains like
|
// This is done to avoid making TXT lookups for domains like
|
||||||
// _acme-challenge.*.example.com
|
// _acme-challenge.*.example.com
|
||||||
// Instead perform txt lookup for _acme-challenge.example.com
|
// Instead perform txt lookup for _acme-challenge.example.com
|
||||||
domain := strings.TrimPrefix(dc.Value, "*.")
|
domain := strings.TrimPrefix(ch.Value, "*.")
|
||||||
|
|
||||||
txtRecords, err := vo.lookupTxt("_acme-challenge." + domain)
|
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err = dc.storeError(db,
|
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
|
||||||
DNSErr(errors.Wrapf(err, "error looking up TXT "+
|
"error looking up TXT records for domain %s", domain))
|
||||||
"records for domain %s", domain))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return dc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk)
|
expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
h := sha256.Sum256([]byte(expectedKeyAuth))
|
h := sha256.Sum256([]byte(expectedKeyAuth))
|
||||||
expected := base64.RawURLEncoding.EncodeToString(h[:])
|
expected := base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
@ -584,37 +229,51 @@ func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
if err = dc.storeError(db,
|
return storeError(ctx, db, ch, false, NewError(ErrorRejectedIdentifierType,
|
||||||
RejectedIdentifierErr(errors.Errorf("keyAuthorization "+
|
"keyAuthorization does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))
|
||||||
"does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return dc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update and store the challenge.
|
// Update and store the challenge.
|
||||||
upd := &dns01Challenge{dc.baseChallenge.clone()}
|
ch.Status = StatusValid
|
||||||
upd.Status = StatusValid
|
ch.Error = nil
|
||||||
upd.Error = nil
|
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
|
||||||
upd.Validated = time.Now().UTC()
|
|
||||||
|
|
||||||
if err := upd.save(db, dc); err != nil {
|
if err = db.UpdateChallenge(ctx, ch); err != nil {
|
||||||
return nil, err
|
return WrapErrorISE(err, "error updating challenge")
|
||||||
}
|
}
|
||||||
return upd, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getChallenge retrieves and unmarshals an ACME challenge type from the database.
|
// KeyAuthorization creates the ACME key authorization value from a token
|
||||||
func getChallenge(db nosql.DB, id string) (challenge, error) {
|
// and a jwk.
|
||||||
b, err := db.Get(challengeTable, []byte(id))
|
func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
|
||||||
if nosql.IsErrNotFound(err) {
|
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
|
||||||
return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id))
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id))
|
|
||||||
}
|
|
||||||
ch, err := unmarshalChallenge(b)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", WrapErrorISE(err, "error generating JWK thumbprint")
|
||||||
}
|
}
|
||||||
return ch, nil
|
encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
|
||||||
|
return fmt.Sprintf("%s.%s", token, encPrint), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeError the given error to an ACME error and saves using the DB interface.
|
||||||
|
func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err *Error) error {
|
||||||
|
ch.Error = err
|
||||||
|
if markInvalid {
|
||||||
|
ch.Status = StatusInvalid
|
||||||
|
}
|
||||||
|
if err := db.UpdateChallenge(ctx, ch); err != nil {
|
||||||
|
return WrapErrorISE(err, "failure saving error to acme challenge")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpGetter func(string) (*http.Response, error)
|
||||||
|
type lookupTxt func(string) ([]string, error)
|
||||||
|
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||||
|
|
||||||
|
// ValidateChallengeOptions are ACME challenge validator functions.
|
||||||
|
type ValidateChallengeOptions struct {
|
||||||
|
HTTPGet httpGetter
|
||||||
|
LookupTxt lookupTxt
|
||||||
|
TLSDial tlsDialer
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
143
acme/common.go
143
acme/common.go
|
@ -3,19 +3,32 @@ package acme
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"net/url"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"go.step.sm/crypto/jose"
|
|
||||||
"go.step.sm/crypto/randutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CertificateAuthority is the interface implemented by a CA authority.
|
||||||
|
type CertificateAuthority interface {
|
||||||
|
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||||
|
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clock that returns time in UTC rounded to seconds.
|
||||||
|
type Clock struct{}
|
||||||
|
|
||||||
|
// Now returns the UTC time rounded to seconds.
|
||||||
|
func (c *Clock) Now() time.Time {
|
||||||
|
return time.Now().UTC().Truncate(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
var clock Clock
|
||||||
|
|
||||||
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
||||||
// only those methods required by the ACME api/authority.
|
// only those methods required by the ACME api/authority.
|
||||||
type Provisioner interface {
|
type Provisioner interface {
|
||||||
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
|
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
|
||||||
|
GetID() string
|
||||||
GetName() string
|
GetName() string
|
||||||
DefaultTLSCertDuration() time.Duration
|
DefaultTLSCertDuration() time.Duration
|
||||||
GetOptions() *provisioner.Options
|
GetOptions() *provisioner.Options
|
||||||
|
@ -25,6 +38,7 @@ type Provisioner interface {
|
||||||
type MockProvisioner struct {
|
type MockProvisioner struct {
|
||||||
Mret1 interface{}
|
Mret1 interface{}
|
||||||
Merr error
|
Merr error
|
||||||
|
MgetID func() string
|
||||||
MgetName func() string
|
MgetName func() string
|
||||||
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||||
MdefaultTLSCertDuration func() time.Duration
|
MdefaultTLSCertDuration func() time.Duration
|
||||||
|
@ -55,6 +69,7 @@ func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration {
|
||||||
return m.Mret1.(time.Duration)
|
return m.Mret1.(time.Duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOptions mock
|
||||||
func (m *MockProvisioner) GetOptions() *provisioner.Options {
|
func (m *MockProvisioner) GetOptions() *provisioner.Options {
|
||||||
if m.MgetOptions != nil {
|
if m.MgetOptions != nil {
|
||||||
return m.MgetOptions()
|
return m.MgetOptions()
|
||||||
|
@ -62,120 +77,10 @@ func (m *MockProvisioner) GetOptions() *provisioner.Options {
|
||||||
return m.Mret1.(*provisioner.Options)
|
return m.Mret1.(*provisioner.Options)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ContextKey is the key type for storing and searching for ACME request
|
// GetID mock
|
||||||
// essentials in the context of a request.
|
func (m *MockProvisioner) GetID() string {
|
||||||
type ContextKey string
|
if m.MgetID != nil {
|
||||||
|
return m.MgetID()
|
||||||
const (
|
|
||||||
// AccContextKey account key
|
|
||||||
AccContextKey = ContextKey("acc")
|
|
||||||
// BaseURLContextKey baseURL key
|
|
||||||
BaseURLContextKey = ContextKey("baseURL")
|
|
||||||
// JwsContextKey jws key
|
|
||||||
JwsContextKey = ContextKey("jws")
|
|
||||||
// JwkContextKey jwk key
|
|
||||||
JwkContextKey = ContextKey("jwk")
|
|
||||||
// PayloadContextKey payload key
|
|
||||||
PayloadContextKey = ContextKey("payload")
|
|
||||||
// ProvisionerContextKey provisioner key
|
|
||||||
ProvisionerContextKey = ContextKey("provisioner")
|
|
||||||
)
|
|
||||||
|
|
||||||
// AccountFromContext searches the context for an ACME account. Returns the
|
|
||||||
// account or an error.
|
|
||||||
func AccountFromContext(ctx context.Context) (*Account, error) {
|
|
||||||
val, ok := ctx.Value(AccContextKey).(*Account)
|
|
||||||
if !ok || val == nil {
|
|
||||||
return nil, AccountDoesNotExistErr(nil)
|
|
||||||
}
|
}
|
||||||
return val, nil
|
return m.Mret1.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BaseURLFromContext returns the baseURL if one is stored in the context.
|
|
||||||
func BaseURLFromContext(ctx context.Context) *url.URL {
|
|
||||||
val, ok := ctx.Value(BaseURLContextKey).(*url.URL)
|
|
||||||
if !ok || val == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
// JwkFromContext searches the context for a JWK. Returns the JWK or an error.
|
|
||||||
func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
|
|
||||||
val, ok := ctx.Value(JwkContextKey).(*jose.JSONWebKey)
|
|
||||||
if !ok || val == nil {
|
|
||||||
return nil, ServerInternalErr(errors.Errorf("jwk expected in request context"))
|
|
||||||
}
|
|
||||||
return val, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// JwsFromContext searches the context for a JWS. Returns the JWS or an error.
|
|
||||||
func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
|
|
||||||
val, ok := ctx.Value(JwsContextKey).(*jose.JSONWebSignature)
|
|
||||||
if !ok || val == nil {
|
|
||||||
return nil, ServerInternalErr(errors.Errorf("jws expected in request context"))
|
|
||||||
}
|
|
||||||
return val, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProvisionerFromContext searches the context for a provisioner. Returns the
|
|
||||||
// provisioner or an error.
|
|
||||||
func ProvisionerFromContext(ctx context.Context) (Provisioner, error) {
|
|
||||||
val := ctx.Value(ProvisionerContextKey)
|
|
||||||
if val == nil {
|
|
||||||
return nil, ServerInternalErr(errors.Errorf("provisioner expected in request context"))
|
|
||||||
}
|
|
||||||
pval, ok := val.(Provisioner)
|
|
||||||
if !ok || pval == nil {
|
|
||||||
return nil, ServerInternalErr(errors.Errorf("provisioner in context is not an ACME provisioner"))
|
|
||||||
}
|
|
||||||
return pval, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignAuthority is the interface implemented by a CA authority.
|
|
||||||
type SignAuthority interface {
|
|
||||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
|
||||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Identifier encodes the type that an order pertains to.
|
|
||||||
type Identifier struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Value string `json:"value"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
// StatusValid -- valid
|
|
||||||
StatusValid = "valid"
|
|
||||||
// StatusInvalid -- invalid
|
|
||||||
StatusInvalid = "invalid"
|
|
||||||
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
|
|
||||||
StatusPending = "pending"
|
|
||||||
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
|
|
||||||
StatusDeactivated = "deactivated"
|
|
||||||
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
|
|
||||||
StatusReady = "ready"
|
|
||||||
//statusExpired = "expired"
|
|
||||||
//statusActive = "active"
|
|
||||||
//statusProcessing = "processing"
|
|
||||||
)
|
|
||||||
|
|
||||||
var idLen = 32
|
|
||||||
|
|
||||||
func randID() (val string, err error) {
|
|
||||||
val, err = randutil.Alphanumeric(idLen)
|
|
||||||
if err != nil {
|
|
||||||
return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID"))
|
|
||||||
}
|
|
||||||
return val, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clock that returns time in UTC rounded to seconds.
|
|
||||||
type Clock int
|
|
||||||
|
|
||||||
// Now returns the UTC time rounded to seconds.
|
|
||||||
func (c *Clock) Now() time.Time {
|
|
||||||
return time.Now().UTC().Round(time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
var clock = new(Clock)
|
|
||||||
|
|
251
acme/db.go
Normal file
251
acme/db.go
Normal file
|
@ -0,0 +1,251 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrNotFound is an error that should be used by the acme.DB interface to
|
||||||
|
// indicate that an entity does not exist. For example, in the new-account
|
||||||
|
// endpoint, if GetAccountByKeyID returns ErrNotFound we will create the new
|
||||||
|
// account.
|
||||||
|
var ErrNotFound = errors.New("not found")
|
||||||
|
|
||||||
|
// DB is the DB interface expected by the step-ca ACME API.
|
||||||
|
type DB interface {
|
||||||
|
CreateAccount(ctx context.Context, acc *Account) error
|
||||||
|
GetAccount(ctx context.Context, id string) (*Account, error)
|
||||||
|
GetAccountByKeyID(ctx context.Context, kid string) (*Account, error)
|
||||||
|
UpdateAccount(ctx context.Context, acc *Account) error
|
||||||
|
|
||||||
|
CreateNonce(ctx context.Context) (Nonce, error)
|
||||||
|
DeleteNonce(ctx context.Context, nonce Nonce) error
|
||||||
|
|
||||||
|
CreateAuthorization(ctx context.Context, az *Authorization) error
|
||||||
|
GetAuthorization(ctx context.Context, id string) (*Authorization, error)
|
||||||
|
UpdateAuthorization(ctx context.Context, az *Authorization) error
|
||||||
|
|
||||||
|
CreateCertificate(ctx context.Context, cert *Certificate) error
|
||||||
|
GetCertificate(ctx context.Context, id string) (*Certificate, error)
|
||||||
|
|
||||||
|
CreateChallenge(ctx context.Context, ch *Challenge) error
|
||||||
|
GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error)
|
||||||
|
UpdateChallenge(ctx context.Context, ch *Challenge) error
|
||||||
|
|
||||||
|
CreateOrder(ctx context.Context, o *Order) error
|
||||||
|
GetOrder(ctx context.Context, id string) (*Order, error)
|
||||||
|
GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
UpdateOrder(ctx context.Context, o *Order) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockDB is an implementation of the DB interface that should only be used as
|
||||||
|
// a mock in tests.
|
||||||
|
type MockDB struct {
|
||||||
|
MockCreateAccount func(ctx context.Context, acc *Account) error
|
||||||
|
MockGetAccount func(ctx context.Context, id string) (*Account, error)
|
||||||
|
MockGetAccountByKeyID func(ctx context.Context, kid string) (*Account, error)
|
||||||
|
MockUpdateAccount func(ctx context.Context, acc *Account) error
|
||||||
|
|
||||||
|
MockCreateNonce func(ctx context.Context) (Nonce, error)
|
||||||
|
MockDeleteNonce func(ctx context.Context, nonce Nonce) error
|
||||||
|
|
||||||
|
MockCreateAuthorization func(ctx context.Context, az *Authorization) error
|
||||||
|
MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error)
|
||||||
|
MockUpdateAuthorization func(ctx context.Context, az *Authorization) error
|
||||||
|
|
||||||
|
MockCreateCertificate func(ctx context.Context, cert *Certificate) error
|
||||||
|
MockGetCertificate func(ctx context.Context, id string) (*Certificate, error)
|
||||||
|
|
||||||
|
MockCreateChallenge func(ctx context.Context, ch *Challenge) error
|
||||||
|
MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error)
|
||||||
|
MockUpdateChallenge func(ctx context.Context, ch *Challenge) error
|
||||||
|
|
||||||
|
MockCreateOrder func(ctx context.Context, o *Order) error
|
||||||
|
MockGetOrder func(ctx context.Context, id string) (*Order, error)
|
||||||
|
MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
MockUpdateOrder func(ctx context.Context, o *Order) error
|
||||||
|
|
||||||
|
MockRet1 interface{}
|
||||||
|
MockError error
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccount mock.
|
||||||
|
func (m *MockDB) CreateAccount(ctx context.Context, acc *Account) error {
|
||||||
|
if m.MockCreateAccount != nil {
|
||||||
|
return m.MockCreateAccount(ctx, acc)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount mock.
|
||||||
|
func (m *MockDB) GetAccount(ctx context.Context, id string) (*Account, error) {
|
||||||
|
if m.MockGetAccount != nil {
|
||||||
|
return m.MockGetAccount(ctx, id)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
return m.MockRet1.(*Account), m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByKeyID mock
|
||||||
|
func (m *MockDB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) {
|
||||||
|
if m.MockGetAccountByKeyID != nil {
|
||||||
|
return m.MockGetAccountByKeyID(ctx, kid)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
return m.MockRet1.(*Account), m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccount mock
|
||||||
|
func (m *MockDB) UpdateAccount(ctx context.Context, acc *Account) error {
|
||||||
|
if m.MockUpdateAccount != nil {
|
||||||
|
return m.MockUpdateAccount(ctx, acc)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateNonce mock
|
||||||
|
func (m *MockDB) CreateNonce(ctx context.Context) (Nonce, error) {
|
||||||
|
if m.MockCreateNonce != nil {
|
||||||
|
return m.MockCreateNonce(ctx)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return Nonce(""), m.MockError
|
||||||
|
}
|
||||||
|
return m.MockRet1.(Nonce), m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteNonce mock
|
||||||
|
func (m *MockDB) DeleteNonce(ctx context.Context, nonce Nonce) error {
|
||||||
|
if m.MockDeleteNonce != nil {
|
||||||
|
return m.MockDeleteNonce(ctx, nonce)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAuthorization mock
|
||||||
|
func (m *MockDB) CreateAuthorization(ctx context.Context, az *Authorization) error {
|
||||||
|
if m.MockCreateAuthorization != nil {
|
||||||
|
return m.MockCreateAuthorization(ctx, az)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthorization mock
|
||||||
|
func (m *MockDB) GetAuthorization(ctx context.Context, id string) (*Authorization, error) {
|
||||||
|
if m.MockGetAuthorization != nil {
|
||||||
|
return m.MockGetAuthorization(ctx, id)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
return m.MockRet1.(*Authorization), m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAuthorization mock
|
||||||
|
func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) error {
|
||||||
|
if m.MockUpdateAuthorization != nil {
|
||||||
|
return m.MockUpdateAuthorization(ctx, az)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateCertificate mock
|
||||||
|
func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error {
|
||||||
|
if m.MockCreateCertificate != nil {
|
||||||
|
return m.MockCreateCertificate(ctx, cert)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCertificate mock
|
||||||
|
func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, error) {
|
||||||
|
if m.MockGetCertificate != nil {
|
||||||
|
return m.MockGetCertificate(ctx, id)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
return m.MockRet1.(*Certificate), m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateChallenge mock
|
||||||
|
func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error {
|
||||||
|
if m.MockCreateChallenge != nil {
|
||||||
|
return m.MockCreateChallenge(ctx, ch)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChallenge mock
|
||||||
|
func (m *MockDB) GetChallenge(ctx context.Context, chID, azID string) (*Challenge, error) {
|
||||||
|
if m.MockGetChallenge != nil {
|
||||||
|
return m.MockGetChallenge(ctx, chID, azID)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
return m.MockRet1.(*Challenge), m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateChallenge mock
|
||||||
|
func (m *MockDB) UpdateChallenge(ctx context.Context, ch *Challenge) error {
|
||||||
|
if m.MockUpdateChallenge != nil {
|
||||||
|
return m.MockUpdateChallenge(ctx, ch)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOrder mock
|
||||||
|
func (m *MockDB) CreateOrder(ctx context.Context, o *Order) error {
|
||||||
|
if m.MockCreateOrder != nil {
|
||||||
|
return m.MockCreateOrder(ctx, o)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrder mock
|
||||||
|
func (m *MockDB) GetOrder(ctx context.Context, id string) (*Order, error) {
|
||||||
|
if m.MockGetOrder != nil {
|
||||||
|
return m.MockGetOrder(ctx, id)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
return m.MockRet1.(*Order), m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOrder mock
|
||||||
|
func (m *MockDB) UpdateOrder(ctx context.Context, o *Order) error {
|
||||||
|
if m.MockUpdateOrder != nil {
|
||||||
|
return m.MockUpdateOrder(ctx, o)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
return m.MockError
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrdersByAccountID mock
|
||||||
|
func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
|
||||||
|
if m.MockGetOrdersByAccountID != nil {
|
||||||
|
return m.MockGetOrdersByAccountID(ctx, accID)
|
||||||
|
} else if m.MockError != nil {
|
||||||
|
return nil, m.MockError
|
||||||
|
}
|
||||||
|
return m.MockRet1.([]string), m.MockError
|
||||||
|
}
|
136
acme/db/nosql/account.go
Normal file
136
acme/db/nosql/account.go
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
nosqlDB "github.com/smallstep/nosql"
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dbAccount represents an ACME account.
|
||||||
|
type dbAccount struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Key *jose.JSONWebKey `json:"key"`
|
||||||
|
Contact []string `json:"contact,omitempty"`
|
||||||
|
Status acme.Status `json:"status"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
DeactivatedAt time.Time `json:"deactivatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dba *dbAccount) clone() *dbAccount {
|
||||||
|
nu := *dba
|
||||||
|
return &nu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) {
|
||||||
|
id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
|
||||||
|
if err != nil {
|
||||||
|
if nosqlDB.IsErrNotFound(err) {
|
||||||
|
return "", acme.ErrNotFound
|
||||||
|
}
|
||||||
|
return "", errors.Wrapf(err, "error loading key-account index for key %s", kid)
|
||||||
|
}
|
||||||
|
return string(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDBAccount retrieves and unmarshals dbAccount.
|
||||||
|
func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) {
|
||||||
|
data, err := db.db.Get(accountTable, []byte(id))
|
||||||
|
if err != nil {
|
||||||
|
if nosqlDB.IsErrNotFound(err) {
|
||||||
|
return nil, acme.ErrNotFound
|
||||||
|
}
|
||||||
|
return nil, errors.Wrapf(err, "error loading account %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbacc := new(dbAccount)
|
||||||
|
if err = json.Unmarshal(data, dbacc); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id)
|
||||||
|
}
|
||||||
|
return dbacc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount retrieves an ACME account by ID.
|
||||||
|
func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
|
||||||
|
dbacc, err := db.getDBAccount(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &acme.Account{
|
||||||
|
Status: dbacc.Status,
|
||||||
|
Contact: dbacc.Contact,
|
||||||
|
Key: dbacc.Key,
|
||||||
|
ID: dbacc.ID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK).
|
||||||
|
func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) {
|
||||||
|
id, err := db.getAccountIDByKeyID(ctx, kid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return db.GetAccount(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccount imlements the AcmeDB.CreateAccount interface.
|
||||||
|
func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
|
||||||
|
var err error
|
||||||
|
acc.ID, err = randID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dba := &dbAccount{
|
||||||
|
ID: acc.ID,
|
||||||
|
Key: acc.Key,
|
||||||
|
Contact: acc.Contact,
|
||||||
|
Status: acc.Status,
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
kid, err := acme.KeyToID(dba.Key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
kidB := []byte(kid)
|
||||||
|
|
||||||
|
// Set the jwkID -> acme account ID index
|
||||||
|
_, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(acc.ID))
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return errors.Wrap(err, "error storing keyID to accountID index")
|
||||||
|
case !swapped:
|
||||||
|
return errors.Errorf("key-id to account-id index already exists")
|
||||||
|
default:
|
||||||
|
if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil {
|
||||||
|
db.db.Del(accountByKeyIDTable, kidB)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccount imlements the AcmeDB.UpdateAccount interface.
|
||||||
|
func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error {
|
||||||
|
old, err := db.getDBAccount(ctx, acc.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
nu := old.clone()
|
||||||
|
nu.Contact = acc.Contact
|
||||||
|
nu.Status = acc.Status
|
||||||
|
|
||||||
|
// If the status has changed to 'deactivated', then set deactivatedAt timestamp.
|
||||||
|
if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated {
|
||||||
|
nu.DeactivatedAt = clock.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.save(ctx, old.ID, nu, old, "account", accountTable)
|
||||||
|
}
|
706
acme/db/nosql/account_test.go
Normal file
706
acme/db/nosql/account_test.go
Normal file
|
@ -0,0 +1,706 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
nosqldb "github.com/smallstep/nosql/database"
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDB_getDBAccount(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbacc *dbAccount
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: acme.ErrNotFound,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading account accID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return []byte("foo"), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error unmarshaling account accID into dbAccount"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
dbacc := &dbAccount{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
CreatedAt: now,
|
||||||
|
DeactivatedAt: now,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbacc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbacc: dbacc,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
||||||
|
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
||||||
|
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
||||||
|
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
|
||||||
|
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
||||||
|
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
kid := "kid"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: acme.ErrNotFound,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading key-account index for key kid: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
|
||||||
|
return []byte(accID), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, retAccID, accID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_GetAccount(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbacc *dbAccount
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading account accID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
dbacc := &dbAccount{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
CreatedAt: now,
|
||||||
|
DeactivatedAt: now,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbacc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbacc: dbacc,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if acc, err := db.GetAccount(context.Background(), accID); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||||
|
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||||
|
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||||
|
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
kid := "kid"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbacc *dbAccount
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.getAccountIDByKeyID-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, string(bucket), string(accountByKeyIDTable))
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading key-account index for key kid: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.GetAccount-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(accountByKeyIDTable):
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
return []byte(accID), nil
|
||||||
|
case string(accountTable):
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
return nil, errors.New("force")
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading account accID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
dbacc := &dbAccount{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
CreatedAt: now,
|
||||||
|
DeactivatedAt: now,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbacc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(accountByKeyIDTable):
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
return []byte(accID), nil
|
||||||
|
case string(accountTable):
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
return b, nil
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbacc: dbacc,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||||
|
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||||
|
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||||
|
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_CreateAccount(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
acc *acme.Account
|
||||||
|
err error
|
||||||
|
_id *string
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/keyID-cmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), jwk.KeyID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
assert.Equals(t, nu, []byte(acc.ID))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acc: acc,
|
||||||
|
err: errors.New("error storing keyID to accountID index: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/keyID-cmpAndSwap-false": func(t *testing.T) test {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), jwk.KeyID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
assert.Equals(t, nu, []byte(acc.ID))
|
||||||
|
return nil, false, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acc: acc,
|
||||||
|
err: errors.New("key-id to account-id index already exists"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/account-save-error": func(t *testing.T) test {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(accountByKeyIDTable):
|
||||||
|
assert.Equals(t, string(key), jwk.KeyID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
return nu, true, nil
|
||||||
|
case string(accountTable):
|
||||||
|
assert.Equals(t, string(key), acc.ID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbacc := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||||
|
assert.Equals(t, dbacc.ID, string(key))
|
||||||
|
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||||
|
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||||
|
assert.True(t, dbacc.DeactivatedAt.IsZero())
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acc: acc,
|
||||||
|
err: errors.New("error saving acme account: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
var (
|
||||||
|
id string
|
||||||
|
idPtr = &id
|
||||||
|
)
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
id = string(key)
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(accountByKeyIDTable):
|
||||||
|
assert.Equals(t, string(key), jwk.KeyID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
return nu, true, nil
|
||||||
|
case string(accountTable):
|
||||||
|
assert.Equals(t, string(key), acc.ID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbacc := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||||
|
assert.Equals(t, dbacc.ID, string(key))
|
||||||
|
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||||
|
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||||
|
assert.True(t, dbacc.DeactivatedAt.IsZero())
|
||||||
|
return nu, true, nil
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acc: acc,
|
||||||
|
_id: idPtr,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.CreateAccount(context.Background(), tc.acc); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.acc.ID, *tc._id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_UpdateAccount(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
now := clock.Now()
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
dbacc := &dbAccount{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
CreatedAt: now,
|
||||||
|
DeactivatedAt: now,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbacc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
acc *acme.Account
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
acc: &acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
},
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading account accID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/already-deactivated": func(t *testing.T) test {
|
||||||
|
clone := dbacc.clone()
|
||||||
|
clone.Status = acme.StatusDeactivated
|
||||||
|
clone.DeactivatedAt = now
|
||||||
|
dbaccb, err := json.Marshal(clone)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return dbaccb, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbNew := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, clone.ID)
|
||||||
|
assert.Equals(t, dbNew.Status, clone.Status)
|
||||||
|
assert.Equals(t, dbNew.Contact, clone.Contact)
|
||||||
|
assert.Equals(t, dbNew.Key.KeyID, clone.Key.KeyID)
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, clone.CreatedAt)
|
||||||
|
assert.Equals(t, dbNew.DeactivatedAt, clone.DeactivatedAt)
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error saving acme account: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbNew := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||||
|
assert.Equals(t, dbNew.Status, acc.Status)
|
||||||
|
assert.Equals(t, dbNew.Contact, dbacc.Contact)
|
||||||
|
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||||
|
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||||
|
assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error saving acme account: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbNew := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||||
|
assert.Equals(t, dbNew.Status, acc.Status)
|
||||||
|
assert.Equals(t, dbNew.Contact, dbacc.Contact)
|
||||||
|
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||||
|
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||||
|
assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
|
||||||
|
return nu, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.UpdateAccount(context.Background(), tc.acc); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.acc.ID, dbacc.ID)
|
||||||
|
assert.Equals(t, tc.acc.Status, dbacc.Status)
|
||||||
|
assert.Equals(t, tc.acc.Contact, dbacc.Contact)
|
||||||
|
assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
118
acme/db/nosql/authz.go
Normal file
118
acme/db/nosql/authz.go
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dbAuthz is the base authz type that others build from.
|
||||||
|
type dbAuthz struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
AccountID string `json:"accountID"`
|
||||||
|
Identifier acme.Identifier `json:"identifier"`
|
||||||
|
Status acme.Status `json:"status"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
ChallengeIDs []string `json:"challengeIDs"`
|
||||||
|
Wildcard bool `json:"wildcard"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
ExpiresAt time.Time `json:"expiresAt"`
|
||||||
|
Error *acme.Error `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ba *dbAuthz) clone() *dbAuthz {
|
||||||
|
u := *ba
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDBAuthz retrieves and unmarshals a database representation of the
|
||||||
|
// ACME Authorization type.
|
||||||
|
func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) {
|
||||||
|
data, err := db.db.Get(authzTable, []byte(id))
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error loading authz %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dbaz dbAuthz
|
||||||
|
if err = json.Unmarshal(data, &dbaz); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error unmarshaling authz %s into dbAuthz", id)
|
||||||
|
}
|
||||||
|
return &dbaz, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthorization retrieves and unmarshals an ACME authz type from the database.
|
||||||
|
// Implements acme.DB GetAuthorization interface.
|
||||||
|
func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||||
|
dbaz, err := db.getDBAuthz(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var chs = make([]*acme.Challenge, len(dbaz.ChallengeIDs))
|
||||||
|
for i, chID := range dbaz.ChallengeIDs {
|
||||||
|
chs[i], err = db.GetChallenge(ctx, chID, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &acme.Authorization{
|
||||||
|
ID: dbaz.ID,
|
||||||
|
AccountID: dbaz.AccountID,
|
||||||
|
Identifier: dbaz.Identifier,
|
||||||
|
Status: dbaz.Status,
|
||||||
|
Challenges: chs,
|
||||||
|
Wildcard: dbaz.Wildcard,
|
||||||
|
ExpiresAt: dbaz.ExpiresAt,
|
||||||
|
Token: dbaz.Token,
|
||||||
|
Error: dbaz.Error,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAuthorization creates an entry in the database for the Authorization.
|
||||||
|
// Implements the acme.DB.CreateAuthorization interface.
|
||||||
|
func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||||
|
var err error
|
||||||
|
az.ID, err = randID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
chIDs := make([]string, len(az.Challenges))
|
||||||
|
for i, ch := range az.Challenges {
|
||||||
|
chIDs[i] = ch.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
now := clock.Now()
|
||||||
|
dbaz := &dbAuthz{
|
||||||
|
ID: az.ID,
|
||||||
|
AccountID: az.AccountID,
|
||||||
|
Status: az.Status,
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: az.ExpiresAt,
|
||||||
|
Identifier: az.Identifier,
|
||||||
|
ChallengeIDs: chIDs,
|
||||||
|
Token: az.Token,
|
||||||
|
Wildcard: az.Wildcard,
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAuthorization saves an updated ACME Authorization to the database.
|
||||||
|
func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||||
|
old, err := db.getDBAuthz(ctx, az.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
nu := old.clone()
|
||||||
|
|
||||||
|
nu.Status = az.Status
|
||||||
|
nu.Error = az.Error
|
||||||
|
return db.save(ctx, old.ID, nu, old, "authz", authzTable)
|
||||||
|
}
|
620
acme/db/nosql/authz_test.go
Normal file
620
acme/db/nosql/authz_test.go
Normal file
|
@ -0,0 +1,620 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
nosqldb "github.com/smallstep/nosql/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDB_getDBAuthz(t *testing.T) {
|
||||||
|
azID := "azID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbaz *dbAuthz
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading authz azID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
|
||||||
|
return []byte("foo"), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error unmarshaling authz azID into dbAuthz"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
dbaz := &dbAuthz{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
},
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
Error: acme.NewErrorISE("force"),
|
||||||
|
ChallengeIDs: []string{"foo", "bar"},
|
||||||
|
Wildcard: true,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbaz)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbaz: dbaz,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, dbaz.ID, tc.dbaz.ID)
|
||||||
|
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
|
||||||
|
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
|
||||||
|
assert.Equals(t, dbaz.Status, tc.dbaz.Status)
|
||||||
|
assert.Equals(t, dbaz.Token, tc.dbaz.Token)
|
||||||
|
assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
|
||||||
|
assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||||
|
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
|
||||||
|
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_GetAuthorization(t *testing.T) {
|
||||||
|
azID := "azID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbaz *dbAuthz
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading authz azID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/forward-acme-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
dbaz := &dbAuthz{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
},
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
Error: acme.NewErrorISE("force"),
|
||||||
|
ChallengeIDs: []string{"foo", "bar"},
|
||||||
|
Wildcard: true,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbaz)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(authzTable):
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
return b, nil
|
||||||
|
case string(challengeTable):
|
||||||
|
assert.Equals(t, string(key), "foo")
|
||||||
|
return nil, errors.New("force")
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading acme challenge foo: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.GetChallenge-not-found": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
dbaz := &dbAuthz{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
},
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
Error: acme.NewErrorISE("force"),
|
||||||
|
ChallengeIDs: []string{"foo", "bar"},
|
||||||
|
Wildcard: true,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbaz)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(authzTable):
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
return b, nil
|
||||||
|
case string(challengeTable):
|
||||||
|
assert.Equals(t, string(key), "foo")
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge foo not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
dbaz := &dbAuthz{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
},
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
Error: acme.NewErrorISE("force"),
|
||||||
|
ChallengeIDs: []string{"foo", "bar"},
|
||||||
|
Wildcard: true,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbaz)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
chCount := 0
|
||||||
|
fooChb, err := json.Marshal(&dbChallenge{ID: "foo"})
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
barChb, err := json.Marshal(&dbChallenge{ID: "bar"})
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(authzTable):
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
return b, nil
|
||||||
|
case string(challengeTable):
|
||||||
|
if chCount == 0 {
|
||||||
|
chCount++
|
||||||
|
assert.Equals(t, string(key), "foo")
|
||||||
|
return fooChb, nil
|
||||||
|
}
|
||||||
|
assert.Equals(t, string(key), "bar")
|
||||||
|
return barChb, nil
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbaz: dbaz,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if az, err := db.GetAuthorization(context.Background(), azID); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, az.ID, tc.dbaz.ID)
|
||||||
|
assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
|
||||||
|
assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
|
||||||
|
assert.Equals(t, az.Status, tc.dbaz.Status)
|
||||||
|
assert.Equals(t, az.Token, tc.dbaz.Token)
|
||||||
|
assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
|
||||||
|
assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||||
|
assert.Equals(t, az.Challenges, []*acme.Challenge{
|
||||||
|
{ID: "foo"},
|
||||||
|
{ID: "bar"},
|
||||||
|
})
|
||||||
|
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_CreateAuthorization(t *testing.T) {
|
||||||
|
azID := "azID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
az *acme.Authorization
|
||||||
|
err error
|
||||||
|
_id *string
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
az := &acme.Authorization{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
},
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
Challenges: []*acme.Challenge{
|
||||||
|
{ID: "foo"},
|
||||||
|
{ID: "bar"},
|
||||||
|
},
|
||||||
|
Wildcard: true,
|
||||||
|
Error: acme.NewErrorISE("force"),
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), az.ID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbaz := new(dbAuthz)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbaz))
|
||||||
|
assert.Equals(t, dbaz.ID, string(key))
|
||||||
|
assert.Equals(t, dbaz.AccountID, az.AccountID)
|
||||||
|
assert.Equals(t, dbaz.Identifier, acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
})
|
||||||
|
assert.Equals(t, dbaz.Status, az.Status)
|
||||||
|
assert.Equals(t, dbaz.Token, az.Token)
|
||||||
|
assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"})
|
||||||
|
assert.Equals(t, dbaz.Wildcard, az.Wildcard)
|
||||||
|
assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt)
|
||||||
|
assert.Nil(t, dbaz.Error)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
az: az,
|
||||||
|
err: errors.New("error saving acme authz: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
var (
|
||||||
|
id string
|
||||||
|
idPtr = &id
|
||||||
|
now = clock.Now()
|
||||||
|
az = &acme.Authorization{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
},
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
Challenges: []*acme.Challenge{
|
||||||
|
{ID: "foo"},
|
||||||
|
{ID: "bar"},
|
||||||
|
},
|
||||||
|
Wildcard: true,
|
||||||
|
Error: acme.NewErrorISE("force"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
*idPtr = string(key)
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), az.ID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbaz := new(dbAuthz)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbaz))
|
||||||
|
assert.Equals(t, dbaz.ID, string(key))
|
||||||
|
assert.Equals(t, dbaz.AccountID, az.AccountID)
|
||||||
|
assert.Equals(t, dbaz.Identifier, acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
})
|
||||||
|
assert.Equals(t, dbaz.Status, az.Status)
|
||||||
|
assert.Equals(t, dbaz.Token, az.Token)
|
||||||
|
assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"})
|
||||||
|
assert.Equals(t, dbaz.Wildcard, az.Wildcard)
|
||||||
|
assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt)
|
||||||
|
assert.Nil(t, dbaz.Error)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt))
|
||||||
|
return nu, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
az: az,
|
||||||
|
_id: idPtr,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.CreateAuthorization(context.Background(), tc.az); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.az.ID, *tc._id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_UpdateAuthorization(t *testing.T) {
|
||||||
|
azID := "azID"
|
||||||
|
now := clock.Now()
|
||||||
|
dbaz := &dbAuthz{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
},
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(5 * time.Minute),
|
||||||
|
ChallengeIDs: []string{"foo", "bar"},
|
||||||
|
Wildcard: true,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbaz)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
az *acme.Authorization
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
az: &acme.Authorization{
|
||||||
|
ID: azID,
|
||||||
|
},
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading authz azID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
updAz := &acme.Authorization{
|
||||||
|
ID: azID,
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
az: updAz,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbOld := new(dbAuthz)
|
||||||
|
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||||
|
assert.Equals(t, dbaz, dbOld)
|
||||||
|
|
||||||
|
dbNew := new(dbAuthz)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, dbaz.ID)
|
||||||
|
assert.Equals(t, dbNew.AccountID, dbaz.AccountID)
|
||||||
|
assert.Equals(t, dbNew.Identifier, dbaz.Identifier)
|
||||||
|
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||||
|
assert.Equals(t, dbNew.Token, dbaz.Token)
|
||||||
|
assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs)
|
||||||
|
assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
|
||||||
|
assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
|
||||||
|
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error saving acme authz: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
updAz := &acme.Authorization{
|
||||||
|
ID: azID,
|
||||||
|
AccountID: dbaz.AccountID,
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Identifier: dbaz.Identifier,
|
||||||
|
Challenges: []*acme.Challenge{
|
||||||
|
{ID: "foo"},
|
||||||
|
{ID: "bar"},
|
||||||
|
},
|
||||||
|
Token: dbaz.Token,
|
||||||
|
Wildcard: dbaz.Wildcard,
|
||||||
|
ExpiresAt: dbaz.ExpiresAt,
|
||||||
|
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
az: updAz,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, string(key), azID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbOld := new(dbAuthz)
|
||||||
|
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||||
|
assert.Equals(t, dbaz, dbOld)
|
||||||
|
|
||||||
|
dbNew := new(dbAuthz)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, dbaz.ID)
|
||||||
|
assert.Equals(t, dbNew.AccountID, dbaz.AccountID)
|
||||||
|
assert.Equals(t, dbNew.Identifier, dbaz.Identifier)
|
||||||
|
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||||
|
assert.Equals(t, dbNew.Token, dbaz.Token)
|
||||||
|
assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs)
|
||||||
|
assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
|
||||||
|
assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
|
||||||
|
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||||
|
return nu, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.az.ID, dbaz.ID)
|
||||||
|
assert.Equals(t, tc.az.AccountID, dbaz.AccountID)
|
||||||
|
assert.Equals(t, tc.az.Identifier, dbaz.Identifier)
|
||||||
|
assert.Equals(t, tc.az.Status, acme.StatusValid)
|
||||||
|
assert.Equals(t, tc.az.Wildcard, dbaz.Wildcard)
|
||||||
|
assert.Equals(t, tc.az.Token, dbaz.Token)
|
||||||
|
assert.Equals(t, tc.az.ExpiresAt, dbaz.ExpiresAt)
|
||||||
|
assert.Equals(t, tc.az.Challenges, []*acme.Challenge{
|
||||||
|
{ID: "foo"},
|
||||||
|
{ID: "bar"},
|
||||||
|
})
|
||||||
|
assert.Equals(t, tc.az.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
109
acme/db/nosql/certificate.go
Normal file
109
acme/db/nosql/certificate.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dbCert struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
AccountID string `json:"accountID"`
|
||||||
|
OrderID string `json:"orderID"`
|
||||||
|
Leaf []byte `json:"leaf"`
|
||||||
|
Intermediates []byte `json:"intermediates"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateCertificate creates and stores an ACME certificate type.
|
||||||
|
func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error {
|
||||||
|
var err error
|
||||||
|
cert.ID, err = randID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
leaf := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: cert.Leaf.Raw,
|
||||||
|
})
|
||||||
|
var intermediates []byte
|
||||||
|
for _, cert := range cert.Intermediates {
|
||||||
|
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: cert.Raw,
|
||||||
|
})...)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbch := &dbCert{
|
||||||
|
ID: cert.ID,
|
||||||
|
AccountID: cert.AccountID,
|
||||||
|
OrderID: cert.OrderID,
|
||||||
|
Leaf: leaf,
|
||||||
|
Intermediates: intermediates,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCertificate retrieves and unmarshals an ACME certificate type from the
|
||||||
|
// datastore.
|
||||||
|
func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) {
|
||||||
|
b, err := db.db.Get(certTable, []byte(id))
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error loading certificate %s", id)
|
||||||
|
}
|
||||||
|
dbC := new(dbCert)
|
||||||
|
if err := json.Unmarshal(b, dbC); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error unmarshaling certificate %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
certs, err := parseBundle(append(dbC.Leaf, dbC.Intermediates...))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error parsing certificate chain for ACME certificate with ID %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &acme.Certificate{
|
||||||
|
ID: dbC.ID,
|
||||||
|
AccountID: dbC.AccountID,
|
||||||
|
OrderID: dbC.OrderID,
|
||||||
|
Leaf: certs[0],
|
||||||
|
Intermediates: certs[1:],
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBundle(b []byte) ([]*x509.Certificate, error) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
block *pem.Block
|
||||||
|
bundle []*x509.Certificate
|
||||||
|
)
|
||||||
|
for len(b) > 0 {
|
||||||
|
block, b = pem.Decode(b)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if block.Type != "CERTIFICATE" {
|
||||||
|
return nil, errors.New("error decoding PEM: data contains block that is not a certificate")
|
||||||
|
}
|
||||||
|
var crt *x509.Certificate
|
||||||
|
crt, err = x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error parsing x509 certificate")
|
||||||
|
}
|
||||||
|
bundle = append(bundle, crt)
|
||||||
|
}
|
||||||
|
if len(b) > 0 {
|
||||||
|
return nil, errors.New("error decoding PEM: unexpected data")
|
||||||
|
}
|
||||||
|
return bundle, nil
|
||||||
|
|
||||||
|
}
|
321
acme/db/nosql/certificate_test.go
Normal file
321
acme/db/nosql/certificate_test.go
Normal file
|
@ -0,0 +1,321 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
nosqldb "github.com/smallstep/nosql/database"
|
||||||
|
|
||||||
|
"go.step.sm/crypto/pemutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDB_CreateCertificate(t *testing.T) {
|
||||||
|
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
cert *acme.Certificate
|
||||||
|
err error
|
||||||
|
_id *string
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
cert := &acme.Certificate{
|
||||||
|
AccountID: "accountID",
|
||||||
|
OrderID: "orderID",
|
||||||
|
Leaf: leaf,
|
||||||
|
Intermediates: []*x509.Certificate{inter, root},
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, key, []byte(cert.ID))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbc := new(dbCert)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||||
|
assert.Equals(t, dbc.ID, string(key))
|
||||||
|
assert.Equals(t, dbc.ID, cert.ID)
|
||||||
|
assert.Equals(t, dbc.AccountID, cert.AccountID)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
cert: cert,
|
||||||
|
err: errors.New("error saving acme certificate: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
cert := &acme.Certificate{
|
||||||
|
AccountID: "accountID",
|
||||||
|
OrderID: "orderID",
|
||||||
|
Leaf: leaf,
|
||||||
|
Intermediates: []*x509.Certificate{inter, root},
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
id string
|
||||||
|
idPtr = &id
|
||||||
|
)
|
||||||
|
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
*idPtr = string(key)
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, key, []byte(cert.ID))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbc := new(dbCert)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||||
|
assert.Equals(t, dbc.ID, string(key))
|
||||||
|
assert.Equals(t, dbc.ID, cert.ID)
|
||||||
|
assert.Equals(t, dbc.AccountID, cert.AccountID)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
_id: idPtr,
|
||||||
|
cert: cert,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.CreateCertificate(context.Background(), tc.cert); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.cert.ID, *tc._id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_GetCertificate(t *testing.T) {
|
||||||
|
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
certID := "certID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, string(key), certID)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, string(key), certID)
|
||||||
|
|
||||||
|
return nil, errors.Errorf("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading certificate certID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, string(key), certID)
|
||||||
|
|
||||||
|
return []byte("foobar"), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error unmarshaling certificate certID"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/parseBundle-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, string(key), certID)
|
||||||
|
|
||||||
|
cert := dbCert{
|
||||||
|
ID: certID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
OrderID: "orderID",
|
||||||
|
Leaf: pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "Public Key",
|
||||||
|
Bytes: leaf.Raw,
|
||||||
|
}),
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(cert)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.Errorf("error parsing certificate chain for ACME certificate with ID certID"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, string(key), certID)
|
||||||
|
|
||||||
|
cert := dbCert{
|
||||||
|
ID: certID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
OrderID: "orderID",
|
||||||
|
Leaf: pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: leaf.Raw,
|
||||||
|
}),
|
||||||
|
Intermediates: append(pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: inter.Raw,
|
||||||
|
}), pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: root.Raw,
|
||||||
|
})...),
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(cert)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
cert, err := db.GetCertificate(context.Background(), certID)
|
||||||
|
if err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, cert.ID, certID)
|
||||||
|
assert.Equals(t, cert.AccountID, "accountID")
|
||||||
|
assert.Equals(t, cert.OrderID, "orderID")
|
||||||
|
assert.Equals(t, cert.Leaf, leaf)
|
||||||
|
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_parseBundle(t *testing.T) {
|
||||||
|
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
var certs []byte
|
||||||
|
for _, cert := range []*x509.Certificate{leaf, inter, root} {
|
||||||
|
certs = append(certs, pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: cert.Raw,
|
||||||
|
})...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
b []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
var tests = map[string]test{
|
||||||
|
"fail/bad-type-error": {
|
||||||
|
b: pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "Public Key",
|
||||||
|
Bytes: leaf.Raw,
|
||||||
|
}),
|
||||||
|
err: errors.Errorf("error decoding PEM: data contains block that is not a certificate"),
|
||||||
|
},
|
||||||
|
"fail/bad-pem-error": {
|
||||||
|
b: pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: []byte("foo"),
|
||||||
|
}),
|
||||||
|
err: errors.Errorf("error parsing x509 certificate"),
|
||||||
|
},
|
||||||
|
"fail/unexpected-data": {
|
||||||
|
b: append(pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: leaf.Raw,
|
||||||
|
}), []byte("foo")...),
|
||||||
|
err: errors.Errorf("error decoding PEM: unexpected data"),
|
||||||
|
},
|
||||||
|
"ok": {
|
||||||
|
b: certs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
ret, err := parseBundle(tc.b)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, ret, []*x509.Certificate{leaf, inter, root})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
103
acme/db/nosql/challenge.go
Normal file
103
acme/db/nosql/challenge.go
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dbChallenge struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
AccountID string `json:"accountID"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Status acme.Status `json:"status"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
ValidatedAt string `json:"validatedAt"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
Error *acme.Error `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dbc *dbChallenge) clone() *dbChallenge {
|
||||||
|
u := *dbc
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) {
|
||||||
|
data, err := db.db.Get(challengeTable, []byte(id))
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error loading acme challenge %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbch := new(dbChallenge)
|
||||||
|
if err := json.Unmarshal(data, dbch); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error unmarshaling dbChallenge")
|
||||||
|
}
|
||||||
|
return dbch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateChallenge creates a new ACME challenge data structure in the database.
|
||||||
|
// Implements acme.DB.CreateChallenge interface.
|
||||||
|
func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error {
|
||||||
|
var err error
|
||||||
|
ch.ID, err = randID()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error generating random id for ACME challenge")
|
||||||
|
}
|
||||||
|
|
||||||
|
dbch := &dbChallenge{
|
||||||
|
ID: ch.ID,
|
||||||
|
AccountID: ch.AccountID,
|
||||||
|
Value: ch.Value,
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: ch.Token,
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
Type: ch.Type,
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChallenge retrieves and unmarshals an ACME challenge type from the database.
|
||||||
|
// Implements the acme.DB GetChallenge interface.
|
||||||
|
func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) {
|
||||||
|
dbch, err := db.getDBChallenge(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := &acme.Challenge{
|
||||||
|
ID: dbch.ID,
|
||||||
|
AccountID: dbch.AccountID,
|
||||||
|
Type: dbch.Type,
|
||||||
|
Value: dbch.Value,
|
||||||
|
Status: dbch.Status,
|
||||||
|
Token: dbch.Token,
|
||||||
|
Error: dbch.Error,
|
||||||
|
ValidatedAt: dbch.ValidatedAt,
|
||||||
|
}
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateChallenge updates an ACME challenge type in the database.
|
||||||
|
func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error {
|
||||||
|
old, err := db.getDBChallenge(ctx, ch.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
nu := old.clone()
|
||||||
|
|
||||||
|
// These should be the only values changing in an Update request.
|
||||||
|
nu.Status = ch.Status
|
||||||
|
nu.Error = ch.Error
|
||||||
|
nu.ValidatedAt = ch.ValidatedAt
|
||||||
|
|
||||||
|
return db.save(ctx, old.ID, nu, old, "challenge", challengeTable)
|
||||||
|
}
|
464
acme/db/nosql/challenge_test.go
Normal file
464
acme/db/nosql/challenge_test.go
Normal file
|
@ -0,0 +1,464 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
nosqldb "github.com/smallstep/nosql/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDB_getDBChallenge(t *testing.T) {
|
||||||
|
chID := "chID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbc *dbChallenge
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading acme challenge chID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return []byte("foo"), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error unmarshaling dbChallenge"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
dbc := &dbChallenge{
|
||||||
|
ID: chID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Type: "dns-01",
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
ValidatedAt: "foobar",
|
||||||
|
Error: acme.NewErrorISE("force"),
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbc: dbc,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||||
|
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||||
|
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||||
|
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||||
|
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||||
|
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||||
|
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||||
|
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_CreateChallenge(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
ch *acme.Challenge
|
||||||
|
err error
|
||||||
|
_id *string
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
ch := &acme.Challenge{
|
||||||
|
AccountID: "accountID",
|
||||||
|
Type: "dns-01",
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), ch.ID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbc := new(dbChallenge)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||||
|
assert.Equals(t, dbc.ID, string(key))
|
||||||
|
assert.Equals(t, dbc.AccountID, ch.AccountID)
|
||||||
|
assert.Equals(t, dbc.Type, ch.Type)
|
||||||
|
assert.Equals(t, dbc.Status, ch.Status)
|
||||||
|
assert.Equals(t, dbc.Token, ch.Token)
|
||||||
|
assert.Equals(t, dbc.Value, ch.Value)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ch: ch,
|
||||||
|
err: errors.New("error saving acme challenge: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
var (
|
||||||
|
id string
|
||||||
|
idPtr = &id
|
||||||
|
ch = &acme.Challenge{
|
||||||
|
AccountID: "accountID",
|
||||||
|
Type: "dns-01",
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return test{
|
||||||
|
ch: ch,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
*idPtr = string(key)
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), ch.ID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbc := new(dbChallenge)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
||||||
|
assert.Equals(t, dbc.ID, string(key))
|
||||||
|
assert.Equals(t, dbc.AccountID, ch.AccountID)
|
||||||
|
assert.Equals(t, dbc.Type, ch.Type)
|
||||||
|
assert.Equals(t, dbc.Status, ch.Status)
|
||||||
|
assert.Equals(t, dbc.Token, ch.Token)
|
||||||
|
assert.Equals(t, dbc.Value, ch.Value)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
_id: idPtr,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.ch.ID, *tc._id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_GetChallenge(t *testing.T) {
|
||||||
|
chID := "chID"
|
||||||
|
azID := "azID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbc *dbChallenge
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading acme challenge chID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/forward-acme-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
dbc := &dbChallenge{
|
||||||
|
ID: chID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Type: "dns-01",
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
ValidatedAt: "foobar",
|
||||||
|
Error: acme.NewErrorISE("force"),
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbc: dbc,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||||
|
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||||
|
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||||
|
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||||
|
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||||
|
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||||
|
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||||
|
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_UpdateChallenge(t *testing.T) {
|
||||||
|
chID := "chID"
|
||||||
|
dbc := &dbChallenge{
|
||||||
|
ID: chID,
|
||||||
|
AccountID: "accountID",
|
||||||
|
Type: "dns-01",
|
||||||
|
Status: acme.StatusPending,
|
||||||
|
Token: "token",
|
||||||
|
Value: "test.ca.smallstep.com",
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
ch *acme.Challenge
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ch: &acme.Challenge{
|
||||||
|
ID: chID,
|
||||||
|
},
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading acme challenge chID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
updCh := &acme.Challenge{
|
||||||
|
ID: chID,
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
ValidatedAt: "foobar",
|
||||||
|
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
ch: updCh,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbOld := new(dbChallenge)
|
||||||
|
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||||
|
assert.Equals(t, dbc, dbOld)
|
||||||
|
|
||||||
|
dbNew := new(dbChallenge)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, dbc.ID)
|
||||||
|
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
|
||||||
|
assert.Equals(t, dbNew.Type, dbc.Type)
|
||||||
|
assert.Equals(t, dbNew.Status, updCh.Status)
|
||||||
|
assert.Equals(t, dbNew.Token, dbc.Token)
|
||||||
|
assert.Equals(t, dbNew.Value, dbc.Value)
|
||||||
|
assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error())
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
|
||||||
|
assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt)
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error saving acme challenge: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
updCh := &acme.Challenge{
|
||||||
|
ID: dbc.ID,
|
||||||
|
AccountID: dbc.AccountID,
|
||||||
|
Type: dbc.Type,
|
||||||
|
Token: dbc.Token,
|
||||||
|
Value: dbc.Value,
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
ValidatedAt: "foobar",
|
||||||
|
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
ch: updCh,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), chID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbOld := new(dbChallenge)
|
||||||
|
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
||||||
|
assert.Equals(t, dbc, dbOld)
|
||||||
|
|
||||||
|
dbNew := new(dbChallenge)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, dbc.ID)
|
||||||
|
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
|
||||||
|
assert.Equals(t, dbNew.Type, dbc.Type)
|
||||||
|
assert.Equals(t, dbNew.Token, dbc.Token)
|
||||||
|
assert.Equals(t, dbNew.Value, dbc.Value)
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
|
||||||
|
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
||||||
|
assert.Equals(t, dbNew.ValidatedAt, "foobar")
|
||||||
|
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||||
|
return nu, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.ch.ID, dbc.ID)
|
||||||
|
assert.Equals(t, tc.ch.AccountID, dbc.AccountID)
|
||||||
|
assert.Equals(t, tc.ch.Type, dbc.Type)
|
||||||
|
assert.Equals(t, tc.ch.Token, dbc.Token)
|
||||||
|
assert.Equals(t, tc.ch.Value, dbc.Value)
|
||||||
|
assert.Equals(t, tc.ch.ValidatedAt, "foobar")
|
||||||
|
assert.Equals(t, tc.ch.Status, acme.StatusValid)
|
||||||
|
assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
66
acme/db/nosql/nonce.go
Normal file
66
acme/db/nosql/nonce.go
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
"github.com/smallstep/nosql/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dbNonce contains nonce metadata used in the ACME protocol.
|
||||||
|
type dbNonce struct {
|
||||||
|
ID string
|
||||||
|
CreatedAt time.Time
|
||||||
|
DeletedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateNonce creates, stores, and returns an ACME replay-nonce.
|
||||||
|
// Implements the acme.DB interface.
|
||||||
|
func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
|
||||||
|
_id, err := randID()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
|
||||||
|
n := &dbNonce{
|
||||||
|
ID: id,
|
||||||
|
CreatedAt: clock.Now(),
|
||||||
|
}
|
||||||
|
if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return acme.Nonce(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteNonce verifies that the nonce is valid (by checking if it exists),
|
||||||
|
// and if so, consumes the nonce resource by deleting it from the database.
|
||||||
|
func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error {
|
||||||
|
err := db.db.Update(&database.Tx{
|
||||||
|
Operations: []*database.TxEntry{
|
||||||
|
{
|
||||||
|
Bucket: nonceTable,
|
||||||
|
Key: []byte(nonce),
|
||||||
|
Cmd: database.Get,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Bucket: nonceTable,
|
||||||
|
Key: []byte(nonce),
|
||||||
|
Cmd: database.Delete,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case nosql.IsErrNotFound(err):
|
||||||
|
return acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", string(nonce))
|
||||||
|
case err != nil:
|
||||||
|
return errors.Wrapf(err, "error deleting nonce %s", string(nonce))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
168
acme/db/nosql/nonce_test.go
Normal file
168
acme/db/nosql/nonce_test.go
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
"github.com/smallstep/nosql/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDB_CreateNonce(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
_id *string
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, nonceTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbn := new(dbNonce)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbn))
|
||||||
|
assert.Equals(t, dbn.ID, string(key))
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error saving acme nonce: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
var (
|
||||||
|
id string
|
||||||
|
idPtr = &id
|
||||||
|
)
|
||||||
|
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
*idPtr = string(key)
|
||||||
|
assert.Equals(t, bucket, nonceTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbn := new(dbNonce)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbn))
|
||||||
|
assert.Equals(t, dbn.ID, string(key))
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
_id: idPtr,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if n, err := db.CreateNonce(context.Background()); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, string(n), *tc._id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_DeleteNonce(t *testing.T) {
|
||||||
|
|
||||||
|
nonceID := "nonceID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
|
||||||
|
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||||
|
|
||||||
|
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
|
||||||
|
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||||
|
return database.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", nonceID),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.Update-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
|
||||||
|
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||||
|
|
||||||
|
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
|
||||||
|
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||||
|
return errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error deleting nonce nonceID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
|
||||||
|
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||||
|
|
||||||
|
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
|
||||||
|
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
96
acme/db/nosql/nosql.go
Normal file
96
acme/db/nosql/nosql.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
nosqlDB "github.com/smallstep/nosql"
|
||||||
|
"go.step.sm/crypto/randutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
accountTable = []byte("acme_accounts")
|
||||||
|
accountByKeyIDTable = []byte("acme_keyID_accountID_index")
|
||||||
|
authzTable = []byte("acme_authzs")
|
||||||
|
challengeTable = []byte("acme_challenges")
|
||||||
|
nonceTable = []byte("nonces")
|
||||||
|
orderTable = []byte("acme_orders")
|
||||||
|
ordersByAccountIDTable = []byte("acme_account_orders_index")
|
||||||
|
certTable = []byte("acme_certs")
|
||||||
|
)
|
||||||
|
|
||||||
|
// DB is a struct that implements the AcmeDB interface.
|
||||||
|
type DB struct {
|
||||||
|
db nosqlDB.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// New configures and returns a new ACME DB backend implemented using a nosql DB.
|
||||||
|
func New(db nosqlDB.DB) (*DB, error) {
|
||||||
|
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
|
||||||
|
challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable}
|
||||||
|
for _, b := range tables {
|
||||||
|
if err := db.CreateTable(b); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error creating table %s",
|
||||||
|
string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &DB{db}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// save writes the new data to the database, overwriting the old data if it
|
||||||
|
// existed.
|
||||||
|
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
newB []byte
|
||||||
|
)
|
||||||
|
if nu == nil {
|
||||||
|
newB = nil
|
||||||
|
} else {
|
||||||
|
newB, err = json.Marshal(nu)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var oldB []byte
|
||||||
|
if old == nil {
|
||||||
|
oldB = nil
|
||||||
|
} else {
|
||||||
|
oldB, err = json.Marshal(old)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return errors.Wrapf(err, "error saving acme %s", typ)
|
||||||
|
case !swapped:
|
||||||
|
return errors.Errorf("error saving acme %s; changed since last read", typ)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var idLen = 32
|
||||||
|
|
||||||
|
func randID() (val string, err error) {
|
||||||
|
val, err = randutil.Alphanumeric(idLen)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "error generating random alphanumeric ID")
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clock that returns time in UTC rounded to seconds.
|
||||||
|
type Clock struct{}
|
||||||
|
|
||||||
|
// Now returns the UTC time rounded to seconds.
|
||||||
|
func (c *Clock) Now() time.Time {
|
||||||
|
return time.Now().UTC().Truncate(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
var clock = new(Clock)
|
139
acme/db/nosql/nosql_test.go
Normal file
139
acme/db/nosql/nosql_test.go
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
var tests = map[string]test{
|
||||||
|
"fail/db.CreateTable-error": {
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCreateTable: func(bucket []byte) error {
|
||||||
|
assert.Equals(t, string(bucket), string(accountTable))
|
||||||
|
return errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.Errorf("error creating table %s: force", string(accountTable)),
|
||||||
|
},
|
||||||
|
"ok": {
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCreateTable: func(bucket []byte) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
if _, err := New(tc.db); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorThrower string
|
||||||
|
|
||||||
|
func (et errorThrower) MarshalJSON() ([]byte, error) {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_save(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
nu interface{}
|
||||||
|
old interface{}
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
var tests = map[string]test{
|
||||||
|
"fail/error-marshaling-new": {
|
||||||
|
nu: errorThrower("foo"),
|
||||||
|
err: errors.New("error marshaling acme type: challenge"),
|
||||||
|
},
|
||||||
|
"fail/error-marshaling-old": {
|
||||||
|
nu: "new",
|
||||||
|
old: errorThrower("foo"),
|
||||||
|
err: errors.New("error marshaling acme type: challenge"),
|
||||||
|
},
|
||||||
|
"fail/db.CmpAndSwap-error": {
|
||||||
|
nu: "new",
|
||||||
|
old: "old",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), "id")
|
||||||
|
assert.Equals(t, string(old), "\"old\"")
|
||||||
|
assert.Equals(t, string(nu), "\"new\"")
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error saving acme challenge: force"),
|
||||||
|
},
|
||||||
|
"fail/db.CmpAndSwap-false-marshaling-old": {
|
||||||
|
nu: "new",
|
||||||
|
old: "old",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), "id")
|
||||||
|
assert.Equals(t, string(old), "\"old\"")
|
||||||
|
assert.Equals(t, string(nu), "\"new\"")
|
||||||
|
return nil, false, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error saving acme challenge; changed since last read"),
|
||||||
|
},
|
||||||
|
"ok": {
|
||||||
|
nu: "new",
|
||||||
|
old: "old",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), "id")
|
||||||
|
assert.Equals(t, string(old), "\"old\"")
|
||||||
|
assert.Equals(t, string(nu), "\"new\"")
|
||||||
|
return nu, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ok/nils": {
|
||||||
|
nu: nil,
|
||||||
|
old: nil,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, challengeTable)
|
||||||
|
assert.Equals(t, string(key), "id")
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, nu, nil)
|
||||||
|
return nu, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := &DB{db: tc.db}
|
||||||
|
if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
189
acme/db/nosql/order.go
Normal file
189
acme/db/nosql/order.go
Normal file
|
@ -0,0 +1,189 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mutex for locking ordersByAccount index operations.
|
||||||
|
var ordersByAccountMux sync.Mutex
|
||||||
|
|
||||||
|
type dbOrder struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
AccountID string `json:"accountID"`
|
||||||
|
ProvisionerID string `json:"provisionerID"`
|
||||||
|
Identifiers []acme.Identifier `json:"identifiers"`
|
||||||
|
AuthorizationIDs []string `json:"authorizationIDs"`
|
||||||
|
Status acme.Status `json:"status"`
|
||||||
|
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||||
|
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
ExpiresAt time.Time `json:"expiresAt,omitempty"`
|
||||||
|
CertificateID string `json:"certificate,omitempty"`
|
||||||
|
Error *acme.Error `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *dbOrder) clone() *dbOrder {
|
||||||
|
b := *a
|
||||||
|
return &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDBOrder retrieves and unmarshals an ACME Order type from the database.
|
||||||
|
func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) {
|
||||||
|
b, err := db.db.Get(orderTable, []byte(id))
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error loading order %s", id)
|
||||||
|
}
|
||||||
|
o := new(dbOrder)
|
||||||
|
if err := json.Unmarshal(b, &o); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error unmarshaling order %s into dbOrder", id)
|
||||||
|
}
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrder retrieves an ACME Order from the database.
|
||||||
|
func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) {
|
||||||
|
dbo, err := db.getDBOrder(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
o := &acme.Order{
|
||||||
|
ID: dbo.ID,
|
||||||
|
AccountID: dbo.AccountID,
|
||||||
|
ProvisionerID: dbo.ProvisionerID,
|
||||||
|
CertificateID: dbo.CertificateID,
|
||||||
|
Status: dbo.Status,
|
||||||
|
ExpiresAt: dbo.ExpiresAt,
|
||||||
|
Identifiers: dbo.Identifiers,
|
||||||
|
NotBefore: dbo.NotBefore,
|
||||||
|
NotAfter: dbo.NotAfter,
|
||||||
|
AuthorizationIDs: dbo.AuthorizationIDs,
|
||||||
|
Error: dbo.Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOrder creates ACME Order resources and saves them to the DB.
|
||||||
|
func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error {
|
||||||
|
var err error
|
||||||
|
o.ID, err = randID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := clock.Now()
|
||||||
|
dbo := &dbOrder{
|
||||||
|
ID: o.ID,
|
||||||
|
AccountID: o.AccountID,
|
||||||
|
ProvisionerID: o.ProvisionerID,
|
||||||
|
Status: o.Status,
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: o.ExpiresAt,
|
||||||
|
Identifiers: o.Identifiers,
|
||||||
|
NotBefore: o.NotBefore,
|
||||||
|
NotAfter: o.NotAfter,
|
||||||
|
AuthorizationIDs: o.AuthorizationIDs,
|
||||||
|
}
|
||||||
|
if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.updateAddOrderIDs(ctx, o.AccountID, o.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOrder saves an updated ACME Order to the database.
|
||||||
|
func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error {
|
||||||
|
old, err := db.getDBOrder(ctx, o.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
nu := old.clone()
|
||||||
|
|
||||||
|
nu.Status = o.Status
|
||||||
|
nu.Error = o.Error
|
||||||
|
nu.CertificateID = o.CertificateID
|
||||||
|
return db.save(ctx, old.ID, nu, old, "order", orderTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) {
|
||||||
|
ordersByAccountMux.Lock()
|
||||||
|
defer ordersByAccountMux.Unlock()
|
||||||
|
|
||||||
|
b, err := db.db.Get(ordersByAccountIDTable, []byte(accID))
|
||||||
|
var (
|
||||||
|
oldOids []string
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if !nosql.IsErrNotFound(err) {
|
||||||
|
return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := json.Unmarshal(b, &oldOids); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any order that is not in PENDING state and update the stored list
|
||||||
|
// before returning.
|
||||||
|
//
|
||||||
|
// According to RFC 8555:
|
||||||
|
// The server SHOULD include pending orders and SHOULD NOT include orders
|
||||||
|
// that are invalid in the array of URLs.
|
||||||
|
pendOids := []string{}
|
||||||
|
for _, oid := range oldOids {
|
||||||
|
o, err := db.GetOrder(ctx, oid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID)
|
||||||
|
}
|
||||||
|
if err = o.UpdateStatus(ctx, db); err != nil {
|
||||||
|
return nil, acme.WrapErrorISE(err, "error updating order %s for account %s", oid, accID)
|
||||||
|
}
|
||||||
|
if o.Status == acme.StatusPending {
|
||||||
|
pendOids = append(pendOids, oid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pendOids = append(pendOids, addOids...)
|
||||||
|
var (
|
||||||
|
_old interface{} = oldOids
|
||||||
|
_new interface{} = pendOids
|
||||||
|
)
|
||||||
|
switch {
|
||||||
|
case len(oldOids) == 0 && len(pendOids) == 0:
|
||||||
|
// If list has not changed from empty, then no need to write the DB.
|
||||||
|
return []string{}, nil
|
||||||
|
case len(oldOids) == 0:
|
||||||
|
_old = nil
|
||||||
|
case len(pendOids) == 0:
|
||||||
|
_new = nil
|
||||||
|
}
|
||||||
|
if err = db.save(ctx, accID, _new, _old, "orderIDsByAccountID", ordersByAccountIDTable); err != nil {
|
||||||
|
// Delete all orders that may have been previously stored if orderIDsByAccountID update fails.
|
||||||
|
for _, oid := range addOids {
|
||||||
|
// Ignore error from delete -- we tried our best.
|
||||||
|
// TODO when we have logging w/ request ID tracking, logging this error.
|
||||||
|
db.db.Del(orderTable, []byte(oid))
|
||||||
|
}
|
||||||
|
return nil, errors.Wrapf(err, "error saving orderIDs index for account %s", accID)
|
||||||
|
}
|
||||||
|
return pendOids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrdersByAccountID returns a list of order IDs owned by the account.
|
||||||
|
func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
|
||||||
|
return db.updateAddOrderIDs(ctx, accID)
|
||||||
|
}
|
1003
acme/db/nosql/order_test.go
Normal file
1003
acme/db/nosql/order_test.go
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,150 +0,0 @@
|
||||||
package acme
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Directory represents an ACME directory for configuring clients.
|
|
||||||
type Directory struct {
|
|
||||||
NewNonce string `json:"newNonce,omitempty"`
|
|
||||||
NewAccount string `json:"newAccount,omitempty"`
|
|
||||||
NewOrder string `json:"newOrder,omitempty"`
|
|
||||||
NewAuthz string `json:"newAuthz,omitempty"`
|
|
||||||
RevokeCert string `json:"revokeCert,omitempty"`
|
|
||||||
KeyChange string `json:"keyChange,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToLog enables response logging for the Directory type.
|
|
||||||
func (d *Directory) ToLog() (interface{}, error) {
|
|
||||||
b, err := json.Marshal(d)
|
|
||||||
if err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging"))
|
|
||||||
}
|
|
||||||
return string(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type directory struct {
|
|
||||||
prefix, dns string
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDirectory returns a new Directory type.
|
|
||||||
func newDirectory(dns, prefix string) *directory {
|
|
||||||
return &directory{prefix: prefix, dns: dns}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Link captures the link type.
|
|
||||||
type Link int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// NewNonceLink new-nonce
|
|
||||||
NewNonceLink Link = iota
|
|
||||||
// NewAccountLink new-account
|
|
||||||
NewAccountLink
|
|
||||||
// AccountLink account
|
|
||||||
AccountLink
|
|
||||||
// OrderLink order
|
|
||||||
OrderLink
|
|
||||||
// NewOrderLink new-order
|
|
||||||
NewOrderLink
|
|
||||||
// OrdersByAccountLink list of orders owned by account
|
|
||||||
OrdersByAccountLink
|
|
||||||
// FinalizeLink finalize order
|
|
||||||
FinalizeLink
|
|
||||||
// NewAuthzLink authz
|
|
||||||
NewAuthzLink
|
|
||||||
// AuthzLink new-authz
|
|
||||||
AuthzLink
|
|
||||||
// ChallengeLink challenge
|
|
||||||
ChallengeLink
|
|
||||||
// CertificateLink certificate
|
|
||||||
CertificateLink
|
|
||||||
// DirectoryLink directory
|
|
||||||
DirectoryLink
|
|
||||||
// RevokeCertLink revoke certificate
|
|
||||||
RevokeCertLink
|
|
||||||
// KeyChangeLink key rollover
|
|
||||||
KeyChangeLink
|
|
||||||
)
|
|
||||||
|
|
||||||
func (l Link) String() string {
|
|
||||||
switch l {
|
|
||||||
case NewNonceLink:
|
|
||||||
return "new-nonce"
|
|
||||||
case NewAccountLink:
|
|
||||||
return "new-account"
|
|
||||||
case AccountLink:
|
|
||||||
return "account"
|
|
||||||
case NewOrderLink:
|
|
||||||
return "new-order"
|
|
||||||
case OrderLink:
|
|
||||||
return "order"
|
|
||||||
case NewAuthzLink:
|
|
||||||
return "new-authz"
|
|
||||||
case AuthzLink:
|
|
||||||
return "authz"
|
|
||||||
case ChallengeLink:
|
|
||||||
return "challenge"
|
|
||||||
case CertificateLink:
|
|
||||||
return "certificate"
|
|
||||||
case DirectoryLink:
|
|
||||||
return "directory"
|
|
||||||
case RevokeCertLink:
|
|
||||||
return "revoke-cert"
|
|
||||||
case KeyChangeLink:
|
|
||||||
return "key-change"
|
|
||||||
default:
|
|
||||||
return "unexpected"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *directory) getLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
|
|
||||||
var provName string
|
|
||||||
if p, err := ProvisionerFromContext(ctx); err == nil && p != nil {
|
|
||||||
provName = p.GetName()
|
|
||||||
}
|
|
||||||
return d.getLinkExplicit(typ, provName, abs, BaseURLFromContext(ctx), inputs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getLinkExplicit returns an absolute or partial path to the given resource and a base
|
|
||||||
// URL dynamically obtained from the request for which the link is being
|
|
||||||
// calculated.
|
|
||||||
func (d *directory) getLinkExplicit(typ Link, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string {
|
|
||||||
var link string
|
|
||||||
switch typ {
|
|
||||||
case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
|
|
||||||
link = fmt.Sprintf("/%s/%s", provisionerName, typ.String())
|
|
||||||
case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink:
|
|
||||||
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0])
|
|
||||||
case OrdersByAccountLink:
|
|
||||||
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0])
|
|
||||||
case FinalizeLink:
|
|
||||||
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if abs {
|
|
||||||
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
|
|
||||||
u := url.URL{}
|
|
||||||
if baseURL != nil {
|
|
||||||
u = *baseURL
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no Scheme is set, then default to https.
|
|
||||||
if u.Scheme == "" {
|
|
||||||
u.Scheme = "https"
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no Host is set, then use the default (first DNS attr in the ca.json).
|
|
||||||
if u.Host == "" {
|
|
||||||
u.Host = d.dns
|
|
||||||
}
|
|
||||||
|
|
||||||
u.Path = d.prefix + link
|
|
||||||
return u.String()
|
|
||||||
}
|
|
||||||
return link
|
|
||||||
}
|
|
|
@ -1,99 +0,0 @@
|
||||||
package acme
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDirectoryGetLink(t *testing.T) {
|
|
||||||
dns := "ca.smallstep.com"
|
|
||||||
prefix := "acme"
|
|
||||||
dir := newDirectory(dns, prefix)
|
|
||||||
id := "1234"
|
|
||||||
|
|
||||||
prov := newProv()
|
|
||||||
provName := url.PathEscape(prov.GetName())
|
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
|
||||||
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
|
|
||||||
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLink(ctx, NewNonceLink, true),
|
|
||||||
fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName))
|
|
||||||
assert.Equals(t, dir.getLink(ctx, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName))
|
|
||||||
|
|
||||||
// No provisioner
|
|
||||||
ctxNoProv := context.WithValue(context.Background(), BaseURLContextKey, baseURL)
|
|
||||||
assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, true),
|
|
||||||
fmt.Sprintf("%s/acme//new-nonce", baseURL.String()))
|
|
||||||
assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, false), "//new-nonce")
|
|
||||||
|
|
||||||
// No baseURL
|
|
||||||
ctxNoBaseURL := context.WithValue(context.Background(), ProvisionerContextKey, prov)
|
|
||||||
assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, true),
|
|
||||||
fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName))
|
|
||||||
assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLink(ctx, OrderLink, true, id),
|
|
||||||
fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName))
|
|
||||||
assert.Equals(t, dir.getLink(ctx, OrderLink, false, id), fmt.Sprintf("/%s/order/1234", provName))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDirectoryGetLinkExplicit(t *testing.T) {
|
|
||||||
dns := "ca.smallstep.com"
|
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
|
||||||
prefix := "acme"
|
|
||||||
dir := newDirectory(dns, prefix)
|
|
||||||
id := "1234"
|
|
||||||
|
|
||||||
prov := newProv()
|
|
||||||
provID := url.PathEscape(prov.GetName())
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/1234", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/1234", provID))
|
|
||||||
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID))
|
|
||||||
assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID))
|
|
||||||
}
|
|
709
acme/errors.go
709
acme/errors.go
|
@ -1,407 +1,339 @@
|
||||||
package acme
|
package acme
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/errs"
|
||||||
|
"github.com/smallstep/certificates/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountDoesNotExistErr returns a new acme error.
|
// ProblemType is the type of the ACME problem.
|
||||||
func AccountDoesNotExistErr(err error) *Error {
|
type ProblemType int
|
||||||
return &Error{
|
|
||||||
Type: accountDoesNotExistErr,
|
|
||||||
Detail: "Account does not exist",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AlreadyRevokedErr returns a new acme error.
|
|
||||||
func AlreadyRevokedErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: alreadyRevokedErr,
|
|
||||||
Detail: "Certificate already revoked",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BadCSRErr returns a new acme error.
|
|
||||||
func BadCSRErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: badCSRErr,
|
|
||||||
Detail: "The CSR is unacceptable",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BadNonceErr returns a new acme error.
|
|
||||||
func BadNonceErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: badNonceErr,
|
|
||||||
Detail: "Unacceptable anti-replay nonce",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BadPublicKeyErr returns a new acme error.
|
|
||||||
func BadPublicKeyErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: badPublicKeyErr,
|
|
||||||
Detail: "The jws was signed by a public key the server does not support",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BadRevocationReasonErr returns a new acme error.
|
|
||||||
func BadRevocationReasonErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: badRevocationReasonErr,
|
|
||||||
Detail: "The revocation reason provided is not allowed by the server",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BadSignatureAlgorithmErr returns a new acme error.
|
|
||||||
func BadSignatureAlgorithmErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: badSignatureAlgorithmErr,
|
|
||||||
Detail: "The JWS was signed with an algorithm the server does not support",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CaaErr returns a new acme error.
|
|
||||||
func CaaErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: caaErr,
|
|
||||||
Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CompoundErr returns a new acme error.
|
|
||||||
func CompoundErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: compoundErr,
|
|
||||||
Detail: "Specific error conditions are indicated in the “subproblems” array",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectionErr returns a new acme error.
|
|
||||||
func ConnectionErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: connectionErr,
|
|
||||||
Detail: "The server could not connect to validation target",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DNSErr returns a new acme error.
|
|
||||||
func DNSErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: dnsErr,
|
|
||||||
Detail: "There was a problem with a DNS query during identifier validation",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExternalAccountRequiredErr returns a new acme error.
|
|
||||||
func ExternalAccountRequiredErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: externalAccountRequiredErr,
|
|
||||||
Detail: "The request must include a value for the \"externalAccountBinding\" field",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncorrectResponseErr returns a new acme error.
|
|
||||||
func IncorrectResponseErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: incorrectResponseErr,
|
|
||||||
Detail: "Response received didn't match the challenge's requirements",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvalidContactErr returns a new acme error.
|
|
||||||
func InvalidContactErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: invalidContactErr,
|
|
||||||
Detail: "A contact URL for an account was invalid",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MalformedErr returns a new acme error.
|
|
||||||
func MalformedErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: malformedErr,
|
|
||||||
Detail: "The request message was malformed",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OrderNotReadyErr returns a new acme error.
|
|
||||||
func OrderNotReadyErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: orderNotReadyErr,
|
|
||||||
Detail: "The request attempted to finalize an order that is not ready to be finalized",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RateLimitedErr returns a new acme error.
|
|
||||||
func RateLimitedErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: rateLimitedErr,
|
|
||||||
Detail: "The request exceeds a rate limit",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RejectedIdentifierErr returns a new acme error.
|
|
||||||
func RejectedIdentifierErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: rejectedIdentifierErr,
|
|
||||||
Detail: "The server will not issue certificates for the identifier",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerInternalErr returns a new acme error.
|
|
||||||
func ServerInternalErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: serverInternalErr,
|
|
||||||
Detail: "The server experienced an internal error",
|
|
||||||
Status: 500,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotImplemented returns a new acme error.
|
|
||||||
func NotImplemented(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: notImplemented,
|
|
||||||
Detail: "The requested operation is not implemented",
|
|
||||||
Status: 501,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TLSErr returns a new acme error.
|
|
||||||
func TLSErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: tlsErr,
|
|
||||||
Detail: "The server received a TLS error during validation",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnauthorizedErr returns a new acme error.
|
|
||||||
func UnauthorizedErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: unauthorizedErr,
|
|
||||||
Detail: "The client lacks sufficient authorization",
|
|
||||||
Status: 401,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnsupportedContactErr returns a new acme error.
|
|
||||||
func UnsupportedContactErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: unsupportedContactErr,
|
|
||||||
Detail: "A contact URL for an account used an unsupported protocol scheme",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnsupportedIdentifierErr returns a new acme error.
|
|
||||||
func UnsupportedIdentifierErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: unsupportedIdentifierErr,
|
|
||||||
Detail: "An identifier is of an unsupported type",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserActionRequiredErr returns a new acme error.
|
|
||||||
func UserActionRequiredErr(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Type: userActionRequiredErr,
|
|
||||||
Detail: "Visit the “instance” URL and take actions specified there",
|
|
||||||
Status: 400,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProbType is the type of the ACME problem.
|
|
||||||
type ProbType int
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// The request specified an account that does not exist
|
// ErrorAccountDoesNotExistType request specified an account that does not exist
|
||||||
accountDoesNotExistErr ProbType = iota
|
ErrorAccountDoesNotExistType ProblemType = iota
|
||||||
// The request specified a certificate to be revoked that has already been revoked
|
// ErrorAlreadyRevokedType request specified a certificate to be revoked that has already been revoked
|
||||||
alreadyRevokedErr
|
ErrorAlreadyRevokedType
|
||||||
// The CSR is unacceptable (e.g., due to a short key)
|
// ErrorBadCSRType CSR is unacceptable (e.g., due to a short key)
|
||||||
badCSRErr
|
ErrorBadCSRType
|
||||||
// The client sent an unacceptable anti-replay nonce
|
// ErrorBadNonceType client sent an unacceptable anti-replay nonce
|
||||||
badNonceErr
|
ErrorBadNonceType
|
||||||
// The JWS was signed by a public key the server does not support
|
// ErrorBadPublicKeyType JWS was signed by a public key the server does not support
|
||||||
badPublicKeyErr
|
ErrorBadPublicKeyType
|
||||||
// The revocation reason provided is not allowed by the server
|
// ErrorBadRevocationReasonType revocation reason provided is not allowed by the server
|
||||||
badRevocationReasonErr
|
ErrorBadRevocationReasonType
|
||||||
// The JWS was signed with an algorithm the server does not support
|
// ErrorBadSignatureAlgorithmType JWS was signed with an algorithm the server does not support
|
||||||
badSignatureAlgorithmErr
|
ErrorBadSignatureAlgorithmType
|
||||||
// Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate
|
// ErrorCaaType Authority Authorization (CAA) records forbid the CA from issuing a certificate
|
||||||
caaErr
|
ErrorCaaType
|
||||||
// Specific error conditions are indicated in the “subproblems” array.
|
// ErrorCompoundType error conditions are indicated in the “subproblems” array.
|
||||||
compoundErr
|
ErrorCompoundType
|
||||||
// The server could not connect to validation target
|
// ErrorConnectionType server could not connect to validation target
|
||||||
connectionErr
|
ErrorConnectionType
|
||||||
// There was a problem with a DNS query during identifier validation
|
// ErrorDNSType was a problem with a DNS query during identifier validation
|
||||||
dnsErr
|
ErrorDNSType
|
||||||
// The request must include a value for the “externalAccountBinding” field
|
// ErrorExternalAccountRequiredType request must include a value for the “externalAccountBinding” field
|
||||||
externalAccountRequiredErr
|
ErrorExternalAccountRequiredType
|
||||||
// Response received didn’t match the challenge’s requirements
|
// ErrorIncorrectResponseType received didn’t match the challenge’s requirements
|
||||||
incorrectResponseErr
|
ErrorIncorrectResponseType
|
||||||
// A contact URL for an account was invalid
|
// ErrorInvalidContactType URL for an account was invalid
|
||||||
invalidContactErr
|
ErrorInvalidContactType
|
||||||
// The request message was malformed
|
// ErrorMalformedType request message was malformed
|
||||||
malformedErr
|
ErrorMalformedType
|
||||||
// The request attempted to finalize an order that is not ready to be finalized
|
// ErrorOrderNotReadyType request attempted to finalize an order that is not ready to be finalized
|
||||||
orderNotReadyErr
|
ErrorOrderNotReadyType
|
||||||
// The request exceeds a rate limit
|
// ErrorRateLimitedType request exceeds a rate limit
|
||||||
rateLimitedErr
|
ErrorRateLimitedType
|
||||||
// The server will not issue certificates for the identifier
|
// ErrorRejectedIdentifierType server will not issue certificates for the identifier
|
||||||
rejectedIdentifierErr
|
ErrorRejectedIdentifierType
|
||||||
// The server experienced an internal error
|
// ErrorServerInternalType server experienced an internal error
|
||||||
serverInternalErr
|
ErrorServerInternalType
|
||||||
// The server received a TLS error during validation
|
// ErrorTLSType server received a TLS error during validation
|
||||||
tlsErr
|
ErrorTLSType
|
||||||
// The client lacks sufficient authorization
|
// ErrorUnauthorizedType client lacks sufficient authorization
|
||||||
unauthorizedErr
|
ErrorUnauthorizedType
|
||||||
// A contact URL for an account used an unsupported protocol scheme
|
// ErrorUnsupportedContactType URL for an account used an unsupported protocol scheme
|
||||||
unsupportedContactErr
|
ErrorUnsupportedContactType
|
||||||
// An identifier is of an unsupported type
|
// ErrorUnsupportedIdentifierType identifier is of an unsupported type
|
||||||
unsupportedIdentifierErr
|
ErrorUnsupportedIdentifierType
|
||||||
// Visit the “instance” URL and take actions specified there
|
// ErrorUserActionRequiredType the “instance” URL and take actions specified there
|
||||||
userActionRequiredErr
|
ErrorUserActionRequiredType
|
||||||
// The operation is not implemented
|
// ErrorNotImplementedType operation is not implemented
|
||||||
notImplemented
|
ErrorNotImplementedType
|
||||||
)
|
)
|
||||||
|
|
||||||
// String returns the string representation of the acme problem type,
|
// String returns the string representation of the acme problem type,
|
||||||
// fulfilling the Stringer interface.
|
// fulfilling the Stringer interface.
|
||||||
func (ap ProbType) String() string {
|
func (ap ProblemType) String() string {
|
||||||
switch ap {
|
switch ap {
|
||||||
case accountDoesNotExistErr:
|
case ErrorAccountDoesNotExistType:
|
||||||
return "accountDoesNotExist"
|
return "accountDoesNotExist"
|
||||||
case alreadyRevokedErr:
|
case ErrorAlreadyRevokedType:
|
||||||
return "alreadyRevoked"
|
return "alreadyRevoked"
|
||||||
case badCSRErr:
|
case ErrorBadCSRType:
|
||||||
return "badCSR"
|
return "badCSR"
|
||||||
case badNonceErr:
|
case ErrorBadNonceType:
|
||||||
return "badNonce"
|
return "badNonce"
|
||||||
case badPublicKeyErr:
|
case ErrorBadPublicKeyType:
|
||||||
return "badPublicKey"
|
return "badPublicKey"
|
||||||
case badRevocationReasonErr:
|
case ErrorBadRevocationReasonType:
|
||||||
return "badRevocationReason"
|
return "badRevocationReason"
|
||||||
case badSignatureAlgorithmErr:
|
case ErrorBadSignatureAlgorithmType:
|
||||||
return "badSignatureAlgorithm"
|
return "badSignatureAlgorithm"
|
||||||
case caaErr:
|
case ErrorCaaType:
|
||||||
return "caa"
|
return "caa"
|
||||||
case compoundErr:
|
case ErrorCompoundType:
|
||||||
return "compound"
|
return "compound"
|
||||||
case connectionErr:
|
case ErrorConnectionType:
|
||||||
return "connection"
|
return "connection"
|
||||||
case dnsErr:
|
case ErrorDNSType:
|
||||||
return "dns"
|
return "dns"
|
||||||
case externalAccountRequiredErr:
|
case ErrorExternalAccountRequiredType:
|
||||||
return "externalAccountRequired"
|
return "externalAccountRequired"
|
||||||
case incorrectResponseErr:
|
case ErrorInvalidContactType:
|
||||||
return "incorrectResponse"
|
return "incorrectResponse"
|
||||||
case invalidContactErr:
|
case ErrorMalformedType:
|
||||||
return "invalidContact"
|
|
||||||
case malformedErr:
|
|
||||||
return "malformed"
|
return "malformed"
|
||||||
case orderNotReadyErr:
|
case ErrorOrderNotReadyType:
|
||||||
return "orderNotReady"
|
return "orderNotReady"
|
||||||
case rateLimitedErr:
|
case ErrorRateLimitedType:
|
||||||
return "rateLimited"
|
return "rateLimited"
|
||||||
case rejectedIdentifierErr:
|
case ErrorRejectedIdentifierType:
|
||||||
return "rejectedIdentifier"
|
return "rejectedIdentifier"
|
||||||
case serverInternalErr:
|
case ErrorServerInternalType:
|
||||||
return "serverInternal"
|
return "serverInternal"
|
||||||
case tlsErr:
|
case ErrorTLSType:
|
||||||
return "tls"
|
return "tls"
|
||||||
case unauthorizedErr:
|
case ErrorUnauthorizedType:
|
||||||
return "unauthorized"
|
return "unauthorized"
|
||||||
case unsupportedContactErr:
|
case ErrorUnsupportedContactType:
|
||||||
return "unsupportedContact"
|
return "unsupportedContact"
|
||||||
case unsupportedIdentifierErr:
|
case ErrorUnsupportedIdentifierType:
|
||||||
return "unsupportedIdentifier"
|
return "unsupportedIdentifier"
|
||||||
case userActionRequiredErr:
|
case ErrorUserActionRequiredType:
|
||||||
return "userActionRequired"
|
return "userActionRequired"
|
||||||
case notImplemented:
|
case ErrorNotImplementedType:
|
||||||
return "notImplemented"
|
return "notImplemented"
|
||||||
default:
|
default:
|
||||||
return "unsupported type"
|
return fmt.Sprintf("unsupported type ACME error type '%d'", int(ap))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error is an ACME error type complete with problem document.
|
type errorMetadata struct {
|
||||||
|
details string
|
||||||
|
status int
|
||||||
|
typ string
|
||||||
|
String string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
officialACMEPrefix = "urn:ietf:params:acme:error:"
|
||||||
|
errorServerInternalMetadata = errorMetadata{
|
||||||
|
typ: officialACMEPrefix + ErrorServerInternalType.String(),
|
||||||
|
details: "The server experienced an internal error",
|
||||||
|
status: 500,
|
||||||
|
}
|
||||||
|
errorMap = map[ProblemType]errorMetadata{
|
||||||
|
ErrorAccountDoesNotExistType: {
|
||||||
|
typ: officialACMEPrefix + ErrorAccountDoesNotExistType.String(),
|
||||||
|
details: "Account does not exist",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorAlreadyRevokedType: {
|
||||||
|
typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(),
|
||||||
|
details: "Certificate already Revoked",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorBadCSRType: {
|
||||||
|
typ: officialACMEPrefix + ErrorBadCSRType.String(),
|
||||||
|
details: "The CSR is unacceptable",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorBadNonceType: {
|
||||||
|
typ: officialACMEPrefix + ErrorBadNonceType.String(),
|
||||||
|
details: "Unacceptable anti-replay nonce",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorBadPublicKeyType: {
|
||||||
|
typ: officialACMEPrefix + ErrorBadPublicKeyType.String(),
|
||||||
|
details: "The jws was signed by a public key the server does not support",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorBadRevocationReasonType: {
|
||||||
|
typ: officialACMEPrefix + ErrorBadRevocationReasonType.String(),
|
||||||
|
details: "The revocation reason provided is not allowed by the server",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorBadSignatureAlgorithmType: {
|
||||||
|
typ: officialACMEPrefix + ErrorBadSignatureAlgorithmType.String(),
|
||||||
|
details: "The JWS was signed with an algorithm the server does not support",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorCaaType: {
|
||||||
|
typ: officialACMEPrefix + ErrorCaaType.String(),
|
||||||
|
details: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorCompoundType: {
|
||||||
|
typ: officialACMEPrefix + ErrorCompoundType.String(),
|
||||||
|
details: "Specific error conditions are indicated in the “subproblems” array",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorConnectionType: {
|
||||||
|
typ: officialACMEPrefix + ErrorConnectionType.String(),
|
||||||
|
details: "The server could not connect to validation target",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorDNSType: {
|
||||||
|
typ: officialACMEPrefix + ErrorDNSType.String(),
|
||||||
|
details: "There was a problem with a DNS query during identifier validation",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorExternalAccountRequiredType: {
|
||||||
|
typ: officialACMEPrefix + ErrorExternalAccountRequiredType.String(),
|
||||||
|
details: "The request must include a value for the \"externalAccountBinding\" field",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorIncorrectResponseType: {
|
||||||
|
typ: officialACMEPrefix + ErrorIncorrectResponseType.String(),
|
||||||
|
details: "Response received didn't match the challenge's requirements",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorInvalidContactType: {
|
||||||
|
typ: officialACMEPrefix + ErrorInvalidContactType.String(),
|
||||||
|
details: "A contact URL for an account was invalid",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorMalformedType: {
|
||||||
|
typ: officialACMEPrefix + ErrorMalformedType.String(),
|
||||||
|
details: "The request message was malformed",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorOrderNotReadyType: {
|
||||||
|
typ: officialACMEPrefix + ErrorOrderNotReadyType.String(),
|
||||||
|
details: "The request attempted to finalize an order that is not ready to be finalized",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorRateLimitedType: {
|
||||||
|
typ: officialACMEPrefix + ErrorRateLimitedType.String(),
|
||||||
|
details: "The request exceeds a rate limit",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorRejectedIdentifierType: {
|
||||||
|
typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(),
|
||||||
|
details: "The server will not issue certificates for the identifier",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorNotImplementedType: {
|
||||||
|
typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(),
|
||||||
|
details: "The requested operation is not implemented",
|
||||||
|
status: 501,
|
||||||
|
},
|
||||||
|
ErrorTLSType: {
|
||||||
|
typ: officialACMEPrefix + ErrorTLSType.String(),
|
||||||
|
details: "The server received a TLS error during validation",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorUnauthorizedType: {
|
||||||
|
typ: officialACMEPrefix + ErrorUnauthorizedType.String(),
|
||||||
|
details: "The client lacks sufficient authorization",
|
||||||
|
status: 401,
|
||||||
|
},
|
||||||
|
ErrorUnsupportedContactType: {
|
||||||
|
typ: officialACMEPrefix + ErrorUnsupportedContactType.String(),
|
||||||
|
details: "A contact URL for an account used an unsupported protocol scheme",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorUnsupportedIdentifierType: {
|
||||||
|
typ: officialACMEPrefix + ErrorUnsupportedIdentifierType.String(),
|
||||||
|
details: "An identifier is of an unsupported type",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorUserActionRequiredType: {
|
||||||
|
typ: officialACMEPrefix + ErrorUserActionRequiredType.String(),
|
||||||
|
details: "Visit the “instance” URL and take actions specified there",
|
||||||
|
status: 400,
|
||||||
|
},
|
||||||
|
ErrorServerInternalType: errorServerInternalMetadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error represents an ACME
|
||||||
type Error struct {
|
type Error struct {
|
||||||
Type ProbType
|
Type string `json:"type"`
|
||||||
Detail string
|
Detail string `json:"detail"`
|
||||||
Err error
|
Subproblems []interface{} `json:"subproblems,omitempty"`
|
||||||
Status int
|
Identifier interface{} `json:"identifier,omitempty"`
|
||||||
Sub []*Error
|
Err error `json:"-"`
|
||||||
Identifier *Identifier
|
Status int `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap attempts to wrap the internal error.
|
// NewError creates a new Error type.
|
||||||
func Wrap(err error, wrap string) *Error {
|
func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
|
||||||
|
return newError(pt, errors.Errorf(msg, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newError(pt ProblemType, err error) *Error {
|
||||||
|
meta, ok := errorMap[pt]
|
||||||
|
if !ok {
|
||||||
|
meta = errorServerInternalMetadata
|
||||||
|
return &Error{
|
||||||
|
Type: meta.typ,
|
||||||
|
Detail: meta.details,
|
||||||
|
Status: meta.status,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Error{
|
||||||
|
Type: meta.typ,
|
||||||
|
Detail: meta.details,
|
||||||
|
Status: meta.status,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrorISE creates a new ErrorServerInternalType Error.
|
||||||
|
func NewErrorISE(msg string, args ...interface{}) *Error {
|
||||||
|
return NewError(ErrorServerInternalType, msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapError attempts to wrap the internal error.
|
||||||
|
func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error {
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
return nil
|
return nil
|
||||||
case *Error:
|
case *Error:
|
||||||
if e.Err == nil {
|
if e.Err == nil {
|
||||||
e.Err = errors.New(wrap + "; " + e.Detail)
|
e.Err = errors.Errorf(msg+"; "+e.Detail, args...)
|
||||||
} else {
|
} else {
|
||||||
e.Err = errors.Wrap(e.Err, wrap)
|
e.Err = errors.Wrapf(e.Err, msg, args...)
|
||||||
}
|
}
|
||||||
return e
|
return e
|
||||||
default:
|
default:
|
||||||
return ServerInternalErr(errors.Wrap(err, wrap))
|
return newError(typ, errors.Wrapf(err, msg, args...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error implements the error interface.
|
// WrapErrorISE shortcut to wrap an internal server error type.
|
||||||
func (e *Error) Error() string {
|
func WrapErrorISE(err error, msg string, args ...interface{}) *Error {
|
||||||
if e.Err == nil {
|
return WrapError(ErrorServerInternalType, err, msg, args...)
|
||||||
return e.Detail
|
|
||||||
}
|
}
|
||||||
return e.Err.Error()
|
|
||||||
|
// StatusCode returns the status code and implements the StatusCoder interface.
|
||||||
|
func (e *Error) StatusCode() int {
|
||||||
|
return e.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error allows AError to implement the error interface.
|
||||||
|
func (e *Error) Error() string {
|
||||||
|
return e.Detail
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cause returns the internal error and implements the Causer interface.
|
// Cause returns the internal error and implements the Causer interface.
|
||||||
|
@ -412,70 +344,35 @@ func (e *Error) Cause() error {
|
||||||
return e.Err
|
return e.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Official returns true if this error's type is listed in §6.7 of RFC 8555.
|
// ToLog implements the EnableLogger interface.
|
||||||
// Error types in §6.7 are registered under IETF urn namespace:
|
func (e *Error) ToLog() (interface{}, error) {
|
||||||
//
|
b, err := json.Marshal(e)
|
||||||
// "urn:ietf:params:acme:error:"
|
if err != nil {
|
||||||
//
|
return nil, WrapErrorISE(err, "error marshaling acme.Error for logging")
|
||||||
// and should include the namespace as a prefix when appearing as a problem
|
}
|
||||||
// document.
|
return string(b), nil
|
||||||
//
|
|
||||||
// RFC 8555 also says:
|
|
||||||
//
|
|
||||||
// This list is not exhaustive. The server MAY return errors whose
|
|
||||||
// "type" field is set to a URI other than those defined above. Servers
|
|
||||||
// MUST NOT use the ACME URN namespace for errors not listed in the
|
|
||||||
// appropriate IANA registry (see Section 9.6). Clients SHOULD display
|
|
||||||
// the "detail" field of all errors.
|
|
||||||
//
|
|
||||||
// In this case Official returns `false` so that a different namespace can
|
|
||||||
// be used.
|
|
||||||
func (e *Error) Official() bool {
|
|
||||||
return e.Type != notImplemented
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToACME returns an acme representation of the problem type.
|
// WriteError writes to w a JSON representation of the given error.
|
||||||
// For official errors, the IETF ACME namespace is prepended to the error type.
|
func WriteError(w http.ResponseWriter, err *Error) {
|
||||||
// For our own errors, we use an (yet) unregistered smallstep acme namespace.
|
w.Header().Set("Content-Type", "application/problem+json")
|
||||||
func (e *Error) ToACME() *AError {
|
w.WriteHeader(err.StatusCode())
|
||||||
prefix := "urn:step:acme:error"
|
|
||||||
if e.Official() {
|
// Write errors in the response writer
|
||||||
prefix = "urn:ietf:params:acme:error:"
|
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||||
|
rl.WithFields(map[string]interface{}{
|
||||||
|
"error": err.Err,
|
||||||
|
})
|
||||||
|
if os.Getenv("STEPDEBUG") == "1" {
|
||||||
|
if e, ok := err.Err.(errs.StackTracer); ok {
|
||||||
|
rl.WithFields(map[string]interface{}{
|
||||||
|
"stack-trace": fmt.Sprintf("%+v", e),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
ae := &AError{
|
|
||||||
Type: prefix + e.Type.String(),
|
|
||||||
Detail: e.Error(),
|
|
||||||
Status: e.Status,
|
|
||||||
}
|
}
|
||||||
if e.Identifier != nil {
|
|
||||||
ae.Identifier = *e.Identifier
|
|
||||||
}
|
|
||||||
for _, p := range e.Sub {
|
|
||||||
ae.Subproblems = append(ae.Subproblems, p.ToACME())
|
|
||||||
}
|
|
||||||
return ae
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StatusCode returns the status code and implements the StatusCode interface.
|
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||||||
func (e *Error) StatusCode() int {
|
log.Println(err)
|
||||||
return e.Status
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AError is the error type as seen in acme request/responses.
|
|
||||||
type AError struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Detail string `json:"detail"`
|
|
||||||
Identifier interface{} `json:"identifier,omitempty"`
|
|
||||||
Subproblems []interface{} `json:"subproblems,omitempty"`
|
|
||||||
Status int `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error allows AError to implement the error interface.
|
|
||||||
func (ae *AError) Error() string {
|
|
||||||
return ae.Detail
|
|
||||||
}
|
|
||||||
|
|
||||||
// StatusCode returns the status code and implements the StatusCode interface.
|
|
||||||
func (ae *AError) StatusCode() int {
|
|
||||||
return ae.Status
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,73 +1,9 @@
|
||||||
package acme
|
package acme
|
||||||
|
|
||||||
import (
|
// Nonce represents an ACME nonce type.
|
||||||
"encoding/base64"
|
type Nonce string
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
// String implements the ToString interface.
|
||||||
"github.com/smallstep/nosql"
|
func (n Nonce) String() string {
|
||||||
"github.com/smallstep/nosql/database"
|
return string(n)
|
||||||
)
|
|
||||||
|
|
||||||
// nonce contains nonce metadata used in the ACME protocol.
|
|
||||||
type nonce struct {
|
|
||||||
ID string
|
|
||||||
Created time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// newNonce creates, stores, and returns an ACME replay-nonce.
|
|
||||||
func newNonce(db nosql.DB) (*nonce, error) {
|
|
||||||
_id, err := randID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
|
|
||||||
n := &nonce{
|
|
||||||
ID: id,
|
|
||||||
Created: clock.Now(),
|
|
||||||
}
|
|
||||||
b, err := json.Marshal(n)
|
|
||||||
if err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce"))
|
|
||||||
}
|
|
||||||
_, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b)
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce"))
|
|
||||||
case !swapped:
|
|
||||||
return nil, ServerInternalErr(errors.New("error storing nonce; " +
|
|
||||||
"value has changed since last read"))
|
|
||||||
default:
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// useNonce verifies that the nonce is valid (by checking if it exists),
|
|
||||||
// and if so, consumes the nonce resource by deleting it from the database.
|
|
||||||
func useNonce(db nosql.DB, nonce string) error {
|
|
||||||
err := db.Update(&database.Tx{
|
|
||||||
Operations: []*database.TxEntry{
|
|
||||||
{
|
|
||||||
Bucket: nonceTable,
|
|
||||||
Key: []byte(nonce),
|
|
||||||
Cmd: database.Get,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Bucket: nonceTable,
|
|
||||||
Key: []byte(nonce),
|
|
||||||
Cmd: database.Delete,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case nosql.IsErrNotFound(err):
|
|
||||||
return BadNonceErr(nil)
|
|
||||||
case err != nil:
|
|
||||||
return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce))
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,163 +0,0 @@
|
||||||
package acme
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
"github.com/smallstep/nosql/database"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewNonce(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
db nosql.DB
|
|
||||||
err *Error
|
|
||||||
id *string
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, nonceTable)
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
return nil, false, errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.Errorf("error storing nonce: force")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, nonceTable)
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
return nil, false, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
var _id string
|
|
||||||
id := &_id
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
assert.Equals(t, bucket, nonceTable)
|
|
||||||
assert.Equals(t, old, nil)
|
|
||||||
*id = string(key)
|
|
||||||
return nil, true, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
id: id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if n, err := newNonce(tc.db); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, n.ID, *tc.id)
|
|
||||||
|
|
||||||
assert.True(t, n.Created.Before(time.Now().Add(time.Minute)))
|
|
||||||
assert.True(t, n.Created.After(time.Now().Add(-time.Minute)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUseNonce(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
id string
|
|
||||||
db nosql.DB
|
|
||||||
err *Error
|
|
||||||
}
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
|
||||||
"fail/update-not-found": func(t *testing.T) test {
|
|
||||||
id := "foo"
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MUpdate: func(tx *database.Tx) error {
|
|
||||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
|
||||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
|
||||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
|
||||||
|
|
||||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
|
||||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
|
||||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
|
||||||
return database.ErrNotFound
|
|
||||||
},
|
|
||||||
},
|
|
||||||
id: id,
|
|
||||||
err: BadNonceErr(nil),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/update-error": func(t *testing.T) test {
|
|
||||||
id := "foo"
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MUpdate: func(tx *database.Tx) error {
|
|
||||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
|
||||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
|
||||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
|
||||||
|
|
||||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
|
||||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
|
||||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
|
||||||
return errors.New("force")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
id: id,
|
|
||||||
err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
|
||||||
id := "foo"
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MUpdate: func(tx *database.Tx) error {
|
|
||||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
|
||||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
|
||||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
|
||||||
|
|
||||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
|
||||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
|
||||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
id: id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, run := range tests {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
tc := run(t)
|
|
||||||
if err := useNonce(tc.db, tc.id); err != nil {
|
|
||||||
if assert.NotNil(t, tc.err) {
|
|
||||||
ae, ok := err.(*Error)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
|
||||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
|
||||||
assert.Equals(t, ae.Type, tc.err.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
423
acme/order.go
423
acme/order.go
|
@ -6,351 +6,129 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
"go.step.sm/crypto/x509util"
|
"go.step.sm/crypto/x509util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultOrderExpiry = time.Hour * 24
|
// Identifier encodes the type that an order pertains to.
|
||||||
|
type Identifier struct {
|
||||||
// Mutex for locking ordersByAccount index operations.
|
Type string `json:"type"`
|
||||||
var ordersByAccountMux sync.Mutex
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
// Order contains order metadata for the ACME protocol order type.
|
// Order contains order metadata for the ACME protocol order type.
|
||||||
type Order struct {
|
type Order struct {
|
||||||
Status string `json:"status"`
|
ID string `json:"id"`
|
||||||
Expires string `json:"expires,omitempty"`
|
AccountID string `json:"-"`
|
||||||
|
ProvisionerID string `json:"-"`
|
||||||
|
Status Status `json:"status"`
|
||||||
|
ExpiresAt time.Time `json:"expires"`
|
||||||
Identifiers []Identifier `json:"identifiers"`
|
Identifiers []Identifier `json:"identifiers"`
|
||||||
NotBefore string `json:"notBefore,omitempty"`
|
NotBefore time.Time `json:"notBefore"`
|
||||||
NotAfter string `json:"notAfter,omitempty"`
|
NotAfter time.Time `json:"notAfter"`
|
||||||
Error interface{} `json:"error,omitempty"`
|
Error *Error `json:"error,omitempty"`
|
||||||
Authorizations []string `json:"authorizations"`
|
AuthorizationIDs []string `json:"-"`
|
||||||
Finalize string `json:"finalize"`
|
AuthorizationURLs []string `json:"authorizations"`
|
||||||
Certificate string `json:"certificate,omitempty"`
|
FinalizeURL string `json:"finalize"`
|
||||||
ID string `json:"-"`
|
CertificateID string `json:"-"`
|
||||||
|
CertificateURL string `json:"certificate,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToLog enables response logging.
|
// ToLog enables response logging.
|
||||||
func (o *Order) ToLog() (interface{}, error) {
|
func (o *Order) ToLog() (interface{}, error) {
|
||||||
b, err := json.Marshal(o)
|
b, err := json.Marshal(o)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging"))
|
return nil, WrapErrorISE(err, "error marshaling order for logging")
|
||||||
}
|
}
|
||||||
return string(b), nil
|
return string(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the Order ID.
|
// UpdateStatus updates the ACME Order Status if necessary.
|
||||||
func (o *Order) GetID() string {
|
// Changes to the order are saved using the database interface.
|
||||||
return o.ID
|
func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
|
||||||
}
|
|
||||||
|
|
||||||
// OrderOptions options with which to create a new Order.
|
|
||||||
type OrderOptions struct {
|
|
||||||
AccountID string `json:"accID"`
|
|
||||||
Identifiers []Identifier `json:"identifiers"`
|
|
||||||
NotBefore time.Time `json:"notBefore"`
|
|
||||||
NotAfter time.Time `json:"notAfter"`
|
|
||||||
backdate time.Duration
|
|
||||||
defaultDuration time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type order struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
AccountID string `json:"accountID"`
|
|
||||||
Created time.Time `json:"created"`
|
|
||||||
Expires time.Time `json:"expires,omitempty"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Identifiers []Identifier `json:"identifiers"`
|
|
||||||
NotBefore time.Time `json:"notBefore,omitempty"`
|
|
||||||
NotAfter time.Time `json:"notAfter,omitempty"`
|
|
||||||
Error *Error `json:"error,omitempty"`
|
|
||||||
Authorizations []string `json:"authorizations"`
|
|
||||||
Certificate string `json:"certificate,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// newOrder returns a new Order type.
|
|
||||||
func newOrder(db nosql.DB, ops OrderOptions) (*order, error) {
|
|
||||||
id, err := randID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authzs := make([]string, len(ops.Identifiers))
|
|
||||||
for i, identifier := range ops.Identifiers {
|
|
||||||
az, err := newAuthz(db, ops.AccountID, identifier)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
authzs[i] = az.getID()
|
|
||||||
}
|
|
||||||
|
|
||||||
now := clock.Now()
|
now := clock.Now()
|
||||||
var backdate time.Duration
|
|
||||||
nbf := ops.NotBefore
|
|
||||||
if nbf.IsZero() {
|
|
||||||
nbf = now
|
|
||||||
backdate = -1 * ops.backdate
|
|
||||||
}
|
|
||||||
naf := ops.NotAfter
|
|
||||||
if naf.IsZero() {
|
|
||||||
naf = nbf.Add(ops.defaultDuration)
|
|
||||||
}
|
|
||||||
|
|
||||||
o := &order{
|
|
||||||
ID: id,
|
|
||||||
AccountID: ops.AccountID,
|
|
||||||
Created: now,
|
|
||||||
Status: StatusPending,
|
|
||||||
Expires: now.Add(defaultOrderExpiry),
|
|
||||||
Identifiers: ops.Identifiers,
|
|
||||||
NotBefore: nbf.Add(backdate),
|
|
||||||
NotAfter: naf,
|
|
||||||
Authorizations: authzs,
|
|
||||||
}
|
|
||||||
if err := o.save(db, nil); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var oidHelper = orderIDsByAccount{}
|
|
||||||
_, err = oidHelper.addOrderID(db, ops.AccountID, o.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return o, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type orderIDsByAccount struct{}
|
|
||||||
|
|
||||||
// addOrderID adds an order ID to a users index of in progress order IDs.
|
|
||||||
// This method will also cull any orders that are no longer in the `pending`
|
|
||||||
// state from the index before returning it.
|
|
||||||
func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) {
|
|
||||||
ordersByAccountMux.Lock()
|
|
||||||
defer ordersByAccountMux.Unlock()
|
|
||||||
|
|
||||||
// Update the "order IDs by account ID" index
|
|
||||||
oids, err := oiba.unsafeGetOrderIDsByAccount(db, accID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newOids := append(oids, oid)
|
|
||||||
if err = orderIDs(newOids).save(db, oids, accID); err != nil {
|
|
||||||
// Delete the entire order if storing the index fails.
|
|
||||||
db.Del(orderTable, []byte(oid))
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return newOids, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// unsafeGetOrderIDsByAccount retrieves a list of Order IDs that were created by the
|
|
||||||
// account.
|
|
||||||
func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) {
|
|
||||||
b, err := db.Get(ordersByAccountIDTable, []byte(accID))
|
|
||||||
if err != nil {
|
|
||||||
if nosql.IsErrNotFound(err) {
|
|
||||||
return []string{}, nil
|
|
||||||
}
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID))
|
|
||||||
}
|
|
||||||
var oids []string
|
|
||||||
if err := json.Unmarshal(b, &oids); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove any order that is not in PENDING state and update the stored list
|
|
||||||
// before returning.
|
|
||||||
//
|
|
||||||
// According to RFC 8555:
|
|
||||||
// The server SHOULD include pending orders and SHOULD NOT include orders
|
|
||||||
// that are invalid in the array of URLs.
|
|
||||||
pendOids := []string{}
|
|
||||||
for _, oid := range oids {
|
|
||||||
o, err := getOrder(db, oid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID))
|
|
||||||
}
|
|
||||||
if o, err = o.updateStatus(db); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID))
|
|
||||||
}
|
|
||||||
if o.Status == StatusPending {
|
|
||||||
pendOids = append(pendOids, oid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If the number of pending orders is less than the number of orders in the
|
|
||||||
// list, then update the pending order list.
|
|
||||||
if len(pendOids) != len(oids) {
|
|
||||||
if err = orderIDs(pendOids).save(db, oids, accID); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+
|
|
||||||
"len(orderIDs) = %d", len(pendOids)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pendOids, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type orderIDs []string
|
|
||||||
|
|
||||||
// save is used to update the list of orderIDs keyed by ACME account ID
|
|
||||||
// stored in the database.
|
|
||||||
//
|
|
||||||
// This method always converts empty lists to 'nil' when storing to the DB. We
|
|
||||||
// do this to avoid any confusion between an empty list and a nil value in the
|
|
||||||
// db.
|
|
||||||
func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
oldb []byte
|
|
||||||
newb []byte
|
|
||||||
)
|
|
||||||
if len(old) == 0 {
|
|
||||||
oldb = nil
|
|
||||||
} else {
|
|
||||||
oldb, err = json.Marshal(old)
|
|
||||||
if err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(oids) == 0 {
|
|
||||||
newb = nil
|
|
||||||
} else {
|
|
||||||
newb, err = json.Marshal(oids)
|
|
||||||
if err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb)
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID))
|
|
||||||
case !swapped:
|
|
||||||
return ServerInternalErr(errors.Errorf("error storing order IDs "+
|
|
||||||
"for account %s; order IDs changed since last read", accID))
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *order) save(db nosql.DB, old *order) error {
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
oldB []byte
|
|
||||||
)
|
|
||||||
if old == nil {
|
|
||||||
oldB = nil
|
|
||||||
} else {
|
|
||||||
if oldB, err = json.Marshal(old); err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newB, err := json.Marshal(o)
|
|
||||||
if err != nil {
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order"))
|
|
||||||
}
|
|
||||||
|
|
||||||
_, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB)
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
return ServerInternalErr(errors.Wrap(err, "error storing order"))
|
|
||||||
case !swapped:
|
|
||||||
return ServerInternalErr(errors.New("error storing order; " +
|
|
||||||
"value has changed since last read"))
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateStatus updates order status if necessary.
|
|
||||||
func (o *order) updateStatus(db nosql.DB) (*order, error) {
|
|
||||||
_newOrder := *o
|
|
||||||
newOrder := &_newOrder
|
|
||||||
|
|
||||||
now := time.Now().UTC()
|
|
||||||
switch o.Status {
|
switch o.Status {
|
||||||
case StatusInvalid:
|
case StatusInvalid:
|
||||||
return o, nil
|
return nil
|
||||||
case StatusValid:
|
case StatusValid:
|
||||||
return o, nil
|
return nil
|
||||||
case StatusReady:
|
case StatusReady:
|
||||||
// check expiry
|
// Check expiry
|
||||||
if now.After(o.Expires) {
|
if now.After(o.ExpiresAt) {
|
||||||
newOrder.Status = StatusInvalid
|
o.Status = StatusInvalid
|
||||||
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
o.Error = NewError(ErrorMalformedType, "order has expired")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
return o, nil
|
return nil
|
||||||
case StatusPending:
|
case StatusPending:
|
||||||
// check expiry
|
// Check expiry
|
||||||
if now.After(o.Expires) {
|
if now.After(o.ExpiresAt) {
|
||||||
newOrder.Status = StatusInvalid
|
o.Status = StatusInvalid
|
||||||
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
o.Error = NewError(ErrorMalformedType, "order has expired")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
var count = map[string]int{
|
var count = map[Status]int{
|
||||||
StatusValid: 0,
|
StatusValid: 0,
|
||||||
StatusInvalid: 0,
|
StatusInvalid: 0,
|
||||||
StatusPending: 0,
|
StatusPending: 0,
|
||||||
}
|
}
|
||||||
for _, azID := range o.Authorizations {
|
for _, azID := range o.AuthorizationIDs {
|
||||||
az, err := getAuthz(db, azID)
|
az, err := db.GetAuthorization(ctx, azID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return WrapErrorISE(err, "error getting authorization ID %s", azID)
|
||||||
}
|
}
|
||||||
if az, err = az.updateStatus(db); err != nil {
|
if err = az.UpdateStatus(ctx, db); err != nil {
|
||||||
return nil, err
|
return WrapErrorISE(err, "error updating authorization ID %s", azID)
|
||||||
}
|
}
|
||||||
st := az.getStatus()
|
st := az.Status
|
||||||
count[st]++
|
count[st]++
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
case count[StatusInvalid] > 0:
|
case count[StatusInvalid] > 0:
|
||||||
newOrder.Status = StatusInvalid
|
o.Status = StatusInvalid
|
||||||
|
|
||||||
// No change in the order status, so just return the order as is -
|
// No change in the order status, so just return the order as is -
|
||||||
// without writing any changes.
|
// without writing any changes.
|
||||||
case count[StatusPending] > 0:
|
case count[StatusPending] > 0:
|
||||||
return newOrder, nil
|
return nil
|
||||||
|
|
||||||
case count[StatusValid] == len(o.Authorizations):
|
case count[StatusValid] == len(o.AuthorizationIDs):
|
||||||
newOrder.Status = StatusReady
|
o.Status = StatusReady
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, ServerInternalErr(errors.New("unexpected authz status"))
|
return NewErrorISE("unexpected authz status")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status))
|
return NewErrorISE("unrecognized order status: %s", o.Status)
|
||||||
|
}
|
||||||
|
if err := db.UpdateOrder(ctx, o); err != nil {
|
||||||
|
return WrapErrorISE(err, "error updating order")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := newOrder.save(db, o); err != nil {
|
// Finalize signs a certificate if the necessary conditions for Order completion
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return newOrder, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// finalize signs a certificate if the necessary conditions for Order completion
|
|
||||||
// have been met.
|
// have been met.
|
||||||
func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) (*order, error) {
|
func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth CertificateAuthority, p Provisioner) error {
|
||||||
var err error
|
if err := o.UpdateStatus(ctx, db); err != nil {
|
||||||
if o, err = o.updateStatus(db); err != nil {
|
return err
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch o.Status {
|
switch o.Status {
|
||||||
case StatusInvalid:
|
case StatusInvalid:
|
||||||
return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID))
|
return NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID)
|
||||||
case StatusValid:
|
case StatusValid:
|
||||||
return o, nil
|
return nil
|
||||||
case StatusPending:
|
case StatusPending:
|
||||||
return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID))
|
return NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID)
|
||||||
case StatusReady:
|
case StatusReady:
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID))
|
return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RFC8555: The CSR MUST indicate the exact same set of requested
|
// RFC8555: The CSR MUST indicate the exact same set of requested
|
||||||
|
@ -361,12 +139,12 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut
|
||||||
if csr.Subject.CommonName != "" {
|
if csr.Subject.CommonName != "" {
|
||||||
csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
|
csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
|
||||||
}
|
}
|
||||||
csr.DNSNames = uniqueLowerNames(csr.DNSNames)
|
csr.DNSNames = uniqueSortedLowerNames(csr.DNSNames)
|
||||||
orderNames := make([]string, len(o.Identifiers))
|
orderNames := make([]string, len(o.Identifiers))
|
||||||
for i, n := range o.Identifiers {
|
for i, n := range o.Identifiers {
|
||||||
orderNames[i] = n.Value
|
orderNames[i] = n.Value
|
||||||
}
|
}
|
||||||
orderNames = uniqueLowerNames(orderNames)
|
orderNames = uniqueSortedLowerNames(orderNames)
|
||||||
|
|
||||||
// Validate identifier names against CSR alternative names.
|
// Validate identifier names against CSR alternative names.
|
||||||
//
|
//
|
||||||
|
@ -374,13 +152,15 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut
|
||||||
// absence of other SANs as they will only be set if the templates allows
|
// absence of other SANs as they will only be set if the templates allows
|
||||||
// them.
|
// them.
|
||||||
if len(csr.DNSNames) != len(orderNames) {
|
if len(csr.DNSNames) != len(orderNames) {
|
||||||
return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames))
|
return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||||
|
"CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
|
||||||
}
|
}
|
||||||
|
|
||||||
sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames))
|
sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames))
|
||||||
for i := range csr.DNSNames {
|
for i := range csr.DNSNames {
|
||||||
if csr.DNSNames[i] != orderNames[i] {
|
if csr.DNSNames[i] != orderNames[i] {
|
||||||
return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames))
|
return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
|
||||||
|
"CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
|
||||||
}
|
}
|
||||||
sans[i] = x509util.SubjectAlternativeName{
|
sans[i] = x509util.SubjectAlternativeName{
|
||||||
Type: x509util.DNSType,
|
Type: x509util.DNSType,
|
||||||
|
@ -389,10 +169,10 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get authorizations from the ACME provisioner.
|
// Get authorizations from the ACME provisioner.
|
||||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||||
signOps, err := p.AuthorizeSign(ctx, "")
|
signOps, err := p.AuthorizeSign(ctx, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner"))
|
return WrapErrorISE(err, "error retrieving authorization options from ACME provisioner")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Template data
|
// Template data
|
||||||
|
@ -402,82 +182,41 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut
|
||||||
|
|
||||||
templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data)
|
templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner"))
|
return WrapErrorISE(err, "error creating template options from ACME provisioner")
|
||||||
}
|
}
|
||||||
signOps = append(signOps, templateOptions)
|
signOps = append(signOps, templateOptions)
|
||||||
|
|
||||||
// Create and store a new certificate.
|
// Sign a new certificate.
|
||||||
certChain, err := auth.Sign(csr, provisioner.SignOptions{
|
certChain, err := auth.Sign(csr, provisioner.SignOptions{
|
||||||
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
|
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
|
||||||
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
|
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
|
||||||
}, signOps...)
|
}, signOps...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID))
|
return WrapErrorISE(err, "error signing certificate for order %s", o.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, err := newCert(db, CertOptions{
|
cert := &Certificate{
|
||||||
AccountID: o.AccountID,
|
AccountID: o.AccountID,
|
||||||
OrderID: o.ID,
|
OrderID: o.ID,
|
||||||
Leaf: certChain[0],
|
Leaf: certChain[0],
|
||||||
Intermediates: certChain[1:],
|
Intermediates: certChain[1:],
|
||||||
})
|
}
|
||||||
if err != nil {
|
if err := db.CreateCertificate(ctx, cert); err != nil {
|
||||||
return nil, err
|
return WrapErrorISE(err, "error creating certificate for order %s", o.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
_newOrder := *o
|
o.CertificateID = cert.ID
|
||||||
newOrder := &_newOrder
|
o.Status = StatusValid
|
||||||
newOrder.Certificate = cert.ID
|
if err = db.UpdateOrder(ctx, o); err != nil {
|
||||||
newOrder.Status = StatusValid
|
return WrapErrorISE(err, "error updating order %s", o.ID)
|
||||||
if err := newOrder.save(db, o); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return newOrder, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrder retrieves and unmarshals an ACME Order type from the database.
|
// uniqueSortedLowerNames returns the set of all unique names in the input after all
|
||||||
func getOrder(db nosql.DB, id string) (*order, error) {
|
|
||||||
b, err := db.Get(orderTable, []byte(id))
|
|
||||||
if nosql.IsErrNotFound(err) {
|
|
||||||
return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id))
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id))
|
|
||||||
}
|
|
||||||
var o order
|
|
||||||
if err := json.Unmarshal(b, &o); err != nil {
|
|
||||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order"))
|
|
||||||
}
|
|
||||||
return &o, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// toACME converts the internal Order type into the public acmeOrder type for
|
|
||||||
// presentation in the ACME protocol.
|
|
||||||
func (o *order) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Order, error) {
|
|
||||||
azs := make([]string, len(o.Authorizations))
|
|
||||||
for i, aid := range o.Authorizations {
|
|
||||||
azs[i] = dir.getLink(ctx, AuthzLink, true, aid)
|
|
||||||
}
|
|
||||||
ao := &Order{
|
|
||||||
Status: o.Status,
|
|
||||||
Expires: o.Expires.Format(time.RFC3339),
|
|
||||||
Identifiers: o.Identifiers,
|
|
||||||
NotBefore: o.NotBefore.Format(time.RFC3339),
|
|
||||||
NotAfter: o.NotAfter.Format(time.RFC3339),
|
|
||||||
Authorizations: azs,
|
|
||||||
Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID),
|
|
||||||
ID: o.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
if o.Certificate != "" {
|
|
||||||
ao.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate)
|
|
||||||
}
|
|
||||||
return ao, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// uniqueLowerNames returns the set of all unique names in the input after all
|
|
||||||
// of them are lowercased. The returned names will be in their lowercased form
|
// of them are lowercased. The returned names will be in their lowercased form
|
||||||
// and sorted alphabetically.
|
// and sorted alphabetically.
|
||||||
func uniqueLowerNames(names []string) (unique []string) {
|
func uniqueSortedLowerNames(names []string) (unique []string) {
|
||||||
nameMap := make(map[string]int, len(names))
|
nameMap := make(map[string]int, len(names))
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
nameMap[strings.ToLower(name)] = 1
|
nameMap[strings.ToLower(name)] = 1
|
||||||
|
|
2056
acme/order_test.go
2056
acme/order_test.go
File diff suppressed because it is too large
Load diff
20
acme/status.go
Normal file
20
acme/status.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
// Status represents an ACME status.
|
||||||
|
type Status string
|
||||||
|
|
||||||
|
var (
|
||||||
|
// StatusValid -- valid
|
||||||
|
StatusValid = Status("valid")
|
||||||
|
// StatusInvalid -- invalid
|
||||||
|
StatusInvalid = Status("invalid")
|
||||||
|
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
|
||||||
|
StatusPending = Status("pending")
|
||||||
|
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
|
||||||
|
StatusDeactivated = Status("deactivated")
|
||||||
|
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
|
||||||
|
StatusReady = Status("ready")
|
||||||
|
//statusExpired = "expired"
|
||||||
|
//statusActive = "active"
|
||||||
|
//statusProcessing = "processing"
|
||||||
|
)
|
|
@ -17,14 +17,14 @@ import (
|
||||||
func WriteError(w http.ResponseWriter, err error) {
|
func WriteError(w http.ResponseWriter, err error) {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
w.Header().Set("Content-Type", "application/problem+json")
|
acme.WriteError(w, k)
|
||||||
err = k.ToACME()
|
return
|
||||||
case *scep.Error:
|
case *scep.Error:
|
||||||
// TODO: check if this is correct; and should we do some more processing?
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
default:
|
default:
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
}
|
}
|
||||||
|
|
||||||
cause := errors.Cause(err)
|
cause := errors.Cause(err)
|
||||||
if sc, ok := err.(errs.StatusCoder); ok {
|
if sc, ok := err.(errs.StatusCoder); ok {
|
||||||
w.WriteHeader(sc.StatusCode())
|
w.WriteHeader(sc.StatusCode())
|
||||||
|
|
|
@ -56,8 +56,7 @@ func NewContextWithMethod(ctx context.Context, method Method) context.Context {
|
||||||
return context.WithValue(ctx, methodKey{}, method)
|
return context.WithValue(ctx, methodKey{}, method)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MethodFromContext returns the Method saved in ctx. Returns Sign if the given
|
// MethodFromContext returns the Method saved in ctx.
|
||||||
// context has no Method associated with it.
|
|
||||||
func MethodFromContext(ctx context.Context) Method {
|
func MethodFromContext(ctx context.Context) Method {
|
||||||
m, _ := ctx.Value(methodKey{}).(Method)
|
m, _ := ctx.Value(methodKey{}).(Method)
|
||||||
return m
|
return m
|
||||||
|
|
|
@ -450,7 +450,7 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
|
||||||
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("Match exec \"step ssh check-host %h\"\n\tUserKnownHostsFile /home/user/.step/ssh/known_hosts\n\tProxyCommand step ssh proxycommand %r %h %p\n")},
|
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("Match exec \"step ssh check-host %h\"\n\tUserKnownHostsFile /home/user/.step/ssh/known_hosts\n\tProxyCommand step ssh proxycommand %r %h %p\n")},
|
||||||
}
|
}
|
||||||
hostOutputWithUserData := []templates.Output{
|
hostOutputWithUserData := []templates.Output{
|
||||||
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("TrustedUserCAKeys /etc/ssh/ca.pub\nHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\nHostKey /etc/ssh/ssh_host_ecdsa_key")},
|
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
|
||||||
}
|
}
|
||||||
|
|
||||||
tmplConfigErr := &templates.Templates{
|
tmplConfigErr := &templates.Templates{
|
||||||
|
|
1
authority/testdata/templates/sshd_config.tpl
vendored
1
authority/testdata/templates/sshd_config.tpl
vendored
|
@ -1,3 +1,4 @@
|
||||||
|
Match all
|
||||||
TrustedUserCAKeys /etc/ssh/ca.pub
|
TrustedUserCAKeys /etc/ssh/ca.pub
|
||||||
HostCertificate /etc/ssh/{{.User.Certificate}}
|
HostCertificate /etc/ssh/{{.User.Certificate}}
|
||||||
HostKey /etc/ssh/{{.User.Key}}
|
HostKey /etc/ssh/{{.User.Key}}
|
|
@ -21,7 +21,7 @@ import (
|
||||||
type ACMEClient struct {
|
type ACMEClient struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
dirLoc string
|
dirLoc string
|
||||||
dir *acme.Directory
|
dir *acmeAPI.Directory
|
||||||
acc *acme.Account
|
acc *acme.Account
|
||||||
Key *jose.JSONWebKey
|
Key *jose.JSONWebKey
|
||||||
kid string
|
kid string
|
||||||
|
@ -53,7 +53,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
return nil, readACMEError(resp.Body)
|
return nil, readACMEError(resp.Body)
|
||||||
}
|
}
|
||||||
var dir acme.Directory
|
var dir acmeAPI.Directory
|
||||||
if err := readJSON(resp.Body, &dir); err != nil {
|
if err := readJSON(resp.Body, &dir); err != nil {
|
||||||
return nil, errors.Wrapf(err, "error reading %s", endpoint)
|
return nil, errors.Wrapf(err, "error reading %s", endpoint)
|
||||||
}
|
}
|
||||||
|
@ -93,7 +93,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC
|
||||||
|
|
||||||
// GetDirectory makes a directory request to the ACME api and returns an
|
// GetDirectory makes a directory request to the ACME api and returns an
|
||||||
// ACME directory object.
|
// ACME directory object.
|
||||||
func (c *ACMEClient) GetDirectory() (*acme.Directory, error) {
|
func (c *ACMEClient) GetDirectory() (*acmeAPI.Directory, error) {
|
||||||
return c.dir, nil
|
return c.dir, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ func (c *ACMEClient) ValidateChallenge(url string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthz returns the Authz at the given path.
|
// GetAuthz returns the Authz at the given path.
|
||||||
func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) {
|
func (c *ACMEClient) GetAuthz(url string) (*acme.Authorization, error) {
|
||||||
resp, err := c.post(nil, url, withKid(c))
|
resp, err := c.post(nil, url, withKid(c))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -240,7 +240,7 @@ func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) {
|
||||||
return nil, readACMEError(resp.Body)
|
return nil, readACMEError(resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
var az acme.Authz
|
var az acme.Authorization
|
||||||
if err := readJSON(resp.Body, &az); err != nil {
|
if err := readJSON(resp.Body, &az); err != nil {
|
||||||
return nil, errors.Wrapf(err, "error reading %s", url)
|
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||||
}
|
}
|
||||||
|
@ -320,7 +320,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) {
|
||||||
if c.acc == nil {
|
if c.acc == nil {
|
||||||
return nil, errors.New("acme client not configured with account")
|
return nil, errors.New("acme client not configured with account")
|
||||||
}
|
}
|
||||||
resp, err := c.post(nil, c.acc.Orders, withKid(c))
|
resp, err := c.post(nil, c.acc.OrdersURL, withKid(c))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -330,7 +330,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) {
|
||||||
|
|
||||||
var orders []string
|
var orders []string
|
||||||
if err := readJSON(resp.Body, &orders); err != nil {
|
if err := readJSON(resp.Body, &orders); err != nil {
|
||||||
return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders)
|
return nil, errors.Wrapf(err, "error reading %s", c.acc.OrdersURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return orders, nil
|
return orders, nil
|
||||||
|
@ -342,7 +342,7 @@ func readACMEError(r io.ReadCloser) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "error reading from body")
|
return errors.Wrap(err, "error reading from body")
|
||||||
}
|
}
|
||||||
ae := new(acme.AError)
|
ae := new(acme.Error)
|
||||||
err = json.Unmarshal(b, &ae)
|
err = json.Unmarshal(b, &ae)
|
||||||
// If we successfully marshaled to an ACMEError then return the ACMEError.
|
// If we successfully marshaled to an ACMEError then return the ACMEError.
|
||||||
if err != nil || len(ae.Error()) == 0 {
|
if err != nil || len(ae.Error()) == 0 {
|
||||||
|
|
|
@ -31,18 +31,17 @@ func TestNewACMEClient(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
NewAccount: srv.URL + "/bar",
|
NewAccount: srv.URL + "/bar",
|
||||||
NewOrder: srv.URL + "/baz",
|
NewOrder: srv.URL + "/baz",
|
||||||
NewAuthz: srv.URL + "/zap",
|
|
||||||
RevokeCert: srv.URL + "/zip",
|
RevokeCert: srv.URL + "/zip",
|
||||||
KeyChange: srv.URL + "/blorp",
|
KeyChange: srv.URL + "/blorp",
|
||||||
}
|
}
|
||||||
acc := acme.Account{
|
acc := acme.Account{
|
||||||
Contact: []string{"max", "mariano"},
|
Contact: []string{"max", "mariano"},
|
||||||
Status: "valid",
|
Status: "valid",
|
||||||
Orders: "orders-url",
|
OrdersURL: "orders-url",
|
||||||
}
|
}
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/client-option-error": func(t *testing.T) test {
|
"fail/client-option-error": func(t *testing.T) test {
|
||||||
|
@ -58,7 +57,7 @@ func TestNewACMEClient(t *testing.T) {
|
||||||
"fail/get-directory": func(t *testing.T) test {
|
"fail/get-directory": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
ops: []ClientOption{WithTransport(http.DefaultTransport)},
|
ops: []ClientOption{WithTransport(http.DefaultTransport)},
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -76,7 +75,7 @@ func TestNewACMEClient(t *testing.T) {
|
||||||
ops: []ClientOption{WithTransport(http.DefaultTransport)},
|
ops: []ClientOption{WithTransport(http.DefaultTransport)},
|
||||||
r1: dir,
|
r1: dir,
|
||||||
rc1: 200,
|
rc1: 200,
|
||||||
r2: acme.AccountDoesNotExistErr(nil).ToACME(),
|
r2: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||||
rc2: 400,
|
rc2: 400,
|
||||||
err: errors.New("Account does not exist"),
|
err: errors.New("Account does not exist"),
|
||||||
}
|
}
|
||||||
|
@ -142,11 +141,10 @@ func TestNewACMEClient(t *testing.T) {
|
||||||
|
|
||||||
func TestACMEClient_GetDirectory(t *testing.T) {
|
func TestACMEClient_GetDirectory(t *testing.T) {
|
||||||
c := &ACMEClient{
|
c := &ACMEClient{
|
||||||
dir: &acme.Directory{
|
dir: &acmeAPI.Directory{
|
||||||
NewNonce: "/foo",
|
NewNonce: "/foo",
|
||||||
NewAccount: "/bar",
|
NewAccount: "/bar",
|
||||||
NewOrder: "/baz",
|
NewOrder: "/baz",
|
||||||
NewAuthz: "/zap",
|
|
||||||
RevokeCert: "/zip",
|
RevokeCert: "/zip",
|
||||||
KeyChange: "/blorp",
|
KeyChange: "/blorp",
|
||||||
},
|
},
|
||||||
|
@ -166,7 +164,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
}
|
}
|
||||||
// Retrieve transport from options.
|
// Retrieve transport from options.
|
||||||
|
@ -185,7 +183,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/GET-nonce": func(t *testing.T) test {
|
"fail/GET-nonce": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -237,7 +235,7 @@ func TestACMEClient_post(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
}
|
}
|
||||||
// Retrieve transport from options.
|
// Retrieve transport from options.
|
||||||
|
@ -250,7 +248,7 @@ func TestACMEClient_post(t *testing.T) {
|
||||||
acc := acme.Account{
|
acc := acme.Account{
|
||||||
Contact: []string{"max", "mariano"},
|
Contact: []string{"max", "mariano"},
|
||||||
Status: "valid",
|
Status: "valid",
|
||||||
Orders: "orders-url",
|
OrdersURL: "orders-url",
|
||||||
}
|
}
|
||||||
ac := &ACMEClient{
|
ac := &ACMEClient{
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
|
@ -266,7 +264,7 @@ func TestACMEClient_post(t *testing.T) {
|
||||||
"fail/account-not-configured": func(t *testing.T) test {
|
"fail/account-not-configured": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
client: &ACMEClient{},
|
client: &ACMEClient{},
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("acme client not configured with account"),
|
err: errors.New("acme client not configured with account"),
|
||||||
}
|
}
|
||||||
|
@ -274,7 +272,7 @@ func TestACMEClient_post(t *testing.T) {
|
||||||
"fail/GET-nonce": func(t *testing.T) test {
|
"fail/GET-nonce": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
client: ac,
|
client: ac,
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -365,7 +363,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
NewOrder: srv.URL + "/bar",
|
NewOrder: srv.URL + "/bar",
|
||||||
}
|
}
|
||||||
|
@ -376,20 +374,21 @@ func TestACMEClient_NewOrder(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
now := time.Now().UTC().Round(time.Second)
|
||||||
nor := acmeAPI.NewOrderRequest{
|
nor := acmeAPI.NewOrderRequest{
|
||||||
Identifiers: []acme.Identifier{
|
Identifiers: []acme.Identifier{
|
||||||
{Type: "dns", Value: "example.com"},
|
{Type: "dns", Value: "example.com"},
|
||||||
{Type: "dns", Value: "acme.example.com"},
|
{Type: "dns", Value: "acme.example.com"},
|
||||||
},
|
},
|
||||||
NotBefore: time.Now(),
|
NotBefore: now,
|
||||||
NotAfter: time.Now().Add(time.Minute),
|
NotAfter: now.Add(time.Minute),
|
||||||
}
|
}
|
||||||
norb, err := json.Marshal(nor)
|
norb, err := json.Marshal(nor)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ord := acme.Order{
|
ord := acme.Order{
|
||||||
Status: "valid",
|
Status: "valid",
|
||||||
Expires: "soon",
|
ExpiresAt: now, // "soon"
|
||||||
Finalize: "finalize-url",
|
FinalizeURL: "finalize-url",
|
||||||
}
|
}
|
||||||
ac := &ACMEClient{
|
ac := &ACMEClient{
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
|
@ -404,7 +403,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/client-post": func(t *testing.T) test {
|
"fail/client-post": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -413,7 +412,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
r1: []byte{},
|
r1: []byte{},
|
||||||
rc1: 200,
|
rc1: 200,
|
||||||
r2: acme.MalformedErr(nil).ToACME(),
|
r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc2: 400,
|
rc2: 400,
|
||||||
ops: []withHeaderOption{withKid(ac)},
|
ops: []withHeaderOption{withKid(ac)},
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
|
@ -498,7 +497,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
}
|
}
|
||||||
// Retrieve transport from options.
|
// Retrieve transport from options.
|
||||||
|
@ -510,8 +509,8 @@ func TestACMEClient_GetOrder(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ord := acme.Order{
|
ord := acme.Order{
|
||||||
Status: "valid",
|
Status: "valid",
|
||||||
Expires: "soon",
|
ExpiresAt: time.Now().UTC().Round(time.Second), // "soon"
|
||||||
Finalize: "finalize-url",
|
FinalizeURL: "finalize-url",
|
||||||
}
|
}
|
||||||
ac := &ACMEClient{
|
ac := &ACMEClient{
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
|
@ -526,7 +525,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/client-post": func(t *testing.T) test {
|
"fail/client-post": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -535,7 +534,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
r1: []byte{},
|
r1: []byte{},
|
||||||
rc1: 200,
|
rc1: 200,
|
||||||
r2: acme.MalformedErr(nil).ToACME(),
|
r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc2: 400,
|
rc2: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -618,7 +617,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
}
|
}
|
||||||
// Retrieve transport from options.
|
// Retrieve transport from options.
|
||||||
|
@ -628,9 +627,9 @@ func TestACMEClient_GetAuthz(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
az := acme.Authz{
|
az := acme.Authorization{
|
||||||
Status: "valid",
|
Status: "valid",
|
||||||
Expires: "soon",
|
ExpiresAt: time.Now().UTC().Round(time.Second),
|
||||||
Identifier: acme.Identifier{Type: "dns", Value: "example.com"},
|
Identifier: acme.Identifier{Type: "dns", Value: "example.com"},
|
||||||
}
|
}
|
||||||
ac := &ACMEClient{
|
ac := &ACMEClient{
|
||||||
|
@ -646,7 +645,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/client-post": func(t *testing.T) test {
|
"fail/client-post": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -655,7 +654,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
r1: []byte{},
|
r1: []byte{},
|
||||||
rc1: 200,
|
rc1: 200,
|
||||||
r2: acme.MalformedErr(nil).ToACME(),
|
r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc2: 400,
|
rc2: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -738,7 +737,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
}
|
}
|
||||||
// Retrieve transport from options.
|
// Retrieve transport from options.
|
||||||
|
@ -766,7 +765,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/client-post": func(t *testing.T) test {
|
"fail/client-post": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -775,7 +774,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
r1: []byte{},
|
r1: []byte{},
|
||||||
rc1: 200,
|
rc1: 200,
|
||||||
r2: acme.MalformedErr(nil).ToACME(),
|
r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc2: 400,
|
rc2: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -859,7 +858,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
}
|
}
|
||||||
// Retrieve transport from options.
|
// Retrieve transport from options.
|
||||||
|
@ -887,7 +886,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/client-post": func(t *testing.T) test {
|
"fail/client-post": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -896,7 +895,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
r1: []byte{},
|
r1: []byte{},
|
||||||
rc1: 200,
|
rc1: 200,
|
||||||
r2: acme.MalformedErr(nil).ToACME(),
|
r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc2: 400,
|
rc2: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -976,7 +975,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
}
|
}
|
||||||
// Retrieve transport from options.
|
// Retrieve transport from options.
|
||||||
|
@ -988,9 +987,9 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ord := acme.Order{
|
ord := acme.Order{
|
||||||
Status: "valid",
|
Status: "valid",
|
||||||
Expires: "soon",
|
ExpiresAt: time.Now(), // "soon"
|
||||||
Finalize: "finalize-url",
|
FinalizeURL: "finalize-url",
|
||||||
Certificate: "cert-url",
|
CertificateURL: "cert-url",
|
||||||
}
|
}
|
||||||
_csr, err := pemutil.Read("../authority/testdata/certs/foo.csr")
|
_csr, err := pemutil.Read("../authority/testdata/certs/foo.csr")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
@ -1012,7 +1011,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/client-post": func(t *testing.T) test {
|
"fail/client-post": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -1021,7 +1020,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
r1: []byte{},
|
r1: []byte{},
|
||||||
rc1: 200,
|
rc1: 200,
|
||||||
r2: acme.MalformedErr(nil).ToACME(),
|
r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc2: 400,
|
rc2: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -1101,7 +1100,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
}
|
}
|
||||||
// Retrieve transport from options.
|
// Retrieve transport from options.
|
||||||
|
@ -1123,7 +1122,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
|
||||||
acc: &acme.Account{
|
acc: &acme.Account{
|
||||||
Contact: []string{"max", "mariano"},
|
Contact: []string{"max", "mariano"},
|
||||||
Status: "valid",
|
Status: "valid",
|
||||||
Orders: srv.URL + "/orders-url",
|
OrdersURL: srv.URL + "/orders-url",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1137,7 +1136,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
|
||||||
"fail/client-post": func(t *testing.T) test {
|
"fail/client-post": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
client: ac,
|
client: ac,
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -1147,7 +1146,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
|
||||||
client: ac,
|
client: ac,
|
||||||
r1: []byte{},
|
r1: []byte{},
|
||||||
rc1: 200,
|
rc1: 200,
|
||||||
r2: acme.MalformedErr(nil).ToACME(),
|
r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc2: 400,
|
rc2: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -1198,7 +1197,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
|
||||||
assert.Equals(t, hdr.Nonce, expectedNonce)
|
assert.Equals(t, hdr.Nonce, expectedNonce)
|
||||||
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
||||||
assert.Fatal(t, ok)
|
assert.Fatal(t, ok)
|
||||||
assert.Equals(t, jwsURL, ac.acc.Orders)
|
assert.Equals(t, jwsURL, ac.acc.OrdersURL)
|
||||||
assert.Equals(t, hdr.KeyID, ac.kid)
|
assert.Equals(t, hdr.KeyID, ac.kid)
|
||||||
|
|
||||||
payload, err := jws.Verify(ac.Key.Public())
|
payload, err := jws.Verify(ac.Key.Public())
|
||||||
|
@ -1232,7 +1231,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
|
||||||
srv := httptest.NewServer(nil)
|
srv := httptest.NewServer(nil)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
dir := acme.Directory{
|
dir := acmeAPI.Directory{
|
||||||
NewNonce: srv.URL + "/foo",
|
NewNonce: srv.URL + "/foo",
|
||||||
}
|
}
|
||||||
// Retrieve transport from options.
|
// Retrieve transport from options.
|
||||||
|
@ -1261,14 +1260,14 @@ func TestACMEClient_GetCertificate(t *testing.T) {
|
||||||
acc: &acme.Account{
|
acc: &acme.Account{
|
||||||
Contact: []string{"max", "mariano"},
|
Contact: []string{"max", "mariano"},
|
||||||
Status: "valid",
|
Status: "valid",
|
||||||
Orders: srv.URL + "/orders-url",
|
OrdersURL: srv.URL + "/orders-url",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"fail/client-post": func(t *testing.T) test {
|
"fail/client-post": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
r1: acme.MalformedErr(nil).ToACME(),
|
r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc1: 400,
|
rc1: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
@ -1277,7 +1276,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
|
||||||
return test{
|
return test{
|
||||||
r1: []byte{},
|
r1: []byte{},
|
||||||
rc1: 200,
|
rc1: 200,
|
||||||
r2: acme.MalformedErr(nil).ToACME(),
|
r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
|
||||||
rc2: 400,
|
rc2: 400,
|
||||||
err: errors.New("The request message was malformed"),
|
err: errors.New("The request message was malformed"),
|
||||||
}
|
}
|
||||||
|
|
23
ca/ca.go
23
ca/ca.go
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
acmeAPI "github.com/smallstep/certificates/acme/api"
|
acmeAPI "github.com/smallstep/certificates/acme/api"
|
||||||
|
acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
|
@ -149,23 +150,29 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix := "acme"
|
prefix := "acme"
|
||||||
acmeAuth, err := acme.New(auth, acme.AuthorityOptions{
|
var acmeDB acme.DB
|
||||||
|
if config.DB == nil {
|
||||||
|
acmeDB = nil
|
||||||
|
} else {
|
||||||
|
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error configuring ACME DB interface")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
|
||||||
Backdate: *config.AuthorityConfig.Backdate,
|
Backdate: *config.AuthorityConfig.Backdate,
|
||||||
DB: auth.GetDatabase().(nosql.DB),
|
DB: acmeDB,
|
||||||
DNS: dns,
|
DNS: dns,
|
||||||
Prefix: prefix,
|
Prefix: prefix,
|
||||||
|
CA: auth,
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error creating ACME authority")
|
|
||||||
}
|
|
||||||
acmeRouterHandler := acmeAPI.New(acmeAuth)
|
|
||||||
mux.Route("/"+prefix, func(r chi.Router) {
|
mux.Route("/"+prefix, func(r chi.Router) {
|
||||||
acmeRouterHandler.Route(r)
|
acmeHandler.Route(r)
|
||||||
})
|
})
|
||||||
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
|
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
|
||||||
// of the ACME spec.
|
// of the ACME spec.
|
||||||
mux.Route("/2.0/"+prefix, func(r chi.Router) {
|
mux.Route("/2.0/"+prefix, func(r chi.Router) {
|
||||||
acmeRouterHandler.Route(r)
|
acmeHandler.Route(r)
|
||||||
})
|
})
|
||||||
|
|
||||||
if ca.shouldServeSCEPEndpoints() {
|
if ca.shouldServeSCEPEndpoints() {
|
||||||
|
|
|
@ -57,6 +57,7 @@ func newInsecureClient() *uaClient {
|
||||||
return &uaClient{
|
return &uaClient{
|
||||||
Client: &http.Client{
|
Client: &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
4
debian/rules
vendored
4
debian/rules
vendored
|
@ -4,8 +4,10 @@ override_dh_install-arch:
|
||||||
dh_install --arch
|
dh_install --arch
|
||||||
|
|
||||||
build:
|
build:
|
||||||
make bootstrap
|
|
||||||
dh build
|
dh build
|
||||||
|
|
||||||
|
override_dh_auto_build:
|
||||||
|
dh_auto_build -- build
|
||||||
|
|
||||||
%:
|
%:
|
||||||
dh $@
|
dh $@
|
||||||
|
|
|
@ -80,7 +80,7 @@ Example `claims`:
|
||||||
use this value.
|
use this value.
|
||||||
|
|
||||||
* `enableSSHCA`: enable all provisioners to generate SSH Certificates.
|
* `enableSSHCA`: enable all provisioners to generate SSH Certificates.
|
||||||
The deault value is `false`. You can enable this option per provisioner
|
The default value is `false`. You can enable this option per provisioner
|
||||||
by setting it to `true` in the provisioner claims.
|
by setting it to `true` in the provisioner claims.
|
||||||
|
|
||||||
## Provisioner Types
|
## Provisioner Types
|
||||||
|
|
|
@ -99,7 +99,8 @@ var DefaultSSHTemplateData = map[string]string{
|
||||||
`,
|
`,
|
||||||
|
|
||||||
// sshd_config.tpl adds the configuration to support certificates
|
// sshd_config.tpl adds the configuration to support certificates
|
||||||
"sshd_config.tpl": `TrustedUserCAKeys /etc/ssh/ca.pub
|
"sshd_config.tpl": `Match all
|
||||||
|
TrustedUserCAKeys /etc/ssh/ca.pub
|
||||||
HostCertificate /etc/ssh/{{.User.Certificate}}
|
HostCertificate /etc/ssh/{{.User.Certificate}}
|
||||||
HostKey /etc/ssh/{{.User.Key}}`,
|
HostKey /etc/ssh/{{.User.Key}}`,
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue