diff --git a/.dockerignore b/.dockerignore
index 5b671c40..8dbc8eb2 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,5 +1,3 @@
-README.md
-.gitignore
bin
coverage.txt
*.test
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index f547d61d..819a470e 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -10,6 +10,9 @@ jobs:
test:
name: Lint, Test, Build
runs-on: ubuntu-20.04
+ strategy:
+ matrix:
+ go: [ '1.15', '1.16' ]
outputs:
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
steps:
@@ -20,15 +23,39 @@ jobs:
name: Setup Go
uses: actions/setup-go@v2
with:
- go-version: '1.15.8'
+ go-version: ${{ matrix.go }}
-
name: Install Deps
id: install-deps
run: sudo apt-get -y install libpcsclite-dev
-
- name: Lint, Test, Build
+ name: golangci-lint
+ uses: golangci/golangci-lint-action@v2
+ with:
+ # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
+ version: 'latest'
+
+ # Optional: working directory, useful for monorepos
+ # working-directory: somedir
+
+ # Optional: golangci-lint command line arguments.
+ args: --timeout=30m
+
+ # Optional: show only new issues if it's a pull request. The default value is `false`.
+ # only-new-issues: true
+
+ # Optional: if set to true then the action will use pre-installed Go.
+ # skip-go-installation: true
+
+ # Optional: if set to true then the action don't cache or restore ~/go/pkg.
+ # skip-pkg-cache: true
+
+ # Optional: if set to true then the action don't cache or restore ~/.cache/go-build.
+ # skip-build-cache: true
+ -
+ name: Test, Build
id: lint_test_build
- run: V=1 make -j1 bootstrap ci
+ run: V=1 make ci
create_release:
name: Create Release
@@ -96,7 +123,7 @@ jobs:
name: Set up Go
uses: actions/setup-go@v2
with:
- go-version: '1.15.8'
+ go-version: '1.16'
-
name: APT Install
id: aptInstall
@@ -126,7 +153,7 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v2
with:
- go-version: '1.15.8'
+ go-version: '1.16'
- name: Build
id: build
run: |
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 74e435fd..9c73cfbd 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -11,24 +11,55 @@ on:
jobs:
lintTestBuild:
name: Lint, Test, Build
- runs-on: ubuntu-latest
+ runs-on: ubuntu-20.04
+ strategy:
+ matrix:
+ go: [ '1.15', '1.16' ]
steps:
- - name: Checkout
+ -
+ name: Checkout
uses: actions/checkout@v2
- - name: Setup Go
+ -
+ name: Setup Go
uses: actions/setup-go@v2
with:
- go-version: '1.15.6'
- - name: Install Deps
+ go-version: ${{ matrix.go }}
+ -
+ name: Install Deps
id: install-deps
run: sudo apt-get -y install libpcsclite-dev
- - name: Lint, Test, Build
- id: lintTestBuild
- run: V=1 make -j1 bootstrap ci
- - name: Codecov
+ -
+ name: golangci-lint
+ uses: golangci/golangci-lint-action@v2
+ with:
+ # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
+ version: 'latest'
+
+ # Optional: working directory, useful for monorepos
+ # working-directory: somedir
+
+ # Optional: golangci-lint command line arguments.
+ args: --timeout=30m
+
+ # Optional: show only new issues if it's a pull request. The default value is `false`.
+ # only-new-issues: true
+
+ # Optional: if set to true then the action will use pre-installed Go.
+ # skip-go-installation: true
+
+ # Optional: if set to true then the action don't cache or restore ~/go/pkg.
+ # skip-pkg-cache: true
+
+ # Optional: if set to true then the action don't cache or restore ~/.cache/go-build.
+ # skip-build-cache: true
+ -
+ name: Test, Build
+ id: lint_test_build
+ run: V=1 make ci
+ -
+ name: Codecov
uses: codecov/codecov-action@v1.2.1
with:
- token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
file: ./coverage.out # optional
name: codecov-umbrella # optional
fail_ci_if_error: true # optional (default = false)
diff --git a/Makefile b/Makefile
index a8907b8b..1a3e7023 100644
--- a/Makefile
+++ b/Makefile
@@ -18,7 +18,7 @@ OUTPUT_ROOT=output/
all: lint test build
-ci: lintcgo testcgo build
+ci: testcgo build
.PHONY: all ci
@@ -28,7 +28,7 @@ ci: lintcgo testcgo build
bootstra%:
# Using a released version of golangci-lint to take into account custom replacements in their go.mod
- $Q GO111MODULE=on go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.24.0
+ $Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.39.0
.PHONY: bootstra%
@@ -38,7 +38,7 @@ bootstra%:
# If TRAVIS_TAG is set then we know this ref has been tagged.
ifdef TRAVIS_TAG
-VERSION := $(TRAVIS_TAG)
+VERSION ?= $(TRAVIS_TAG)
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
ifeq ($(NOT_RC),)
PUSHTYPE := release-candidate
@@ -47,7 +47,7 @@ PUSHTYPE := release
endif
# GITHUB Actions
else ifdef GITHUB_REF
-VERSION := $(shell echo $(GITHUB_REF) | sed 's/^refs\/tags\///')
+VERSION ?= $(shell echo $(GITHUB_REF) | sed 's/^refs\/tags\///')
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
ifeq ($(NOT_RC),)
PUSHTYPE := release-candidate
diff --git a/README.md b/README.md
index 21f4c35d..f0649175 100644
--- a/README.md
+++ b/README.md
@@ -22,8 +22,7 @@ Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [
[Website](https://smallstep.com/certificates) |
[Documentation](https://smallstep.com/docs) |
-[Installation Guide](#installation-guide) |
-[Quickstart](#quickstart) |
+[Installation](https://smallstep.com/docs/step-ca/installation) |
[Getting Started](https://smallstep.com/docs/step-ca/getting-started) |
[Contributor's Guide](./docs/CONTRIBUTING.md)
@@ -103,270 +102,9 @@ ACME is the protocol used by Let's Encrypt to automate the issuance of HTTPS cer
- [Install root certificates](https://smallstep.com/docs/step-cli/reference/certificate/install/) on your machine and browsers, so your CA is trusted
- [Inspect](https://smallstep.com/docs/step-cli/reference/certificate/inspect/) and [lint](https://smallstep.com/docs/step-cli/reference/certificate/lint/) certificates
-## Installation Guide
+## Installation
-These instructions will install an OS specific version of the `step-ca` binary on
-your local machine.
-
-Want to build from source? See [our contributor's guide](./docs/CONTRIBUTING.md)
-
-### Mac OS
-
-Install `step` and `step-ca` together, via [Homebrew](https://brew.sh/):
-
-```
-$ brew install step
-```
-
-### Linux
-
-> **Note:** The [`step` CLI tool](https://github.com/smallstep/cli) is the easiest way to initialize, configure, and control `step-ca`. While `step` is not technically required to run `step-ca`, it is very much recommended.
-
-#### Debian
-
-1. Install `step`.
-
- Download the Debian package from the
- [latest `step` release](https://github.com/smallstep/cli/releases/latest):
-
- ```
- $ wget https://github.com/smallstep/cli/releases/download/vX.Y.Z/step-cli_X.Y.Z_amd64.deb
- ```
-
- Install the Debian package:
-
- ```
- $ sudo dpkg -i step-cli_X.Y.Z_amd64.deb
- ```
-
-2. Install `step-ca`.
-
- Download the Debian package from the [latest `step-ca` release](https://github.com/smallstep/certificates/releases/latest):
-
- ```
- $ wget https://github.com/smallstep/certificates/releases/download/vX.Y.Z/step-ca_X.Y.Z_amd64.deb
- ```
-
- Install the Debian package:
-
- ```
- $ sudo dpkg -i step-ca_X.Y.Z_amd64.deb
- ```
-
-#### Arch Linux
-
-We are using the [Arch User Repository](https://aur.archlinux.org) to distribute
-`step` binaries for Arch Linux.
-
-* The `step` binary tarball can be found [here](https://aur.archlinux.org/packages/step-cli-bin/).
-* The `step-ca` binary tarball can be found [here](https://aur.archlinux.org/packages/step-ca-bin/).
-
-You can use [pacman](https://www.archlinux.org/pacman/) to install the packages.
-
-#### RHEL/CentOS
-
-1. Install `step`.
-
- Download the Linux tarball from the
- [latest `step` release](https://github.com/smallstep/cli/releases/latest):
-
- ```
- $ wget -O step-cli.tar.gz https://github.com/smallstep/cli/releases/download/vX.Y.Z/step_linux_X.Y.Z_amd64.tar.gz
- ```
-
- Install `step` by unzipping and copying the executable over to `/usr/bin`:
-
- ```
- $ tar -xf step-cli.tar.gz
- $ sudo cp step_X.Y.Z/bin/step /usr/bin
- ```
-
-2. Install `step-ca`.
-
- Download the Linux package from the [latest `step-ca` release](https://github.com/smallstep/certificates/releases/latest):
-
- ```
- $ wget -O step-ca.tar.gz https://github.com/smallstep/certificates/releases/download/vX.Y.Z/step-ca_linux_X.Y.Z_amd64.tar.gz
- ```
-
- Install `step-ca` by unzipping and copying the executable over to `/usr/bin`:
-
- ```
- $ tar -xf step-ca.tar.gz
- $ sudo cp step-ca_X.Y.Z/bin/step-ca /usr/bin
- ```
-
-See the [`systemctl` setup section](https://smallstep.com/docs/step-ca/certificate-authority-server-production#running-step-ca-as-a-daemon) for a
-guide on configuring `step-ca` as a daemon.
-
-### Kubernetes
-
-We publish [helm charts](https://hub.helm.sh/charts/smallstep/step-certificates) for easy installation on kubernetes:
-
-```
-helm install step-certificates
-```
-
->
->
-> If you're using Kubernetes, make sure you [check out
-> autocert](https://github.com/smallstep/autocert): a kubernetes add-on that builds on `step
-> certificates` to automatically inject TLS/HTTPS certificates into your containers.
-
-### Docker
-
-See our [Docker getting started guide](https://smallstep.com/docs/tutorials/docker-tls-certificate-authority)
-
-### Test
-
-
$ step version
-Smallstep CLI/0.10.0 (darwin/amd64)
-Release Date: 2019-04-30 19:01 UTC
-
-$ step-ca version
-Smallstep CA/0.10.0 (darwin/amd64)
-Release Date: 2019-04-30 19:02 UTC
-
-## Quickstart
-
-In the following guide we'll run a simple `hello` server that requires clients
-to connect over an authorized and encrypted channel using HTTPS. `step-ca`
-will issue certificates to our server, allowing it to authenticate and encrypt
-communication.
-
-![Animated terminal showing step certificates in practice](https://github.com/smallstep/certificates/raw/master/docs/images/step-ca-2-legged.gif)
-
-Let's get started!
-
-### Prerequisites
-
-* [`step`](#installation-guide)
-* [golang](https://golang.org/doc/install)
-
-### Let's get started!
-
-#### 1. Run `step ca init` to create your CA's keys & certificates and configure `step-ca`:
-
-$ step ca init
-✔ What would you like to name your new PKI? (e.g. Smallstep): Example Inc.
-✔ What DNS names or IP addresses would you like to add to your new CA? (e.g. ca.smallstep.com[,1.1.1.1,etc.]): localhost
-✔ What address will your new CA listen at? (e.g. :443): 127.0.0.1:8080
-✔ What would you like to name the first provisioner for your new CA? (e.g. you@smallstep.com): bob@example.com
-✔ What do you want your password to be? [leave empty and we'll generate one]: abc123
-
-Generating root certificate...
-all done!
-
-Generating intermediate certificate...
-all done!
-
-✔ Root certificate: /Users/bob/src/github.com/smallstep/step/.step/certs/root_ca.crt
-✔ Root private key: /Users/bob/src/github.com/smallstep/step/.step/secrets/root_ca_key
-✔ Root fingerprint: 702a094e239c9eec6f0dcd0a5f65e595bf7ed6614012825c5fe3d1ae1b2fd6ee
-✔ Intermediate certificate: /Users/bob/src/github.com/smallstep/step/.step/certs/intermediate_ca.crt
-✔ Intermediate private key: /Users/bob/src/github.com/smallstep/step/.step/secrets/intermediate_ca_key
-✔ Default configuration: /Users/bob/src/github.com/smallstep/step/.step/config/defaults.json
-✔ Certificate Authority configuration: /Users/bob/src/github.com/smallstep/step/.step/config/ca.json
-
-Your PKI is ready to go. To generate certificates for individual services see 'step help ca'.
-
-This command will:
-
-- Generate [password protected](https://github.com/smallstep/certificates/blob/master/docs/GETTING_STARTED.md#passwords) private keys for your CA to sign certificates
-- Generate a root and [intermediate signing certificate](https://security.stackexchange.com/questions/128779/why-is-it-more-secure-to-use-intermediate-ca-certificates) for your CA
-- Create a JSON configuration file for `step-ca` (see [configuration docs](https://smallstep.com/docs/step-ca/configuration) for details)
-
-You can find these artifacts in `$STEPPATH` (or `~/.step` by default).
-
-#### 2. Start `step-ca`:
-
-You'll be prompted for your password from the previous step, to decrypt the CA's private signing key:
-
-$ step-ca $(step path)/config/ca.json
-Please enter the password to decrypt /Users/bob/src/github.com/smallstep/step/.step/secrets/intermediate_ca_key: abc123
-2019/02/18 13:28:58 Serving HTTPS on 127.0.0.1:8080 ...
-
-#### 3. Copy our `hello world` golang server.
-
-```
-$ cat > srv.go <$ step ca certificate localhost srv.crt srv.key
-✔ Key ID: rQxROEr7Kx9TNjSQBTETtsu3GKmuW9zm02dMXZ8GUEk (bob@example.com)
-✔ Please enter the password to decrypt the provisioner key: abc123
-✔ CA: https://localhost:8080/1.0/sign
-✔ Certificate: srv.crt
-✔ Private Key: srv.key
-
-$ step certificate inspect --bundle srv.crt
-Certificate:
- Data:
- Version: 3 (0x2)
- Serial Number: 140439335711218707689123407681832384336 (0x69a7a1d7f6f22f68059d2d9088307750)
- Signature Algorithm: ECDSA-SHA256
- Issuer: CN=Example Inc. Intermediate CA
- Validity
- Not Before: Feb 18 21:32:35 2019 UTC
- Not After : Feb 19 21:32:35 2019 UTC
- Subject: CN=localhost
-...
-Certificate:
- Data:
- Version: 3 (0x2)
- Serial Number: 207035091234452090159026162349261226844 (0x9bc18217bd560cf07db23178ed90835c)
- Signature Algorithm: ECDSA-SHA256
- Issuer: CN=Example Inc. Root CA
- Validity
- Not Before: Feb 18 21:27:21 2019 UTC
- Not After : Feb 15 21:27:21 2029 UTC
- Subject: CN=Example Inc. Intermediate CA
-...
-
-Note that `step` and `step-ca` handle details like [certificate bundling](https://smallstep.com/blog/everything-pki.html#intermediates-chains-and-bundling) for you.
-
-#### 5. Run the simple server.
-
-$ go run srv.go &
-
-#### 6. Get the root certificate from the Step CA.
-
-In a new Terminal window:
-
-$ step ca root root.crt
-The root certificate has been saved in root.crt.
-
-#### 7. Make an authenticated, encrypted curl request to your server using HTTP over TLS.
-
-$ curl --cacert root.crt https://localhost:8443/hi
-Hello, world!
-
-*All Done!*
-
-Check out the [Getting Started](./docs/GETTING_STARTED.md) guide for more examples
-and best practices on running Step CA in production.
+See our installation docs [here](https://smallstep.com/docs/step-ca/installation).
## Documentation
diff --git a/acme/account.go b/acme/account.go
index 1c5870d5..197a3400 100644
--- a/acme/account.go
+++ b/acme/account.go
@@ -1,197 +1,42 @@
package acme
import (
- "context"
+ "crypto"
+ "encoding/base64"
"encoding/json"
- "time"
- "github.com/pkg/errors"
- "github.com/smallstep/nosql"
"go.step.sm/crypto/jose"
)
// Account is a subset of the internal account type containing only those
// attributes required for responses in the ACME protocol.
type Account struct {
- Contact []string `json:"contact,omitempty"`
- Status string `json:"status"`
- Orders string `json:"orders"`
- ID string `json:"-"`
- Key *jose.JSONWebKey `json:"-"`
+ ID string `json:"-"`
+ Key *jose.JSONWebKey `json:"-"`
+ Contact []string `json:"contact,omitempty"`
+ Status Status `json:"status"`
+ OrdersURL string `json:"orders"`
}
// ToLog enables response logging.
func (a *Account) ToLog() (interface{}, error) {
b, err := json.Marshal(a)
if err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging"))
+ return nil, WrapErrorISE(err, "error marshaling account for logging")
}
return string(b), nil
}
-// GetID returns the account ID.
-func (a *Account) GetID() string {
- return a.ID
-}
-
-// GetKey returns the JWK associated with the account.
-func (a *Account) GetKey() *jose.JSONWebKey {
- return a.Key
-}
-
// IsValid returns true if the Account is valid.
func (a *Account) IsValid() bool {
- return a.Status == StatusValid
+ return Status(a.Status) == StatusValid
}
-// AccountOptions are the options needed to create a new ACME account.
-type AccountOptions struct {
- Key *jose.JSONWebKey
- Contact []string
-}
-
-// account represents an ACME account.
-type account struct {
- ID string `json:"id"`
- Created time.Time `json:"created"`
- Deactivated time.Time `json:"deactivated"`
- Key *jose.JSONWebKey `json:"key"`
- Contact []string `json:"contact,omitempty"`
- Status string `json:"status"`
-}
-
-// newAccount returns a new acme account type.
-func newAccount(db nosql.DB, ops AccountOptions) (*account, error) {
- id, err := randID()
+// KeyToID converts a JWK to a thumbprint.
+func KeyToID(jwk *jose.JSONWebKey) (string, error) {
+ kid, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
- return nil, err
+ return "", WrapErrorISE(err, "error generating jwk thumbprint")
}
-
- a := &account{
- ID: id,
- Key: ops.Key,
- Contact: ops.Contact,
- Status: "valid",
- Created: clock.Now(),
- }
- return a, a.saveNew(db)
-}
-
-// toACME converts the internal Account type into the public acmeAccount
-// type for presentation in the ACME protocol.
-func (a *account) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Account, error) {
- return &Account{
- Status: a.Status,
- Contact: a.Contact,
- Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID),
- Key: a.Key,
- ID: a.ID,
- }, nil
-}
-
-// save writes the Account to the DB.
-// If the account is new then the necessary indices will be created.
-// Else, the account in the DB will be updated.
-func (a *account) saveNew(db nosql.DB) error {
- kid, err := keyToID(a.Key)
- if err != nil {
- return err
- }
- kidB := []byte(kid)
-
- // Set the jwkID -> acme account ID index
- _, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID))
- switch {
- case err != nil:
- return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index"))
- case !swapped:
- return ServerInternalErr(errors.Errorf("key-id to account-id index already exists"))
- default:
- if err = a.save(db, nil); err != nil {
- db.Del(accountByKeyIDTable, kidB)
- return err
- }
- return nil
- }
-}
-
-func (a *account) save(db nosql.DB, old *account) error {
- var (
- err error
- oldB []byte
- )
- if old == nil {
- oldB = nil
- } else {
- if oldB, err = json.Marshal(old); err != nil {
- return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
- }
- }
-
- b, err := json.Marshal(*a)
- if err != nil {
- return errors.Wrap(err, "error marshaling new account object")
- }
- // Set the Account
- _, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b)
- switch {
- case err != nil:
- return ServerInternalErr(errors.Wrap(err, "error storing account"))
- case !swapped:
- return ServerInternalErr(errors.New("error storing account; " +
- "value has changed since last read"))
- default:
- return nil
- }
-}
-
-// update updates the acme account object stored in the database if,
-// and only if, the account has not changed since the last read.
-func (a *account) update(db nosql.DB, contact []string) (*account, error) {
- b := *a
- b.Contact = contact
- if err := b.save(db, a); err != nil {
- return nil, err
- }
- return &b, nil
-}
-
-// deactivate deactivates the acme account.
-func (a *account) deactivate(db nosql.DB) (*account, error) {
- b := *a
- b.Status = StatusDeactivated
- b.Deactivated = clock.Now()
- if err := b.save(db, a); err != nil {
- return nil, err
- }
- return &b, nil
-}
-
-// getAccountByID retrieves the account with the given ID.
-func getAccountByID(db nosql.DB, id string) (*account, error) {
- ab, err := db.Get(accountTable, []byte(id))
- if err != nil {
- if nosql.IsErrNotFound(err) {
- return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id))
- }
- return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id))
- }
-
- a := new(account)
- if err = json.Unmarshal(ab, a); err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account"))
- }
- return a, nil
-}
-
-// getAccountByKeyID retrieves Id associated with the given Kid.
-func getAccountByKeyID(db nosql.DB, kid string) (*account, error) {
- id, err := db.Get(accountByKeyIDTable, []byte(kid))
- if err != nil {
- if nosql.IsErrNotFound(err) {
- return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid))
- }
- return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index"))
- }
- return getAccountByID(db, string(id))
+ return base64.RawURLEncoding.EncodeToString(kid), nil
}
diff --git a/acme/account_test.go b/acme/account_test.go
index 2e072af5..5625c3dc 100644
--- a/acme/account_test.go
+++ b/acme/account_test.go
@@ -1,770 +1,81 @@
package acme
import (
- "context"
- "encoding/json"
- "fmt"
- "net/url"
+ "crypto"
+ "encoding/base64"
"testing"
- "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
- "github.com/smallstep/certificates/authority/provisioner"
- "github.com/smallstep/certificates/db"
- "github.com/smallstep/nosql"
- "github.com/smallstep/nosql/database"
"go.step.sm/crypto/jose"
)
-var (
- defaultDisableRenewal = false
- globalProvisionerClaims = provisioner.Claims{
- MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
- MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
- DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
- DisableRenewal: &defaultDisableRenewal,
- }
-)
-
-func newProv() Provisioner {
- // Initialize provisioners
- p := &provisioner.ACME{
- Type: "ACME",
- Name: "test@acme-provisioner.com",
- }
- if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
- fmt.Printf("%v", err)
- }
- return p
-}
-
-func newAcc() (*account, error) {
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- if err != nil {
- return nil, err
- }
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, true, nil
- },
- }
- return newAccount(mockdb, AccountOptions{
- Key: jwk, Contact: []string{"foo", "bar"},
- })
-}
-
-func TestGetAccountByID(t *testing.T) {
+func TestKeyToID(t *testing.T) {
type test struct {
- id string
- db nosql.DB
- acc *account
+ jwk *jose.JSONWebKey
+ exp string
err *Error
}
tests := map[string]func(t *testing.T) test{
- "fail/not-found": func(t *testing.T) test {
- acc, err := newAcc()
+ "fail/error-generating-thumbprint": func(t *testing.T) test {
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
+ jwk.Key = "foo"
return test{
- acc: acc,
- id: acc.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- },
- err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)),
- }
- },
- "fail/db-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- return test{
- acc: acc,
- id: acc.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)),
- }
- },
- "fail/unmarshal-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- return test{
- acc: acc,
- id: acc.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- return nil, nil
- },
- },
- err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")),
+ jwk: jwk,
+ err: NewErrorISE("error generating jwk thumbprint: square/go-jose: unknown key type 'string'"),
}
},
"ok": func(t *testing.T) test {
- acc, err := newAcc()
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- b, err := json.Marshal(acc)
+
+ kid, err := jwk.Thumbprint(crypto.SHA256)
assert.FatalError(t, err)
+
return test{
- acc: acc,
- id: acc.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- return b, nil
- },
- },
+ jwk: jwk,
+ exp: base64.RawURLEncoding.EncodeToString(kid),
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
- if acc, err := getAccountByID(tc.db, tc.id); err != nil {
+ if id, err := KeyToID(tc.jwk); err != nil {
if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
}
} else {
if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.acc.ID, acc.ID)
- assert.Equals(t, tc.acc.Status, acc.Status)
- assert.Equals(t, tc.acc.Created, acc.Created)
- assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
- assert.Equals(t, tc.acc.Contact, acc.Contact)
- assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
+ assert.Equals(t, id, tc.exp)
}
}
})
}
}
-func TestGetAccountByKeyID(t *testing.T) {
+func TestAccount_IsValid(t *testing.T) {
type test struct {
- kid string
- db nosql.DB
- acc *account
- err *Error
+ acc *Account
+ exp bool
}
- tests := map[string]func(t *testing.T) test{
- "fail/kid-not-found": func(t *testing.T) test {
- return test{
- kid: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- },
- err: MalformedErr(errors.Errorf("account with key id foo not found: not found")),
- }
- },
- "fail/db-error": func(t *testing.T) test {
- return test{
- kid: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error loading key-account index: force")),
- }
- },
- "fail/getAccount-error": func(t *testing.T) test {
- count := 0
- return test{
- kid: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- if count == 0 {
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte("foo"))
- count++
- return []byte("bar"), nil
- }
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error loading account bar: force")),
- }
- },
- "ok": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
- count := 0
- return test{
- kid: acc.Key.KeyID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- var ret []byte
- switch count {
- case 0:
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte(acc.Key.KeyID))
- ret = []byte(acc.ID)
- case 1:
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- ret = b
- }
- count++
- return ret, nil
- },
- },
- acc: acc,
- }
- },
+ tests := map[string]test{
+ "valid": {acc: &Account{Status: StatusValid}, exp: true},
+ "invalid": {acc: &Account{Status: StatusInvalid}, exp: false},
}
- for name, run := range tests {
+ for name, tc := range tests {
t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.acc.ID, acc.ID)
- assert.Equals(t, tc.acc.Status, acc.Status)
- assert.Equals(t, tc.acc.Created, acc.Created)
- assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
- assert.Equals(t, tc.acc.Contact, acc.Contact)
- assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
- }
- }
- })
- }
-}
-
-func TestAccountToACME(t *testing.T) {
- dir := newDirectory("ca.smallstep.com", "acme")
- prov := newProv()
- provName := url.PathEscape(prov.GetName())
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
-
- type test struct {
- acc *account
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "ok": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- return test{acc: acc}
- },
- }
- for name, run := range tests {
- tc := run(t)
- t.Run(name, func(t *testing.T) {
- acmeAccount, err := tc.acc.toACME(ctx, nil, dir)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, acmeAccount.ID, tc.acc.ID)
- assert.Equals(t, acmeAccount.Status, tc.acc.Status)
- assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
- assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
- assert.Equals(t, acmeAccount.Orders,
- fmt.Sprintf("%s/acme/%s/account/%s/orders", baseURL.String(), provName, tc.acc.ID))
- }
- }
- })
- }
-}
-
-func TestAccountSave(t *testing.T) {
- type test struct {
- acc, old *account
- db nosql.DB
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/old-nil/swap-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- return test{
- acc: acc,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error storing account: force")),
- }
- },
- "fail/old-nil/swap-false": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- return test{
- acc: acc,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), false, nil
- },
- },
- err: ServerInternalErr(errors.New("error storing account; value has changed since last read")),
- }
- },
- "ok/old-nil": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
- return test{
- acc: acc,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, nil)
- assert.Equals(t, b, newval)
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, []byte(acc.ID), key)
- return nil, true, nil
- },
- },
- }
- },
- "ok/old-not-nil": func(t *testing.T) test {
- oldAcc, err := newAcc()
- assert.FatalError(t, err)
- acc, err := newAcc()
- assert.FatalError(t, err)
-
- oldb, err := json.Marshal(oldAcc)
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
- return test{
- acc: acc,
- old: oldAcc,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, b)
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, []byte(acc.ID), key)
- return []byte("foo"), true, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if err := tc.acc.save(tc.db, tc.old); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- assert.Nil(t, tc.err)
- }
- })
- }
-}
-
-func TestAccountSaveNew(t *testing.T) {
- type test struct {
- acc *account
- db nosql.DB
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/keyToID-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- acc.Key.Key = "foo"
- return test{
- acc: acc,
- err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
- }
- },
- "fail/swap-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- kid, err := keyToID(acc.Key)
- assert.FatalError(t, err)
- return test{
- acc: acc,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte(kid))
- assert.Equals(t, old, nil)
- assert.Equals(t, newval, []byte(acc.ID))
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
- }
- },
- "fail/swap-false": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- kid, err := keyToID(acc.Key)
- assert.FatalError(t, err)
- return test{
- acc: acc,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte(kid))
- assert.Equals(t, old, nil)
- assert.Equals(t, newval, []byte(acc.ID))
- return nil, false, nil
- },
- },
- err: ServerInternalErr(errors.New("key-id to account-id index already exists")),
- }
- },
- "fail/save-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- kid, err := keyToID(acc.Key)
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
- count := 0
- return test{
- acc: acc,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 0 {
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte(kid))
- assert.Equals(t, old, nil)
- assert.Equals(t, newval, []byte(acc.ID))
- count++
- return nil, true, nil
- }
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- assert.Equals(t, old, nil)
- assert.Equals(t, newval, b)
- return nil, false, errors.New("force")
- },
- MDel: func(bucket, key []byte) error {
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte(kid))
- return nil
- },
- },
- err: ServerInternalErr(errors.New("error storing account: force")),
- }
- },
- "ok": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- kid, err := keyToID(acc.Key)
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
- count := 0
- return test{
- acc: acc,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 0 {
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte(kid))
- assert.Equals(t, old, nil)
- assert.Equals(t, newval, []byte(acc.ID))
- count++
- return nil, true, nil
- }
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- assert.Equals(t, old, nil)
- assert.Equals(t, newval, b)
- return nil, true, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if err := tc.acc.saveNew(tc.db); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- assert.Nil(t, tc.err)
- }
- })
- }
-}
-
-func TestAccountUpdate(t *testing.T) {
- type test struct {
- acc *account
- contact []string
- db nosql.DB
- res []byte
- err *Error
- }
- contact := []string{"foo", "bar"}
- tests := map[string]func(t *testing.T) test{
- "fail/save-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(acc)
- assert.FatalError(t, err)
-
- _acc := *acc
- clone := &_acc
- clone.Contact = contact
- b, err := json.Marshal(clone)
- assert.FatalError(t, err)
- return test{
- acc: acc,
- contact: contact,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, b)
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error storing account: force")),
- }
- },
- "ok": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(acc)
- assert.FatalError(t, err)
-
- _acc := *acc
- clone := &_acc
- clone.Contact = contact
- b, err := json.Marshal(clone)
- assert.FatalError(t, err)
- return test{
- acc: acc,
- contact: contact,
- res: b,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, b)
- return nil, true, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- acc, err := tc.acc.update(tc.db, tc.contact)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
- assert.Equals(t, b, tc.res)
- }
- }
- })
- }
-}
-
-func TestAccountDeactivate(t *testing.T) {
- type test struct {
- acc *account
- db nosql.DB
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/save-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(acc)
- assert.FatalError(t, err)
-
- return test{
- acc: acc,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- assert.Equals(t, old, oldb)
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error storing account: force")),
- }
- },
- "ok": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(acc)
- assert.FatalError(t, err)
-
- return test{
- acc: acc,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- assert.Equals(t, old, oldb)
- return nil, true, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- acc, err := tc.acc.deactivate(tc.db)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, acc.ID, tc.acc.ID)
- assert.Equals(t, acc.Contact, tc.acc.Contact)
- assert.Equals(t, acc.Status, StatusDeactivated)
- assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID)
- assert.Equals(t, acc.Created, tc.acc.Created)
-
- assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute)))
- assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute)))
- }
- }
- })
- }
-}
-
-func TestNewAccount(t *testing.T) {
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
- kid, err := keyToID(jwk)
- assert.FatalError(t, err)
- ops := AccountOptions{
- Key: jwk,
- Contact: []string{"foo", "bar"},
- }
- type test struct {
- ops AccountOptions
- db nosql.DB
- err *Error
- id *string
- }
- tests := map[string]func(t *testing.T) test{
- "fail/store-error": func(t *testing.T) test {
- return test{
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
- }
- },
- "ok": func(t *testing.T) test {
- var _id string
- id := &_id
- count := 0
- return test{
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- switch count {
- case 0:
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte(kid))
- case 1:
- assert.Equals(t, bucket, accountTable)
- *id = string(key)
- }
- count++
- return nil, true, nil
- },
- },
- id: id,
- }
- },
- }
- for name, run := range tests {
- tc := run(t)
- t.Run(name, func(t *testing.T) {
- acc, err := newAccount(tc.db, tc.ops)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, acc.ID, *tc.id)
- assert.Equals(t, acc.Status, StatusValid)
- assert.Equals(t, acc.Contact, ops.Contact)
- assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID)
-
- assert.True(t, acc.Deactivated.IsZero())
-
- assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute)))
- }
- }
+ assert.Equals(t, tc.acc.IsValid(), tc.exp)
})
}
}
diff --git a/acme/api/account.go b/acme/api/account.go
index 93f46651..b733c679 100644
--- a/acme/api/account.go
+++ b/acme/api/account.go
@@ -5,7 +5,6 @@ import (
"net/http"
"github.com/go-chi/chi"
- "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/logging"
@@ -21,7 +20,7 @@ type NewAccountRequest struct {
func validateContacts(cs []string) error {
for _, c := range cs {
if len(c) == 0 {
- return acme.MalformedErr(errors.New("contact cannot be empty string"))
+ return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
}
}
return nil
@@ -30,29 +29,23 @@ func validateContacts(cs []string) error {
// Validate validates a new-account request body.
func (n *NewAccountRequest) Validate() error {
if n.OnlyReturnExisting && len(n.Contact) > 0 {
- return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone"))
+ return acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone")
}
return validateContacts(n.Contact)
}
// UpdateAccountRequest represents an update-account request.
type UpdateAccountRequest struct {
- Contact []string `json:"contact"`
- Status string `json:"status"`
-}
-
-// IsDeactivateRequest returns true if the update request is a deactivation
-// request, false otherwise.
-func (u *UpdateAccountRequest) IsDeactivateRequest() bool {
- return u.Status == acme.StatusDeactivated
+ Contact []string `json:"contact"`
+ Status acme.Status `json:"status"`
}
// Validate validates a update-account request body.
func (u *UpdateAccountRequest) Validate() error {
switch {
case len(u.Status) > 0 && len(u.Contact) > 0:
- return acme.MalformedErr(errors.New("incompatible input; contact and " +
- "status updates are mutually exclusive"))
+ return acme.NewError(acme.ErrorMalformedType, "incompatible input; contact and "+
+ "status updates are mutually exclusive")
case len(u.Contact) > 0:
if err := validateContacts(u.Contact); err != nil {
return err
@@ -60,8 +53,8 @@ func (u *UpdateAccountRequest) Validate() error {
return nil
case len(u.Status) > 0:
if u.Status != acme.StatusDeactivated {
- return acme.MalformedErr(errors.Errorf("cannot update account "+
- "status to %s, only deactivated", u.Status))
+ return acme.NewError(acme.ErrorMalformedType, "cannot update account "+
+ "status to %s, only deactivated", u.Status)
}
return nil
default:
@@ -73,15 +66,16 @@ func (u *UpdateAccountRequest) Validate() error {
// NewAccount is the handler resource for creating new ACME accounts.
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
- payload, err := payloadFromContext(r.Context())
+ ctx := r.Context()
+ payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
var nar NewAccountRequest
if err := json.Unmarshal(payload.value, &nar); err != nil {
- api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
- "failed to unmarshal new-account request payload")))
+ api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
+ "failed to unmarshal new-account request payload"))
return
}
if err := nar.Validate(); err != nil {
@@ -90,7 +84,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
}
httpStatus := http.StatusCreated
- acc, err := acme.AccountFromContext(r.Context())
+ acc, err := accountFromContext(r.Context())
if err != nil {
acmeErr, ok := err.(*acme.Error)
if !ok || acmeErr.Status != http.StatusBadRequest {
@@ -101,20 +95,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
// Account does not exist //
if nar.OnlyReturnExisting {
- api.WriteError(w, acme.AccountDoesNotExistErr(nil))
+ api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
+ "account does not exist"))
return
}
- jwk, err := acme.JwkFromContext(r.Context())
+ jwk, err := jwkFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
- if acc, err = h.Auth.NewAccount(r.Context(), acme.AccountOptions{
+ acc = &acme.Account{
Key: jwk,
Contact: nar.Contact,
- }); err != nil {
- api.WriteError(w, err)
+ Status: acme.StatusValid,
+ }
+ if err := h.db.CreateAccount(ctx, acc); err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error creating account"))
return
}
} else {
@@ -122,19 +119,21 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
httpStatus = http.StatusOK
}
- w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink,
- true, acc.GetID()))
+ h.linker.LinkAccount(ctx, acc)
+
+ w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
api.JSONStatus(w, acc, httpStatus)
}
-// GetUpdateAccount is the api for updating an ACME account.
-func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
- acc, err := acme.AccountFromContext(r.Context())
+// GetOrUpdateAccount is the api for updating an ACME account.
+func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
- payload, err := payloadFromContext(r.Context())
+ payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
@@ -145,29 +144,31 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
if !payload.isPostAsGet {
var uar UpdateAccountRequest
if err := json.Unmarshal(payload.value, &uar); err != nil {
- api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload")))
+ api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
+ "failed to unmarshal new-account request payload"))
return
}
if err := uar.Validate(); err != nil {
api.WriteError(w, err)
return
}
- var err error
- // If neither the status nor the contacts are being updated then ignore
- // the updates and return 200. This conforms with the behavior detailed
- // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2).
- if uar.IsDeactivateRequest() {
- acc, err = h.Auth.DeactivateAccount(r.Context(), acc.GetID())
- } else if len(uar.Contact) > 0 {
- acc, err = h.Auth.UpdateAccount(r.Context(), acc.GetID(), uar.Contact)
- }
- if err != nil {
- api.WriteError(w, err)
- return
+ if len(uar.Status) > 0 || len(uar.Contact) > 0 {
+ if len(uar.Status) > 0 {
+ acc.Status = uar.Status
+ } else if len(uar.Contact) > 0 {
+ acc.Contact = uar.Contact
+ }
+
+ if err := h.db.UpdateAccount(ctx, acc); err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error updating account"))
+ return
+ }
}
}
- w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink,
- true, acc.GetID()))
+
+ h.linker.LinkAccount(ctx, acc)
+
+ w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
api.JSON(w, acc)
}
@@ -180,23 +181,27 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
}
}
-// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
-func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
- acc, err := acme.AccountFromContext(r.Context())
+// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
+func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
accID := chi.URLParam(r, "accID")
if acc.ID != accID {
- api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param")))
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return
}
- orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID())
+ orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil {
api.WriteError(w, err)
return
}
+
+ h.linker.LinkOrdersByAccountID(ctx, orders)
+
api.JSON(w, orders)
logOrdersByAccount(w, orders)
}
diff --git a/acme/api/account_test.go b/acme/api/account_test.go
index bdd61c59..c4d7a812 100644
--- a/acme/api/account_test.go
+++ b/acme/api/account_test.go
@@ -12,7 +12,6 @@ import (
"time"
"github.com/go-chi/chi"
- "github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
@@ -29,11 +28,11 @@ var (
}
)
-func newProv() provisioner.Interface {
+func newProv() acme.Provisioner {
// Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
- Name: "test@acme-provisioner.com",
+ Name: "test@acme-provisioner.com",
}
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
fmt.Printf("%v", err)
@@ -41,7 +40,7 @@ func newProv() provisioner.Interface {
return p
}
-func TestNewAccountRequestValidate(t *testing.T) {
+func TestNewAccountRequest_Validate(t *testing.T) {
type test struct {
nar *NewAccountRequest
err *acme.Error
@@ -53,7 +52,7 @@ func TestNewAccountRequestValidate(t *testing.T) {
OnlyReturnExisting: true,
Contact: []string{"foo", "bar"},
},
- err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")),
+ err: acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone"),
}
},
"fail/bad-contact": func(t *testing.T) test {
@@ -61,7 +60,7 @@ func TestNewAccountRequestValidate(t *testing.T) {
nar: &NewAccountRequest{
Contact: []string{"foo", ""},
},
- err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
+ err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
}
},
"ok": func(t *testing.T) test {
@@ -97,7 +96,7 @@ func TestNewAccountRequestValidate(t *testing.T) {
}
}
-func TestUpdateAccountRequestValidate(t *testing.T) {
+func TestUpdateAccountRequest_Validate(t *testing.T) {
type test struct {
uar *UpdateAccountRequest
err *acme.Error
@@ -109,8 +108,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) {
Contact: []string{"foo", "bar"},
Status: "foo",
},
- err: acme.MalformedErr(errors.Errorf("incompatible input; " +
- "contact and status updates are mutually exclusive")),
+ err: acme.NewError(acme.ErrorMalformedType, "incompatible input; "+
+ "contact and status updates are mutually exclusive"),
}
},
"fail/bad-contact": func(t *testing.T) test {
@@ -118,7 +117,7 @@ func TestUpdateAccountRequestValidate(t *testing.T) {
uar: &UpdateAccountRequest{
Contact: []string{"foo", ""},
},
- err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
+ err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
}
},
"fail/bad-status": func(t *testing.T) test {
@@ -126,8 +125,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) {
uar: &UpdateAccountRequest{
Status: "foo",
},
- err: acme.MalformedErr(errors.Errorf("cannot update account " +
- "status to foo, only deactivated")),
+ err: acme.NewError(acme.ErrorMalformedType, "cannot update account "+
+ "status to foo, only deactivated"),
}
},
"ok/contact": func(t *testing.T) test {
@@ -168,81 +167,81 @@ func TestUpdateAccountRequestValidate(t *testing.T) {
}
}
-func TestHandlerGetOrdersByAccount(t *testing.T) {
- oids := []string{
- "https://ca.smallstep.com/acme/order/foo",
- "https://ca.smallstep.com/acme/order/bar",
- }
+func TestHandler_GetOrdersByAccountID(t *testing.T) {
accID := "account-id"
- prov := newProv()
// Request with chi context
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("accID", accID)
- url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID)
+
+ prov := newProv()
+ provName := url.PathEscape(prov.GetName())
+ baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
+
+ url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
+
+ oids := []string{"foo", "bar"}
+ oidURLs := []string{
+ fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName),
+ fmt.Sprintf("%s/acme/%s/order/bar", baseURL.String(), provName),
+ }
type test struct {
- auth acme.Interface
+ db acme.DB
ctx context.Context
statusCode int
- problem *acme.Error
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
- auth: &mockAcmeAuthority{},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ db: &acme.MockDB{},
+ ctx: context.Background(),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, nil)
return test{
- auth: &mockAcmeAuthority{},
- ctx: ctx,
+ db: &acme.MockDB{},
+ ctx: context.WithValue(context.Background(), accContextKey, nil),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "foo"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
- auth: &mockAcmeAuthority{},
+ db: &acme.MockDB{},
ctx: ctx,
statusCode: 401,
- problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")),
+ err: acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"),
}
},
- "fail/getOrdersByAccount-error": func(t *testing.T) test {
+ "fail/db.GetOrdersByAccountID-error": func(t *testing.T) test {
acc := &acme.Account{ID: accID}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
- auth: &mockAcmeAuthority{
- err: acme.ServerInternalErr(errors.New("force")),
+ db: &acme.MockDB{
+ MockError: acme.NewErrorISE("force"),
},
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: accID}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{
- auth: &mockAcmeAuthority{
- getOrdersByAccount: func(ctx context.Context, id string) ([]string, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
+ db: &acme.MockDB{
+ MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
assert.Equals(t, id, acc.ID)
return oids, nil
},
@@ -255,11 +254,11 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
- h.GetOrdersByAccount(w, req)
+ h.GetOrdersByAccountID(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@@ -268,18 +267,17 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
- expB, err := json.Marshal(oids)
+ expB, err := json.Marshal(oidURLs)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
@@ -288,47 +286,41 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
}
}
-func TestHandlerNewAccount(t *testing.T) {
- accID := "accountID"
- acc := acme.Account{
- ID: accID,
- Status: "valid",
- Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
- }
+func TestHandler_NewAccount(t *testing.T) {
prov := newProv()
- provName := url.PathEscape(prov.GetName())
+ escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct {
- auth acme.Interface
+ db acme.DB
+ acc *acme.Account
ctx context.Context
statusCode int
- problem *acme.Error
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-payload": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.Background(),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload expected in request context"),
}
},
"fail/nil-payload": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
+ ctx := context.WithValue(context.Background(), payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload expected in request context"),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
+ ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
+ err: acme.NewError(acme.ErrorMalformedType, "failed to "+
+ "unmarshal new-account request payload: unexpected end of JSON input"),
}
},
"fail/malformed-payload-error": func(t *testing.T) test {
@@ -337,12 +329,11 @@ func TestHandlerNewAccount(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
+ err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
}
},
"fail/no-existing-account": func(t *testing.T) test {
@@ -351,12 +342,11 @@ func TestHandlerNewAccount(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/no-jwk": func(t *testing.T) test {
@@ -365,12 +355,11 @@ func TestHandlerNewAccount(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
+ err: acme.NewErrorISE("jwk expected in request context"),
}
},
"fail/nil-jwk": func(t *testing.T) test {
@@ -379,16 +368,15 @@ func TestHandlerNewAccount(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- ctx = context.WithValue(ctx, acme.JwkContextKey, nil)
+ ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, jwkContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
+ err: acme.NewErrorISE("jwk expected in request context"),
}
},
- "fail/NewAccount-error": func(t *testing.T) test {
+ "fail/db.CreateAccount-error": func(t *testing.T) test {
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
}
@@ -396,23 +384,19 @@ func TestHandlerNewAccount(t *testing.T) {
assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
+ ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, jwkContextKey, jwk)
return test{
- auth: &mockAcmeAuthority{
- newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, ops.Contact, nar.Contact)
- assert.Equals(t, ops.Key, jwk)
- return nil, acme.ServerInternalErr(errors.New("force"))
+ db: &acme.MockDB{
+ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
+ assert.Equals(t, acc.Contact, nar.Contact)
+ assert.Equals(t, acc.Key, jwk)
+ return acme.NewErrorISE("force")
},
},
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
}
},
"ok/new-account": func(t *testing.T) test {
@@ -423,29 +407,26 @@ func TestHandlerNewAccount(t *testing.T) {
assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, jwkContextKey, jwk)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{
- auth: &mockAcmeAuthority{
- newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, ops.Contact, nar.Contact)
- assert.Equals(t, ops.Key, jwk)
- return &acc, nil
- },
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.True(t, abs)
- assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx))
- return fmt.Sprintf("%s/acme/%s/account/%s",
- baseURL.String(), provName, accID)
+ db: &acme.MockDB{
+ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
+ acc.ID = "accountID"
+ assert.Equals(t, acc.Contact, nar.Contact)
+ assert.Equals(t, acc.Key, jwk)
+ return nil
},
},
+ acc: &acme.Account{
+ ID: "accountID",
+ Key: jwk,
+ Status: acme.StatusValid,
+ Contact: []string{"foo", "bar"},
+ OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName),
+ },
ctx: ctx,
statusCode: 201,
}
@@ -456,22 +437,21 @@ func TestHandlerNewAccount(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ acc := &acme.Account{
+ ID: "accountID",
+ Key: jwk,
+ Status: acme.StatusValid,
+ Contact: []string{"foo", "bar"},
+ }
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx))
- assert.Equals(t, ins, []string{accID})
- return fmt.Sprintf("%s/acme/%s/account/%s",
- baseURL.String(), provName, accID)
- },
- },
ctx: ctx,
+ acc: acc,
statusCode: 200,
}
},
@@ -479,7 +459,7 @@ func TestHandlerNewAccount(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -492,90 +472,85 @@ func TestHandlerNewAccount(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
- expB, err := json.Marshal(acc)
+ expB, err := json.Marshal(tc.acc)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
- provName, accID)})
+ escProvName, "accountID")})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}
-func TestHandlerGetUpdateAccount(t *testing.T) {
+func TestHandler_GetOrUpdateAccount(t *testing.T) {
accID := "accountID"
acc := acme.Account{
- ID: accID,
- Status: "valid",
- Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
+ ID: accID,
+ Status: "valid",
+ OrdersURL: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
}
prov := newProv()
- provName := url.PathEscape(prov.GetName())
+ escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct {
- auth acme.Interface
+ db acme.DB
ctx context.Context
statusCode int
- problem *acme.Error
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.Background(),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, nil)
+ ctx := context.WithValue(context.Background(), accContextKey, nil)
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/no-payload": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
+ ctx := context.WithValue(context.Background(), accContextKey, &acc)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload expected in request context"),
}
},
"fail/nil-payload": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
+ ctx := context.WithValue(context.Background(), accContextKey, &acc)
+ ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload expected in request context"),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
+ ctx := context.WithValue(context.Background(), accContextKey, &acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
+ err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
}
},
"fail/malformed-payload-error": func(t *testing.T) test {
@@ -584,62 +559,33 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ ctx := context.WithValue(context.Background(), accContextKey, &acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
+ err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
}
},
- "fail/Deactivate-error": func(t *testing.T) test {
+ "fail/db.UpdateAccount-error": func(t *testing.T) test {
uar := &UpdateAccountRequest{
Status: "deactivated",
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ ctx := context.WithValue(context.Background(), accContextKey, &acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
- auth: &mockAcmeAuthority{
- deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, id, accID)
- return nil, acme.ServerInternalErr(errors.New("force"))
+ db: &acme.MockDB{
+ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
+ assert.Equals(t, upd.Status, acme.StatusDeactivated)
+ assert.Equals(t, upd.ID, acc.ID)
+ return acme.NewErrorISE("force")
},
},
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
- }
- },
- "fail/UpdateAccount-error": func(t *testing.T) test {
- uar := &UpdateAccountRequest{
- Contact: []string{"foo", "bar"},
- }
- b, err := json.Marshal(uar)
- assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- return test{
- auth: &mockAcmeAuthority{
- updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, id, accID)
- assert.Equals(t, contacts, uar.Contact)
- return nil, acme.ServerInternalErr(errors.New("force"))
- },
- },
- ctx: ctx,
- statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
}
},
"ok/deactivate": func(t *testing.T) test {
@@ -648,26 +594,16 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, &acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, id, accID)
- return &acc, nil
- },
- getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
- assert.Equals(t, ins, []string{accID})
- return fmt.Sprintf("%s/acme/%s/account/%s",
- baseURL.String(), provName, accID)
+ db: &acme.MockDB{
+ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
+ assert.Equals(t, upd.Status, acme.StatusDeactivated)
+ assert.Equals(t, upd.ID, acc.ID)
+ return nil
},
},
ctx: ctx,
@@ -678,21 +614,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
uar := &UpdateAccountRequest{}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, &acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
- assert.Equals(t, ins, []string{accID})
- return fmt.Sprintf("%s/acme/%s/account/%s",
- baseURL.String(), provName, accID)
- },
- },
ctx: ctx,
statusCode: 200,
}
@@ -703,27 +629,16 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, &acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, id, accID)
- assert.Equals(t, contacts, uar.Contact)
- return &acc, nil
- },
- getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
- assert.Equals(t, ins, []string{accID})
- return fmt.Sprintf("%s/acme/%s/account/%s",
- baseURL.String(), provName, accID)
+ db: &acme.MockDB{
+ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
+ assert.Equals(t, upd.Contact, uar.Contact)
+ assert.Equals(t, upd.ID, acc.ID)
+ return nil
},
},
ctx: ctx,
@@ -731,21 +646,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
}
},
"ok/post-as-get": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true})
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, &acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
- assert.Equals(t, ins, []string{accID})
- return fmt.Sprintf("%s/acme/%s/account/%s",
- baseURL, provName, accID)
- },
- },
ctx: ctx,
statusCode: 200,
}
@@ -754,11 +659,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
- h.GetUpdateAccount(w, req)
+ h.GetOrUpdateAccount(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@@ -767,15 +672,14 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(acc)
@@ -783,7 +687,7 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
- provName, accID)})
+ escProvName, accID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
diff --git a/acme/api/handler.go b/acme/api/handler.go
index 921e614e..c1d2d62a 100644
--- a/acme/api/handler.go
+++ b/acme/api/handler.go
@@ -1,56 +1,98 @@
package api
import (
- "context"
+ "crypto/tls"
"crypto/x509"
+ "encoding/json"
"encoding/pem"
"fmt"
+ "net"
"net/http"
+ "time"
"github.com/go-chi/chi"
- "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
+ "github.com/smallstep/certificates/authority/provisioner"
)
func link(url, typ string) string {
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
}
+// Clock that returns time in UTC rounded to seconds.
+type Clock struct{}
+
+// Now returns the UTC time rounded to seconds.
+func (c *Clock) Now() time.Time {
+ return time.Now().UTC().Truncate(time.Second)
+}
+
+var clock Clock
+
type payloadInfo struct {
value []byte
isPostAsGet bool
isEmptyJSON bool
}
-// payloadFromContext searches the context for a payload. Returns the payload
-// or an error.
-func payloadFromContext(ctx context.Context) (*payloadInfo, error) {
- val, ok := ctx.Value(acme.PayloadContextKey).(*payloadInfo)
- if !ok || val == nil {
- return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context"))
- }
- return val, nil
-}
-
-// New returns a new ACME API router.
-func New(acmeAuth acme.Interface) api.RouterHandler {
- return &Handler{acmeAuth}
-}
-
-// Handler is the ACME request handler.
+// Handler is the ACME API request handler.
type Handler struct {
- Auth acme.Interface
+ db acme.DB
+ backdate provisioner.Duration
+ ca acme.CertificateAuthority
+ linker Linker
+ validateChallengeOptions *acme.ValidateChallengeOptions
+}
+
+// HandlerOptions required to create a new ACME API request handler.
+type HandlerOptions struct {
+ Backdate provisioner.Duration
+ // DB storage backend that impements the acme.DB interface.
+ DB acme.DB
+ // DNS the host used to generate accurate ACME links. By default the authority
+ // will use the Host from the request, so this value will only be used if
+ // request.Host is empty.
+ DNS string
+ // Prefix is a URL path prefix under which the ACME api is served. This
+ // prefix is required to generate accurate ACME links.
+ // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
+ // "acme" is the prefix from which the ACME api is accessed.
+ Prefix string
+ CA acme.CertificateAuthority
+}
+
+// NewHandler returns a new ACME API handler.
+func NewHandler(ops HandlerOptions) api.RouterHandler {
+ client := http.Client{
+ Timeout: 30 * time.Second,
+ }
+ dialer := &net.Dialer{
+ Timeout: 30 * time.Second,
+ }
+ return &Handler{
+ ca: ops.CA,
+ db: ops.DB,
+ backdate: ops.Backdate,
+ linker: NewLinker(ops.DNS, ops.Prefix),
+ validateChallengeOptions: &acme.ValidateChallengeOptions{
+ HTTPGet: client.Get,
+ LookupTxt: net.LookupTXT,
+ TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
+ return tls.DialWithDialer(dialer, network, addr, config)
+ },
+ },
+ }
}
// Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) {
- getLink := h.Auth.GetLinkExplicit
+ getPath := h.linker.GetUnescapedPathSuffix
// Standard ACME API
- r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce))))
- r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce))))
- r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
- r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
+ r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
+ r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
+ r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
+ r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))
@@ -59,16 +101,16 @@ func (h *Handler) Route(r api.Router) {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))))
}
- r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount))
- r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount))
- r.MethodFunc("POST", getLink(acme.KeyChangeLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented))
- r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder))
- r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
- r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount)))
- r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
- r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz)))
- r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, nil, "{chID}"), extractPayloadByKid(h.GetChallenge))
- r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
+ r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
+ r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
+ r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented))
+ r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder))
+ r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
+ r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID)))
+ r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
+ r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
+ r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
+ r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
}
// GetNonce just sets the right header since a Nonce is added to each response
@@ -81,101 +123,153 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
}
}
+// Directory represents an ACME directory for configuring clients.
+type Directory struct {
+ NewNonce string `json:"newNonce"`
+ NewAccount string `json:"newAccount"`
+ NewOrder string `json:"newOrder"`
+ RevokeCert string `json:"revokeCert"`
+ KeyChange string `json:"keyChange"`
+}
+
+// ToLog enables response logging for the Directory type.
+func (d *Directory) ToLog() (interface{}, error) {
+ b, err := json.Marshal(d)
+ if err != nil {
+ return nil, acme.WrapErrorISE(err, "error marshaling directory for logging")
+ }
+ return string(b), nil
+}
+
// GetDirectory is the ACME resource for returning a directory configuration
// for client configuration.
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
- dir, err := h.Auth.GetDirectory(r.Context())
- if err != nil {
- api.WriteError(w, err)
- return
- }
- api.JSON(w, dir)
+ ctx := r.Context()
+ api.JSON(w, &Directory{
+ NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
+ NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
+ NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
+ RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType),
+ KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType),
+ })
}
// NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
- api.WriteError(w, acme.NotImplemented(nil).ToACME())
+ api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
}
-// GetAuthz ACME api for retrieving an Authz.
-func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) {
- acc, err := acme.AccountFromContext(r.Context())
+// GetAuthorization ACME api for retrieving an Authz.
+func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
- authz, err := h.Auth.GetAuthz(r.Context(), acc.GetID(), chi.URLParam(r, "authzID"))
+ az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil {
- api.WriteError(w, err)
+ api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization"))
+ return
+ }
+ if acc.ID != az.AccountID {
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
+ "account '%s' does not own authorization '%s'", acc.ID, az.ID))
+ return
+ }
+ if err = az.UpdateStatus(ctx, h.db); err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status"))
return
}
- w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AuthzLink, true, authz.GetID()))
- api.JSON(w, authz)
+ h.linker.LinkAuthorization(ctx, az)
+
+ w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
+ api.JSON(w, az)
}
// GetChallenge ACME api for retrieving a Challenge.
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
- acc, err := acme.AccountFromContext(r.Context())
+ ctx := r.Context()
+ acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
// Just verify that the payload was set, since we're not strictly adhering
// to ACME V2 spec for reasons specified below.
- _, err = payloadFromContext(r.Context())
+ _, err = payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
- // NOTE: We should be checking that the request is either a POST-as-GET, or
+ // NOTE: We should be checking ^^^ that the request is either a POST-as-GET, or
// that the payload is an empty JSON block ({}). However, older ACME clients
// still send a vestigial body (rather than an empty JSON block) and
// strict enforcement would render these clients broken. For the time being
// we'll just ignore the body.
- var (
- ch *acme.Challenge
- chID = chi.URLParam(r, "chID")
- )
- ch, err = h.Auth.ValidateChallenge(r.Context(), acc.GetID(), chID, acc.GetKey())
+
+ azID := chi.URLParam(r, "authzID")
+ ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
+ if err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge"))
+ return
+ }
+ ch.AuthorizationID = azID
+ if acc.ID != ch.AccountID {
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
+ "account '%s' does not own challenge '%s'", acc.ID, ch.ID))
+ return
+ }
+ jwk, err := jwkFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
+ if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge"))
+ return
+ }
- w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), acme.AuthzLink, true, ch.GetAuthzID()), "up"))
- w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.ChallengeLink, true, ch.GetID()))
+ h.linker.LinkChallenge(ctx, ch, azID)
+
+ w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
+ w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
api.JSON(w, ch)
}
// GetCertificate ACME api for retrieving a Certificate.
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
- acc, err := acme.AccountFromContext(r.Context())
+ ctx := r.Context()
+ acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
certID := chi.URLParam(r, "certID")
- certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID)
+
+ cert, err := h.db.GetCertificate(ctx, certID)
if err != nil {
- api.WriteError(w, err)
+ api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate"))
+ return
+ }
+ if cert.AccountID != acc.ID {
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
+ "account '%s' does not own certificate '%s'", acc.ID, certID))
return
}
- block, _ := pem.Decode(certBytes)
- if block == nil {
- api.WriteError(w, acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes")))
- return
- }
- cert, err := x509.ParseCertificate(block.Bytes)
- if err != nil {
- api.WriteError(w, acme.Wrap(err, "failed to parse generated leaf certificate"))
- return
+ var certBytes []byte
+ for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) {
+ certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: c.Raw,
+ })...)
}
- api.LogCertificate(w, cert)
+ api.LogCertificate(w, cert.Leaf)
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
w.Write(certBytes)
}
diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go
index 7e19ea75..5501479d 100644
--- a/acme/api/handler_test.go
+++ b/acme/api/handler_test.go
@@ -8,6 +8,7 @@ import (
"encoding/pem"
"fmt"
"io/ioutil"
+ "net/http"
"net/http/httptest"
"net/url"
"testing"
@@ -17,206 +18,11 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
- "github.com/smallstep/certificates/authority/provisioner"
- "github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
)
-type mockAcmeAuthority struct {
- getLink func(ctx context.Context, link acme.Link, absPath bool, ins ...string) string
- getLinkExplicit func(acme.Link, string, bool, *url.URL, ...string) string
-
- deactivateAccount func(ctx context.Context, accID string) (*acme.Account, error)
- getAccount func(ctx context.Context, accID string) (*acme.Account, error)
- getAccountByKey func(ctx context.Context, key *jose.JSONWebKey) (*acme.Account, error)
- newAccount func(ctx context.Context, ao acme.AccountOptions) (*acme.Account, error)
- updateAccount func(context.Context, string, []string) (*acme.Account, error)
-
- getChallenge func(ctx context.Context, accID string, chID string) (*acme.Challenge, error)
- validateChallenge func(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*acme.Challenge, error)
- getAuthz func(ctx context.Context, accID string, authzID string) (*acme.Authz, error)
- getDirectory func(ctx context.Context) (*acme.Directory, error)
- getCertificate func(string, string) ([]byte, error)
-
- finalizeOrder func(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*acme.Order, error)
- getOrder func(ctx context.Context, accID string, orderID string) (*acme.Order, error)
- getOrdersByAccount func(ctx context.Context, accID string) ([]string, error)
- newOrder func(ctx context.Context, oo acme.OrderOptions) (*acme.Order, error)
-
- loadProvisionerByID func(string) (provisioner.Interface, error)
- newNonce func() (string, error)
- useNonce func(string) error
- ret1 interface{}
- err error
-}
-
-func (m *mockAcmeAuthority) DeactivateAccount(ctx context.Context, id string) (*acme.Account, error) {
- if m.deactivateAccount != nil {
- return m.deactivateAccount(ctx, id)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Account), m.err
-}
-
-func (m *mockAcmeAuthority) FinalizeOrder(ctx context.Context, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) {
- if m.finalizeOrder != nil {
- return m.finalizeOrder(ctx, accID, id, csr)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Order), m.err
-}
-
-func (m *mockAcmeAuthority) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
- if m.getAccount != nil {
- return m.getAccount(ctx, id)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Account), m.err
-}
-
-func (m *mockAcmeAuthority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
- if m.getAccountByKey != nil {
- return m.getAccountByKey(ctx, jwk)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Account), m.err
-}
-
-func (m *mockAcmeAuthority) GetAuthz(ctx context.Context, accID, id string) (*acme.Authz, error) {
- if m.getAuthz != nil {
- return m.getAuthz(ctx, accID, id)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Authz), m.err
-}
-
-func (m *mockAcmeAuthority) GetCertificate(accID string, id string) ([]byte, error) {
- if m.getCertificate != nil {
- return m.getCertificate(accID, id)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.([]byte), m.err
-}
-
-func (m *mockAcmeAuthority) GetChallenge(ctx context.Context, accID, id string) (*acme.Challenge, error) {
- if m.getChallenge != nil {
- return m.getChallenge(ctx, accID, id)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Challenge), m.err
-}
-
-func (m *mockAcmeAuthority) GetDirectory(ctx context.Context) (*acme.Directory, error) {
- if m.getDirectory != nil {
- return m.getDirectory(ctx)
- }
- return m.ret1.(*acme.Directory), m.err
-}
-
-func (m *mockAcmeAuthority) GetLink(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
- if m.getLink != nil {
- return m.getLink(ctx, typ, abs, ins...)
- }
- return m.ret1.(string)
-}
-
-func (m *mockAcmeAuthority) GetLinkExplicit(typ acme.Link, provID string, abs bool, baseURL *url.URL, ins ...string) string {
- if m.getLinkExplicit != nil {
- return m.getLinkExplicit(typ, provID, abs, baseURL, ins...)
- }
- return m.ret1.(string)
-}
-
-func (m *mockAcmeAuthority) GetOrder(ctx context.Context, accID, id string) (*acme.Order, error) {
- if m.getOrder != nil {
- return m.getOrder(ctx, accID, id)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Order), m.err
-}
-
-func (m *mockAcmeAuthority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) {
- if m.getOrdersByAccount != nil {
- return m.getOrdersByAccount(ctx, id)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.([]string), m.err
-}
-
-func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
- if m.loadProvisionerByID != nil {
- return m.loadProvisionerByID(provID)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(provisioner.Interface), m.err
-}
-
-func (m *mockAcmeAuthority) NewAccount(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) {
- if m.newAccount != nil {
- return m.newAccount(ctx, ops)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Account), m.err
-}
-
-func (m *mockAcmeAuthority) NewNonce() (string, error) {
- if m.newNonce != nil {
- return m.newNonce()
- } else if m.err != nil {
- return "", m.err
- }
- return m.ret1.(string), m.err
-}
-
-func (m *mockAcmeAuthority) NewOrder(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) {
- if m.newOrder != nil {
- return m.newOrder(ctx, ops)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Order), m.err
-}
-
-func (m *mockAcmeAuthority) UpdateAccount(ctx context.Context, id string, contact []string) (*acme.Account, error) {
- if m.updateAccount != nil {
- return m.updateAccount(ctx, id, contact)
- } else if m.err != nil {
- return nil, m.err
- }
- return m.ret1.(*acme.Account), m.err
-}
-
-func (m *mockAcmeAuthority) UseNonce(nonce string) error {
- if m.useNonce != nil {
- return m.useNonce(nonce)
- }
- return m.err
-}
-
-func (m *mockAcmeAuthority) ValidateChallenge(ctx context.Context, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
- switch {
- case m.validateChallenge != nil:
- return m.validateChallenge(ctx, accID, id, jwk)
- case m.err != nil:
- return nil, m.err
- default:
- return m.ret1.(*acme.Challenge), m.err
- }
-}
-
-func TestHandlerGetNonce(t *testing.T) {
+func TestHandler_GetNonce(t *testing.T) {
tests := []struct {
name string
statusCode int
@@ -230,7 +36,7 @@ func TestHandlerGetNonce(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- h := New(nil).(*Handler)
+ h := &Handler{}
w := httptest.NewRecorder()
req.Method = tt.name
h.GetNonce(w, req)
@@ -243,21 +49,16 @@ func TestHandlerGetNonce(t *testing.T) {
}
}
-func TestHandlerGetDirectory(t *testing.T) {
- auth, err := acme.New(nil, acme.AuthorityOptions{
- DB: new(db.MockNoSQLDB),
- DNS: "ca.smallstep.com",
- Prefix: "acme",
- })
- assert.FatalError(t, err)
+func TestHandler_GetDirectory(t *testing.T) {
+ linker := NewLinker("ca.smallstep.com", "acme")
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
- expDir := acme.Directory{
+ expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
@@ -267,7 +68,7 @@ func TestHandlerGetDirectory(t *testing.T) {
type test struct {
statusCode int
- problem *acme.Error
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"ok": func(t *testing.T) test {
@@ -279,7 +80,7 @@ func TestHandlerGetDirectory(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(auth).(*Handler)
+ h := &Handler{linker: linker}
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
@@ -292,18 +93,17 @@ func TestHandlerGetDirectory(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
- var dir acme.Directory
+ var dir Directory
json.Unmarshal(bytes.TrimSpace(body), &dir)
assert.Equals(t, dir, expDir)
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
@@ -312,33 +112,32 @@ func TestHandlerGetDirectory(t *testing.T) {
}
}
-func TestHandlerGetAuthz(t *testing.T) {
+func TestHandler_GetAuthorization(t *testing.T) {
expiry := time.Now().UTC().Add(6 * time.Hour)
- az := acme.Authz{
- ID: "authzID",
+ az := acme.Authorization{
+ ID: "authzID",
+ AccountID: "accID",
Identifier: acme.Identifier{
Type: "dns",
Value: "example.com",
},
- Status: "pending",
- Expires: expiry.Format(time.RFC3339),
- Wildcard: false,
+ Status: "pending",
+ ExpiresAt: expiry,
+ Wildcard: false,
Challenges: []*acme.Challenge{
{
- Type: "http-01",
- Status: "pending",
- Token: "tok2",
- URL: "https://ca.smallstep.com/acme/challenge/chHTTPID",
- ID: "chHTTP01ID",
- AuthzID: "authzID",
+ Type: "http-01",
+ Status: "pending",
+ Token: "tok2",
+ URL: "https://ca.smallstep.com/acme/challenge/chHTTPID",
+ ID: "chHTTP01ID",
},
{
- Type: "dns-01",
- Status: "pending",
- Token: "tok2",
- URL: "https://ca.smallstep.com/acme/challenge/chDNSID",
- ID: "chDNSID",
- AuthzID: "authzID",
+ Type: "dns-01",
+ Status: "pending",
+ Token: "tok2",
+ URL: "https://ca.smallstep.com/acme/challenge/chDNSID",
+ ID: "chDNSID",
},
},
}
@@ -349,71 +148,101 @@ func TestHandlerGetAuthz(t *testing.T) {
// Request with chi context
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("authzID", az.ID)
- url := fmt.Sprintf("%s/acme/%s/challenge/%s",
+ url := fmt.Sprintf("%s/acme/%s/authz/%s",
baseURL.String(), provName, az.ID)
type test struct {
- auth acme.Interface
+ db acme.DB
ctx context.Context
statusCode int
- problem *acme.Error
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
- auth: &mockAcmeAuthority{},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ db: &acme.MockDB{},
+ ctx: context.Background(),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, nil)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, nil)
return test{
- auth: &mockAcmeAuthority{},
+ db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
- "fail/getAuthz-error": func(t *testing.T) test {
+ "fail/db.GetAuthorization-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
- auth: &mockAcmeAuthority{
- err: acme.ServerInternalErr(errors.New("force")),
+ db: &acme.MockDB{
+ MockError: acme.NewErrorISE("force"),
},
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
+ }
+ },
+ "fail/account-id-mismatch": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accID"}
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
+ assert.Equals(t, id, az.ID)
+ return &acme.Authorization{
+ AccountID: "foo",
+ }, nil
+ },
+ },
+ ctx: ctx,
+ statusCode: 401,
+ err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"),
+ }
+ },
+ "fail/db.UpdateAuthorization-error": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accID"}
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
+ assert.Equals(t, id, az.ID)
+ return &acme.Authorization{
+ AccountID: "accID",
+ Status: acme.StatusPending,
+ ExpiresAt: time.Now().Add(-1 * time.Hour),
+ }, nil
+ },
+ MockUpdateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
+ assert.Equals(t, az.Status, acme.StatusInvalid)
+ return acme.NewErrorISE("force")
+ },
+ },
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("force"),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getAuthz: func(ctx context.Context, accID, id string) (*acme.Authz, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, accID, acc.ID)
+ db: &acme.MockDB{
+ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
assert.Equals(t, id, az.ID)
return &az, nil
},
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.AuthzLink)
- assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
- assert.True(t, abs)
- assert.Equals(t, in, []string{az.ID})
- return url
- },
},
ctx: ctx,
statusCode: 200,
@@ -423,11 +252,11 @@ func TestHandlerGetAuthz(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
- h.GetAuthz(w, req)
+ h.GetAuthorization(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@@ -436,15 +265,14 @@ func TestHandlerGetAuthz(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
//var gotAz acme.Authz
@@ -459,7 +287,7 @@ func TestHandlerGetAuthz(t *testing.T) {
}
}
-func TestHandlerGetCertificate(t *testing.T) {
+func TestHandler_GetCertificate(t *testing.T) {
leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt")
assert.FatalError(t, err)
inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt")
@@ -490,89 +318,73 @@ func TestHandlerGetCertificate(t *testing.T) {
baseURL.String(), provName, certID)
type test struct {
- auth acme.Interface
+ db acme.DB
ctx context.Context
statusCode int
- problem *acme.Error
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
- auth: &mockAcmeAuthority{},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ db: &acme.MockDB{},
+ ctx: context.Background(),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.AccContextKey, nil)
+ ctx := context.WithValue(context.Background(), accContextKey, nil)
return test{
- auth: &mockAcmeAuthority{},
+ db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
- "fail/getCertificate-error": func(t *testing.T) test {
+ "fail/db.GetCertificate-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
- auth: &mockAcmeAuthority{
- err: acme.ServerInternalErr(errors.New("force")),
+ db: &acme.MockDB{
+ MockError: acme.NewErrorISE("force"),
},
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
}
},
- "fail/decode-leaf-for-loggger": func(t *testing.T) test {
+ "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
- auth: &mockAcmeAuthority{
- getCertificate: func(accID, id string) ([]byte, error) {
- assert.Equals(t, accID, acc.ID)
+ db: &acme.MockDB{
+ MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) {
assert.Equals(t, id, certID)
- return []byte("foo"), nil
+ return &acme.Certificate{AccountID: "foo"}, nil
},
},
ctx: ctx,
- statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes")),
- }
- },
- "fail/parse-x509-leaf-for-logger": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
- return test{
- auth: &mockAcmeAuthority{
- getCertificate: func(accID, id string) ([]byte, error) {
- assert.Equals(t, accID, acc.ID)
- assert.Equals(t, id, certID)
- return pem.EncodeToMemory(&pem.Block{
- Type: "CERTIFICATE REQUEST",
- Bytes: []byte("foo"),
- }), nil
- },
- },
- ctx: ctx,
- statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("failed to parse generated leaf certificate")),
+ statusCode: 401,
+ err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
- auth: &mockAcmeAuthority{
- getCertificate: func(accID, id string) ([]byte, error) {
- assert.Equals(t, accID, acc.ID)
+ db: &acme.MockDB{
+ MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) {
assert.Equals(t, id, certID)
- return certBytes, nil
+ return &acme.Certificate{
+ AccountID: "accID",
+ OrderID: "ordID",
+ Leaf: leaf,
+ Intermediates: []*x509.Certificate{inter, root},
+ ID: id,
+ }, nil
},
},
ctx: ctx,
@@ -583,7 +395,7 @@ func TestHandlerGetCertificate(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -596,15 +408,14 @@ func TestHandlerGetCertificate(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.HasPrefix(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.HasPrefix(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes))
@@ -614,152 +425,233 @@ func TestHandlerGetCertificate(t *testing.T) {
}
}
-func ch() acme.Challenge {
- return acme.Challenge{
- Type: "http-01",
- Status: "pending",
- Token: "tok2",
- URL: "https://ca.smallstep.com/acme/challenge/chID",
- ID: "chID",
- AuthzID: "authzID",
- }
-}
-
-func TestHandlerGetChallenge(t *testing.T) {
+func TestHandler_GetChallenge(t *testing.T) {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("chID", "chID")
+ chiCtx.URLParams.Add("authzID", "authzID")
prov := newProv()
provName := url.PathEscape(prov.GetName())
+
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- url := fmt.Sprintf("%s/acme/challenge/%s", baseURL, "chID")
+
+ url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
+ baseURL.String(), provName, "authzID", "chID")
type test struct {
- auth acme.Interface
+ db acme.DB
+ vco *acme.ValidateChallengeOptions
ctx context.Context
statusCode int
- ch acme.Challenge
- problem *acme.Error
+ ch *acme.Challenge
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.Background(),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, nil)
return test{
- ctx: ctx,
+ ctx: context.WithValue(context.Background(), accContextKey, nil),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload expected in request context"),
}
},
"fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload expected in request context"),
+ }
+ },
+ "fail/db.GetChallenge-error": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
+ assert.Equals(t, chID, "chID")
+ assert.Equals(t, azID, "authzID")
+ return nil, acme.NewErrorISE("force")
+ },
+ },
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("force"),
+ }
+ },
+ "fail/account-id-mismatch": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
+ assert.Equals(t, chID, "chID")
+ assert.Equals(t, azID, "authzID")
+ return &acme.Challenge{AccountID: "foo"}, nil
+ },
+ },
+ ctx: ctx,
+ statusCode: 401,
+ err: acme.NewError(acme.ErrorUnauthorizedType, "accout id mismatch"),
+ }
+ },
+ "fail/no-jwk": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
+ assert.Equals(t, chID, "chID")
+ assert.Equals(t, azID, "authzID")
+ return &acme.Challenge{AccountID: "accID"}, nil
+ },
+ },
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("missing jwk"),
+ }
+ },
+ "fail/nil-jwk": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
+ ctx = context.WithValue(ctx, jwkContextKey, nil)
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
+ assert.Equals(t, chID, "chID")
+ assert.Equals(t, azID, "authzID")
+ return &acme.Challenge{AccountID: "accID"}, nil
+ },
+ },
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("nil jwk"),
}
},
"fail/validate-challenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true})
- ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
- return test{
- auth: &mockAcmeAuthority{
- err: acme.UnauthorizedErr(nil),
- },
- ctx: ctx,
- statusCode: 401,
- problem: acme.UnauthorizedErr(nil),
- }
- },
- "fail/get-challenge-error": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true})
- ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
- return test{
- auth: &mockAcmeAuthority{
- err: acme.UnauthorizedErr(nil),
- },
- ctx: ctx,
- statusCode: 401,
- problem: acme.UnauthorizedErr(nil),
- }
- },
- "ok/validate-challenge": func(t *testing.T) test {
- key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
+ _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- acc := &acme.Account{ID: "accID", Key: key}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true})
+ _pub := _jwk.Public()
+ ctx = context.WithValue(ctx, jwkContextKey, &_pub)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
- ch := ch()
- ch.Status = "valid"
- ch.Validated = time.Now().UTC().Format(time.RFC3339)
- count := 0
return test{
- auth: &mockAcmeAuthority{
- validateChallenge: func(ctx context.Context, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, accID, acc.ID)
- assert.Equals(t, id, ch.ID)
- assert.Equals(t, jwk.KeyID, key.KeyID)
- return &ch, nil
+ db: &acme.MockDB{
+ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
+ assert.Equals(t, chID, "chID")
+ assert.Equals(t, azID, "authzID")
+ return &acme.Challenge{
+ Status: acme.StatusPending,
+ Type: "http-01",
+ AccountID: "accID",
+ }, nil
},
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- var ret string
- switch count {
- case 0:
- assert.Equals(t, typ, acme.AuthzLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{ch.AuthzID})
- ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID)
- case 1:
- assert.Equals(t, typ, acme.ChallengeLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{ch.ID})
- ret = url
- }
- count++
- return ret
+ MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Type, "http-01")
+ assert.Equals(t, ch.AccountID, "accID")
+ assert.Equals(t, ch.AuthorizationID, "authzID")
+ assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String())
+ return acme.NewErrorISE("force")
+ },
+ },
+ vco: &acme.ValidateChallengeOptions{
+ HTTPGet: func(string) (*http.Response, error) {
+ return nil, errors.New("force")
+ },
+ },
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
+ _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ _pub := _jwk.Public()
+ ctx = context.WithValue(ctx, jwkContextKey, &_pub)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) {
+ assert.Equals(t, chID, "chID")
+ assert.Equals(t, azID, "authzID")
+ return &acme.Challenge{
+ ID: "chID",
+ Status: acme.StatusPending,
+ Type: "http-01",
+ AccountID: "accID",
+ }, nil
+ },
+ MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Type, "http-01")
+ assert.Equals(t, ch.AccountID, "accID")
+ assert.Equals(t, ch.AuthorizationID, "authzID")
+ assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String())
+ return nil
+ },
+ },
+ ch: &acme.Challenge{
+ ID: "chID",
+ Status: acme.StatusPending,
+ AuthorizationID: "authzID",
+ Type: "http-01",
+ AccountID: "accID",
+ URL: url,
+ Error: acme.NewError(acme.ErrorConnectionType, "force"),
+ },
+ vco: &acme.ValidateChallengeOptions{
+ HTTPGet: func(string) (*http.Response, error) {
+ return nil, errors.New("force")
},
},
ctx: ctx,
statusCode: 200,
- ch: ch,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -772,21 +664,20 @@ func TestHandlerGetChallenge(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(tc.ch)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
- assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, tc.ch.AuthzID)})
+ assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")})
assert.Equals(t, res.Header["Location"], []string{url})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
diff --git a/acme/api/linker.go b/acme/api/linker.go
new file mode 100644
index 00000000..d4490470
--- /dev/null
+++ b/acme/api/linker.go
@@ -0,0 +1,181 @@
+package api
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+
+ "github.com/smallstep/certificates/acme"
+)
+
+// NewLinker returns a new Directory type.
+func NewLinker(dns, prefix string) Linker {
+ return &linker{prefix: prefix, dns: dns}
+}
+
+// Linker interface for generating links for ACME resources.
+type Linker interface {
+ GetLink(ctx context.Context, typ LinkType, inputs ...string) string
+ GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
+
+ LinkOrder(ctx context.Context, o *acme.Order)
+ LinkAccount(ctx context.Context, o *acme.Account)
+ LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
+ LinkAuthorization(ctx context.Context, o *acme.Authorization)
+ LinkOrdersByAccountID(ctx context.Context, orders []string)
+}
+
+// linker generates ACME links.
+type linker struct {
+ prefix string
+ dns string
+}
+
+func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
+ switch typ {
+ case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
+ return fmt.Sprintf("/%s/%s", provisionerName, typ)
+ case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
+ return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
+ case ChallengeLinkType:
+ return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
+ case OrdersByAccountLinkType:
+ return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
+ case FinalizeLinkType:
+ return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
+ default:
+ return ""
+ }
+}
+
+// GetLink is a helper for GetLinkExplicit
+func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
+ var (
+ provName string
+ baseURL = baseURLFromContext(ctx)
+ u = url.URL{}
+ )
+ if p, err := provisionerFromContext(ctx); err == nil && p != nil {
+ provName = p.GetName()
+ }
+ // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
+ if baseURL != nil {
+ u = *baseURL
+ }
+
+ u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
+
+ // If no Scheme is set, then default to https.
+ if u.Scheme == "" {
+ u.Scheme = "https"
+ }
+
+ // If no Host is set, then use the default (first DNS attr in the ca.json).
+ if u.Host == "" {
+ u.Host = l.dns
+ }
+
+ u.Path = l.prefix + u.Path
+ return u.String()
+}
+
+// LinkType captures the link type.
+type LinkType int
+
+const (
+ // NewNonceLinkType new-nonce
+ NewNonceLinkType LinkType = iota
+ // NewAccountLinkType new-account
+ NewAccountLinkType
+ // AccountLinkType account
+ AccountLinkType
+ // OrderLinkType order
+ OrderLinkType
+ // NewOrderLinkType new-order
+ NewOrderLinkType
+ // OrdersByAccountLinkType list of orders owned by account
+ OrdersByAccountLinkType
+ // FinalizeLinkType finalize order
+ FinalizeLinkType
+ // NewAuthzLinkType authz
+ NewAuthzLinkType
+ // AuthzLinkType new-authz
+ AuthzLinkType
+ // ChallengeLinkType challenge
+ ChallengeLinkType
+ // CertificateLinkType certificate
+ CertificateLinkType
+ // DirectoryLinkType directory
+ DirectoryLinkType
+ // RevokeCertLinkType revoke certificate
+ RevokeCertLinkType
+ // KeyChangeLinkType key rollover
+ KeyChangeLinkType
+)
+
+func (l LinkType) String() string {
+ switch l {
+ case NewNonceLinkType:
+ return "new-nonce"
+ case NewAccountLinkType:
+ return "new-account"
+ case AccountLinkType:
+ return "account"
+ case NewOrderLinkType:
+ return "new-order"
+ case OrderLinkType:
+ return "order"
+ case NewAuthzLinkType:
+ return "new-authz"
+ case AuthzLinkType:
+ return "authz"
+ case ChallengeLinkType:
+ return "challenge"
+ case CertificateLinkType:
+ return "certificate"
+ case DirectoryLinkType:
+ return "directory"
+ case RevokeCertLinkType:
+ return "revoke-cert"
+ case KeyChangeLinkType:
+ return "key-change"
+ default:
+ return fmt.Sprintf("unexpected LinkType '%d'", int(l))
+ }
+}
+
+// LinkOrder sets the ACME links required by an ACME order.
+func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
+ o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
+ for i, azID := range o.AuthorizationIDs {
+ o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
+ }
+ o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, o.ID)
+ if o.CertificateID != "" {
+ o.CertificateURL = l.GetLink(ctx, CertificateLinkType, o.CertificateID)
+ }
+}
+
+// LinkAccount sets the ACME links required by an ACME account.
+func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
+ acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
+}
+
+// LinkChallenge sets the ACME links required by an ACME challenge.
+func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
+ ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
+}
+
+// LinkAuthorization sets the ACME links required by an ACME authorization.
+func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
+ for _, ch := range az.Challenges {
+ l.LinkChallenge(ctx, ch, az.ID)
+ }
+}
+
+// LinkOrdersByAccountID converts each order ID to an ACME link.
+func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) {
+ for i, id := range orders {
+ orders[i] = l.GetLink(ctx, OrderLinkType, id)
+ }
+}
diff --git a/acme/api/linker_test.go b/acme/api/linker_test.go
new file mode 100644
index 00000000..4790dec8
--- /dev/null
+++ b/acme/api/linker_test.go
@@ -0,0 +1,283 @@
+package api
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "testing"
+
+ "github.com/smallstep/assert"
+ "github.com/smallstep/certificates/acme"
+)
+
+func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
+ dns := "ca.smallstep.com"
+ prefix := "acme"
+ linker := NewLinker(dns, prefix)
+
+ getPath := linker.GetUnescapedPathSuffix
+
+ assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
+ assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
+ assert.Equals(t, getPath(NewAccountLinkType, "{provisionerID}"), "/{provisionerID}/new-account")
+ assert.Equals(t, getPath(AccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}")
+ assert.Equals(t, getPath(KeyChangeLinkType, "{provisionerID}"), "/{provisionerID}/key-change")
+ assert.Equals(t, getPath(NewOrderLinkType, "{provisionerID}"), "/{provisionerID}/new-order")
+ assert.Equals(t, getPath(OrderLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}")
+ assert.Equals(t, getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}/orders")
+ assert.Equals(t, getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}/finalize")
+ assert.Equals(t, getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), "/{provisionerID}/authz/{authzID}")
+ assert.Equals(t, getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), "/{provisionerID}/challenge/{authzID}/{chID}")
+ assert.Equals(t, getPath(CertificateLinkType, "{provisionerID}", "{certID}"), "/{provisionerID}/certificate/{certID}")
+}
+
+func TestLinker_GetLink(t *testing.T) {
+ dns := "ca.smallstep.com"
+ prefix := "acme"
+ linker := NewLinker(dns, prefix)
+ id := "1234"
+
+ prov := newProv()
+ escProvName := url.PathEscape(prov.GetName())
+ baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
+
+ // No provisioner and no BaseURL from request
+ assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
+ // Provisioner: yes, BaseURL: no
+ assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
+
+ // Provisioner: no, BaseURL: yes
+ assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
+
+ assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
+ assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, NewAccountLinkType), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, AccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, NewOrderLinkType), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, OrderLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, OrdersByAccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, FinalizeLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, NewAuthzLinkType), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, AuthzLinkType, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, DirectoryLinkType), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, RevokeCertLinkType, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, KeyChangeLinkType), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName))
+
+ assert.Equals(t, linker.GetLink(ctx, ChallengeLinkType, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id))
+
+ assert.Equals(t, linker.GetLink(ctx, CertificateLinkType, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName))
+}
+
+func TestLinker_LinkOrder(t *testing.T) {
+ baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
+ prov := newProv()
+ provName := url.PathEscape(prov.GetName())
+ ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, provisionerContextKey, prov)
+
+ oid := "orderID"
+ certID := "certID"
+ linkerPrefix := "acme"
+ l := NewLinker("dns", linkerPrefix)
+ type test struct {
+ o *acme.Order
+ validate func(o *acme.Order)
+ }
+ var tests = map[string]test{
+ "no-authz-and-no-cert": {
+ o: &acme.Order{
+ ID: oid,
+ },
+ validate: func(o *acme.Order) {
+ assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
+ assert.Equals(t, o.AuthorizationURLs, []string{})
+ assert.Equals(t, o.CertificateURL, "")
+ },
+ },
+ "one-authz-and-cert": {
+ o: &acme.Order{
+ ID: oid,
+ CertificateID: certID,
+ AuthorizationIDs: []string{"foo"},
+ },
+ validate: func(o *acme.Order) {
+ assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
+ assert.Equals(t, o.AuthorizationURLs, []string{
+ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
+ })
+ assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
+ },
+ },
+ "many-authz": {
+ o: &acme.Order{
+ ID: oid,
+ CertificateID: certID,
+ AuthorizationIDs: []string{"foo", "bar", "zap"},
+ },
+ validate: func(o *acme.Order) {
+ assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
+ assert.Equals(t, o.AuthorizationURLs, []string{
+ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
+ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "bar"),
+ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "zap"),
+ })
+ assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
+ },
+ },
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ l.LinkOrder(ctx, tc.o)
+ tc.validate(tc.o)
+ })
+ }
+}
+
+func TestLinker_LinkAccount(t *testing.T) {
+ baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
+ prov := newProv()
+ provName := url.PathEscape(prov.GetName())
+ ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, provisionerContextKey, prov)
+
+ accID := "accountID"
+ linkerPrefix := "acme"
+ l := NewLinker("dns", linkerPrefix)
+ type test struct {
+ a *acme.Account
+ validate func(o *acme.Account)
+ }
+ var tests = map[string]test{
+ "ok": {
+ a: &acme.Account{
+ ID: accID,
+ },
+ validate: func(a *acme.Account) {
+ assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
+ },
+ },
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ l.LinkAccount(ctx, tc.a)
+ tc.validate(tc.a)
+ })
+ }
+}
+
+func TestLinker_LinkChallenge(t *testing.T) {
+ baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
+ prov := newProv()
+ provName := url.PathEscape(prov.GetName())
+ ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, provisionerContextKey, prov)
+
+ chID := "chID"
+ azID := "azID"
+ linkerPrefix := "acme"
+ l := NewLinker("dns", linkerPrefix)
+ type test struct {
+ ch *acme.Challenge
+ validate func(o *acme.Challenge)
+ }
+ var tests = map[string]test{
+ "ok": {
+ ch: &acme.Challenge{
+ ID: chID,
+ },
+ validate: func(ch *acme.Challenge) {
+ assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
+ },
+ },
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ l.LinkChallenge(ctx, tc.ch, azID)
+ tc.validate(tc.ch)
+ })
+ }
+}
+
+func TestLinker_LinkAuthorization(t *testing.T) {
+ baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
+ prov := newProv()
+ provName := url.PathEscape(prov.GetName())
+ ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, provisionerContextKey, prov)
+
+ chID0 := "chID-0"
+ chID1 := "chID-1"
+ chID2 := "chID-2"
+ azID := "azID"
+ linkerPrefix := "acme"
+ l := NewLinker("dns", linkerPrefix)
+ type test struct {
+ az *acme.Authorization
+ validate func(o *acme.Authorization)
+ }
+ var tests = map[string]test{
+ "ok": {
+ az: &acme.Authorization{
+ ID: azID,
+ Challenges: []*acme.Challenge{
+ {ID: chID0},
+ {ID: chID1},
+ {ID: chID2},
+ },
+ },
+ validate: func(az *acme.Authorization) {
+ assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
+ assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
+ assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
+ },
+ },
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ l.LinkAuthorization(ctx, tc.az)
+ tc.validate(tc.az)
+ })
+ }
+}
+
+func TestLinker_LinkOrdersByAccountID(t *testing.T) {
+ baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
+ prov := newProv()
+ provName := url.PathEscape(prov.GetName())
+ ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, provisionerContextKey, prov)
+
+ linkerPrefix := "acme"
+ l := NewLinker("dns", linkerPrefix)
+ type test struct {
+ oids []string
+ }
+ var tests = map[string]test{
+ "ok": {
+ oids: []string{"foo", "bar", "baz"},
+ },
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ l.LinkOrdersByAccountID(ctx, tc.oids)
+ assert.Equals(t, tc.oids, []string{
+ fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "foo"),
+ fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "bar"),
+ fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "baz"),
+ })
+ })
+ }
+}
diff --git a/acme/api/middleware.go b/acme/api/middleware.go
index 3bf5f89a..50f7146f 100644
--- a/acme/api/middleware.go
+++ b/acme/api/middleware.go
@@ -3,13 +3,13 @@ package api
import (
"context"
"crypto/rsa"
+ "errors"
"io/ioutil"
"net/http"
"net/url"
"strings"
"github.com/go-chi/chi"
- "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner"
@@ -54,7 +54,7 @@ func baseURLFromRequest(r *http.Request) *url.URL {
// E.g. https://ca.smallstep.com/
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
- ctx := context.WithValue(r.Context(), acme.BaseURLContextKey, baseURLFromRequest(r))
+ ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
next(w, r.WithContext(ctx))
}
}
@@ -62,14 +62,14 @@ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
// addNonce is a middleware that adds a nonce to the response header.
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
- nonce, err := h.Auth.NewNonce()
+ nonce, err := h.db.CreateNonce(r.Context())
if err != nil {
api.WriteError(w, err)
return
}
- w.Header().Set("Replay-Nonce", nonce)
+ w.Header().Set("Replay-Nonce", string(nonce))
w.Header().Set("Cache-Control", "no-store")
- logNonce(w, nonce)
+ logNonce(w, string(nonce))
next(w, r)
}
}
@@ -78,8 +78,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
// directory index url.
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
- w.Header().Add("Link", link(h.Auth.GetLink(r.Context(),
- acme.DirectoryLink, true), "index"))
+ w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index"))
next(w, r)
}
}
@@ -88,23 +87,31 @@ func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
// application/jose+json.
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
- ct := r.Header.Get("Content-Type")
var expected []string
- if strings.Contains(r.URL.Path, h.Auth.GetLink(r.Context(), acme.CertificateLink, false, "")) {
+ p, err := provisionerFromContext(r.Context())
+ if err != nil {
+ api.WriteError(w, err)
+ return
+ }
+
+ u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
+ if strings.Contains(r.URL.String(), u.EscapedPath()) {
// GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
} else {
// By default every request should have content-type applictaion/jose+json.
expected = []string{"application/jose+json"}
}
+
+ ct := r.Header.Get("Content-Type")
for _, e := range expected {
if ct == e {
next(w, r)
return
}
}
- api.WriteError(w, acme.MalformedErr(errors.Errorf(
- "expected content-type to be in %s, but got %s", expected, ct)))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
+ "expected content-type to be in %s, but got %s", expected, ct))
}
}
@@ -113,15 +120,15 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
- api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body")))
+ api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
return
}
jws, err := jose.ParseJWS(string(body))
if err != nil {
- api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body")))
+ api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
return
}
- ctx := context.WithValue(r.Context(), acme.JwsContextKey, jws)
+ ctx := context.WithValue(r.Context(), jwsContextKey, jws)
next(w, r.WithContext(ctx))
}
}
@@ -143,17 +150,18 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
- jws, err := acme.JwsFromContext(r.Context())
+ ctx := r.Context()
+ jws, err := jwsFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
return
}
if len(jws.Signatures) == 0 {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
return
}
if len(jws.Signatures) > 1 {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
return
}
@@ -164,35 +172,36 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
len(uh.Algorithm) > 0 ||
len(uh.Nonce) > 0 ||
len(uh.ExtraHeaders) > 0 {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return
}
hdr := sig.Protected
switch hdr.Algorithm {
- case jose.RS256, jose.RS384, jose.RS512:
+ case jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512:
if hdr.JSONWebKey != nil {
switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey:
if k.Size() < keyutil.MinRSAKeyBytes {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+
- "keys must be at least %d bits (%d bytes) in size",
- 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
+ "rsa keys must be at least %d bits (%d bytes) in size",
+ 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
return
}
default:
- api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
+ "jws key type and algorithm do not match"))
return
}
}
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good
default:
- api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm)))
+ api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
return
}
// Check the validity/freshness of the Nonce.
- if err := h.Auth.UseNonce(hdr.Nonce); err != nil {
+ if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
api.WriteError(w, err)
return
}
@@ -200,21 +209,22 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
// Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
return
}
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
+ "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
return
}
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return
}
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return
}
next(w, r)
@@ -227,24 +237,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
- jws, err := acme.JwsFromContext(r.Context())
+ jws, err := jwsFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
return
}
jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
return
}
if !jwk.Valid() {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
return
}
- ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
- acc, err := h.Auth.GetAccountByKey(ctx, jwk)
+
+ // Overwrite KeyID with the JWK thumbprint.
+ jwk.KeyID, err = acme.KeyToID(jwk)
+ if err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
+ return
+ }
+
+ // Store the JWK in the context.
+ ctx = context.WithValue(ctx, jwkContextKey, jwk)
+
+ // Get Account or continue to generate a new one.
+ acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
switch {
- case nosql.IsErrNotFound(err):
+ case errors.Is(err, acme.ErrNotFound):
// For NewAccount requests ...
break
case err != nil:
@@ -252,10 +273,10 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
return
default:
if !acc.IsValid() {
- api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return
}
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ ctx = context.WithValue(ctx, accContextKey, acc)
}
next(w, r.WithContext(ctx))
}
@@ -270,20 +291,20 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
name := chi.URLParam(r, "provisionerID")
provID, err := url.PathUnescape(name)
if err != nil {
- api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name)))
+ api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner id '%s'", name))
return
}
- p, err := h.Auth.LoadProvisionerByID("acme/" + provID)
+ p, err := h.ca.LoadProvisionerByID("acme/" + provID)
if err != nil {
api.WriteError(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
- api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME")))
+ api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
- ctx = context.WithValue(ctx, acme.ProvisionerContextKey, acme.Provisioner(acmeProv))
+ ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
next(w, r.WithContext(ctx))
}
}
@@ -294,36 +315,37 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
- jws, err := acme.JwsFromContext(ctx)
+ jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
- kidPrefix := h.Auth.GetLink(ctx, acme.AccountLink, true, "")
+ kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+
- "required prefix; expected %s, but got %s", kidPrefix, kid)))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
+ "kid does not have required prefix; expected %s, but got %s",
+ kidPrefix, kid))
return
}
accID := strings.TrimPrefix(kid, kidPrefix)
- acc, err := h.Auth.GetAccount(r.Context(), accID)
+ acc, err := h.db.GetAccount(ctx, accID)
switch {
case nosql.IsErrNotFound(err):
- api.WriteError(w, acme.AccountDoesNotExistErr(nil))
+ api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return
case err != nil:
api.WriteError(w, err)
return
default:
if !acc.IsValid() {
- api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return
}
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.JwkContextKey, acc.Key)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
next(w, r.WithContext(ctx))
return
}
@@ -334,26 +356,27 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
// Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
- jws, err := acme.JwsFromContext(r.Context())
+ ctx := r.Context()
+ jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
- jwk, err := acme.JwkFromContext(r.Context())
+ jwk, err := jwkFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
- api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return
}
payload, err := jws.Verify(jwk)
if err != nil {
- api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws")))
+ api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
return
}
- ctx := context.WithValue(r.Context(), acme.PayloadContextKey, &payloadInfo{
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
value: payload,
isPostAsGet: string(payload) == "",
isEmptyJSON: string(payload) == "{}",
@@ -371,9 +394,89 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
return
}
if !payload.isPostAsGet {
- api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET")))
+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
return
}
next(w, r)
}
}
+
+// ContextKey is the key type for storing and searching for ACME request
+// essentials in the context of a request.
+type ContextKey string
+
+const (
+ // accContextKey account key
+ accContextKey = ContextKey("acc")
+ // baseURLContextKey baseURL key
+ baseURLContextKey = ContextKey("baseURL")
+ // jwsContextKey jws key
+ jwsContextKey = ContextKey("jws")
+ // jwkContextKey jwk key
+ jwkContextKey = ContextKey("jwk")
+ // payloadContextKey payload key
+ payloadContextKey = ContextKey("payload")
+ // provisionerContextKey provisioner key
+ provisionerContextKey = ContextKey("provisioner")
+)
+
+// accountFromContext searches the context for an ACME account. Returns the
+// account or an error.
+func accountFromContext(ctx context.Context) (*acme.Account, error) {
+ val, ok := ctx.Value(accContextKey).(*acme.Account)
+ if !ok || val == nil {
+ return nil, acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context")
+ }
+ return val, nil
+}
+
+// baseURLFromContext returns the baseURL if one is stored in the context.
+func baseURLFromContext(ctx context.Context) *url.URL {
+ val, ok := ctx.Value(baseURLContextKey).(*url.URL)
+ if !ok || val == nil {
+ return nil
+ }
+ return val
+}
+
+// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
+func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
+ val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
+ if !ok || val == nil {
+ return nil, acme.NewErrorISE("jwk expected in request context")
+ }
+ return val, nil
+}
+
+// jwsFromContext searches the context for a JWS. Returns the JWS or an error.
+func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
+ val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature)
+ if !ok || val == nil {
+ return nil, acme.NewErrorISE("jws expected in request context")
+ }
+ return val, nil
+}
+
+// provisionerFromContext searches the context for a provisioner. Returns the
+// provisioner or an error.
+func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
+ val := ctx.Value(provisionerContextKey)
+ if val == nil {
+ return nil, acme.NewErrorISE("provisioner expected in request context")
+ }
+ pval, ok := val.(acme.Provisioner)
+ if !ok || pval == nil {
+ return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
+ }
+ return pval, nil
+}
+
+// payloadFromContext searches the context for a payload. Returns the payload
+// or an error.
+func payloadFromContext(ctx context.Context) (*payloadInfo, error) {
+ val, ok := ctx.Value(payloadContextKey).(*payloadInfo)
+ if !ok || val == nil {
+ return nil, acme.NewErrorISE("payload expected in request context")
+ }
+ return val, nil
+}
diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go
index d2a9cdc0..40090e83 100644
--- a/acme/api/middleware_test.go
+++ b/acme/api/middleware_test.go
@@ -81,14 +81,14 @@ func Test_baseURLFromRequest(t *testing.T) {
}
}
-func TestHandlerBaseURLFromRequest(t *testing.T) {
- h := New(&mockAcmeAuthority{}).(*Handler)
+func TestHandler_baseURLFromRequest(t *testing.T) {
+ h := &Handler{}
req := httptest.NewRequest("GET", "/foo", nil)
req.Host = "test.ca.smallstep.com:8080"
w := httptest.NewRecorder()
next := func(w http.ResponseWriter, r *http.Request) {
- bu := acme.BaseURLFromContext(r.Context())
+ bu := baseURLFromContext(r.Context())
if assert.NotNil(t, bu) {
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
assert.Equals(t, bu.Scheme, "https")
@@ -101,35 +101,35 @@ func TestHandlerBaseURLFromRequest(t *testing.T) {
req.Host = ""
next = func(w http.ResponseWriter, r *http.Request) {
- assert.Equals(t, acme.BaseURLFromContext(r.Context()), nil)
+ assert.Equals(t, baseURLFromContext(r.Context()), nil)
}
h.baseURLFromRequest(next)(w, req)
}
-func TestHandlerAddNonce(t *testing.T) {
+func TestHandler_addNonce(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-nonce"
type test struct {
- auth acme.Interface
- problem *acme.Error
+ db acme.DB
+ err *acme.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/AddNonce-error": func(t *testing.T) test {
return test{
- auth: &mockAcmeAuthority{
- newNonce: func() (string, error) {
- return "", acme.ServerInternalErr(errors.New("force"))
+ db: &acme.MockDB{
+ MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) {
+ return acme.Nonce(""), acme.NewErrorISE("force")
},
},
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
}
},
"ok": func(t *testing.T) test {
return test{
- auth: &mockAcmeAuthority{
- newNonce: func() (string, error) {
+ db: &acme.MockDB{
+ MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) {
return "bar", nil
},
},
@@ -140,7 +140,7 @@ func TestHandlerAddNonce(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil)
w := httptest.NewRecorder()
h.addNonce(testNext)(w, req)
@@ -152,15 +152,14 @@ func TestHandlerAddNonce(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, res.Header["Replay-Nonce"], []string{"bar"})
@@ -171,28 +170,23 @@ func TestHandlerAddNonce(t *testing.T) {
}
}
-func TestHandlerAddDirLink(t *testing.T) {
+func TestHandler_addDirLink(t *testing.T) {
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct {
- auth acme.Interface
link string
+ linker Linker
statusCode int
ctx context.Context
- problem *acme.Error
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"ok": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
- assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
- return fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName)
- },
- },
+ linker: NewLinker("dns", "acme"),
ctx: ctx,
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
statusCode: 200,
@@ -202,7 +196,7 @@ func TestHandlerAddDirLink(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{linker: tc.linker}
req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -215,15 +209,14 @@ func TestHandlerAddDirLink(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s>;rel=\"index\"", tc.link)})
@@ -233,70 +226,61 @@ func TestHandlerAddDirLink(t *testing.T) {
}
}
-func TestHandlerVerifyContentType(t *testing.T) {
+func TestHandler_verifyContentType(t *testing.T) {
prov := newProv()
- provName := prov.GetName()
+ escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), provName)
+ url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
type test struct {
h Handler
ctx context.Context
contentType string
- problem *acme.Error
+ err *acme.Error
statusCode int
url string
}
var tests = map[string]func(t *testing.T) test{
+ "fail/provisioner-not-set": func(t *testing.T) test {
+ return test{
+ h: Handler{
+ linker: NewLinker("dns", "acme"),
+ },
+ url: url,
+ ctx: context.Background(),
+ contentType: "foo",
+ statusCode: 500,
+ err: acme.NewErrorISE("provisioner expected in request context"),
+ }
+ },
"fail/general-bad-content-type": func(t *testing.T) test {
return test{
h: Handler{
- Auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.CertificateLink)
- assert.Equals(t, abs, false)
- assert.Equals(t, in, []string{""})
- return fmt.Sprintf("/acme/%s/certificate/", provName)
- },
- },
+ linker: NewLinker("dns", "acme"),
},
- url: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ url: url,
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo",
statusCode: 400,
- problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json], but got foo")),
+ err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
}
},
"fail/certificate-bad-content-type": func(t *testing.T) test {
return test{
h: Handler{
- Auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.CertificateLink)
- assert.Equals(t, abs, false)
- assert.Equals(t, in, []string{""})
- return "/certificate/"
- },
- },
+ linker: NewLinker("dns", "acme"),
},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo",
statusCode: 400,
- problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo")),
+ err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
}
},
"ok": func(t *testing.T) test {
return test{
h: Handler{
- Auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.CertificateLink)
- assert.Equals(t, abs, false)
- assert.Equals(t, in, []string{""})
- return "/certificate/"
- },
- },
+ linker: NewLinker("dns", "acme"),
},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json",
statusCode: 200,
}
@@ -304,16 +288,9 @@ func TestHandlerVerifyContentType(t *testing.T) {
"ok/certificate/pkix-cert": func(t *testing.T) test {
return test{
h: Handler{
- Auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.CertificateLink)
- assert.Equals(t, abs, false)
- assert.Equals(t, in, []string{""})
- return "/certificate/"
- },
- },
+ linker: NewLinker("dns", "acme"),
},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkix-cert",
statusCode: 200,
}
@@ -321,16 +298,9 @@ func TestHandlerVerifyContentType(t *testing.T) {
"ok/certificate/jose+json": func(t *testing.T) test {
return test{
h: Handler{
- Auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.CertificateLink)
- assert.Equals(t, abs, false)
- assert.Equals(t, in, []string{""})
- return "/certificate/"
- },
- },
+ linker: NewLinker("dns", "acme"),
},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json",
statusCode: 200,
}
@@ -338,16 +308,9 @@ func TestHandlerVerifyContentType(t *testing.T) {
"ok/certificate/pkcs7-mime": func(t *testing.T) test {
return test{
h: Handler{
- Auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.CertificateLink)
- assert.Equals(t, abs, false)
- assert.Equals(t, in, []string{""})
- return "/certificate/"
- },
- },
+ linker: NewLinker("dns", "acme"),
},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkcs7-mime",
statusCode: 200,
}
@@ -373,15 +336,14 @@ func TestHandlerVerifyContentType(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
@@ -390,11 +352,11 @@ func TestHandlerVerifyContentType(t *testing.T) {
}
}
-func TestHandlerIsPostAsGet(t *testing.T) {
+func TestHandler_isPostAsGet(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-account"
type test struct {
ctx context.Context
- problem *acme.Error
+ err *acme.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
@@ -402,26 +364,26 @@ func TestHandlerIsPostAsGet(t *testing.T) {
return test{
ctx: context.Background(),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload expected in request context"),
}
},
"fail/nil-payload": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.PayloadContextKey, nil),
+ ctx: context.WithValue(context.Background(), payloadContextKey, nil),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload expected in request context"),
}
},
"fail/not-post-as-get": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{}),
+ ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}),
statusCode: 400,
- problem: acme.MalformedErr(errors.New("expected POST-as-GET")),
+ err: acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"),
}
},
"ok": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}),
+ ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{isPostAsGet: true}),
statusCode: 200,
}
},
@@ -429,7 +391,7 @@ func TestHandlerIsPostAsGet(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(nil).(*Handler)
+ h := &Handler{}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -442,15 +404,14 @@ func TestHandlerIsPostAsGet(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
@@ -468,12 +429,12 @@ func (errReader) Close() error {
return nil
}
-func TestHandlerParseJWS(t *testing.T) {
+func TestHandler_parseJWS(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-account"
type test struct {
next nextHTTP
body io.Reader
- problem *acme.Error
+ err *acme.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
@@ -481,14 +442,14 @@ func TestHandlerParseJWS(t *testing.T) {
return test{
body: errReader(0),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("failed to read request body: force")),
+ err: acme.NewErrorISE("failed to read request body: force"),
}
},
"fail/parse-jws-error": func(t *testing.T) test {
return test{
body: strings.NewReader("foo"),
statusCode: 400,
- problem: acme.MalformedErr(errors.New("failed to parse JWS from request body: square/go-jose: compact JWS format must have three parts")),
+ err: acme.NewError(acme.ErrorMalformedType, "failed to parse JWS from request body: square/go-jose: compact JWS format must have three parts"),
}
},
"ok": func(t *testing.T) test {
@@ -507,7 +468,7 @@ func TestHandlerParseJWS(t *testing.T) {
return test{
body: strings.NewReader(expRaw),
next: func(w http.ResponseWriter, r *http.Request) {
- jws, err := acme.JwsFromContext(r.Context())
+ jws, err := jwsFromContext(r.Context())
assert.FatalError(t, err)
gotRaw, err := jws.CompactSerialize()
assert.FatalError(t, err)
@@ -521,7 +482,7 @@ func TestHandlerParseJWS(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(nil).(*Handler)
+ h := &Handler{}
req := httptest.NewRequest("GET", url, tc.body)
w := httptest.NewRecorder()
h.parseJWS(tc.next)(w, req)
@@ -533,15 +494,14 @@ func TestHandlerParseJWS(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
@@ -550,7 +510,7 @@ func TestHandlerParseJWS(t *testing.T) {
}
}
-func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
+func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
_pub := jwk.Public()
@@ -572,7 +532,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
type test struct {
ctx context.Context
next func(http.ResponseWriter, *http.Request)
- problem *acme.Error
+ err *acme.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
@@ -580,58 +540,58 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
return test{
ctx: context.Background(),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
+ err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/nil-jws": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil),
+ ctx: context.WithValue(context.Background(), jwsContextKey, nil),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
+ err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/no-jwk": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jwk expected in request context")),
+ err: acme.NewErrorISE("jwk expected in request context"),
}
},
"fail/nil-jwk": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
+ ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
return test{
- ctx: context.WithValue(ctx, acme.JwkContextKey, nil),
+ ctx: context.WithValue(ctx, jwsContextKey, nil),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jwk expected in request context")),
+ err: acme.NewErrorISE("jwk expected in request context"),
}
},
"fail/verify-jws-failure": func(t *testing.T) test {
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
_pub := _jwk.Public()
- ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
- ctx = context.WithValue(ctx, acme.JwkContextKey, &_pub)
+ ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
+ ctx = context.WithValue(ctx, jwkContextKey, &_pub)
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("error verifying jws: square/go-jose: error in cryptographic primitive")),
+ err: acme.NewError(acme.ErrorMalformedType, "error verifying jws: square/go-jose: error in cryptographic primitive"),
}
},
"fail/algorithm-mismatch": func(t *testing.T) test {
_pub := *pub
clone := &_pub
clone.Algorithm = jose.HS256
- ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
- ctx = context.WithValue(ctx, acme.JwkContextKey, clone)
+ ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
+ ctx = context.WithValue(ctx, jwkContextKey, clone)
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("verifier and signature algorithm do not match")),
+ err: acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"),
}
},
"ok": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
- ctx = context.WithValue(ctx, acme.JwkContextKey, pub)
+ ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
+ ctx = context.WithValue(ctx, jwkContextKey, pub)
return test{
ctx: ctx,
statusCode: 200,
@@ -651,8 +611,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
_pub := *pub
clone := &_pub
clone.Algorithm = ""
- ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
- ctx = context.WithValue(ctx, acme.JwkContextKey, pub)
+ ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
+ ctx = context.WithValue(ctx, jwkContextKey, pub)
return test{
ctx: ctx,
statusCode: 200,
@@ -675,8 +635,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed)
- ctx = context.WithValue(ctx, acme.JwkContextKey, pub)
+ ctx := context.WithValue(context.Background(), jwsContextKey, _parsed)
+ ctx = context.WithValue(ctx, jwkContextKey, pub)
return test{
ctx: ctx,
statusCode: 200,
@@ -699,8 +659,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed)
- ctx = context.WithValue(ctx, acme.JwkContextKey, pub)
+ ctx := context.WithValue(context.Background(), jwsContextKey, _parsed)
+ ctx = context.WithValue(ctx, jwkContextKey, pub)
return test{
ctx: ctx,
statusCode: 200,
@@ -720,7 +680,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(nil).(*Handler)
+ h := &Handler{}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -733,15 +693,14 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
@@ -750,7 +709,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
}
}
-func TestHandlerLookupJWK(t *testing.T) {
+func TestHandler_lookupJWK(t *testing.T) {
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
@@ -775,27 +734,28 @@ func TestHandlerLookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
type test struct {
- auth acme.Interface
+ linker Linker
+ db acme.DB
ctx context.Context
next func(http.ResponseWriter, *http.Request)
- problem *acme.Error
+ err *acme.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
+ err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/nil-jws": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, nil)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
+ err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/no-kid": func(t *testing.T) test {
@@ -806,21 +766,14 @@ func TestHandlerLookupJWK(t *testing.T) {
assert.FatalError(t, err)
_jws, err := _signer.Sign([]byte("baz"))
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, _jws)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, _jws)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{""})
- return prefix
- },
- },
+ linker: NewLinker("dns", "acme"),
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got ", prefix)),
+ err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
}
},
"fail/bad-kid-prefix": func(t *testing.T) test {
@@ -837,126 +790,87 @@ func TestHandlerLookupJWK(t *testing.T) {
assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, _parsed)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, _parsed)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{""})
- return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
- },
- },
+ linker: NewLinker("dns", "acme"),
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got foo", prefix)),
+ err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
}
},
"fail/account-not-found": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
+ linker: NewLinker("dns", "acme"),
+ db: &acme.MockDB{
+ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, accID)
return nil, database.ErrNotFound
},
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{""})
- return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
- },
},
ctx: ctx,
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/GetAccount-error": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, accID, accID)
- return nil, acme.ServerInternalErr(errors.New("force"))
- },
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{""})
- return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
+ linker: NewLinker("dns", "acme"),
+ db: &acme.MockDB{
+ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
+ assert.Equals(t, id, accID)
+ return nil, acme.NewErrorISE("force")
},
},
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
}
},
"fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, accID, accID)
+ linker: NewLinker("dns", "acme"),
+ db: &acme.MockDB{
+ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
+ assert.Equals(t, id, accID)
return acc, nil
},
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{""})
- return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
- },
},
ctx: ctx,
statusCode: 401,
- problem: acme.UnauthorizedErr(errors.New("account is not active")),
+ err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid", Key: jwk}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, accID, accID)
+ linker: NewLinker("dns", "acme"),
+ db: &acme.MockDB{
+ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
+ assert.Equals(t, id, accID)
return acc, nil
},
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.AccountLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{""})
- return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
- },
},
ctx: ctx,
next: func(w http.ResponseWriter, r *http.Request) {
- _acc, err := acme.AccountFromContext(r.Context())
+ _acc, err := accountFromContext(r.Context())
assert.FatalError(t, err)
assert.Equals(t, _acc, acc)
- _jwk, err := acme.JwkFromContext(r.Context())
+ _jwk, err := jwkFromContext(r.Context())
assert.FatalError(t, err)
assert.Equals(t, _jwk, jwk)
w.Write(testBody)
@@ -968,7 +882,7 @@ func TestHandlerLookupJWK(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db, linker: tc.linker}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -981,15 +895,14 @@ func TestHandlerLookupJWK(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
@@ -998,7 +911,7 @@ func TestHandlerLookupJWK(t *testing.T) {
}
}
-func TestHandlerExtractJWK(t *testing.T) {
+func TestHandler_extractJWK(t *testing.T) {
prov := newProv()
provName := url.PathEscape(prov.GetName())
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
@@ -1024,27 +937,27 @@ func TestHandlerExtractJWK(t *testing.T) {
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
provName)
type test struct {
- auth acme.Interface
+ db acme.DB
ctx context.Context
next func(http.ResponseWriter, *http.Request)
- problem *acme.Error
+ err *acme.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
+ err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/nil-jws": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, nil)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
+ err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/nil-jwk": func(t *testing.T) test {
@@ -1057,12 +970,12 @@ func TestHandlerExtractJWK(t *testing.T) {
},
},
}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, _jws)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("jwk expected in protected header")),
+ err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
}
},
"fail/invalid-jwk": func(t *testing.T) test {
@@ -1075,71 +988,62 @@ func TestHandlerExtractJWK(t *testing.T) {
},
},
}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, _jws)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("invalid jwk in protected header")),
+ err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
}
},
"fail/GetAccountByKey-error": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
ctx: ctx,
- auth: &mockAcmeAuthority{
- getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, jwk.KeyID, pub.KeyID)
- return nil, acme.ServerInternalErr(errors.New("force"))
+ db: &acme.MockDB{
+ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
+ assert.Equals(t, kid, pub.KeyID)
+ return nil, acme.NewErrorISE("force")
},
},
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
}
},
"fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
ctx: ctx,
- auth: &mockAcmeAuthority{
- getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, jwk.KeyID, pub.KeyID)
+ db: &acme.MockDB{
+ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
+ assert.Equals(t, kid, pub.KeyID)
return acc, nil
},
},
statusCode: 401,
- problem: acme.UnauthorizedErr(errors.New("account is not active")),
+ err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
ctx: ctx,
- auth: &mockAcmeAuthority{
- getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, jwk.KeyID, pub.KeyID)
+ db: &acme.MockDB{
+ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
+ assert.Equals(t, kid, pub.KeyID)
return acc, nil
},
},
next: func(w http.ResponseWriter, r *http.Request) {
- _acc, err := acme.AccountFromContext(r.Context())
+ _acc, err := accountFromContext(r.Context())
assert.FatalError(t, err)
assert.Equals(t, _acc, acc)
- _jwk, err := acme.JwkFromContext(r.Context())
+ _jwk, err := jwkFromContext(r.Context())
assert.FatalError(t, err)
assert.Equals(t, _jwk.KeyID, pub.KeyID)
w.Write(testBody)
@@ -1148,24 +1052,21 @@ func TestHandlerExtractJWK(t *testing.T) {
}
},
"ok/no-account": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
ctx: ctx,
- auth: &mockAcmeAuthority{
- getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, jwk.KeyID, pub.KeyID)
- return nil, database.ErrNotFound
+ db: &acme.MockDB{
+ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
+ assert.Equals(t, kid, pub.KeyID)
+ return nil, acme.ErrNotFound
},
},
next: func(w http.ResponseWriter, r *http.Request) {
- _acc, err := acme.AccountFromContext(r.Context())
+ _acc, err := accountFromContext(r.Context())
assert.NotNil(t, err)
assert.Nil(t, _acc)
- _jwk, err := acme.JwkFromContext(r.Context())
+ _jwk, err := jwkFromContext(r.Context())
assert.FatalError(t, err)
assert.Equals(t, _jwk.KeyID, pub.KeyID)
w.Write(testBody)
@@ -1177,7 +1078,7 @@ func TestHandlerExtractJWK(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -1190,15 +1091,14 @@ func TestHandlerExtractJWK(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
@@ -1207,13 +1107,13 @@ func TestHandlerExtractJWK(t *testing.T) {
}
}
-func TestHandlerValidateJWS(t *testing.T) {
+func TestHandler_validateJWS(t *testing.T) {
url := "https://ca.smallstep.com/acme/account/1234"
type test struct {
- auth acme.Interface
+ db acme.DB
ctx context.Context
next func(http.ResponseWriter, *http.Request)
- problem *acme.Error
+ err *acme.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
@@ -1221,21 +1121,21 @@ func TestHandlerValidateJWS(t *testing.T) {
return test{
ctx: context.Background(),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
+ err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/nil-jws": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil),
+ ctx: context.WithValue(context.Background(), jwsContextKey, nil),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
+ err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/no-signature": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, &jose.JSONWebSignature{}),
+ ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
statusCode: 400,
- problem: acme.MalformedErr(errors.New("request body does not contain a signature")),
+ err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
}
},
"fail/more-than-one-signature": func(t *testing.T) test {
@@ -1246,9 +1146,9 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.New("request body contains more than one signature")),
+ err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
}
},
"fail/unprotected-header-not-empty": func(t *testing.T) test {
@@ -1258,9 +1158,9 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.New("unprotected header must not be used")),
+ err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
}
},
"fail/unsuitable-algorithm-none": func(t *testing.T) test {
@@ -1270,9 +1170,9 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.New("unsuitable algorithm: none")),
+ err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
}
},
"fail/unsuitable-algorithm-mac": func(t *testing.T) test {
@@ -1282,9 +1182,9 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)),
+ err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
}
},
"fail/rsa-key-&-alg-mismatch": func(t *testing.T) test {
@@ -1305,14 +1205,14 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
return nil
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")),
+ err: acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match"),
}
},
"fail/rsa-key-too-small": func(t *testing.T) test {
@@ -1333,14 +1233,14 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
return nil
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")),
+ err: acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least 2048 bits (256 bytes) in size"),
}
},
"fail/UseNonce-error": func(t *testing.T) test {
@@ -1350,14 +1250,14 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
- return acme.ServerInternalErr(errors.New("force"))
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
+ return acme.NewErrorISE("force")
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
}
},
"fail/no-url-header": func(t *testing.T) test {
@@ -1367,14 +1267,14 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
return nil
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.New("jws missing url protected header")),
+ err: acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"),
}
},
"fail/url-mismatch": func(t *testing.T) test {
@@ -1391,14 +1291,14 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
return nil
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)),
+ err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", url),
}
},
"fail/both-jwk-kid": func(t *testing.T) test {
@@ -1420,14 +1320,14 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
return nil
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")),
+ err: acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"),
}
},
"fail/no-jwk-kid": func(t *testing.T) test {
@@ -1444,14 +1344,14 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
return nil
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
- problem: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")),
+ err: acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"),
}
},
"ok/kid": func(t *testing.T) test {
@@ -1469,12 +1369,12 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
return nil
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
@@ -1499,12 +1399,12 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
return nil
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
@@ -1529,12 +1429,12 @@ func TestHandlerValidateJWS(t *testing.T) {
},
}
return test{
- auth: &mockAcmeAuthority{
- useNonce: func(n string) error {
+ db: &acme.MockDB{
+ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error {
return nil
},
},
- ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
+ ctx: context.WithValue(context.Background(), jwsContextKey, jws),
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
@@ -1545,7 +1445,7 @@ func TestHandlerValidateJWS(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -1558,15 +1458,14 @@ func TestHandlerValidateJWS(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
diff --git a/acme/api/order.go b/acme/api/order.go
index 5c62cb52..9d410173 100644
--- a/acme/api/order.go
+++ b/acme/api/order.go
@@ -1,16 +1,18 @@
package api
import (
+ "context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"net/http"
+ "strings"
"time"
"github.com/go-chi/chi"
- "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
+ "go.step.sm/crypto/randutil"
)
// NewOrderRequest represents the body for a NewOrder request.
@@ -23,11 +25,11 @@ type NewOrderRequest struct {
// Validate validates a new-order request body.
func (n *NewOrderRequest) Validate() error {
if len(n.Identifiers) == 0 {
- return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty"))
+ return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty")
}
for _, id := range n.Identifiers {
if id.Type != "dns" {
- return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type))
+ return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type)
}
}
return nil
@@ -44,22 +46,30 @@ func (f *FinalizeRequest) Validate() error {
var err error
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
if err != nil {
- return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr"))
+ return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr")
}
f.csr, err = x509.ParseCertificateRequest(csrBytes)
if err != nil {
- return acme.MalformedErr(errors.Wrap(err, "unable to parse csr"))
+ return acme.WrapError(acme.ErrorMalformedType, err, "unable to parse csr")
}
if err = f.csr.CheckSignature(); err != nil {
- return acme.MalformedErr(errors.Wrap(err, "csr failed signature check"))
+ return acme.WrapError(acme.ErrorMalformedType, err, "csr failed signature check")
}
return nil
}
+var defaultOrderExpiry = time.Hour * 24
+var defaultOrderBackdate = time.Minute
+
// NewOrder ACME api for creating a new order.
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
- acc, err := acme.AccountFromContext(ctx)
+ acc, err := accountFromContext(ctx)
+ if err != nil {
+ api.WriteError(w, err)
+ return
+ }
+ prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
@@ -71,8 +81,8 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
}
var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil {
- api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
- "failed to unmarshal new-order request payload")))
+ api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
+ "failed to unmarshal new-order request payload"))
return
}
if err := nor.Validate(); err != nil {
@@ -80,44 +90,146 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
return
}
- o, err := h.Auth.NewOrder(ctx, acme.OrderOptions{
- AccountID: acc.GetID(),
- Identifiers: nor.Identifiers,
- NotBefore: nor.NotBefore,
- NotAfter: nor.NotAfter,
- })
- if err != nil {
- api.WriteError(w, err)
+ now := clock.Now()
+ // New order.
+ o := &acme.Order{
+ AccountID: acc.ID,
+ ProvisionerID: prov.GetID(),
+ Status: acme.StatusPending,
+ Identifiers: nor.Identifiers,
+ ExpiresAt: now.Add(defaultOrderExpiry),
+ AuthorizationIDs: make([]string, len(nor.Identifiers)),
+ NotBefore: nor.NotBefore,
+ NotAfter: nor.NotAfter,
+ }
+
+ for i, identifier := range o.Identifiers {
+ az := &acme.Authorization{
+ AccountID: acc.ID,
+ Identifier: identifier,
+ ExpiresAt: o.ExpiresAt,
+ Status: acme.StatusPending,
+ }
+ if err := h.newAuthorization(ctx, az); err != nil {
+ api.WriteError(w, err)
+ return
+ }
+ o.AuthorizationIDs[i] = az.ID
+ }
+
+ if o.NotBefore.IsZero() {
+ o.NotBefore = now
+ }
+ if o.NotAfter.IsZero() {
+ o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration())
+ }
+ // If request NotBefore was empty then backdate the order.NotBefore (now)
+ // to avoid timing issues.
+ if nor.NotBefore.IsZero() {
+ o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
+ }
+
+ if err := h.db.CreateOrder(ctx, o); err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error creating order"))
return
}
- w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID()))
+ h.linker.LinkOrder(ctx, o)
+
+ w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSONStatus(w, o, http.StatusCreated)
}
+func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
+ if strings.HasPrefix(az.Identifier.Value, "*.") {
+ az.Wildcard = true
+ az.Identifier = acme.Identifier{
+ Value: strings.TrimPrefix(az.Identifier.Value, "*."),
+ Type: az.Identifier.Type,
+ }
+ }
+
+ var (
+ err error
+ chTypes = []string{"dns-01"}
+ )
+ // HTTP and TLS challenges can only be used for identifiers without wildcards.
+ if !az.Wildcard {
+ chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...)
+ }
+
+ az.Token, err = randutil.Alphanumeric(32)
+ if err != nil {
+ return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
+ }
+ az.Challenges = make([]*acme.Challenge, len(chTypes))
+ for i, typ := range chTypes {
+ ch := &acme.Challenge{
+ AccountID: az.AccountID,
+ Value: az.Identifier.Value,
+ Type: typ,
+ Token: az.Token,
+ Status: acme.StatusPending,
+ }
+ if err := h.db.CreateChallenge(ctx, ch); err != nil {
+ return acme.WrapErrorISE(err, "error creating challenge")
+ }
+ az.Challenges[i] = ch
+ }
+ if err = h.db.CreateAuthorization(ctx, az); err != nil {
+ return acme.WrapErrorISE(err, "error creating authorization")
+ }
+ return nil
+}
+
// GetOrder ACME api for retrieving an order.
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
- acc, err := acme.AccountFromContext(ctx)
+ acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
- oid := chi.URLParam(r, "ordID")
- o, err := h.Auth.GetOrder(ctx, acc.GetID(), oid)
+ prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
+ o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
+ if err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
+ return
+ }
+ if acc.ID != o.AccountID {
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
+ "account '%s' does not own order '%s'", acc.ID, o.ID))
+ return
+ }
+ if prov.GetID() != o.ProvisionerID {
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
+ "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
+ return
+ }
+ if err = o.UpdateStatus(ctx, h.db); err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error updating order status"))
+ return
+ }
- w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID()))
+ h.linker.LinkOrder(ctx, o)
+
+ w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o)
}
// FinalizeOrder attemptst to finalize an order and create a certificate.
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
- acc, err := acme.AccountFromContext(ctx)
+ acc, err := accountFromContext(ctx)
+ if err != nil {
+ api.WriteError(w, err)
+ return
+ }
+ prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
@@ -129,7 +241,8 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
}
var fr FinalizeRequest
if err := json.Unmarshal(payload.value, &fr); err != nil {
- api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload")))
+ api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
+ "failed to unmarshal finalize-order request payload"))
return
}
if err := fr.Validate(); err != nil {
@@ -137,13 +250,28 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
return
}
- oid := chi.URLParam(r, "ordID")
- o, err := h.Auth.FinalizeOrder(ctx, acc.GetID(), oid, fr.csr)
+ o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
- api.WriteError(w, err)
+ api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
+ return
+ }
+ if acc.ID != o.AccountID {
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
+ "account '%s' does not own order '%s'", acc.ID, o.ID))
+ return
+ }
+ if prov.GetID() != o.ProvisionerID {
+ api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
+ "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
+ return
+ }
+ if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
+ api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order"))
return
}
- w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.ID))
+ h.linker.LinkOrder(ctx, o)
+
+ w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o)
}
diff --git a/acme/api/order_test.go b/acme/api/order_test.go
index a1c8fef7..300aa61b 100644
--- a/acme/api/order_test.go
+++ b/acme/api/order_test.go
@@ -20,7 +20,7 @@ import (
"go.step.sm/crypto/pemutil"
)
-func TestNewOrderRequestValidate(t *testing.T) {
+func TestNewOrderRequest_Validate(t *testing.T) {
type test struct {
nor *NewOrderRequest
nbf, naf time.Time
@@ -30,7 +30,7 @@ func TestNewOrderRequestValidate(t *testing.T) {
"fail/no-identifiers": func(t *testing.T) test {
return test{
nor: &NewOrderRequest{},
- err: acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")),
+ err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
}
},
"fail/bad-identifier": func(t *testing.T) test {
@@ -41,7 +41,7 @@ func TestNewOrderRequestValidate(t *testing.T) {
{Type: "foo", Value: "bar.com"},
},
},
- err: acme.MalformedErr(errors.Errorf("identifier type unsupported: foo")),
+ err: acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: foo"),
}
},
"ok": func(t *testing.T) test {
@@ -105,7 +105,7 @@ func TestFinalizeRequestValidate(t *testing.T) {
"fail/parse-csr-error": func(t *testing.T) test {
return test{
fr: &FinalizeRequest{},
- err: acme.MalformedErr(errors.Errorf("unable to parse csr: asn1: syntax error: sequence truncated")),
+ err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
}
},
"fail/invalid-csr-signature": func(t *testing.T) test {
@@ -117,7 +117,7 @@ func TestFinalizeRequestValidate(t *testing.T) {
fr: &FinalizeRequest{
CSR: base64.RawURLEncoding.EncodeToString(c.Raw),
},
- err: acme.MalformedErr(errors.Errorf("csr failed signature check: x509: ECDSA verification failure")),
+ err: acme.NewError(acme.ErrorMalformedType, "csr failed signature check: x509: ECDSA verification failure"),
}
},
"ok": func(t *testing.T) test {
@@ -148,15 +148,19 @@ func TestFinalizeRequestValidate(t *testing.T) {
}
}
-func TestHandlerGetOrder(t *testing.T) {
- expiry := time.Now().UTC().Add(6 * time.Hour)
- nbf := time.Now().UTC()
- naf := time.Now().UTC().Add(24 * time.Hour)
+func TestHandler_GetOrder(t *testing.T) {
+ prov := newProv()
+ escProvName := url.PathEscape(prov.GetName())
+ baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
+
+ now := clock.Now()
+ nbf := now
+ naf := now.Add(24 * time.Hour)
+ expiry := now.Add(-time.Hour)
o := acme.Order{
ID: "orderID",
- Expires: expiry.Format(time.RFC3339),
- NotBefore: nbf.Format(time.RFC3339),
- NotAfter: naf.Format(time.RFC3339),
+ NotBefore: nbf,
+ NotAfter: naf,
Identifiers: []acme.Identifier{
{
Type: "dns",
@@ -167,79 +171,167 @@ func TestHandlerGetOrder(t *testing.T) {
Value: "*.smallstep.com",
},
},
- Status: "pending",
- Authorizations: []string{"foo", "bar"},
+ ExpiresAt: expiry,
+ Status: acme.StatusInvalid,
+ Error: acme.NewError(acme.ErrorMalformedType, "order has expired"),
+ AuthorizationURLs: []string{
+ fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName),
+ fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName),
+ fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName),
+ },
+ FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName),
}
// Request with chi context
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID)
- prov := newProv()
- provName := url.PathEscape(prov.GetName())
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/order/%s",
- baseURL.String(), provName, o.ID)
+ baseURL.String(), escProvName, o.ID)
type test struct {
- auth acme.Interface
+ db acme.DB
ctx context.Context
statusCode int
- problem *acme.Error
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
- auth: &mockAcmeAuthority{},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, nil)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, nil)
return test{
- auth: &mockAcmeAuthority{},
ctx: ctx,
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
- "fail/getOrder-error": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ "fail/no-provisioner": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
+ return test{
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("provisioner does not exist"),
+ }
+ },
+ "fail/nil-provisioner": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ return test{
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("provisioner does not exist"),
+ }
+ },
+ "fail/db.GetOrder-error": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
- auth: &mockAcmeAuthority{
- err: acme.ServerInternalErr(errors.New("force")),
+ db: &acme.MockDB{
+ MockError: acme.NewErrorISE("force"),
},
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("force")),
+ err: acme.NewErrorISE("force"),
+ }
+ },
+ "fail/account-id-mismatch": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
+ return &acme.Order{AccountID: "foo"}, nil
+ },
+ },
+ ctx: ctx,
+ statusCode: 401,
+ err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"),
+ }
+ },
+ "fail/provisioner-id-mismatch": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
+ return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil
+ },
+ },
+ ctx: ctx,
+ statusCode: 401,
+ err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"),
+ }
+ },
+ "fail/order-update-error": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
+ return &acme.Order{
+ AccountID: "accountID",
+ ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()),
+ ExpiresAt: clock.Now().Add(-time.Hour),
+ Status: acme.StatusReady,
+ }, nil
+ },
+ MockUpdateOrder: func(ctx context.Context, o *acme.Order) error {
+ return acme.NewErrorISE("force")
+ },
+ },
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("force"),
}
},
"ok": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- getOrder: func(ctx context.Context, accID, id string) (*acme.Order, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, accID, acc.ID)
- assert.Equals(t, id, o.ID)
- return &o, nil
+ db: &acme.MockDB{
+ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
+ return &acme.Order{
+ ID: "orderID",
+ AccountID: "accountID",
+ ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()),
+ ExpiresAt: expiry,
+ Status: acme.StatusReady,
+ AuthorizationIDs: []string{"foo", "bar", "baz"},
+ NotBefore: nbf,
+ NotAfter: naf,
+ Identifiers: []acme.Identifier{
+ {
+ Type: "dns",
+ Value: "example.com",
+ },
+ {
+ Type: "dns",
+ Value: "*.smallstep.com",
+ },
+ },
+ }, nil
},
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.OrderLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{o.ID})
- return url
+ MockUpdateOrder: func(ctx context.Context, o *acme.Order) error {
+ return nil
},
},
ctx: ctx,
@@ -250,7 +342,7 @@ func TestHandlerGetOrder(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -263,19 +355,19 @@ func TestHandlerGetOrder(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(o)
assert.FatalError(t, err)
+
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], []string{url})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
@@ -284,209 +376,886 @@ func TestHandlerGetOrder(t *testing.T) {
}
}
-func TestHandlerNewOrder(t *testing.T) {
- expiry := time.Now().UTC().Add(6 * time.Hour)
- nbf := time.Now().UTC().Add(5 * time.Hour)
- naf := nbf.Add(17 * time.Hour)
- o := acme.Order{
- ID: "orderID",
- Expires: expiry.Format(time.RFC3339),
- NotBefore: nbf.Format(time.RFC3339),
- NotAfter: naf.Format(time.RFC3339),
- Identifiers: []acme.Identifier{
- {Type: "dns", Value: "example.com"},
- {Type: "dns", Value: "bar.com"},
- },
- Status: "pending",
- Authorizations: []string{"foo", "bar"},
+func TestHandler_newAuthorization(t *testing.T) {
+ type test struct {
+ az *acme.Authorization
+ db acme.DB
+ err *acme.Error
}
+ var tests = map[string]func(t *testing.T) test{
+ "fail/error-db.CreateChallenge": func(t *testing.T) test {
+ az := &acme.Authorization{
+ AccountID: "accID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "zap.internal",
+ },
+ }
+ return test{
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ assert.Equals(t, ch.AccountID, az.AccountID)
+ assert.Equals(t, ch.Type, "dns-01")
+ assert.Equals(t, ch.Token, az.Token)
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, az.Identifier.Value)
+ return errors.New("force")
+ },
+ },
+ az: az,
+ err: acme.NewErrorISE("error creating challenge: force"),
+ }
+ },
+ "fail/error-db.CreateAuthorization": func(t *testing.T) test {
+ az := &acme.Authorization{
+ AccountID: "accID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "zap.internal",
+ },
+ Status: acme.StatusPending,
+ ExpiresAt: clock.Now(),
+ }
+ count := 0
+ var ch1, ch2, ch3 **acme.Challenge
+ return test{
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ switch count {
+ case 0:
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ ch1 = &ch
+ case 1:
+ ch.ID = "http"
+ assert.Equals(t, ch.Type, "http-01")
+ ch2 = &ch
+ case 2:
+ ch.ID = "tls"
+ assert.Equals(t, ch.Type, "tls-alpn-01")
+ ch3 = &ch
+ default:
+ assert.FatalError(t, errors.New("test logic error"))
+ return errors.New("force")
+ }
+ count++
+ assert.Equals(t, ch.AccountID, az.AccountID)
+ assert.Equals(t, ch.Token, az.Token)
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, az.Identifier.Value)
+ return nil
+ },
+ MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error {
+ assert.Equals(t, _az.AccountID, az.AccountID)
+ assert.Equals(t, _az.Token, az.Token)
+ assert.Equals(t, _az.Status, acme.StatusPending)
+ assert.Equals(t, _az.Identifier, az.Identifier)
+ assert.Equals(t, _az.ExpiresAt, az.ExpiresAt)
+ assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
+ assert.Equals(t, _az.Wildcard, false)
+ return errors.New("force")
+ },
+ },
+ az: az,
+ err: acme.NewErrorISE("error creating authorization: force"),
+ }
+ },
+ "ok/no-wildcard": func(t *testing.T) test {
+ az := &acme.Authorization{
+ AccountID: "accID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "zap.internal",
+ },
+ Status: acme.StatusPending,
+ ExpiresAt: clock.Now(),
+ }
+ count := 0
+ var ch1, ch2, ch3 **acme.Challenge
+ return test{
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ switch count {
+ case 0:
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ ch1 = &ch
+ case 1:
+ ch.ID = "http"
+ assert.Equals(t, ch.Type, "http-01")
+ ch2 = &ch
+ case 2:
+ ch.ID = "tls"
+ assert.Equals(t, ch.Type, "tls-alpn-01")
+ ch3 = &ch
+ default:
+ assert.FatalError(t, errors.New("test logic error"))
+ return errors.New("force")
+ }
+ count++
+ assert.Equals(t, ch.AccountID, az.AccountID)
+ assert.Equals(t, ch.Token, az.Token)
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, az.Identifier.Value)
+ return nil
+ },
+ MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error {
+ assert.Equals(t, _az.AccountID, az.AccountID)
+ assert.Equals(t, _az.Token, az.Token)
+ assert.Equals(t, _az.Status, acme.StatusPending)
+ assert.Equals(t, _az.Identifier, az.Identifier)
+ assert.Equals(t, _az.ExpiresAt, az.ExpiresAt)
+ assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
+ assert.Equals(t, _az.Wildcard, false)
+ return nil
+ },
+ },
+ az: az,
+ }
+ },
+ "ok/wildcard": func(t *testing.T) test {
+ az := &acme.Authorization{
+ AccountID: "accID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "*.zap.internal",
+ },
+ Status: acme.StatusPending,
+ ExpiresAt: clock.Now(),
+ }
+ var ch1 **acme.Challenge
+ return test{
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ assert.Equals(t, ch.AccountID, az.AccountID)
+ assert.Equals(t, ch.Token, az.Token)
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, "zap.internal")
+ ch1 = &ch
+ return nil
+ },
+ MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error {
+ assert.Equals(t, _az.AccountID, az.AccountID)
+ assert.Equals(t, _az.Token, az.Token)
+ assert.Equals(t, _az.Status, acme.StatusPending)
+ assert.Equals(t, _az.Identifier, acme.Identifier{
+ Type: "dns",
+ Value: "zap.internal",
+ })
+ assert.Equals(t, _az.ExpiresAt, az.ExpiresAt)
+ assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1})
+ assert.Equals(t, _az.Wildcard, true)
+ return nil
+ },
+ },
+ az: az,
+ }
+ },
+ }
+ for name, run := range tests {
+ t.Run(name, func(t *testing.T) {
+ tc := run(t)
+ h := &Handler{db: tc.db}
+ if err := h.newAuthorization(context.Background(), tc.az); err != nil {
+ if assert.NotNil(t, tc.err) {
+ switch k := err.(type) {
+ case *acme.Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
+ }
+ } else {
+ assert.Nil(t, tc.err)
+ }
+ })
+ }
+}
+
+func TestHandler_NewOrder(t *testing.T) {
+ // Request with chi context
prov := newProv()
- provName := url.PathEscape(prov.GetName())
+ escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- url := fmt.Sprintf("%s/acme/%s/new-order",
- baseURL.String(), provName)
+ url := fmt.Sprintf("%s/acme/%s/order/ordID",
+ baseURL.String(), escProvName)
type test struct {
- auth acme.Interface
+ db acme.DB
ctx context.Context
+ nor *NewOrderRequest
statusCode int
- problem *acme.Error
+ vr func(t *testing.T, o *acme.Order)
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, nil)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, nil)
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
+ }
+ },
+ "fail/no-provisioner": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
+ return test{
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("provisioner does not exist"),
+ }
+ },
+ "fail/nil-provisioner": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ return test{
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("provisioner does not exist"),
}
},
"fail/no-payload": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
+ ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload does not exist"),
}
},
"fail/nil-payload": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("paylod does not exist"),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("failed to unmarshal new-order request payload: unexpected end of JSON input")),
+ err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"),
}
},
"fail/malformed-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- nor := &NewOrderRequest{}
- b, err := json.Marshal(nor)
+ fr := &NewOrderRequest{}
+ b, err := json.Marshal(fr)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("identifiers list cannot be empty")),
+ err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
}
},
- "fail/NewOrder-error": func(t *testing.T) test {
+ "fail/error-h.newAuthorization": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- nor := &NewOrderRequest{
+ fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
- {Type: "dns", Value: "example.com"},
- {Type: "dns", Value: "bar.com"},
+ {Type: "dns", Value: "zap.internal"},
},
}
- b, err := json.Marshal(nor)
+ b, err := json.Marshal(fr)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
- auth: &mockAcmeAuthority{
- newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, ops.AccountID, acc.ID)
- assert.Equals(t, ops.Identifiers, nor.Identifiers)
- return nil, acme.MalformedErr(errors.New("force"))
- },
- },
ctx: ctx,
- statusCode: 400,
- problem: acme.MalformedErr(errors.New("force")),
+ statusCode: 500,
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ assert.Equals(t, ch.AccountID, "accID")
+ assert.Equals(t, ch.Type, "dns-01")
+ assert.NotEquals(t, ch.Token, "")
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, "zap.internal")
+ return errors.New("force")
+ },
+ },
+ err: acme.NewErrorISE("error creating challenge: force"),
}
},
- "ok": func(t *testing.T) test {
+ "fail/error-db.CreateOrder": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accID"}
+ fr := &NewOrderRequest{
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "zap.internal"},
+ },
+ }
+ b, err := json.Marshal(fr)
+ assert.FatalError(t, err)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ var (
+ ch1, ch2, ch3 **acme.Challenge
+ az1ID *string
+ count = 0
+ )
+ return test{
+ ctx: ctx,
+ statusCode: 500,
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ switch count {
+ case 0:
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ ch1 = &ch
+ case 1:
+ ch.ID = "http"
+ assert.Equals(t, ch.Type, "http-01")
+ ch2 = &ch
+ case 2:
+ ch.ID = "tls"
+ assert.Equals(t, ch.Type, "tls-alpn-01")
+ ch3 = &ch
+ default:
+ assert.FatalError(t, errors.New("test logic error"))
+ return errors.New("force")
+ }
+ count++
+ assert.Equals(t, ch.AccountID, "accID")
+ assert.NotEquals(t, ch.Token, "")
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, "zap.internal")
+ return nil
+ },
+ MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
+ az.ID = "az1ID"
+ az1ID = &az.ID
+ assert.Equals(t, az.AccountID, "accID")
+ assert.NotEquals(t, az.Token, "")
+ assert.Equals(t, az.Status, acme.StatusPending)
+ assert.Equals(t, az.Identifier, fr.Identifiers[0])
+ assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
+ assert.Equals(t, az.Wildcard, false)
+ return nil
+ },
+ MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
+ assert.Equals(t, o.AccountID, "accID")
+ assert.Equals(t, o.ProvisionerID, prov.GetID())
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, fr.Identifiers)
+ assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
+ return errors.New("force")
+ },
+ },
+ err: acme.NewErrorISE("error creating order: force"),
+ }
+ },
+ "ok/multiple-authz": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
nor := &NewOrderRequest{
Identifiers: []acme.Identifier{
- {Type: "dns", Value: "example.com"},
- {Type: "dns", Value: "bar.com"},
+ {Type: "dns", Value: "zap.internal"},
+ {Type: "dns", Value: "*.zar.internal"},
},
- NotBefore: nbf,
- NotAfter: naf,
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
+ var (
+ ch1, ch2, ch3, ch4 **acme.Challenge
+ az1ID, az2ID *string
+ chCount, azCount = 0, 0
+ )
return test{
- auth: &mockAcmeAuthority{
- newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, ops.AccountID, acc.ID)
- assert.Equals(t, ops.Identifiers, nor.Identifiers)
- assert.Equals(t, ops.NotBefore, nbf)
- assert.Equals(t, ops.NotAfter, naf)
- return &o, nil
- },
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.OrderLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{o.ID})
- return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID)
- },
- },
ctx: ctx,
statusCode: 201,
+ nor: nor,
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ switch chCount {
+ case 0:
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ assert.Equals(t, ch.Value, "zap.internal")
+ ch1 = &ch
+ case 1:
+ ch.ID = "http"
+ assert.Equals(t, ch.Type, "http-01")
+ assert.Equals(t, ch.Value, "zap.internal")
+ ch2 = &ch
+ case 2:
+ ch.ID = "tls"
+ assert.Equals(t, ch.Type, "tls-alpn-01")
+ assert.Equals(t, ch.Value, "zap.internal")
+ ch3 = &ch
+ case 3:
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ assert.Equals(t, ch.Value, "zar.internal")
+ ch4 = &ch
+ default:
+ assert.FatalError(t, errors.New("test logic error"))
+ return errors.New("force")
+ }
+ chCount++
+ assert.Equals(t, ch.AccountID, "accID")
+ assert.NotEquals(t, ch.Token, "")
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ return nil
+ },
+ MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
+ switch azCount {
+ case 0:
+ az.ID = "az1ID"
+ az1ID = &az.ID
+ assert.Equals(t, az.Identifier, nor.Identifiers[0])
+ assert.Equals(t, az.Wildcard, false)
+ assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
+ case 1:
+ az.ID = "az2ID"
+ az2ID = &az.ID
+ assert.Equals(t, az.Identifier, acme.Identifier{
+ Type: "dns",
+ Value: "zar.internal",
+ })
+ assert.Equals(t, az.Wildcard, true)
+ assert.Equals(t, az.Challenges, []*acme.Challenge{*ch4})
+ default:
+ assert.FatalError(t, errors.New("test logic error"))
+ return errors.New("force")
+ }
+ azCount++
+ assert.Equals(t, az.AccountID, "accID")
+ assert.NotEquals(t, az.Token, "")
+ assert.Equals(t, az.Status, acme.StatusPending)
+ return nil
+ },
+ MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
+ o.ID = "ordID"
+ assert.Equals(t, o.AccountID, "accID")
+ assert.Equals(t, o.ProvisionerID, prov.GetID())
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID})
+ return nil
+ },
+ },
+ vr: func(t *testing.T, o *acme.Order) {
+ now := clock.Now()
+ testBufferDur := 5 * time.Second
+ orderExpiry := now.Add(defaultOrderExpiry)
+ expNbf := now.Add(-defaultOrderBackdate)
+ expNaf := now.Add(prov.DefaultTLSCertDuration())
+
+ assert.Equals(t, o.ID, "ordID")
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationURLs, []string{
+ fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName),
+ fmt.Sprintf("%s/acme/%s/authz/az2ID", baseURL.String(), escProvName),
+ })
+ assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
+ assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
+ assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
+ assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf))
+ assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry))
+ assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry))
+ },
}
},
"ok/default-naf-nbf": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
nor := &NewOrderRequest{
Identifiers: []acme.Identifier{
- {Type: "dns", Value: "example.com"},
- {Type: "dns", Value: "bar.com"},
+ {Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
+ var (
+ ch1, ch2, ch3 **acme.Challenge
+ az1ID *string
+ count = 0
+ )
return test{
- auth: &mockAcmeAuthority{
- newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, ops.AccountID, acc.ID)
- assert.Equals(t, ops.Identifiers, nor.Identifiers)
-
- assert.True(t, ops.NotBefore.IsZero())
- assert.True(t, ops.NotAfter.IsZero())
- return &o, nil
- },
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.OrderLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{o.ID})
- return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID)
- },
- },
ctx: ctx,
statusCode: 201,
+ nor: nor,
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ switch count {
+ case 0:
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ ch1 = &ch
+ case 1:
+ ch.ID = "http"
+ assert.Equals(t, ch.Type, "http-01")
+ ch2 = &ch
+ case 2:
+ ch.ID = "tls"
+ assert.Equals(t, ch.Type, "tls-alpn-01")
+ ch3 = &ch
+ default:
+ assert.FatalError(t, errors.New("test logic error"))
+ return errors.New("force")
+ }
+ count++
+ assert.Equals(t, ch.AccountID, "accID")
+ assert.NotEquals(t, ch.Token, "")
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, "zap.internal")
+ return nil
+ },
+ MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
+ az.ID = "az1ID"
+ az1ID = &az.ID
+ assert.Equals(t, az.AccountID, "accID")
+ assert.NotEquals(t, az.Token, "")
+ assert.Equals(t, az.Status, acme.StatusPending)
+ assert.Equals(t, az.Identifier, nor.Identifiers[0])
+ assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
+ assert.Equals(t, az.Wildcard, false)
+ return nil
+ },
+ MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
+ o.ID = "ordID"
+ assert.Equals(t, o.AccountID, "accID")
+ assert.Equals(t, o.ProvisionerID, prov.GetID())
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
+ return nil
+ },
+ },
+ vr: func(t *testing.T, o *acme.Order) {
+ now := clock.Now()
+ testBufferDur := 5 * time.Second
+ orderExpiry := now.Add(defaultOrderExpiry)
+ expNbf := now.Add(-defaultOrderBackdate)
+ expNaf := now.Add(prov.DefaultTLSCertDuration())
+
+ assert.Equals(t, o.ID, "ordID")
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
+ assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
+ assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
+ assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
+ assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf))
+ assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry))
+ assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry))
+ },
+ }
+ },
+ "ok/nbf-no-naf": func(t *testing.T) test {
+ now := clock.Now()
+ expNbf := now.Add(10 * time.Minute)
+ acc := &acme.Account{ID: "accID"}
+ nor := &NewOrderRequest{
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "zap.internal"},
+ },
+ NotBefore: expNbf,
+ }
+ b, err := json.Marshal(nor)
+ assert.FatalError(t, err)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
+ var (
+ ch1, ch2, ch3 **acme.Challenge
+ az1ID *string
+ count = 0
+ )
+ return test{
+ ctx: ctx,
+ statusCode: 201,
+ nor: nor,
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ switch count {
+ case 0:
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ ch1 = &ch
+ case 1:
+ ch.ID = "http"
+ assert.Equals(t, ch.Type, "http-01")
+ ch2 = &ch
+ case 2:
+ ch.ID = "tls"
+ assert.Equals(t, ch.Type, "tls-alpn-01")
+ ch3 = &ch
+ default:
+ assert.FatalError(t, errors.New("test logic error"))
+ return errors.New("force")
+ }
+ count++
+ assert.Equals(t, ch.AccountID, "accID")
+ assert.NotEquals(t, ch.Token, "")
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, "zap.internal")
+ return nil
+ },
+ MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
+ az.ID = "az1ID"
+ az1ID = &az.ID
+ assert.Equals(t, az.AccountID, "accID")
+ assert.NotEquals(t, az.Token, "")
+ assert.Equals(t, az.Status, acme.StatusPending)
+ assert.Equals(t, az.Identifier, nor.Identifiers[0])
+ assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
+ assert.Equals(t, az.Wildcard, false)
+ return nil
+ },
+ MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
+ o.ID = "ordID"
+ assert.Equals(t, o.AccountID, "accID")
+ assert.Equals(t, o.ProvisionerID, prov.GetID())
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
+ return nil
+ },
+ },
+ vr: func(t *testing.T, o *acme.Order) {
+ now := clock.Now()
+ testBufferDur := 5 * time.Second
+ orderExpiry := now.Add(defaultOrderExpiry)
+ expNaf := expNbf.Add(prov.DefaultTLSCertDuration())
+
+ assert.Equals(t, o.ID, "ordID")
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
+ assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
+ assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
+ assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
+ assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf))
+ assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry))
+ assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry))
+ },
+ }
+ },
+ "ok/naf-no-nbf": func(t *testing.T) test {
+ now := clock.Now()
+ expNaf := now.Add(15 * time.Minute)
+ acc := &acme.Account{ID: "accID"}
+ nor := &NewOrderRequest{
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "zap.internal"},
+ },
+ NotAfter: expNaf,
+ }
+ b, err := json.Marshal(nor)
+ assert.FatalError(t, err)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
+ var (
+ ch1, ch2, ch3 **acme.Challenge
+ az1ID *string
+ count = 0
+ )
+ return test{
+ ctx: ctx,
+ statusCode: 201,
+ nor: nor,
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ switch count {
+ case 0:
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ ch1 = &ch
+ case 1:
+ ch.ID = "http"
+ assert.Equals(t, ch.Type, "http-01")
+ ch2 = &ch
+ case 2:
+ ch.ID = "tls"
+ assert.Equals(t, ch.Type, "tls-alpn-01")
+ ch3 = &ch
+ default:
+ assert.FatalError(t, errors.New("test logic error"))
+ return errors.New("force")
+ }
+ count++
+ assert.Equals(t, ch.AccountID, "accID")
+ assert.NotEquals(t, ch.Token, "")
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, "zap.internal")
+ return nil
+ },
+ MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
+ az.ID = "az1ID"
+ az1ID = &az.ID
+ assert.Equals(t, az.AccountID, "accID")
+ assert.NotEquals(t, az.Token, "")
+ assert.Equals(t, az.Status, acme.StatusPending)
+ assert.Equals(t, az.Identifier, nor.Identifiers[0])
+ assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
+ assert.Equals(t, az.Wildcard, false)
+ return nil
+ },
+ MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
+ o.ID = "ordID"
+ assert.Equals(t, o.AccountID, "accID")
+ assert.Equals(t, o.ProvisionerID, prov.GetID())
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
+ return nil
+ },
+ },
+ vr: func(t *testing.T, o *acme.Order) {
+ testBufferDur := 5 * time.Second
+ orderExpiry := now.Add(defaultOrderExpiry)
+ expNbf := now.Add(-defaultOrderBackdate)
+
+ assert.Equals(t, o.ID, "ordID")
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
+ assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
+ assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
+ assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
+ assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf))
+ assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry))
+ assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry))
+ },
+ }
+ },
+ "ok/naf-nbf": func(t *testing.T) test {
+ now := clock.Now()
+ expNbf := now.Add(5 * time.Minute)
+ expNaf := now.Add(15 * time.Minute)
+ acc := &acme.Account{ID: "accID"}
+ nor := &NewOrderRequest{
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "zap.internal"},
+ },
+ NotBefore: expNbf,
+ NotAfter: expNaf,
+ }
+ b, err := json.Marshal(nor)
+ assert.FatalError(t, err)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
+ var (
+ ch1, ch2, ch3 **acme.Challenge
+ az1ID *string
+ count = 0
+ )
+ return test{
+ ctx: ctx,
+ statusCode: 201,
+ nor: nor,
+ db: &acme.MockDB{
+ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
+ switch count {
+ case 0:
+ ch.ID = "dns"
+ assert.Equals(t, ch.Type, "dns-01")
+ ch1 = &ch
+ case 1:
+ ch.ID = "http"
+ assert.Equals(t, ch.Type, "http-01")
+ ch2 = &ch
+ case 2:
+ ch.ID = "tls"
+ assert.Equals(t, ch.Type, "tls-alpn-01")
+ ch3 = &ch
+ default:
+ assert.FatalError(t, errors.New("test logic error"))
+ return errors.New("force")
+ }
+ count++
+ assert.Equals(t, ch.AccountID, "accID")
+ assert.NotEquals(t, ch.Token, "")
+ assert.Equals(t, ch.Status, acme.StatusPending)
+ assert.Equals(t, ch.Value, "zap.internal")
+ return nil
+ },
+ MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
+ az.ID = "az1ID"
+ az1ID = &az.ID
+ assert.Equals(t, az.AccountID, "accID")
+ assert.NotEquals(t, az.Token, "")
+ assert.Equals(t, az.Status, acme.StatusPending)
+ assert.Equals(t, az.Identifier, nor.Identifiers[0])
+ assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
+ assert.Equals(t, az.Wildcard, false)
+ return nil
+ },
+ MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
+ o.ID = "ordID"
+ assert.Equals(t, o.AccountID, "accID")
+ assert.Equals(t, o.ProvisionerID, prov.GetID())
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
+ return nil
+ },
+ },
+ vr: func(t *testing.T, o *acme.Order) {
+ testBufferDur := 5 * time.Second
+ orderExpiry := now.Add(defaultOrderExpiry)
+
+ assert.Equals(t, o.ID, "ordID")
+ assert.Equals(t, o.Status, acme.StatusPending)
+ assert.Equals(t, o.Identifiers, nor.Identifiers)
+ assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
+ assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
+ assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
+ assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
+ assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf))
+ assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry))
+ assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry))
+ },
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -499,115 +1268,151 @@ func TestHandlerNewOrder(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
- expB, err := json.Marshal(o)
- assert.FatalError(t, err)
- assert.Equals(t, bytes.TrimSpace(body), expB)
- assert.Equals(t, res.Header["Location"],
- []string{fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(),
- provName, o.ID)})
+ ro := new(acme.Order)
+ assert.FatalError(t, json.Unmarshal(body, ro))
+ if tc.vr != nil {
+ tc.vr(t, ro)
+ }
+
+ assert.Equals(t, res.Header["Location"], []string{url})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}
-func TestHandlerFinalizeOrder(t *testing.T) {
- expiry := time.Now().UTC().Add(6 * time.Hour)
- nbf := time.Now().UTC().Add(5 * time.Hour)
- naf := nbf.Add(17 * time.Hour)
+func TestHandler_FinalizeOrder(t *testing.T) {
+ prov := newProv()
+ escProvName := url.PathEscape(prov.GetName())
+ baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
+
+ now := clock.Now()
+ nbf := now
+ naf := now.Add(24 * time.Hour)
o := acme.Order{
ID: "orderID",
- Expires: expiry.Format(time.RFC3339),
- NotBefore: nbf.Format(time.RFC3339),
- NotAfter: naf.Format(time.RFC3339),
+ NotBefore: nbf,
+ NotAfter: naf,
Identifiers: []acme.Identifier{
- {Type: "dns", Value: "example.com"},
- {Type: "dns", Value: "bar.com"},
+ {
+ Type: "dns",
+ Value: "example.com",
+ },
+ {
+ Type: "dns",
+ Value: "*.smallstep.com",
+ },
},
- Status: "valid",
- Authorizations: []string{"foo", "bar"},
- Certificate: "https://ca.smallstep.com/acme/certificate/certID",
+ ExpiresAt: naf,
+ Status: acme.StatusValid,
+ AuthorizationURLs: []string{
+ fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName),
+ fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName),
+ fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName),
+ },
+ FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName),
+ CertificateURL: fmt.Sprintf("%s/acme/%s/certificate/certID", baseURL.String(), escProvName),
}
+
+ // Request with chi context
+ chiCtx := chi.NewRouteContext()
+ chiCtx.URLParams.Add("ordID", o.ID)
+ url := fmt.Sprintf("%s/acme/%s/order/%s",
+ baseURL.String(), escProvName, o.ID)
+
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
assert.FatalError(t, err)
csr, ok := _csr.(*x509.CertificateRequest)
assert.Fatal(t, ok)
- // Request with chi context
- chiCtx := chi.NewRouteContext()
- chiCtx.URLParams.Add("ordID", o.ID)
- prov := newProv()
- provName := url.PathEscape(prov.GetName())
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- url := fmt.Sprintf("%s/acme/%s/order/%s/finalize",
- baseURL.String(), provName, o.ID)
+ nor := &FinalizeRequest{
+ CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
+ }
+ payloadBytes, err := json.Marshal(nor)
+ assert.FatalError(t, err)
type test struct {
- auth acme.Interface
+ db acme.DB
ctx context.Context
statusCode int
- problem *acme.Error
+ err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
- auth: &mockAcmeAuthority{},
- ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
+ ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, nil)
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, nil)
return test{
- auth: &mockAcmeAuthority{},
ctx: ctx,
statusCode: 400,
- problem: acme.AccountDoesNotExistErr(nil),
+ err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
+ }
+ },
+ "fail/no-provisioner": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
+ return test{
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("provisioner does not exist"),
+ }
+ },
+ "fail/nil-provisioner": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ return test{
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("provisioner does not exist"),
}
},
"fail/no-payload": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), accContextKey, acc)
+ ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("payload does not exist"),
}
},
"fail/nil-payload": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
- problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
+ err: acme.NewErrorISE("paylod does not exist"),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("failed to unmarshal finalize-order request payload: unexpected end of JSON input")),
+ err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"),
}
},
"fail/malformed-payload-error": func(t *testing.T) test {
@@ -615,72 +1420,121 @@ func TestHandlerFinalizeOrder(t *testing.T) {
fr := &FinalizeRequest{}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
- problem: acme.MalformedErr(errors.New("unable to parse csr: asn1: syntax error: sequence truncated")),
+ err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
}
},
- "fail/FinalizeOrder-error": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- nor := &FinalizeRequest{
- CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
- }
- b, err := json.Marshal(nor)
- assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ "fail/db.GetOrder-error": func(t *testing.T) test {
+
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
- auth: &mockAcmeAuthority{
- finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, accID, acc.ID)
- assert.Equals(t, id, o.ID)
- assert.Equals(t, incsr.Raw, csr.Raw)
- return nil, acme.MalformedErr(errors.New("force"))
+ db: &acme.MockDB{
+ MockError: acme.NewErrorISE("force"),
+ },
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("force"),
+ }
+ },
+ "fail/account-id-mismatch": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
+ return &acme.Order{AccountID: "foo"}, nil
},
},
ctx: ctx,
- statusCode: 400,
- problem: acme.MalformedErr(errors.New("force")),
+ statusCode: 401,
+ err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"),
+ }
+ },
+ "fail/provisioner-id-mismatch": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
+ return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil
+ },
+ },
+ ctx: ctx,
+ statusCode: 401,
+ err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"),
+ }
+ },
+ "fail/order-finalize-error": func(t *testing.T) test {
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
+ ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
+ return test{
+ db: &acme.MockDB{
+ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
+ return &acme.Order{
+ AccountID: "accountID",
+ ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()),
+ ExpiresAt: clock.Now().Add(-time.Hour),
+ Status: acme.StatusReady,
+ }, nil
+ },
+ MockUpdateOrder: func(ctx context.Context, o *acme.Order) error {
+ return acme.NewErrorISE("force")
+ },
+ },
+ ctx: ctx,
+ statusCode: 500,
+ err: acme.NewErrorISE("force"),
}
},
"ok": func(t *testing.T) test {
- acc := &acme.Account{ID: "accID"}
- nor := &FinalizeRequest{
- CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
- }
- b, err := json.Marshal(nor)
- assert.FatalError(t, err)
- ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, acme.AccContextKey, acc)
- ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
+ acc := &acme.Account{ID: "accountID"}
+ ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
+ ctx = context.WithValue(ctx, accContextKey, acc)
+ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
+ ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
- ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{
- auth: &mockAcmeAuthority{
- finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
- p, err := acme.ProvisionerFromContext(ctx)
- assert.FatalError(t, err)
- assert.Equals(t, p, prov)
- assert.Equals(t, accID, acc.ID)
- assert.Equals(t, id, o.ID)
- assert.Equals(t, incsr.Raw, csr.Raw)
- return &o, nil
- },
- getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
- assert.Equals(t, typ, acme.OrderLink)
- assert.True(t, abs)
- assert.Equals(t, in, []string{o.ID})
- return fmt.Sprintf("%s/acme/%s/order/%s",
- baseURL.String(), provName, o.ID)
+ db: &acme.MockDB{
+ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
+ return &acme.Order{
+ ID: "orderID",
+ AccountID: "accountID",
+ ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()),
+ ExpiresAt: naf,
+ Status: acme.StatusValid,
+ AuthorizationIDs: []string{"foo", "bar", "baz"},
+ NotBefore: nbf,
+ NotAfter: naf,
+ Identifiers: []acme.Identifier{
+ {
+ Type: "dns",
+ Value: "example.com",
+ },
+ {
+ Type: "dns",
+ Value: "*.smallstep.com",
+ },
+ },
+ CertificateID: "certID",
+ }, nil
},
},
ctx: ctx,
@@ -691,7 +1545,7 @@ func TestHandlerFinalizeOrder(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
- h := New(tc.auth).(*Handler)
+ h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
@@ -704,23 +1558,24 @@ func TestHandlerFinalizeOrder(t *testing.T) {
res.Body.Close()
assert.FatalError(t, err)
- if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
- var ae acme.AError
+ if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
+ var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
- prob := tc.problem.ToACME()
- assert.Equals(t, ae.Type, prob.Type)
- assert.Equals(t, ae.Detail, prob.Detail)
- assert.Equals(t, ae.Identifier, prob.Identifier)
- assert.Equals(t, ae.Subproblems, prob.Subproblems)
+ assert.Equals(t, ae.Type, tc.err.Type)
+ assert.Equals(t, ae.Detail, tc.err.Detail)
+ assert.Equals(t, ae.Identifier, tc.err.Identifier)
+ assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(o)
assert.FatalError(t, err)
+
+ ro := new(acme.Order)
+ assert.FatalError(t, json.Unmarshal(body, ro))
+
assert.Equals(t, bytes.TrimSpace(body), expB)
- assert.Equals(t, res.Header["Location"],
- []string{fmt.Sprintf("%s/acme/%s/order/%s",
- baseURL, provName, o.ID)})
+ assert.Equals(t, res.Header["Location"], []string{url})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
diff --git a/acme/authority.go b/acme/authority.go
deleted file mode 100644
index 0f5f2c9f..00000000
--- a/acme/authority.go
+++ /dev/null
@@ -1,342 +0,0 @@
-package acme
-
-import (
- "context"
- "crypto"
- "crypto/tls"
- "crypto/x509"
- "encoding/base64"
- "net"
- "net/http"
- "net/url"
- "time"
-
- "github.com/pkg/errors"
- "github.com/smallstep/certificates/authority/provisioner"
- database "github.com/smallstep/certificates/db"
- "github.com/smallstep/nosql"
- "go.step.sm/crypto/jose"
-)
-
-// Interface is the acme authority interface.
-type Interface interface {
- GetDirectory(ctx context.Context) (*Directory, error)
- NewNonce() (string, error)
- UseNonce(string) error
-
- DeactivateAccount(ctx context.Context, accID string) (*Account, error)
- GetAccount(ctx context.Context, accID string) (*Account, error)
- GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error)
- NewAccount(ctx context.Context, ao AccountOptions) (*Account, error)
- UpdateAccount(context.Context, string, []string) (*Account, error)
-
- GetAuthz(ctx context.Context, accID string, authzID string) (*Authz, error)
- ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error)
-
- FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error)
- GetOrder(ctx context.Context, accID string, orderID string) (*Order, error)
- GetOrdersByAccount(ctx context.Context, accID string) ([]string, error)
- NewOrder(ctx context.Context, oo OrderOptions) (*Order, error)
-
- GetCertificate(string, string) ([]byte, error)
-
- LoadProvisionerByID(string) (provisioner.Interface, error)
- GetLink(ctx context.Context, linkType Link, absoluteLink bool, inputs ...string) string
- GetLinkExplicit(linkType Link, provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string
-}
-
-// Authority is the layer that handles all ACME interactions.
-type Authority struct {
- backdate provisioner.Duration
- db nosql.DB
- dir *directory
- signAuth SignAuthority
-}
-
-// AuthorityOptions required to create a new ACME Authority.
-type AuthorityOptions struct {
- Backdate provisioner.Duration
- // DB is the database used by nosql.
- DB nosql.DB
- // DNS the host used to generate accurate ACME links. By default the authority
- // will use the Host from the request, so this value will only be used if
- // request.Host is empty.
- DNS string
- // Prefix is a URL path prefix under which the ACME api is served. This
- // prefix is required to generate accurate ACME links.
- // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
- // "acme" is the prefix from which the ACME api is accessed.
- Prefix string
-}
-
-var (
- accountTable = []byte("acme_accounts")
- accountByKeyIDTable = []byte("acme_keyID_accountID_index")
- authzTable = []byte("acme_authzs")
- challengeTable = []byte("acme_challenges")
- nonceTable = []byte("nonces")
- orderTable = []byte("acme_orders")
- ordersByAccountIDTable = []byte("acme_account_orders_index")
- certTable = []byte("acme_certs")
-)
-
-// NewAuthority returns a new Authority that implements the ACME interface.
-//
-// Deprecated: NewAuthority exists for hitorical compatibility and should not
-// be used. Use acme.New() instead.
-func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) {
- return New(signAuth, AuthorityOptions{
- DB: db,
- DNS: dns,
- Prefix: prefix,
- })
-}
-
-// New returns a new Autohrity that implements the ACME interface.
-func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) {
- if _, ok := ops.DB.(*database.SimpleDB); !ok {
- // If it's not a SimpleDB then go ahead and bootstrap the DB with the
- // necessary ACME tables. SimpleDB should ONLY be used for testing.
- tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
- challengeTable, nonceTable, orderTable, ordersByAccountIDTable,
- certTable}
- for _, b := range tables {
- if err := ops.DB.CreateTable(b); err != nil {
- return nil, errors.Wrapf(err, "error creating table %s",
- string(b))
- }
- }
- }
- return &Authority{
- backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth,
- }, nil
-}
-
-// GetLink returns the requested link from the directory.
-func (a *Authority) GetLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
- return a.dir.getLink(ctx, typ, abs, inputs...)
-}
-
-// GetLinkExplicit returns the requested link from the directory.
-func (a *Authority) GetLinkExplicit(typ Link, provName string, abs bool, baseURL *url.URL, inputs ...string) string {
- return a.dir.getLinkExplicit(typ, provName, abs, baseURL, inputs...)
-}
-
-// GetDirectory returns the ACME directory object.
-func (a *Authority) GetDirectory(ctx context.Context) (*Directory, error) {
- return &Directory{
- NewNonce: a.dir.getLink(ctx, NewNonceLink, true),
- NewAccount: a.dir.getLink(ctx, NewAccountLink, true),
- NewOrder: a.dir.getLink(ctx, NewOrderLink, true),
- RevokeCert: a.dir.getLink(ctx, RevokeCertLink, true),
- KeyChange: a.dir.getLink(ctx, KeyChangeLink, true),
- }, nil
-}
-
-// LoadProvisionerByID calls out to the SignAuthority interface to load a
-// provisioner by ID.
-func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
- return a.signAuth.LoadProvisionerByID(id)
-}
-
-// NewNonce generates, stores, and returns a new ACME nonce.
-func (a *Authority) NewNonce() (string, error) {
- n, err := newNonce(a.db)
- if err != nil {
- return "", err
- }
- return n.ID, nil
-}
-
-// UseNonce consumes the given nonce if it is valid, returns error otherwise.
-func (a *Authority) UseNonce(nonce string) error {
- return useNonce(a.db, nonce)
-}
-
-// NewAccount creates, stores, and returns a new ACME account.
-func (a *Authority) NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) {
- acc, err := newAccount(a.db, ao)
- if err != nil {
- return nil, err
- }
- return acc.toACME(ctx, a.db, a.dir)
-}
-
-// UpdateAccount updates an ACME account.
-func (a *Authority) UpdateAccount(ctx context.Context, id string, contact []string) (*Account, error) {
- acc, err := getAccountByID(a.db, id)
- if err != nil {
- return nil, ServerInternalErr(err)
- }
- if acc, err = acc.update(a.db, contact); err != nil {
- return nil, err
- }
- return acc.toACME(ctx, a.db, a.dir)
-}
-
-// GetAccount returns an ACME account.
-func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) {
- acc, err := getAccountByID(a.db, id)
- if err != nil {
- return nil, err
- }
- return acc.toACME(ctx, a.db, a.dir)
-}
-
-// DeactivateAccount deactivates an ACME account.
-func (a *Authority) DeactivateAccount(ctx context.Context, id string) (*Account, error) {
- acc, err := getAccountByID(a.db, id)
- if err != nil {
- return nil, err
- }
- if acc, err = acc.deactivate(a.db); err != nil {
- return nil, err
- }
- return acc.toACME(ctx, a.db, a.dir)
-}
-
-func keyToID(jwk *jose.JSONWebKey) (string, error) {
- kid, err := jwk.Thumbprint(crypto.SHA256)
- if err != nil {
- return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint"))
- }
- return base64.RawURLEncoding.EncodeToString(kid), nil
-}
-
-// GetAccountByKey returns the ACME associated with the jwk id.
-func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) {
- kid, err := keyToID(jwk)
- if err != nil {
- return nil, err
- }
- acc, err := getAccountByKeyID(a.db, kid)
- if err != nil {
- return nil, err
- }
- return acc.toACME(ctx, a.db, a.dir)
-}
-
-// GetOrder returns an ACME order.
-func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) {
- o, err := getOrder(a.db, orderID)
- if err != nil {
- return nil, err
- }
- if accID != o.AccountID {
- return nil, UnauthorizedErr(errors.New("account does not own order"))
- }
- if o, err = o.updateStatus(a.db); err != nil {
- return nil, err
- }
- return o.toACME(ctx, a.db, a.dir)
-}
-
-// GetOrdersByAccount returns the list of order urls owned by the account.
-func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) {
- ordersByAccountMux.Lock()
- defer ordersByAccountMux.Unlock()
-
- var oiba = orderIDsByAccount{}
- oids, err := oiba.unsafeGetOrderIDsByAccount(a.db, id)
- if err != nil {
- return nil, err
- }
-
- var ret = []string{}
- for _, oid := range oids {
- ret = append(ret, a.dir.getLink(ctx, OrderLink, true, oid))
- }
- return ret, nil
-}
-
-// NewOrder generates, stores, and returns a new ACME order.
-func (a *Authority) NewOrder(ctx context.Context, ops OrderOptions) (*Order, error) {
- prov, err := ProvisionerFromContext(ctx)
- if err != nil {
- return nil, err
- }
- ops.backdate = a.backdate.Duration
- ops.defaultDuration = prov.DefaultTLSCertDuration()
- order, err := newOrder(a.db, ops)
- if err != nil {
- return nil, Wrap(err, "error creating order")
- }
- return order.toACME(ctx, a.db, a.dir)
-}
-
-// FinalizeOrder attempts to finalize an order and generate a new certificate.
-func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
- prov, err := ProvisionerFromContext(ctx)
- if err != nil {
- return nil, err
- }
- o, err := getOrder(a.db, orderID)
- if err != nil {
- return nil, err
- }
- if accID != o.AccountID {
- return nil, UnauthorizedErr(errors.New("account does not own order"))
- }
- o, err = o.finalize(a.db, csr, a.signAuth, prov)
- if err != nil {
- return nil, Wrap(err, "error finalizing order")
- }
- return o.toACME(ctx, a.db, a.dir)
-}
-
-// GetAuthz retrieves and attempts to update the status on an ACME authz
-// before returning.
-func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authz, error) {
- az, err := getAuthz(a.db, authzID)
- if err != nil {
- return nil, err
- }
- if accID != az.getAccountID() {
- return nil, UnauthorizedErr(errors.New("account does not own authz"))
- }
- az, err = az.updateStatus(a.db)
- if err != nil {
- return nil, Wrap(err, "error updating authz status")
- }
- return az.toACME(ctx, a.db, a.dir)
-}
-
-// ValidateChallenge attempts to validate the challenge.
-func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) {
- ch, err := getChallenge(a.db, chID)
- if err != nil {
- return nil, err
- }
- if accID != ch.getAccountID() {
- return nil, UnauthorizedErr(errors.New("account does not own challenge"))
- }
- client := http.Client{
- Timeout: time.Duration(30 * time.Second),
- }
- dialer := &net.Dialer{
- Timeout: 30 * time.Second,
- }
- ch, err = ch.validate(a.db, jwk, validateOptions{
- httpGet: client.Get,
- lookupTxt: net.LookupTXT,
- tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
- return tls.DialWithDialer(dialer, network, addr, config)
- },
- })
- if err != nil {
- return nil, Wrap(err, "error attempting challenge validation")
- }
- return ch.toACME(ctx, a.db, a.dir)
-}
-
-// GetCertificate retrieves the Certificate by ID.
-func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) {
- cert, err := getCert(a.db, certID)
- if err != nil {
- return nil, err
- }
- if accID != cert.AccountID {
- return nil, UnauthorizedErr(errors.New("account does not own certificate"))
- }
- return cert.toACME(a.db, a.dir)
-}
diff --git a/acme/authority_test.go b/acme/authority_test.go
deleted file mode 100644
index 8861c15e..00000000
--- a/acme/authority_test.go
+++ /dev/null
@@ -1,1739 +0,0 @@
-package acme
-
-import (
- "context"
- "crypto"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "net/http"
- "net/http/httptest"
- "net/url"
- "strings"
- "testing"
- "time"
-
- "github.com/pkg/errors"
- "github.com/smallstep/assert"
- "github.com/smallstep/certificates/db"
- "github.com/smallstep/nosql/database"
- "go.step.sm/crypto/jose"
-)
-
-func TestAuthorityGetLink(t *testing.T) {
- auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- prov := newProv()
- provName := url.PathEscape(prov.GetName())
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
- type test struct {
- auth *Authority
- typ Link
- abs bool
- inputs []string
- res string
- }
- tests := map[string]func(t *testing.T) test{
- "ok/new-account/abs": func(t *testing.T) test {
- return test{
- auth: auth,
- typ: NewAccountLink,
- abs: true,
- res: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
- }
- },
- "ok/new-account/no-abs": func(t *testing.T) test {
- return test{
- auth: auth,
- typ: NewAccountLink,
- abs: false,
- res: fmt.Sprintf("/%s/new-account", provName),
- }
- },
- "ok/order/abs": func(t *testing.T) test {
- return test{
- auth: auth,
- typ: OrderLink,
- abs: true,
- inputs: []string{"foo"},
- res: fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName),
- }
- },
- "ok/order/no-abs": func(t *testing.T) test {
- return test{
- auth: auth,
- typ: OrderLink,
- abs: false,
- inputs: []string{"foo"},
- res: fmt.Sprintf("/%s/order/foo", provName),
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- link := tc.auth.GetLink(ctx, tc.typ, tc.abs, tc.inputs...)
- assert.Equals(t, tc.res, link)
- })
- }
-}
-
-func TestAuthorityGetDirectory(t *testing.T) {
- auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
-
- prov := newProv()
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
-
- type test struct {
- ctx context.Context
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "ok/empty-provisioner": func(t *testing.T) test {
- return test{
- ctx: context.Background(),
- }
- },
- "ok/no-baseURL": func(t *testing.T) test {
- return test{
- ctx: context.WithValue(context.Background(), ProvisionerContextKey, prov),
- }
- },
- "ok/baseURL": func(t *testing.T) test {
- return test{
- ctx: ctx,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if dir, err := auth.GetDirectory(tc.ctx); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- bu := BaseURLFromContext(tc.ctx)
- if bu == nil {
- bu = &url.URL{Scheme: "https", Host: "ca.smallstep.com"}
- }
-
- var provName string
- prov, err := ProvisionerFromContext(tc.ctx)
- if err != nil {
- provName = ""
- } else {
- provName = url.PathEscape(prov.GetName())
- }
-
- assert.Equals(t, dir.NewNonce, fmt.Sprintf("%s/acme/%s/new-nonce", bu.String(), provName))
- assert.Equals(t, dir.NewAccount, fmt.Sprintf("%s/acme/%s/new-account", bu.String(), provName))
- assert.Equals(t, dir.NewOrder, fmt.Sprintf("%s/acme/%s/new-order", bu.String(), provName))
- assert.Equals(t, dir.RevokeCert, fmt.Sprintf("%s/acme/%s/revoke-cert", bu.String(), provName))
- assert.Equals(t, dir.KeyChange, fmt.Sprintf("%s/acme/%s/key-change", bu.String(), provName))
- }
- }
- })
- }
-}
-
-func TestAuthorityNewNonce(t *testing.T) {
- type test struct {
- auth *Authority
- res *string
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/newNonce-error": func(t *testing.T) test {
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- res: nil,
- err: ServerInternalErr(errors.New("error storing nonce: force")),
- }
- },
- "ok": func(t *testing.T) test {
- var _res string
- res := &_res
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- *res = string(key)
- return nil, true, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- res: res,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if nonce, err := tc.auth.NewNonce(); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, nonce, *tc.res)
- }
- }
- })
- }
-}
-
-func TestAuthorityUseNonce(t *testing.T) {
- type test struct {
- auth *Authority
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/newNonce-error": func(t *testing.T) test {
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MUpdate: func(tx *database.Tx) error {
- return errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- err: ServerInternalErr(errors.New("error deleting nonce foo: force")),
- }
- },
- "ok": func(t *testing.T) test {
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MUpdate: func(tx *database.Tx) error {
- return nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if err := tc.auth.UseNonce("foo"); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- assert.Nil(t, tc.err)
- }
- })
- }
-}
-
-func TestAuthorityNewAccount(t *testing.T) {
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
- ops := AccountOptions{
- Key: jwk, Contact: []string{"foo", "bar"},
- }
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
- type test struct {
- auth *Authority
- ops AccountOptions
- err *Error
- acc **Account
- }
- tests := map[string]func(t *testing.T) test{
- "fail/newAccount-error": func(t *testing.T) test {
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- ops: ops,
- err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
- }
- },
- "ok": func(t *testing.T) test {
- var (
- _acmeacc = &Account{}
- acmeacc = &_acmeacc
- count = 0
- dir = newDirectory("ca.smallstep.com", "acme")
- )
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 1 {
- var acc *account
- assert.FatalError(t, json.Unmarshal(newval, &acc))
- *acmeacc, err = acc.toACME(ctx, nil, dir)
- return nil, true, nil
- }
- count++
- return nil, true, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- ops: ops,
- acc: acmeacc,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeAcc, err := tc.auth.NewAccount(ctx, tc.ops); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeAcc)
- assert.FatalError(t, err)
- expb, err := json.Marshal(*tc.acc)
- assert.FatalError(t, err)
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
-
-func TestAuthorityGetAccount(t *testing.T) {
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
- type test struct {
- auth *Authority
- id string
- err *Error
- acc *account
- }
- tests := map[string]func(t *testing.T) test{
- "fail/getAccount-error": func(t *testing.T) test {
- id := "foo"
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(id))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)),
- }
- },
- "ok": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: acc.ID,
- acc: acc,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeAcc, err := tc.auth.GetAccount(ctx, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeAcc)
- assert.FatalError(t, err)
-
- acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir)
- assert.FatalError(t, err)
- expb, err := json.Marshal(acmeExp)
- assert.FatalError(t, err)
-
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
-
-func TestAuthorityGetAccountByKey(t *testing.T) {
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
- type test struct {
- auth *Authority
- jwk *jose.JSONWebKey
- err *Error
- acc *account
- }
- tests := map[string]func(t *testing.T) test{
- "fail/generate-thumbprint-error": func(t *testing.T) test {
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
- jwk.Key = "foo"
- auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- jwk: jwk,
- err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
- }
- },
- "fail/getAccount-error": func(t *testing.T) test {
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
- kid, err := keyToID(jwk)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte(kid))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- jwk: jwk,
- err: ServerInternalErr(errors.New("error loading key-account index: force")),
- }
- },
- "ok": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
- count := 0
- kid, err := keyToID(acc.Key)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- var ret []byte
- switch {
- case count == 0:
- assert.Equals(t, bucket, accountByKeyIDTable)
- assert.Equals(t, key, []byte(kid))
- ret = []byte(acc.ID)
- case count == 1:
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- ret = b
- }
- count++
- return ret, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- jwk: acc.Key,
- acc: acc,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeAcc, err := tc.auth.GetAccountByKey(ctx, tc.jwk); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeAcc)
- assert.FatalError(t, err)
-
- acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir)
- assert.FatalError(t, err)
- expb, err := json.Marshal(acmeExp)
- assert.FatalError(t, err)
-
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
-
-func TestAuthorityGetOrder(t *testing.T) {
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
- type test struct {
- auth *Authority
- id, accID string
- err *Error
- o *order
- }
- tests := map[string]func(t *testing.T) test{
- "fail/getOrder-error": func(t *testing.T) test {
- id := "foo"
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(id))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- err: ServerInternalErr(errors.New("error loading order foo: force")),
- }
- },
- "fail/order-not-owned-by-account": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(o.ID))
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: o.ID,
- accID: "foo",
- err: UnauthorizedErr(errors.New("account does not own order")),
- }
- },
- "fail/updateStatus-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- i := 0
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- switch {
- case i == 0:
- i++
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(o.ID))
- return b, nil
- default:
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, key, []byte(o.Authorizations[0]))
- return nil, ServerInternalErr(errors.New("force"))
- }
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: o.ID,
- accID: o.AccountID,
- err: ServerInternalErr(errors.Errorf("error loading authz %s: force", o.Authorizations[0])),
- }
- },
- "ok": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = "valid"
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(o.ID))
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: o.ID,
- accID: o.AccountID,
- o: o,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeO, err := tc.auth.GetOrder(ctx, tc.accID, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeO)
- assert.FatalError(t, err)
-
- acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir)
- assert.FatalError(t, err)
- expb, err := json.Marshal(acmeExp)
- assert.FatalError(t, err)
-
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
-
-func TestAuthorityGetCertificate(t *testing.T) {
- type test struct {
- auth *Authority
- id, accID string
- err *Error
- cert *certificate
- }
- tests := map[string]func(t *testing.T) test{
- "fail/getCertificate-error": func(t *testing.T) test {
- id := "foo"
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, key, []byte(id))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- err: ServerInternalErr(errors.New("error loading certificate: force")),
- }
- },
- "fail/certificate-not-owned-by-account": func(t *testing.T) test {
- cert, err := newcert()
- assert.FatalError(t, err)
- b, err := json.Marshal(cert)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, key, []byte(cert.ID))
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: cert.ID,
- accID: "foo",
- err: UnauthorizedErr(errors.New("account does not own certificate")),
- }
- },
- "ok": func(t *testing.T) test {
- cert, err := newcert()
- assert.FatalError(t, err)
- b, err := json.Marshal(cert)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, key, []byte(cert.ID))
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: cert.ID,
- accID: cert.AccountID,
- cert: cert,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeCert, err := tc.auth.GetCertificate(tc.accID, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeCert)
- assert.FatalError(t, err)
-
- acmeExp, err := tc.cert.toACME(nil, tc.auth.dir)
- assert.FatalError(t, err)
- expb, err := json.Marshal(acmeExp)
- assert.FatalError(t, err)
-
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
-
-func TestAuthorityGetAuthz(t *testing.T) {
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
- type test struct {
- auth *Authority
- id, accID string
- err *Error
- acmeAz *Authz
- }
- tests := map[string]func(t *testing.T) test{
- "fail/getAuthz-error": func(t *testing.T) test {
- id := "foo"
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, key, []byte(id))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- err: ServerInternalErr(errors.Errorf("error loading authz %s: force", id)),
- }
- },
- "fail/authz-not-owned-by-account": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, key, []byte(az.getID()))
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: az.getID(),
- accID: "foo",
- err: UnauthorizedErr(errors.New("account does not own authz")),
- }
- },
- "fail/update-status-error": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
- count := 0
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- var ret []byte
- switch count {
- case 0:
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, key, []byte(az.getID()))
- ret = b
- case 1:
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(az.getChallenges()[0]))
- return nil, errors.New("force")
- }
- count++
- return ret, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: az.getID(),
- accID: az.getAccountID(),
- err: ServerInternalErr(errors.New("error updating authz status: error loading challenge")),
- }
- },
- "ok": func(t *testing.T) test {
- var ch1B, ch2B, ch3B = &[]byte{}, &[]byte{}, &[]byte{}
- count := 0
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- switch count {
- case 0:
- *ch1B = newval
- case 1:
- *ch2B = newval
- case 2:
- *ch3B = newval
- }
- count++
- return nil, true, nil
- },
- }
- az, err := newAuthz(mockdb, "1234", Identifier{
- Type: "dns", Value: "acme.example.com",
- })
- assert.FatalError(t, err)
- _az, ok := az.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az.baseAuthz.Status = StatusValid
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
-
- ch1, err := unmarshalChallenge(*ch1B)
- assert.FatalError(t, err)
- ch2, err := unmarshalChallenge(*ch2B)
- assert.FatalError(t, err)
- ch3, err := unmarshalChallenge(*ch3B)
- assert.FatalError(t, err)
- count = 0
- mockdb = &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- var ret []byte
- switch count {
- case 0:
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch1.getID()))
- ret = *ch1B
- case 1:
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch2.getID()))
- ret = *ch2B
- case 2:
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch3.getID()))
- ret = *ch3B
- }
- count++
- return ret, nil
- },
- }
- acmeAz, err := az.toACME(ctx, mockdb, newDirectory("ca.smallstep.com", "acme"))
- assert.FatalError(t, err)
-
- count = 0
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- var ret []byte
- switch count {
- case 0:
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, key, []byte(az.getID()))
- ret = b
- case 1:
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch1.getID()))
- ret = *ch1B
- case 2:
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch2.getID()))
- ret = *ch2B
- case 3:
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch3.getID()))
- ret = *ch3B
- }
- count++
- return ret, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: az.getID(),
- accID: az.getAccountID(),
- acmeAz: acmeAz,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeAz, err := tc.auth.GetAuthz(ctx, tc.accID, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeAz)
- assert.FatalError(t, err)
-
- expb, err := json.Marshal(tc.acmeAz)
- assert.FatalError(t, err)
-
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
-
-func TestAuthorityNewOrder(t *testing.T) {
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
- type test struct {
- auth *Authority
- ops OrderOptions
- ctx context.Context
- err *Error
- o **Order
- }
- tests := map[string]func(t *testing.T) test{
- "fail/no-provisioner": func(t *testing.T) test {
- auth, err := NewAuthority(&db.MockNoSQLDB{}, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- ops: defaultOrderOps(),
- ctx: context.Background(),
- err: ServerInternalErr(errors.New("provisioner expected in request context")),
- }
- },
- "fail/newOrder-error": func(t *testing.T) test {
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- ops: defaultOrderOps(),
- ctx: ctx,
- err: ServerInternalErr(errors.New("error creating order: error creating http challenge: error saving acme challenge: force")),
- }
- },
- "ok": func(t *testing.T) test {
- var (
- _acmeO = &Order{}
- acmeO = &_acmeO
- count = 0
- dir = newDirectory("ca.smallstep.com", "acme")
- err error
- _accID string
- accID = &_accID
- )
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- switch count {
- case 0:
- assert.Equals(t, bucket, challengeTable)
- case 1:
- assert.Equals(t, bucket, challengeTable)
- case 2:
- assert.Equals(t, bucket, challengeTable)
- case 3:
- assert.Equals(t, bucket, authzTable)
- case 4:
- assert.Equals(t, bucket, challengeTable)
- case 5:
- assert.Equals(t, bucket, challengeTable)
- case 6:
- assert.Equals(t, bucket, challengeTable)
- case 7:
- assert.Equals(t, bucket, authzTable)
- case 8:
- assert.Equals(t, bucket, orderTable)
- var o order
- assert.FatalError(t, json.Unmarshal(newval, &o))
- *acmeO, err = o.toACME(ctx, nil, dir)
- assert.FatalError(t, err)
- *accID = o.AccountID
- case 9:
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, string(key), *accID)
- }
- count++
- return nil, true, nil
- },
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- ops: defaultOrderOps(),
- ctx: ctx,
- o: acmeO,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeO, err := tc.auth.NewOrder(tc.ctx, tc.ops); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeO)
- assert.FatalError(t, err)
- expb, err := json.Marshal(*tc.o)
- assert.FatalError(t, err)
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
-
-func TestAuthorityGetOrdersByAccount(t *testing.T) {
- prov := newProv()
- provName := url.PathEscape(prov.GetName())
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
- type test struct {
- auth *Authority
- id string
- err *Error
- res []string
- }
- tests := map[string]func(t *testing.T) test{
- "fail/getOrderIDsByAccount-error": func(t *testing.T) test {
- id := "foo"
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte(id))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")),
- }
- },
- "fail/getOrder-error": func(t *testing.T) test {
- var (
- id = "zap"
- oids = []string{"foo", "bar"}
- count = 0
- err error
- )
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- var ret []byte
- switch count {
- case 0:
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte(id))
- ret, err = json.Marshal(oids)
- assert.FatalError(t, err)
- case 1:
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(oids[0]))
- return nil, errors.New("force")
- }
- count++
- return ret, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- err: ServerInternalErr(errors.New("error loading order foo for account zap: error loading order foo: force")),
- }
- },
- "ok": func(t *testing.T) test {
- accID := "zap"
-
- foo, err := newO()
- assert.FatalError(t, err)
- bfoo, err := json.Marshal(foo)
- assert.FatalError(t, err)
-
- bar, err := newO()
- assert.FatalError(t, err)
- bar.Status = StatusInvalid
- bbar, err := json.Marshal(bar)
- assert.FatalError(t, err)
-
- zap, err := newO()
- assert.FatalError(t, err)
- bzap, err := json.Marshal(zap)
- assert.FatalError(t, err)
-
- az, err := newAz()
- assert.FatalError(t, err)
- baz, err := json.Marshal(az)
- assert.FatalError(t, err)
-
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- bch, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- dbGetOrder := 0
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- switch string(bucket) {
- case string(orderTable):
- dbGetOrder++
- switch dbGetOrder {
- case 1:
- return bfoo, nil
- case 2:
- return bbar, nil
- case 3:
- return bzap, nil
- }
- case string(ordersByAccountIDTable):
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte(accID))
- ret, err := json.Marshal([]string{foo.ID, bar.ID, zap.ID})
- assert.FatalError(t, err)
- return ret, nil
- case string(challengeTable):
- return bch, nil
- case string(authzTable):
- return baz, nil
- }
- return nil, errors.Errorf("should not be query db table %s", bucket)
- },
- MCmpAndSwap: func(bucket, key, old, newVal []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, string(key), accID)
- return nil, true, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: accID,
- res: []string{
- fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, foo.ID),
- fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, zap.ID),
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if orderLinks, err := tc.auth.GetOrdersByAccount(ctx, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.res, orderLinks)
- }
- }
- })
- }
-}
-
-func TestAuthorityFinalizeOrder(t *testing.T) {
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
- type test struct {
- auth *Authority
- id, accID string
- ctx context.Context
- err *Error
- o *order
- }
- tests := map[string]func(t *testing.T) test{
- "fail/no-provisioner": func(t *testing.T) test {
- auth, err := NewAuthority(&db.MockNoSQLDB{}, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: "foo",
- ctx: context.Background(),
- err: ServerInternalErr(errors.New("provisioner expected in request context")),
- }
- },
- "fail/getOrder-error": func(t *testing.T) test {
- id := "foo"
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(id))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- ctx: ctx,
- err: ServerInternalErr(errors.New("error loading order foo: force")),
- }
- },
- "fail/order-not-owned-by-account": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(o.ID))
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: o.ID,
- accID: "foo",
- ctx: ctx,
- err: UnauthorizedErr(errors.New("account does not own order")),
- }
- },
- "fail/finalize-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Expires = time.Now().Add(-time.Minute)
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(o.ID))
- return b, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(o.ID))
- return nil, false, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: o.ID,
- accID: o.AccountID,
- ctx: ctx,
- err: ServerInternalErr(errors.New("error finalizing order: error storing order: force")),
- }
- },
- "ok": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusValid
- o.Certificate = "certID"
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(o.ID))
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: o.ID,
- accID: o.AccountID,
- ctx: ctx,
- o: o,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeO, err := tc.auth.FinalizeOrder(tc.ctx, tc.accID, tc.id, nil); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeO)
- assert.FatalError(t, err)
-
- acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir)
- assert.FatalError(t, err)
- expb, err := json.Marshal(acmeExp)
- assert.FatalError(t, err)
-
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
-
-func TestAuthorityValidateChallenge(t *testing.T) {
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
-
- type test struct {
- auth *Authority
- id, accID string
- err *Error
- ch challenge
- jwk *jose.JSONWebKey
- server *httptest.Server
- }
- tests := map[string]func(t *testing.T) test{
- "fail/getChallenge-error": func(t *testing.T) test {
- id := "foo"
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(id))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", id)),
- }
- },
- "fail/challenge-not-owned-by-account": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
- b, err := json.Marshal(ch)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: ch.getID(),
- accID: "foo",
- err: UnauthorizedErr(errors.New("account does not own challenge")),
- }
- },
- "fail/validate-error": func(t *testing.T) test {
- keyauth := "temp"
- keyauthp := &keyauth
- // Create test server that returns challenge auth
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprintf(w, "%s\r\n", *keyauthp)
- }))
- t.Cleanup(func() { ts.Close() })
-
- ch, err := newHTTPChWithServer(strings.TrimPrefix(ts.URL, "http://"))
- assert.FatalError(t, err)
-
- jwk, _, err := jose.GenerateDefaultKeyPair([]byte("pass"))
- assert.FatalError(t, err)
-
- thumbprint, err := jwk.Thumbprint(crypto.SHA256)
- assert.FatalError(t, err)
- encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
- *keyauthp = fmt.Sprintf("%s.%s", ch.getToken(), encPrint)
-
- b, err := json.Marshal(ch)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- return b, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- return nil, false, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: ch.getID(),
- accID: ch.getAccountID(),
- jwk: jwk,
- server: ts,
- err: ServerInternalErr(errors.New("error attempting challenge validation: error saving acme challenge: force")),
- }
- },
- "ok/already-valid": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*http01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Status = StatusValid
- _ch.baseChallenge.Validated = clock.Now()
- b, err := json.Marshal(ch)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- return b, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: ch.getID(),
- accID: ch.getAccountID(),
- ch: ch,
- }
- },
- "ok": func(t *testing.T) test {
- keyauth := "temp"
- keyauthp := &keyauth
- // Create test server that returns challenge auth
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprintf(w, "%s\r\n", *keyauthp)
- }))
- t.Cleanup(func() { ts.Close() })
-
- ch, err := newHTTPChWithServer(strings.TrimPrefix(ts.URL, "http://"))
- assert.FatalError(t, err)
-
- jwk, _, err := jose.GenerateDefaultKeyPair([]byte("pass"))
- assert.FatalError(t, err)
-
- thumbprint, err := jwk.Thumbprint(crypto.SHA256)
- assert.FatalError(t, err)
- encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
- *keyauthp = fmt.Sprintf("%s.%s", ch.getToken(), encPrint)
-
- b, err := json.Marshal(ch)
- assert.FatalError(t, err)
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- return b, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- return nil, true, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: ch.getID(),
- accID: ch.getAccountID(),
- jwk: jwk,
- server: ts,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeCh, err := tc.auth.ValidateChallenge(ctx, tc.accID, tc.id, tc.jwk); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeCh)
- assert.FatalError(t, err)
-
- if tc.ch != nil {
- acmeExp, err := tc.ch.toACME(ctx, nil, tc.auth.dir)
- assert.FatalError(t, err)
- expb, err := json.Marshal(acmeExp)
- assert.FatalError(t, err)
-
- assert.Equals(t, expb, gotb)
- }
- }
- }
- })
- }
-}
-
-func TestAuthorityUpdateAccount(t *testing.T) {
- contact := []string{"baz", "zap"}
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
- type test struct {
- auth *Authority
- id string
- contact []string
- acc *account
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/getAccount-error": func(t *testing.T) test {
- id := "foo"
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(id))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- contact: contact,
- err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)),
- }
- },
- "fail/update-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
-
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return b, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: acc.ID,
- contact: contact,
- err: ServerInternalErr(errors.New("error storing account: force")),
- }
- },
-
- "ok": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
-
- _acc := *acc
- clone := &_acc
- clone.Contact = contact
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return b, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- return nil, true, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: acc.ID,
- contact: contact,
- acc: clone,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeAcc, err := tc.auth.UpdateAccount(ctx, tc.id, tc.contact); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeAcc)
- assert.FatalError(t, err)
-
- acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir)
- assert.FatalError(t, err)
- expb, err := json.Marshal(acmeExp)
- assert.FatalError(t, err)
-
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
-
-func TestAuthorityDeactivateAccount(t *testing.T) {
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
- type test struct {
- auth *Authority
- id string
- acc *account
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/getAccount-error": func(t *testing.T) test {
- id := "foo"
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(id))
- return nil, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: id,
- err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)),
- }
- },
- "fail/deactivate-error": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
-
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return b, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: acc.ID,
- err: ServerInternalErr(errors.New("error storing account: force")),
- }
- },
-
- "ok": func(t *testing.T) test {
- acc, err := newAcc()
- assert.FatalError(t, err)
- b, err := json.Marshal(acc)
- assert.FatalError(t, err)
-
- _acc := *acc
- clone := &_acc
- clone.Status = StatusDeactivated
- clone.Deactivated = clock.Now()
- auth, err := NewAuthority(&db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return b, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, accountTable)
- assert.Equals(t, key, []byte(acc.ID))
- return nil, true, nil
- },
- }, "ca.smallstep.com", "acme", nil)
- assert.FatalError(t, err)
- return test{
- auth: auth,
- id: acc.ID,
- acc: clone,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if acmeAcc, err := tc.auth.DeactivateAccount(ctx, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- gotb, err := json.Marshal(acmeAcc)
- assert.FatalError(t, err)
-
- acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir)
- assert.FatalError(t, err)
- expb, err := json.Marshal(acmeExp)
- assert.FatalError(t, err)
-
- assert.Equals(t, expb, gotb)
- }
- }
- })
- }
-}
diff --git a/acme/authorization.go b/acme/authorization.go
new file mode 100644
index 00000000..d2df5ea5
--- /dev/null
+++ b/acme/authorization.go
@@ -0,0 +1,69 @@
+package acme
+
+import (
+ "context"
+ "encoding/json"
+ "time"
+)
+
+// Authorization representst an ACME Authorization.
+type Authorization struct {
+ ID string `json:"-"`
+ AccountID string `json:"-"`
+ Token string `json:"-"`
+ Identifier Identifier `json:"identifier"`
+ Status Status `json:"status"`
+ Challenges []*Challenge `json:"challenges"`
+ Wildcard bool `json:"wildcard"`
+ ExpiresAt time.Time `json:"expires"`
+ Error *Error `json:"error,omitempty"`
+}
+
+// ToLog enables response logging.
+func (az *Authorization) ToLog() (interface{}, error) {
+ b, err := json.Marshal(az)
+ if err != nil {
+ return nil, WrapErrorISE(err, "error marshaling authz for logging")
+ }
+ return string(b), nil
+}
+
+// UpdateStatus updates the ACME Authorization Status if necessary.
+// Changes to the Authorization are saved using the database interface.
+func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
+ now := clock.Now()
+
+ switch az.Status {
+ case StatusInvalid:
+ return nil
+ case StatusValid:
+ return nil
+ case StatusPending:
+ // check expiry
+ if now.After(az.ExpiresAt) {
+ az.Status = StatusInvalid
+ break
+ }
+
+ var isValid = false
+ for _, ch := range az.Challenges {
+ if ch.Status == StatusValid {
+ isValid = true
+ break
+ }
+ }
+
+ if !isValid {
+ return nil
+ }
+ az.Status = StatusValid
+ az.Error = nil
+ default:
+ return NewErrorISE("unrecognized authorization status: %s", az.Status)
+ }
+
+ if err := db.UpdateAuthorization(ctx, az); err != nil {
+ return WrapErrorISE(err, "error updating authorization")
+ }
+ return nil
+}
diff --git a/acme/authorization_test.go b/acme/authorization_test.go
new file mode 100644
index 00000000..00b35b99
--- /dev/null
+++ b/acme/authorization_test.go
@@ -0,0 +1,150 @@
+package acme
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/assert"
+)
+
+func TestAuthorization_UpdateStatus(t *testing.T) {
+ type test struct {
+ az *Authorization
+ err *Error
+ db DB
+ }
+ tests := map[string]func(t *testing.T) test{
+ "ok/already-invalid": func(t *testing.T) test {
+ az := &Authorization{
+ Status: StatusInvalid,
+ }
+ return test{
+ az: az,
+ }
+ },
+ "ok/already-valid": func(t *testing.T) test {
+ az := &Authorization{
+ Status: StatusInvalid,
+ }
+ return test{
+ az: az,
+ }
+ },
+ "fail/error-unexpected-status": func(t *testing.T) test {
+ az := &Authorization{
+ Status: "foo",
+ }
+ return test{
+ az: az,
+ err: NewErrorISE("unrecognized authorization status: %s", az.Status),
+ }
+ },
+ "ok/expired": func(t *testing.T) test {
+ now := clock.Now()
+ az := &Authorization{
+ ID: "azID",
+ AccountID: "accID",
+ Status: StatusPending,
+ ExpiresAt: now.Add(-5 * time.Minute),
+ }
+ return test{
+ az: az,
+ db: &MockDB{
+ MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
+ assert.Equals(t, updaz.ID, az.ID)
+ assert.Equals(t, updaz.AccountID, az.AccountID)
+ assert.Equals(t, updaz.Status, StatusInvalid)
+ assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
+ return nil
+ },
+ },
+ }
+ },
+ "fail/db.UpdateAuthorization-error": func(t *testing.T) test {
+ now := clock.Now()
+ az := &Authorization{
+ ID: "azID",
+ AccountID: "accID",
+ Status: StatusPending,
+ ExpiresAt: now.Add(-5 * time.Minute),
+ }
+ return test{
+ az: az,
+ db: &MockDB{
+ MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
+ assert.Equals(t, updaz.ID, az.ID)
+ assert.Equals(t, updaz.AccountID, az.AccountID)
+ assert.Equals(t, updaz.Status, StatusInvalid)
+ assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
+ return errors.New("force")
+ },
+ },
+ err: NewErrorISE("error updating authorization: force"),
+ }
+ },
+ "ok/no-valid-challenges": func(t *testing.T) test {
+ now := clock.Now()
+ az := &Authorization{
+ ID: "azID",
+ AccountID: "accID",
+ Status: StatusPending,
+ ExpiresAt: now.Add(5 * time.Minute),
+ Challenges: []*Challenge{
+ {Status: StatusPending}, {Status: StatusPending}, {Status: StatusPending},
+ },
+ }
+ return test{
+ az: az,
+ }
+ },
+ "ok/valid": func(t *testing.T) test {
+ now := clock.Now()
+ az := &Authorization{
+ ID: "azID",
+ AccountID: "accID",
+ Status: StatusPending,
+ ExpiresAt: now.Add(5 * time.Minute),
+ Challenges: []*Challenge{
+ {Status: StatusPending}, {Status: StatusPending}, {Status: StatusValid},
+ },
+ }
+ return test{
+ az: az,
+ db: &MockDB{
+ MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
+ assert.Equals(t, updaz.ID, az.ID)
+ assert.Equals(t, updaz.AccountID, az.AccountID)
+ assert.Equals(t, updaz.Status, StatusValid)
+ assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
+ assert.Equals(t, updaz.Error, nil)
+ return nil
+ },
+ },
+ }
+ },
+ }
+ for name, run := range tests {
+ t.Run(name, func(t *testing.T) {
+ tc := run(t)
+ if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil {
+ if assert.NotNil(t, tc.err) {
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
+ }
+ } else {
+ assert.Nil(t, tc.err)
+ }
+ })
+
+ }
+}
diff --git a/acme/authz.go b/acme/authz.go
deleted file mode 100644
index 8c45bce0..00000000
--- a/acme/authz.go
+++ /dev/null
@@ -1,347 +0,0 @@
-package acme
-
-import (
- "context"
- "encoding/json"
- "strings"
- "time"
-
- "github.com/pkg/errors"
- "github.com/smallstep/nosql"
-)
-
-var defaultExpiryDuration = time.Hour * 24
-
-// Authz is a subset of the Authz type containing only those attributes
-// required for responses in the ACME protocol.
-type Authz struct {
- Identifier Identifier `json:"identifier"`
- Status string `json:"status"`
- Expires string `json:"expires"`
- Challenges []*Challenge `json:"challenges"`
- Wildcard bool `json:"wildcard"`
- ID string `json:"-"`
-}
-
-// ToLog enables response logging.
-func (a *Authz) ToLog() (interface{}, error) {
- b, err := json.Marshal(a)
- if err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging"))
- }
- return string(b), nil
-}
-
-// GetID returns the Authz ID.
-func (a *Authz) GetID() string {
- return a.ID
-}
-
-// authz is the interface that the various authz types must implement.
-type authz interface {
- save(nosql.DB, authz) error
- clone() *baseAuthz
- getID() string
- getAccountID() string
- getType() string
- getIdentifier() Identifier
- getStatus() string
- getExpiry() time.Time
- getWildcard() bool
- getChallenges() []string
- getCreated() time.Time
- updateStatus(db nosql.DB) (authz, error)
- toACME(context.Context, nosql.DB, *directory) (*Authz, error)
-}
-
-// baseAuthz is the base authz type that others build from.
-type baseAuthz struct {
- ID string `json:"id"`
- AccountID string `json:"accountID"`
- Identifier Identifier `json:"identifier"`
- Status string `json:"status"`
- Expires time.Time `json:"expires"`
- Challenges []string `json:"challenges"`
- Wildcard bool `json:"wildcard"`
- Created time.Time `json:"created"`
- Error *Error `json:"error"`
-}
-
-func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) {
- id, err := randID()
- if err != nil {
- return nil, err
- }
-
- now := clock.Now()
- ba := &baseAuthz{
- ID: id,
- AccountID: accID,
- Status: StatusPending,
- Created: now,
- Expires: now.Add(defaultExpiryDuration),
- Identifier: identifier,
- }
-
- if strings.HasPrefix(identifier.Value, "*.") {
- ba.Wildcard = true
- ba.Identifier = Identifier{
- Value: strings.TrimPrefix(identifier.Value, "*."),
- Type: identifier.Type,
- }
- }
-
- return ba, nil
-}
-
-// getID returns the ID of the authz.
-func (ba *baseAuthz) getID() string {
- return ba.ID
-}
-
-// getAccountID returns the Account ID that created the authz.
-func (ba *baseAuthz) getAccountID() string {
- return ba.AccountID
-}
-
-// getType returns the type of the authz.
-func (ba *baseAuthz) getType() string {
- return ba.Identifier.Type
-}
-
-// getIdentifier returns the identifier for the authz.
-func (ba *baseAuthz) getIdentifier() Identifier {
- return ba.Identifier
-}
-
-// getStatus returns the status of the authz.
-func (ba *baseAuthz) getStatus() string {
- return ba.Status
-}
-
-// getWildcard returns true if the authz identifier has a '*', false otherwise.
-func (ba *baseAuthz) getWildcard() bool {
- return ba.Wildcard
-}
-
-// getChallenges returns the authz challenge IDs.
-func (ba *baseAuthz) getChallenges() []string {
- return ba.Challenges
-}
-
-// getExpiry returns the expiration time of the authz.
-func (ba *baseAuthz) getExpiry() time.Time {
- return ba.Expires
-}
-
-// getCreated returns the created time of the authz.
-func (ba *baseAuthz) getCreated() time.Time {
- return ba.Created
-}
-
-// toACME converts the internal Authz type into the public acmeAuthz type for
-// presentation in the ACME protocol.
-func (ba *baseAuthz) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Authz, error) {
- var chs = make([]*Challenge, len(ba.Challenges))
- for i, chID := range ba.Challenges {
- ch, err := getChallenge(db, chID)
- if err != nil {
- return nil, err
- }
- chs[i], err = ch.toACME(ctx, db, dir)
- if err != nil {
- return nil, err
- }
- }
- return &Authz{
- Identifier: ba.Identifier,
- Status: ba.getStatus(),
- Challenges: chs,
- Wildcard: ba.getWildcard(),
- Expires: ba.Expires.Format(time.RFC3339),
- ID: ba.ID,
- }, nil
-}
-
-func (ba *baseAuthz) save(db nosql.DB, old authz) error {
- var (
- err error
- oldB, newB []byte
- )
- if old == nil {
- oldB = nil
- } else {
- if oldB, err = json.Marshal(old); err != nil {
- return ServerInternalErr(errors.Wrap(err, "error marshaling old authz"))
- }
- }
- if newB, err = json.Marshal(ba); err != nil {
- return ServerInternalErr(errors.Wrap(err, "error marshaling new authz"))
- }
- _, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB)
- switch {
- case err != nil:
- return ServerInternalErr(errors.Wrapf(err, "error storing authz"))
- case !swapped:
- return ServerInternalErr(errors.Errorf("error storing authz; " +
- "value has changed since last read"))
- default:
- return nil
- }
-}
-
-func (ba *baseAuthz) clone() *baseAuthz {
- u := *ba
- return &u
-}
-
-func (ba *baseAuthz) parent() authz {
- return &dnsAuthz{ba}
-}
-
-// updateStatus attempts to update the status on a baseAuthz and stores the
-// updating object if necessary.
-func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) {
- newAuthz := ba.clone()
-
- now := time.Now().UTC()
- switch ba.Status {
- case StatusInvalid:
- return ba.parent(), nil
- case StatusValid:
- return ba.parent(), nil
- case StatusPending:
- // check expiry
- if now.After(ba.Expires) {
- newAuthz.Status = StatusInvalid
- newAuthz.Error = MalformedErr(errors.New("authz has expired"))
- break
- }
-
- var isValid = false
- for _, chID := range ba.Challenges {
- ch, err := getChallenge(db, chID)
- if err != nil {
- return ba, err
- }
- if ch.getStatus() == StatusValid {
- isValid = true
- break
- }
- }
-
- if !isValid {
- return ba.parent(), nil
- }
- newAuthz.Status = StatusValid
- newAuthz.Error = nil
- default:
- return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status))
- }
-
- if err := newAuthz.save(db, ba); err != nil {
- return ba, err
- }
- return newAuthz.parent(), nil
-}
-
-// unmarshalAuthz unmarshals an authz type into the correct sub-type.
-func unmarshalAuthz(data []byte) (authz, error) {
- var getType struct {
- Identifier Identifier `json:"identifier"`
- }
- if err := json.Unmarshal(data, &getType); err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type"))
- }
-
- switch getType.Identifier.Type {
- case "dns":
- var ba baseAuthz
- if err := json.Unmarshal(data, &ba); err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz"))
- }
- return &dnsAuthz{&ba}, nil
- default:
- return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s",
- getType.Identifier.Type))
- }
-}
-
-// dnsAuthz represents a dns acme authorization.
-type dnsAuthz struct {
- *baseAuthz
-}
-
-// newAuthz returns a new acme authorization object based on the identifier
-// type.
-func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) {
- switch identifier.Type {
- case "dns":
- a, err = newDNSAuthz(db, accID, identifier)
- default:
- err = MalformedErr(errors.Errorf("unexpected authz type %s",
- identifier.Type))
- }
- return
-}
-
-// newDNSAuthz returns a new dns acme authorization object.
-func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) {
- ba, err := newBaseAuthz(accID, identifier)
- if err != nil {
- return nil, err
- }
-
- ba.Challenges = []string{}
- if !ba.Wildcard {
- // http and alpn challenges are only permitted if the DNS is not a wildcard dns.
- ch1, err := newHTTP01Challenge(db, ChallengeOptions{
- AccountID: accID,
- AuthzID: ba.ID,
- Identifier: ba.Identifier})
- if err != nil {
- return nil, Wrap(err, "error creating http challenge")
- }
- ba.Challenges = append(ba.Challenges, ch1.getID())
-
- ch2, err := newTLSALPN01Challenge(db, ChallengeOptions{
- AccountID: accID,
- AuthzID: ba.ID,
- Identifier: ba.Identifier,
- })
- if err != nil {
- return nil, Wrap(err, "error creating alpn challenge")
- }
- ba.Challenges = append(ba.Challenges, ch2.getID())
- }
- ch3, err := newDNS01Challenge(db, ChallengeOptions{
- AccountID: accID,
- AuthzID: ba.ID,
- Identifier: identifier})
- if err != nil {
- return nil, Wrap(err, "error creating dns challenge")
- }
- ba.Challenges = append(ba.Challenges, ch3.getID())
-
- da := &dnsAuthz{ba}
- if err := da.save(db, nil); err != nil {
- return nil, err
- }
-
- return da, nil
-}
-
-// getAuthz retrieves and unmarshals an ACME authz type from the database.
-func getAuthz(db nosql.DB, id string) (authz, error) {
- b, err := db.Get(authzTable, []byte(id))
- if nosql.IsErrNotFound(err) {
- return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id))
- } else if err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id))
- }
- az, err := unmarshalAuthz(b)
- if err != nil {
- return nil, err
- }
- return az, nil
-}
diff --git a/acme/authz_test.go b/acme/authz_test.go
deleted file mode 100644
index 31e6bb58..00000000
--- a/acme/authz_test.go
+++ /dev/null
@@ -1,836 +0,0 @@
-package acme
-
-import (
- "context"
- "encoding/json"
- "strings"
- "testing"
- "time"
-
- "github.com/pkg/errors"
- "github.com/smallstep/assert"
- "github.com/smallstep/certificates/db"
- "github.com/smallstep/nosql"
- "github.com/smallstep/nosql/database"
-)
-
-func newAz() (authz, error) {
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), true, nil
- },
- }
- return newAuthz(mockdb, "1234", Identifier{
- Type: "dns", Value: "acme.example.com",
- })
-}
-
-func TestGetAuthz(t *testing.T) {
- type test struct {
- id string
- db nosql.DB
- az authz
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/not-found": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- return test{
- az: az,
- id: az.getID(),
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- },
- err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
- }
- },
- "fail/db-error": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- return test{
- az: az,
- id: az.getID(),
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
- }
- },
- "fail/unmarshal-error": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- _az, ok := az.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az.baseAuthz.Identifier.Type = "foo"
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
- return test{
- az: az,
- id: az.getID(),
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, key, []byte(az.getID()))
- return b, nil
- },
- },
- err: ServerInternalErr(errors.New("unexpected authz type foo")),
- }
- },
- "ok": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
- return test{
- az: az,
- id: az.getID(),
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, key, []byte(az.getID()))
- return b, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if az, err := getAuthz(tc.db, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.az.getID(), az.getID())
- assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
- assert.Equals(t, tc.az.getStatus(), az.getStatus())
- assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
- assert.Equals(t, tc.az.getCreated(), az.getCreated())
- assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
- assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
- }
- }
- })
- }
-}
-
-func TestAuthzClone(t *testing.T) {
- az, err := newAz()
- assert.FatalError(t, err)
-
- clone := az.clone()
-
- assert.Equals(t, clone.getID(), az.getID())
- assert.Equals(t, clone.getAccountID(), az.getAccountID())
- assert.Equals(t, clone.getStatus(), az.getStatus())
- assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
- assert.Equals(t, clone.getExpiry(), az.getExpiry())
- assert.Equals(t, clone.getCreated(), az.getCreated())
- assert.Equals(t, clone.getChallenges(), az.getChallenges())
-
- clone.Status = StatusValid
-
- assert.NotEquals(t, clone.getStatus(), az.getStatus())
-}
-
-func TestNewAuthz(t *testing.T) {
- iden := Identifier{
- Type: "dns", Value: "acme.example.com",
- }
- accID := "1234"
- type test struct {
- iden Identifier
- db nosql.DB
- err *Error
- resChs *([]string)
- }
- tests := map[string]func(t *testing.T) test{
- "fail/unexpected-type": func(t *testing.T) test {
- return test{
- iden: Identifier{Type: "foo", Value: "acme.example.com"},
- err: MalformedErr(errors.New("unexpected authz type foo")),
- }
- },
- "fail/new-http-chall-error": func(t *testing.T) test {
- return test{
- iden: iden,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
- }
- },
- "fail/new-tls-alpn-chall-error": func(t *testing.T) test {
- count := 0
- return test{
- iden: iden,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 1 {
- return nil, false, errors.New("force")
- }
- count++
- return nil, true, nil
- },
- },
- err: ServerInternalErr(errors.New("error creating alpn challenge: error saving acme challenge: force")),
- }
- },
- "fail/new-dns-chall-error": func(t *testing.T) test {
- count := 0
- return test{
- iden: iden,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 2 {
- return nil, false, errors.New("force")
- }
- count++
- return nil, true, nil
- },
- },
- err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
- }
- },
- "fail/save-authz-error": func(t *testing.T) test {
- count := 0
- return test{
- iden: iden,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 3 {
- return nil, false, errors.New("force")
- }
- count++
- return nil, true, nil
- },
- },
- err: ServerInternalErr(errors.New("error storing authz: force")),
- }
- },
- "ok": func(t *testing.T) test {
- chs := &([]string{})
- count := 0
- return test{
- iden: iden,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 3 {
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, old, nil)
-
- az, err := unmarshalAuthz(newval)
- assert.FatalError(t, err)
-
- assert.Equals(t, az.getID(), string(key))
- assert.Equals(t, az.getAccountID(), accID)
- assert.Equals(t, az.getStatus(), StatusPending)
- assert.Equals(t, az.getIdentifier(), iden)
- assert.Equals(t, az.getWildcard(), false)
-
- *chs = az.getChallenges()
-
- assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
-
- expiry := az.getCreated().Add(defaultExpiryDuration)
- assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
- assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
- }
- count++
- return nil, true, nil
- },
- },
- resChs: chs,
- }
- },
- "ok/wildcard": func(t *testing.T) test {
- chs := &([]string{})
- count := 0
- _iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
- return test{
- iden: _iden,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 1 {
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, old, nil)
-
- az, err := unmarshalAuthz(newval)
- assert.FatalError(t, err)
-
- assert.Equals(t, az.getID(), string(key))
- assert.Equals(t, az.getAccountID(), accID)
- assert.Equals(t, az.getStatus(), StatusPending)
- assert.Equals(t, az.getIdentifier(), iden)
- assert.Equals(t, az.getWildcard(), true)
-
- *chs = az.getChallenges()
- // Verify that we only have 1 challenge instead of 2.
- assert.True(t, len(*chs) == 1)
-
- assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
-
- expiry := az.getCreated().Add(defaultExpiryDuration)
- assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
- assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
- }
- count++
- return nil, true, nil
- },
- },
- resChs: chs,
- }
- },
- }
- for name, run := range tests {
- tc := run(t)
- t.Run(name, func(t *testing.T) {
- az, err := newAuthz(tc.db, accID, tc.iden)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, az.getAccountID(), accID)
- assert.Equals(t, az.getType(), "dns")
- assert.Equals(t, az.getStatus(), StatusPending)
-
- assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
-
- expiry := az.getCreated().Add(defaultExpiryDuration)
- assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
- assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
-
- assert.Equals(t, az.getChallenges(), *(tc.resChs))
-
- if strings.HasPrefix(tc.iden.Value, "*.") {
- assert.True(t, az.getWildcard())
- assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
- } else {
- assert.False(t, az.getWildcard())
- assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
- }
-
- assert.True(t, az.getID() != "")
- }
- }
- })
- }
-}
-
-func TestAuthzToACME(t *testing.T) {
- dir := newDirectory("ca.smallstep.com", "acme")
-
- var (
- ch1, ch2 challenge
- ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
- err error
- )
-
- count := 0
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 0 {
- *ch1Bytes = newval
- ch1, err = unmarshalChallenge(newval)
- assert.FatalError(t, err)
- } else if count == 1 {
- *ch2Bytes = newval
- ch2, err = unmarshalChallenge(newval)
- assert.FatalError(t, err)
- }
- count++
- return []byte("foo"), true, nil
- },
- }
- iden := Identifier{
- Type: "dns", Value: "acme.example.com",
- }
- az, err := newAuthz(mockdb, "1234", iden)
- assert.FatalError(t, err)
-
- prov := newProv()
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
-
- type test struct {
- db nosql.DB
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/getChallenge1-error": func(t *testing.T) test {
- return test{
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error loading challenge")),
- }
- },
- "fail/getChallenge2-error": func(t *testing.T) test {
- count := 0
- return test{
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- if count == 1 {
- return nil, errors.New("force")
- }
- count++
- return *ch1Bytes, nil
- },
- },
- err: ServerInternalErr(errors.New("error loading challenge")),
- }
- },
- "ok": func(t *testing.T) test {
- count := 0
- return test{
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- if count == 0 {
- count++
- return *ch1Bytes, nil
- }
- return *ch2Bytes, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- tc := run(t)
- t.Run(name, func(t *testing.T) {
- acmeAz, err := az.toACME(ctx, tc.db, dir)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, acmeAz.ID, az.getID())
- assert.Equals(t, acmeAz.Identifier, iden)
- assert.Equals(t, acmeAz.Status, StatusPending)
-
- acmeCh1, err := ch1.toACME(ctx, nil, dir)
- assert.FatalError(t, err)
- acmeCh2, err := ch2.toACME(ctx, nil, dir)
- assert.FatalError(t, err)
-
- assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
- assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
-
- expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
- assert.FatalError(t, err)
- assert.Equals(t, expiry.String(), az.getExpiry().String())
- }
- }
- })
- }
-}
-
-func TestAuthzSave(t *testing.T) {
- type test struct {
- az, old authz
- db nosql.DB
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/old-nil/swap-error": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- return test{
- az: az,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error storing authz: force")),
- }
- },
- "fail/old-nil/swap-false": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- return test{
- az: az,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), false, nil
- },
- },
- err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
- }
- },
- "ok/old-nil": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
- return test{
- az: az,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, nil)
- assert.Equals(t, b, newval)
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, []byte(az.getID()), key)
- return nil, true, nil
- },
- },
- }
- },
- "ok/old-not-nil": func(t *testing.T) test {
- oldAz, err := newAz()
- assert.FatalError(t, err)
- az, err := newAz()
- assert.FatalError(t, err)
-
- oldb, err := json.Marshal(oldAz)
- assert.FatalError(t, err)
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
- return test{
- az: az,
- old: oldAz,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, oldb)
- assert.Equals(t, b, newval)
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, []byte(az.getID()), key)
- return []byte("foo"), true, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if err := tc.az.save(tc.db, tc.old); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- assert.Nil(t, tc.err)
- }
- })
- }
-}
-
-func TestAuthzUnmarshal(t *testing.T) {
- type test struct {
- az authz
- azb []byte
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/nil": func(t *testing.T) test {
- return test{
- azb: nil,
- err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
- }
- },
- "fail/unexpected-type": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- _az, ok := az.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az.baseAuthz.Identifier.Type = "foo"
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
- return test{
- azb: b,
- err: ServerInternalErr(errors.New("unexpected authz type foo")),
- }
- },
- "ok/dns": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
- return test{
- az: az,
- azb: b,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if az, err := unmarshalAuthz(tc.azb); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.az.getID(), az.getID())
- assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
- assert.Equals(t, tc.az.getStatus(), az.getStatus())
- assert.Equals(t, tc.az.getCreated(), az.getCreated())
- assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
- assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
- assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
- }
- }
- })
- }
-}
-
-func TestAuthzUpdateStatus(t *testing.T) {
- type test struct {
- az, res authz
- err *Error
- db nosql.DB
- }
- tests := map[string]func(t *testing.T) test{
- "fail/already-invalid": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- _az, ok := az.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az.baseAuthz.Status = StatusInvalid
- return test{
- az: az,
- res: az,
- }
- },
- "fail/already-valid": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- _az, ok := az.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az.baseAuthz.Status = StatusValid
- return test{
- az: az,
- res: az,
- }
- },
- "fail/unexpected-status": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- _az, ok := az.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az.baseAuthz.Status = StatusReady
- return test{
- az: az,
- res: az,
- err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
- }
- },
- "fail/save-error": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- _az, ok := az.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
- return test{
- az: az,
- res: az,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error storing authz: force")),
- }
- },
- "ok/expired": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
- _az, ok := az.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
-
- clone := az.clone()
- clone.Error = MalformedErr(errors.New("authz has expired"))
- clone.Status = StatusInvalid
- return test{
- az: az,
- res: clone.parent(),
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, true, nil
- },
- },
- }
- },
- "fail/get-challenge-error": func(t *testing.T) test {
- az, err := newAz()
- assert.FatalError(t, err)
-
- return test{
- az: az,
- res: az,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error loading challenge")),
- }
- },
- "ok/valid": func(t *testing.T) test {
- var (
- ch3 challenge
- ch2Bytes = &([]byte{})
- ch1Bytes = &([]byte{})
- err error
- )
-
- count := 0
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 0 {
- *ch1Bytes = newval
- } else if count == 1 {
- *ch2Bytes = newval
- } else if count == 2 {
- ch3, err = unmarshalChallenge(newval)
- assert.FatalError(t, err)
- }
- count++
- return nil, true, nil
- },
- }
- iden := Identifier{
- Type: "dns", Value: "acme.example.com",
- }
- az, err := newAuthz(mockdb, "1234", iden)
- assert.FatalError(t, err)
- _az, ok := az.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az.baseAuthz.Error = MalformedErr(nil)
-
- _ch, ok := ch3.(*dns01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Status = StatusValid
- chb, err := json.Marshal(ch3)
-
- clone := az.clone()
- clone.Status = StatusValid
- clone.Error = nil
-
- count = 0
- return test{
- az: az,
- res: clone.parent(),
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- if count == 0 {
- count++
- return *ch1Bytes, nil
- }
- if count == 1 {
- count++
- return *ch2Bytes, nil
- }
- count++
- return chb, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, true, nil
- },
- },
- }
- },
- "ok/still-pending": func(t *testing.T) test {
- var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
-
- count := 0
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 0 {
- *ch1Bytes = newval
- } else if count == 1 {
- *ch2Bytes = newval
- }
- count++
- return nil, true, nil
- },
- }
- iden := Identifier{
- Type: "dns", Value: "acme.example.com",
- }
- az, err := newAuthz(mockdb, "1234", iden)
- assert.FatalError(t, err)
-
- count = 0
- return test{
- az: az,
- res: az,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- if count == 0 {
- count++
- return *ch1Bytes, nil
- }
- count++
- return *ch2Bytes, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- az, err := tc.az.updateStatus(tc.db)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- expB, err := json.Marshal(tc.res)
- assert.FatalError(t, err)
- b, err := json.Marshal(az)
- assert.FatalError(t, err)
- assert.Equals(t, expB, b)
- }
- }
- })
- }
-}
diff --git a/acme/certificate.go b/acme/certificate.go
index 6a31c880..d46d1a08 100644
--- a/acme/certificate.go
+++ b/acme/certificate.go
@@ -2,88 +2,13 @@ package acme
import (
"crypto/x509"
- "encoding/json"
- "encoding/pem"
- "time"
-
- "github.com/pkg/errors"
- "github.com/smallstep/nosql"
)
-type certificate struct {
- ID string `json:"id"`
- Created time.Time `json:"created"`
- AccountID string `json:"accountID"`
- OrderID string `json:"orderID"`
- Leaf []byte `json:"leaf"`
- Intermediates []byte `json:"intermediates"`
-}
-
-// CertOptions options with which to create and store a cert object.
-type CertOptions struct {
+// Certificate options with which to create and store a cert object.
+type Certificate struct {
+ ID string
AccountID string
OrderID string
Leaf *x509.Certificate
Intermediates []*x509.Certificate
}
-
-func newCert(db nosql.DB, ops CertOptions) (*certificate, error) {
- id, err := randID()
- if err != nil {
- return nil, err
- }
-
- leaf := pem.EncodeToMemory(&pem.Block{
- Type: "CERTIFICATE",
- Bytes: ops.Leaf.Raw,
- })
- var intermediates []byte
- for _, cert := range ops.Intermediates {
- intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
- Type: "CERTIFICATE",
- Bytes: cert.Raw,
- })...)
- }
-
- cert := &certificate{
- ID: id,
- AccountID: ops.AccountID,
- OrderID: ops.OrderID,
- Leaf: leaf,
- Intermediates: intermediates,
- Created: time.Now().UTC(),
- }
- certB, err := json.Marshal(cert)
- if err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate"))
- }
-
- _, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB)
- switch {
- case err != nil:
- return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate"))
- case !swapped:
- return nil, ServerInternalErr(errors.New("error storing certificate; " +
- "value has changed since last read"))
- default:
- return cert, nil
- }
-}
-
-func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) {
- return append(c.Leaf, c.Intermediates...), nil
-}
-
-func getCert(db nosql.DB, id string) (*certificate, error) {
- b, err := db.Get(certTable, []byte(id))
- if nosql.IsErrNotFound(err) {
- return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id))
- } else if err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate"))
- }
- var cert certificate
- if err := json.Unmarshal(b, &cert); err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate"))
- }
- return &cert, nil
-}
diff --git a/acme/certificate_test.go b/acme/certificate_test.go
deleted file mode 100644
index a4b8f91a..00000000
--- a/acme/certificate_test.go
+++ /dev/null
@@ -1,253 +0,0 @@
-package acme
-
-import (
- "crypto/x509"
- "encoding/json"
- "encoding/pem"
- "testing"
- "time"
-
- "github.com/pkg/errors"
- "github.com/smallstep/assert"
- "github.com/smallstep/certificates/db"
- "github.com/smallstep/nosql"
- "github.com/smallstep/nosql/database"
- "go.step.sm/crypto/pemutil"
-)
-
-func defaultCertOps() (*CertOptions, error) {
- crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt")
- if err != nil {
- return nil, err
- }
- inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt")
- if err != nil {
- return nil, err
- }
- root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt")
- if err != nil {
- return nil, err
- }
- return &CertOptions{
- AccountID: "accID",
- OrderID: "ordID",
- Leaf: crt,
- Intermediates: []*x509.Certificate{inter, root},
- }, nil
-}
-
-func newcert() (*certificate, error) {
- ops, err := defaultCertOps()
- if err != nil {
- return nil, err
- }
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, true, nil
- },
- }
- return newCert(mockdb, *ops)
-}
-
-func TestNewCert(t *testing.T) {
- type test struct {
- db nosql.DB
- ops CertOptions
- err *Error
- id *string
- }
- tests := map[string]func(t *testing.T) test{
- "fail/cmpAndSwap-error": func(t *testing.T) test {
- ops, err := defaultCertOps()
- assert.FatalError(t, err)
- return test{
- ops: *ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, old, nil)
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.Errorf("error storing certificate: force")),
- }
- },
- "fail/cmpAndSwap-false": func(t *testing.T) test {
- ops, err := defaultCertOps()
- assert.FatalError(t, err)
- return test{
- ops: *ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, old, nil)
- return nil, false, nil
- },
- },
- err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")),
- }
- },
- "ok": func(t *testing.T) test {
- ops, err := defaultCertOps()
- assert.FatalError(t, err)
- var _id string
- id := &_id
- return test{
- ops: *ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, old, nil)
- *id = string(key)
- return nil, true, nil
- },
- },
- id: id,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if cert, err := newCert(tc.db, tc.ops); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, cert.ID, *tc.id)
- assert.Equals(t, cert.AccountID, tc.ops.AccountID)
- assert.Equals(t, cert.OrderID, tc.ops.OrderID)
-
- leaf := pem.EncodeToMemory(&pem.Block{
- Type: "CERTIFICATE",
- Bytes: tc.ops.Leaf.Raw,
- })
- var intermediates []byte
- for _, cert := range tc.ops.Intermediates {
- intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
- Type: "CERTIFICATE",
- Bytes: cert.Raw,
- })...)
- }
- assert.Equals(t, cert.Leaf, leaf)
- assert.Equals(t, cert.Intermediates, intermediates)
-
- assert.True(t, cert.Created.Before(time.Now().Add(time.Minute)))
- assert.True(t, cert.Created.After(time.Now().Add(-time.Minute)))
- }
- }
- })
- }
-}
-
-func TestGetCert(t *testing.T) {
- type test struct {
- id string
- db nosql.DB
- cert *certificate
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/not-found": func(t *testing.T) test {
- cert, err := newcert()
- assert.FatalError(t, err)
- return test{
- cert: cert,
- id: cert.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, key, []byte(cert.ID))
- return nil, database.ErrNotFound
- },
- },
- err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)),
- }
- },
- "fail/db-error": func(t *testing.T) test {
- cert, err := newcert()
- assert.FatalError(t, err)
- return test{
- cert: cert,
- id: cert.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, key, []byte(cert.ID))
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error loading certificate: force")),
- }
- },
- "fail/unmarshal-error": func(t *testing.T) test {
- cert, err := newcert()
- assert.FatalError(t, err)
- return test{
- cert: cert,
- id: cert.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, key, []byte(cert.ID))
- return nil, nil
- },
- },
- err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")),
- }
- },
- "ok": func(t *testing.T) test {
- cert, err := newcert()
- assert.FatalError(t, err)
- b, err := json.Marshal(cert)
- assert.FatalError(t, err)
- return test{
- cert: cert,
- id: cert.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, certTable)
- assert.Equals(t, key, []byte(cert.ID))
- return b, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if cert, err := getCert(tc.db, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.cert.ID, cert.ID)
- assert.Equals(t, tc.cert.AccountID, cert.AccountID)
- assert.Equals(t, tc.cert.OrderID, cert.OrderID)
- assert.Equals(t, tc.cert.Created, cert.Created)
- assert.Equals(t, tc.cert.Leaf, cert.Leaf)
- assert.Equals(t, tc.cert.Intermediates, cert.Intermediates)
- }
- }
- })
- }
-}
-
-func TestCertificateToACME(t *testing.T) {
- cert, err := newcert()
- assert.FatalError(t, err)
- acmeCert, err := cert.toACME(nil, nil)
- assert.FatalError(t, err)
- assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert)
-}
diff --git a/acme/challenge.go b/acme/challenge.go
index 6d2d13d1..1059e437 100644
--- a/acme/challenge.go
+++ b/acme/challenge.go
@@ -14,394 +14,115 @@ import (
"io/ioutil"
"net"
"net/http"
+ "net/url"
"strings"
"time"
- "github.com/pkg/errors"
- "github.com/smallstep/nosql"
"go.step.sm/crypto/jose"
)
-// Challenge is a subset of the challenge type containing only those attributes
-// required for responses in the ACME protocol.
+// Challenge represents an ACME response Challenge type.
type Challenge struct {
- Type string `json:"type"`
- Status string `json:"status"`
- Token string `json:"token"`
- Validated string `json:"validated,omitempty"`
- URL string `json:"url"`
- Error *AError `json:"error,omitempty"`
- ID string `json:"-"`
- AuthzID string `json:"-"`
+ ID string `json:"-"`
+ AccountID string `json:"-"`
+ AuthorizationID string `json:"-"`
+ Value string `json:"-"`
+ Type string `json:"type"`
+ Status Status `json:"status"`
+ Token string `json:"token"`
+ ValidatedAt string `json:"validated,omitempty"`
+ URL string `json:"url"`
+ Error *Error `json:"error,omitempty"`
}
// ToLog enables response logging.
-func (c *Challenge) ToLog() (interface{}, error) {
- b, err := json.Marshal(c)
+func (ch *Challenge) ToLog() (interface{}, error) {
+ b, err := json.Marshal(ch)
if err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging"))
+ return nil, WrapErrorISE(err, "error marshaling challenge for logging")
}
return string(b), nil
}
-// GetID returns the Challenge ID.
-func (c *Challenge) GetID() string {
- return c.ID
-}
-
-// GetAuthzID returns the parent Authz ID that owns the Challenge.
-func (c *Challenge) GetAuthzID() string {
- return c.AuthzID
-}
-
-type httpGetter func(string) (*http.Response, error)
-type lookupTxt func(string) ([]string, error)
-type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
-
-type validateOptions struct {
- httpGet httpGetter
- lookupTxt lookupTxt
- tlsDial tlsDialer
-}
-
-// challenge is the interface ACME challenege types must implement.
-type challenge interface {
- save(db nosql.DB, swap challenge) error
- validate(nosql.DB, *jose.JSONWebKey, validateOptions) (challenge, error)
- getType() string
- getError() *AError
- getValue() string
- getStatus() string
- getID() string
- getAuthzID() string
- getToken() string
- clone() *baseChallenge
- getAccountID() string
- getValidated() time.Time
- getCreated() time.Time
- toACME(context.Context, nosql.DB, *directory) (*Challenge, error)
-}
-
-// ChallengeOptions is the type used to created a new Challenge.
-type ChallengeOptions struct {
- AccountID string
- AuthzID string
- Identifier Identifier
-}
-
-// baseChallenge is the base Challenge type that others build from.
-type baseChallenge struct {
- ID string `json:"id"`
- AccountID string `json:"accountID"`
- AuthzID string `json:"authzID"`
- Type string `json:"type"`
- Status string `json:"status"`
- Token string `json:"token"`
- Value string `json:"value"`
- Validated time.Time `json:"validated"`
- Created time.Time `json:"created"`
- Error *AError `json:"error"`
-}
-
-func newBaseChallenge(accountID, authzID string) (*baseChallenge, error) {
- id, err := randID()
- if err != nil {
- return nil, Wrap(err, "error generating random id for ACME challenge")
- }
- token, err := randID()
- if err != nil {
- return nil, Wrap(err, "error generating token for ACME challenge")
- }
-
- return &baseChallenge{
- ID: id,
- AccountID: accountID,
- AuthzID: authzID,
- Status: StatusPending,
- Token: token,
- Created: clock.Now(),
- }, nil
-}
-
-// getID returns the id of the baseChallenge.
-func (bc *baseChallenge) getID() string {
- return bc.ID
-}
-
-// getAuthzID returns the authz ID of the baseChallenge.
-func (bc *baseChallenge) getAuthzID() string {
- return bc.AuthzID
-}
-
-// getAccountID returns the account id of the baseChallenge.
-func (bc *baseChallenge) getAccountID() string {
- return bc.AccountID
-}
-
-// getType returns the type of the baseChallenge.
-func (bc *baseChallenge) getType() string {
- return bc.Type
-}
-
-// getValue returns the type of the baseChallenge.
-func (bc *baseChallenge) getValue() string {
- return bc.Value
-}
-
-// getStatus returns the status of the baseChallenge.
-func (bc *baseChallenge) getStatus() string {
- return bc.Status
-}
-
-// getToken returns the token of the baseChallenge.
-func (bc *baseChallenge) getToken() string {
- return bc.Token
-}
-
-// getValidated returns the validated time of the baseChallenge.
-func (bc *baseChallenge) getValidated() time.Time {
- return bc.Validated
-}
-
-// getCreated returns the created time of the baseChallenge.
-func (bc *baseChallenge) getCreated() time.Time {
- return bc.Created
-}
-
-// getCreated returns the created time of the baseChallenge.
-func (bc *baseChallenge) getError() *AError {
- return bc.Error
-}
-
-// toACME converts the internal Challenge type into the public acmeChallenge
-// type for presentation in the ACME protocol.
-func (bc *baseChallenge) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Challenge, error) {
- ac := &Challenge{
- Type: bc.getType(),
- Status: bc.getStatus(),
- Token: bc.getToken(),
- URL: dir.getLink(ctx, ChallengeLink, true, bc.getID()),
- ID: bc.getID(),
- AuthzID: bc.getAuthzID(),
- }
- if !bc.Validated.IsZero() {
- ac.Validated = bc.Validated.Format(time.RFC3339)
- }
- if bc.Error != nil {
- ac.Error = bc.Error
- }
- return ac, nil
-}
-
-// save writes the challenge to disk. For new challenges 'old' should be nil,
-// otherwise 'old' should be a pointer to the acme challenge as it was at the
-// start of the request. This method will fail if the value currently found
-// in the bucket/row does not match the value of 'old'.
-func (bc *baseChallenge) save(db nosql.DB, old challenge) error {
- newB, err := json.Marshal(bc)
- if err != nil {
- return ServerInternalErr(errors.Wrap(err,
- "error marshaling new acme challenge"))
- }
- var oldB []byte
- if old == nil {
- oldB = nil
- } else {
- oldB, err = json.Marshal(old)
- if err != nil {
- return ServerInternalErr(errors.Wrap(err,
- "error marshaling old acme challenge"))
- }
- }
-
- _, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB)
- switch {
- case err != nil:
- return ServerInternalErr(errors.Wrap(err, "error saving acme challenge"))
- case !swapped:
- return ServerInternalErr(errors.New("error saving acme challenge; " +
- "acme challenge has changed since last read"))
- default:
+// Validate attempts to validate the challenge. Stores changes to the Challenge
+// type using the DB interface.
+// satisfactorily validated, the 'status' and 'validated' attributes are
+// updated.
+func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
+ // If already valid or invalid then return without performing validation.
+ if ch.Status != StatusPending {
return nil
}
+ switch ch.Type {
+ case "http-01":
+ return http01Validate(ctx, ch, db, jwk, vo)
+ case "dns-01":
+ return dns01Validate(ctx, ch, db, jwk, vo)
+ case "tls-alpn-01":
+ return tlsalpn01Validate(ctx, ch, db, jwk, vo)
+ default:
+ return NewErrorISE("unexpected challenge type '%s'", ch.Type)
+ }
}
-func (bc *baseChallenge) clone() *baseChallenge {
- u := *bc
- return &u
-}
+func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
+ url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
-func (bc *baseChallenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
- return nil, ServerInternalErr(errors.New("unimplemented"))
-}
+ resp, err := vo.HTTPGet(url.String())
+ if err != nil {
+ return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
+ "error doing http GET for url %s", url))
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode >= 400 {
+ return storeError(ctx, db, ch, false, NewError(ErrorConnectionType,
+ "error doing http GET for url %s with status code %d", url, resp.StatusCode))
+ }
-func (bc *baseChallenge) storeError(db nosql.DB, err *Error) error {
- clone := bc.clone()
- clone.Error = err.ToACME()
- if err := clone.save(db, bc); err != nil {
- return ServerInternalErr(errors.Wrap(err, "failure saving error to acme challenge"))
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return WrapErrorISE(err, "error reading "+
+ "response body for url %s", url)
+ }
+ keyAuth := strings.TrimSpace(string(body))
+
+ expected, err := KeyAuthorization(ch.Token, jwk)
+ if err != nil {
+ return err
+ }
+ if keyAuth != expected {
+ return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
+ "keyAuthorization does not match; expected %s, but got %s", expected, keyAuth))
+ }
+
+ // Update and store the challenge.
+ ch.Status = StatusValid
+ ch.Error = nil
+ ch.ValidatedAt = clock.Now().Format(time.RFC3339)
+
+ if err = db.UpdateChallenge(ctx, ch); err != nil {
+ return WrapErrorISE(err, "error updating challenge")
}
return nil
}
-// unmarshalChallenge unmarshals a challenge type into the correct sub-type.
-func unmarshalChallenge(data []byte) (challenge, error) {
- var getType struct {
- Type string `json:"type"`
- }
- if err := json.Unmarshal(data, &getType); err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type"))
- }
-
- switch getType.Type {
- case "dns-01":
- var bc baseChallenge
- if err := json.Unmarshal(data, &bc); err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
- "challenge type into dns01Challenge"))
- }
- return &dns01Challenge{&bc}, nil
- case "http-01":
- var bc baseChallenge
- if err := json.Unmarshal(data, &bc); err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
- "challenge type into http01Challenge"))
- }
- return &http01Challenge{&bc}, nil
- case "tls-alpn-01":
- var bc baseChallenge
- if err := json.Unmarshal(data, &bc); err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
- "challenge type into tlsALPN01Challenge"))
- }
- return &tlsALPN01Challenge{&bc}, nil
- default:
- return nil, ServerInternalErr(errors.Errorf("unexpected challenge type %s", getType.Type))
- }
-}
-
-// http01Challenge represents an http-01 acme challenge.
-type http01Challenge struct {
- *baseChallenge
-}
-
-// newHTTP01Challenge returns a new acme http-01 challenge.
-func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
- bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
- if err != nil {
- return nil, err
- }
- bc.Type = "http-01"
- bc.Value = ops.Identifier.Value
-
- hc := &http01Challenge{bc}
- if err := hc.save(db, nil); err != nil {
- return nil, err
- }
- return hc, nil
-}
-
-// Validate attempts to validate the challenge. If the challenge has been
-// satisfactorily validated, the 'status' and 'validated' attributes are
-// updated.
-func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
- // If already valid or invalid then return without performing validation.
- if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid {
- return hc, nil
- }
- url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token)
-
- resp, err := vo.httpGet(url)
- if err != nil {
- if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err,
- "error doing http GET for url %s", url))); err != nil {
- return nil, err
- }
- return hc, nil
- }
- if resp.StatusCode >= 400 {
- if err = hc.storeError(db,
- ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d",
- url, resp.StatusCode))); err != nil {
- return nil, err
- }
- return hc, nil
- }
- defer resp.Body.Close()
-
- body, err := ioutil.ReadAll(resp.Body)
- if err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+
- "response body for url %s", url))
- }
- keyAuth := strings.Trim(string(body), "\r\n")
-
- expected, err := KeyAuthorization(hc.Token, jwk)
- if err != nil {
- return nil, err
- }
- if keyAuth != expected {
- if err = hc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+
- "expected %s, but got %s", expected, keyAuth))); err != nil {
- return nil, err
- }
- return hc, nil
- }
-
- // Update and store the challenge.
- upd := &http01Challenge{hc.baseChallenge.clone()}
- upd.Status = StatusValid
- upd.Error = nil
- upd.Validated = clock.Now()
-
- if err := upd.save(db, hc); err != nil {
- return nil, err
- }
- return upd, nil
-}
-
-type tlsALPN01Challenge struct {
- *baseChallenge
-}
-
-// newTLSALPN01Challenge returns a new acme tls-alpn-01 challenge.
-func newTLSALPN01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
- bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
- if err != nil {
- return nil, err
- }
- bc.Type = "tls-alpn-01"
- bc.Value = ops.Identifier.Value
-
- hc := &tlsALPN01Challenge{bc}
- if err := hc.save(db, nil); err != nil {
- return nil, err
- }
- return hc, nil
-}
-
-func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
- // If already valid or invalid then return without performing validation.
- if tc.getStatus() == StatusValid || tc.getStatus() == StatusInvalid {
- return tc, nil
- }
-
+func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
config := &tls.Config{
- NextProtos: []string{"acme-tls/1"},
- ServerName: tc.Value,
+ NextProtos: []string{"acme-tls/1"},
+ // https://tools.ietf.org/html/rfc8737#section-4
+ // ACME servers that implement "acme-tls/1" MUST only negotiate TLS 1.2
+ // [RFC5246] or higher when connecting to clients for validation.
+ MinVersion: tls.VersionTLS12,
+ ServerName: ch.Value,
InsecureSkipVerify: true, // we expect a self-signed challenge certificate
}
- hostPort := net.JoinHostPort(tc.Value, "443")
+ hostPort := net.JoinHostPort(ch.Value, "443")
- conn, err := vo.tlsDial("tcp", hostPort, config)
+ conn, err := vo.TLSDial("tcp", hostPort, config)
if err != nil {
- if err = tc.storeError(db,
- ConnectionErr(errors.Wrapf(err, "error doing TLS dial for %s", hostPort))); err != nil {
- return nil, err
- }
- return tc, nil
+ return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
+ "error doing TLS dial for %s", hostPort))
}
defer conn.Close()
@@ -409,86 +130,62 @@ func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo val
certs := cs.PeerCertificates
if len(certs) == 0 {
- if err = tc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("%s challenge for %s resulted in no certificates",
- tc.Type, tc.Value))); err != nil {
- return nil, err
- }
- return tc, nil
+ return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
+ "%s challenge for %s resulted in no certificates", ch.Type, ch.Value))
}
- if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" {
- if err = tc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("cannot negotiate ALPN acme-tls/1 protocol for "+
- "tls-alpn-01 challenge"))); err != nil {
- return nil, err
- }
- return tc, nil
+ if cs.NegotiatedProtocol != "acme-tls/1" {
+ return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
+ "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge"))
}
leafCert := certs[0]
- if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], tc.Value) {
- if err = tc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
- "leaf certificate must contain a single DNS name, %v", tc.Value))); err != nil {
- return nil, err
- }
- return tc, nil
+ if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) {
+ return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
+ "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value))
}
idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31}
idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
foundIDPeAcmeIdentifierV1Obsolete := false
- keyAuth, err := KeyAuthorization(tc.Token, jwk)
+ keyAuth, err := KeyAuthorization(ch.Token, jwk)
if err != nil {
- return nil, err
+ return err
}
hashedKeyAuth := sha256.Sum256([]byte(keyAuth))
for _, ext := range leafCert.Extensions {
if idPeAcmeIdentifier.Equal(ext.Id) {
if !ext.Critical {
- if err = tc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
- "acmeValidationV1 extension not critical"))); err != nil {
- return nil, err
- }
- return tc, nil
+ return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
+ "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical"))
}
var extValue []byte
rest, err := asn1.Unmarshal(ext.Value, &extValue)
if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) {
- if err = tc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
- "malformed acmeValidationV1 extension value"))); err != nil {
- return nil, err
- }
- return tc, nil
+ return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
+ "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value"))
}
if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 {
- if err = tc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
+ return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
+ "incorrect certificate for tls-alpn-01 challenge: "+
"expected acmeValidationV1 extension value %s for this challenge but got %s",
- hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))); err != nil {
- return nil, err
- }
- return tc, nil
+ hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))
}
- upd := &tlsALPN01Challenge{tc.baseChallenge.clone()}
- upd.Status = StatusValid
- upd.Error = nil
- upd.Validated = clock.Now()
+ ch.Status = StatusValid
+ ch.Error = nil
+ ch.ValidatedAt = clock.Now().Format(time.RFC3339)
- if err := upd.save(db, tc); err != nil {
- return nil, err
+ if err = db.UpdateChallenge(ctx, ch); err != nil {
+ return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge")
}
- return upd, nil
+ return nil
}
if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) {
@@ -497,82 +194,30 @@ func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo val
}
if foundIDPeAcmeIdentifierV1Obsolete {
- if err = tc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
- "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))); err != nil {
- return nil, err
- }
- return tc, nil
+ return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
+ "incorrect certificate for tls-alpn-01 challenge: obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))
}
- if err = tc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
- "missing acmeValidationV1 extension"))); err != nil {
- return nil, err
- }
- return tc, nil
+ return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
+ "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
}
-// dns01Challenge represents an dns-01 acme challenge.
-type dns01Challenge struct {
- *baseChallenge
-}
-
-// newDNS01Challenge returns a new acme dns-01 challenge.
-func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
- bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
- if err != nil {
- return nil, err
- }
- bc.Type = "dns-01"
- bc.Value = ops.Identifier.Value
-
- dc := &dns01Challenge{bc}
- if err := dc.save(db, nil); err != nil {
- return nil, err
- }
- return dc, nil
-}
-
-// KeyAuthorization creates the ACME key authorization value from a token
-// and a jwk.
-func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
- thumbprint, err := jwk.Thumbprint(crypto.SHA256)
- if err != nil {
- return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint"))
- }
- encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
- return fmt.Sprintf("%s.%s", token, encPrint), nil
-}
-
-// validate attempts to validate the challenge. If the challenge has been
-// satisfactorily validated, the 'status' and 'validated' attributes are
-// updated.
-func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
- // If already valid or invalid then return without performing validation.
- if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid {
- return dc, nil
- }
-
+func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
// Normalize domain for wildcard DNS names
// This is done to avoid making TXT lookups for domains like
// _acme-challenge.*.example.com
// Instead perform txt lookup for _acme-challenge.example.com
- domain := strings.TrimPrefix(dc.Value, "*.")
+ domain := strings.TrimPrefix(ch.Value, "*.")
- txtRecords, err := vo.lookupTxt("_acme-challenge." + domain)
+ txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
if err != nil {
- if err = dc.storeError(db,
- DNSErr(errors.Wrapf(err, "error looking up TXT "+
- "records for domain %s", domain))); err != nil {
- return nil, err
- }
- return dc, nil
+ return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
+ "error looking up TXT records for domain %s", domain))
}
- expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk)
+ expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk)
if err != nil {
- return nil, err
+ return err
}
h := sha256.Sum256([]byte(expectedKeyAuth))
expected := base64.RawURLEncoding.EncodeToString(h[:])
@@ -584,37 +229,51 @@ func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validat
}
}
if !found {
- if err = dc.storeError(db,
- RejectedIdentifierErr(errors.Errorf("keyAuthorization "+
- "does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil {
- return nil, err
- }
- return dc, nil
+ return storeError(ctx, db, ch, false, NewError(ErrorRejectedIdentifierType,
+ "keyAuthorization does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))
}
// Update and store the challenge.
- upd := &dns01Challenge{dc.baseChallenge.clone()}
- upd.Status = StatusValid
- upd.Error = nil
- upd.Validated = time.Now().UTC()
+ ch.Status = StatusValid
+ ch.Error = nil
+ ch.ValidatedAt = clock.Now().Format(time.RFC3339)
- if err := upd.save(db, dc); err != nil {
- return nil, err
+ if err = db.UpdateChallenge(ctx, ch); err != nil {
+ return WrapErrorISE(err, "error updating challenge")
}
- return upd, nil
+ return nil
}
-// getChallenge retrieves and unmarshals an ACME challenge type from the database.
-func getChallenge(db nosql.DB, id string) (challenge, error) {
- b, err := db.Get(challengeTable, []byte(id))
- if nosql.IsErrNotFound(err) {
- return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id))
- } else if err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id))
- }
- ch, err := unmarshalChallenge(b)
+// KeyAuthorization creates the ACME key authorization value from a token
+// and a jwk.
+func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
+ thumbprint, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
- return nil, err
+ return "", WrapErrorISE(err, "error generating JWK thumbprint")
}
- return ch, nil
+ encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
+ return fmt.Sprintf("%s.%s", token, encPrint), nil
+}
+
+// storeError the given error to an ACME error and saves using the DB interface.
+func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err *Error) error {
+ ch.Error = err
+ if markInvalid {
+ ch.Status = StatusInvalid
+ }
+ if err := db.UpdateChallenge(ctx, ch); err != nil {
+ return WrapErrorISE(err, "failure saving error to acme challenge")
+ }
+ return nil
+}
+
+type httpGetter func(string) (*http.Response, error)
+type lookupTxt func(string) ([]string, error)
+type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
+
+// ValidateChallengeOptions are ACME challenge validator functions.
+type ValidateChallengeOptions struct {
+ HTTPGet httpGetter
+ LookupTxt lookupTxt
+ TLSDial tlsDialer
}
diff --git a/acme/challenge_test.go b/acme/challenge_test.go
index 87ec0c4c..14287945 100644
--- a/acme/challenge_test.go
+++ b/acme/challenge_test.go
@@ -13,7 +13,6 @@ import (
"encoding/asn1"
"encoding/base64"
"encoding/hex"
- "encoding/json"
"fmt"
"io"
"io/ioutil"
@@ -21,644 +20,150 @@ import (
"net"
"net/http"
"net/http/httptest"
- "net/url"
+ "strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
- "github.com/smallstep/certificates/db"
- "github.com/smallstep/nosql"
- "github.com/smallstep/nosql/database"
"go.step.sm/crypto/jose"
)
-var testOps = ChallengeOptions{
- AccountID: "accID",
- AuthzID: "authzID",
- Identifier: Identifier{
- Type: "", // will get set correctly depending on the "new.." method.
- Value: "zap.internal",
- },
-}
-
-func newDNSCh() (challenge, error) {
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), true, nil
- },
- }
- return newDNS01Challenge(mockdb, testOps)
-}
-
-func newTLSALPNCh() (challenge, error) {
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), true, nil
- },
- }
- return newTLSALPN01Challenge(mockdb, testOps)
-}
-
-func newHTTPCh() (challenge, error) {
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), true, nil
- },
- }
- return newHTTP01Challenge(mockdb, testOps)
-}
-
-func newHTTPChWithServer(host string) (challenge, error) {
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), true, nil
- },
- }
- return newHTTP01Challenge(mockdb, ChallengeOptions{
- AccountID: "accID",
- AuthzID: "authzID",
- Identifier: Identifier{
- Type: "", // will get set correctly depending on the "new.." method.
- Value: host,
- },
- })
-}
-
-func TestNewHTTP01Challenge(t *testing.T) {
- ops := ChallengeOptions{
- AccountID: "accID",
- AuthzID: "authzID",
- Identifier: Identifier{
- Type: "http",
- Value: "zap.internal",
- },
- }
+func Test_storeError(t *testing.T) {
type test struct {
- ops ChallengeOptions
- db nosql.DB
- err *Error
- }
- tests := map[string]test{
- "fail/store-error": {
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error saving acme challenge: force")),
- },
- "ok": {
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), true, nil
- },
- },
- },
- }
- for name, tc := range tests {
- t.Run(name, func(t *testing.T) {
- ch, err := newHTTP01Challenge(tc.db, tc.ops)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, ch.getAccountID(), ops.AccountID)
- assert.Equals(t, ch.getAuthzID(), ops.AuthzID)
- assert.Equals(t, ch.getType(), "http-01")
- assert.Equals(t, ch.getValue(), "zap.internal")
- assert.Equals(t, ch.getStatus(), StatusPending)
-
- assert.True(t, ch.getValidated().IsZero())
- assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
-
- assert.True(t, ch.getID() != "")
- assert.True(t, ch.getToken() != "")
- }
- }
- })
- }
-}
-
-func TestNewTLSALPN01Challenge(t *testing.T) {
- ops := ChallengeOptions{
- AccountID: "accID",
- AuthzID: "authzID",
- Identifier: Identifier{
- Type: "http",
- Value: "zap.internal",
- },
- }
- type test struct {
- ops ChallengeOptions
- db nosql.DB
- err *Error
- }
- tests := map[string]test{
- "fail/store-error": {
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error saving acme challenge: force")),
- },
- "ok": {
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), true, nil
- },
- },
- },
- }
- for name, tc := range tests {
- t.Run(name, func(t *testing.T) {
- ch, err := newTLSALPN01Challenge(tc.db, tc.ops)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, ch.getAccountID(), ops.AccountID)
- assert.Equals(t, ch.getAuthzID(), ops.AuthzID)
- assert.Equals(t, ch.getType(), "tls-alpn-01")
- assert.Equals(t, ch.getValue(), "zap.internal")
- assert.Equals(t, ch.getStatus(), StatusPending)
-
- assert.True(t, ch.getValidated().IsZero())
- assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
-
- assert.True(t, ch.getID() != "")
- assert.True(t, ch.getToken() != "")
- }
- }
- })
- }
-}
-
-func TestNewDNS01Challenge(t *testing.T) {
- ops := ChallengeOptions{
- AccountID: "accID",
- AuthzID: "authzID",
- Identifier: Identifier{
- Type: "dns",
- Value: "zap.internal",
- },
- }
- type test struct {
- ops ChallengeOptions
- db nosql.DB
- err *Error
- }
- tests := map[string]test{
- "fail/store-error": {
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error saving acme challenge: force")),
- },
- "ok": {
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), true, nil
- },
- },
- },
- }
- for name, tc := range tests {
- t.Run(name, func(t *testing.T) {
- ch, err := newDNS01Challenge(tc.db, tc.ops)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, ch.getAccountID(), ops.AccountID)
- assert.Equals(t, ch.getAuthzID(), ops.AuthzID)
- assert.Equals(t, ch.getType(), "dns-01")
- assert.Equals(t, ch.getValue(), "zap.internal")
- assert.Equals(t, ch.getStatus(), StatusPending)
-
- assert.True(t, ch.getValidated().IsZero())
- assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
-
- assert.True(t, ch.getID() != "")
- assert.True(t, ch.getToken() != "")
- }
- }
- })
- }
-}
-
-func TestChallengeToACME(t *testing.T) {
- dir := newDirectory("ca.smallstep.com", "acme")
-
- httpCh, err := newHTTPCh()
- assert.FatalError(t, err)
- _httpCh, ok := httpCh.(*http01Challenge)
- assert.Fatal(t, ok)
- _httpCh.baseChallenge.Validated = clock.Now()
- dnsCh, err := newDNSCh()
- assert.FatalError(t, err)
- tlsALPNCh, err := newTLSALPNCh()
- assert.FatalError(t, err)
-
- prov := newProv()
- provName := url.PathEscape(prov.GetName())
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
- tests := map[string]challenge{
- "dns": dnsCh,
- "http": httpCh,
- "tls-alpn": tlsALPNCh,
- }
- for name, ch := range tests {
- t.Run(name, func(t *testing.T) {
- ach, err := ch.toACME(ctx, nil, dir)
- assert.FatalError(t, err)
-
- assert.Equals(t, ach.Type, ch.getType())
- assert.Equals(t, ach.Status, ch.getStatus())
- assert.Equals(t, ach.Token, ch.getToken())
- assert.Equals(t, ach.URL,
- fmt.Sprintf("%s/acme/%s/challenge/%s",
- baseURL.String(), provName, ch.getID()))
- assert.Equals(t, ach.ID, ch.getID())
- assert.Equals(t, ach.AuthzID, ch.getAuthzID())
-
- if ach.Type == "http-01" {
- v, err := time.Parse(time.RFC3339, ach.Validated)
- assert.FatalError(t, err)
- assert.Equals(t, v.String(), _httpCh.baseChallenge.Validated.String())
- } else {
- assert.Equals(t, ach.Validated, "")
- }
- })
- }
-}
-
-func TestChallengeSave(t *testing.T) {
- type test struct {
- ch challenge
- old challenge
- db nosql.DB
- err *Error
+ ch *Challenge
+ db DB
+ markInvalid bool
+ err *Error
}
+ err := NewError(ErrorMalformedType, "foo")
tests := map[string]func(t *testing.T) test{
- "fail/old-nil/swap-error": func(t *testing.T) test {
- httpCh, err := newHTTPCh()
- assert.FatalError(t, err)
+ "fail/db.UpdateChallenge-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusValid,
+ }
return test{
- ch: httpCh,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
+ ch: ch,
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusValid)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
},
},
- err: ServerInternalErr(errors.New("error saving acme challenge: force")),
+ err: NewErrorISE("failure saving error to acme challenge: force"),
}
},
- "fail/old-nil/swap-false": func(t *testing.T) test {
- httpCh, err := newHTTPCh()
- assert.FatalError(t, err)
+ "fail/db.UpdateChallenge-acme-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusValid,
+ }
return test{
- ch: httpCh,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), false, nil
+ ch: ch,
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusValid)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return NewError(ErrorMalformedType, "bar")
},
},
- err: ServerInternalErr(errors.New("error saving acme challenge; acme challenge has changed since last read")),
- }
- },
- "ok/old-nil": func(t *testing.T) test {
- httpCh, err := newHTTPCh()
- assert.FatalError(t, err)
- b, err := json.Marshal(httpCh)
- assert.FatalError(t, err)
- return test{
- ch: httpCh,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, nil)
- assert.Equals(t, b, newval)
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, []byte(httpCh.getID()), key)
- return []byte("foo"), true, nil
- },
- },
- }
- },
- "ok/old-not-nil": func(t *testing.T) test {
- oldHTTPCh, err := newHTTPCh()
- assert.FatalError(t, err)
- httpCh, err := newHTTPCh()
- assert.FatalError(t, err)
-
- oldb, err := json.Marshal(oldHTTPCh)
- assert.FatalError(t, err)
- b, err := json.Marshal(httpCh)
- assert.FatalError(t, err)
- return test{
- ch: httpCh,
- old: oldHTTPCh,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, oldb)
- assert.Equals(t, b, newval)
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, []byte(httpCh.getID()), key)
- return []byte("foo"), true, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if err := tc.ch.save(tc.db, tc.old); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- assert.Nil(t, tc.err)
- }
- })
- }
-}
-
-func TestChallengeClone(t *testing.T) {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
-
- clone := ch.clone()
-
- assert.Equals(t, clone.getID(), ch.getID())
- assert.Equals(t, clone.getAccountID(), ch.getAccountID())
- assert.Equals(t, clone.getAuthzID(), ch.getAuthzID())
- assert.Equals(t, clone.getStatus(), ch.getStatus())
- assert.Equals(t, clone.getToken(), ch.getToken())
- assert.Equals(t, clone.getCreated(), ch.getCreated())
- assert.Equals(t, clone.getValidated(), ch.getValidated())
-
- clone.Status = StatusValid
-
- assert.NotEquals(t, clone.getStatus(), ch.getStatus())
-}
-
-func TestChallengeUnmarshal(t *testing.T) {
- type test struct {
- ch challenge
- chb []byte
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/nil": func(t *testing.T) test {
- return test{
- chb: nil,
- err: ServerInternalErr(errors.New("error unmarshaling challenge type: unexpected end of JSON input")),
- }
- },
- "fail/unexpected-type-http": func(t *testing.T) test {
- httpCh, err := newHTTPCh()
- assert.FatalError(t, err)
- _httpCh, ok := httpCh.(*http01Challenge)
- assert.Fatal(t, ok)
- _httpCh.baseChallenge.Type = "foo"
- b, err := json.Marshal(httpCh)
- assert.FatalError(t, err)
- return test{
- chb: b,
- err: ServerInternalErr(errors.New("unexpected challenge type foo")),
- }
- },
- "fail/unexpected-type-alpn": func(t *testing.T) test {
- tlsALPNCh, err := newTLSALPNCh()
- assert.FatalError(t, err)
- _tlsALPNCh, ok := tlsALPNCh.(*tlsALPN01Challenge)
- assert.Fatal(t, ok)
- _tlsALPNCh.baseChallenge.Type = "foo"
- b, err := json.Marshal(tlsALPNCh)
- assert.FatalError(t, err)
- return test{
- chb: b,
- err: ServerInternalErr(errors.New("unexpected challenge type foo")),
- }
- },
- "fail/unexpected-type-dns": func(t *testing.T) test {
- dnsCh, err := newDNSCh()
- assert.FatalError(t, err)
- _dnsCh, ok := dnsCh.(*dns01Challenge)
- assert.Fatal(t, ok)
- _dnsCh.baseChallenge.Type = "foo"
- b, err := json.Marshal(dnsCh)
- assert.FatalError(t, err)
- return test{
- chb: b,
- err: ServerInternalErr(errors.New("unexpected challenge type foo")),
- }
- },
- "ok/dns": func(t *testing.T) test {
- dnsCh, err := newDNSCh()
- assert.FatalError(t, err)
- b, err := json.Marshal(dnsCh)
- assert.FatalError(t, err)
- return test{
- ch: dnsCh,
- chb: b,
- }
- },
- "ok/http": func(t *testing.T) test {
- httpCh, err := newHTTPCh()
- assert.FatalError(t, err)
- b, err := json.Marshal(httpCh)
- assert.FatalError(t, err)
- return test{
- ch: httpCh,
- chb: b,
- }
- },
- "ok/alpn": func(t *testing.T) test {
- tlsALPNCh, err := newTLSALPNCh()
- assert.FatalError(t, err)
- b, err := json.Marshal(tlsALPNCh)
- assert.FatalError(t, err)
- return test{
- ch: tlsALPNCh,
- chb: b,
- }
- },
- "ok/err": func(t *testing.T) test {
- httpCh, err := newHTTPCh()
- assert.FatalError(t, err)
- _httpCh, ok := httpCh.(*http01Challenge)
- assert.Fatal(t, ok)
- _httpCh.baseChallenge.Error = ServerInternalErr(errors.New("force")).ToACME()
- b, err := json.Marshal(httpCh)
- assert.FatalError(t, err)
- return test{
- ch: httpCh,
- chb: b,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if ch, err := unmarshalChallenge(tc.chb); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.ch.getID(), ch.getID())
- assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID())
- assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID())
- assert.Equals(t, tc.ch.getStatus(), ch.getStatus())
- assert.Equals(t, tc.ch.getToken(), ch.getToken())
- assert.Equals(t, tc.ch.getCreated(), ch.getCreated())
- assert.Equals(t, tc.ch.getValidated(), ch.getValidated())
- }
- }
- })
- }
-}
-func TestGetChallenge(t *testing.T) {
- type test struct {
- id string
- db nosql.DB
- ch challenge
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/not-found": func(t *testing.T) test {
- dnsCh, err := newDNSCh()
- assert.FatalError(t, err)
- return test{
- ch: dnsCh,
- id: dnsCh.getID(),
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- },
- err: MalformedErr(errors.Errorf("challenge %s not found: not found", dnsCh.getID())),
- }
- },
- "fail/db-error": func(t *testing.T) test {
- dnsCh, err := newDNSCh()
- assert.FatalError(t, err)
- return test{
- ch: dnsCh,
- id: dnsCh.getID(),
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", dnsCh.getID())),
- }
- },
- "fail/unmarshal-error": func(t *testing.T) test {
- dnsCh, err := newDNSCh()
- assert.FatalError(t, err)
- _dnsCh, ok := dnsCh.(*dns01Challenge)
- assert.Fatal(t, ok)
- _dnsCh.baseChallenge.Type = "foo"
- b, err := json.Marshal(dnsCh)
- assert.FatalError(t, err)
- return test{
- ch: dnsCh,
- id: dnsCh.getID(),
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(dnsCh.getID()))
- return b, nil
- },
- },
- err: ServerInternalErr(errors.New("unexpected challenge type foo")),
+ err: NewError(ErrorMalformedType, "failure saving error to acme challenge: bar"),
}
},
"ok": func(t *testing.T) test {
- dnsCh, err := newDNSCh()
- assert.FatalError(t, err)
- b, err := json.Marshal(dnsCh)
- assert.FatalError(t, err)
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusValid,
+ }
return test{
- ch: dnsCh,
- id: dnsCh.getID(),
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(dnsCh.getID()))
- return b, nil
+ ch: ch,
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusValid)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
},
},
}
},
+ "ok/mark-invalid": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusValid,
+ }
+ return test{
+ ch: ch,
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusInvalid)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ markInvalid: true,
+ }
+ },
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
- if ch, err := getChallenge(tc.db, tc.id); err != nil {
+ if err := storeError(context.Background(), tc.db, tc.ch, tc.markInvalid, err); err != nil {
if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
}
} else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.ch.getID(), ch.getID())
- assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID())
- assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID())
- assert.Equals(t, tc.ch.getStatus(), ch.getStatus())
- assert.Equals(t, tc.ch.getToken(), ch.getToken())
- assert.Equals(t, tc.ch.getCreated(), ch.getCreated())
- assert.Equals(t, tc.ch.getValidated(), ch.getValidated())
- }
+ assert.Nil(t, tc.err)
}
})
}
@@ -679,7 +184,7 @@ func TestKeyAuthorization(t *testing.T) {
return test{
token: "1234",
jwk: jwk,
- err: ServerInternalErr(errors.Errorf("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")),
+ err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"),
}
},
"ok": func(t *testing.T) test {
@@ -701,11 +206,16 @@ func TestKeyAuthorization(t *testing.T) {
tc := run(t)
if ka, err := KeyAuthorization(tc.token, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
}
} else {
if assert.Nil(t, tc.err) {
@@ -716,6 +226,293 @@ func TestKeyAuthorization(t *testing.T) {
}
}
+func TestChallenge_Validate(t *testing.T) {
+ type test struct {
+ ch *Challenge
+ vo *ValidateChallengeOptions
+ jwk *jose.JSONWebKey
+ db DB
+ srv *httptest.Server
+ err *Error
+ }
+ tests := map[string]func(t *testing.T) test{
+ "ok/already-valid": func(t *testing.T) test {
+ ch := &Challenge{
+ Status: StatusValid,
+ }
+ return test{
+ ch: ch,
+ }
+ },
+ "fail/already-invalid": func(t *testing.T) test {
+ ch := &Challenge{
+ Status: StatusInvalid,
+ }
+ return test{
+ ch: ch,
+ }
+ },
+ "fail/unexpected-type": func(t *testing.T) test {
+ ch := &Challenge{
+ Status: StatusPending,
+ Type: "foo",
+ }
+ return test{
+ ch: ch,
+ err: NewErrorISE("unexpected challenge type 'foo'"),
+ }
+ },
+ "fail/http-01": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Status: StatusPending,
+ Type: "http-01",
+ Token: "token",
+ Value: "zap.internal",
+ }
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
+ return nil, errors.New("force")
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Status, ch.Status)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token)
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/http-01": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Status: StatusPending,
+ Type: "http-01",
+ Token: "token",
+ Value: "zap.internal",
+ }
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
+ return nil, errors.New("force")
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Status, ch.Status)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token)
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ }
+ },
+ "fail/dns-01": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Type: "dns-01",
+ Status: StatusPending,
+ Token: "token",
+ Value: "zap.internal",
+ }
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ LookupTxt: func(url string) ([]string, error) {
+ return nil, errors.New("force")
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Status, ch.Status)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/dns-01": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Type: "dns-01",
+ Status: StatusPending,
+ Token: "token",
+ Value: "zap.internal",
+ }
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ LookupTxt: func(url string) ([]string, error) {
+ return nil, errors.New("force")
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Status, ch.Status)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ }
+ },
+ "fail/tls-alpn-01": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Type: "tls-alpn-01",
+ Status: StatusPending,
+ Value: "zap.internal",
+ }
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
+ return nil, errors.New("force")
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, ch.Status)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: force", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/tls-alpn-01": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Type: "tls-alpn-01",
+ Status: StatusPending,
+ Value: "zap.internal",
+ }
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, ch.Status)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Error, nil)
+ return nil
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ }
+ },
+ }
+ for name, run := range tests {
+ t.Run(name, func(t *testing.T) {
+ tc := run(t)
+
+ if tc.srv != nil {
+ defer tc.srv.Close()
+ }
+
+ if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil {
+ if assert.NotNil(t, tc.err) {
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
+ }
+ } else {
+ assert.Nil(t, tc.err)
+ }
+ })
+ }
+}
+
type errReader int
func (errReader) Read(p []byte) (n int, err error) {
@@ -727,258 +524,361 @@ func (errReader) Close() error {
func TestHTTP01Validate(t *testing.T) {
type test struct {
- vo validateOptions
- ch challenge
- res challenge
+ vo *ValidateChallengeOptions
+ ch *Challenge
jwk *jose.JSONWebKey
- db nosql.DB
+ db DB
err *Error
}
tests := map[string]func(t *testing.T) test{
- "ok/status-already-valid": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*http01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Status = StatusValid
- return test{
- ch: ch,
- res: ch,
+ "fail/http-get-error-store-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
}
- },
- "ok/status-already-invalid": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*http01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Status = StatusInvalid
- return test{
- ch: ch,
- res: ch,
- }
- },
- "ok/http-get-error": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
- expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+
- "http://zap.internal/.well-known/acme-challenge/%s: force", ch.getToken()))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &http01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
return test{
ch: ch,
- vo: validateOptions{
- httpGet: func(url string) (*http.Response, error) {
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
return nil, errors.New("force")
},
},
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, newb)
- return nil, true, nil
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusPending)
+
+ err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token)
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
},
},
- res: ch,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
}
},
- "ok/http-get->=400": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
+ "ok/http-get-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
+ }
- expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+
- "http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.getToken()))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &http01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
return test{
ch: ch,
- vo: validateOptions{
- httpGet: func(url string) (*http.Response, error) {
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
+ return nil, errors.New("force")
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusPending)
+
+ err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token)
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ }
+ },
+ "fail/http-get->=400-store-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
+ }
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadRequest,
+ Body: errReader(0),
}, nil
},
},
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, newb)
- return nil, true, nil
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusPending)
+
+ err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token)
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
},
},
- res: ch,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
}
},
- "fail/read-body": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
- jwk.Key = "foo"
+ "ok/http-get->=400": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
+ }
return test{
ch: ch,
- vo: validateOptions{
- httpGet: func(url string) (*http.Response, error) {
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
+ return &http.Response{
+ StatusCode: http.StatusBadRequest,
+ Body: errReader(0),
+ }, nil
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusPending)
+
+ err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token)
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ }
+ },
+ "fail/read-body": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
+ }
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{
Body: errReader(0),
}, nil
},
},
- jwk: jwk,
- err: ServerInternalErr(errors.Errorf("error reading response "+
- "body for url http://zap.internal/.well-known/acme-challenge/%s: force",
- ch.getToken())),
+ err: NewErrorISE("error reading response body for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token),
}
},
- "fail/key-authorization-gen-error": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
+ "fail/key-auth-gen-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
+ }
+
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
jwk.Key = "foo"
return test{
ch: ch,
- vo: validateOptions{
- httpGet: func(url string) (*http.Response, error) {
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")),
}, nil
},
},
jwk: jwk,
- err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")),
+ err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"),
}
},
"ok/key-auth-mismatch": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
+ }
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
assert.FatalError(t, err)
-
- expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+
- "expected %s, but got foo", expKeyAuth))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &http01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
-
return test{
ch: ch,
- vo: validateOptions{
- httpGet: func(url string) (*http.Response, error) {
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")),
}, nil
},
},
jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, newb)
- return nil, true, nil
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusInvalid)
+
+ err := NewError(ErrorRejectedIdentifierType,
+ "keyAuthorization does not match; expected %s, but got foo", expKeyAuth)
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
},
},
- res: ch,
}
},
- "fail/save-error": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
+ "fail/key-auth-mismatch-store-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
+ }
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
assert.FatalError(t, err)
return test{
ch: ch,
- vo: validateOptions{
- httpGet: func(url string) (*http.Response, error) {
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{
- Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)),
+ Body: ioutil.NopCloser(bytes.NewBufferString("foo")),
}, nil
},
},
jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusInvalid)
+
+ err := NewError(ErrorRejectedIdentifierType,
+ "keyAuthorization does not match; expected %s, but got foo", expKeyAuth)
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
},
},
- err: ServerInternalErr(errors.New("error saving acme challenge: force")),
+ err: NewErrorISE("failure saving error to acme challenge: force"),
}
},
- "ok": func(t *testing.T) test {
- ch, err := newHTTPCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*http01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Error = MalformedErr(nil).ToACME()
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
+ "fail/update-challenge-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
+ }
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
assert.FatalError(t, err)
-
- baseClone := ch.clone()
- baseClone.Status = StatusValid
- baseClone.Error = nil
- newCh := &http01Challenge{baseClone}
-
return test{
- ch: ch,
- res: newCh,
- vo: validateOptions{
- httpGet: func(url string) (*http.Response, error) {
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil
},
},
jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
-
- httpCh, err := unmarshalChallenge(newval)
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusValid)
+ assert.Equals(t, updch.Error, nil)
+ va, err := time.Parse(time.RFC3339, updch.ValidatedAt)
assert.FatalError(t, err)
- assert.Equals(t, httpCh.getStatus(), StatusValid)
- assert.True(t, httpCh.getValidated().Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, httpCh.getValidated().After(time.Now().UTC().Add(-1*time.Second)))
+ now := clock.Now()
+ assert.True(t, va.Add(-time.Minute).Before(now))
+ assert.True(t, va.Add(time.Minute).After(now))
- baseClone.Validated = httpCh.getValidated()
+ return errors.New("force")
+ },
+ },
+ err: NewErrorISE("error updating challenge: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: "zap.internal",
+ Status: StatusPending,
+ }
- return nil, true, nil
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ HTTPGet: func(url string) (*http.Response, error) {
+ return &http.Response{
+ Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)),
+ }, nil
+ },
+ },
+ jwk: jwk,
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ assert.Equals(t, updch.Status, StatusValid)
+ assert.Equals(t, updch.Error, nil)
+ va, err := time.Parse(time.RFC3339, updch.ValidatedAt)
+ assert.FatalError(t, err)
+ now := clock.Now()
+ assert.True(t, va.Add(-time.Minute).Before(now))
+ assert.True(t, va.Add(time.Minute).After(now))
+ return nil
},
},
}
@@ -987,648 +887,320 @@ func TestHTTP01Validate(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
- if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil {
+ if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
}
} else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.res.getID(), ch.getID())
- assert.Equals(t, tc.res.getAccountID(), ch.getAccountID())
- assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID())
- assert.Equals(t, tc.res.getStatus(), ch.getStatus())
- assert.Equals(t, tc.res.getToken(), ch.getToken())
- assert.Equals(t, tc.res.getCreated(), ch.getCreated())
- assert.Equals(t, tc.res.getValidated(), ch.getValidated())
- assert.Equals(t, tc.res.getError(), ch.getError())
- }
+ assert.Nil(t, tc.err)
}
})
}
}
-func TestTLSALPN01Validate(t *testing.T) {
+func TestDNS01Validate(t *testing.T) {
+ fulldomain := "*.zap.internal"
+ domain := strings.TrimPrefix(fulldomain, "*.")
type test struct {
- srv *httptest.Server
- vo validateOptions
- ch challenge
- res challenge
+ vo *ValidateChallengeOptions
+ ch *Challenge
jwk *jose.JSONWebKey
- db nosql.DB
+ db DB
err *Error
}
tests := map[string]func(t *testing.T) test{
- "ok/status-already-valid": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*tlsALPN01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Status = StatusValid
-
- return test{
- ch: ch,
- res: ch,
+ "fail/lookupTXT-store-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: fulldomain,
+ Status: StatusPending,
}
- },
- "ok/status-already-invalid": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*tlsALPN01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Status = StatusInvalid
-
- return test{
- ch: ch,
- res: ch,
- }
- },
- "ok/tls-dial-error": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: force", ch.getValue()))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
return test{
ch: ch,
- vo: validateOptions{
- tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
+ vo: &ValidateChallengeOptions{
+ LookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force")
},
},
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, newb)
- return nil, true, nil
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusPending)
+
+ err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
},
},
- res: ch,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
}
},
- "ok/timeout": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.getValue()))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(nil)
- // srv.Start() - do not start server to cause timeout
+ "ok/lookupTXT-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: fulldomain,
+ Status: StatusPending,
+ }
return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: tlsDial,
- },
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ LookupTxt: func(url string) ([]string, error) {
+ return nil, errors.New("force")
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusPending)
+
+ err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
},
},
- res: ch,
}
},
- "ok/no-certificates": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
+ "fail/key-auth-gen-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: fulldomain,
+ Status: StatusPending,
+ }
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
+ jwk.Key = "foo"
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ LookupTxt: func(url string) ([]string, error) {
+ return []string{"foo"}, nil
+ },
+ },
+ jwk: jwk,
+ err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"),
+ }
+ },
+ "fail/key-auth-mismatch-store-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: fulldomain,
+ Status: StatusPending,
+ }
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expErr := RejectedIdentifierErr(errors.Errorf("tls-alpn-01 challenge for %v resulted in no certificates", ch.getValue()))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
assert.FatalError(t, err)
return test{
ch: ch,
- vo: validateOptions{
- tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
- return tls.Client(&noopConn{}, config), nil
+ vo: &ValidateChallengeOptions{
+ LookupTxt: func(url string) ([]string, error) {
+ return []string{"foo", "bar"}, nil
},
},
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusPending)
+
+ err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"})
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
},
},
- res: ch,
+ jwk: jwk,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
}
},
- "ok/no-names": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue()))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
+ "ok/key-auth-mismatch-store-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: fulldomain,
+ Status: StatusPending,
+ }
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
assert.FatalError(t, err)
- expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
-
- cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true)
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(cert)
- srv.Start()
return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: tlsDial,
- },
- jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ LookupTxt: func(url string) ([]string, error) {
+ return []string{"foo", "bar"}, nil
},
},
- res: ch,
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusPending)
+
+ err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"})
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ jwk: jwk,
}
},
- "ok/too-many-names": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue()))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
+ "fail/update-challenge-error": func(t *testing.T) test {
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: fulldomain,
+ Status: StatusPending,
+ }
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
assert.FatalError(t, err)
- expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
-
- cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue(), "other.internal")
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(cert)
- srv.Start()
+ h := sha256.Sum256([]byte(expKeyAuth))
+ expected := base64.RawURLEncoding.EncodeToString(h[:])
return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: tlsDial,
- },
- jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ LookupTxt: func(url string) ([]string, error) {
+ return []string{"foo", expected}, nil
},
},
- res: ch,
- }
- },
- "ok/wrong-name": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusValid)
- expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue()))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
+ assert.Equals(t, updch.Status, StatusValid)
+ assert.Equals(t, updch.Error, nil)
+ va, err := time.Parse(time.RFC3339, updch.ValidatedAt)
+ assert.FatalError(t, err)
+ now := clock.Now()
+ assert.True(t, va.Add(-time.Minute).Before(now))
+ assert.True(t, va.Add(time.Minute).After(now))
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
-
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
- assert.FatalError(t, err)
- expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
-
- cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, "other.internal")
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(cert)
- srv.Start()
-
- return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: tlsDial,
- },
- jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
- },
- },
- res: ch,
- }
- },
- "ok/no-extension": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
-
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
-
- cert, err := newTLSALPNValidationCert(nil, false, true, ch.getValue())
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(cert)
- srv.Start()
-
- return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: tlsDial,
- },
- jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
- },
- },
- res: ch,
- }
- },
- "ok/extension-not-critical": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical"))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
-
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
-
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
- assert.FatalError(t, err)
- expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
-
- cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.getValue())
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(cert)
- srv.Start()
-
- return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: tlsDial,
- },
- jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
- },
- },
- res: ch,
- }
- },
- "ok/extension-malformed": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value"))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
-
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
-
- cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.getValue())
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(cert)
- srv.Start()
-
- return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: tlsDial,
- },
- jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
- },
- },
- res: ch,
- }
- },
- "ok/no-protocol": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
-
- expErr := RejectedIdentifierErr(errors.New("cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge"))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
-
- srv := httptest.NewTLSServer(nil)
-
- return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
- return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
+ return errors.New("force")
},
},
jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
- },
- },
- res: ch,
- }
- },
- "ok/mismatched-token": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
-
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
- assert.FatalError(t, err)
- expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
- incorrectTokenHash := sha256.Sum256([]byte("mismatched"))
-
- expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
- "expected acmeValidationV1 extension value %s for this challenge but got %s",
- hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:])))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
-
- cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.getValue())
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(cert)
- srv.Start()
-
- return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: tlsDial,
- },
- jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
- },
- },
- res: ch,
- }
- },
- "ok/obsolete-oid": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
-
- expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: " +
- "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &tlsALPN01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
-
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
- assert.FatalError(t, err)
- expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
-
- cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.getValue())
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(cert)
- srv.Start()
-
- return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: tlsDial,
- },
- jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, string(newval), string(newb))
- return nil, true, nil
- },
- },
- res: ch,
+ err: NewErrorISE("error updating challenge: force"),
}
},
"ok": func(t *testing.T) test {
- ch, err := newTLSALPNCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*tlsALPN01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Error = MalformedErr(nil).ToACME()
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- baseClone := ch.clone()
- baseClone.Status = StatusValid
- baseClone.Error = nil
- newCh := &tlsALPN01Challenge{baseClone}
+ ch := &Challenge{
+ ID: "chID",
+ Token: "token",
+ Value: fulldomain,
+ Status: StatusPending,
+ }
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
assert.FatalError(t, err)
- expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
-
- cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue())
- assert.FatalError(t, err)
-
- srv, tlsDial := newTestTLSALPNServer(cert)
- srv.Start()
+ h := sha256.Sum256([]byte(expKeyAuth))
+ expected := base64.RawURLEncoding.EncodeToString(h[:])
return test{
- srv: srv,
- ch: ch,
- vo: validateOptions{
- tlsDial: func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) {
- assert.Equals(t, network, "tcp")
- assert.Equals(t, addr, net.JoinHostPort(newCh.getValue(), "443"))
- assert.Equals(t, config.NextProtos, []string{"acme-tls/1"})
- assert.Equals(t, config.ServerName, newCh.getValue())
- assert.True(t, config.InsecureSkipVerify)
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ LookupTxt: func(url string) ([]string, error) {
+ return []string{"foo", expected}, nil
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Status, StatusValid)
- return tlsDial(network, addr, config)
+ assert.Equals(t, updch.Status, StatusValid)
+ assert.Equals(t, updch.Error, nil)
+ va, err := time.Parse(time.RFC3339, updch.ValidatedAt)
+ assert.FatalError(t, err)
+ now := clock.Now()
+ assert.True(t, va.Add(-time.Minute).Before(now))
+ assert.True(t, va.Add(time.Minute).After(now))
+
+ return nil
},
},
jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
-
- alpnCh, err := unmarshalChallenge(newval)
- assert.FatalError(t, err)
- assert.Equals(t, alpnCh.getStatus(), StatusValid)
- assert.True(t, alpnCh.getValidated().Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, alpnCh.getValidated().After(time.Now().UTC().Add(-1*time.Second)))
-
- baseClone.Validated = alpnCh.getValidated()
-
- return nil, true, nil
- },
- },
- res: newCh,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
-
- if tc.srv != nil {
- defer tc.srv.Close()
- }
-
- if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil {
+ if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
}
} else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.res.getID(), ch.getID())
- assert.Equals(t, tc.res.getAccountID(), ch.getAccountID())
- assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID())
- assert.Equals(t, tc.res.getStatus(), ch.getStatus())
- assert.Equals(t, tc.res.getToken(), ch.getToken())
- assert.Equals(t, tc.res.getCreated(), ch.getCreated())
- assert.Equals(t, tc.res.getValidated(), ch.getValidated())
- assert.Equals(t, tc.res.getError(), ch.getError())
- }
+ assert.Nil(t, tc.err)
}
})
}
@@ -1726,268 +1298,939 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na
}, nil
}
-func TestDNS01Validate(t *testing.T) {
+func TestTLSALPN01Validate(t *testing.T) {
+ makeTLSCh := func() *Challenge {
+ return &Challenge{
+ ID: "chID",
+ Token: "token",
+ Type: "tls-alpn-01",
+ Status: StatusPending,
+ Value: "zap.internal",
+ }
+ }
type test struct {
- vo validateOptions
- ch challenge
- res challenge
+ vo *ValidateChallengeOptions
+ ch *Challenge
jwk *jose.JSONWebKey
- db nosql.DB
+ db DB
+ srv *httptest.Server
err *Error
}
tests := map[string]func(t *testing.T) test{
- "ok/status-already-valid": func(t *testing.T) test {
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*dns01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Status = StatusValid
- return test{
- ch: ch,
- res: ch,
- }
- },
- "ok/status-already-invalid": func(t *testing.T) test {
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*dns01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Status = StatusInvalid
- return test{
- ch: ch,
- res: ch,
- }
- },
- "ok/lookup-txt-error": func(t *testing.T) test {
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- expErr := DNSErr(errors.Errorf("error looking up TXT records for "+
- "domain %s: force", ch.getValue()))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &dns01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
+ "fail/tlsDial-store-error": func(t *testing.T) test {
+ ch := makeTLSCh()
return test{
ch: ch,
- vo: validateOptions{
- lookupTxt: func(url string) ([]string, error) {
+ vo: &ValidateChallengeOptions{
+ TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force")
},
},
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, newb)
- return nil, true, nil
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, ch.Status)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: force", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
},
},
- res: ch,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
}
},
- "ok/lookup-txt-wildcard": func(t *testing.T) test {
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*dns01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Value = "*.zap.internal"
+ "ok/tlsDial-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
+ return nil, errors.New("force")
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, ch.Status)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
- jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
- assert.FatalError(t, err)
+ err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: force", ch.Value)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
- assert.FatalError(t, err)
- h := sha256.Sum256([]byte(expKeyAuth))
- expected := base64.RawURLEncoding.EncodeToString(h[:])
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ }
+ },
+ "ok/tlsDial-timeout": func(t *testing.T) test {
+ ch := makeTLSCh()
- baseClone := ch.clone()
- baseClone.Status = StatusValid
- baseClone.Error = nil
- newCh := &dns01Challenge{baseClone}
+ srv, tlsDial := newTestTLSALPNServer(nil)
+ // srv.Start() - do not start server to cause timeout
return test{
- ch: ch,
- res: newCh,
- vo: validateOptions{
- lookupTxt: func(url string) ([]string, error) {
- assert.Equals(t, url, "_acme-challenge.zap.internal")
- return []string{"foo", expected}, nil
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, ch.Status)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
},
},
- jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- dnsCh, err := unmarshalChallenge(newval)
- assert.FatalError(t, err)
- assert.Equals(t, dnsCh.getStatus(), StatusValid)
- baseClone.Validated = dnsCh.getValidated()
- return nil, true, nil
+ srv: srv,
+ }
+ },
+ "ok/no-certificates-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
+ return tls.Client(&noopConn{}, config), nil
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "tls-alpn-01 challenge for %v resulted in no certificates", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
},
},
}
},
- "fail/key-authorization-gen-error": func(t *testing.T) test {
- ch, err := newDNSCh()
- assert.FatalError(t, err)
+ "fail/no-certificates-store-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
+ return tls.Client(&noopConn{}, config), nil
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "tls-alpn-01 challenge for %v resulted in no certificates", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/error-no-protocol": func(t *testing.T) test {
+ ch := makeTLSCh()
+
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
+
+ srv := httptest.NewTLSServer(nil)
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
+ return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ }
+ },
+ "fail/no-protocol-store-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ srv := httptest.NewTLSServer(nil)
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
+ return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
+ },
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/no-names-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ }
+ },
+ "fail/no-names-store-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/too-many-names-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value, "other.internal")
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ }
+ },
+ "ok/wrong-name": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, "other.internal")
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ }
+ },
+ "fail/key-auth-gen-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
jwk.Key = "foo"
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
return test{
ch: ch,
- vo: validateOptions{
- lookupTxt: func(url string) ([]string, error) {
- return []string{"foo", "bar"}, nil
- },
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
},
+ srv: srv,
jwk: jwk,
- err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")),
+ err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"),
}
},
- "ok/key-auth-mismatch": func(t *testing.T) test {
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
+ "ok/error-no-extension": func(t *testing.T) test {
+ ch := makeTLSCh()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
+ cert, err := newTLSALPNValidationCert(nil, false, true, ch.Value)
assert.FatalError(t, err)
- expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+
- "expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}))
- baseClone := ch.clone()
- baseClone.Error = expErr.ToACME()
- newCh := &http01Challenge{baseClone}
- newb, err := json.Marshal(newCh)
- assert.FatalError(t, err)
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
return test{
ch: ch,
- vo: validateOptions{
- lookupTxt: func(url string) ([]string, error) {
- return []string{"foo", "bar"}, nil
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
},
},
+ srv: srv,
jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, newb)
- return nil, true, nil
- },
- },
- res: ch,
}
},
- "fail/save-error": func(t *testing.T) test {
- ch, err := newDNSCh()
- assert.FatalError(t, err)
+ "fail/no-extension-store-error": func(t *testing.T) test {
+ ch := makeTLSCh()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
+ cert, err := newTLSALPNValidationCert(nil, false, true, ch.Value)
assert.FatalError(t, err)
- h := sha256.Sum256([]byte(expKeyAuth))
- expected := base64.RawURLEncoding.EncodeToString(h[:])
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
return test{
ch: ch,
- vo: validateOptions{
- lookupTxt: func(url string) ([]string, error) {
- return []string{"foo", expected}, nil
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
},
},
+ srv: srv,
jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/error-extension-not-critical": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
},
},
- err: ServerInternalErr(errors.New("error saving acme challenge: force")),
+ srv: srv,
+ jwk: jwk,
+ }
+ },
+ "fail/extension-not-critical-store-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/error-malformed-extension": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ }
+ },
+ "fail/malformed-extension-store-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/error-keyauth-mismatch": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+ incorrectTokenHash := sha256.Sum256([]byte("mismatched"))
+
+ cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+
+ "expected acmeValidationV1 extension value %s for this challenge but got %s",
+ hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:]))
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ }
+ },
+ "fail/keyauth-mismatch-store-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+ incorrectTokenHash := sha256.Sum256([]byte("mismatched"))
+
+ cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+
+ "expected acmeValidationV1 extension value %s for this challenge but got %s",
+ hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:]))
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
+ }
+ },
+ "ok/error-obsolete-oid": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+
+ "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return nil
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ }
+ },
+ "fail/obsolete-oid-store-error": func(t *testing.T) test {
+ ch := makeTLSCh()
+
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
+ assert.FatalError(t, err)
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
+
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
+
+ return test{
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusInvalid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+
+ err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+
+ "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")
+
+ assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updch.Error.Type, err.Type)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ assert.Equals(t, updch.Error.Status, err.Status)
+ assert.Equals(t, updch.Error.Detail, err.Detail)
+ return errors.New("force")
+ },
+ },
+ srv: srv,
+ jwk: jwk,
+ err: NewErrorISE("failure saving error to acme challenge: force"),
}
},
"ok": func(t *testing.T) test {
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- _ch, ok := ch.(*dns01Challenge)
- assert.Fatal(t, ok)
- _ch.baseChallenge.Error = MalformedErr(nil).ToACME()
- oldb, err := json.Marshal(ch)
- assert.FatalError(t, err)
+ ch := makeTLSCh()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk)
+ expKeyAuth, err := KeyAuthorization(ch.Token, jwk)
assert.FatalError(t, err)
- h := sha256.Sum256([]byte(expKeyAuth))
- expected := base64.RawURLEncoding.EncodeToString(h[:])
+ expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth))
- baseClone := ch.clone()
- baseClone.Status = StatusValid
- baseClone.Error = nil
- newCh := &dns01Challenge{baseClone}
+ cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value)
+ assert.FatalError(t, err)
+
+ srv, tlsDial := newTestTLSALPNServer(cert)
+ srv.Start()
return test{
- ch: ch,
- res: newCh,
- vo: validateOptions{
- lookupTxt: func(url string) ([]string, error) {
- return []string{"foo", expected}, nil
+ ch: ch,
+ vo: &ValidateChallengeOptions{
+ TLSDial: tlsDial,
+ },
+ db: &MockDB{
+ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
+ assert.Equals(t, updch.ID, ch.ID)
+ assert.Equals(t, updch.Token, ch.Token)
+ assert.Equals(t, updch.Status, StatusValid)
+ assert.Equals(t, updch.Type, ch.Type)
+ assert.Equals(t, updch.Value, ch.Value)
+ assert.Equals(t, updch.Error, nil)
+ return nil
},
},
+ srv: srv,
jwk: jwk,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, challengeTable)
- assert.Equals(t, key, []byte(ch.getID()))
- assert.Equals(t, old, oldb)
-
- dnsCh, err := unmarshalChallenge(newval)
- assert.FatalError(t, err)
- assert.Equals(t, dnsCh.getStatus(), StatusValid)
- assert.True(t, dnsCh.getValidated().Before(time.Now().UTC()))
- assert.True(t, dnsCh.getValidated().After(time.Now().UTC().Add(-1*time.Second)))
-
- baseClone.Validated = dnsCh.getValidated()
-
- return nil, true, nil
- },
- },
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
- if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil {
+
+ if tc.srv != nil {
+ defer tc.srv.Close()
+ }
+
+ if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
}
} else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.res.getID(), ch.getID())
- assert.Equals(t, tc.res.getAccountID(), ch.getAccountID())
- assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID())
- assert.Equals(t, tc.res.getStatus(), ch.getStatus())
- assert.Equals(t, tc.res.getToken(), ch.getToken())
- assert.Equals(t, tc.res.getCreated(), ch.getCreated())
- assert.Equals(t, tc.res.getValidated(), ch.getValidated())
- assert.Equals(t, tc.res.getError(), ch.getError())
- }
+ assert.Nil(t, tc.err)
}
})
}
diff --git a/acme/common.go b/acme/common.go
index fec47b94..26552c61 100644
--- a/acme/common.go
+++ b/acme/common.go
@@ -3,19 +3,32 @@ package acme
import (
"context"
"crypto/x509"
- "net/url"
"time"
- "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
- "go.step.sm/crypto/jose"
- "go.step.sm/crypto/randutil"
)
+// CertificateAuthority is the interface implemented by a CA authority.
+type CertificateAuthority interface {
+ Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
+ LoadProvisionerByID(string) (provisioner.Interface, error)
+}
+
+// Clock that returns time in UTC rounded to seconds.
+type Clock struct{}
+
+// Now returns the UTC time rounded to seconds.
+func (c *Clock) Now() time.Time {
+ return time.Now().UTC().Truncate(time.Second)
+}
+
+var clock Clock
+
// Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority.
type Provisioner interface {
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
+ GetID() string
GetName() string
DefaultTLSCertDuration() time.Duration
GetOptions() *provisioner.Options
@@ -25,6 +38,7 @@ type Provisioner interface {
type MockProvisioner struct {
Mret1 interface{}
Merr error
+ MgetID func() string
MgetName func() string
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MdefaultTLSCertDuration func() time.Duration
@@ -55,6 +69,7 @@ func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration {
return m.Mret1.(time.Duration)
}
+// GetOptions mock
func (m *MockProvisioner) GetOptions() *provisioner.Options {
if m.MgetOptions != nil {
return m.MgetOptions()
@@ -62,120 +77,10 @@ func (m *MockProvisioner) GetOptions() *provisioner.Options {
return m.Mret1.(*provisioner.Options)
}
-// ContextKey is the key type for storing and searching for ACME request
-// essentials in the context of a request.
-type ContextKey string
-
-const (
- // AccContextKey account key
- AccContextKey = ContextKey("acc")
- // BaseURLContextKey baseURL key
- BaseURLContextKey = ContextKey("baseURL")
- // JwsContextKey jws key
- JwsContextKey = ContextKey("jws")
- // JwkContextKey jwk key
- JwkContextKey = ContextKey("jwk")
- // PayloadContextKey payload key
- PayloadContextKey = ContextKey("payload")
- // ProvisionerContextKey provisioner key
- ProvisionerContextKey = ContextKey("provisioner")
-)
-
-// AccountFromContext searches the context for an ACME account. Returns the
-// account or an error.
-func AccountFromContext(ctx context.Context) (*Account, error) {
- val, ok := ctx.Value(AccContextKey).(*Account)
- if !ok || val == nil {
- return nil, AccountDoesNotExistErr(nil)
+// GetID mock
+func (m *MockProvisioner) GetID() string {
+ if m.MgetID != nil {
+ return m.MgetID()
}
- return val, nil
+ return m.Mret1.(string)
}
-
-// BaseURLFromContext returns the baseURL if one is stored in the context.
-func BaseURLFromContext(ctx context.Context) *url.URL {
- val, ok := ctx.Value(BaseURLContextKey).(*url.URL)
- if !ok || val == nil {
- return nil
- }
- return val
-}
-
-// JwkFromContext searches the context for a JWK. Returns the JWK or an error.
-func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
- val, ok := ctx.Value(JwkContextKey).(*jose.JSONWebKey)
- if !ok || val == nil {
- return nil, ServerInternalErr(errors.Errorf("jwk expected in request context"))
- }
- return val, nil
-}
-
-// JwsFromContext searches the context for a JWS. Returns the JWS or an error.
-func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
- val, ok := ctx.Value(JwsContextKey).(*jose.JSONWebSignature)
- if !ok || val == nil {
- return nil, ServerInternalErr(errors.Errorf("jws expected in request context"))
- }
- return val, nil
-}
-
-// ProvisionerFromContext searches the context for a provisioner. Returns the
-// provisioner or an error.
-func ProvisionerFromContext(ctx context.Context) (Provisioner, error) {
- val := ctx.Value(ProvisionerContextKey)
- if val == nil {
- return nil, ServerInternalErr(errors.Errorf("provisioner expected in request context"))
- }
- pval, ok := val.(Provisioner)
- if !ok || pval == nil {
- return nil, ServerInternalErr(errors.Errorf("provisioner in context is not an ACME provisioner"))
- }
- return pval, nil
-}
-
-// SignAuthority is the interface implemented by a CA authority.
-type SignAuthority interface {
- Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
- LoadProvisionerByID(string) (provisioner.Interface, error)
-}
-
-// Identifier encodes the type that an order pertains to.
-type Identifier struct {
- Type string `json:"type"`
- Value string `json:"value"`
-}
-
-var (
- // StatusValid -- valid
- StatusValid = "valid"
- // StatusInvalid -- invalid
- StatusInvalid = "invalid"
- // StatusPending -- pending; e.g. an Order that is not ready to be finalized.
- StatusPending = "pending"
- // StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
- StatusDeactivated = "deactivated"
- // StatusReady -- ready; e.g. for an Order that is ready to be finalized.
- StatusReady = "ready"
- //statusExpired = "expired"
- //statusActive = "active"
- //statusProcessing = "processing"
-)
-
-var idLen = 32
-
-func randID() (val string, err error) {
- val, err = randutil.Alphanumeric(idLen)
- if err != nil {
- return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID"))
- }
- return val, nil
-}
-
-// Clock that returns time in UTC rounded to seconds.
-type Clock int
-
-// Now returns the UTC time rounded to seconds.
-func (c *Clock) Now() time.Time {
- return time.Now().UTC().Round(time.Second)
-}
-
-var clock = new(Clock)
diff --git a/acme/db.go b/acme/db.go
new file mode 100644
index 00000000..d678fef4
--- /dev/null
+++ b/acme/db.go
@@ -0,0 +1,251 @@
+package acme
+
+import (
+ "context"
+
+ "github.com/pkg/errors"
+)
+
+// ErrNotFound is an error that should be used by the acme.DB interface to
+// indicate that an entity does not exist. For example, in the new-account
+// endpoint, if GetAccountByKeyID returns ErrNotFound we will create the new
+// account.
+var ErrNotFound = errors.New("not found")
+
+// DB is the DB interface expected by the step-ca ACME API.
+type DB interface {
+ CreateAccount(ctx context.Context, acc *Account) error
+ GetAccount(ctx context.Context, id string) (*Account, error)
+ GetAccountByKeyID(ctx context.Context, kid string) (*Account, error)
+ UpdateAccount(ctx context.Context, acc *Account) error
+
+ CreateNonce(ctx context.Context) (Nonce, error)
+ DeleteNonce(ctx context.Context, nonce Nonce) error
+
+ CreateAuthorization(ctx context.Context, az *Authorization) error
+ GetAuthorization(ctx context.Context, id string) (*Authorization, error)
+ UpdateAuthorization(ctx context.Context, az *Authorization) error
+
+ CreateCertificate(ctx context.Context, cert *Certificate) error
+ GetCertificate(ctx context.Context, id string) (*Certificate, error)
+
+ CreateChallenge(ctx context.Context, ch *Challenge) error
+ GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error)
+ UpdateChallenge(ctx context.Context, ch *Challenge) error
+
+ CreateOrder(ctx context.Context, o *Order) error
+ GetOrder(ctx context.Context, id string) (*Order, error)
+ GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error)
+ UpdateOrder(ctx context.Context, o *Order) error
+}
+
+// MockDB is an implementation of the DB interface that should only be used as
+// a mock in tests.
+type MockDB struct {
+ MockCreateAccount func(ctx context.Context, acc *Account) error
+ MockGetAccount func(ctx context.Context, id string) (*Account, error)
+ MockGetAccountByKeyID func(ctx context.Context, kid string) (*Account, error)
+ MockUpdateAccount func(ctx context.Context, acc *Account) error
+
+ MockCreateNonce func(ctx context.Context) (Nonce, error)
+ MockDeleteNonce func(ctx context.Context, nonce Nonce) error
+
+ MockCreateAuthorization func(ctx context.Context, az *Authorization) error
+ MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error)
+ MockUpdateAuthorization func(ctx context.Context, az *Authorization) error
+
+ MockCreateCertificate func(ctx context.Context, cert *Certificate) error
+ MockGetCertificate func(ctx context.Context, id string) (*Certificate, error)
+
+ MockCreateChallenge func(ctx context.Context, ch *Challenge) error
+ MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error)
+ MockUpdateChallenge func(ctx context.Context, ch *Challenge) error
+
+ MockCreateOrder func(ctx context.Context, o *Order) error
+ MockGetOrder func(ctx context.Context, id string) (*Order, error)
+ MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error)
+ MockUpdateOrder func(ctx context.Context, o *Order) error
+
+ MockRet1 interface{}
+ MockError error
+}
+
+// CreateAccount mock.
+func (m *MockDB) CreateAccount(ctx context.Context, acc *Account) error {
+ if m.MockCreateAccount != nil {
+ return m.MockCreateAccount(ctx, acc)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// GetAccount mock.
+func (m *MockDB) GetAccount(ctx context.Context, id string) (*Account, error) {
+ if m.MockGetAccount != nil {
+ return m.MockGetAccount(ctx, id)
+ } else if m.MockError != nil {
+ return nil, m.MockError
+ }
+ return m.MockRet1.(*Account), m.MockError
+}
+
+// GetAccountByKeyID mock
+func (m *MockDB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) {
+ if m.MockGetAccountByKeyID != nil {
+ return m.MockGetAccountByKeyID(ctx, kid)
+ } else if m.MockError != nil {
+ return nil, m.MockError
+ }
+ return m.MockRet1.(*Account), m.MockError
+}
+
+// UpdateAccount mock
+func (m *MockDB) UpdateAccount(ctx context.Context, acc *Account) error {
+ if m.MockUpdateAccount != nil {
+ return m.MockUpdateAccount(ctx, acc)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// CreateNonce mock
+func (m *MockDB) CreateNonce(ctx context.Context) (Nonce, error) {
+ if m.MockCreateNonce != nil {
+ return m.MockCreateNonce(ctx)
+ } else if m.MockError != nil {
+ return Nonce(""), m.MockError
+ }
+ return m.MockRet1.(Nonce), m.MockError
+}
+
+// DeleteNonce mock
+func (m *MockDB) DeleteNonce(ctx context.Context, nonce Nonce) error {
+ if m.MockDeleteNonce != nil {
+ return m.MockDeleteNonce(ctx, nonce)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// CreateAuthorization mock
+func (m *MockDB) CreateAuthorization(ctx context.Context, az *Authorization) error {
+ if m.MockCreateAuthorization != nil {
+ return m.MockCreateAuthorization(ctx, az)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// GetAuthorization mock
+func (m *MockDB) GetAuthorization(ctx context.Context, id string) (*Authorization, error) {
+ if m.MockGetAuthorization != nil {
+ return m.MockGetAuthorization(ctx, id)
+ } else if m.MockError != nil {
+ return nil, m.MockError
+ }
+ return m.MockRet1.(*Authorization), m.MockError
+}
+
+// UpdateAuthorization mock
+func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) error {
+ if m.MockUpdateAuthorization != nil {
+ return m.MockUpdateAuthorization(ctx, az)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// CreateCertificate mock
+func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error {
+ if m.MockCreateCertificate != nil {
+ return m.MockCreateCertificate(ctx, cert)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// GetCertificate mock
+func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, error) {
+ if m.MockGetCertificate != nil {
+ return m.MockGetCertificate(ctx, id)
+ } else if m.MockError != nil {
+ return nil, m.MockError
+ }
+ return m.MockRet1.(*Certificate), m.MockError
+}
+
+// CreateChallenge mock
+func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error {
+ if m.MockCreateChallenge != nil {
+ return m.MockCreateChallenge(ctx, ch)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// GetChallenge mock
+func (m *MockDB) GetChallenge(ctx context.Context, chID, azID string) (*Challenge, error) {
+ if m.MockGetChallenge != nil {
+ return m.MockGetChallenge(ctx, chID, azID)
+ } else if m.MockError != nil {
+ return nil, m.MockError
+ }
+ return m.MockRet1.(*Challenge), m.MockError
+}
+
+// UpdateChallenge mock
+func (m *MockDB) UpdateChallenge(ctx context.Context, ch *Challenge) error {
+ if m.MockUpdateChallenge != nil {
+ return m.MockUpdateChallenge(ctx, ch)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// CreateOrder mock
+func (m *MockDB) CreateOrder(ctx context.Context, o *Order) error {
+ if m.MockCreateOrder != nil {
+ return m.MockCreateOrder(ctx, o)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// GetOrder mock
+func (m *MockDB) GetOrder(ctx context.Context, id string) (*Order, error) {
+ if m.MockGetOrder != nil {
+ return m.MockGetOrder(ctx, id)
+ } else if m.MockError != nil {
+ return nil, m.MockError
+ }
+ return m.MockRet1.(*Order), m.MockError
+}
+
+// UpdateOrder mock
+func (m *MockDB) UpdateOrder(ctx context.Context, o *Order) error {
+ if m.MockUpdateOrder != nil {
+ return m.MockUpdateOrder(ctx, o)
+ } else if m.MockError != nil {
+ return m.MockError
+ }
+ return m.MockError
+}
+
+// GetOrdersByAccountID mock
+func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
+ if m.MockGetOrdersByAccountID != nil {
+ return m.MockGetOrdersByAccountID(ctx, accID)
+ } else if m.MockError != nil {
+ return nil, m.MockError
+ }
+ return m.MockRet1.([]string), m.MockError
+}
diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go
new file mode 100644
index 00000000..1c3bec5d
--- /dev/null
+++ b/acme/db/nosql/account.go
@@ -0,0 +1,136 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/certificates/acme"
+ nosqlDB "github.com/smallstep/nosql"
+ "go.step.sm/crypto/jose"
+)
+
+// dbAccount represents an ACME account.
+type dbAccount struct {
+ ID string `json:"id"`
+ Key *jose.JSONWebKey `json:"key"`
+ Contact []string `json:"contact,omitempty"`
+ Status acme.Status `json:"status"`
+ CreatedAt time.Time `json:"createdAt"`
+ DeactivatedAt time.Time `json:"deactivatedAt"`
+}
+
+func (dba *dbAccount) clone() *dbAccount {
+ nu := *dba
+ return &nu
+}
+
+func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) {
+ id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
+ if err != nil {
+ if nosqlDB.IsErrNotFound(err) {
+ return "", acme.ErrNotFound
+ }
+ return "", errors.Wrapf(err, "error loading key-account index for key %s", kid)
+ }
+ return string(id), nil
+}
+
+// getDBAccount retrieves and unmarshals dbAccount.
+func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) {
+ data, err := db.db.Get(accountTable, []byte(id))
+ if err != nil {
+ if nosqlDB.IsErrNotFound(err) {
+ return nil, acme.ErrNotFound
+ }
+ return nil, errors.Wrapf(err, "error loading account %s", id)
+ }
+
+ dbacc := new(dbAccount)
+ if err = json.Unmarshal(data, dbacc); err != nil {
+ return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id)
+ }
+ return dbacc, nil
+}
+
+// GetAccount retrieves an ACME account by ID.
+func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
+ dbacc, err := db.getDBAccount(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+
+ return &acme.Account{
+ Status: dbacc.Status,
+ Contact: dbacc.Contact,
+ Key: dbacc.Key,
+ ID: dbacc.ID,
+ }, nil
+}
+
+// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK).
+func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) {
+ id, err := db.getAccountIDByKeyID(ctx, kid)
+ if err != nil {
+ return nil, err
+ }
+ return db.GetAccount(ctx, id)
+}
+
+// CreateAccount imlements the AcmeDB.CreateAccount interface.
+func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
+ var err error
+ acc.ID, err = randID()
+ if err != nil {
+ return err
+ }
+
+ dba := &dbAccount{
+ ID: acc.ID,
+ Key: acc.Key,
+ Contact: acc.Contact,
+ Status: acc.Status,
+ CreatedAt: clock.Now(),
+ }
+
+ kid, err := acme.KeyToID(dba.Key)
+ if err != nil {
+ return err
+ }
+ kidB := []byte(kid)
+
+ // Set the jwkID -> acme account ID index
+ _, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(acc.ID))
+ switch {
+ case err != nil:
+ return errors.Wrap(err, "error storing keyID to accountID index")
+ case !swapped:
+ return errors.Errorf("key-id to account-id index already exists")
+ default:
+ if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil {
+ db.db.Del(accountByKeyIDTable, kidB)
+ return err
+ }
+ return nil
+ }
+}
+
+// UpdateAccount imlements the AcmeDB.UpdateAccount interface.
+func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error {
+ old, err := db.getDBAccount(ctx, acc.ID)
+ if err != nil {
+ return err
+ }
+
+ nu := old.clone()
+ nu.Contact = acc.Contact
+ nu.Status = acc.Status
+
+ // If the status has changed to 'deactivated', then set deactivatedAt timestamp.
+ if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated {
+ nu.DeactivatedAt = clock.Now()
+ }
+
+ return db.save(ctx, old.ID, nu, old, "account", accountTable)
+}
diff --git a/acme/db/nosql/account_test.go b/acme/db/nosql/account_test.go
new file mode 100644
index 00000000..5ba99a73
--- /dev/null
+++ b/acme/db/nosql/account_test.go
@@ -0,0 +1,706 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/assert"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/certificates/db"
+ "github.com/smallstep/nosql"
+ nosqldb "github.com/smallstep/nosql/database"
+ "go.step.sm/crypto/jose"
+)
+
+func TestDB_getDBAccount(t *testing.T) {
+ accID := "accID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ dbacc *dbAccount
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/not-found": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+
+ return nil, nosqldb.ErrNotFound
+ },
+ },
+ err: acme.ErrNotFound,
+ }
+ },
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading account accID: force"),
+ }
+ },
+ "fail/unmarshal-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+
+ return []byte("foo"), nil
+ },
+ },
+ err: errors.New("error unmarshaling account accID into dbAccount"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ now := clock.Now()
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ dbacc := &dbAccount{
+ ID: accID,
+ Status: acme.StatusDeactivated,
+ CreatedAt: now,
+ DeactivatedAt: now,
+ Contact: []string{"foo", "bar"},
+ Key: jwk,
+ }
+ b, err := json.Marshal(dbacc)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+
+ return b, nil
+ },
+ },
+ dbacc: dbacc,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, dbacc.ID, tc.dbacc.ID)
+ assert.Equals(t, dbacc.Status, tc.dbacc.Status)
+ assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
+ assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
+ assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
+ assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_getAccountIDByKeyID(t *testing.T) {
+ accID := "accID"
+ kid := "kid"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/not-found": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountByKeyIDTable)
+ assert.Equals(t, string(key), kid)
+
+ return nil, nosqldb.ErrNotFound
+ },
+ },
+ err: acme.ErrNotFound,
+ }
+ },
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountByKeyIDTable)
+ assert.Equals(t, string(key), kid)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading key-account index for key kid: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountByKeyIDTable)
+ assert.Equals(t, string(key), kid)
+
+ return []byte(accID), nil
+ },
+ },
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, retAccID, accID)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_GetAccount(t *testing.T) {
+ accID := "accID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ dbacc *dbAccount
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading account accID: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ now := clock.Now()
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ dbacc := &dbAccount{
+ ID: accID,
+ Status: acme.StatusDeactivated,
+ CreatedAt: now,
+ DeactivatedAt: now,
+ Contact: []string{"foo", "bar"},
+ Key: jwk,
+ }
+ b, err := json.Marshal(dbacc)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+ return b, nil
+ },
+ },
+ dbacc: dbacc,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if acc, err := db.GetAccount(context.Background(), accID); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, acc.ID, tc.dbacc.ID)
+ assert.Equals(t, acc.Status, tc.dbacc.Status)
+ assert.Equals(t, acc.Contact, tc.dbacc.Contact)
+ assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_GetAccountByKeyID(t *testing.T) {
+ accID := "accID"
+ kid := "kid"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ dbacc *dbAccount
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.getAccountIDByKeyID-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, string(bucket), string(accountByKeyIDTable))
+ assert.Equals(t, string(key), kid)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading key-account index for key kid: force"),
+ }
+ },
+ "fail/db.GetAccount-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(accountByKeyIDTable):
+ assert.Equals(t, string(key), kid)
+ return []byte(accID), nil
+ case string(accountTable):
+ assert.Equals(t, string(key), accID)
+ return nil, errors.New("force")
+ default:
+ assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ },
+ err: errors.New("error loading account accID: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ now := clock.Now()
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ dbacc := &dbAccount{
+ ID: accID,
+ Status: acme.StatusDeactivated,
+ CreatedAt: now,
+ DeactivatedAt: now,
+ Contact: []string{"foo", "bar"},
+ Key: jwk,
+ }
+ b, err := json.Marshal(dbacc)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(accountByKeyIDTable):
+ assert.Equals(t, string(key), kid)
+ return []byte(accID), nil
+ case string(accountTable):
+ assert.Equals(t, string(key), accID)
+ return b, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ },
+ dbacc: dbacc,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, acc.ID, tc.dbacc.ID)
+ assert.Equals(t, acc.Status, tc.dbacc.Status)
+ assert.Equals(t, acc.Contact, tc.dbacc.Contact)
+ assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_CreateAccount(t *testing.T) {
+ type test struct {
+ db nosql.DB
+ acc *acme.Account
+ err error
+ _id *string
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/keyID-cmpAndSwap-error": func(t *testing.T) test {
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ acc := &acme.Account{
+ Status: acme.StatusValid,
+ Contact: []string{"foo", "bar"},
+ Key: jwk,
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, accountByKeyIDTable)
+ assert.Equals(t, string(key), jwk.KeyID)
+ assert.Equals(t, old, nil)
+
+ assert.Equals(t, nu, []byte(acc.ID))
+ return nil, false, errors.New("force")
+ },
+ },
+ acc: acc,
+ err: errors.New("error storing keyID to accountID index: force"),
+ }
+ },
+ "fail/keyID-cmpAndSwap-false": func(t *testing.T) test {
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ acc := &acme.Account{
+ Status: acme.StatusValid,
+ Contact: []string{"foo", "bar"},
+ Key: jwk,
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, accountByKeyIDTable)
+ assert.Equals(t, string(key), jwk.KeyID)
+ assert.Equals(t, old, nil)
+
+ assert.Equals(t, nu, []byte(acc.ID))
+ return nil, false, nil
+ },
+ },
+ acc: acc,
+ err: errors.New("key-id to account-id index already exists"),
+ }
+ },
+ "fail/account-save-error": func(t *testing.T) test {
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ acc := &acme.Account{
+ Status: acme.StatusValid,
+ Contact: []string{"foo", "bar"},
+ Key: jwk,
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ switch string(bucket) {
+ case string(accountByKeyIDTable):
+ assert.Equals(t, string(key), jwk.KeyID)
+ assert.Equals(t, old, nil)
+ return nu, true, nil
+ case string(accountTable):
+ assert.Equals(t, string(key), acc.ID)
+ assert.Equals(t, old, nil)
+
+ dbacc := new(dbAccount)
+ assert.FatalError(t, json.Unmarshal(nu, dbacc))
+ assert.Equals(t, dbacc.ID, string(key))
+ assert.Equals(t, dbacc.Contact, acc.Contact)
+ assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
+ assert.True(t, dbacc.DeactivatedAt.IsZero())
+ return nil, false, errors.New("force")
+ default:
+ assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
+ return nil, false, errors.New("force")
+ }
+ },
+ },
+ acc: acc,
+ err: errors.New("error saving acme account: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ var (
+ id string
+ idPtr = &id
+ )
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ acc := &acme.Account{
+ Status: acme.StatusValid,
+ Contact: []string{"foo", "bar"},
+ Key: jwk,
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ id = string(key)
+ switch string(bucket) {
+ case string(accountByKeyIDTable):
+ assert.Equals(t, string(key), jwk.KeyID)
+ assert.Equals(t, old, nil)
+ return nu, true, nil
+ case string(accountTable):
+ assert.Equals(t, string(key), acc.ID)
+ assert.Equals(t, old, nil)
+
+ dbacc := new(dbAccount)
+ assert.FatalError(t, json.Unmarshal(nu, dbacc))
+ assert.Equals(t, dbacc.ID, string(key))
+ assert.Equals(t, dbacc.Contact, acc.Contact)
+ assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
+ assert.True(t, dbacc.DeactivatedAt.IsZero())
+ return nu, true, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
+ return nil, false, errors.New("force")
+ }
+ },
+ },
+ acc: acc,
+ _id: idPtr,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.CreateAccount(context.Background(), tc.acc); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, tc.acc.ID, *tc._id)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_UpdateAccount(t *testing.T) {
+ accID := "accID"
+ now := clock.Now()
+ jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
+ assert.FatalError(t, err)
+ dbacc := &dbAccount{
+ ID: accID,
+ Status: acme.StatusDeactivated,
+ CreatedAt: now,
+ DeactivatedAt: now,
+ Contact: []string{"foo", "bar"},
+ Key: jwk,
+ }
+ b, err := json.Marshal(dbacc)
+ assert.FatalError(t, err)
+ type test struct {
+ db nosql.DB
+ acc *acme.Account
+ err error
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ acc: &acme.Account{
+ ID: accID,
+ },
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading account accID: force"),
+ }
+ },
+ "fail/already-deactivated": func(t *testing.T) test {
+ clone := dbacc.clone()
+ clone.Status = acme.StatusDeactivated
+ clone.DeactivatedAt = now
+ dbaccb, err := json.Marshal(clone)
+ assert.FatalError(t, err)
+ acc := &acme.Account{
+ ID: accID,
+ Status: acme.StatusDeactivated,
+ Contact: []string{"foo", "bar"},
+ }
+ return test{
+ acc: acc,
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+
+ return dbaccb, nil
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, old, b)
+
+ dbNew := new(dbAccount)
+ assert.FatalError(t, json.Unmarshal(nu, dbNew))
+ assert.Equals(t, dbNew.ID, clone.ID)
+ assert.Equals(t, dbNew.Status, clone.Status)
+ assert.Equals(t, dbNew.Contact, clone.Contact)
+ assert.Equals(t, dbNew.Key.KeyID, clone.Key.KeyID)
+ assert.Equals(t, dbNew.CreatedAt, clone.CreatedAt)
+ assert.Equals(t, dbNew.DeactivatedAt, clone.DeactivatedAt)
+ return nil, false, errors.New("force")
+ },
+ },
+ err: errors.New("error saving acme account: force"),
+ }
+ },
+ "fail/db.CmpAndSwap-error": func(t *testing.T) test {
+ acc := &acme.Account{
+ ID: accID,
+ Status: acme.StatusDeactivated,
+ Contact: []string{"foo", "bar"},
+ }
+ return test{
+ acc: acc,
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+
+ return b, nil
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, old, b)
+
+ dbNew := new(dbAccount)
+ assert.FatalError(t, json.Unmarshal(nu, dbNew))
+ assert.Equals(t, dbNew.ID, dbacc.ID)
+ assert.Equals(t, dbNew.Status, acc.Status)
+ assert.Equals(t, dbNew.Contact, dbacc.Contact)
+ assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
+ assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
+ assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
+ assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
+ return nil, false, errors.New("force")
+ },
+ },
+ err: errors.New("error saving acme account: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ acc := &acme.Account{
+ ID: accID,
+ Status: acme.StatusDeactivated,
+ Contact: []string{"foo", "bar"},
+ Key: jwk,
+ }
+ return test{
+ acc: acc,
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, string(key), accID)
+
+ return b, nil
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, accountTable)
+ assert.Equals(t, old, b)
+
+ dbNew := new(dbAccount)
+ assert.FatalError(t, json.Unmarshal(nu, dbNew))
+ assert.Equals(t, dbNew.ID, dbacc.ID)
+ assert.Equals(t, dbNew.Status, acc.Status)
+ assert.Equals(t, dbNew.Contact, dbacc.Contact)
+ assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
+ assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
+ assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
+ assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
+ return nu, true, nil
+ },
+ },
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.UpdateAccount(context.Background(), tc.acc); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, tc.acc.ID, dbacc.ID)
+ assert.Equals(t, tc.acc.Status, dbacc.Status)
+ assert.Equals(t, tc.acc.Contact, dbacc.Contact)
+ assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID)
+ }
+ }
+ })
+ }
+}
diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go
new file mode 100644
index 00000000..6decbe4f
--- /dev/null
+++ b/acme/db/nosql/authz.go
@@ -0,0 +1,118 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/nosql"
+)
+
+// dbAuthz is the base authz type that others build from.
+type dbAuthz struct {
+ ID string `json:"id"`
+ AccountID string `json:"accountID"`
+ Identifier acme.Identifier `json:"identifier"`
+ Status acme.Status `json:"status"`
+ Token string `json:"token"`
+ ChallengeIDs []string `json:"challengeIDs"`
+ Wildcard bool `json:"wildcard"`
+ CreatedAt time.Time `json:"createdAt"`
+ ExpiresAt time.Time `json:"expiresAt"`
+ Error *acme.Error `json:"error"`
+}
+
+func (ba *dbAuthz) clone() *dbAuthz {
+ u := *ba
+ return &u
+}
+
+// getDBAuthz retrieves and unmarshals a database representation of the
+// ACME Authorization type.
+func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) {
+ data, err := db.db.Get(authzTable, []byte(id))
+ if nosql.IsErrNotFound(err) {
+ return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id)
+ } else if err != nil {
+ return nil, errors.Wrapf(err, "error loading authz %s", id)
+ }
+
+ var dbaz dbAuthz
+ if err = json.Unmarshal(data, &dbaz); err != nil {
+ return nil, errors.Wrapf(err, "error unmarshaling authz %s into dbAuthz", id)
+ }
+ return &dbaz, nil
+}
+
+// GetAuthorization retrieves and unmarshals an ACME authz type from the database.
+// Implements acme.DB GetAuthorization interface.
+func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorization, error) {
+ dbaz, err := db.getDBAuthz(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ var chs = make([]*acme.Challenge, len(dbaz.ChallengeIDs))
+ for i, chID := range dbaz.ChallengeIDs {
+ chs[i], err = db.GetChallenge(ctx, chID, id)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return &acme.Authorization{
+ ID: dbaz.ID,
+ AccountID: dbaz.AccountID,
+ Identifier: dbaz.Identifier,
+ Status: dbaz.Status,
+ Challenges: chs,
+ Wildcard: dbaz.Wildcard,
+ ExpiresAt: dbaz.ExpiresAt,
+ Token: dbaz.Token,
+ Error: dbaz.Error,
+ }, nil
+}
+
+// CreateAuthorization creates an entry in the database for the Authorization.
+// Implements the acme.DB.CreateAuthorization interface.
+func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) error {
+ var err error
+ az.ID, err = randID()
+ if err != nil {
+ return err
+ }
+
+ chIDs := make([]string, len(az.Challenges))
+ for i, ch := range az.Challenges {
+ chIDs[i] = ch.ID
+ }
+
+ now := clock.Now()
+ dbaz := &dbAuthz{
+ ID: az.ID,
+ AccountID: az.AccountID,
+ Status: az.Status,
+ CreatedAt: now,
+ ExpiresAt: az.ExpiresAt,
+ Identifier: az.Identifier,
+ ChallengeIDs: chIDs,
+ Token: az.Token,
+ Wildcard: az.Wildcard,
+ }
+
+ return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable)
+}
+
+// UpdateAuthorization saves an updated ACME Authorization to the database.
+func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) error {
+ old, err := db.getDBAuthz(ctx, az.ID)
+ if err != nil {
+ return err
+ }
+
+ nu := old.clone()
+
+ nu.Status = az.Status
+ nu.Error = az.Error
+ return db.save(ctx, old.ID, nu, old, "authz", authzTable)
+}
diff --git a/acme/db/nosql/authz_test.go b/acme/db/nosql/authz_test.go
new file mode 100644
index 00000000..0c2cec50
--- /dev/null
+++ b/acme/db/nosql/authz_test.go
@@ -0,0 +1,620 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/assert"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/certificates/db"
+ "github.com/smallstep/nosql"
+ nosqldb "github.com/smallstep/nosql/database"
+)
+
+func TestDB_getDBAuthz(t *testing.T) {
+ azID := "azID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ dbaz *dbAuthz
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/not-found": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), azID)
+
+ return nil, nosqldb.ErrNotFound
+ },
+ },
+ acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"),
+ }
+ },
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), azID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading authz azID: force"),
+ }
+ },
+ "fail/unmarshal-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), azID)
+
+ return []byte("foo"), nil
+ },
+ },
+ err: errors.New("error unmarshaling authz azID into dbAuthz"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ now := clock.Now()
+ dbaz := &dbAuthz{
+ ID: azID,
+ AccountID: "accountID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "test.ca.smallstep.com",
+ },
+ Status: acme.StatusPending,
+ Token: "token",
+ CreatedAt: now,
+ ExpiresAt: now.Add(5 * time.Minute),
+ Error: acme.NewErrorISE("force"),
+ ChallengeIDs: []string{"foo", "bar"},
+ Wildcard: true,
+ }
+ b, err := json.Marshal(dbaz)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), azID)
+
+ return b, nil
+ },
+ },
+ dbaz: dbaz,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, dbaz.ID, tc.dbaz.ID)
+ assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
+ assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
+ assert.Equals(t, dbaz.Status, tc.dbaz.Status)
+ assert.Equals(t, dbaz.Token, tc.dbaz.Token)
+ assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
+ assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
+ assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
+ assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_GetAuthorization(t *testing.T) {
+ azID := "azID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ dbaz *dbAuthz
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), azID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading authz azID: force"),
+ }
+ },
+ "fail/forward-acme-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), azID)
+
+ return nil, nosqldb.ErrNotFound
+ },
+ },
+ acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"),
+ }
+ },
+ "fail/db.GetChallenge-error": func(t *testing.T) test {
+ now := clock.Now()
+ dbaz := &dbAuthz{
+ ID: azID,
+ AccountID: "accountID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "test.ca.smallstep.com",
+ },
+ Status: acme.StatusPending,
+ Token: "token",
+ CreatedAt: now,
+ ExpiresAt: now.Add(5 * time.Minute),
+ Error: acme.NewErrorISE("force"),
+ ChallengeIDs: []string{"foo", "bar"},
+ Wildcard: true,
+ }
+ b, err := json.Marshal(dbaz)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(authzTable):
+ assert.Equals(t, string(key), azID)
+ return b, nil
+ case string(challengeTable):
+ assert.Equals(t, string(key), "foo")
+ return nil, errors.New("force")
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ },
+ err: errors.New("error loading acme challenge foo: force"),
+ }
+ },
+ "fail/db.GetChallenge-not-found": func(t *testing.T) test {
+ now := clock.Now()
+ dbaz := &dbAuthz{
+ ID: azID,
+ AccountID: "accountID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "test.ca.smallstep.com",
+ },
+ Status: acme.StatusPending,
+ Token: "token",
+ CreatedAt: now,
+ ExpiresAt: now.Add(5 * time.Minute),
+ Error: acme.NewErrorISE("force"),
+ ChallengeIDs: []string{"foo", "bar"},
+ Wildcard: true,
+ }
+ b, err := json.Marshal(dbaz)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(authzTable):
+ assert.Equals(t, string(key), azID)
+ return b, nil
+ case string(challengeTable):
+ assert.Equals(t, string(key), "foo")
+ return nil, nosqldb.ErrNotFound
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ },
+ acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge foo not found"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ now := clock.Now()
+ dbaz := &dbAuthz{
+ ID: azID,
+ AccountID: "accountID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "test.ca.smallstep.com",
+ },
+ Status: acme.StatusPending,
+ Token: "token",
+ CreatedAt: now,
+ ExpiresAt: now.Add(5 * time.Minute),
+ Error: acme.NewErrorISE("force"),
+ ChallengeIDs: []string{"foo", "bar"},
+ Wildcard: true,
+ }
+ b, err := json.Marshal(dbaz)
+ assert.FatalError(t, err)
+ chCount := 0
+ fooChb, err := json.Marshal(&dbChallenge{ID: "foo"})
+ assert.FatalError(t, err)
+ barChb, err := json.Marshal(&dbChallenge{ID: "bar"})
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(authzTable):
+ assert.Equals(t, string(key), azID)
+ return b, nil
+ case string(challengeTable):
+ if chCount == 0 {
+ chCount++
+ assert.Equals(t, string(key), "foo")
+ return fooChb, nil
+ }
+ assert.Equals(t, string(key), "bar")
+ return barChb, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ },
+ dbaz: dbaz,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if az, err := db.GetAuthorization(context.Background(), azID); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, az.ID, tc.dbaz.ID)
+ assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
+ assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
+ assert.Equals(t, az.Status, tc.dbaz.Status)
+ assert.Equals(t, az.Token, tc.dbaz.Token)
+ assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
+ assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
+ assert.Equals(t, az.Challenges, []*acme.Challenge{
+ {ID: "foo"},
+ {ID: "bar"},
+ })
+ assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
+ }
+ }
+ })
+ }
+}
+
+func TestDB_CreateAuthorization(t *testing.T) {
+ azID := "azID"
+ type test struct {
+ db nosql.DB
+ az *acme.Authorization
+ err error
+ _id *string
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/cmpAndSwap-error": func(t *testing.T) test {
+ now := clock.Now()
+ az := &acme.Authorization{
+ ID: azID,
+ AccountID: "accountID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "test.ca.smallstep.com",
+ },
+ Status: acme.StatusPending,
+ Token: "token",
+ ExpiresAt: now.Add(5 * time.Minute),
+ Challenges: []*acme.Challenge{
+ {ID: "foo"},
+ {ID: "bar"},
+ },
+ Wildcard: true,
+ Error: acme.NewErrorISE("force"),
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), az.ID)
+ assert.Equals(t, old, nil)
+
+ dbaz := new(dbAuthz)
+ assert.FatalError(t, json.Unmarshal(nu, dbaz))
+ assert.Equals(t, dbaz.ID, string(key))
+ assert.Equals(t, dbaz.AccountID, az.AccountID)
+ assert.Equals(t, dbaz.Identifier, acme.Identifier{
+ Type: "dns",
+ Value: "test.ca.smallstep.com",
+ })
+ assert.Equals(t, dbaz.Status, az.Status)
+ assert.Equals(t, dbaz.Token, az.Token)
+ assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"})
+ assert.Equals(t, dbaz.Wildcard, az.Wildcard)
+ assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt)
+ assert.Nil(t, dbaz.Error)
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt))
+ return nil, false, errors.New("force")
+ },
+ },
+ az: az,
+ err: errors.New("error saving acme authz: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ var (
+ id string
+ idPtr = &id
+ now = clock.Now()
+ az = &acme.Authorization{
+ ID: azID,
+ AccountID: "accountID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "test.ca.smallstep.com",
+ },
+ Status: acme.StatusPending,
+ Token: "token",
+ ExpiresAt: now.Add(5 * time.Minute),
+ Challenges: []*acme.Challenge{
+ {ID: "foo"},
+ {ID: "bar"},
+ },
+ Wildcard: true,
+ Error: acme.NewErrorISE("force"),
+ }
+ )
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ *idPtr = string(key)
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), az.ID)
+ assert.Equals(t, old, nil)
+
+ dbaz := new(dbAuthz)
+ assert.FatalError(t, json.Unmarshal(nu, dbaz))
+ assert.Equals(t, dbaz.ID, string(key))
+ assert.Equals(t, dbaz.AccountID, az.AccountID)
+ assert.Equals(t, dbaz.Identifier, acme.Identifier{
+ Type: "dns",
+ Value: "test.ca.smallstep.com",
+ })
+ assert.Equals(t, dbaz.Status, az.Status)
+ assert.Equals(t, dbaz.Token, az.Token)
+ assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"})
+ assert.Equals(t, dbaz.Wildcard, az.Wildcard)
+ assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt)
+ assert.Nil(t, dbaz.Error)
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt))
+ return nu, true, nil
+ },
+ },
+ az: az,
+ _id: idPtr,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.CreateAuthorization(context.Background(), tc.az); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, tc.az.ID, *tc._id)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_UpdateAuthorization(t *testing.T) {
+ azID := "azID"
+ now := clock.Now()
+ dbaz := &dbAuthz{
+ ID: azID,
+ AccountID: "accountID",
+ Identifier: acme.Identifier{
+ Type: "dns",
+ Value: "test.ca.smallstep.com",
+ },
+ Status: acme.StatusPending,
+ Token: "token",
+ CreatedAt: now,
+ ExpiresAt: now.Add(5 * time.Minute),
+ ChallengeIDs: []string{"foo", "bar"},
+ Wildcard: true,
+ }
+ b, err := json.Marshal(dbaz)
+ assert.FatalError(t, err)
+ type test struct {
+ db nosql.DB
+ az *acme.Authorization
+ err error
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ az: &acme.Authorization{
+ ID: azID,
+ },
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), azID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading authz azID: force"),
+ }
+ },
+ "fail/db.CmpAndSwap-error": func(t *testing.T) test {
+ updAz := &acme.Authorization{
+ ID: azID,
+ Status: acme.StatusValid,
+ Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
+ }
+ return test{
+ az: updAz,
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), azID)
+
+ return b, nil
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, old, b)
+
+ dbOld := new(dbAuthz)
+ assert.FatalError(t, json.Unmarshal(old, dbOld))
+ assert.Equals(t, dbaz, dbOld)
+
+ dbNew := new(dbAuthz)
+ assert.FatalError(t, json.Unmarshal(nu, dbNew))
+ assert.Equals(t, dbNew.ID, dbaz.ID)
+ assert.Equals(t, dbNew.AccountID, dbaz.AccountID)
+ assert.Equals(t, dbNew.Identifier, dbaz.Identifier)
+ assert.Equals(t, dbNew.Status, acme.StatusValid)
+ assert.Equals(t, dbNew.Token, dbaz.Token)
+ assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs)
+ assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
+ assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
+ assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
+ assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
+ return nil, false, errors.New("force")
+ },
+ },
+ err: errors.New("error saving acme authz: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ updAz := &acme.Authorization{
+ ID: azID,
+ AccountID: dbaz.AccountID,
+ Status: acme.StatusValid,
+ Identifier: dbaz.Identifier,
+ Challenges: []*acme.Challenge{
+ {ID: "foo"},
+ {ID: "bar"},
+ },
+ Token: dbaz.Token,
+ Wildcard: dbaz.Wildcard,
+ ExpiresAt: dbaz.ExpiresAt,
+ Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
+ }
+ return test{
+ az: updAz,
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, string(key), azID)
+
+ return b, nil
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, authzTable)
+ assert.Equals(t, old, b)
+
+ dbOld := new(dbAuthz)
+ assert.FatalError(t, json.Unmarshal(old, dbOld))
+ assert.Equals(t, dbaz, dbOld)
+
+ dbNew := new(dbAuthz)
+ assert.FatalError(t, json.Unmarshal(nu, dbNew))
+ assert.Equals(t, dbNew.ID, dbaz.ID)
+ assert.Equals(t, dbNew.AccountID, dbaz.AccountID)
+ assert.Equals(t, dbNew.Identifier, dbaz.Identifier)
+ assert.Equals(t, dbNew.Status, acme.StatusValid)
+ assert.Equals(t, dbNew.Token, dbaz.Token)
+ assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs)
+ assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
+ assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
+ assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
+ assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
+ return nu, true, nil
+ },
+ },
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, tc.az.ID, dbaz.ID)
+ assert.Equals(t, tc.az.AccountID, dbaz.AccountID)
+ assert.Equals(t, tc.az.Identifier, dbaz.Identifier)
+ assert.Equals(t, tc.az.Status, acme.StatusValid)
+ assert.Equals(t, tc.az.Wildcard, dbaz.Wildcard)
+ assert.Equals(t, tc.az.Token, dbaz.Token)
+ assert.Equals(t, tc.az.ExpiresAt, dbaz.ExpiresAt)
+ assert.Equals(t, tc.az.Challenges, []*acme.Challenge{
+ {ID: "foo"},
+ {ID: "bar"},
+ })
+ assert.Equals(t, tc.az.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
+ }
+ }
+ })
+ }
+}
diff --git a/acme/db/nosql/certificate.go b/acme/db/nosql/certificate.go
new file mode 100644
index 00000000..d3e15833
--- /dev/null
+++ b/acme/db/nosql/certificate.go
@@ -0,0 +1,109 @@
+package nosql
+
+import (
+ "context"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/nosql"
+)
+
+type dbCert struct {
+ ID string `json:"id"`
+ CreatedAt time.Time `json:"createdAt"`
+ AccountID string `json:"accountID"`
+ OrderID string `json:"orderID"`
+ Leaf []byte `json:"leaf"`
+ Intermediates []byte `json:"intermediates"`
+}
+
+// CreateCertificate creates and stores an ACME certificate type.
+func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error {
+ var err error
+ cert.ID, err = randID()
+ if err != nil {
+ return err
+ }
+
+ leaf := pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: cert.Leaf.Raw,
+ })
+ var intermediates []byte
+ for _, cert := range cert.Intermediates {
+ intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: cert.Raw,
+ })...)
+ }
+
+ dbch := &dbCert{
+ ID: cert.ID,
+ AccountID: cert.AccountID,
+ OrderID: cert.OrderID,
+ Leaf: leaf,
+ Intermediates: intermediates,
+ CreatedAt: time.Now().UTC(),
+ }
+ return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
+}
+
+// GetCertificate retrieves and unmarshals an ACME certificate type from the
+// datastore.
+func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) {
+ b, err := db.db.Get(certTable, []byte(id))
+ if nosql.IsErrNotFound(err) {
+ return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id)
+ } else if err != nil {
+ return nil, errors.Wrapf(err, "error loading certificate %s", id)
+ }
+ dbC := new(dbCert)
+ if err := json.Unmarshal(b, dbC); err != nil {
+ return nil, errors.Wrapf(err, "error unmarshaling certificate %s", id)
+ }
+
+ certs, err := parseBundle(append(dbC.Leaf, dbC.Intermediates...))
+ if err != nil {
+ return nil, errors.Wrapf(err, "error parsing certificate chain for ACME certificate with ID %s", id)
+ }
+
+ return &acme.Certificate{
+ ID: dbC.ID,
+ AccountID: dbC.AccountID,
+ OrderID: dbC.OrderID,
+ Leaf: certs[0],
+ Intermediates: certs[1:],
+ }, nil
+}
+
+func parseBundle(b []byte) ([]*x509.Certificate, error) {
+ var (
+ err error
+ block *pem.Block
+ bundle []*x509.Certificate
+ )
+ for len(b) > 0 {
+ block, b = pem.Decode(b)
+ if block == nil {
+ break
+ }
+ if block.Type != "CERTIFICATE" {
+ return nil, errors.New("error decoding PEM: data contains block that is not a certificate")
+ }
+ var crt *x509.Certificate
+ crt, err = x509.ParseCertificate(block.Bytes)
+ if err != nil {
+ return nil, errors.Wrapf(err, "error parsing x509 certificate")
+ }
+ bundle = append(bundle, crt)
+ }
+ if len(b) > 0 {
+ return nil, errors.New("error decoding PEM: unexpected data")
+ }
+ return bundle, nil
+
+}
diff --git a/acme/db/nosql/certificate_test.go b/acme/db/nosql/certificate_test.go
new file mode 100644
index 00000000..4ec4589e
--- /dev/null
+++ b/acme/db/nosql/certificate_test.go
@@ -0,0 +1,321 @@
+package nosql
+
+import (
+ "context"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "testing"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/assert"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/certificates/db"
+ "github.com/smallstep/nosql"
+ nosqldb "github.com/smallstep/nosql/database"
+
+ "go.step.sm/crypto/pemutil"
+)
+
+func TestDB_CreateCertificate(t *testing.T) {
+ leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
+ assert.FatalError(t, err)
+ inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
+ assert.FatalError(t, err)
+ root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
+ assert.FatalError(t, err)
+ type test struct {
+ db nosql.DB
+ cert *acme.Certificate
+ err error
+ _id *string
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/cmpAndSwap-error": func(t *testing.T) test {
+ cert := &acme.Certificate{
+ AccountID: "accountID",
+ OrderID: "orderID",
+ Leaf: leaf,
+ Intermediates: []*x509.Certificate{inter, root},
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, certTable)
+ assert.Equals(t, key, []byte(cert.ID))
+ assert.Equals(t, old, nil)
+
+ dbc := new(dbCert)
+ assert.FatalError(t, json.Unmarshal(nu, dbc))
+ assert.Equals(t, dbc.ID, string(key))
+ assert.Equals(t, dbc.ID, cert.ID)
+ assert.Equals(t, dbc.AccountID, cert.AccountID)
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
+ return nil, false, errors.New("force")
+ },
+ },
+ cert: cert,
+ err: errors.New("error saving acme certificate: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ cert := &acme.Certificate{
+ AccountID: "accountID",
+ OrderID: "orderID",
+ Leaf: leaf,
+ Intermediates: []*x509.Certificate{inter, root},
+ }
+ var (
+ id string
+ idPtr = &id
+ )
+
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ *idPtr = string(key)
+ assert.Equals(t, bucket, certTable)
+ assert.Equals(t, key, []byte(cert.ID))
+ assert.Equals(t, old, nil)
+
+ dbc := new(dbCert)
+ assert.FatalError(t, json.Unmarshal(nu, dbc))
+ assert.Equals(t, dbc.ID, string(key))
+ assert.Equals(t, dbc.ID, cert.ID)
+ assert.Equals(t, dbc.AccountID, cert.AccountID)
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
+ return nil, true, nil
+ },
+ },
+ _id: idPtr,
+ cert: cert,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.CreateCertificate(context.Background(), tc.cert); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, tc.cert.ID, *tc._id)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_GetCertificate(t *testing.T) {
+ leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
+ assert.FatalError(t, err)
+ inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
+ assert.FatalError(t, err)
+ root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
+ assert.FatalError(t, err)
+
+ certID := "certID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/not-found": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, certTable)
+ assert.Equals(t, string(key), certID)
+
+ return nil, nosqldb.ErrNotFound
+ },
+ },
+ acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"),
+ }
+ },
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, certTable)
+ assert.Equals(t, string(key), certID)
+
+ return nil, errors.Errorf("force")
+ },
+ },
+ err: errors.New("error loading certificate certID: force"),
+ }
+ },
+ "fail/unmarshal-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, certTable)
+ assert.Equals(t, string(key), certID)
+
+ return []byte("foobar"), nil
+ },
+ },
+ err: errors.New("error unmarshaling certificate certID"),
+ }
+ },
+ "fail/parseBundle-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, certTable)
+ assert.Equals(t, string(key), certID)
+
+ cert := dbCert{
+ ID: certID,
+ AccountID: "accountID",
+ OrderID: "orderID",
+ Leaf: pem.EncodeToMemory(&pem.Block{
+ Type: "Public Key",
+ Bytes: leaf.Raw,
+ }),
+ CreatedAt: clock.Now(),
+ }
+ b, err := json.Marshal(cert)
+ assert.FatalError(t, err)
+
+ return b, nil
+ },
+ },
+ err: errors.Errorf("error parsing certificate chain for ACME certificate with ID certID"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, certTable)
+ assert.Equals(t, string(key), certID)
+
+ cert := dbCert{
+ ID: certID,
+ AccountID: "accountID",
+ OrderID: "orderID",
+ Leaf: pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: leaf.Raw,
+ }),
+ Intermediates: append(pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: inter.Raw,
+ }), pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: root.Raw,
+ })...),
+ CreatedAt: clock.Now(),
+ }
+ b, err := json.Marshal(cert)
+ assert.FatalError(t, err)
+
+ return b, nil
+ },
+ },
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ cert, err := db.GetCertificate(context.Background(), certID)
+ if err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, cert.ID, certID)
+ assert.Equals(t, cert.AccountID, "accountID")
+ assert.Equals(t, cert.OrderID, "orderID")
+ assert.Equals(t, cert.Leaf, leaf)
+ assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
+ }
+ }
+ })
+ }
+}
+
+func Test_parseBundle(t *testing.T) {
+ leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
+ assert.FatalError(t, err)
+ inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
+ assert.FatalError(t, err)
+ root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
+ assert.FatalError(t, err)
+
+ var certs []byte
+ for _, cert := range []*x509.Certificate{leaf, inter, root} {
+ certs = append(certs, pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: cert.Raw,
+ })...)
+ }
+
+ type test struct {
+ b []byte
+ err error
+ }
+ var tests = map[string]test{
+ "fail/bad-type-error": {
+ b: pem.EncodeToMemory(&pem.Block{
+ Type: "Public Key",
+ Bytes: leaf.Raw,
+ }),
+ err: errors.Errorf("error decoding PEM: data contains block that is not a certificate"),
+ },
+ "fail/bad-pem-error": {
+ b: pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: []byte("foo"),
+ }),
+ err: errors.Errorf("error parsing x509 certificate"),
+ },
+ "fail/unexpected-data": {
+ b: append(pem.EncodeToMemory(&pem.Block{
+ Type: "CERTIFICATE",
+ Bytes: leaf.Raw,
+ }), []byte("foo")...),
+ err: errors.Errorf("error decoding PEM: unexpected data"),
+ },
+ "ok": {
+ b: certs,
+ },
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ ret, err := parseBundle(tc.b)
+ if err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, ret, []*x509.Certificate{leaf, inter, root})
+ }
+ }
+ })
+ }
+}
diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go
new file mode 100644
index 00000000..f3a3cfca
--- /dev/null
+++ b/acme/db/nosql/challenge.go
@@ -0,0 +1,103 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/nosql"
+)
+
+type dbChallenge struct {
+ ID string `json:"id"`
+ AccountID string `json:"accountID"`
+ Type string `json:"type"`
+ Status acme.Status `json:"status"`
+ Token string `json:"token"`
+ Value string `json:"value"`
+ ValidatedAt string `json:"validatedAt"`
+ CreatedAt time.Time `json:"createdAt"`
+ Error *acme.Error `json:"error"`
+}
+
+func (dbc *dbChallenge) clone() *dbChallenge {
+ u := *dbc
+ return &u
+}
+
+func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) {
+ data, err := db.db.Get(challengeTable, []byte(id))
+ if nosql.IsErrNotFound(err) {
+ return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id)
+ } else if err != nil {
+ return nil, errors.Wrapf(err, "error loading acme challenge %s", id)
+ }
+
+ dbch := new(dbChallenge)
+ if err := json.Unmarshal(data, dbch); err != nil {
+ return nil, errors.Wrap(err, "error unmarshaling dbChallenge")
+ }
+ return dbch, nil
+}
+
+// CreateChallenge creates a new ACME challenge data structure in the database.
+// Implements acme.DB.CreateChallenge interface.
+func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error {
+ var err error
+ ch.ID, err = randID()
+ if err != nil {
+ return errors.Wrap(err, "error generating random id for ACME challenge")
+ }
+
+ dbch := &dbChallenge{
+ ID: ch.ID,
+ AccountID: ch.AccountID,
+ Value: ch.Value,
+ Status: acme.StatusPending,
+ Token: ch.Token,
+ CreatedAt: clock.Now(),
+ Type: ch.Type,
+ }
+
+ return db.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable)
+}
+
+// GetChallenge retrieves and unmarshals an ACME challenge type from the database.
+// Implements the acme.DB GetChallenge interface.
+func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) {
+ dbch, err := db.getDBChallenge(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+
+ ch := &acme.Challenge{
+ ID: dbch.ID,
+ AccountID: dbch.AccountID,
+ Type: dbch.Type,
+ Value: dbch.Value,
+ Status: dbch.Status,
+ Token: dbch.Token,
+ Error: dbch.Error,
+ ValidatedAt: dbch.ValidatedAt,
+ }
+ return ch, nil
+}
+
+// UpdateChallenge updates an ACME challenge type in the database.
+func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error {
+ old, err := db.getDBChallenge(ctx, ch.ID)
+ if err != nil {
+ return err
+ }
+
+ nu := old.clone()
+
+ // These should be the only values changing in an Update request.
+ nu.Status = ch.Status
+ nu.Error = ch.Error
+ nu.ValidatedAt = ch.ValidatedAt
+
+ return db.save(ctx, old.ID, nu, old, "challenge", challengeTable)
+}
diff --git a/acme/db/nosql/challenge_test.go b/acme/db/nosql/challenge_test.go
new file mode 100644
index 00000000..b39395e8
--- /dev/null
+++ b/acme/db/nosql/challenge_test.go
@@ -0,0 +1,464 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/assert"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/certificates/db"
+ "github.com/smallstep/nosql"
+ nosqldb "github.com/smallstep/nosql/database"
+)
+
+func TestDB_getDBChallenge(t *testing.T) {
+ chID := "chID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ dbc *dbChallenge
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/not-found": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return nil, nosqldb.ErrNotFound
+ },
+ },
+ acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
+ }
+ },
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading acme challenge chID: force"),
+ }
+ },
+ "fail/unmarshal-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return []byte("foo"), nil
+ },
+ },
+ err: errors.New("error unmarshaling dbChallenge"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ dbc := &dbChallenge{
+ ID: chID,
+ AccountID: "accountID",
+ Type: "dns-01",
+ Status: acme.StatusPending,
+ Token: "token",
+ Value: "test.ca.smallstep.com",
+ CreatedAt: clock.Now(),
+ ValidatedAt: "foobar",
+ Error: acme.NewErrorISE("force"),
+ }
+ b, err := json.Marshal(dbc)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return b, nil
+ },
+ },
+ dbc: dbc,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if ch, err := db.getDBChallenge(context.Background(), chID); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, ch.ID, tc.dbc.ID)
+ assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
+ assert.Equals(t, ch.Type, tc.dbc.Type)
+ assert.Equals(t, ch.Status, tc.dbc.Status)
+ assert.Equals(t, ch.Token, tc.dbc.Token)
+ assert.Equals(t, ch.Value, tc.dbc.Value)
+ assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
+ assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
+ }
+ }
+ })
+ }
+}
+
+func TestDB_CreateChallenge(t *testing.T) {
+ type test struct {
+ db nosql.DB
+ ch *acme.Challenge
+ err error
+ _id *string
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/cmpAndSwap-error": func(t *testing.T) test {
+ ch := &acme.Challenge{
+ AccountID: "accountID",
+ Type: "dns-01",
+ Status: acme.StatusPending,
+ Token: "token",
+ Value: "test.ca.smallstep.com",
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), ch.ID)
+ assert.Equals(t, old, nil)
+
+ dbc := new(dbChallenge)
+ assert.FatalError(t, json.Unmarshal(nu, dbc))
+ assert.Equals(t, dbc.ID, string(key))
+ assert.Equals(t, dbc.AccountID, ch.AccountID)
+ assert.Equals(t, dbc.Type, ch.Type)
+ assert.Equals(t, dbc.Status, ch.Status)
+ assert.Equals(t, dbc.Token, ch.Token)
+ assert.Equals(t, dbc.Value, ch.Value)
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
+ return nil, false, errors.New("force")
+ },
+ },
+ ch: ch,
+ err: errors.New("error saving acme challenge: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ var (
+ id string
+ idPtr = &id
+ ch = &acme.Challenge{
+ AccountID: "accountID",
+ Type: "dns-01",
+ Status: acme.StatusPending,
+ Token: "token",
+ Value: "test.ca.smallstep.com",
+ }
+ )
+
+ return test{
+ ch: ch,
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ *idPtr = string(key)
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), ch.ID)
+ assert.Equals(t, old, nil)
+
+ dbc := new(dbChallenge)
+ assert.FatalError(t, json.Unmarshal(nu, dbc))
+ assert.Equals(t, dbc.ID, string(key))
+ assert.Equals(t, dbc.AccountID, ch.AccountID)
+ assert.Equals(t, dbc.Type, ch.Type)
+ assert.Equals(t, dbc.Status, ch.Status)
+ assert.Equals(t, dbc.Token, ch.Token)
+ assert.Equals(t, dbc.Value, ch.Value)
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
+ return nil, true, nil
+ },
+ },
+ _id: idPtr,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.CreateChallenge(context.Background(), tc.ch); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, tc.ch.ID, *tc._id)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_GetChallenge(t *testing.T) {
+ chID := "chID"
+ azID := "azID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ dbc *dbChallenge
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading acme challenge chID: force"),
+ }
+ },
+ "fail/forward-acme-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return nil, nosqldb.ErrNotFound
+ },
+ },
+ acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ dbc := &dbChallenge{
+ ID: chID,
+ AccountID: "accountID",
+ Type: "dns-01",
+ Status: acme.StatusPending,
+ Token: "token",
+ Value: "test.ca.smallstep.com",
+ CreatedAt: clock.Now(),
+ ValidatedAt: "foobar",
+ Error: acme.NewErrorISE("force"),
+ }
+ b, err := json.Marshal(dbc)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return b, nil
+ },
+ },
+ dbc: dbc,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, ch.ID, tc.dbc.ID)
+ assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
+ assert.Equals(t, ch.Type, tc.dbc.Type)
+ assert.Equals(t, ch.Status, tc.dbc.Status)
+ assert.Equals(t, ch.Token, tc.dbc.Token)
+ assert.Equals(t, ch.Value, tc.dbc.Value)
+ assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
+ assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
+ }
+ }
+ })
+ }
+}
+
+func TestDB_UpdateChallenge(t *testing.T) {
+ chID := "chID"
+ dbc := &dbChallenge{
+ ID: chID,
+ AccountID: "accountID",
+ Type: "dns-01",
+ Status: acme.StatusPending,
+ Token: "token",
+ Value: "test.ca.smallstep.com",
+ CreatedAt: clock.Now(),
+ }
+ b, err := json.Marshal(dbc)
+ assert.FatalError(t, err)
+ type test struct {
+ db nosql.DB
+ ch *acme.Challenge
+ err error
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ ch: &acme.Challenge{
+ ID: chID,
+ },
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading acme challenge chID: force"),
+ }
+ },
+ "fail/db.CmpAndSwap-error": func(t *testing.T) test {
+ updCh := &acme.Challenge{
+ ID: chID,
+ Status: acme.StatusValid,
+ ValidatedAt: "foobar",
+ Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
+ }
+ return test{
+ ch: updCh,
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return b, nil
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, old, b)
+
+ dbOld := new(dbChallenge)
+ assert.FatalError(t, json.Unmarshal(old, dbOld))
+ assert.Equals(t, dbc, dbOld)
+
+ dbNew := new(dbChallenge)
+ assert.FatalError(t, json.Unmarshal(nu, dbNew))
+ assert.Equals(t, dbNew.ID, dbc.ID)
+ assert.Equals(t, dbNew.AccountID, dbc.AccountID)
+ assert.Equals(t, dbNew.Type, dbc.Type)
+ assert.Equals(t, dbNew.Status, updCh.Status)
+ assert.Equals(t, dbNew.Token, dbc.Token)
+ assert.Equals(t, dbNew.Value, dbc.Value)
+ assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error())
+ assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
+ assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt)
+ return nil, false, errors.New("force")
+ },
+ },
+ err: errors.New("error saving acme challenge: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ updCh := &acme.Challenge{
+ ID: dbc.ID,
+ AccountID: dbc.AccountID,
+ Type: dbc.Type,
+ Token: dbc.Token,
+ Value: dbc.Value,
+ Status: acme.StatusValid,
+ ValidatedAt: "foobar",
+ Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
+ }
+ return test{
+ ch: updCh,
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), chID)
+
+ return b, nil
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, old, b)
+
+ dbOld := new(dbChallenge)
+ assert.FatalError(t, json.Unmarshal(old, dbOld))
+ assert.Equals(t, dbc, dbOld)
+
+ dbNew := new(dbChallenge)
+ assert.FatalError(t, json.Unmarshal(nu, dbNew))
+ assert.Equals(t, dbNew.ID, dbc.ID)
+ assert.Equals(t, dbNew.AccountID, dbc.AccountID)
+ assert.Equals(t, dbNew.Type, dbc.Type)
+ assert.Equals(t, dbNew.Token, dbc.Token)
+ assert.Equals(t, dbNew.Value, dbc.Value)
+ assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
+ assert.Equals(t, dbNew.Status, acme.StatusValid)
+ assert.Equals(t, dbNew.ValidatedAt, "foobar")
+ assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
+ return nu, true, nil
+ },
+ },
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, tc.ch.ID, dbc.ID)
+ assert.Equals(t, tc.ch.AccountID, dbc.AccountID)
+ assert.Equals(t, tc.ch.Type, dbc.Type)
+ assert.Equals(t, tc.ch.Token, dbc.Token)
+ assert.Equals(t, tc.ch.Value, dbc.Value)
+ assert.Equals(t, tc.ch.ValidatedAt, "foobar")
+ assert.Equals(t, tc.ch.Status, acme.StatusValid)
+ assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
+ }
+ }
+ })
+ }
+}
diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go
new file mode 100644
index 00000000..9badae87
--- /dev/null
+++ b/acme/db/nosql/nonce.go
@@ -0,0 +1,66 @@
+package nosql
+
+import (
+ "context"
+ "encoding/base64"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/nosql"
+ "github.com/smallstep/nosql/database"
+)
+
+// dbNonce contains nonce metadata used in the ACME protocol.
+type dbNonce struct {
+ ID string
+ CreatedAt time.Time
+ DeletedAt time.Time
+}
+
+// CreateNonce creates, stores, and returns an ACME replay-nonce.
+// Implements the acme.DB interface.
+func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
+ _id, err := randID()
+ if err != nil {
+ return "", err
+ }
+
+ id := base64.RawURLEncoding.EncodeToString([]byte(_id))
+ n := &dbNonce{
+ ID: id,
+ CreatedAt: clock.Now(),
+ }
+ if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
+ return "", err
+ }
+ return acme.Nonce(id), nil
+}
+
+// DeleteNonce verifies that the nonce is valid (by checking if it exists),
+// and if so, consumes the nonce resource by deleting it from the database.
+func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error {
+ err := db.db.Update(&database.Tx{
+ Operations: []*database.TxEntry{
+ {
+ Bucket: nonceTable,
+ Key: []byte(nonce),
+ Cmd: database.Get,
+ },
+ {
+ Bucket: nonceTable,
+ Key: []byte(nonce),
+ Cmd: database.Delete,
+ },
+ },
+ })
+
+ switch {
+ case nosql.IsErrNotFound(err):
+ return acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", string(nonce))
+ case err != nil:
+ return errors.Wrapf(err, "error deleting nonce %s", string(nonce))
+ default:
+ return nil
+ }
+}
diff --git a/acme/db/nosql/nonce_test.go b/acme/db/nosql/nonce_test.go
new file mode 100644
index 00000000..05d73d52
--- /dev/null
+++ b/acme/db/nosql/nonce_test.go
@@ -0,0 +1,168 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/assert"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/certificates/db"
+ "github.com/smallstep/nosql"
+ "github.com/smallstep/nosql/database"
+)
+
+func TestDB_CreateNonce(t *testing.T) {
+ type test struct {
+ db nosql.DB
+ err error
+ _id *string
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/cmpAndSwap-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, nonceTable)
+ assert.Equals(t, old, nil)
+
+ dbn := new(dbNonce)
+ assert.FatalError(t, json.Unmarshal(nu, dbn))
+ assert.Equals(t, dbn.ID, string(key))
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
+ return nil, false, errors.New("force")
+ },
+ },
+ err: errors.New("error saving acme nonce: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ var (
+ id string
+ idPtr = &id
+ )
+
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ *idPtr = string(key)
+ assert.Equals(t, bucket, nonceTable)
+ assert.Equals(t, old, nil)
+
+ dbn := new(dbNonce)
+ assert.FatalError(t, json.Unmarshal(nu, dbn))
+ assert.Equals(t, dbn.ID, string(key))
+ assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
+ assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
+ return nil, true, nil
+ },
+ },
+ _id: idPtr,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if n, err := db.CreateNonce(context.Background()); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, string(n), *tc._id)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_DeleteNonce(t *testing.T) {
+
+ nonceID := "nonceID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/not-found": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MUpdate: func(tx *database.Tx) error {
+ assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
+ assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
+ assert.Equals(t, tx.Operations[0].Cmd, database.Get)
+
+ assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
+ assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
+ assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
+ return database.ErrNotFound
+ },
+ },
+ acmeErr: acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", nonceID),
+ }
+ },
+ "fail/db.Update-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MUpdate: func(tx *database.Tx) error {
+ assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
+ assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
+ assert.Equals(t, tx.Operations[0].Cmd, database.Get)
+
+ assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
+ assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
+ assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
+ return errors.New("force")
+ },
+ },
+ err: errors.New("error deleting nonce nonceID: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MUpdate: func(tx *database.Tx) error {
+ assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
+ assert.Equals(t, tx.Operations[0].Key, []byte(nonceID))
+ assert.Equals(t, tx.Operations[0].Cmd, database.Get)
+
+ assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
+ assert.Equals(t, tx.Operations[1].Key, []byte(nonceID))
+ assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
+ return nil
+ },
+ },
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ assert.Nil(t, tc.err)
+ }
+ })
+ }
+}
diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go
new file mode 100644
index 00000000..052f5729
--- /dev/null
+++ b/acme/db/nosql/nosql.go
@@ -0,0 +1,96 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "time"
+
+ "github.com/pkg/errors"
+ nosqlDB "github.com/smallstep/nosql"
+ "go.step.sm/crypto/randutil"
+)
+
+var (
+ accountTable = []byte("acme_accounts")
+ accountByKeyIDTable = []byte("acme_keyID_accountID_index")
+ authzTable = []byte("acme_authzs")
+ challengeTable = []byte("acme_challenges")
+ nonceTable = []byte("nonces")
+ orderTable = []byte("acme_orders")
+ ordersByAccountIDTable = []byte("acme_account_orders_index")
+ certTable = []byte("acme_certs")
+)
+
+// DB is a struct that implements the AcmeDB interface.
+type DB struct {
+ db nosqlDB.DB
+}
+
+// New configures and returns a new ACME DB backend implemented using a nosql DB.
+func New(db nosqlDB.DB) (*DB, error) {
+ tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
+ challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable}
+ for _, b := range tables {
+ if err := db.CreateTable(b); err != nil {
+ return nil, errors.Wrapf(err, "error creating table %s",
+ string(b))
+ }
+ }
+ return &DB{db}, nil
+}
+
+// save writes the new data to the database, overwriting the old data if it
+// existed.
+func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
+ var (
+ err error
+ newB []byte
+ )
+ if nu == nil {
+ newB = nil
+ } else {
+ newB, err = json.Marshal(nu)
+ if err != nil {
+ return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu)
+ }
+ }
+ var oldB []byte
+ if old == nil {
+ oldB = nil
+ } else {
+ oldB, err = json.Marshal(old)
+ if err != nil {
+ return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old)
+ }
+ }
+
+ _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
+ switch {
+ case err != nil:
+ return errors.Wrapf(err, "error saving acme %s", typ)
+ case !swapped:
+ return errors.Errorf("error saving acme %s; changed since last read", typ)
+ default:
+ return nil
+ }
+}
+
+var idLen = 32
+
+func randID() (val string, err error) {
+ val, err = randutil.Alphanumeric(idLen)
+ if err != nil {
+ return "", errors.Wrap(err, "error generating random alphanumeric ID")
+ }
+ return val, nil
+}
+
+// Clock that returns time in UTC rounded to seconds.
+type Clock struct{}
+
+// Now returns the UTC time rounded to seconds.
+func (c *Clock) Now() time.Time {
+ return time.Now().UTC().Truncate(time.Second)
+}
+
+var clock = new(Clock)
diff --git a/acme/db/nosql/nosql_test.go b/acme/db/nosql/nosql_test.go
new file mode 100644
index 00000000..4396acc8
--- /dev/null
+++ b/acme/db/nosql/nosql_test.go
@@ -0,0 +1,139 @@
+package nosql
+
+import (
+ "context"
+ "testing"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/assert"
+ "github.com/smallstep/certificates/db"
+ "github.com/smallstep/nosql"
+)
+
+func TestNew(t *testing.T) {
+ type test struct {
+ db nosql.DB
+ err error
+ }
+ var tests = map[string]test{
+ "fail/db.CreateTable-error": {
+ db: &db.MockNoSQLDB{
+ MCreateTable: func(bucket []byte) error {
+ assert.Equals(t, string(bucket), string(accountTable))
+ return errors.New("force")
+ },
+ },
+ err: errors.Errorf("error creating table %s: force", string(accountTable)),
+ },
+ "ok": {
+ db: &db.MockNoSQLDB{
+ MCreateTable: func(bucket []byte) error {
+ return nil
+ },
+ },
+ },
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ if _, err := New(tc.db); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ assert.Nil(t, tc.err)
+ }
+ })
+ }
+}
+
+type errorThrower string
+
+func (et errorThrower) MarshalJSON() ([]byte, error) {
+ return nil, errors.New("force")
+}
+
+func TestDB_save(t *testing.T) {
+ type test struct {
+ db nosql.DB
+ nu interface{}
+ old interface{}
+ err error
+ }
+ var tests = map[string]test{
+ "fail/error-marshaling-new": {
+ nu: errorThrower("foo"),
+ err: errors.New("error marshaling acme type: challenge"),
+ },
+ "fail/error-marshaling-old": {
+ nu: "new",
+ old: errorThrower("foo"),
+ err: errors.New("error marshaling acme type: challenge"),
+ },
+ "fail/db.CmpAndSwap-error": {
+ nu: "new",
+ old: "old",
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), "id")
+ assert.Equals(t, string(old), "\"old\"")
+ assert.Equals(t, string(nu), "\"new\"")
+ return nil, false, errors.New("force")
+ },
+ },
+ err: errors.New("error saving acme challenge: force"),
+ },
+ "fail/db.CmpAndSwap-false-marshaling-old": {
+ nu: "new",
+ old: "old",
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), "id")
+ assert.Equals(t, string(old), "\"old\"")
+ assert.Equals(t, string(nu), "\"new\"")
+ return nil, false, nil
+ },
+ },
+ err: errors.New("error saving acme challenge; changed since last read"),
+ },
+ "ok": {
+ nu: "new",
+ old: "old",
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), "id")
+ assert.Equals(t, string(old), "\"old\"")
+ assert.Equals(t, string(nu), "\"new\"")
+ return nu, true, nil
+ },
+ },
+ },
+ "ok/nils": {
+ nu: nil,
+ old: nil,
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, challengeTable)
+ assert.Equals(t, string(key), "id")
+ assert.Equals(t, old, nil)
+ assert.Equals(t, nu, nil)
+ return nu, true, nil
+ },
+ },
+ },
+ }
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ db := &DB{db: tc.db}
+ if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ assert.Nil(t, tc.err)
+ }
+ })
+ }
+}
diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go
new file mode 100644
index 00000000..ba3934af
--- /dev/null
+++ b/acme/db/nosql/order.go
@@ -0,0 +1,189 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "sync"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/nosql"
+)
+
+// Mutex for locking ordersByAccount index operations.
+var ordersByAccountMux sync.Mutex
+
+type dbOrder struct {
+ ID string `json:"id"`
+ AccountID string `json:"accountID"`
+ ProvisionerID string `json:"provisionerID"`
+ Identifiers []acme.Identifier `json:"identifiers"`
+ AuthorizationIDs []string `json:"authorizationIDs"`
+ Status acme.Status `json:"status"`
+ NotBefore time.Time `json:"notBefore,omitempty"`
+ NotAfter time.Time `json:"notAfter,omitempty"`
+ CreatedAt time.Time `json:"createdAt"`
+ ExpiresAt time.Time `json:"expiresAt,omitempty"`
+ CertificateID string `json:"certificate,omitempty"`
+ Error *acme.Error `json:"error,omitempty"`
+}
+
+func (a *dbOrder) clone() *dbOrder {
+ b := *a
+ return &b
+}
+
+// getDBOrder retrieves and unmarshals an ACME Order type from the database.
+func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) {
+ b, err := db.db.Get(orderTable, []byte(id))
+ if nosql.IsErrNotFound(err) {
+ return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id)
+ } else if err != nil {
+ return nil, errors.Wrapf(err, "error loading order %s", id)
+ }
+ o := new(dbOrder)
+ if err := json.Unmarshal(b, &o); err != nil {
+ return nil, errors.Wrapf(err, "error unmarshaling order %s into dbOrder", id)
+ }
+ return o, nil
+}
+
+// GetOrder retrieves an ACME Order from the database.
+func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) {
+ dbo, err := db.getDBOrder(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+
+ o := &acme.Order{
+ ID: dbo.ID,
+ AccountID: dbo.AccountID,
+ ProvisionerID: dbo.ProvisionerID,
+ CertificateID: dbo.CertificateID,
+ Status: dbo.Status,
+ ExpiresAt: dbo.ExpiresAt,
+ Identifiers: dbo.Identifiers,
+ NotBefore: dbo.NotBefore,
+ NotAfter: dbo.NotAfter,
+ AuthorizationIDs: dbo.AuthorizationIDs,
+ Error: dbo.Error,
+ }
+
+ return o, nil
+}
+
+// CreateOrder creates ACME Order resources and saves them to the DB.
+func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error {
+ var err error
+ o.ID, err = randID()
+ if err != nil {
+ return err
+ }
+
+ now := clock.Now()
+ dbo := &dbOrder{
+ ID: o.ID,
+ AccountID: o.AccountID,
+ ProvisionerID: o.ProvisionerID,
+ Status: o.Status,
+ CreatedAt: now,
+ ExpiresAt: o.ExpiresAt,
+ Identifiers: o.Identifiers,
+ NotBefore: o.NotBefore,
+ NotAfter: o.NotAfter,
+ AuthorizationIDs: o.AuthorizationIDs,
+ }
+ if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil {
+ return err
+ }
+
+ _, err = db.updateAddOrderIDs(ctx, o.AccountID, o.ID)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// UpdateOrder saves an updated ACME Order to the database.
+func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error {
+ old, err := db.getDBOrder(ctx, o.ID)
+ if err != nil {
+ return err
+ }
+
+ nu := old.clone()
+
+ nu.Status = o.Status
+ nu.Error = o.Error
+ nu.CertificateID = o.CertificateID
+ return db.save(ctx, old.ID, nu, old, "order", orderTable)
+}
+
+func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) {
+ ordersByAccountMux.Lock()
+ defer ordersByAccountMux.Unlock()
+
+ b, err := db.db.Get(ordersByAccountIDTable, []byte(accID))
+ var (
+ oldOids []string
+ )
+ if err != nil {
+ if !nosql.IsErrNotFound(err) {
+ return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID)
+ }
+ } else {
+ if err := json.Unmarshal(b, &oldOids); err != nil {
+ return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)
+ }
+ }
+
+ // Remove any order that is not in PENDING state and update the stored list
+ // before returning.
+ //
+ // According to RFC 8555:
+ // The server SHOULD include pending orders and SHOULD NOT include orders
+ // that are invalid in the array of URLs.
+ pendOids := []string{}
+ for _, oid := range oldOids {
+ o, err := db.GetOrder(ctx, oid)
+ if err != nil {
+ return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID)
+ }
+ if err = o.UpdateStatus(ctx, db); err != nil {
+ return nil, acme.WrapErrorISE(err, "error updating order %s for account %s", oid, accID)
+ }
+ if o.Status == acme.StatusPending {
+ pendOids = append(pendOids, oid)
+ }
+ }
+ pendOids = append(pendOids, addOids...)
+ var (
+ _old interface{} = oldOids
+ _new interface{} = pendOids
+ )
+ switch {
+ case len(oldOids) == 0 && len(pendOids) == 0:
+ // If list has not changed from empty, then no need to write the DB.
+ return []string{}, nil
+ case len(oldOids) == 0:
+ _old = nil
+ case len(pendOids) == 0:
+ _new = nil
+ }
+ if err = db.save(ctx, accID, _new, _old, "orderIDsByAccountID", ordersByAccountIDTable); err != nil {
+ // Delete all orders that may have been previously stored if orderIDsByAccountID update fails.
+ for _, oid := range addOids {
+ // Ignore error from delete -- we tried our best.
+ // TODO when we have logging w/ request ID tracking, logging this error.
+ db.db.Del(orderTable, []byte(oid))
+ }
+ return nil, errors.Wrapf(err, "error saving orderIDs index for account %s", accID)
+ }
+ return pendOids, nil
+}
+
+// GetOrdersByAccountID returns a list of order IDs owned by the account.
+func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
+ return db.updateAddOrderIDs(ctx, accID)
+}
diff --git a/acme/db/nosql/order_test.go b/acme/db/nosql/order_test.go
new file mode 100644
index 00000000..7248700f
--- /dev/null
+++ b/acme/db/nosql/order_test.go
@@ -0,0 +1,1003 @@
+package nosql
+
+import (
+ "context"
+ "encoding/json"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/smallstep/assert"
+ "github.com/smallstep/certificates/acme"
+ "github.com/smallstep/certificates/db"
+ "github.com/smallstep/nosql"
+ nosqldb "github.com/smallstep/nosql/database"
+)
+
+func TestDB_getDBOrder(t *testing.T) {
+ orderID := "orderID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ dbo *dbOrder
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/not-found": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+
+ return nil, nosqldb.ErrNotFound
+ },
+ },
+ acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
+ }
+ },
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading order orderID: force"),
+ }
+ },
+ "fail/unmarshal-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+
+ return []byte("foo"), nil
+ },
+ },
+ err: errors.New("error unmarshaling order orderID into dbOrder"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ now := clock.Now()
+ dbo := &dbOrder{
+ ID: orderID,
+ AccountID: "accID",
+ ProvisionerID: "provID",
+ CertificateID: "certID",
+ Status: acme.StatusValid,
+ ExpiresAt: now,
+ CreatedAt: now,
+ NotBefore: now,
+ NotAfter: now,
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "test.ca.smallstep.com"},
+ {Type: "dns", Value: "example.foo.com"},
+ },
+ AuthorizationIDs: []string{"foo", "bar"},
+ Error: acme.NewError(acme.ErrorMalformedType, "force"),
+ }
+ b, err := json.Marshal(dbo)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+
+ return b, nil
+ },
+ },
+ dbo: dbo,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if dbo, err := db.getDBOrder(context.Background(), orderID); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, dbo.ID, tc.dbo.ID)
+ assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID)
+ assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID)
+ assert.Equals(t, dbo.Status, tc.dbo.Status)
+ assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt)
+ assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt)
+ assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore)
+ assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter)
+ assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers)
+ assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs)
+ assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
+ }
+ }
+ })
+ }
+}
+
+func TestDB_GetOrder(t *testing.T) {
+ orderID := "orderID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ dbo *dbOrder
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading order orderID: force"),
+ }
+ },
+ "fail/forward-acme-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+
+ return nil, nosqldb.ErrNotFound
+ },
+ },
+ acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ now := clock.Now()
+ dbo := &dbOrder{
+ ID: orderID,
+ AccountID: "accID",
+ ProvisionerID: "provID",
+ CertificateID: "certID",
+ Status: acme.StatusValid,
+ ExpiresAt: now,
+ CreatedAt: now,
+ NotBefore: now,
+ NotAfter: now,
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "test.ca.smallstep.com"},
+ {Type: "dns", Value: "example.foo.com"},
+ },
+ AuthorizationIDs: []string{"foo", "bar"},
+ Error: acme.NewError(acme.ErrorMalformedType, "force"),
+ }
+ b, err := json.Marshal(dbo)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+ return b, nil
+ },
+ },
+ dbo: dbo,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if o, err := db.GetOrder(context.Background(), orderID); err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, o.ID, tc.dbo.ID)
+ assert.Equals(t, o.AccountID, tc.dbo.AccountID)
+ assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID)
+ assert.Equals(t, o.CertificateID, tc.dbo.CertificateID)
+ assert.Equals(t, o.Status, tc.dbo.Status)
+ assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt)
+ assert.Equals(t, o.NotBefore, tc.dbo.NotBefore)
+ assert.Equals(t, o.NotAfter, tc.dbo.NotAfter)
+ assert.Equals(t, o.Identifiers, tc.dbo.Identifiers)
+ assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs)
+ assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
+ }
+ }
+ })
+ }
+}
+
+func TestDB_UpdateOrder(t *testing.T) {
+ orderID := "orderID"
+ now := clock.Now()
+ dbo := &dbOrder{
+ ID: orderID,
+ AccountID: "accID",
+ ProvisionerID: "provID",
+ Status: acme.StatusPending,
+ ExpiresAt: now,
+ CreatedAt: now,
+ NotBefore: now,
+ NotAfter: now,
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "test.ca.smallstep.com"},
+ {Type: "dns", Value: "example.foo.com"},
+ },
+ AuthorizationIDs: []string{"foo", "bar"},
+ }
+ b, err := json.Marshal(dbo)
+ assert.FatalError(t, err)
+ type test struct {
+ db nosql.DB
+ o *acme.Order
+ err error
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ o: &acme.Order{
+ ID: orderID,
+ },
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.New("error loading order orderID: force"),
+ }
+ },
+ "fail/save-error": func(t *testing.T) test {
+ o := &acme.Order{
+ ID: orderID,
+ Status: acme.StatusValid,
+ CertificateID: "certID",
+ Error: acme.NewError(acme.ErrorMalformedType, "force"),
+ }
+ return test{
+ o: o,
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+
+ return b, nil
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, old, b)
+
+ dbNew := new(dbOrder)
+ assert.FatalError(t, json.Unmarshal(nu, dbNew))
+ assert.Equals(t, dbNew.ID, dbo.ID)
+ assert.Equals(t, dbNew.AccountID, dbo.AccountID)
+ assert.Equals(t, dbNew.ProvisionerID, dbo.ProvisionerID)
+ assert.Equals(t, dbNew.CertificateID, o.CertificateID)
+ assert.Equals(t, dbNew.Status, o.Status)
+ assert.Equals(t, dbNew.CreatedAt, dbo.CreatedAt)
+ assert.Equals(t, dbNew.ExpiresAt, dbo.ExpiresAt)
+ assert.Equals(t, dbNew.NotBefore, dbo.NotBefore)
+ assert.Equals(t, dbNew.NotAfter, dbo.NotAfter)
+ assert.Equals(t, dbNew.AuthorizationIDs, dbo.AuthorizationIDs)
+ assert.Equals(t, dbNew.Identifiers, dbo.Identifiers)
+ assert.Equals(t, dbNew.Error.Error(), o.Error.Error())
+ return nil, false, errors.New("force")
+ },
+ },
+ err: errors.New("error saving acme order: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ o := &acme.Order{
+ ID: orderID,
+ Status: acme.StatusValid,
+ CertificateID: "certID",
+ Error: acme.NewError(acme.ErrorMalformedType, "force"),
+ }
+ return test{
+ o: o,
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, string(key), orderID)
+
+ return b, nil
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, old, b)
+
+ dbNew := new(dbOrder)
+ assert.FatalError(t, json.Unmarshal(nu, dbNew))
+ assert.Equals(t, dbNew.ID, dbo.ID)
+ assert.Equals(t, dbNew.AccountID, dbo.AccountID)
+ assert.Equals(t, dbNew.ProvisionerID, dbo.ProvisionerID)
+ assert.Equals(t, dbNew.CertificateID, o.CertificateID)
+ assert.Equals(t, dbNew.Status, o.Status)
+ assert.Equals(t, dbNew.CreatedAt, dbo.CreatedAt)
+ assert.Equals(t, dbNew.ExpiresAt, dbo.ExpiresAt)
+ assert.Equals(t, dbNew.NotBefore, dbo.NotBefore)
+ assert.Equals(t, dbNew.NotAfter, dbo.NotAfter)
+ assert.Equals(t, dbNew.AuthorizationIDs, dbo.AuthorizationIDs)
+ assert.Equals(t, dbNew.Identifiers, dbo.Identifiers)
+ assert.Equals(t, dbNew.Error.Error(), o.Error.Error())
+ return nu, true, nil
+ },
+ },
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.UpdateOrder(context.Background(), tc.o); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, tc.o.ID, dbo.ID)
+ assert.Equals(t, tc.o.CertificateID, "certID")
+ assert.Equals(t, tc.o.Status, acme.StatusValid)
+ assert.Equals(t, tc.o.Error.Error(), acme.NewError(acme.ErrorMalformedType, "force").Error())
+ }
+ }
+ })
+ }
+}
+
+func TestDB_CreateOrder(t *testing.T) {
+ now := clock.Now()
+ nbf := now.Add(5 * time.Minute)
+ naf := now.Add(15 * time.Minute)
+ type test struct {
+ db nosql.DB
+ o *acme.Order
+ err error
+ _id *string
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/order-save-error": func(t *testing.T) test {
+ o := &acme.Order{
+ AccountID: "accID",
+ ProvisionerID: "provID",
+ CertificateID: "certID",
+ Status: acme.StatusValid,
+ ExpiresAt: now,
+ NotBefore: nbf,
+ NotAfter: naf,
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "test.ca.smallstep.com"},
+ {Type: "dns", Value: "example.foo.com"},
+ },
+ AuthorizationIDs: []string{"foo", "bar"},
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, string(bucket), string(orderTable))
+ assert.Equals(t, string(key), o.ID)
+ assert.Equals(t, old, nil)
+
+ dbo := new(dbOrder)
+ assert.FatalError(t, json.Unmarshal(nu, dbo))
+ assert.Equals(t, dbo.ID, o.ID)
+ assert.Equals(t, dbo.AccountID, o.AccountID)
+ assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID)
+ assert.Equals(t, dbo.CertificateID, "")
+ assert.Equals(t, dbo.Status, o.Status)
+ assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now))
+ assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now))
+ assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt)
+ assert.Equals(t, dbo.NotBefore, o.NotBefore)
+ assert.Equals(t, dbo.NotAfter, o.NotAfter)
+ assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs)
+ assert.Equals(t, dbo.Identifiers, o.Identifiers)
+ assert.Equals(t, dbo.Error, nil)
+ return nil, false, errors.New("force")
+ },
+ },
+ o: o,
+ err: errors.New("error saving acme order: force"),
+ }
+ },
+ "fail/orderIDsByOrderUpdate-error": func(t *testing.T) test {
+ o := &acme.Order{
+ AccountID: "accID",
+ ProvisionerID: "provID",
+ CertificateID: "certID",
+ Status: acme.StatusValid,
+ ExpiresAt: now,
+ NotBefore: nbf,
+ NotAfter: naf,
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "test.ca.smallstep.com"},
+ {Type: "dns", Value: "example.foo.com"},
+ },
+ AuthorizationIDs: []string{"foo", "bar"},
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, string(bucket), string(ordersByAccountIDTable))
+ assert.Equals(t, string(key), o.AccountID)
+ return nil, errors.New("force")
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, string(bucket), string(orderTable))
+ assert.Equals(t, string(key), o.ID)
+ assert.Equals(t, old, nil)
+
+ dbo := new(dbOrder)
+ assert.FatalError(t, json.Unmarshal(nu, dbo))
+ assert.Equals(t, dbo.ID, o.ID)
+ assert.Equals(t, dbo.AccountID, o.AccountID)
+ assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID)
+ assert.Equals(t, dbo.CertificateID, "")
+ assert.Equals(t, dbo.Status, o.Status)
+ assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now))
+ assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now))
+ assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt)
+ assert.Equals(t, dbo.NotBefore, o.NotBefore)
+ assert.Equals(t, dbo.NotAfter, o.NotAfter)
+ assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs)
+ assert.Equals(t, dbo.Identifiers, o.Identifiers)
+ assert.Equals(t, dbo.Error, nil)
+ return nu, true, nil
+ },
+ },
+ o: o,
+ err: errors.New("error loading orderIDs for account accID: force"),
+ }
+ },
+ "ok": func(t *testing.T) test {
+ var (
+ id string
+ idptr = &id
+ )
+
+ o := &acme.Order{
+ AccountID: "accID",
+ ProvisionerID: "provID",
+ Status: acme.StatusValid,
+ ExpiresAt: now,
+ NotBefore: nbf,
+ NotAfter: naf,
+ Identifiers: []acme.Identifier{
+ {Type: "dns", Value: "test.ca.smallstep.com"},
+ {Type: "dns", Value: "example.foo.com"},
+ },
+ AuthorizationIDs: []string{"foo", "bar"},
+ }
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, string(bucket), string(ordersByAccountIDTable))
+ assert.Equals(t, string(key), o.AccountID)
+ return nil, nosqldb.ErrNotFound
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ switch string(bucket) {
+ case string(ordersByAccountIDTable):
+ b, err := json.Marshal([]string{o.ID})
+ assert.FatalError(t, err)
+ assert.Equals(t, string(key), "accID")
+ assert.Equals(t, old, nil)
+ assert.Equals(t, nu, b)
+ return nu, true, nil
+ case string(orderTable):
+ *idptr = string(key)
+ assert.Equals(t, string(key), o.ID)
+ assert.Equals(t, old, nil)
+
+ dbo := new(dbOrder)
+ assert.FatalError(t, json.Unmarshal(nu, dbo))
+ assert.Equals(t, dbo.ID, o.ID)
+ assert.Equals(t, dbo.AccountID, o.AccountID)
+ assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID)
+ assert.Equals(t, dbo.CertificateID, "")
+ assert.Equals(t, dbo.Status, o.Status)
+ assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now))
+ assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now))
+ assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt)
+ assert.Equals(t, dbo.NotBefore, o.NotBefore)
+ assert.Equals(t, dbo.NotAfter, o.NotAfter)
+ assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs)
+ assert.Equals(t, dbo.Identifiers, o.Identifiers)
+ assert.Equals(t, dbo.Error, nil)
+ return nu, true, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
+ return nil, false, errors.New("force")
+ }
+ },
+ },
+ o: o,
+ _id: idptr,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ if err := db.CreateOrder(context.Background(), tc.o); err != nil {
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.Equals(t, tc.o.ID, *tc._id)
+ }
+ }
+ })
+ }
+}
+
+func TestDB_updateAddOrderIDs(t *testing.T) {
+ accID := "accID"
+ type test struct {
+ db nosql.DB
+ err error
+ acmeErr *acme.Error
+ addOids []string
+ res []string
+ }
+ var tests = map[string]func(t *testing.T) test{
+ "fail/db.Get-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, ordersByAccountIDTable)
+ assert.Equals(t, key, []byte(accID))
+ return nil, errors.New("force")
+ },
+ },
+ err: errors.Errorf("error loading orderIDs for account %s", accID),
+ }
+ },
+ "fail/unmarshal-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, ordersByAccountIDTable)
+ assert.Equals(t, key, []byte(accID))
+ return []byte("foo"), nil
+ },
+ },
+ err: errors.Errorf("error unmarshaling orderIDs for account %s", accID),
+ }
+ },
+ "fail/db.Get-order-error": func(t *testing.T) test {
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(ordersByAccountIDTable):
+ assert.Equals(t, key, []byte(accID))
+ b, err := json.Marshal([]string{"foo", "bar"})
+ assert.FatalError(t, err)
+ return b, nil
+ case string(orderTable):
+ assert.Equals(t, key, []byte("foo"))
+ return nil, errors.New("force")
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ },
+ acmeErr: acme.NewErrorISE("error loading order foo for account accID: error loading order foo: force"),
+ }
+ },
+ "fail/update-order-status-error": func(t *testing.T) test {
+ expiry := clock.Now().Add(-5 * time.Minute)
+ ofoo := &dbOrder{
+ ID: "foo",
+ Status: acme.StatusPending,
+ ExpiresAt: expiry,
+ }
+ bfoo, err := json.Marshal(ofoo)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(ordersByAccountIDTable):
+ assert.Equals(t, key, []byte(accID))
+ b, err := json.Marshal([]string{"foo", "bar"})
+ assert.FatalError(t, err)
+ return b, nil
+ case string(orderTable):
+ assert.Equals(t, key, []byte("foo"))
+ return bfoo, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, key, []byte("foo"))
+ assert.Equals(t, old, bfoo)
+
+ newdbo := new(dbOrder)
+ assert.FatalError(t, json.Unmarshal(nu, newdbo))
+ assert.Equals(t, newdbo.ID, "foo")
+ assert.Equals(t, newdbo.Status, acme.StatusInvalid)
+ assert.Equals(t, newdbo.ExpiresAt, expiry)
+ assert.Equals(t, newdbo.Error.Error(), acme.NewError(acme.ErrorMalformedType, "order has expired").Error())
+ return nil, false, errors.New("force")
+ },
+ },
+ acmeErr: acme.NewErrorISE("error updating order foo for account accID: error updating order: error saving acme order: force"),
+ }
+ },
+ "fail/db.save-order-error": func(t *testing.T) test {
+ addOids := []string{"foo", "bar"}
+ b, err := json.Marshal(addOids)
+ assert.FatalError(t, err)
+ delCount := 0
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ assert.Equals(t, bucket, ordersByAccountIDTable)
+ assert.Equals(t, key, []byte(accID))
+ return nil, nosqldb.ErrNotFound
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ assert.Equals(t, bucket, ordersByAccountIDTable)
+ assert.Equals(t, key, []byte(accID))
+ assert.Equals(t, old, nil)
+ assert.Equals(t, nu, b)
+ return nil, false, errors.New("force")
+ },
+ MDel: func(bucket, key []byte) error {
+ delCount++
+ switch delCount {
+ case 1:
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, key, []byte("foo"))
+ return nil
+ case 2:
+ assert.Equals(t, bucket, orderTable)
+ assert.Equals(t, key, []byte("bar"))
+ return nil
+ default:
+ assert.FatalError(t, errors.New("delete should only be called twice"))
+ return errors.New("force")
+ }
+ },
+ },
+ addOids: addOids,
+ err: errors.Errorf("error saving orderIDs index for account %s", accID),
+ }
+ },
+ "ok/all-old-not-pending": func(t *testing.T) test {
+ oldOids := []string{"foo", "bar"}
+ bOldOids, err := json.Marshal(oldOids)
+ assert.FatalError(t, err)
+ expiry := clock.Now().Add(-5 * time.Minute)
+ ofoo := &dbOrder{
+ ID: "foo",
+ Status: acme.StatusPending,
+ ExpiresAt: expiry,
+ }
+ bfoo, err := json.Marshal(ofoo)
+ assert.FatalError(t, err)
+ obar := &dbOrder{
+ ID: "bar",
+ Status: acme.StatusPending,
+ ExpiresAt: expiry,
+ }
+ bbar, err := json.Marshal(obar)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(ordersByAccountIDTable):
+ return bOldOids, nil
+ case string(orderTable):
+ switch string(key) {
+ case "foo":
+ assert.Equals(t, key, []byte("foo"))
+ return bfoo, nil
+ case "bar":
+ assert.Equals(t, key, []byte("bar"))
+ return bbar, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected key %s", string(key)))
+ return nil, errors.New("force")
+ }
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ switch string(bucket) {
+ case string(orderTable):
+ return nil, true, nil
+ case string(ordersByAccountIDTable):
+ assert.Equals(t, key, []byte(accID))
+ assert.Equals(t, old, bOldOids)
+ assert.Equals(t, nu, nil)
+ return nil, true, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
+ return nil, false, errors.New("force")
+ }
+ },
+ },
+ res: []string{},
+ }
+ },
+ "ok/old-and-new": func(t *testing.T) test {
+ oldOids := []string{"foo", "bar"}
+ bOldOids, err := json.Marshal(oldOids)
+ assert.FatalError(t, err)
+ addOids := []string{"zap", "zar"}
+ bAddOids, err := json.Marshal(addOids)
+ assert.FatalError(t, err)
+ expiry := clock.Now().Add(-5 * time.Minute)
+ ofoo := &dbOrder{
+ ID: "foo",
+ Status: acme.StatusPending,
+ ExpiresAt: expiry,
+ }
+ bfoo, err := json.Marshal(ofoo)
+ assert.FatalError(t, err)
+ obar := &dbOrder{
+ ID: "bar",
+ Status: acme.StatusPending,
+ ExpiresAt: expiry,
+ }
+ bbar, err := json.Marshal(obar)
+ assert.FatalError(t, err)
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(ordersByAccountIDTable):
+ return bOldOids, nil
+ case string(orderTable):
+ switch string(key) {
+ case "foo":
+ assert.Equals(t, key, []byte("foo"))
+ return bfoo, nil
+ case "bar":
+ assert.Equals(t, key, []byte("bar"))
+ return bbar, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected key %s", string(key)))
+ return nil, errors.New("force")
+ }
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ switch string(bucket) {
+ case string(orderTable):
+ return nil, true, nil
+ case string(ordersByAccountIDTable):
+ assert.Equals(t, key, []byte(accID))
+ assert.Equals(t, old, bOldOids)
+ assert.Equals(t, nu, bAddOids)
+ return nil, true, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
+ return nil, false, errors.New("force")
+ }
+ },
+ },
+ addOids: addOids,
+ res: addOids,
+ }
+ },
+ "ok/old-and-new-2": func(t *testing.T) test {
+ oldOids := []string{"foo", "bar", "baz"}
+ bOldOids, err := json.Marshal(oldOids)
+ assert.FatalError(t, err)
+ addOids := []string{"zap", "zar"}
+ now := clock.Now()
+ min5 := now.Add(5 * time.Minute)
+ expiry := now.Add(-5 * time.Minute)
+
+ o1 := &dbOrder{
+ ID: "foo",
+ Status: acme.StatusPending,
+ ExpiresAt: min5,
+ AuthorizationIDs: []string{"a"},
+ }
+ bo1, err := json.Marshal(o1)
+ assert.FatalError(t, err)
+ o2 := &dbOrder{
+ ID: "bar",
+ Status: acme.StatusPending,
+ ExpiresAt: expiry,
+ }
+ bo2, err := json.Marshal(o2)
+ assert.FatalError(t, err)
+ o3 := &dbOrder{
+ ID: "baz",
+ Status: acme.StatusPending,
+ ExpiresAt: min5,
+ AuthorizationIDs: []string{"b"},
+ }
+ bo3, err := json.Marshal(o3)
+ assert.FatalError(t, err)
+
+ az1 := &dbAuthz{
+ ID: "a",
+ Status: acme.StatusPending,
+ ExpiresAt: min5,
+ ChallengeIDs: []string{"aa"},
+ }
+ baz1, err := json.Marshal(az1)
+ assert.FatalError(t, err)
+ az2 := &dbAuthz{
+ ID: "b",
+ Status: acme.StatusPending,
+ ExpiresAt: min5,
+ ChallengeIDs: []string{"bb"},
+ }
+ baz2, err := json.Marshal(az2)
+ assert.FatalError(t, err)
+
+ ch1 := &dbChallenge{
+ ID: "aa",
+ Status: acme.StatusPending,
+ }
+ bch1, err := json.Marshal(ch1)
+ assert.FatalError(t, err)
+ ch2 := &dbChallenge{
+ ID: "bb",
+ Status: acme.StatusPending,
+ }
+ bch2, err := json.Marshal(ch2)
+ assert.FatalError(t, err)
+
+ newOids := append([]string{"foo", "baz"}, addOids...)
+ bNewOids, err := json.Marshal(newOids)
+ assert.FatalError(t, err)
+
+ return test{
+ db: &db.MockNoSQLDB{
+ MGet: func(bucket, key []byte) ([]byte, error) {
+ switch string(bucket) {
+ case string(authzTable):
+ switch string(key) {
+ case "a":
+ return baz1, nil
+ case "b":
+ return baz2, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected authz key %s", string(key)))
+ return nil, errors.New("force")
+ }
+ case string(challengeTable):
+ switch string(key) {
+ case "aa":
+ return bch1, nil
+ case "bb":
+ return bch2, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected challenge key %s", string(key)))
+ return nil, errors.New("force")
+ }
+ case string(ordersByAccountIDTable):
+ return bOldOids, nil
+ case string(orderTable):
+ switch string(key) {
+ case "foo":
+ return bo1, nil
+ case "bar":
+ return bo2, nil
+ case "baz":
+ return bo3, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected key %s", string(key)))
+ return nil, errors.New("force")
+ }
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
+ return nil, errors.New("force")
+ }
+ },
+ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
+ switch string(bucket) {
+ case string(orderTable):
+ return nil, true, nil
+ case string(ordersByAccountIDTable):
+ assert.Equals(t, key, []byte(accID))
+ assert.Equals(t, old, bOldOids)
+ assert.Equals(t, nu, bNewOids)
+ return nil, true, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
+ return nil, false, errors.New("force")
+ }
+ },
+ },
+ addOids: addOids,
+ res: newOids,
+ }
+ },
+ }
+ for name, run := range tests {
+ tc := run(t)
+ t.Run(name, func(t *testing.T) {
+ db := DB{db: tc.db}
+ var (
+ res []string
+ err error
+ )
+ if tc.addOids == nil {
+ res, err = db.updateAddOrderIDs(context.Background(), accID)
+ } else {
+ res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...)
+ }
+
+ if err != nil {
+ switch k := err.(type) {
+ case *acme.Error:
+ if assert.NotNil(t, tc.acmeErr) {
+ assert.Equals(t, k.Type, tc.acmeErr.Type)
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ assert.Equals(t, k.Status, tc.acmeErr.Status)
+ assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
+ assert.Equals(t, k.Detail, tc.acmeErr.Detail)
+ }
+ default:
+ if assert.NotNil(t, tc.err) {
+ assert.HasPrefix(t, err.Error(), tc.err.Error())
+ }
+ }
+ } else {
+ if assert.Nil(t, tc.err) {
+ assert.True(t, reflect.DeepEqual(res, tc.res))
+ }
+ }
+ })
+ }
+}
diff --git a/acme/directory.go b/acme/directory.go
deleted file mode 100644
index d5681b73..00000000
--- a/acme/directory.go
+++ /dev/null
@@ -1,150 +0,0 @@
-package acme
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "net/url"
-
- "github.com/pkg/errors"
-)
-
-// Directory represents an ACME directory for configuring clients.
-type Directory struct {
- NewNonce string `json:"newNonce,omitempty"`
- NewAccount string `json:"newAccount,omitempty"`
- NewOrder string `json:"newOrder,omitempty"`
- NewAuthz string `json:"newAuthz,omitempty"`
- RevokeCert string `json:"revokeCert,omitempty"`
- KeyChange string `json:"keyChange,omitempty"`
-}
-
-// ToLog enables response logging for the Directory type.
-func (d *Directory) ToLog() (interface{}, error) {
- b, err := json.Marshal(d)
- if err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging"))
- }
- return string(b), nil
-}
-
-type directory struct {
- prefix, dns string
-}
-
-// newDirectory returns a new Directory type.
-func newDirectory(dns, prefix string) *directory {
- return &directory{prefix: prefix, dns: dns}
-}
-
-// Link captures the link type.
-type Link int
-
-const (
- // NewNonceLink new-nonce
- NewNonceLink Link = iota
- // NewAccountLink new-account
- NewAccountLink
- // AccountLink account
- AccountLink
- // OrderLink order
- OrderLink
- // NewOrderLink new-order
- NewOrderLink
- // OrdersByAccountLink list of orders owned by account
- OrdersByAccountLink
- // FinalizeLink finalize order
- FinalizeLink
- // NewAuthzLink authz
- NewAuthzLink
- // AuthzLink new-authz
- AuthzLink
- // ChallengeLink challenge
- ChallengeLink
- // CertificateLink certificate
- CertificateLink
- // DirectoryLink directory
- DirectoryLink
- // RevokeCertLink revoke certificate
- RevokeCertLink
- // KeyChangeLink key rollover
- KeyChangeLink
-)
-
-func (l Link) String() string {
- switch l {
- case NewNonceLink:
- return "new-nonce"
- case NewAccountLink:
- return "new-account"
- case AccountLink:
- return "account"
- case NewOrderLink:
- return "new-order"
- case OrderLink:
- return "order"
- case NewAuthzLink:
- return "new-authz"
- case AuthzLink:
- return "authz"
- case ChallengeLink:
- return "challenge"
- case CertificateLink:
- return "certificate"
- case DirectoryLink:
- return "directory"
- case RevokeCertLink:
- return "revoke-cert"
- case KeyChangeLink:
- return "key-change"
- default:
- return "unexpected"
- }
-}
-
-func (d *directory) getLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
- var provName string
- if p, err := ProvisionerFromContext(ctx); err == nil && p != nil {
- provName = p.GetName()
- }
- return d.getLinkExplicit(typ, provName, abs, BaseURLFromContext(ctx), inputs...)
-}
-
-// getLinkExplicit returns an absolute or partial path to the given resource and a base
-// URL dynamically obtained from the request for which the link is being
-// calculated.
-func (d *directory) getLinkExplicit(typ Link, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string {
- var link string
- switch typ {
- case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
- link = fmt.Sprintf("/%s/%s", provisionerName, typ.String())
- case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink:
- link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0])
- case OrdersByAccountLink:
- link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0])
- case FinalizeLink:
- link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
- }
-
- if abs {
- // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
- u := url.URL{}
- if baseURL != nil {
- u = *baseURL
- }
-
- // If no Scheme is set, then default to https.
- if u.Scheme == "" {
- u.Scheme = "https"
- }
-
- // If no Host is set, then use the default (first DNS attr in the ca.json).
- if u.Host == "" {
- u.Host = d.dns
- }
-
- u.Path = d.prefix + link
- return u.String()
- }
- return link
-}
diff --git a/acme/directory_test.go b/acme/directory_test.go
deleted file mode 100644
index dd4c534c..00000000
--- a/acme/directory_test.go
+++ /dev/null
@@ -1,99 +0,0 @@
-package acme
-
-import (
- "context"
- "fmt"
- "net/url"
- "testing"
-
- "github.com/smallstep/assert"
-)
-
-func TestDirectoryGetLink(t *testing.T) {
- dns := "ca.smallstep.com"
- prefix := "acme"
- dir := newDirectory(dns, prefix)
- id := "1234"
-
- prov := newProv()
- provName := url.PathEscape(prov.GetName())
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
-
- assert.Equals(t, dir.getLink(ctx, NewNonceLink, true),
- fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName))
- assert.Equals(t, dir.getLink(ctx, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName))
-
- // No provisioner
- ctxNoProv := context.WithValue(context.Background(), BaseURLContextKey, baseURL)
- assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, true),
- fmt.Sprintf("%s/acme//new-nonce", baseURL.String()))
- assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, false), "//new-nonce")
-
- // No baseURL
- ctxNoBaseURL := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, true),
- fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName))
- assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName))
-
- assert.Equals(t, dir.getLink(ctx, OrderLink, true, id),
- fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName))
- assert.Equals(t, dir.getLink(ctx, OrderLink, false, id), fmt.Sprintf("/%s/order/1234", provName))
-}
-
-func TestDirectoryGetLinkExplicit(t *testing.T) {
- dns := "ca.smallstep.com"
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- prefix := "acme"
- dir := newDirectory(dns, prefix)
- id := "1234"
-
- prov := newProv()
- provID := url.PathEscape(prov.GetName())
-
- assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID))
- assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID))
- assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID))
- assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID))
-
- assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID))
-
- assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID))
-
- assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID))
-
- assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID))
-
- assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID))
-
- assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID))
-
- assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID))
-
- assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID))
-
- assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID))
-
- assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID))
-
- assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID))
-
- assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/1234", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/1234", provID))
-
- assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID))
- assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID))
-}
diff --git a/acme/errors.go b/acme/errors.go
index a4dd8159..6ecf0912 100644
--- a/acme/errors.go
+++ b/acme/errors.go
@@ -1,407 +1,339 @@
package acme
import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+
"github.com/pkg/errors"
+ "github.com/smallstep/certificates/errs"
+ "github.com/smallstep/certificates/logging"
)
-// AccountDoesNotExistErr returns a new acme error.
-func AccountDoesNotExistErr(err error) *Error {
- return &Error{
- Type: accountDoesNotExistErr,
- Detail: "Account does not exist",
- Status: 400,
- Err: err,
- }
-}
-
-// AlreadyRevokedErr returns a new acme error.
-func AlreadyRevokedErr(err error) *Error {
- return &Error{
- Type: alreadyRevokedErr,
- Detail: "Certificate already revoked",
- Status: 400,
- Err: err,
- }
-}
-
-// BadCSRErr returns a new acme error.
-func BadCSRErr(err error) *Error {
- return &Error{
- Type: badCSRErr,
- Detail: "The CSR is unacceptable",
- Status: 400,
- Err: err,
- }
-}
-
-// BadNonceErr returns a new acme error.
-func BadNonceErr(err error) *Error {
- return &Error{
- Type: badNonceErr,
- Detail: "Unacceptable anti-replay nonce",
- Status: 400,
- Err: err,
- }
-}
-
-// BadPublicKeyErr returns a new acme error.
-func BadPublicKeyErr(err error) *Error {
- return &Error{
- Type: badPublicKeyErr,
- Detail: "The jws was signed by a public key the server does not support",
- Status: 400,
- Err: err,
- }
-}
-
-// BadRevocationReasonErr returns a new acme error.
-func BadRevocationReasonErr(err error) *Error {
- return &Error{
- Type: badRevocationReasonErr,
- Detail: "The revocation reason provided is not allowed by the server",
- Status: 400,
- Err: err,
- }
-}
-
-// BadSignatureAlgorithmErr returns a new acme error.
-func BadSignatureAlgorithmErr(err error) *Error {
- return &Error{
- Type: badSignatureAlgorithmErr,
- Detail: "The JWS was signed with an algorithm the server does not support",
- Status: 400,
- Err: err,
- }
-}
-
-// CaaErr returns a new acme error.
-func CaaErr(err error) *Error {
- return &Error{
- Type: caaErr,
- Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
- Status: 400,
- Err: err,
- }
-}
-
-// CompoundErr returns a new acme error.
-func CompoundErr(err error) *Error {
- return &Error{
- Type: compoundErr,
- Detail: "Specific error conditions are indicated in the “subproblems” array",
- Status: 400,
- Err: err,
- }
-}
-
-// ConnectionErr returns a new acme error.
-func ConnectionErr(err error) *Error {
- return &Error{
- Type: connectionErr,
- Detail: "The server could not connect to validation target",
- Status: 400,
- Err: err,
- }
-}
-
-// DNSErr returns a new acme error.
-func DNSErr(err error) *Error {
- return &Error{
- Type: dnsErr,
- Detail: "There was a problem with a DNS query during identifier validation",
- Status: 400,
- Err: err,
- }
-}
-
-// ExternalAccountRequiredErr returns a new acme error.
-func ExternalAccountRequiredErr(err error) *Error {
- return &Error{
- Type: externalAccountRequiredErr,
- Detail: "The request must include a value for the \"externalAccountBinding\" field",
- Status: 400,
- Err: err,
- }
-}
-
-// IncorrectResponseErr returns a new acme error.
-func IncorrectResponseErr(err error) *Error {
- return &Error{
- Type: incorrectResponseErr,
- Detail: "Response received didn't match the challenge's requirements",
- Status: 400,
- Err: err,
- }
-}
-
-// InvalidContactErr returns a new acme error.
-func InvalidContactErr(err error) *Error {
- return &Error{
- Type: invalidContactErr,
- Detail: "A contact URL for an account was invalid",
- Status: 400,
- Err: err,
- }
-}
-
-// MalformedErr returns a new acme error.
-func MalformedErr(err error) *Error {
- return &Error{
- Type: malformedErr,
- Detail: "The request message was malformed",
- Status: 400,
- Err: err,
- }
-}
-
-// OrderNotReadyErr returns a new acme error.
-func OrderNotReadyErr(err error) *Error {
- return &Error{
- Type: orderNotReadyErr,
- Detail: "The request attempted to finalize an order that is not ready to be finalized",
- Status: 400,
- Err: err,
- }
-}
-
-// RateLimitedErr returns a new acme error.
-func RateLimitedErr(err error) *Error {
- return &Error{
- Type: rateLimitedErr,
- Detail: "The request exceeds a rate limit",
- Status: 400,
- Err: err,
- }
-}
-
-// RejectedIdentifierErr returns a new acme error.
-func RejectedIdentifierErr(err error) *Error {
- return &Error{
- Type: rejectedIdentifierErr,
- Detail: "The server will not issue certificates for the identifier",
- Status: 400,
- Err: err,
- }
-}
-
-// ServerInternalErr returns a new acme error.
-func ServerInternalErr(err error) *Error {
- return &Error{
- Type: serverInternalErr,
- Detail: "The server experienced an internal error",
- Status: 500,
- Err: err,
- }
-}
-
-// NotImplemented returns a new acme error.
-func NotImplemented(err error) *Error {
- return &Error{
- Type: notImplemented,
- Detail: "The requested operation is not implemented",
- Status: 501,
- Err: err,
- }
-}
-
-// TLSErr returns a new acme error.
-func TLSErr(err error) *Error {
- return &Error{
- Type: tlsErr,
- Detail: "The server received a TLS error during validation",
- Status: 400,
- Err: err,
- }
-}
-
-// UnauthorizedErr returns a new acme error.
-func UnauthorizedErr(err error) *Error {
- return &Error{
- Type: unauthorizedErr,
- Detail: "The client lacks sufficient authorization",
- Status: 401,
- Err: err,
- }
-}
-
-// UnsupportedContactErr returns a new acme error.
-func UnsupportedContactErr(err error) *Error {
- return &Error{
- Type: unsupportedContactErr,
- Detail: "A contact URL for an account used an unsupported protocol scheme",
- Status: 400,
- Err: err,
- }
-}
-
-// UnsupportedIdentifierErr returns a new acme error.
-func UnsupportedIdentifierErr(err error) *Error {
- return &Error{
- Type: unsupportedIdentifierErr,
- Detail: "An identifier is of an unsupported type",
- Status: 400,
- Err: err,
- }
-}
-
-// UserActionRequiredErr returns a new acme error.
-func UserActionRequiredErr(err error) *Error {
- return &Error{
- Type: userActionRequiredErr,
- Detail: "Visit the “instance” URL and take actions specified there",
- Status: 400,
- Err: err,
- }
-}
-
-// ProbType is the type of the ACME problem.
-type ProbType int
+// ProblemType is the type of the ACME problem.
+type ProblemType int
const (
- // The request specified an account that does not exist
- accountDoesNotExistErr ProbType = iota
- // The request specified a certificate to be revoked that has already been revoked
- alreadyRevokedErr
- // The CSR is unacceptable (e.g., due to a short key)
- badCSRErr
- // The client sent an unacceptable anti-replay nonce
- badNonceErr
- // The JWS was signed by a public key the server does not support
- badPublicKeyErr
- // The revocation reason provided is not allowed by the server
- badRevocationReasonErr
- // The JWS was signed with an algorithm the server does not support
- badSignatureAlgorithmErr
- // Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate
- caaErr
- // Specific error conditions are indicated in the “subproblems” array.
- compoundErr
- // The server could not connect to validation target
- connectionErr
- // There was a problem with a DNS query during identifier validation
- dnsErr
- // The request must include a value for the “externalAccountBinding” field
- externalAccountRequiredErr
- // Response received didn’t match the challenge’s requirements
- incorrectResponseErr
- // A contact URL for an account was invalid
- invalidContactErr
- // The request message was malformed
- malformedErr
- // The request attempted to finalize an order that is not ready to be finalized
- orderNotReadyErr
- // The request exceeds a rate limit
- rateLimitedErr
- // The server will not issue certificates for the identifier
- rejectedIdentifierErr
- // The server experienced an internal error
- serverInternalErr
- // The server received a TLS error during validation
- tlsErr
- // The client lacks sufficient authorization
- unauthorizedErr
- // A contact URL for an account used an unsupported protocol scheme
- unsupportedContactErr
- // An identifier is of an unsupported type
- unsupportedIdentifierErr
- // Visit the “instance” URL and take actions specified there
- userActionRequiredErr
- // The operation is not implemented
- notImplemented
+ // ErrorAccountDoesNotExistType request specified an account that does not exist
+ ErrorAccountDoesNotExistType ProblemType = iota
+ // ErrorAlreadyRevokedType request specified a certificate to be revoked that has already been revoked
+ ErrorAlreadyRevokedType
+ // ErrorBadCSRType CSR is unacceptable (e.g., due to a short key)
+ ErrorBadCSRType
+ // ErrorBadNonceType client sent an unacceptable anti-replay nonce
+ ErrorBadNonceType
+ // ErrorBadPublicKeyType JWS was signed by a public key the server does not support
+ ErrorBadPublicKeyType
+ // ErrorBadRevocationReasonType revocation reason provided is not allowed by the server
+ ErrorBadRevocationReasonType
+ // ErrorBadSignatureAlgorithmType JWS was signed with an algorithm the server does not support
+ ErrorBadSignatureAlgorithmType
+ // ErrorCaaType Authority Authorization (CAA) records forbid the CA from issuing a certificate
+ ErrorCaaType
+ // ErrorCompoundType error conditions are indicated in the “subproblems” array.
+ ErrorCompoundType
+ // ErrorConnectionType server could not connect to validation target
+ ErrorConnectionType
+ // ErrorDNSType was a problem with a DNS query during identifier validation
+ ErrorDNSType
+ // ErrorExternalAccountRequiredType request must include a value for the “externalAccountBinding” field
+ ErrorExternalAccountRequiredType
+ // ErrorIncorrectResponseType received didn’t match the challenge’s requirements
+ ErrorIncorrectResponseType
+ // ErrorInvalidContactType URL for an account was invalid
+ ErrorInvalidContactType
+ // ErrorMalformedType request message was malformed
+ ErrorMalformedType
+ // ErrorOrderNotReadyType request attempted to finalize an order that is not ready to be finalized
+ ErrorOrderNotReadyType
+ // ErrorRateLimitedType request exceeds a rate limit
+ ErrorRateLimitedType
+ // ErrorRejectedIdentifierType server will not issue certificates for the identifier
+ ErrorRejectedIdentifierType
+ // ErrorServerInternalType server experienced an internal error
+ ErrorServerInternalType
+ // ErrorTLSType server received a TLS error during validation
+ ErrorTLSType
+ // ErrorUnauthorizedType client lacks sufficient authorization
+ ErrorUnauthorizedType
+ // ErrorUnsupportedContactType URL for an account used an unsupported protocol scheme
+ ErrorUnsupportedContactType
+ // ErrorUnsupportedIdentifierType identifier is of an unsupported type
+ ErrorUnsupportedIdentifierType
+ // ErrorUserActionRequiredType the “instance” URL and take actions specified there
+ ErrorUserActionRequiredType
+ // ErrorNotImplementedType operation is not implemented
+ ErrorNotImplementedType
)
// String returns the string representation of the acme problem type,
// fulfilling the Stringer interface.
-func (ap ProbType) String() string {
+func (ap ProblemType) String() string {
switch ap {
- case accountDoesNotExistErr:
+ case ErrorAccountDoesNotExistType:
return "accountDoesNotExist"
- case alreadyRevokedErr:
+ case ErrorAlreadyRevokedType:
return "alreadyRevoked"
- case badCSRErr:
+ case ErrorBadCSRType:
return "badCSR"
- case badNonceErr:
+ case ErrorBadNonceType:
return "badNonce"
- case badPublicKeyErr:
+ case ErrorBadPublicKeyType:
return "badPublicKey"
- case badRevocationReasonErr:
+ case ErrorBadRevocationReasonType:
return "badRevocationReason"
- case badSignatureAlgorithmErr:
+ case ErrorBadSignatureAlgorithmType:
return "badSignatureAlgorithm"
- case caaErr:
+ case ErrorCaaType:
return "caa"
- case compoundErr:
+ case ErrorCompoundType:
return "compound"
- case connectionErr:
+ case ErrorConnectionType:
return "connection"
- case dnsErr:
+ case ErrorDNSType:
return "dns"
- case externalAccountRequiredErr:
+ case ErrorExternalAccountRequiredType:
return "externalAccountRequired"
- case incorrectResponseErr:
+ case ErrorInvalidContactType:
return "incorrectResponse"
- case invalidContactErr:
- return "invalidContact"
- case malformedErr:
+ case ErrorMalformedType:
return "malformed"
- case orderNotReadyErr:
+ case ErrorOrderNotReadyType:
return "orderNotReady"
- case rateLimitedErr:
+ case ErrorRateLimitedType:
return "rateLimited"
- case rejectedIdentifierErr:
+ case ErrorRejectedIdentifierType:
return "rejectedIdentifier"
- case serverInternalErr:
+ case ErrorServerInternalType:
return "serverInternal"
- case tlsErr:
+ case ErrorTLSType:
return "tls"
- case unauthorizedErr:
+ case ErrorUnauthorizedType:
return "unauthorized"
- case unsupportedContactErr:
+ case ErrorUnsupportedContactType:
return "unsupportedContact"
- case unsupportedIdentifierErr:
+ case ErrorUnsupportedIdentifierType:
return "unsupportedIdentifier"
- case userActionRequiredErr:
+ case ErrorUserActionRequiredType:
return "userActionRequired"
- case notImplemented:
+ case ErrorNotImplementedType:
return "notImplemented"
default:
- return "unsupported type"
+ return fmt.Sprintf("unsupported type ACME error type '%d'", int(ap))
}
}
-// Error is an ACME error type complete with problem document.
-type Error struct {
- Type ProbType
- Detail string
- Err error
- Status int
- Sub []*Error
- Identifier *Identifier
+type errorMetadata struct {
+ details string
+ status int
+ typ string
+ String string
}
-// Wrap attempts to wrap the internal error.
-func Wrap(err error, wrap string) *Error {
+var (
+ officialACMEPrefix = "urn:ietf:params:acme:error:"
+ errorServerInternalMetadata = errorMetadata{
+ typ: officialACMEPrefix + ErrorServerInternalType.String(),
+ details: "The server experienced an internal error",
+ status: 500,
+ }
+ errorMap = map[ProblemType]errorMetadata{
+ ErrorAccountDoesNotExistType: {
+ typ: officialACMEPrefix + ErrorAccountDoesNotExistType.String(),
+ details: "Account does not exist",
+ status: 400,
+ },
+ ErrorAlreadyRevokedType: {
+ typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(),
+ details: "Certificate already Revoked",
+ status: 400,
+ },
+ ErrorBadCSRType: {
+ typ: officialACMEPrefix + ErrorBadCSRType.String(),
+ details: "The CSR is unacceptable",
+ status: 400,
+ },
+ ErrorBadNonceType: {
+ typ: officialACMEPrefix + ErrorBadNonceType.String(),
+ details: "Unacceptable anti-replay nonce",
+ status: 400,
+ },
+ ErrorBadPublicKeyType: {
+ typ: officialACMEPrefix + ErrorBadPublicKeyType.String(),
+ details: "The jws was signed by a public key the server does not support",
+ status: 400,
+ },
+ ErrorBadRevocationReasonType: {
+ typ: officialACMEPrefix + ErrorBadRevocationReasonType.String(),
+ details: "The revocation reason provided is not allowed by the server",
+ status: 400,
+ },
+ ErrorBadSignatureAlgorithmType: {
+ typ: officialACMEPrefix + ErrorBadSignatureAlgorithmType.String(),
+ details: "The JWS was signed with an algorithm the server does not support",
+ status: 400,
+ },
+ ErrorCaaType: {
+ typ: officialACMEPrefix + ErrorCaaType.String(),
+ details: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
+ status: 400,
+ },
+ ErrorCompoundType: {
+ typ: officialACMEPrefix + ErrorCompoundType.String(),
+ details: "Specific error conditions are indicated in the “subproblems” array",
+ status: 400,
+ },
+ ErrorConnectionType: {
+ typ: officialACMEPrefix + ErrorConnectionType.String(),
+ details: "The server could not connect to validation target",
+ status: 400,
+ },
+ ErrorDNSType: {
+ typ: officialACMEPrefix + ErrorDNSType.String(),
+ details: "There was a problem with a DNS query during identifier validation",
+ status: 400,
+ },
+ ErrorExternalAccountRequiredType: {
+ typ: officialACMEPrefix + ErrorExternalAccountRequiredType.String(),
+ details: "The request must include a value for the \"externalAccountBinding\" field",
+ status: 400,
+ },
+ ErrorIncorrectResponseType: {
+ typ: officialACMEPrefix + ErrorIncorrectResponseType.String(),
+ details: "Response received didn't match the challenge's requirements",
+ status: 400,
+ },
+ ErrorInvalidContactType: {
+ typ: officialACMEPrefix + ErrorInvalidContactType.String(),
+ details: "A contact URL for an account was invalid",
+ status: 400,
+ },
+ ErrorMalformedType: {
+ typ: officialACMEPrefix + ErrorMalformedType.String(),
+ details: "The request message was malformed",
+ status: 400,
+ },
+ ErrorOrderNotReadyType: {
+ typ: officialACMEPrefix + ErrorOrderNotReadyType.String(),
+ details: "The request attempted to finalize an order that is not ready to be finalized",
+ status: 400,
+ },
+ ErrorRateLimitedType: {
+ typ: officialACMEPrefix + ErrorRateLimitedType.String(),
+ details: "The request exceeds a rate limit",
+ status: 400,
+ },
+ ErrorRejectedIdentifierType: {
+ typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(),
+ details: "The server will not issue certificates for the identifier",
+ status: 400,
+ },
+ ErrorNotImplementedType: {
+ typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(),
+ details: "The requested operation is not implemented",
+ status: 501,
+ },
+ ErrorTLSType: {
+ typ: officialACMEPrefix + ErrorTLSType.String(),
+ details: "The server received a TLS error during validation",
+ status: 400,
+ },
+ ErrorUnauthorizedType: {
+ typ: officialACMEPrefix + ErrorUnauthorizedType.String(),
+ details: "The client lacks sufficient authorization",
+ status: 401,
+ },
+ ErrorUnsupportedContactType: {
+ typ: officialACMEPrefix + ErrorUnsupportedContactType.String(),
+ details: "A contact URL for an account used an unsupported protocol scheme",
+ status: 400,
+ },
+ ErrorUnsupportedIdentifierType: {
+ typ: officialACMEPrefix + ErrorUnsupportedIdentifierType.String(),
+ details: "An identifier is of an unsupported type",
+ status: 400,
+ },
+ ErrorUserActionRequiredType: {
+ typ: officialACMEPrefix + ErrorUserActionRequiredType.String(),
+ details: "Visit the “instance” URL and take actions specified there",
+ status: 400,
+ },
+ ErrorServerInternalType: errorServerInternalMetadata,
+ }
+)
+
+// Error represents an ACME
+type Error struct {
+ Type string `json:"type"`
+ Detail string `json:"detail"`
+ Subproblems []interface{} `json:"subproblems,omitempty"`
+ Identifier interface{} `json:"identifier,omitempty"`
+ Err error `json:"-"`
+ Status int `json:"-"`
+}
+
+// NewError creates a new Error type.
+func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
+ return newError(pt, errors.Errorf(msg, args...))
+}
+
+func newError(pt ProblemType, err error) *Error {
+ meta, ok := errorMap[pt]
+ if !ok {
+ meta = errorServerInternalMetadata
+ return &Error{
+ Type: meta.typ,
+ Detail: meta.details,
+ Status: meta.status,
+ Err: err,
+ }
+ }
+
+ return &Error{
+ Type: meta.typ,
+ Detail: meta.details,
+ Status: meta.status,
+ Err: err,
+ }
+}
+
+// NewErrorISE creates a new ErrorServerInternalType Error.
+func NewErrorISE(msg string, args ...interface{}) *Error {
+ return NewError(ErrorServerInternalType, msg, args...)
+}
+
+// WrapError attempts to wrap the internal error.
+func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error {
switch e := err.(type) {
case nil:
return nil
case *Error:
if e.Err == nil {
- e.Err = errors.New(wrap + "; " + e.Detail)
+ e.Err = errors.Errorf(msg+"; "+e.Detail, args...)
} else {
- e.Err = errors.Wrap(e.Err, wrap)
+ e.Err = errors.Wrapf(e.Err, msg, args...)
}
return e
default:
- return ServerInternalErr(errors.Wrap(err, wrap))
+ return newError(typ, errors.Wrapf(err, msg, args...))
}
}
-// Error implements the error interface.
+// WrapErrorISE shortcut to wrap an internal server error type.
+func WrapErrorISE(err error, msg string, args ...interface{}) *Error {
+ return WrapError(ErrorServerInternalType, err, msg, args...)
+}
+
+// StatusCode returns the status code and implements the StatusCoder interface.
+func (e *Error) StatusCode() int {
+ return e.Status
+}
+
+// Error allows AError to implement the error interface.
func (e *Error) Error() string {
- if e.Err == nil {
- return e.Detail
- }
- return e.Err.Error()
+ return e.Detail
}
// Cause returns the internal error and implements the Causer interface.
@@ -412,70 +344,35 @@ func (e *Error) Cause() error {
return e.Err
}
-// Official returns true if this error's type is listed in §6.7 of RFC 8555.
-// Error types in §6.7 are registered under IETF urn namespace:
-//
-// "urn:ietf:params:acme:error:"
-//
-// and should include the namespace as a prefix when appearing as a problem
-// document.
-//
-// RFC 8555 also says:
-//
-// This list is not exhaustive. The server MAY return errors whose
-// "type" field is set to a URI other than those defined above. Servers
-// MUST NOT use the ACME URN namespace for errors not listed in the
-// appropriate IANA registry (see Section 9.6). Clients SHOULD display
-// the "detail" field of all errors.
-//
-// In this case Official returns `false` so that a different namespace can
-// be used.
-func (e *Error) Official() bool {
- return e.Type != notImplemented
-}
-
-// ToACME returns an acme representation of the problem type.
-// For official errors, the IETF ACME namespace is prepended to the error type.
-// For our own errors, we use an (yet) unregistered smallstep acme namespace.
-func (e *Error) ToACME() *AError {
- prefix := "urn:step:acme:error"
- if e.Official() {
- prefix = "urn:ietf:params:acme:error:"
+// ToLog implements the EnableLogger interface.
+func (e *Error) ToLog() (interface{}, error) {
+ b, err := json.Marshal(e)
+ if err != nil {
+ return nil, WrapErrorISE(err, "error marshaling acme.Error for logging")
}
- ae := &AError{
- Type: prefix + e.Type.String(),
- Detail: e.Error(),
- Status: e.Status,
+ return string(b), nil
+}
+
+// WriteError writes to w a JSON representation of the given error.
+func WriteError(w http.ResponseWriter, err *Error) {
+ w.Header().Set("Content-Type", "application/problem+json")
+ w.WriteHeader(err.StatusCode())
+
+ // Write errors in the response writer
+ if rl, ok := w.(logging.ResponseLogger); ok {
+ rl.WithFields(map[string]interface{}{
+ "error": err.Err,
+ })
+ if os.Getenv("STEPDEBUG") == "1" {
+ if e, ok := err.Err.(errs.StackTracer); ok {
+ rl.WithFields(map[string]interface{}{
+ "stack-trace": fmt.Sprintf("%+v", e),
+ })
+ }
+ }
}
- if e.Identifier != nil {
- ae.Identifier = *e.Identifier
+
+ if err := json.NewEncoder(w).Encode(err); err != nil {
+ log.Println(err)
}
- for _, p := range e.Sub {
- ae.Subproblems = append(ae.Subproblems, p.ToACME())
- }
- return ae
-}
-
-// StatusCode returns the status code and implements the StatusCode interface.
-func (e *Error) StatusCode() int {
- return e.Status
-}
-
-// AError is the error type as seen in acme request/responses.
-type AError struct {
- Type string `json:"type"`
- Detail string `json:"detail"`
- Identifier interface{} `json:"identifier,omitempty"`
- Subproblems []interface{} `json:"subproblems,omitempty"`
- Status int `json:"-"`
-}
-
-// Error allows AError to implement the error interface.
-func (ae *AError) Error() string {
- return ae.Detail
-}
-
-// StatusCode returns the status code and implements the StatusCode interface.
-func (ae *AError) StatusCode() int {
- return ae.Status
}
diff --git a/acme/nonce.go b/acme/nonce.go
index db680f08..25c86360 100644
--- a/acme/nonce.go
+++ b/acme/nonce.go
@@ -1,73 +1,9 @@
package acme
-import (
- "encoding/base64"
- "encoding/json"
- "time"
+// Nonce represents an ACME nonce type.
+type Nonce string
- "github.com/pkg/errors"
- "github.com/smallstep/nosql"
- "github.com/smallstep/nosql/database"
-)
-
-// nonce contains nonce metadata used in the ACME protocol.
-type nonce struct {
- ID string
- Created time.Time
-}
-
-// newNonce creates, stores, and returns an ACME replay-nonce.
-func newNonce(db nosql.DB) (*nonce, error) {
- _id, err := randID()
- if err != nil {
- return nil, err
- }
-
- id := base64.RawURLEncoding.EncodeToString([]byte(_id))
- n := &nonce{
- ID: id,
- Created: clock.Now(),
- }
- b, err := json.Marshal(n)
- if err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce"))
- }
- _, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b)
- switch {
- case err != nil:
- return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce"))
- case !swapped:
- return nil, ServerInternalErr(errors.New("error storing nonce; " +
- "value has changed since last read"))
- default:
- return n, nil
- }
-}
-
-// useNonce verifies that the nonce is valid (by checking if it exists),
-// and if so, consumes the nonce resource by deleting it from the database.
-func useNonce(db nosql.DB, nonce string) error {
- err := db.Update(&database.Tx{
- Operations: []*database.TxEntry{
- {
- Bucket: nonceTable,
- Key: []byte(nonce),
- Cmd: database.Get,
- },
- {
- Bucket: nonceTable,
- Key: []byte(nonce),
- Cmd: database.Delete,
- },
- },
- })
-
- switch {
- case nosql.IsErrNotFound(err):
- return BadNonceErr(nil)
- case err != nil:
- return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce))
- default:
- return nil
- }
+// String implements the ToString interface.
+func (n Nonce) String() string {
+ return string(n)
}
diff --git a/acme/nonce_test.go b/acme/nonce_test.go
deleted file mode 100644
index 6aa467a0..00000000
--- a/acme/nonce_test.go
+++ /dev/null
@@ -1,163 +0,0 @@
-package acme
-
-import (
- "testing"
- "time"
-
- "github.com/pkg/errors"
- "github.com/smallstep/assert"
- "github.com/smallstep/certificates/db"
- "github.com/smallstep/nosql"
- "github.com/smallstep/nosql/database"
-)
-
-func TestNewNonce(t *testing.T) {
- type test struct {
- db nosql.DB
- err *Error
- id *string
- }
- tests := map[string]func(t *testing.T) test{
- "fail/cmpAndSwap-error": func(t *testing.T) test {
- return test{
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, nonceTable)
- assert.Equals(t, old, nil)
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.Errorf("error storing nonce: force")),
- }
- },
- "fail/cmpAndSwap-false": func(t *testing.T) test {
- return test{
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, nonceTable)
- assert.Equals(t, old, nil)
- return nil, false, nil
- },
- },
- err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")),
- }
- },
- "ok": func(t *testing.T) test {
- var _id string
- id := &_id
- return test{
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, nonceTable)
- assert.Equals(t, old, nil)
- *id = string(key)
- return nil, true, nil
- },
- },
- id: id,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if n, err := newNonce(tc.db); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, n.ID, *tc.id)
-
- assert.True(t, n.Created.Before(time.Now().Add(time.Minute)))
- assert.True(t, n.Created.After(time.Now().Add(-time.Minute)))
- }
- }
- })
- }
-}
-
-func TestUseNonce(t *testing.T) {
- type test struct {
- id string
- db nosql.DB
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/update-not-found": func(t *testing.T) test {
- id := "foo"
- return test{
- db: &db.MockNoSQLDB{
- MUpdate: func(tx *database.Tx) error {
- assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
- assert.Equals(t, tx.Operations[0].Key, []byte(id))
- assert.Equals(t, tx.Operations[0].Cmd, database.Get)
-
- assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
- assert.Equals(t, tx.Operations[1].Key, []byte(id))
- assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
- return database.ErrNotFound
- },
- },
- id: id,
- err: BadNonceErr(nil),
- }
- },
- "fail/update-error": func(t *testing.T) test {
- id := "foo"
- return test{
- db: &db.MockNoSQLDB{
- MUpdate: func(tx *database.Tx) error {
- assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
- assert.Equals(t, tx.Operations[0].Key, []byte(id))
- assert.Equals(t, tx.Operations[0].Cmd, database.Get)
-
- assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
- assert.Equals(t, tx.Operations[1].Key, []byte(id))
- assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
- return errors.New("force")
- },
- },
- id: id,
- err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)),
- }
- },
- "ok": func(t *testing.T) test {
- id := "foo"
- return test{
- db: &db.MockNoSQLDB{
- MUpdate: func(tx *database.Tx) error {
- assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
- assert.Equals(t, tx.Operations[0].Key, []byte(id))
- assert.Equals(t, tx.Operations[0].Cmd, database.Get)
-
- assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
- assert.Equals(t, tx.Operations[1].Key, []byte(id))
- assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
-
- return nil
- },
- },
- id: id,
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if err := useNonce(tc.db, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- }
- })
- }
-}
diff --git a/acme/order.go b/acme/order.go
index 574477ca..a003fe9a 100644
--- a/acme/order.go
+++ b/acme/order.go
@@ -6,351 +6,129 @@ import (
"encoding/json"
"sort"
"strings"
- "sync"
"time"
- "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
- "github.com/smallstep/nosql"
"go.step.sm/crypto/x509util"
)
-var defaultOrderExpiry = time.Hour * 24
-
-// Mutex for locking ordersByAccount index operations.
-var ordersByAccountMux sync.Mutex
+// Identifier encodes the type that an order pertains to.
+type Identifier struct {
+ Type string `json:"type"`
+ Value string `json:"value"`
+}
// Order contains order metadata for the ACME protocol order type.
type Order struct {
- Status string `json:"status"`
- Expires string `json:"expires,omitempty"`
- Identifiers []Identifier `json:"identifiers"`
- NotBefore string `json:"notBefore,omitempty"`
- NotAfter string `json:"notAfter,omitempty"`
- Error interface{} `json:"error,omitempty"`
- Authorizations []string `json:"authorizations"`
- Finalize string `json:"finalize"`
- Certificate string `json:"certificate,omitempty"`
- ID string `json:"-"`
+ ID string `json:"id"`
+ AccountID string `json:"-"`
+ ProvisionerID string `json:"-"`
+ Status Status `json:"status"`
+ ExpiresAt time.Time `json:"expires"`
+ Identifiers []Identifier `json:"identifiers"`
+ NotBefore time.Time `json:"notBefore"`
+ NotAfter time.Time `json:"notAfter"`
+ Error *Error `json:"error,omitempty"`
+ AuthorizationIDs []string `json:"-"`
+ AuthorizationURLs []string `json:"authorizations"`
+ FinalizeURL string `json:"finalize"`
+ CertificateID string `json:"-"`
+ CertificateURL string `json:"certificate,omitempty"`
}
// ToLog enables response logging.
func (o *Order) ToLog() (interface{}, error) {
b, err := json.Marshal(o)
if err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging"))
+ return nil, WrapErrorISE(err, "error marshaling order for logging")
}
return string(b), nil
}
-// GetID returns the Order ID.
-func (o *Order) GetID() string {
- return o.ID
-}
-
-// OrderOptions options with which to create a new Order.
-type OrderOptions struct {
- AccountID string `json:"accID"`
- Identifiers []Identifier `json:"identifiers"`
- NotBefore time.Time `json:"notBefore"`
- NotAfter time.Time `json:"notAfter"`
- backdate time.Duration
- defaultDuration time.Duration
-}
-
-type order struct {
- ID string `json:"id"`
- AccountID string `json:"accountID"`
- Created time.Time `json:"created"`
- Expires time.Time `json:"expires,omitempty"`
- Status string `json:"status"`
- Identifiers []Identifier `json:"identifiers"`
- NotBefore time.Time `json:"notBefore,omitempty"`
- NotAfter time.Time `json:"notAfter,omitempty"`
- Error *Error `json:"error,omitempty"`
- Authorizations []string `json:"authorizations"`
- Certificate string `json:"certificate,omitempty"`
-}
-
-// newOrder returns a new Order type.
-func newOrder(db nosql.DB, ops OrderOptions) (*order, error) {
- id, err := randID()
- if err != nil {
- return nil, err
- }
-
- authzs := make([]string, len(ops.Identifiers))
- for i, identifier := range ops.Identifiers {
- az, err := newAuthz(db, ops.AccountID, identifier)
- if err != nil {
- return nil, err
- }
- authzs[i] = az.getID()
- }
-
+// UpdateStatus updates the ACME Order Status if necessary.
+// Changes to the order are saved using the database interface.
+func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
now := clock.Now()
- var backdate time.Duration
- nbf := ops.NotBefore
- if nbf.IsZero() {
- nbf = now
- backdate = -1 * ops.backdate
- }
- naf := ops.NotAfter
- if naf.IsZero() {
- naf = nbf.Add(ops.defaultDuration)
- }
- o := &order{
- ID: id,
- AccountID: ops.AccountID,
- Created: now,
- Status: StatusPending,
- Expires: now.Add(defaultOrderExpiry),
- Identifiers: ops.Identifiers,
- NotBefore: nbf.Add(backdate),
- NotAfter: naf,
- Authorizations: authzs,
- }
- if err := o.save(db, nil); err != nil {
- return nil, err
- }
-
- var oidHelper = orderIDsByAccount{}
- _, err = oidHelper.addOrderID(db, ops.AccountID, o.ID)
- if err != nil {
- return nil, err
- }
- return o, nil
-}
-
-type orderIDsByAccount struct{}
-
-// addOrderID adds an order ID to a users index of in progress order IDs.
-// This method will also cull any orders that are no longer in the `pending`
-// state from the index before returning it.
-func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) {
- ordersByAccountMux.Lock()
- defer ordersByAccountMux.Unlock()
-
- // Update the "order IDs by account ID" index
- oids, err := oiba.unsafeGetOrderIDsByAccount(db, accID)
- if err != nil {
- return nil, err
- }
- newOids := append(oids, oid)
- if err = orderIDs(newOids).save(db, oids, accID); err != nil {
- // Delete the entire order if storing the index fails.
- db.Del(orderTable, []byte(oid))
- return nil, err
- }
- return newOids, nil
-}
-
-// unsafeGetOrderIDsByAccount retrieves a list of Order IDs that were created by the
-// account.
-func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) {
- b, err := db.Get(ordersByAccountIDTable, []byte(accID))
- if err != nil {
- if nosql.IsErrNotFound(err) {
- return []string{}, nil
- }
- return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID))
- }
- var oids []string
- if err := json.Unmarshal(b, &oids); err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID))
- }
-
- // Remove any order that is not in PENDING state and update the stored list
- // before returning.
- //
- // According to RFC 8555:
- // The server SHOULD include pending orders and SHOULD NOT include orders
- // that are invalid in the array of URLs.
- pendOids := []string{}
- for _, oid := range oids {
- o, err := getOrder(db, oid)
- if err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID))
- }
- if o, err = o.updateStatus(db); err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID))
- }
- if o.Status == StatusPending {
- pendOids = append(pendOids, oid)
- }
- }
- // If the number of pending orders is less than the number of orders in the
- // list, then update the pending order list.
- if len(pendOids) != len(oids) {
- if err = orderIDs(pendOids).save(db, oids, accID); err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+
- "len(orderIDs) = %d", len(pendOids)))
- }
- }
-
- return pendOids, nil
-}
-
-type orderIDs []string
-
-// save is used to update the list of orderIDs keyed by ACME account ID
-// stored in the database.
-//
-// This method always converts empty lists to 'nil' when storing to the DB. We
-// do this to avoid any confusion between an empty list and a nil value in the
-// db.
-func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
- var (
- err error
- oldb []byte
- newb []byte
- )
- if len(old) == 0 {
- oldb = nil
- } else {
- oldb, err = json.Marshal(old)
- if err != nil {
- return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice"))
- }
- }
- if len(oids) == 0 {
- newb = nil
- } else {
- newb, err = json.Marshal(oids)
- if err != nil {
- return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice"))
- }
- }
- _, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb)
- switch {
- case err != nil:
- return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID))
- case !swapped:
- return ServerInternalErr(errors.Errorf("error storing order IDs "+
- "for account %s; order IDs changed since last read", accID))
- default:
- return nil
- }
-}
-
-func (o *order) save(db nosql.DB, old *order) error {
- var (
- err error
- oldB []byte
- )
- if old == nil {
- oldB = nil
- } else {
- if oldB, err = json.Marshal(old); err != nil {
- return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
- }
- }
-
- newB, err := json.Marshal(o)
- if err != nil {
- return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order"))
- }
-
- _, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB)
- switch {
- case err != nil:
- return ServerInternalErr(errors.Wrap(err, "error storing order"))
- case !swapped:
- return ServerInternalErr(errors.New("error storing order; " +
- "value has changed since last read"))
- default:
- return nil
- }
-}
-
-// updateStatus updates order status if necessary.
-func (o *order) updateStatus(db nosql.DB) (*order, error) {
- _newOrder := *o
- newOrder := &_newOrder
-
- now := time.Now().UTC()
switch o.Status {
case StatusInvalid:
- return o, nil
+ return nil
case StatusValid:
- return o, nil
+ return nil
case StatusReady:
- // check expiry
- if now.After(o.Expires) {
- newOrder.Status = StatusInvalid
- newOrder.Error = MalformedErr(errors.New("order has expired"))
+ // Check expiry
+ if now.After(o.ExpiresAt) {
+ o.Status = StatusInvalid
+ o.Error = NewError(ErrorMalformedType, "order has expired")
break
}
- return o, nil
+ return nil
case StatusPending:
- // check expiry
- if now.After(o.Expires) {
- newOrder.Status = StatusInvalid
- newOrder.Error = MalformedErr(errors.New("order has expired"))
+ // Check expiry
+ if now.After(o.ExpiresAt) {
+ o.Status = StatusInvalid
+ o.Error = NewError(ErrorMalformedType, "order has expired")
break
}
- var count = map[string]int{
+ var count = map[Status]int{
StatusValid: 0,
StatusInvalid: 0,
StatusPending: 0,
}
- for _, azID := range o.Authorizations {
- az, err := getAuthz(db, azID)
+ for _, azID := range o.AuthorizationIDs {
+ az, err := db.GetAuthorization(ctx, azID)
if err != nil {
- return nil, err
+ return WrapErrorISE(err, "error getting authorization ID %s", azID)
}
- if az, err = az.updateStatus(db); err != nil {
- return nil, err
+ if err = az.UpdateStatus(ctx, db); err != nil {
+ return WrapErrorISE(err, "error updating authorization ID %s", azID)
}
- st := az.getStatus()
+ st := az.Status
count[st]++
}
switch {
case count[StatusInvalid] > 0:
- newOrder.Status = StatusInvalid
+ o.Status = StatusInvalid
// No change in the order status, so just return the order as is -
// without writing any changes.
case count[StatusPending] > 0:
- return newOrder, nil
+ return nil
- case count[StatusValid] == len(o.Authorizations):
- newOrder.Status = StatusReady
+ case count[StatusValid] == len(o.AuthorizationIDs):
+ o.Status = StatusReady
default:
- return nil, ServerInternalErr(errors.New("unexpected authz status"))
+ return NewErrorISE("unexpected authz status")
}
default:
- return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status))
+ return NewErrorISE("unrecognized order status: %s", o.Status)
}
-
- if err := newOrder.save(db, o); err != nil {
- return nil, err
+ if err := db.UpdateOrder(ctx, o); err != nil {
+ return WrapErrorISE(err, "error updating order")
}
- return newOrder, nil
+ return nil
}
-// finalize signs a certificate if the necessary conditions for Order completion
+// Finalize signs a certificate if the necessary conditions for Order completion
// have been met.
-func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) (*order, error) {
- var err error
- if o, err = o.updateStatus(db); err != nil {
- return nil, err
+func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth CertificateAuthority, p Provisioner) error {
+ if err := o.UpdateStatus(ctx, db); err != nil {
+ return err
}
switch o.Status {
case StatusInvalid:
- return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID))
+ return NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID)
case StatusValid:
- return o, nil
+ return nil
case StatusPending:
- return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID))
+ return NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID)
case StatusReady:
break
default:
- return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID))
+ return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID)
}
// RFC8555: The CSR MUST indicate the exact same set of requested
@@ -361,12 +139,12 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut
if csr.Subject.CommonName != "" {
csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
}
- csr.DNSNames = uniqueLowerNames(csr.DNSNames)
+ csr.DNSNames = uniqueSortedLowerNames(csr.DNSNames)
orderNames := make([]string, len(o.Identifiers))
for i, n := range o.Identifiers {
orderNames[i] = n.Value
}
- orderNames = uniqueLowerNames(orderNames)
+ orderNames = uniqueSortedLowerNames(orderNames)
// Validate identifier names against CSR alternative names.
//
@@ -374,13 +152,15 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut
// absence of other SANs as they will only be set if the templates allows
// them.
if len(csr.DNSNames) != len(orderNames) {
- return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames))
+ return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
+ "CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
}
sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames))
for i := range csr.DNSNames {
if csr.DNSNames[i] != orderNames[i] {
- return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames))
+ return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
+ "CSR names = %v, Order names = %v", csr.DNSNames, orderNames)
}
sans[i] = x509util.SubjectAlternativeName{
Type: x509util.DNSType,
@@ -389,10 +169,10 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut
}
// Get authorizations from the ACME provisioner.
- ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
+ ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOps, err := p.AuthorizeSign(ctx, "")
if err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner"))
+ return WrapErrorISE(err, "error retrieving authorization options from ACME provisioner")
}
// Template data
@@ -402,82 +182,41 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut
templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data)
if err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner"))
+ return WrapErrorISE(err, "error creating template options from ACME provisioner")
}
signOps = append(signOps, templateOptions)
- // Create and store a new certificate.
+ // Sign a new certificate.
certChain, err := auth.Sign(csr, provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
}, signOps...)
if err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID))
+ return WrapErrorISE(err, "error signing certificate for order %s", o.ID)
}
- cert, err := newCert(db, CertOptions{
+ cert := &Certificate{
AccountID: o.AccountID,
OrderID: o.ID,
Leaf: certChain[0],
Intermediates: certChain[1:],
- })
- if err != nil {
- return nil, err
+ }
+ if err := db.CreateCertificate(ctx, cert); err != nil {
+ return WrapErrorISE(err, "error creating certificate for order %s", o.ID)
}
- _newOrder := *o
- newOrder := &_newOrder
- newOrder.Certificate = cert.ID
- newOrder.Status = StatusValid
- if err := newOrder.save(db, o); err != nil {
- return nil, err
+ o.CertificateID = cert.ID
+ o.Status = StatusValid
+ if err = db.UpdateOrder(ctx, o); err != nil {
+ return WrapErrorISE(err, "error updating order %s", o.ID)
}
- return newOrder, nil
+ return nil
}
-// getOrder retrieves and unmarshals an ACME Order type from the database.
-func getOrder(db nosql.DB, id string) (*order, error) {
- b, err := db.Get(orderTable, []byte(id))
- if nosql.IsErrNotFound(err) {
- return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id))
- } else if err != nil {
- return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id))
- }
- var o order
- if err := json.Unmarshal(b, &o); err != nil {
- return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order"))
- }
- return &o, nil
-}
-
-// toACME converts the internal Order type into the public acmeOrder type for
-// presentation in the ACME protocol.
-func (o *order) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Order, error) {
- azs := make([]string, len(o.Authorizations))
- for i, aid := range o.Authorizations {
- azs[i] = dir.getLink(ctx, AuthzLink, true, aid)
- }
- ao := &Order{
- Status: o.Status,
- Expires: o.Expires.Format(time.RFC3339),
- Identifiers: o.Identifiers,
- NotBefore: o.NotBefore.Format(time.RFC3339),
- NotAfter: o.NotAfter.Format(time.RFC3339),
- Authorizations: azs,
- Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID),
- ID: o.ID,
- }
-
- if o.Certificate != "" {
- ao.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate)
- }
- return ao, nil
-}
-
-// uniqueLowerNames returns the set of all unique names in the input after all
+// uniqueSortedLowerNames returns the set of all unique names in the input after all
// of them are lowercased. The returned names will be in their lowercased form
// and sorted alphabetically.
-func uniqueLowerNames(names []string) (unique []string) {
+func uniqueSortedLowerNames(names []string) (unique []string) {
nameMap := make(map[string]int, len(names))
for _, name := range names {
nameMap[strings.ToLower(name)] = 1
diff --git a/acme/order_test.go b/acme/order_test.go
index e6a8f057..993a92f2 100644
--- a/acme/order_test.go
+++ b/acme/order_test.go
@@ -5,865 +5,232 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
- "fmt"
- "net"
- "net/url"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
- "github.com/smallstep/certificates/db"
- "github.com/smallstep/nosql"
- "github.com/smallstep/nosql/database"
)
-var certDuration = 6 * time.Hour
-
-func defaultOrderOps() OrderOptions {
- return OrderOptions{
- AccountID: "accID",
- Identifiers: []Identifier{
- {Type: "dns", Value: "acme.example.com"},
- {Type: "dns", Value: "step.example.com"},
- },
- NotBefore: clock.Now(),
- NotAfter: clock.Now().Add(certDuration),
- }
-}
-
-func newO() (*order, error) {
- mockdb := &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), true, nil
- },
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- }
- return newOrder(mockdb, defaultOrderOps())
-}
-
-func Test_getOrder(t *testing.T) {
+func TestOrder_UpdateStatus(t *testing.T) {
type test struct {
- id string
- db nosql.DB
- o *order
+ o *Order
err *Error
+ db DB
}
tests := map[string]func(t *testing.T) test{
- "fail/not-found": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
+ "ok/already-invalid": func(t *testing.T) test {
+ o := &Order{
+ Status: StatusInvalid,
+ }
return test{
- o: o,
- id: o.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- },
- err: MalformedErr(errors.Errorf("order %s not found: not found", o.ID)),
+ o: o,
}
},
- "fail/db-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
+ "ok/already-valid": func(t *testing.T) test {
+ o := &Order{
+ Status: StatusInvalid,
+ }
return test{
- o: o,
- id: o.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.Errorf("error loading order %s: force", o.ID)),
+ o: o,
}
},
- "fail/unmarshal-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- return test{
- o: o,
- id: o.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(o.ID))
- return nil, nil
- },
- },
- err: ServerInternalErr(errors.New("error unmarshaling order: unexpected end of JSON input")),
+ "fail/error-unexpected-status": func(t *testing.T) test {
+ o := &Order{
+ Status: "foo",
}
- },
- "ok": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- return test{
- o: o,
- id: o.ID,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(o.ID))
- return b, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if o, err := getOrder(tc.db, tc.id); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.o.ID, o.ID)
- assert.Equals(t, tc.o.AccountID, o.AccountID)
- assert.Equals(t, tc.o.Status, o.Status)
- assert.Equals(t, tc.o.Identifiers, o.Identifiers)
- assert.Equals(t, tc.o.Created, o.Created)
- assert.Equals(t, tc.o.Expires, o.Expires)
- assert.Equals(t, tc.o.Authorizations, o.Authorizations)
- assert.Equals(t, tc.o.NotBefore, o.NotBefore)
- assert.Equals(t, tc.o.NotAfter, o.NotAfter)
- assert.Equals(t, tc.o.Certificate, o.Certificate)
- assert.Equals(t, tc.o.Error, o.Error)
- }
- }
- })
- }
-}
-
-func TestOrderToACME(t *testing.T) {
- dir := newDirectory("ca.smallstep.com", "acme")
- prov := newProv()
- provName := url.PathEscape(prov.GetName())
- baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
- ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
- ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
-
- type test struct {
- o *order
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "ok/no-cert": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- return test{o: o}
- },
- "ok/cert": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusValid
- o.Certificate = "cert-id"
- return test{o: o}
- },
- }
- for name, run := range tests {
- tc := run(t)
- t.Run(name, func(t *testing.T) {
- acmeOrder, err := tc.o.toACME(ctx, nil, dir)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, acmeOrder.ID, tc.o.ID)
- assert.Equals(t, acmeOrder.Status, tc.o.Status)
- assert.Equals(t, acmeOrder.Identifiers, tc.o.Identifiers)
- assert.Equals(t, acmeOrder.Finalize,
- fmt.Sprintf("%s/acme/%s/order/%s/finalize", baseURL.String(), provName, tc.o.ID))
- if tc.o.Certificate != "" {
- assert.Equals(t, acmeOrder.Certificate, fmt.Sprintf("%s/acme/%s/certificate/%s", baseURL.String(), provName, tc.o.Certificate))
- }
-
- expiry, err := time.Parse(time.RFC3339, acmeOrder.Expires)
- assert.FatalError(t, err)
- assert.Equals(t, expiry.String(), tc.o.Expires.String())
- nbf, err := time.Parse(time.RFC3339, acmeOrder.NotBefore)
- assert.FatalError(t, err)
- assert.Equals(t, nbf.String(), tc.o.NotBefore.String())
- naf, err := time.Parse(time.RFC3339, acmeOrder.NotAfter)
- assert.FatalError(t, err)
- assert.Equals(t, naf.String(), tc.o.NotAfter.String())
- }
- }
- })
- }
-}
-
-func TestOrderSave(t *testing.T) {
- type test struct {
- o, old *order
- db nosql.DB
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/old-nil/swap-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
return test{
o: o,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error storing order: force")),
+ err: NewErrorISE("unrecognized order status: %s", o.Status),
}
},
- "fail/old-nil/swap-false": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- return test{
- o: o,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), false, nil
- },
- },
- err: ServerInternalErr(errors.New("error storing order; value has changed since last read")),
+ "ok/ready-expired": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(-5 * time.Minute),
}
- },
- "ok/old-nil": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
return test{
- o: o,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, nil)
- assert.Equals(t, b, newval)
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, []byte(o.ID), key)
- return nil, true, nil
- },
- },
- }
- },
- "ok/old-not-nil": func(t *testing.T) test {
- oldo, err := newO()
- assert.FatalError(t, err)
- o, err := newO()
- assert.FatalError(t, err)
-
- oldb, err := json.Marshal(oldo)
- assert.FatalError(t, err)
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- return test{
- o: o,
- old: oldo,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, oldb)
- assert.Equals(t, b, newval)
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, []byte(o.ID), key)
- return []byte("foo"), true, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if err := tc.o.save(tc.db, tc.old); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- assert.Nil(t, tc.err)
- }
- })
- }
-}
-
-func Test_newOrder(t *testing.T) {
- type test struct {
- ops OrderOptions
- db nosql.DB
- err *Error
- authzs *([]string)
- }
- tests := map[string]func(t *testing.T) test{
- "fail/unexpected-identifier-type": func(t *testing.T) test {
- ops := defaultOrderOps()
- ops.Identifiers[0].Type = "foo"
- return test{
- ops: ops,
- err: MalformedErr(errors.New("unexpected authz type foo")),
- }
- },
- "fail/save-order-error": func(t *testing.T) test {
- count := 0
- return test{
- ops: defaultOrderOps(),
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count >= 8 {
- return nil, false, errors.New("force")
- }
- count++
- return nil, true, nil
- },
- },
- err: ServerInternalErr(errors.New("error storing order: force")),
- }
- },
- "fail/get-orderIDs-error": func(t *testing.T) test {
- count := 0
- ops := defaultOrderOps()
- return test{
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count >= 9 {
- return nil, false, errors.New("force")
- }
- count++
- return nil, true, nil
- },
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.Errorf("error loading orderIDs for account %s: force", ops.AccountID)),
- }
- },
- "fail/save-orderIDs-error": func(t *testing.T) test {
- count := 0
- var (
- _oid = ""
- oid = &_oid
- )
- ops := defaultOrderOps()
- return test{
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count >= 9 {
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte(ops.AccountID))
- return nil, false, errors.New("force")
- } else if count == 8 {
- *oid = string(key)
- }
- count++
- return nil, true, nil
- },
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- MDel: func(bucket, key []byte) error {
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte(*oid))
+ o: o,
+ db: &MockDB{
+ MockUpdateOrder: func(ctx context.Context, updo *Order) error {
+ assert.Equals(t, updo.ID, o.ID)
+ assert.Equals(t, updo.AccountID, o.AccountID)
+ assert.Equals(t, updo.Status, StatusInvalid)
+ assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
return nil
},
},
- err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", ops.AccountID)),
}
},
- "ok": func(t *testing.T) test {
- count := 0
- authzs := &([]string{})
- var (
- _oid = ""
- oid = &_oid
- )
- ops := defaultOrderOps()
+ "fail/ready-expired-db.UpdateOrder-error": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(-5 * time.Minute),
+ }
return test{
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count >= 9 {
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte(ops.AccountID))
- assert.Equals(t, old, nil)
- newB, err := json.Marshal([]string{*oid})
- assert.FatalError(t, err)
- assert.Equals(t, newval, newB)
- } else if count == 8 {
- *oid = string(key)
- } else if count == 7 {
- *authzs = append(*authzs, string(key))
- } else if count == 3 {
- *authzs = []string{string(key)}
- }
- count++
- return nil, true, nil
- },
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
+ o: o,
+ db: &MockDB{
+ MockUpdateOrder: func(ctx context.Context, updo *Order) error {
+ assert.Equals(t, updo.ID, o.ID)
+ assert.Equals(t, updo.AccountID, o.AccountID)
+ assert.Equals(t, updo.Status, StatusInvalid)
+ assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
+ return errors.New("force")
},
},
- authzs: authzs,
+ err: NewErrorISE("error updating order: force"),
}
},
- "ok/validity-bounds-not-set": func(t *testing.T) test {
- count := 0
- authzs := &([]string{})
- var (
- _oid = ""
- oid = &_oid
- )
- ops := defaultOrderOps()
- ops.backdate = time.Minute
- ops.defaultDuration = 12 * time.Hour
- ops.NotBefore = time.Time{}
- ops.NotAfter = time.Time{}
+ "ok/pending-expired": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusPending,
+ ExpiresAt: now.Add(-5 * time.Minute),
+ }
return test{
- ops: ops,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count >= 9 {
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte(ops.AccountID))
- assert.Equals(t, old, nil)
- newB, err := json.Marshal([]string{*oid})
- assert.FatalError(t, err)
- assert.Equals(t, newval, newB)
- } else if count == 8 {
- *oid = string(key)
- } else if count == 7 {
- *authzs = append(*authzs, string(key))
- } else if count == 3 {
- *authzs = []string{string(key)}
- }
- count++
- return nil, true, nil
- },
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- },
- authzs: authzs,
- }
- },
- }
- for name, run := range tests {
- tc := run(t)
- t.Run(name, func(t *testing.T) {
- o, err := newOrder(tc.db, tc.ops)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, o.AccountID, tc.ops.AccountID)
- assert.Equals(t, o.Status, StatusPending)
- assert.Equals(t, o.Identifiers, tc.ops.Identifiers)
- assert.Equals(t, o.Error, nil)
- assert.Equals(t, o.Certificate, "")
- assert.Equals(t, o.Authorizations, *tc.authzs)
+ o: o,
+ db: &MockDB{
+ MockUpdateOrder: func(ctx context.Context, updo *Order) error {
+ assert.Equals(t, updo.ID, o.ID)
+ assert.Equals(t, updo.AccountID, o.AccountID)
+ assert.Equals(t, updo.Status, StatusInvalid)
+ assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
- assert.True(t, o.Created.Before(time.Now().UTC().Add(time.Minute)))
- assert.True(t, o.Created.After(time.Now().UTC().Add(-1*time.Minute)))
-
- expiry := o.Created.Add(defaultExpiryDuration)
- assert.True(t, o.Expires.Before(expiry.Add(time.Minute)))
- assert.True(t, o.Expires.After(expiry.Add(-1*time.Minute)))
-
- nbf := tc.ops.NotBefore
- now := time.Now().UTC()
- if !tc.ops.NotBefore.IsZero() {
- assert.Equals(t, o.NotBefore, tc.ops.NotBefore)
- } else {
- nbf = o.NotBefore.Add(tc.ops.backdate)
- assert.True(t, o.NotBefore.Before(now.Add(-tc.ops.backdate+time.Second)))
- assert.True(t, o.NotBefore.Add(tc.ops.backdate+2*time.Second).After(now))
- }
- if !tc.ops.NotAfter.IsZero() {
- assert.Equals(t, o.NotAfter, tc.ops.NotAfter)
- } else {
- naf := nbf.Add(tc.ops.defaultDuration)
- assert.Equals(t, o.NotAfter, naf)
- }
- }
- }
- })
- }
-}
-
-func TestOrderIDs_save(t *testing.T) {
- accID := "acc-id"
- newOids := func() orderIDs {
- return []string{"1", "2"}
- }
- type test struct {
- oids, old orderIDs
- db nosql.DB
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "fail/old-nil/swap-error": func(t *testing.T) test {
- return test{
- oids: newOids(),
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", accID)),
- }
- },
- "fail/old-nil/swap-false": func(t *testing.T) test {
- return test{
- oids: newOids(),
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return []byte("foo"), false, nil
- },
- },
- err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s; order IDs changed since last read", accID)),
- }
- },
- "ok/old-nil": func(t *testing.T) test {
- oids := newOids()
- b, err := json.Marshal(oids)
- assert.FatalError(t, err)
- return test{
- oids: oids,
- old: nil,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, nil)
- assert.Equals(t, b, newval)
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte(accID))
- return nil, true, nil
- },
- },
- }
- },
- "ok/old-not-nil": func(t *testing.T) test {
- oldOids := newOids()
- oids := append(oldOids, "3")
-
- oldb, err := json.Marshal(oldOids)
- assert.FatalError(t, err)
- b, err := json.Marshal(oids)
- assert.FatalError(t, err)
- return test{
- oids: oids,
- old: oldOids,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, b)
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte(accID))
- return nil, true, nil
- },
- },
- }
- },
- "ok/new-empty-saved-as-nil": func(t *testing.T) test {
- oldOids := newOids()
- oids := []string{}
-
- oldb, err := json.Marshal(oldOids)
- assert.FatalError(t, err)
- return test{
- oids: oids,
- old: oldOids,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, old, oldb)
- assert.Equals(t, newval, nil)
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte(accID))
- return nil, true, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- if err := tc.oids.save(tc.db, tc.old, accID); err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- assert.Nil(t, tc.err)
- }
- })
- }
-}
-
-func TestOrderUpdateStatus(t *testing.T) {
- type test struct {
- o, res *order
- err *Error
- db nosql.DB
- }
- tests := map[string]func(t *testing.T) test{
- "fail/already-invalid": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusInvalid
- return test{
- o: o,
- res: o,
- }
- },
- "fail/already-valid": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusValid
- return test{
- o: o,
- res: o,
- }
- },
- "fail/unexpected-status": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusDeactivated
- return test{
- o: o,
- res: o,
- err: ServerInternalErr(errors.New("unrecognized order status: deactivated")),
- }
- },
- "fail/save-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Expires = time.Now().UTC().Add(-time.Minute)
- return test{
- o: o,
- res: o,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error storing order: force")),
- }
- },
- "ok/expired": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Expires = time.Now().UTC().Add(-time.Minute)
-
- _o := *o
- clone := &_o
- clone.Error = MalformedErr(errors.New("order has expired"))
- clone.Status = StatusInvalid
- return test{
- o: o,
- res: clone,
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, true, nil
- },
- },
- }
- },
- "fail/get-authz-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- return test{
- o: o,
- res: o,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error loading authz")),
- }
- },
- "ok/still-pending": func(t *testing.T) test {
- az1, err := newAz()
- assert.FatalError(t, err)
- az2, err := newAz()
- assert.FatalError(t, err)
- az3, err := newAz()
- assert.FatalError(t, err)
-
- ch1, err := newHTTPCh()
- assert.FatalError(t, err)
- ch2, err := newTLSALPNCh()
- assert.FatalError(t, err)
- ch3, err := newDNSCh()
- assert.FatalError(t, err)
-
- ch1b, err := json.Marshal(ch1)
- assert.FatalError(t, err)
- ch2b, err := json.Marshal(ch2)
- assert.FatalError(t, err)
- ch3b, err := json.Marshal(ch3)
- assert.FatalError(t, err)
-
- o, err := newO()
- assert.FatalError(t, err)
- o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()}
-
- _az3, ok := az3.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az3.baseAuthz.Status = StatusValid
-
- b1, err := json.Marshal(az1)
- assert.FatalError(t, err)
- b2, err := json.Marshal(az2)
- assert.FatalError(t, err)
- b3, err := json.Marshal(az3)
- assert.FatalError(t, err)
-
- count := 0
- return test{
- o: o,
- res: o,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- var ret []byte
- switch count {
- case 0:
- ret = b1
- case 1:
- ret = ch1b
- case 2:
- ret = ch2b
- case 3:
- ret = ch3b
- case 4:
- ret = b2
- case 5:
- ret = ch1b
- case 6:
- ret = ch2b
- case 7:
- ret = ch3b
- case 8:
- ret = b3
- default:
- return nil, errors.New("unexpected count")
- }
- count++
- return ret, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, true, nil
+ err := NewError(ErrorMalformedType, "order has expired")
+ assert.HasPrefix(t, updo.Error.Err.Error(), err.Err.Error())
+ assert.Equals(t, updo.Error.Type, err.Type)
+ assert.Equals(t, updo.Error.Detail, err.Detail)
+ assert.Equals(t, updo.Error.Status, err.Status)
+ assert.Equals(t, updo.Error.Detail, err.Detail)
+ return nil
},
},
}
},
"ok/invalid": func(t *testing.T) test {
- az1, err := newAz()
- assert.FatalError(t, err)
- az2, err := newAz()
- assert.FatalError(t, err)
- az3, err := newAz()
- assert.FatalError(t, err)
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusPending,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ }
+ az1 := &Authorization{
+ ID: "a",
+ Status: StatusValid,
+ }
+ az2 := &Authorization{
+ ID: "b",
+ Status: StatusInvalid,
+ }
- ch1, err := newHTTPCh()
- assert.FatalError(t, err)
- ch2, err := newTLSALPNCh()
- assert.FatalError(t, err)
- ch3, err := newDNSCh()
- assert.FatalError(t, err)
-
- ch1b, err := json.Marshal(ch1)
- assert.FatalError(t, err)
- ch2b, err := json.Marshal(ch2)
- assert.FatalError(t, err)
- ch3b, err := json.Marshal(ch3)
- assert.FatalError(t, err)
-
- o, err := newO()
- assert.FatalError(t, err)
- o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()}
-
- _az3, ok := az3.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az3.baseAuthz.Status = StatusInvalid
-
- b1, err := json.Marshal(az1)
- assert.FatalError(t, err)
- b2, err := json.Marshal(az2)
- assert.FatalError(t, err)
- b3, err := json.Marshal(az3)
- assert.FatalError(t, err)
-
- _o := *o
- clone := &_o
- clone.Status = StatusInvalid
-
- count := 0
return test{
- o: o,
- res: clone,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- var ret []byte
- switch count {
- case 0:
- ret = b1
- case 1:
- ret = ch1b
- case 2:
- ret = ch2b
- case 3:
- ret = ch3b
- case 4:
- ret = b2
- case 5:
- ret = ch1b
- case 6:
- ret = ch2b
- case 7:
- ret = ch3b
- case 8:
- ret = b3
- default:
- return nil, errors.New("unexpected count")
- }
- count++
- return ret, nil
+ o: o,
+ db: &MockDB{
+ MockUpdateOrder: func(ctx context.Context, updo *Order) error {
+ assert.Equals(t, updo.ID, o.ID)
+ assert.Equals(t, updo.AccountID, o.AccountID)
+ assert.Equals(t, updo.Status, StatusInvalid)
+ assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
+ return nil
},
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, true, nil
+ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
+ switch id {
+ case az1.ID:
+ return az1, nil
+ case az2.ID:
+ return az2, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected authz key %s", id))
+ return nil, errors.New("force")
+ }
+ },
+ },
+ }
+ },
+ "ok/still-pending": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusPending,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ }
+ az1 := &Authorization{
+ ID: "a",
+ Status: StatusValid,
+ }
+ az2 := &Authorization{
+ ID: "b",
+ Status: StatusPending,
+ }
+
+ return test{
+ o: o,
+ db: &MockDB{
+ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
+ switch id {
+ case az1.ID:
+ return az1, nil
+ case az2.ID:
+ return az2, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected authz key %s", id))
+ return nil, errors.New("force")
+ }
+ },
+ },
+ }
+ },
+ "ok/valid": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusPending,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ }
+ az1 := &Authorization{
+ ID: "a",
+ Status: StatusValid,
+ }
+ az2 := &Authorization{
+ ID: "b",
+ Status: StatusValid,
+ }
+
+ return test{
+ o: o,
+ db: &MockDB{
+ MockUpdateOrder: func(ctx context.Context, updo *Order) error {
+ assert.Equals(t, updo.ID, o.ID)
+ assert.Equals(t, updo.AccountID, o.AccountID)
+ assert.Equals(t, updo.Status, StatusReady)
+ assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
+ return nil
+ },
+ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
+ switch id {
+ case az1.ID:
+ return az1, nil
+ case az2.ID:
+ return az2, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected authz key %s", id))
+ return nil, errors.New("force")
+ }
},
},
}
@@ -872,25 +239,24 @@ func TestOrderUpdateStatus(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
- o, err := tc.o.updateStatus(tc.db)
- if err != nil {
+ if err := tc.o.UpdateStatus(context.Background(), tc.db); err != nil {
if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
}
} else {
- if assert.Nil(t, tc.err) {
- expB, err := json.Marshal(tc.res)
- assert.FatalError(t, err)
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- assert.Equals(t, expB, b)
- }
+ assert.Nil(t, tc.err)
}
})
+
}
}
@@ -917,820 +283,456 @@ func (m *mockSignAuth) LoadProvisionerByID(id string) (provisioner.Interface, er
return m.ret1.(provisioner.Interface), m.err
}
-func TestOrderFinalize(t *testing.T) {
- prov := newProv()
+func TestOrder_Finalize(t *testing.T) {
type test struct {
- o, res *order
- err *Error
- db nosql.DB
- csr *x509.CertificateRequest
- sa SignAuthority
- prov Provisioner
+ o *Order
+ err *Error
+ db DB
+ ca CertificateAuthority
+ csr *x509.CertificateRequest
+ prov Provisioner
}
tests := map[string]func(t *testing.T) test{
- "fail/already-invalid": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusInvalid
+ "fail/invalid": func(t *testing.T) test {
+ o := &Order{
+ ID: "oid",
+ Status: StatusInvalid,
+ }
return test{
o: o,
- err: OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)),
+ err: NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID),
+ }
+ },
+ "fail/pending": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusPending,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ }
+ az1 := &Authorization{
+ ID: "a",
+ Status: StatusValid,
+ }
+ az2 := &Authorization{
+ ID: "b",
+ Status: StatusPending,
+ ExpiresAt: now.Add(5 * time.Minute),
+ }
+
+ return test{
+ o: o,
+ db: &MockDB{
+ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
+ switch id {
+ case az1.ID:
+ return az1, nil
+ case az2.ID:
+ return az2, nil
+ default:
+ assert.FatalError(t, errors.Errorf("unexpected authz key %s", id))
+ return nil, errors.New("force")
+ }
+ },
+ },
+ err: NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID),
}
},
"ok/already-valid": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusValid
- o.Certificate = "cert-id"
+ o := &Order{
+ ID: "oid",
+ Status: StatusValid,
+ }
return test{
- o: o,
- res: o,
+ o: o,
}
},
- "fail/still-pending": func(t *testing.T) test {
- az1, err := newAz()
- assert.FatalError(t, err)
- az2, err := newAz()
- assert.FatalError(t, err)
- az3, err := newAz()
- assert.FatalError(t, err)
+ "fail/error-unexpected-status": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: "foo",
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ }
- ch1, err := newHTTPCh()
- assert.FatalError(t, err)
- ch2, err := newTLSALPNCh()
- assert.FatalError(t, err)
- ch3, err := newDNSCh()
- assert.FatalError(t, err)
-
- ch1b, err := json.Marshal(ch1)
- assert.FatalError(t, err)
- ch2b, err := json.Marshal(ch2)
- assert.FatalError(t, err)
- ch3b, err := json.Marshal(ch3)
- assert.FatalError(t, err)
-
- o, err := newO()
- assert.FatalError(t, err)
- o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()}
-
- _az3, ok := az3.(*dnsAuthz)
- assert.Fatal(t, ok)
- _az3.baseAuthz.Status = StatusValid
-
- b1, err := json.Marshal(az1)
- assert.FatalError(t, err)
- b2, err := json.Marshal(az2)
- assert.FatalError(t, err)
- b3, err := json.Marshal(az3)
- assert.FatalError(t, err)
-
- count := 0
return test{
o: o,
- res: o,
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- var ret []byte
- switch count {
- case 0:
- ret = b1
- case 1:
- ret = ch1b
- case 2:
- ret = ch2b
- case 3:
- ret = ch3b
- case 4:
- ret = b2
- case 5:
- ret = ch1b
- case 6:
- ret = ch2b
- case 7:
- ret = ch3b
- case 8:
- ret = b3
- default:
- return nil, errors.New("unexpected count")
- }
- count++
- return ret, nil
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, true, nil
- },
+ err: NewErrorISE("unrecognized order status: %s", o.Status),
+ }
+ },
+ "fail/error-names-length-mismatch": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ Identifiers: []Identifier{
+ {Type: "dns", Value: "foo.internal"},
+ {Type: "dns", Value: "bar.internal"},
},
- err: OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)),
}
- },
- "fail/ready/csr-names-match-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
+ orderNames := []string{"bar.internal", "foo.internal"}
csr := &x509.CertificateRequest{
Subject: pkix.Name{
- CommonName: "acme.example.com",
+ CommonName: "foo.internal",
},
- DNSNames: []string{"acme.example.com", "fail.smallstep.com"},
}
+
return test{
o: o,
csr: csr,
- err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")),
+ err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
+ "CSR names = %v, Order names = %v", []string{"foo.internal"}, orderNames),
}
},
- "fail/ready/csr-names-match-error-2": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
+ "fail/error-names-mismatch": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ Identifiers: []Identifier{
+ {Type: "dns", Value: "foo.internal"},
+ {Type: "dns", Value: "bar.internal"},
+ },
+ }
+ orderNames := []string{"bar.internal", "foo.internal"}
csr := &x509.CertificateRequest{
Subject: pkix.Name{
- CommonName: "",
+ CommonName: "foo.internal",
},
- DNSNames: []string{"acme.example.com"},
+ DNSNames: []string{"zap.internal"},
}
+
return test{
o: o,
csr: csr,
- err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")),
+ err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+
+ "CSR names = %v, Order names = %v", []string{"foo.internal", "zap.internal"}, orderNames),
}
},
- "fail/ready/no-ipAddresses": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
+ "fail/error-provisioner-auth": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ Identifiers: []Identifier{
+ {Type: "dns", Value: "foo.internal"},
+ {Type: "dns", Value: "bar.internal"},
+ },
+ }
csr := &x509.CertificateRequest{
Subject: pkix.Name{
- CommonName: "",
+ CommonName: "foo.internal",
},
- // DNSNames: []string{"acme.example.com", "step.example.com"},
- IPAddresses: []net.IP{net.ParseIP("1.1.1.1")},
+ DNSNames: []string{"bar.internal"},
}
+
return test{
o: o,
csr: csr,
- err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")),
- }
- },
- "fail/ready/no-emailAddresses": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
- csr := &x509.CertificateRequest{
- Subject: pkix.Name{
- CommonName: "",
- },
- // DNSNames: []string{"acme.example.com", "step.example.com"},
- EmailAddresses: []string{"max@smallstep.com", "mariano@smallstep.com"},
- }
- return test{
- o: o,
- csr: csr,
- err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")),
- }
- },
- "fail/ready/no-URIs": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
- u, err := url.Parse("https://google.com")
- assert.FatalError(t, err)
- csr := &x509.CertificateRequest{
- Subject: pkix.Name{
- CommonName: "",
- },
- // DNSNames: []string{"acme.example.com", "step.example.com"},
- URIs: []*url.URL{u},
- }
- return test{
- o: o,
- csr: csr,
- err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")),
- }
- },
- "fail/ready/provisioner-auth-sign-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
- csr := &x509.CertificateRequest{
- Subject: pkix.Name{
- CommonName: "acme.example.com",
- },
- DNSNames: []string{"step.example.com", "acme.example.com"},
- }
- return test{
- o: o,
- csr: csr,
- err: ServerInternalErr(errors.New("error retrieving authorization options from ACME provisioner: force")),
prov: &MockProvisioner{
MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
+ assert.Equals(t, token, "")
return nil, errors.New("force")
},
},
+ err: NewErrorISE("error retrieving authorization options from ACME provisioner: force"),
}
},
- "fail/ready/sign-cert-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
+ "fail/error-template-options": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ Identifiers: []Identifier{
+ {Type: "dns", Value: "foo.internal"},
+ {Type: "dns", Value: "bar.internal"},
+ },
+ }
csr := &x509.CertificateRequest{
Subject: pkix.Name{
- CommonName: "acme.example.com",
+ CommonName: "foo.internal",
},
- DNSNames: []string{"step.example.com", "acme.example.com"},
+ DNSNames: []string{"bar.internal"},
}
+
return test{
o: o,
csr: csr,
- err: ServerInternalErr(errors.Errorf("error generating certificate for order %s: force", o.ID)),
- sa: &mockSignAuth{
- err: errors.New("force"),
- },
- }
- },
- "fail/ready/store-cert-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
- csr := &x509.CertificateRequest{
- Subject: pkix.Name{
- CommonName: "acme.example.com",
- },
- DNSNames: []string{"step.example.com", "acme.example.com"},
- }
- crt := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "acme.example.com",
- },
- }
- inter := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "intermediate",
- },
- }
- return test{
- o: o,
- csr: csr,
- err: ServerInternalErr(errors.Errorf("error storing certificate: force")),
- sa: &mockSignAuth{
- ret1: crt, ret2: inter,
- },
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("force")
- },
- },
- }
- },
- "fail/ready/store-order-error": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
- csr := &x509.CertificateRequest{
- Subject: pkix.Name{
- CommonName: "acme.example.com",
- },
- DNSNames: []string{"acme.example.com", "step.example.com"},
- }
- crt := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "acme.example.com",
- },
- }
- inter := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "intermediate",
- },
- }
- count := 0
- return test{
- o: o,
- csr: csr,
- err: ServerInternalErr(errors.Errorf("error storing order: force")),
- sa: &mockSignAuth{
- ret1: crt, ret2: inter,
- },
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 1 {
- return nil, false, errors.New("force")
- }
- count++
- return nil, true, nil
- },
- },
- }
- },
- "ok/ready/sign": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
- csr := &x509.CertificateRequest{
- Subject: pkix.Name{
- CommonName: "acme.example.com",
- },
- DNSNames: []string{"acme.example.com", "step.example.com"},
- }
- crt := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "acme.example.com",
- },
- }
- inter := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "intermediate",
- },
- }
-
- _o := *o
- clone := &_o
- clone.Status = StatusValid
-
- count := 0
- return test{
- o: o,
- res: clone,
- csr: csr,
- sa: &mockSignAuth{
- sign: func(csr *x509.CertificateRequest, pops provisioner.SignOptions, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) {
- assert.Equals(t, len(signOps), 6)
- return []*x509.Certificate{crt, inter}, nil
- },
- },
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 0 {
- clone.Certificate = string(key)
- }
- count++
- return nil, true, nil
- },
- },
- }
- },
- "ok/ready/no-sans": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
- o.Identifiers = []Identifier{
- {Type: "dns", Value: "step.example.com"},
- }
-
- csr := &x509.CertificateRequest{
- Subject: pkix.Name{
- CommonName: "step.example.com",
- },
- }
- crt := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "step.example.com",
- },
- DNSNames: []string{"step.example.com"},
- }
- inter := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "intermediate",
- },
- }
-
- clone := *o
- clone.Status = StatusValid
- count := 0
- return test{
- o: o,
- res: &clone,
- csr: csr,
- sa: &mockSignAuth{
- sign: func(csr *x509.CertificateRequest, pops provisioner.SignOptions, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) {
- assert.Equals(t, len(signOps), 6)
- return []*x509.Certificate{crt, inter}, nil
- },
- },
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 0 {
- clone.Certificate = string(key)
- }
- count++
- return nil, true, nil
- },
- },
- }
- },
- "ok/ready/sans-and-name": func(t *testing.T) test {
- o, err := newO()
- assert.FatalError(t, err)
- o.Status = StatusReady
-
- csr := &x509.CertificateRequest{
- Subject: pkix.Name{
- CommonName: "acme.example.com",
- },
- DNSNames: []string{"step.example.com"},
- }
- crt := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "acme.example.com",
- },
- DNSNames: []string{"acme.example.com", "step.example.com"},
- }
- inter := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: "intermediate",
- },
- }
-
- clone := *o
- clone.Status = StatusValid
- count := 0
- return test{
- o: o,
- res: &clone,
- csr: csr,
- sa: &mockSignAuth{
- sign: func(csr *x509.CertificateRequest, pops provisioner.SignOptions, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) {
- assert.Equals(t, len(signOps), 6)
- return []*x509.Certificate{crt, inter}, nil
- },
- },
- db: &db.MockNoSQLDB{
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- if count == 0 {
- clone.Certificate = string(key)
- }
- count++
- return nil, true, nil
- },
- },
- }
- },
- }
- for name, run := range tests {
- t.Run(name, func(t *testing.T) {
- tc := run(t)
- p := tc.prov
- if p == nil {
- p = prov
- }
- o, err := tc.o.finalize(tc.db, tc.csr, tc.sa, p)
- if err != nil {
- if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
- }
- } else {
- if assert.Nil(t, tc.err) {
- expB, err := json.Marshal(tc.res)
- assert.FatalError(t, err)
- b, err := json.Marshal(o)
- assert.FatalError(t, err)
- assert.Equals(t, expB, b)
- }
- }
- })
- }
-}
-
-func Test_getOrderIDsByAccount(t *testing.T) {
- type test struct {
- id string
- db nosql.DB
- res []string
- err *Error
- }
- tests := map[string]func(t *testing.T) test{
- "ok/not-found": func(t *testing.T) test {
- return test{
- id: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, database.ErrNotFound
- },
- },
- res: []string{},
- }
- },
- "fail/db-error": func(t *testing.T) test {
- return test{
- id: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- return nil, errors.New("force")
- },
- },
- err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")),
- }
- },
- "fail/unmarshal-error": func(t *testing.T) test {
- return test{
- id: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte("foo"))
+ prov: &MockProvisioner{
+ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
+ assert.Equals(t, token, "")
return nil, nil
},
- },
- err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")),
- }
- },
- "fail/error-loading-order-from-order-IDs": func(t *testing.T) test {
- oids := []string{"o1", "o2", "o3"}
- boids, err := json.Marshal(oids)
- assert.FatalError(t, err)
- dbHit := 0
- return test{
- id: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- dbHit++
- switch dbHit {
- case 1:
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte("foo"))
- return boids, nil
- case 2:
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte("o1"))
- return nil, errors.New("force")
- default:
- assert.FatalError(t, errors.New("should not be here"))
- return nil, nil
+ MgetOptions: func() *provisioner.Options {
+ return &provisioner.Options{
+ X509: &provisioner.X509Options{
+ TemplateData: json.RawMessage([]byte("fo{o")),
+ },
}
},
},
- err: ServerInternalErr(errors.New("error loading order o1 for account foo: error loading order o1: force")),
+ err: NewErrorISE("error creating template options from ACME provisioner: error unmarshaling template data: invalid character 'o' in literal false (expecting 'a')"),
}
},
- "fail/error-updating-order-from-order-IDs": func(t *testing.T) test {
- oids := []string{"o1", "o2", "o3"}
- boids, err := json.Marshal(oids)
- assert.FatalError(t, err)
+ "fail/error-ca-sign": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ Identifiers: []Identifier{
+ {Type: "dns", Value: "foo.internal"},
+ {Type: "dns", Value: "bar.internal"},
+ },
+ }
+ csr := &x509.CertificateRequest{
+ Subject: pkix.Name{
+ CommonName: "foo.internal",
+ },
+ DNSNames: []string{"bar.internal"},
+ }
- o, err := newO()
- assert.FatalError(t, err)
- bo, err := json.Marshal(o)
- assert.FatalError(t, err)
-
- dbHit := 0
return test{
- id: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- dbHit++
- switch dbHit {
- case 1:
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte("foo"))
- return boids, nil
- case 2:
- assert.Equals(t, bucket, orderTable)
- assert.Equals(t, key, []byte("o1"))
- return bo, nil
- case 3:
- assert.Equals(t, bucket, authzTable)
- assert.Equals(t, key, []byte(o.Authorizations[0]))
- return nil, errors.New("force")
- default:
- assert.FatalError(t, errors.New("should not be here"))
- return nil, nil
- }
+ o: o,
+ csr: csr,
+ prov: &MockProvisioner{
+ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
+ assert.Equals(t, token, "")
+ return nil, nil
+ },
+ MgetOptions: func() *provisioner.Options {
+ return nil
},
},
- err: ServerInternalErr(errors.Errorf("error updating order o1 for account foo: error loading authz %s: force", o.Authorizations[0])),
+ ca: &mockSignAuth{
+ sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
+ assert.Equals(t, _csr, csr)
+ return nil, errors.New("force")
+ },
+ },
+ err: NewErrorISE("error signing certificate for order oID: force"),
}
},
- "ok/no-change-to-pending-orders": func(t *testing.T) test {
- oids := []string{"o1", "o2", "o3"}
- boids, err := json.Marshal(oids)
- assert.FatalError(t, err)
+ "fail/error-db.CreateCertificate": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ Identifiers: []Identifier{
+ {Type: "dns", Value: "foo.internal"},
+ {Type: "dns", Value: "bar.internal"},
+ },
+ }
+ csr := &x509.CertificateRequest{
+ Subject: pkix.Name{
+ CommonName: "foo.internal",
+ },
+ DNSNames: []string{"bar.internal"},
+ }
- o, err := newO()
- assert.FatalError(t, err)
- bo, err := json.Marshal(o)
- assert.FatalError(t, err)
-
- az, err := newAz()
- assert.FatalError(t, err)
- baz, err := json.Marshal(az)
- assert.FatalError(t, err)
-
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- bch, err := json.Marshal(ch)
- assert.FatalError(t, err)
+ foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}}
+ bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}}
+ baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}}
return test{
- id: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- switch string(bucket) {
- case string(ordersByAccountIDTable):
- assert.Equals(t, key, []byte("foo"))
- return boids, nil
- case string(orderTable):
- return bo, nil
- case string(authzTable):
- return baz, nil
- case string(challengeTable):
- return bch, nil
- default:
- assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket))
- return nil, nil
- }
+ o: o,
+ csr: csr,
+ prov: &MockProvisioner{
+ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
+ assert.Equals(t, token, "")
+ return nil, nil
},
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- return nil, false, errors.New("should not be attempting to store anything")
+ MgetOptions: func() *provisioner.Options {
+ return nil
},
},
- res: oids,
+ ca: &mockSignAuth{
+ sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
+ assert.Equals(t, _csr, csr)
+ return []*x509.Certificate{foo, bar, baz}, nil
+ },
+ },
+ db: &MockDB{
+ MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
+ assert.Equals(t, cert.AccountID, o.AccountID)
+ assert.Equals(t, cert.OrderID, o.ID)
+ assert.Equals(t, cert.Leaf, foo)
+ assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz})
+ return errors.New("force")
+ },
+ },
+ err: NewErrorISE("error creating certificate for order oID: force"),
}
},
- "fail/error-storing-new-oids": func(t *testing.T) test {
- oids := []string{"o1", "o2", "o3"}
- boids, err := json.Marshal(oids)
- assert.FatalError(t, err)
+ "fail/error-db.UpdateOrder": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ Identifiers: []Identifier{
+ {Type: "dns", Value: "foo.internal"},
+ {Type: "dns", Value: "bar.internal"},
+ },
+ }
+ csr := &x509.CertificateRequest{
+ Subject: pkix.Name{
+ CommonName: "foo.internal",
+ },
+ DNSNames: []string{"bar.internal"},
+ }
- o, err := newO()
- assert.FatalError(t, err)
- bo, err := json.Marshal(o)
- assert.FatalError(t, err)
+ foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}}
+ bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}}
+ baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}}
- invalidOrder, err := newO()
- assert.FatalError(t, err)
- invalidOrder.Status = StatusInvalid
- binvalidOrder, err := json.Marshal(invalidOrder)
- assert.FatalError(t, err)
-
- az, err := newAz()
- assert.FatalError(t, err)
- baz, err := json.Marshal(az)
- assert.FatalError(t, err)
-
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- bch, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- dbGetOrder := 0
return test{
- id: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- switch string(bucket) {
- case string(ordersByAccountIDTable):
- assert.Equals(t, key, []byte("foo"))
- return boids, nil
- case string(orderTable):
- dbGetOrder++
- if dbGetOrder == 1 {
- return binvalidOrder, nil
- }
- return bo, nil
- case string(authzTable):
- return baz, nil
- case string(challengeTable):
- return bch, nil
- default:
- assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket))
- return nil, nil
- }
+ o: o,
+ csr: csr,
+ prov: &MockProvisioner{
+ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
+ assert.Equals(t, token, "")
+ return nil, nil
},
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte("foo"))
- return nil, false, errors.New("force")
+ MgetOptions: func() *provisioner.Options {
+ return nil
},
},
- err: ServerInternalErr(errors.New("error storing orderIDs as part of getOrderIDsByAccount logic: len(orderIDs) = 2: error storing order IDs for account foo: force")),
+ ca: &mockSignAuth{
+ sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
+ assert.Equals(t, _csr, csr)
+ return []*x509.Certificate{foo, bar, baz}, nil
+ },
+ },
+ db: &MockDB{
+ MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
+ cert.ID = "certID"
+ assert.Equals(t, cert.AccountID, o.AccountID)
+ assert.Equals(t, cert.OrderID, o.ID)
+ assert.Equals(t, cert.Leaf, foo)
+ assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz})
+ return nil
+ },
+ MockUpdateOrder: func(ctx context.Context, updo *Order) error {
+ assert.Equals(t, updo.CertificateID, "certID")
+ assert.Equals(t, updo.Status, StatusValid)
+ assert.Equals(t, updo.ID, o.ID)
+ assert.Equals(t, updo.AccountID, o.AccountID)
+ assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
+ assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
+ assert.Equals(t, updo.Identifiers, o.Identifiers)
+ return errors.New("force")
+ },
+ },
+ err: NewErrorISE("error updating order oID: force"),
}
},
- "ok": func(t *testing.T) test {
- oids := []string{"o1", "o2", "o3", "o4"}
- boids, err := json.Marshal(oids)
- assert.FatalError(t, err)
-
- o, err := newO()
- assert.FatalError(t, err)
- bo, err := json.Marshal(o)
- assert.FatalError(t, err)
-
- invalidOrder, err := newO()
- assert.FatalError(t, err)
- invalidOrder.Status = StatusInvalid
- binvalidOrder, err := json.Marshal(invalidOrder)
- assert.FatalError(t, err)
-
- az, err := newAz()
- assert.FatalError(t, err)
- baz, err := json.Marshal(az)
- assert.FatalError(t, err)
-
- ch, err := newDNSCh()
- assert.FatalError(t, err)
- bch, err := json.Marshal(ch)
- assert.FatalError(t, err)
-
- dbGetOrder := 0
- return test{
- id: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- switch string(bucket) {
- case string(ordersByAccountIDTable):
- assert.Equals(t, key, []byte("foo"))
- return boids, nil
- case string(orderTable):
- dbGetOrder++
- if dbGetOrder == 1 || dbGetOrder == 3 {
- return binvalidOrder, nil
- }
- return bo, nil
- case string(authzTable):
- return baz, nil
- case string(challengeTable):
- return bch, nil
- default:
- assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket))
- return nil, nil
- }
- },
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte("foo"))
- return nil, true, nil
- },
+ "ok/new-cert": func(t *testing.T) test {
+ now := clock.Now()
+ o := &Order{
+ ID: "oID",
+ AccountID: "accID",
+ Status: StatusReady,
+ ExpiresAt: now.Add(5 * time.Minute),
+ AuthorizationIDs: []string{"a", "b"},
+ Identifiers: []Identifier{
+ {Type: "dns", Value: "foo.internal"},
+ {Type: "dns", Value: "bar.internal"},
},
- res: []string{"o2", "o4"},
}
- },
- "ok/no-pending-orders": func(t *testing.T) test {
- oids := []string{"o1"}
- boids, err := json.Marshal(oids)
- assert.FatalError(t, err)
+ csr := &x509.CertificateRequest{
+ Subject: pkix.Name{
+ CommonName: "foo.internal",
+ },
+ DNSNames: []string{"bar.internal"},
+ }
- invalidOrder, err := newO()
- assert.FatalError(t, err)
- invalidOrder.Status = StatusInvalid
- binvalidOrder, err := json.Marshal(invalidOrder)
- assert.FatalError(t, err)
+ foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}}
+ bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}}
+ baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}}
return test{
- id: "foo",
- db: &db.MockNoSQLDB{
- MGet: func(bucket, key []byte) ([]byte, error) {
- switch string(bucket) {
- case string(ordersByAccountIDTable):
- assert.Equals(t, key, []byte("foo"))
- return boids, nil
- case string(orderTable):
- return binvalidOrder, nil
- default:
- assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket))
- return nil, nil
- }
+ o: o,
+ csr: csr,
+ prov: &MockProvisioner{
+ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
+ assert.Equals(t, token, "")
+ return nil, nil
},
- MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
- assert.Equals(t, bucket, ordersByAccountIDTable)
- assert.Equals(t, key, []byte("foo"))
- assert.Equals(t, old, boids)
- assert.Nil(t, newval)
- return nil, true, nil
+ MgetOptions: func() *provisioner.Options {
+ return nil
+ },
+ },
+ ca: &mockSignAuth{
+ sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
+ assert.Equals(t, _csr, csr)
+ return []*x509.Certificate{foo, bar, baz}, nil
+ },
+ },
+ db: &MockDB{
+ MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
+ cert.ID = "certID"
+ assert.Equals(t, cert.AccountID, o.AccountID)
+ assert.Equals(t, cert.OrderID, o.ID)
+ assert.Equals(t, cert.Leaf, foo)
+ assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz})
+ return nil
+ },
+ MockUpdateOrder: func(ctx context.Context, updo *Order) error {
+ assert.Equals(t, updo.CertificateID, "certID")
+ assert.Equals(t, updo.Status, StatusValid)
+ assert.Equals(t, updo.ID, o.ID)
+ assert.Equals(t, updo.AccountID, o.AccountID)
+ assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
+ assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
+ assert.Equals(t, updo.Identifiers, o.Identifiers)
+ return nil
},
},
- res: []string{},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
- var oiba = orderIDsByAccount{}
- if oids, err := oiba.unsafeGetOrderIDsByAccount(tc.db, tc.id); err != nil {
+ if err := tc.o.Finalize(context.Background(), tc.db, tc.csr, tc.ca, tc.prov); err != nil {
if assert.NotNil(t, tc.err) {
- ae, ok := err.(*Error)
- assert.True(t, ok)
- assert.HasPrefix(t, ae.Error(), tc.err.Error())
- assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
- assert.Equals(t, ae.Type, tc.err.Type)
+ switch k := err.(type) {
+ case *Error:
+ assert.Equals(t, k.Type, tc.err.Type)
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ assert.Equals(t, k.Status, tc.err.Status)
+ assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
+ assert.Equals(t, k.Detail, tc.err.Detail)
+ default:
+ assert.FatalError(t, errors.New("unexpected error type"))
+ }
}
} else {
- if assert.Nil(t, tc.err) {
- assert.Equals(t, tc.res, oids)
- }
+ assert.Nil(t, tc.err)
}
})
}
diff --git a/acme/status.go b/acme/status.go
new file mode 100644
index 00000000..d9aae82d
--- /dev/null
+++ b/acme/status.go
@@ -0,0 +1,20 @@
+package acme
+
+// Status represents an ACME status.
+type Status string
+
+var (
+ // StatusValid -- valid
+ StatusValid = Status("valid")
+ // StatusInvalid -- invalid
+ StatusInvalid = Status("invalid")
+ // StatusPending -- pending; e.g. an Order that is not ready to be finalized.
+ StatusPending = Status("pending")
+ // StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
+ StatusDeactivated = Status("deactivated")
+ // StatusReady -- ready; e.g. for an Order that is ready to be finalized.
+ StatusReady = Status("ready")
+ //statusExpired = "expired"
+ //statusActive = "active"
+ //statusProcessing = "processing"
+)
diff --git a/api/errors.go b/api/errors.go
index 3e5dec47..438b873d 100644
--- a/api/errors.go
+++ b/api/errors.go
@@ -17,14 +17,14 @@ import (
func WriteError(w http.ResponseWriter, err error) {
switch k := err.(type) {
case *acme.Error:
- w.Header().Set("Content-Type", "application/problem+json")
- err = k.ToACME()
+ acme.WriteError(w, k)
+ return
case *scep.Error:
- // TODO: check if this is correct; and should we do some more processing?
w.Header().Set("Content-Type", "text/plain")
default:
w.Header().Set("Content-Type", "application/json")
}
+
cause := errors.Cause(err)
if sc, ok := err.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go
index 775ed96f..f5cd5221 100644
--- a/authority/provisioner/method.go
+++ b/authority/provisioner/method.go
@@ -56,8 +56,7 @@ func NewContextWithMethod(ctx context.Context, method Method) context.Context {
return context.WithValue(ctx, methodKey{}, method)
}
-// MethodFromContext returns the Method saved in ctx. Returns Sign if the given
-// context has no Method associated with it.
+// MethodFromContext returns the Method saved in ctx.
func MethodFromContext(ctx context.Context) Method {
m, _ := ctx.Value(methodKey{}).(Method)
return m
diff --git a/authority/ssh_test.go b/authority/ssh_test.go
index b5cce1fd..1662260c 100644
--- a/authority/ssh_test.go
+++ b/authority/ssh_test.go
@@ -450,7 +450,7 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("Match exec \"step ssh check-host %h\"\n\tUserKnownHostsFile /home/user/.step/ssh/known_hosts\n\tProxyCommand step ssh proxycommand %r %h %p\n")},
}
hostOutputWithUserData := []templates.Output{
- {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("TrustedUserCAKeys /etc/ssh/ca.pub\nHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\nHostKey /etc/ssh/ssh_host_ecdsa_key")},
+ {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
}
tmplConfigErr := &templates.Templates{
diff --git a/authority/testdata/templates/sshd_config.tpl b/authority/testdata/templates/sshd_config.tpl
index 5ce01fc4..c8e4b884 100644
--- a/authority/testdata/templates/sshd_config.tpl
+++ b/authority/testdata/templates/sshd_config.tpl
@@ -1,3 +1,4 @@
-TrustedUserCAKeys /etc/ssh/ca.pub
-HostCertificate /etc/ssh/{{.User.Certificate}}
-HostKey /etc/ssh/{{.User.Key}}
\ No newline at end of file
+Match all
+ TrustedUserCAKeys /etc/ssh/ca.pub
+ HostCertificate /etc/ssh/{{.User.Certificate}}
+ HostKey /etc/ssh/{{.User.Key}}
\ No newline at end of file
diff --git a/ca/acmeClient.go b/ca/acmeClient.go
index deb8a3a2..5633dac5 100644
--- a/ca/acmeClient.go
+++ b/ca/acmeClient.go
@@ -21,7 +21,7 @@ import (
type ACMEClient struct {
client *http.Client
dirLoc string
- dir *acme.Directory
+ dir *acmeAPI.Directory
acc *acme.Account
Key *jose.JSONWebKey
kid string
@@ -53,7 +53,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC
if resp.StatusCode >= 400 {
return nil, readACMEError(resp.Body)
}
- var dir acme.Directory
+ var dir acmeAPI.Directory
if err := readJSON(resp.Body, &dir); err != nil {
return nil, errors.Wrapf(err, "error reading %s", endpoint)
}
@@ -93,7 +93,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC
// GetDirectory makes a directory request to the ACME api and returns an
// ACME directory object.
-func (c *ACMEClient) GetDirectory() (*acme.Directory, error) {
+func (c *ACMEClient) GetDirectory() (*acmeAPI.Directory, error) {
return c.dir, nil
}
@@ -231,7 +231,7 @@ func (c *ACMEClient) ValidateChallenge(url string) error {
}
// GetAuthz returns the Authz at the given path.
-func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) {
+func (c *ACMEClient) GetAuthz(url string) (*acme.Authorization, error) {
resp, err := c.post(nil, url, withKid(c))
if err != nil {
return nil, err
@@ -240,7 +240,7 @@ func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) {
return nil, readACMEError(resp.Body)
}
- var az acme.Authz
+ var az acme.Authorization
if err := readJSON(resp.Body, &az); err != nil {
return nil, errors.Wrapf(err, "error reading %s", url)
}
@@ -320,7 +320,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) {
if c.acc == nil {
return nil, errors.New("acme client not configured with account")
}
- resp, err := c.post(nil, c.acc.Orders, withKid(c))
+ resp, err := c.post(nil, c.acc.OrdersURL, withKid(c))
if err != nil {
return nil, err
}
@@ -330,7 +330,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) {
var orders []string
if err := readJSON(resp.Body, &orders); err != nil {
- return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders)
+ return nil, errors.Wrapf(err, "error reading %s", c.acc.OrdersURL)
}
return orders, nil
@@ -342,7 +342,7 @@ func readACMEError(r io.ReadCloser) error {
if err != nil {
return errors.Wrap(err, "error reading from body")
}
- ae := new(acme.AError)
+ ae := new(acme.Error)
err = json.Unmarshal(b, &ae)
// If we successfully marshaled to an ACMEError then return the ACMEError.
if err != nil || len(ae.Error()) == 0 {
diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go
index 25d74b9d..f5963de4 100644
--- a/ca/acmeClient_test.go
+++ b/ca/acmeClient_test.go
@@ -31,18 +31,17 @@ func TestNewACMEClient(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
NewAccount: srv.URL + "/bar",
NewOrder: srv.URL + "/baz",
- NewAuthz: srv.URL + "/zap",
RevokeCert: srv.URL + "/zip",
KeyChange: srv.URL + "/blorp",
}
acc := acme.Account{
- Contact: []string{"max", "mariano"},
- Status: "valid",
- Orders: "orders-url",
+ Contact: []string{"max", "mariano"},
+ Status: "valid",
+ OrdersURL: "orders-url",
}
tests := map[string]func(t *testing.T) test{
"fail/client-option-error": func(t *testing.T) test {
@@ -58,7 +57,7 @@ func TestNewACMEClient(t *testing.T) {
"fail/get-directory": func(t *testing.T) test {
return test{
ops: []ClientOption{WithTransport(http.DefaultTransport)},
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -76,7 +75,7 @@ func TestNewACMEClient(t *testing.T) {
ops: []ClientOption{WithTransport(http.DefaultTransport)},
r1: dir,
rc1: 200,
- r2: acme.AccountDoesNotExistErr(nil).ToACME(),
+ r2: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
rc2: 400,
err: errors.New("Account does not exist"),
}
@@ -142,11 +141,10 @@ func TestNewACMEClient(t *testing.T) {
func TestACMEClient_GetDirectory(t *testing.T) {
c := &ACMEClient{
- dir: &acme.Directory{
+ dir: &acmeAPI.Directory{
NewNonce: "/foo",
NewAccount: "/bar",
NewOrder: "/baz",
- NewAuthz: "/zap",
RevokeCert: "/zip",
KeyChange: "/blorp",
},
@@ -166,7 +164,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
@@ -185,7 +183,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
tests := map[string]func(t *testing.T) test{
"fail/GET-nonce": func(t *testing.T) test {
return test{
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -237,7 +235,7 @@ func TestACMEClient_post(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
@@ -248,9 +246,9 @@ func TestACMEClient_post(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
acc := acme.Account{
- Contact: []string{"max", "mariano"},
- Status: "valid",
- Orders: "orders-url",
+ Contact: []string{"max", "mariano"},
+ Status: "valid",
+ OrdersURL: "orders-url",
}
ac := &ACMEClient{
client: &http.Client{
@@ -266,7 +264,7 @@ func TestACMEClient_post(t *testing.T) {
"fail/account-not-configured": func(t *testing.T) test {
return test{
client: &ACMEClient{},
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("acme client not configured with account"),
}
@@ -274,7 +272,7 @@ func TestACMEClient_post(t *testing.T) {
"fail/GET-nonce": func(t *testing.T) test {
return test{
client: ac,
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -365,7 +363,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
NewOrder: srv.URL + "/bar",
}
@@ -376,20 +374,21 @@ func TestACMEClient_NewOrder(t *testing.T) {
assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
+ now := time.Now().UTC().Round(time.Second)
nor := acmeAPI.NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "example.com"},
{Type: "dns", Value: "acme.example.com"},
},
- NotBefore: time.Now(),
- NotAfter: time.Now().Add(time.Minute),
+ NotBefore: now,
+ NotAfter: now.Add(time.Minute),
}
norb, err := json.Marshal(nor)
assert.FatalError(t, err)
ord := acme.Order{
- Status: "valid",
- Expires: "soon",
- Finalize: "finalize-url",
+ Status: "valid",
+ ExpiresAt: now, // "soon"
+ FinalizeURL: "finalize-url",
}
ac := &ACMEClient{
client: &http.Client{
@@ -404,7 +403,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
tests := map[string]func(t *testing.T) test{
"fail/client-post": func(t *testing.T) test {
return test{
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -413,7 +412,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
return test{
r1: []byte{},
rc1: 200,
- r2: acme.MalformedErr(nil).ToACME(),
+ r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc2: 400,
ops: []withHeaderOption{withKid(ac)},
err: errors.New("The request message was malformed"),
@@ -498,7 +497,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
@@ -509,9 +508,9 @@ func TestACMEClient_GetOrder(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
ord := acme.Order{
- Status: "valid",
- Expires: "soon",
- Finalize: "finalize-url",
+ Status: "valid",
+ ExpiresAt: time.Now().UTC().Round(time.Second), // "soon"
+ FinalizeURL: "finalize-url",
}
ac := &ACMEClient{
client: &http.Client{
@@ -526,7 +525,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
tests := map[string]func(t *testing.T) test{
"fail/client-post": func(t *testing.T) test {
return test{
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -535,7 +534,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
return test{
r1: []byte{},
rc1: 200,
- r2: acme.MalformedErr(nil).ToACME(),
+ r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc2: 400,
err: errors.New("The request message was malformed"),
}
@@ -618,7 +617,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
@@ -628,9 +627,9 @@ func TestACMEClient_GetAuthz(t *testing.T) {
assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
- az := acme.Authz{
+ az := acme.Authorization{
Status: "valid",
- Expires: "soon",
+ ExpiresAt: time.Now().UTC().Round(time.Second),
Identifier: acme.Identifier{Type: "dns", Value: "example.com"},
}
ac := &ACMEClient{
@@ -646,7 +645,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
tests := map[string]func(t *testing.T) test{
"fail/client-post": func(t *testing.T) test {
return test{
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -655,7 +654,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
return test{
r1: []byte{},
rc1: 200,
- r2: acme.MalformedErr(nil).ToACME(),
+ r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc2: 400,
err: errors.New("The request message was malformed"),
}
@@ -738,7 +737,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
@@ -766,7 +765,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
tests := map[string]func(t *testing.T) test{
"fail/client-post": func(t *testing.T) test {
return test{
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -775,7 +774,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
return test{
r1: []byte{},
rc1: 200,
- r2: acme.MalformedErr(nil).ToACME(),
+ r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc2: 400,
err: errors.New("The request message was malformed"),
}
@@ -859,7 +858,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
@@ -887,7 +886,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
tests := map[string]func(t *testing.T) test{
"fail/client-post": func(t *testing.T) test {
return test{
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -896,7 +895,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
return test{
r1: []byte{},
rc1: 200,
- r2: acme.MalformedErr(nil).ToACME(),
+ r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc2: 400,
err: errors.New("The request message was malformed"),
}
@@ -976,7 +975,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
@@ -987,10 +986,10 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
ord := acme.Order{
- Status: "valid",
- Expires: "soon",
- Finalize: "finalize-url",
- Certificate: "cert-url",
+ Status: "valid",
+ ExpiresAt: time.Now(), // "soon"
+ FinalizeURL: "finalize-url",
+ CertificateURL: "cert-url",
}
_csr, err := pemutil.Read("../authority/testdata/certs/foo.csr")
assert.FatalError(t, err)
@@ -1012,7 +1011,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
tests := map[string]func(t *testing.T) test{
"fail/client-post": func(t *testing.T) test {
return test{
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -1021,7 +1020,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
return test{
r1: []byte{},
rc1: 200,
- r2: acme.MalformedErr(nil).ToACME(),
+ r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc2: 400,
err: errors.New("The request message was malformed"),
}
@@ -1101,7 +1100,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
@@ -1121,9 +1120,9 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
Key: jwk,
kid: "foobar",
acc: &acme.Account{
- Contact: []string{"max", "mariano"},
- Status: "valid",
- Orders: srv.URL + "/orders-url",
+ Contact: []string{"max", "mariano"},
+ Status: "valid",
+ OrdersURL: srv.URL + "/orders-url",
},
}
@@ -1137,7 +1136,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
"fail/client-post": func(t *testing.T) test {
return test{
client: ac,
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -1147,7 +1146,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
client: ac,
r1: []byte{},
rc1: 200,
- r2: acme.MalformedErr(nil).ToACME(),
+ r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc2: 400,
err: errors.New("The request message was malformed"),
}
@@ -1198,7 +1197,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
assert.Equals(t, hdr.Nonce, expectedNonce)
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
assert.Fatal(t, ok)
- assert.Equals(t, jwsURL, ac.acc.Orders)
+ assert.Equals(t, jwsURL, ac.acc.OrdersURL)
assert.Equals(t, hdr.KeyID, ac.kid)
payload, err := jws.Verify(ac.Key.Public())
@@ -1232,7 +1231,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
srv := httptest.NewServer(nil)
defer srv.Close()
- dir := acme.Directory{
+ dir := acmeAPI.Directory{
NewNonce: srv.URL + "/foo",
}
// Retrieve transport from options.
@@ -1259,16 +1258,16 @@ func TestACMEClient_GetCertificate(t *testing.T) {
Key: jwk,
kid: "foobar",
acc: &acme.Account{
- Contact: []string{"max", "mariano"},
- Status: "valid",
- Orders: srv.URL + "/orders-url",
+ Contact: []string{"max", "mariano"},
+ Status: "valid",
+ OrdersURL: srv.URL + "/orders-url",
},
}
tests := map[string]func(t *testing.T) test{
"fail/client-post": func(t *testing.T) test {
return test{
- r1: acme.MalformedErr(nil).ToACME(),
+ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc1: 400,
err: errors.New("The request message was malformed"),
}
@@ -1277,7 +1276,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
return test{
r1: []byte{},
rc1: 200,
- r2: acme.MalformedErr(nil).ToACME(),
+ r2: acme.NewError(acme.ErrorMalformedType, "malformed request"),
rc2: 400,
err: errors.New("The request message was malformed"),
}
diff --git a/ca/ca.go b/ca/ca.go
index 6452d9f5..56f4c1f8 100644
--- a/ca/ca.go
+++ b/ca/ca.go
@@ -14,6 +14,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api"
+ acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/db"
@@ -149,23 +150,29 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) {
}
prefix := "acme"
- acmeAuth, err := acme.New(auth, acme.AuthorityOptions{
+ var acmeDB acme.DB
+ if config.DB == nil {
+ acmeDB = nil
+ } else {
+ acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
+ if err != nil {
+ return nil, errors.Wrap(err, "error configuring ACME DB interface")
+ }
+ }
+ acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
Backdate: *config.AuthorityConfig.Backdate,
- DB: auth.GetDatabase().(nosql.DB),
+ DB: acmeDB,
DNS: dns,
Prefix: prefix,
+ CA: auth,
})
- if err != nil {
- return nil, errors.Wrap(err, "error creating ACME authority")
- }
- acmeRouterHandler := acmeAPI.New(acmeAuth)
mux.Route("/"+prefix, func(r chi.Router) {
- acmeRouterHandler.Route(r)
+ acmeHandler.Route(r)
})
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
// of the ACME spec.
mux.Route("/2.0/"+prefix, func(r chi.Router) {
- acmeRouterHandler.Route(r)
+ acmeHandler.Route(r)
})
if ca.shouldServeSCEPEndpoints() {
diff --git a/ca/client.go b/ca/client.go
index be55ba53..b9593162 100644
--- a/ca/client.go
+++ b/ca/client.go
@@ -57,6 +57,7 @@ func newInsecureClient() *uaClient {
return &uaClient{
Client: &http.Client{
Transport: &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
},
diff --git a/debian/rules b/debian/rules
index 6683ef21..f5b70196 100755
--- a/debian/rules
+++ b/debian/rules
@@ -4,8 +4,10 @@ override_dh_install-arch:
dh_install --arch
build:
- make bootstrap
dh build
+override_dh_auto_build:
+ dh_auto_build -- build
+
%:
dh $@
diff --git a/docs/provisioners.md b/docs/provisioners.md
index 445bb650..63275916 100644
--- a/docs/provisioners.md
+++ b/docs/provisioners.md
@@ -80,7 +80,7 @@ Example `claims`:
use this value.
* `enableSSHCA`: enable all provisioners to generate SSH Certificates.
- The deault value is `false`. You can enable this option per provisioner
+ The default value is `false`. You can enable this option per provisioner
by setting it to `true` in the provisioner claims.
## Provisioner Types
diff --git a/templates/values.go b/templates/values.go
index fd4ee4c2..972b1d55 100644
--- a/templates/values.go
+++ b/templates/values.go
@@ -99,9 +99,10 @@ var DefaultSSHTemplateData = map[string]string{
`,
// sshd_config.tpl adds the configuration to support certificates
- "sshd_config.tpl": `TrustedUserCAKeys /etc/ssh/ca.pub
-HostCertificate /etc/ssh/{{.User.Certificate}}
-HostKey /etc/ssh/{{.User.Key}}`,
+ "sshd_config.tpl": `Match all
+ TrustedUserCAKeys /etc/ssh/ca.pub
+ HostCertificate /etc/ssh/{{.User.Certificate}}
+ HostKey /etc/ssh/{{.User.Key}}`,
// ca.tpl contains the public key used to authorized clients
"ca.tpl": `{{.Step.SSH.UserKey.Type}} {{.Step.SSH.UserKey.Marshal | toString | b64enc}}