Merge branch 'master' into crl-support
# Conflicts: # api/api.go # authority/config/config.go # cas/softcas/softcas.go # db/db.go
This commit is contained in:
commit
60671b07d7
182 changed files with 21493 additions and 2534 deletions
56
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
56
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
|
@ -0,0 +1,56 @@
|
|||
name: Bug Report
|
||||
description: File a bug report
|
||||
title: "[Bug]: "
|
||||
labels: ["bug", "needs triage"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report!
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Tell us how to reproduce this issue.
|
||||
placeholder: These are the steps!
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: your-env
|
||||
attributes:
|
||||
label: Your Environment
|
||||
value: |-
|
||||
* OS -
|
||||
* `step-ca` Version -
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: What did you expect to happen?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: What happens instead?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context about the problem here.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: contributing
|
||||
attributes:
|
||||
label: Contributing
|
||||
value: |
|
||||
Vote on this issue by adding a 👍 reaction.
|
||||
To contribute a fix for this issue, leave a comment (and link to your pull request, if you've opened one already).
|
||||
validations:
|
||||
required: false
|
27
.github/ISSUE_TEMPLATE/bug_report.md
vendored
27
.github/ISSUE_TEMPLATE/bug_report.md
vendored
|
@ -1,27 +0,0 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: bug, needs triage
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
### Subject of the issue
|
||||
Describe your issue here.
|
||||
|
||||
### Your environment
|
||||
* OS -
|
||||
* Version -
|
||||
|
||||
### Steps to reproduce
|
||||
Tell us how to reproduce this issue. Please provide a working demo, you can use [this template](https://plnkr.co/edit/XorWgI?p=preview) as a base.
|
||||
|
||||
### Expected behaviour
|
||||
Tell us what should happen
|
||||
|
||||
### Actual behaviour
|
||||
Tell us what happens instead
|
||||
|
||||
### Additional context
|
||||
Add any other context about the problem here.
|
12
.github/ISSUE_TEMPLATE/documentation-request.md
vendored
12
.github/ISSUE_TEMPLATE/documentation-request.md
vendored
|
@ -1,12 +1,20 @@
|
|||
---
|
||||
name: Documentation Request
|
||||
about: Request documentation for a feature
|
||||
title: ''
|
||||
labels: documentation, needs triage
|
||||
title: '[Docs]:'
|
||||
labels: docs, needs triage
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Hello!
|
||||
<!-- Please leave this section as-is, it's designed to help others in the community know how to interact with our GitHub issues. -->
|
||||
|
||||
- Vote on this issue by adding a 👍 reaction
|
||||
- If you want to document this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
|
||||
|
||||
## Affected area/feature
|
||||
|
||||
<!---
|
||||
Tell us which feature you'd like to see documented.
|
||||
- Where would you like that documentation to live (command line usage output, website, github markdown on the repo)?
|
||||
|
|
17
.github/ISSUE_TEMPLATE/enhancement.md
vendored
17
.github/ISSUE_TEMPLATE/enhancement.md
vendored
|
@ -1,13 +1,24 @@
|
|||
---
|
||||
name: Enhancement
|
||||
about: Suggest an enhancement to step certificates
|
||||
about: Suggest an enhancement to step-ca
|
||||
title: ''
|
||||
labels: enhancement, needs triage
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
### What would you like to be added
|
||||
## Hello!
|
||||
<!-- Please leave this section as-is,
|
||||
it's designed to help others in the community know how to interact with our GitHub issues. -->
|
||||
|
||||
- Vote on this issue by adding a 👍 reaction
|
||||
- If you want to implement this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
|
||||
|
||||
### Why this is needed
|
||||
## Issue details
|
||||
|
||||
<!-- Enhancement requests are most helpful when they describe the problem you're having
|
||||
as well as articulating the potential solution you'd like to see built. -->
|
||||
|
||||
## Why is this needed?
|
||||
|
||||
<!-- Let us know why you think this enhancement would be good for the project or community. -->
|
||||
|
|
20
.github/PULL_REQUEST_TEMPLATE
vendored
20
.github/PULL_REQUEST_TEMPLATE
vendored
|
@ -1,4 +1,20 @@
|
|||
### Description
|
||||
Please describe your pull request.
|
||||
<!---
|
||||
Please provide answers in the spaces below each prompt, where applicable.
|
||||
Not every PR requires responses for each prompt.
|
||||
Use your discretion.
|
||||
-->
|
||||
#### Name of feature:
|
||||
|
||||
#### Pain or issue this feature alleviates:
|
||||
|
||||
#### Why is this important to the project (if not answered above):
|
||||
|
||||
#### Is there documentation on how to use this feature? If so, where?
|
||||
|
||||
#### In what environments or workflows is this feature supported?
|
||||
|
||||
#### In what environments or workflows is this feature explicitly NOT supported (if any)?
|
||||
|
||||
#### Supporting links/other PRs/issues:
|
||||
|
||||
💔Thank you!
|
||||
|
|
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
|
@ -33,7 +33,7 @@ jobs:
|
|||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||
version: 'v1.45.0'
|
||||
version: 'v1.45.2'
|
||||
|
||||
# Optional: working directory, useful for monorepos
|
||||
# working-directory: somedir
|
||||
|
@ -139,7 +139,7 @@ jobs:
|
|||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0
|
||||
with:
|
||||
version: latest
|
||||
version: 'v1.7.0'
|
||||
args: release --rm-dist
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.PAT }}
|
||||
|
|
7
.github/workflows/test.yml
vendored
7
.github/workflows/test.yml
vendored
|
@ -33,7 +33,7 @@ jobs:
|
|||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||
version: 'v1.45.0'
|
||||
version: 'v1.45.2'
|
||||
|
||||
# Optional: working directory, useful for monorepos
|
||||
# working-directory: somedir
|
||||
|
@ -59,8 +59,9 @@ jobs:
|
|||
-
|
||||
name: Codecov
|
||||
if: matrix.go == '1.18'
|
||||
uses: codecov/codecov-action@v1.2.1
|
||||
uses: codecov/codecov-action@v2
|
||||
with:
|
||||
file: ./coverage.out # optional
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.out # optional
|
||||
name: codecov-umbrella # optional
|
||||
fail_ci_if_error: true # optional (default = false)
|
||||
|
|
|
@ -230,42 +230,3 @@ scoop:
|
|||
# Your app's license
|
||||
# Default is empty.
|
||||
license: "Apache-2.0"
|
||||
|
||||
#dockers:
|
||||
# - dockerfile: docker/Dockerfile
|
||||
# goos: linux
|
||||
# goarch: amd64
|
||||
# use_buildx: true
|
||||
# image_templates:
|
||||
# - "smallstep/step-cli:latest"
|
||||
# - "smallstep/step-cli:{{ .Tag }}"
|
||||
# build_flag_templates:
|
||||
# - "--platform=linux/amd64"
|
||||
# - dockerfile: docker/Dockerfile
|
||||
# goos: linux
|
||||
# goarch: 386
|
||||
# use_buildx: true
|
||||
# image_templates:
|
||||
# - "smallstep/step-cli:latest"
|
||||
# - "smallstep/step-cli:{{ .Tag }}"
|
||||
# build_flag_templates:
|
||||
# - "--platform=linux/386"
|
||||
# - dockerfile: docker/Dockerfile
|
||||
# goos: linux
|
||||
# goarch: arm
|
||||
# goarm: 7
|
||||
# use_buildx: true
|
||||
# image_templates:
|
||||
# - "smallstep/step-cli:latest"
|
||||
# - "smallstep/step-cli:{{ .Tag }}"
|
||||
# build_flag_templates:
|
||||
# - "--platform=linux/arm/v7"
|
||||
# - dockerfile: docker/Dockerfile
|
||||
# goos: linux
|
||||
# goarch: arm64
|
||||
# use_buildx: true
|
||||
# image_templates:
|
||||
# - "smallstep/step-cli:latest"
|
||||
# - "smallstep/step-cli:{{ .Tag }}"
|
||||
# build_flag_templates:
|
||||
# - "--platform=linux/arm64/v8"
|
||||
|
|
66
CHANGELOG.md
66
CHANGELOG.md
|
@ -4,17 +4,71 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
||||
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased - 0.18.3] - DATE
|
||||
### TEMPLATE -- do not alter or remove
|
||||
---
|
||||
## [x.y.z] - aaaa-bb-cc
|
||||
### Added
|
||||
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
|
||||
- Added support for `extraNames` in X.509 templates.
|
||||
### Changed
|
||||
- Made SCEP CA URL paths dynamic
|
||||
- Support two latest versions of Go (1.17, 1.18)
|
||||
### Deprecated
|
||||
### Removed
|
||||
### Fixed
|
||||
### Security
|
||||
---
|
||||
|
||||
## [Unreleased]
|
||||
### Changed
|
||||
- Certificates signed by an issuer using an RSA key will be signed using the same algorithm as the issuer certificate was signed with. The signature will no longer default to PKCS #1. For example, if the issuer certificate was signed using RSA-PSS with SHA-256, a new certificate will also be signed using RSA-PSS with SHA-256.
|
||||
|
||||
## [0.20.0] - 2022-05-26
|
||||
### Added
|
||||
- Added Kubernetes auth method for Vault RAs.
|
||||
- Added support for reporting provisioners to linkedca.
|
||||
- Added support for certificate policies on authority level.
|
||||
- Added a Dockerfile with a step-ca build with HSM support.
|
||||
- A few new WithXX methods for instantiating authorities
|
||||
### Changed
|
||||
- Context usage in HTTP APIs.
|
||||
- Changed authentication for Vault RAs.
|
||||
- Error message returned to client when authenticating with expired certificate.
|
||||
- Strip padding from ACME CSRs.
|
||||
### Deprecated
|
||||
- HTTP API handler types.
|
||||
### Fixed
|
||||
- Fixed SSH revocation.
|
||||
- CA client dial context for js/wasm target.
|
||||
- Incomplete `extraNames` support in templates.
|
||||
- SCEP GET request support.
|
||||
- Large SCEP request handling.
|
||||
|
||||
## [0.19.0] - 2022-04-19
|
||||
### Added
|
||||
- Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`.
|
||||
- Added support for `extraNames` in X.509 templates.
|
||||
- Added `armv5` builds.
|
||||
- Added RA support using a Vault instance as the CA.
|
||||
- Added `WithX509SignerFunc` authority option.
|
||||
- Added a new `/roots.pem` endpoint to download the CA roots in PEM format.
|
||||
- Added support for Azure `Managed Identity` tokens.
|
||||
- Added support for automatic configuration of linked RAs.
|
||||
- Added support for the `--context` flag. It's now possible to start the
|
||||
CA with `step-ca --context=abc` to use the configuration from context `abc`.
|
||||
When a context has been configured and no configuration file is provided
|
||||
on startup, the configuration for the current context is used.
|
||||
- Added startup info logging and option to skip it (`--quiet`).
|
||||
- Added support for renaming the CA (Common Name).
|
||||
### Changed
|
||||
- Made SCEP CA URL paths dynamic.
|
||||
- Support two latest versions of Go (1.17, 1.18).
|
||||
- Upgrade go.step.sm/crypto to v0.16.1.
|
||||
- Upgrade go.step.sm/linkedca to v0.15.0.
|
||||
### Deprecated
|
||||
- Go 1.16 support.
|
||||
### Removed
|
||||
### Fixed
|
||||
- Fixed admin credentials on RAs.
|
||||
- Fixed ACME HTTP-01 challenges for IPv6 identifiers.
|
||||
- Various improvements under the hood.
|
||||
### Security
|
||||
|
||||
## [0.18.2] - 2022-03-01
|
||||
### Added
|
||||
|
@ -49,7 +103,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||
- Support for multiple certificate authority contexts.
|
||||
- Support for generating extractable keys and certificates on a pkcs#11 module.
|
||||
### Changed
|
||||
- Support two latest versions of golang (1.16, 1.17)
|
||||
- Support two latest versions of Go (1.16, 1.17)
|
||||
### Deprecated
|
||||
- go 1.15 support
|
||||
|
||||
|
|
2
Makefile
2
Makefile
|
@ -151,7 +151,7 @@ integration: bin/$(BINNAME)
|
|||
#########################################
|
||||
|
||||
fmt:
|
||||
$Q gofmt -l -w $(SRC)
|
||||
$Q gofmt -l -s -w $(SRC)
|
||||
|
||||
lint:
|
||||
$Q golangci-lint run --timeout=30m
|
||||
|
|
|
@ -54,7 +54,7 @@ Setting up a *public key infrastructure* (PKI) is out of reach for many small te
|
|||
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation
|
||||
- Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
|
||||
- Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca)
|
||||
- [Badger, BoltDB, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
|
||||
- [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
|
||||
|
||||
### ⚙️ Many ways to automate
|
||||
|
||||
|
@ -68,6 +68,7 @@ You can issue certificates in exchange for:
|
|||
- [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure
|
||||
- [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
|
||||
- A trusted X.509 certificate (X5C provisioner)
|
||||
- A host certificate from your Nebula network
|
||||
- A SCEP challenge (SCEP provisioner)
|
||||
- An SSH host certificates needing renewal (the SSHPOP provisioner)
|
||||
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners)
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"time"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
// Account is a subset of the internal account type containing only those
|
||||
|
@ -43,15 +45,63 @@ func KeyToID(jwk *jose.JSONWebKey) (string, error) {
|
|||
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||
}
|
||||
|
||||
// PolicyNames contains ACME account level policy names
|
||||
type PolicyNames struct {
|
||||
DNSNames []string `json:"dns"`
|
||||
IPRanges []string `json:"ips"`
|
||||
}
|
||||
|
||||
// X509Policy contains ACME account level X.509 policy
|
||||
type X509Policy struct {
|
||||
Allowed PolicyNames `json:"allow"`
|
||||
Denied PolicyNames `json:"deny"`
|
||||
AllowWildcardNames bool `json:"allowWildcardNames"`
|
||||
}
|
||||
|
||||
// Policy is an ACME Account level policy
|
||||
type Policy struct {
|
||||
X509 X509Policy `json:"x509"`
|
||||
}
|
||||
|
||||
func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &policy.X509NameOptions{
|
||||
DNSDomains: p.X509.Allowed.DNSNames,
|
||||
IPRanges: p.X509.Allowed.IPRanges,
|
||||
}
|
||||
}
|
||||
func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &policy.X509NameOptions{
|
||||
DNSDomains: p.X509.Denied.DNSNames,
|
||||
IPRanges: p.X509.Denied.IPRanges,
|
||||
}
|
||||
}
|
||||
|
||||
// AreWildcardNamesAllowed returns if wildcard names
|
||||
// like *.example.com are allowed to be signed.
|
||||
// Defaults to false.
|
||||
func (p *Policy) AreWildcardNamesAllowed() bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
return p.X509.AllowWildcardNames
|
||||
}
|
||||
|
||||
// ExternalAccountKey is an ACME External Account Binding key.
|
||||
type ExternalAccountKey struct {
|
||||
ID string `json:"id"`
|
||||
ProvisionerID string `json:"provisionerID"`
|
||||
Reference string `json:"reference"`
|
||||
AccountID string `json:"-"`
|
||||
KeyBytes []byte `json:"-"`
|
||||
HmacKey []byte `json:"-"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
BoundAt time.Time `json:"boundAt,omitempty"`
|
||||
Policy *Policy `json:"policy,omitempty"`
|
||||
}
|
||||
|
||||
// AlreadyBound returns whether this EAK is already bound to
|
||||
|
@ -68,6 +118,6 @@ func (eak *ExternalAccountKey) BindTo(account *Account) error {
|
|||
}
|
||||
eak.AccountID = account.ID
|
||||
eak.BoundAt = time.Now()
|
||||
eak.KeyBytes = []byte{} // clearing the key bytes; can only be used once
|
||||
eak.HmacKey = []byte{} // clearing the key bytes; can only be used once
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,8 +7,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestKeyToID(t *testing.T) {
|
||||
|
@ -95,7 +96,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
},
|
||||
acct: &Account{
|
||||
ID: "accountID",
|
||||
|
@ -108,7 +109,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
AccountID: "someAccountID",
|
||||
BoundAt: boundAt,
|
||||
},
|
||||
|
@ -138,7 +139,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
|
|||
assert.Equals(t, ae.Subproblems, tt.err.Subproblems)
|
||||
} else {
|
||||
assert.Equals(t, eak.AccountID, acct.ID)
|
||||
assert.Equals(t, eak.KeyBytes, []byte{})
|
||||
assert.Equals(t, eak.HmacKey, []byte{})
|
||||
assert.NotNil(t, eak.BoundAt)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -67,8 +67,11 @@ func (u *UpdateAccountRequest) Validate() error {
|
|||
}
|
||||
|
||||
// NewAccount is the handler resource for creating new ACME accounts.
|
||||
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -114,7 +117,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
eak, err := h.validateExternalAccountBinding(ctx, &nar)
|
||||
eak, err := validateExternalAccountBinding(ctx, &nar)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -125,18 +128,17 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
Contact: nar.Contact,
|
||||
Status: acme.StatusValid,
|
||||
}
|
||||
if err := h.db.CreateAccount(ctx, acc); err != nil {
|
||||
if err := db.CreateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
||||
return
|
||||
}
|
||||
|
||||
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
|
||||
err := eak.BindTo(acc)
|
||||
if err != nil {
|
||||
if err := eak.BindTo(acc); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||
if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
|
||||
return
|
||||
}
|
||||
|
@ -147,15 +149,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
httpStatus = http.StatusOK
|
||||
}
|
||||
|
||||
h.linker.LinkAccount(ctx, acc)
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
|
||||
w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
|
||||
render.JSONStatus(w, acc, httpStatus)
|
||||
}
|
||||
|
||||
// GetOrUpdateAccount is the api for updating an ACME account.
|
||||
func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -187,16 +192,16 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
|||
acc.Contact = uar.Contact
|
||||
}
|
||||
|
||||
if err := h.db.UpdateAccount(ctx, acc); err != nil {
|
||||
if err := db.UpdateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h.linker.LinkAccount(ctx, acc)
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
|
||||
render.JSON(w, acc)
|
||||
}
|
||||
|
||||
|
@ -210,8 +215,11 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
|||
}
|
||||
|
||||
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
|
||||
func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||
func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -222,13 +230,14 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
||||
return
|
||||
}
|
||||
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
|
||||
|
||||
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkOrdersByAccountID(ctx, orders)
|
||||
linker.LinkOrdersByAccountID(ctx, orders)
|
||||
|
||||
render.JSON(w, orders)
|
||||
logOrdersByAccount(w, orders)
|
||||
|
|
|
@ -13,10 +13,12 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -29,6 +31,22 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
type fakeProvisioner struct{}
|
||||
|
||||
func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil }
|
||||
func (*fakeProvisioner) GetID() string { return "" }
|
||||
func (*fakeProvisioner) GetName() string { return "" }
|
||||
func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 }
|
||||
func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil }
|
||||
|
||||
func newProv() acme.Provisioner {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
|
@ -41,6 +59,19 @@ func newProv() acme.Provisioner {
|
|||
return p
|
||||
}
|
||||
|
||||
func newProvWithOptions(options *provisioner.Options) acme.Provisioner {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-<test>provisioner.com",
|
||||
Options: options,
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func newACMEProv(t *testing.T) *provisioner.ACME {
|
||||
p := newProv()
|
||||
a, ok := p.(*provisioner.ACME)
|
||||
|
@ -50,6 +81,15 @@ func newACMEProv(t *testing.T) *provisioner.ACME {
|
|||
return a
|
||||
}
|
||||
|
||||
func newACMEProvWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME {
|
||||
p := newProvWithOptions(options)
|
||||
a, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
t.Fatal("not a valid ACME provisioner")
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) {
|
||||
signer, err := jose.NewSigner(
|
||||
jose.SigningKey{
|
||||
|
@ -296,10 +336,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: accID}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
|
||||
|
@ -315,11 +354,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrdersByAccountID(w, req)
|
||||
GetOrdersByAccountID(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -363,6 +402,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -371,6 +411,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -379,6 +420,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to "+
|
||||
|
@ -393,6 +435,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||
|
@ -405,8 +448,9 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -418,9 +462,10 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jwk expected in request context"),
|
||||
|
@ -432,10 +477,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jwk expected in request context"),
|
||||
|
@ -454,9 +500,9 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"),
|
||||
|
@ -471,7 +517,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -501,18 +547,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
scepProvisioner := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
|
||||
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"),
|
||||
|
@ -551,14 +590,13 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
eak := &acme.ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
return test{
|
||||
|
@ -599,8 +637,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||
|
@ -635,11 +672,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
Status: acme.StatusValid,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
acc: acc,
|
||||
statusCode: 200,
|
||||
|
@ -664,8 +701,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = false
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
|
||||
|
@ -719,8 +755,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -735,7 +770,7 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -759,11 +794,11 @@ func TestHandler_NewAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.NewAccount(w, req)
|
||||
NewAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -814,6 +849,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -822,6 +858,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -830,6 +867,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
"fail/no-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -839,6 +877,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -848,6 +887,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
|
||||
|
@ -862,6 +902,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
ctx := context.WithValue(context.Background(), accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
|
||||
|
@ -894,10 +935,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||
|
@ -914,11 +954,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
uar := &UpdateAccountRequest{}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
|
@ -929,10 +969,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
|
||||
|
@ -946,11 +985,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok/post-as-get": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
|
@ -959,11 +998,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrUpdateAccount(w, req)
|
||||
GetOrUpdateAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
|
|
@ -4,8 +4,9 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
)
|
||||
|
||||
// ExternalAccountBinding represents the ACME externalAccountBinding JWS
|
||||
|
@ -16,7 +17,7 @@ type ExternalAccountBinding struct {
|
|||
}
|
||||
|
||||
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
|
||||
func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
||||
func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
|
||||
|
@ -47,7 +48,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc
|
|||
return nil, acmeErr
|
||||
}
|
||||
|
||||
externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||
if err != nil {
|
||||
if _, ok := err.(*acme.Error); ok {
|
||||
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
|
||||
|
@ -55,11 +57,19 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc
|
|||
return nil, acme.WrapErrorISE(err, "error retrieving external account key")
|
||||
}
|
||||
|
||||
if externalAccountKey == nil {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key")
|
||||
}
|
||||
|
||||
if len(externalAccountKey.HmacKey) == 0 {
|
||||
return nil, acme.NewError(acme.ErrorServerInternalType, "external account binding key with id '%s' does not have secret bytes", keyID)
|
||||
}
|
||||
|
||||
if externalAccountKey.AlreadyBound() {
|
||||
return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt)
|
||||
}
|
||||
|
||||
payload, err := eabJWS.Verify(externalAccountKey.KeyBytes)
|
||||
payload, err := eabJWS.Verify(externalAccountKey.HmacKey)
|
||||
if err != nil {
|
||||
return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature")
|
||||
}
|
||||
|
@ -97,12 +107,12 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
|
|||
|
||||
// validateEABJWS verifies the contents of the External Account Binding JWS.
|
||||
// The protected header of the JWS MUST meet the following criteria:
|
||||
// o The "alg" field MUST indicate a MAC-based algorithm
|
||||
// o The "kid" field MUST contain the key identifier provided by the CA
|
||||
// o The "nonce" field MUST NOT be present
|
||||
// o The "url" field MUST be set to the same value as the outer JWS
|
||||
//
|
||||
// - The "alg" field MUST indicate a MAC-based algorithm
|
||||
// - The "kid" field MUST contain the key identifier provided by the CA
|
||||
// - The "nonce" field MUST NOT be present
|
||||
// - The "url" field MUST be set to the same value as the outer JWS
|
||||
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
|
||||
|
||||
if jws == nil {
|
||||
return "", acme.NewErrorISE("no JWS provided")
|
||||
}
|
||||
|
|
|
@ -9,10 +9,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
func Test_keysAreEqual(t *testing.T) {
|
||||
|
@ -98,8 +99,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
|
@ -143,8 +143,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
createdAt := time.Now()
|
||||
return test{
|
||||
|
@ -154,7 +153,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: createdAt,
|
||||
}, nil
|
||||
},
|
||||
|
@ -168,7 +167,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: createdAt,
|
||||
},
|
||||
err: nil,
|
||||
|
@ -189,17 +188,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
scepProvisioner := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
|
||||
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
|
||||
|
@ -218,8 +210,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
|
@ -264,8 +255,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
|
@ -310,8 +300,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -358,8 +347,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -408,8 +396,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -426,6 +413,112 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
err: acme.NewErrorISE("error retrieving external account key"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetExternalAccountKey-nil": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
|
||||
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
|
||||
assert.FatalError(t, err)
|
||||
eab := &ExternalAccountBinding{}
|
||||
err = json.Unmarshal(rawEABJWS, &eab)
|
||||
assert.FatalError(t, err)
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
ExternalAccountBinding: eab,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
|
||||
so.WithHeader("url", url)
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||
Key: jwk.Key,
|
||||
}, so)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := signer.Sign(payloadBytes)
|
||||
assert.FatalError(t, err)
|
||||
raw, err := jws.CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
nar: &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
ExternalAccountBinding: eab,
|
||||
},
|
||||
eak: nil,
|
||||
err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetExternalAccountKey-no-keybytes": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
|
||||
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
|
||||
assert.FatalError(t, err)
|
||||
eab := &ExternalAccountBinding{}
|
||||
err = json.Unmarshal(rawEABJWS, &eab)
|
||||
assert.FatalError(t, err)
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
ExternalAccountBinding: eab,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
|
||||
so.WithHeader("url", url)
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||
Key: jwk.Key,
|
||||
}, so)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := signer.Sign(payloadBytes)
|
||||
assert.FatalError(t, err)
|
||||
raw, err := jws.CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
createdAt := time.Now()
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
return &acme.ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
CreatedAt: createdAt,
|
||||
AccountID: "some-account-id",
|
||||
HmacKey: []byte{},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
nar: &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
ExternalAccountBinding: eab,
|
||||
},
|
||||
eak: nil,
|
||||
err: acme.NewError(acme.ErrorServerInternalType, "external account binding key with id 'eakID' does not have secret bytes"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
|
@ -458,8 +551,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -506,8 +598,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
createdAt := time.Now()
|
||||
boundAt := time.Now().Add(1 * time.Second)
|
||||
|
@ -520,6 +611,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
Reference: "testeak",
|
||||
CreatedAt: createdAt,
|
||||
AccountID: "some-account-id",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
BoundAt: boundAt,
|
||||
}, nil
|
||||
},
|
||||
|
@ -565,8 +657,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -575,7 +666,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
KeyBytes: []byte{1, 2, 3, 4},
|
||||
HmacKey: []byte{1, 2, 3, 4},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -623,8 +714,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -633,7 +723,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -678,8 +768,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -688,7 +777,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -734,8 +823,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
ctx := context.WithValue(context.Background(), jwkContextKey, nil)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -744,7 +832,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
ID: "eakID",
|
||||
ProvisionerID: provID,
|
||||
Reference: "testeak",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
},
|
||||
|
@ -762,10 +850,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
db: tc.db,
|
||||
}
|
||||
got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar)
|
||||
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
|
||||
got, err := validateExternalAccountBinding(ctx, tc.nar)
|
||||
wantErr := tc.err != nil
|
||||
gotErr := err != nil
|
||||
if wantErr != gotErr {
|
||||
|
@ -787,7 +873,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
|
|||
} else {
|
||||
assert.NotNil(t, tc.eak)
|
||||
assert.Equals(t, got.ID, tc.eak.ID)
|
||||
assert.Equals(t, got.KeyBytes, tc.eak.KeyBytes)
|
||||
assert.Equals(t, got.HmacKey, tc.eak.HmacKey)
|
||||
assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID)
|
||||
assert.Equals(t, got.Reference, tc.eak.Reference)
|
||||
assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt)
|
||||
|
|
|
@ -2,12 +2,10 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -16,6 +14,7 @@ import (
|
|||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
|
@ -39,111 +38,152 @@ type payloadInfo struct {
|
|||
isEmptyJSON bool
|
||||
}
|
||||
|
||||
// Handler is the ACME API request handler.
|
||||
type Handler struct {
|
||||
db acme.DB
|
||||
backdate provisioner.Duration
|
||||
ca acme.CertificateAuthority
|
||||
linker Linker
|
||||
validateChallengeOptions *acme.ValidateChallengeOptions
|
||||
prerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
}
|
||||
|
||||
// HandlerOptions required to create a new ACME API request handler.
|
||||
type HandlerOptions struct {
|
||||
Backdate provisioner.Duration
|
||||
// DB storage backend that impements the acme.DB interface.
|
||||
// DB storage backend that implements the acme.DB interface.
|
||||
//
|
||||
// Deprecated: use acme.NewContex(context.Context, acme.DB)
|
||||
DB acme.DB
|
||||
|
||||
// CA is the certificate authority interface.
|
||||
//
|
||||
// Deprecated: use authority.NewContext(context.Context, *authority.Authority)
|
||||
CA acme.CertificateAuthority
|
||||
|
||||
// Backdate is the duration that the CA will subtract from the current time
|
||||
// to set the NotBefore in the certificate.
|
||||
Backdate provisioner.Duration
|
||||
|
||||
// DNS the host used to generate accurate ACME links. By default the authority
|
||||
// will use the Host from the request, so this value will only be used if
|
||||
// request.Host is empty.
|
||||
DNS string
|
||||
|
||||
// Prefix is a URL path prefix under which the ACME api is served. This
|
||||
// prefix is required to generate accurate ACME links.
|
||||
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
|
||||
// "acme" is the prefix from which the ACME api is accessed.
|
||||
Prefix string
|
||||
CA acme.CertificateAuthority
|
||||
|
||||
// PrerequisitesChecker checks if all prerequisites for serving ACME are
|
||||
// met by the CA configuration.
|
||||
PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
}
|
||||
|
||||
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// handler is the ACME API request handler.
|
||||
type handler struct {
|
||||
opts *HandlerOptions
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface. For backward compatibility
|
||||
// this route adds will add a new middleware that will set the ACME components
|
||||
// on the context.
|
||||
//
|
||||
// Note: this method is deprecated in step-ca, other applications can still use
|
||||
// this to support ACME, but the recommendation is to use use
|
||||
// api.Route(api.Router) and acme.NewContext() instead.
|
||||
func (h *handler) Route(r api.Router) {
|
||||
client := acme.NewClient()
|
||||
linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix)
|
||||
route(r, func(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
|
||||
ctx = authority.NewContext(ctx, ca)
|
||||
}
|
||||
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// NewHandler returns a new ACME API handler.
|
||||
func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
prerequisitesChecker := func(ctx context.Context) (bool, error) {
|
||||
// by default all prerequisites are met
|
||||
return true, nil
|
||||
}
|
||||
if ops.PrerequisitesChecker != nil {
|
||||
prerequisitesChecker = ops.PrerequisitesChecker
|
||||
}
|
||||
return &Handler{
|
||||
ca: ops.CA,
|
||||
db: ops.DB,
|
||||
backdate: ops.Backdate,
|
||||
linker: NewLinker(ops.DNS, ops.Prefix),
|
||||
validateChallengeOptions: &acme.ValidateChallengeOptions{
|
||||
HTTPGet: client.Get,
|
||||
LookupTxt: net.LookupTXT,
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(dialer, network, addr, config)
|
||||
},
|
||||
},
|
||||
prerequisitesChecker: prerequisitesChecker,
|
||||
//
|
||||
// Note: this method is deprecated in step-ca, other applications can still use
|
||||
// this to support ACME, but the recommendation is to use use
|
||||
// api.Route(api.Router) and acme.NewContext() instead.
|
||||
func NewHandler(opts HandlerOptions) api.RouterHandler {
|
||||
return &handler{
|
||||
opts: &opts,
|
||||
}
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
getPath := h.linker.GetUnescapedPathSuffix
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||
// Route traffic and implement the Router interface. This method requires that
|
||||
// all the acme components, authority, db, client, linker, and prerequisite
|
||||
// checker to be present in the context.
|
||||
func Route(r api.Router) {
|
||||
route(r, nil)
|
||||
}
|
||||
|
||||
func route(r api.Router, middleware func(next nextHTTP) nextHTTP) {
|
||||
commonMiddleware := func(next nextHTTP) nextHTTP {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
// Linker middleware gets the provisioner and current url from the
|
||||
// request and sets them in the context.
|
||||
linker := acme.MustLinkerFromContext(r.Context())
|
||||
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
|
||||
}
|
||||
if middleware != nil {
|
||||
handler = middleware(handler)
|
||||
}
|
||||
return handler
|
||||
}
|
||||
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
||||
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))))
|
||||
return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
|
||||
}
|
||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
|
||||
return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))
|
||||
return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
|
||||
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next)))
|
||||
return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
|
||||
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
|
||||
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
|
||||
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented))
|
||||
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder))
|
||||
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
||||
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID)))
|
||||
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
||||
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
|
||||
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
|
||||
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
||||
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert))
|
||||
getPath := acme.GetUnescapedPathSuffix
|
||||
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
|
||||
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
|
||||
extractPayloadByJWK(NewAccount))
|
||||
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(GetOrUpdateAccount))
|
||||
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(NotImplemented))
|
||||
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
|
||||
extractPayloadByKid(NewOrder))
|
||||
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrder)))
|
||||
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
|
||||
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(FinalizeOrder))
|
||||
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetAuthorization)))
|
||||
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
|
||||
extractPayloadByKid(GetChallenge))
|
||||
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetCertificate)))
|
||||
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
|
||||
extractPayloadByKidOrJWK(RevokeCert))
|
||||
}
|
||||
|
||||
// GetNonce just sets the right header since a Nonce is added to each response
|
||||
// by middleware by default.
|
||||
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
func GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
|
@ -179,7 +219,7 @@ func (d *Directory) ToLog() (interface{}, error) {
|
|||
|
||||
// GetDirectory is the ACME resource for returning a directory configuration
|
||||
// for client configuration.
|
||||
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
func GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -187,12 +227,13 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
render.JSON(w, &Directory{
|
||||
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
|
||||
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
|
||||
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
|
||||
RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType),
|
||||
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType),
|
||||
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
|
||||
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
|
||||
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
|
||||
RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
|
||||
KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
|
||||
Meta: Meta{
|
||||
ExternalAccountRequired: acmeProv.RequireEAB,
|
||||
},
|
||||
|
@ -201,19 +242,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
||||
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
||||
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
func NotImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
||||
}
|
||||
|
||||
// GetAuthorization ACME api for retrieving an Authz.
|
||||
func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
func GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||
az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
||||
return
|
||||
|
@ -223,20 +267,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
|||
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
||||
return
|
||||
}
|
||||
if err = az.UpdateStatus(ctx, h.db); err != nil {
|
||||
if err = az.UpdateStatus(ctx, db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkAuthorization(ctx, az)
|
||||
linker.LinkAuthorization(ctx, az)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
|
||||
render.JSON(w, az)
|
||||
}
|
||||
|
||||
// GetChallenge ACME api for retrieving a Challenge.
|
||||
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -257,7 +304,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
|||
// we'll just ignore the body.
|
||||
|
||||
azID := chi.URLParam(r, "authzID")
|
||||
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||
ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
||||
return
|
||||
|
@ -273,29 +320,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
|
||||
if err = ch.Validate(ctx, db, jwk); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkChallenge(ctx, ch, azID)
|
||||
linker.LinkChallenge(ctx, ch, azID)
|
||||
|
||||
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
|
||||
render.JSON(w, ch)
|
||||
}
|
||||
|
||||
// GetCertificate ACME api for retrieving a Certificate.
|
||||
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
func GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
certID := chi.URLParam(r, "certID")
|
||||
|
||||
cert, err := h.db.GetCertificate(ctx, certID)
|
||||
certID := chi.URLParam(r, "certID")
|
||||
cert, err := db.GetCertificate(ctx, certID)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
||||
return
|
||||
|
|
|
@ -3,6 +3,7 @@ package api
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
|
@ -19,11 +20,33 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
get func(url string) (*http.Response, error)
|
||||
lookupTxt func(name string) ([]string, error)
|
||||
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) }
|
||||
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
|
||||
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return m.tlsDial(network, addr, config)
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_GetNonce(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -38,10 +61,10 @@ func TestHandler_GetNonce(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
// h := &Handler{}
|
||||
w := httptest.NewRecorder()
|
||||
req.Method = tt.name
|
||||
h.GetNonce(w, req)
|
||||
GetNonce(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -52,7 +75,8 @@ func TestHandler_GetNonce(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandler_GetDirectory(t *testing.T) {
|
||||
linker := NewLinker("ca.smallstep.com", "acme")
|
||||
linker := acme.NewLinker("ca.smallstep.com", "acme")
|
||||
_ = linker
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
|
@ -61,23 +85,14 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"),
|
||||
err: acme.NewErrorISE("provisioner is not in context"),
|
||||
}
|
||||
},
|
||||
"fail/different-provisioner": func(t *testing.T) test {
|
||||
prov := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
|
@ -88,8 +103,7 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
|
@ -108,8 +122,7 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
prov.RequireEAB = true
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
|
@ -130,11 +143,11 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: linker}
|
||||
ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetDirectory(w, req)
|
||||
GetDirectory(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -219,7 +232,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
|
@ -285,10 +298,9 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
|
||||
|
@ -304,11 +316,11 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAuthorization(w, req)
|
||||
GetAuthorization(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -447,11 +459,11 @@ func TestHandler_GetCertificate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetCertificate(w, req)
|
||||
GetCertificate(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -491,7 +503,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
|
||||
type test struct {
|
||||
db acme.DB
|
||||
vco *acme.ValidateChallengeOptions
|
||||
vc acme.Client
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
ch *acme.Challenge
|
||||
|
@ -500,6 +512,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -507,6 +520,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), accContextKey, nil),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -516,6 +530,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -523,10 +538,11 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload expected in request context"),
|
||||
|
@ -534,7 +550,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/db.GetChallenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -553,7 +569,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -572,7 +588,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/no-jwk": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -591,7 +607,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/nil-jwk": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||
|
@ -611,7 +627,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"fail/validate-challenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
|
@ -639,8 +655,8 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
return acme.NewErrorISE("force")
|
||||
},
|
||||
},
|
||||
vco: &acme.ValidateChallengeOptions{
|
||||
HTTPGet: func(string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -651,14 +667,13 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
_pub := _jwk.Public()
|
||||
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -690,8 +705,8 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
URL: u,
|
||||
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
||||
},
|
||||
vco: &acme.ValidateChallengeOptions{
|
||||
HTTPGet: func(string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -703,11 +718,11 @@ func TestHandler_GetChallenge(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
||||
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetChallenge(w, req)
|
||||
GetChallenge(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
|
||||
|
@ -31,39 +30,11 @@ func logNonce(w http.ResponseWriter, nonce string) {
|
|||
}
|
||||
}
|
||||
|
||||
// baseURLFromRequest determines the base URL which should be used for
|
||||
// constructing link URLs in e.g. the ACME directory result by taking the
|
||||
// request Host into consideration.
|
||||
//
|
||||
// If the Request.Host is an empty string, we return an empty string, to
|
||||
// indicate that the configured URL values should be used instead. If this
|
||||
// function returns a non-empty result, then this should be used in
|
||||
// constructing ACME link URLs.
|
||||
func baseURLFromRequest(r *http.Request) *url.URL {
|
||||
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
|
||||
// for an implementation that allows HTTP requests using the x-forwarded-proto
|
||||
// header.
|
||||
|
||||
if r.Host == "" {
|
||||
return nil
|
||||
}
|
||||
return &url.URL{Scheme: "https", Host: r.Host}
|
||||
}
|
||||
|
||||
// baseURLFromRequest is a middleware that extracts and caches the baseURL
|
||||
// from the request.
|
||||
// E.g. https://ca.smallstep.com/
|
||||
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// addNonce is a middleware that adds a nonce to the response header.
|
||||
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||
func addNonce(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
nonce, err := h.db.CreateNonce(r.Context())
|
||||
db := acme.MustDatabaseFromContext(r.Context())
|
||||
nonce, err := db.CreateNonce(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -77,25 +48,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
|||
|
||||
// addDirLink is a middleware that adds a 'Link' response reader with the
|
||||
// directory index url.
|
||||
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||
func addDirLink(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index"))
|
||||
ctx := r.Context()
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// verifyContentType is a middleware that verifies that content type is
|
||||
// application/jose+json.
|
||||
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||
func verifyContentType(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var expected []string
|
||||
p, err := provisionerFromContext(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
|
||||
u := &url.URL{
|
||||
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
|
||||
}
|
||||
|
||||
var expected []string
|
||||
if strings.Contains(r.URL.String(), u.EscapedPath()) {
|
||||
// GET /certificate requests allow a greater range of content types.
|
||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||
|
@ -117,7 +94,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
|
||||
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||
func parseJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
|
@ -142,17 +119,19 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
|||
// The JWS Unprotected Header [RFC7515] MUST NOT be used
|
||||
// The JWS Payload MUST NOT be detached
|
||||
// The JWS Protected Header MUST include the following fields:
|
||||
// * “alg” (Algorithm)
|
||||
// * This field MUST NOT contain “none” or a Message Authentication Code
|
||||
// (MAC) algorithm (e.g. one in which the algorithm registry description
|
||||
// mentions MAC/HMAC).
|
||||
// * “nonce” (defined in Section 6.5)
|
||||
// * “url” (defined in Section 6.4)
|
||||
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||
// - “alg” (Algorithm).
|
||||
// This field MUST NOT contain “none” or a Message Authentication Code
|
||||
// (MAC) algorithm (e.g. one in which the algorithm registry description
|
||||
// mentions MAC/HMAC).
|
||||
// - “nonce” (defined in Section 6.5)
|
||||
// - “url” (defined in Section 6.4)
|
||||
// - Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func validateJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(r.Context())
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -202,7 +181,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// Check the validity/freshness of the Nonce.
|
||||
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||
if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
@ -235,10 +214,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
|||
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
|
||||
// in the context. Make sure to parse and validate the JWS before running this
|
||||
// middleware.
|
||||
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||
func extractJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(r.Context())
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -264,7 +245,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
|||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
|
||||
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
|
||||
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||
acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||
switch {
|
||||
case errors.Is(err, acme.ErrNotFound):
|
||||
// For NewAccount and Revoke requests ...
|
||||
|
@ -283,63 +264,44 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
|||
}
|
||||
}
|
||||
|
||||
// lookupProvisioner loads the provisioner associated with the request.
|
||||
// Responds 404 if the provisioner does not exist.
|
||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
p, err := h.ca.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// checkPrerequisites checks if all prerequisites for serving ACME
|
||||
// are met by the CA configuration.
|
||||
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
|
||||
func checkPrerequisites(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ok, err := h.prerequisitesChecker(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
return
|
||||
// If the function is not set assume that all prerequisites are met.
|
||||
checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
|
||||
if ok {
|
||||
ok, err := checkFunc(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
next(w, r.WithContext(ctx))
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||
// kid parameter of the signed payload.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||
func lookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
|
||||
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
|
||||
kid := jws.Signatures[0].Protected.KeyID
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
|
@ -349,7 +311,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||
acc, err := h.db.GetAccount(ctx, accID)
|
||||
acc, err := db.GetAccount(ctx, accID)
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||
|
@ -372,7 +334,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
|||
|
||||
// extractOrLookupJWK forwards handling to either extractJWK or
|
||||
// lookupJWK based on the presence of a JWK or a KID, respectively.
|
||||
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||
func extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
|
@ -385,13 +347,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
|||
// and it can be used to check if a JWK exists. This flow is used when the ACME client
|
||||
// signed the payload with a certificate private key.
|
||||
if canExtractJWKFrom(jws) {
|
||||
h.extractJWK(next)(w, r)
|
||||
extractJWK(next)(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
|
||||
// signed the payload with an account private key.
|
||||
h.lookupJWK(next)(w, r)
|
||||
lookupJWK(next)(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -408,7 +370,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
|
|||
|
||||
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
|
@ -440,7 +402,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
|
||||
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
||||
func isPostAsGet(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, err := payloadFromContext(r.Context())
|
||||
if err != nil {
|
||||
|
@ -462,16 +424,12 @@ type ContextKey string
|
|||
const (
|
||||
// accContextKey account key
|
||||
accContextKey = ContextKey("acc")
|
||||
// baseURLContextKey baseURL key
|
||||
baseURLContextKey = ContextKey("baseURL")
|
||||
// jwsContextKey jws key
|
||||
jwsContextKey = ContextKey("jws")
|
||||
// jwkContextKey jwk key
|
||||
jwkContextKey = ContextKey("jwk")
|
||||
// payloadContextKey payload key
|
||||
payloadContextKey = ContextKey("payload")
|
||||
// provisionerContextKey provisioner key
|
||||
provisionerContextKey = ContextKey("provisioner")
|
||||
)
|
||||
|
||||
// accountFromContext searches the context for an ACME account. Returns the
|
||||
|
@ -484,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) {
|
|||
return val, nil
|
||||
}
|
||||
|
||||
// baseURLFromContext returns the baseURL if one is stored in the context.
|
||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
|
||||
if !ok || val == nil {
|
||||
return nil
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
|
||||
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
|
||||
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
|
||||
|
@ -514,29 +463,26 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
|
|||
// provisionerFromContext searches the context for a provisioner. Returns the
|
||||
// provisioner or an error.
|
||||
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
|
||||
val := ctx.Value(provisionerContextKey)
|
||||
if val == nil {
|
||||
p, ok := acme.ProvisionerFromContext(ctx)
|
||||
if !ok || p == nil {
|
||||
return nil, acme.NewErrorISE("provisioner expected in request context")
|
||||
}
|
||||
pval, ok := val.(acme.Provisioner)
|
||||
if !ok || pval == nil {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
return pval, nil
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
|
||||
// pointer to an ACME provisioner or an error.
|
||||
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
p, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
acmeProv, ok := prov.(*provisioner.ACME)
|
||||
if !ok || acmeProv == nil {
|
||||
ap, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
return acmeProv, nil
|
||||
|
||||
return ap, nil
|
||||
}
|
||||
|
||||
// payloadFromContext searches the context for a payload. Returns the payload
|
||||
|
|
|
@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) {
|
|||
w.Write(testBody)
|
||||
}
|
||||
|
||||
func Test_baseURLFromRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
targetURL string
|
||||
expectedResult *url.URL
|
||||
requestPreparer func(*http.Request)
|
||||
}{
|
||||
{
|
||||
"HTTPS host pass-through failed.",
|
||||
"https://my.dummy.host",
|
||||
&url.URL{Scheme: "https", Host: "my.dummy.host"},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Port pass-through failed",
|
||||
"https://host.with.port:8080",
|
||||
&url.URL{Scheme: "https", Host: "host.with.port:8080"},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Explicit host from Request.Host was not used.",
|
||||
"https://some.target.host:8080",
|
||||
&url.URL{Scheme: "https", Host: "proxied.host"},
|
||||
func(r *http.Request) {
|
||||
r.Host = "proxied.host"
|
||||
},
|
||||
},
|
||||
{
|
||||
"Missing Request.Host value did not result in empty string result.",
|
||||
"https://some.host",
|
||||
nil,
|
||||
func(r *http.Request) {
|
||||
r.Host = ""
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
request := httptest.NewRequest("GET", tc.targetURL, nil)
|
||||
if tc.requestPreparer != nil {
|
||||
tc.requestPreparer(request)
|
||||
}
|
||||
result := baseURLFromRequest(request)
|
||||
if result == nil || tc.expectedResult == nil {
|
||||
assert.Equals(t, result, tc.expectedResult)
|
||||
} else if result.String() != tc.expectedResult.String() {
|
||||
t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_baseURLFromRequest(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req.Host = "test.ca.smallstep.com:8080"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
next := func(w http.ResponseWriter, r *http.Request) {
|
||||
bu := baseURLFromContext(r.Context())
|
||||
if assert.NotNil(t, bu) {
|
||||
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
|
||||
assert.Equals(t, bu.Scheme, "https")
|
||||
func newBaseContext(ctx context.Context, args ...interface{}) context.Context {
|
||||
for _, a := range args {
|
||||
switch v := a.(type) {
|
||||
case acme.DB:
|
||||
ctx = acme.NewDatabaseContext(ctx, v)
|
||||
case acme.Linker:
|
||||
ctx = acme.NewLinkerContext(ctx, v)
|
||||
case acme.PrerequisitesChecker:
|
||||
ctx = acme.NewPrerequisitesCheckerContext(ctx, v)
|
||||
}
|
||||
}
|
||||
|
||||
h.baseURLFromRequest(next)(w, req)
|
||||
|
||||
req = httptest.NewRequest("GET", "/foo", nil)
|
||||
req.Host = ""
|
||||
|
||||
next = func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, baseURLFromContext(r.Context()), nil)
|
||||
}
|
||||
|
||||
h.baseURLFromRequest(next)(w, req)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func TestHandler_addNonce(t *testing.T) {
|
||||
|
@ -139,10 +74,10 @@ func TestHandler_addNonce(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
ctx := newBaseContext(context.Background(), tc.db)
|
||||
req := httptest.NewRequest("GET", u, nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.addNonce(testNext)(w, req)
|
||||
addNonce(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) {
|
|||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
type test struct {
|
||||
link string
|
||||
linker Linker
|
||||
statusCode int
|
||||
ctx context.Context
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"ok": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
|
||||
statusCode: 200,
|
||||
|
@ -195,11 +128,10 @@ func TestHandler_addDirLink(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: tc.linker}
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.addDirLink(testNext)(w, req)
|
||||
addDirLink(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
||||
type test struct {
|
||||
h Handler
|
||||
ctx context.Context
|
||||
contentType string
|
||||
err *acme.Error
|
||||
|
@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/provisioner-not-set": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
url: u,
|
||||
ctx: context.Background(),
|
||||
contentType: "foo",
|
||||
|
@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
},
|
||||
"fail/general-bad-content-type": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
url: u,
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "foo",
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
|
||||
|
@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
},
|
||||
"fail/certificate-bad-content-type": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "foo",
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
|
||||
|
@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "application/jose+json",
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/certificate/pkix-cert": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "application/pkix-cert",
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/certificate/jose+json": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "application/jose+json",
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/certificate/pkcs7-mime": func(t *testing.T) test {
|
||||
return test{
|
||||
h: Handler{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
contentType: "application/pkcs7-mime",
|
||||
statusCode: 200,
|
||||
}
|
||||
|
@ -326,7 +236,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|||
req = req.WithContext(tc.ctx)
|
||||
req.Header.Add("Content-Type", tc.contentType)
|
||||
w := httptest.NewRecorder()
|
||||
tc.h.verifyContentType(testNext)(w, req)
|
||||
verifyContentType(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -390,11 +300,11 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
// h := &Handler{}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.isPostAsGet(testNext)(w, req)
|
||||
isPostAsGet(testNext)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -481,10 +391,10 @@ func TestHandler_parseJWS(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
// h := &Handler{}
|
||||
req := httptest.NewRequest("GET", u, tc.body)
|
||||
w := httptest.NewRecorder()
|
||||
h.parseJWS(tc.next)(w, req)
|
||||
parseJWS(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -679,11 +589,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
// h := &Handler{}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||
verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
type test struct {
|
||||
linker Linker
|
||||
linker acme.Linker
|
||||
db acme.DB
|
||||
ctx context.Context
|
||||
next func(http.ResponseWriter, *http.Request)
|
||||
|
@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
}
|
||||
},
|
||||
"fail/nil-jws": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
_jws, err := _signer.Sign([]byte("baz"))
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
|
||||
|
@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
_parsed, err := jose.ParseJWS(_raw)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _parsed)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{},
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
|
||||
}
|
||||
},
|
||||
"fail/account-not-found": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||
assert.Equals(t, accID, accID)
|
||||
|
@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/GetAccount-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
|
@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-valid": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "deactivated"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
|
@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "valid", Key: jwk}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, id, accID)
|
||||
|
@ -881,11 +791,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: tc.linker}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.lookupJWK(tc.next)(w, req)
|
||||
lookupJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
}
|
||||
},
|
||||
"fail/nil-jws": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
|
||||
|
@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, _jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
|
||||
}
|
||||
},
|
||||
"fail/GetAccountByKey-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-valid": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "deactivated"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{Status: "valid"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok/no-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
|
@ -1077,11 +991,11 @@ func TestHandler_extractJWK(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
ctx := newBaseContext(tc.ctx, tc.db)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractJWK(tc.next)(w, req)
|
||||
extractJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
"fail/nil-jws": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, nil),
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
"fail/no-signature": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
|
||||
|
@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
|
||||
|
@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
|
||||
|
@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
|
||||
|
@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
|
||||
|
@ -1444,11 +1365,11 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
ctx := newBaseContext(tc.ctx, tc.db)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.validateJWS(tc.next)(w, req)
|
||||
validateJWS(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
u := "https://ca.smallstep.com/acme/account"
|
||||
type test struct {
|
||||
db acme.DB
|
||||
linker Linker
|
||||
linker acme.Linker
|
||||
statusCode int
|
||||
ctx context.Context
|
||||
err *acme.Error
|
||||
|
@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
||||
assert.Equals(t, kid, pub.KeyID)
|
||||
|
@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
linker: NewLinker("test.ca.smallstep.com", "acme"),
|
||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
|
@ -1628,11 +1548,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: tc.linker}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractOrLookupJWK(tc.next)(w, req)
|
||||
extractOrLookupJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||
baseURL, provName)
|
||||
type test struct {
|
||||
linker Linker
|
||||
linker acme.Linker
|
||||
ctx context.Context
|
||||
prerequisitesChecker func(context.Context) (bool, error)
|
||||
next func(http.ResponseWriter, *http.Request)
|
||||
|
@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/prerequisites-nok": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
linker: acme.NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1717,11 +1634,11 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
|
||||
ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker)
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.checkPrerequisites(tc.next)(w, req)
|
||||
checkPrerequisites(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
|
|
@ -16,6 +16,8 @@ import (
|
|||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// NewOrderRequest represents the body for a NewOrder request.
|
||||
|
@ -37,6 +39,8 @@ func (n *NewOrderRequest) Validate() error {
|
|||
if id.Type == acme.IP && net.ParseIP(id.Value) == nil {
|
||||
return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value)
|
||||
}
|
||||
// TODO(hs): add some validations for DNS domains?
|
||||
// TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -50,7 +54,13 @@ type FinalizeRequest struct {
|
|||
// Validate validates a finalize request body.
|
||||
func (f *FinalizeRequest) Validate() error {
|
||||
var err error
|
||||
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
|
||||
// RFC 8555 isn't 100% conclusive about using raw base64-url encoding for the
|
||||
// CSR specifically, instead of "normal" base64-url encoding (incl. padding).
|
||||
// By trimming the padding from CSRs submitted by ACME clients that use
|
||||
// base64-url encoding instead of raw base64-url encoding, these are also
|
||||
// supported. This was reported in https://github.com/smallstep/certificates/issues/939
|
||||
// to be the case for a Synology DSM NAS system.
|
||||
csrBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(f.CSR, "="))
|
||||
if err != nil {
|
||||
return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr")
|
||||
}
|
||||
|
@ -68,8 +78,12 @@ var defaultOrderExpiry = time.Hour * 24
|
|||
var defaultOrderBackdate = time.Minute
|
||||
|
||||
// NewOrder ACME api for creating a new order.
|
||||
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ca := mustAuthority(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -85,6 +99,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var nor NewOrderRequest
|
||||
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
|
@ -97,6 +112,48 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// TODO(hs): gather all errors, so that we can build one response with ACME subproblems
|
||||
// include the nor.Validate() error here too, like in the example in the ACME RFC?
|
||||
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var eak *acme.ExternalAccountKey
|
||||
if acmeProv.RequireEAB {
|
||||
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
acmePolicy, err := newACMEPolicyEngine(eak)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine"))
|
||||
return
|
||||
}
|
||||
|
||||
for _, identifier := range nor.Identifiers {
|
||||
// evaluate the ACME account level policy
|
||||
if err = isIdentifierAllowed(acmePolicy, identifier); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
// evaluate the provisioner level policy
|
||||
orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
|
||||
if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
// evaluate the authority level policy
|
||||
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
// New order.
|
||||
o := &acme.Order{
|
||||
|
@ -117,7 +174,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
ExpiresAt: o.ExpiresAt,
|
||||
Status: acme.StatusPending,
|
||||
}
|
||||
if err := h.newAuthorization(ctx, az); err != nil {
|
||||
if err := newAuthorization(ctx, az); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
@ -136,18 +193,32 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
|
||||
}
|
||||
|
||||
if err := h.db.CreateOrder(ctx, o); err != nil {
|
||||
if err := db.CreateOrder(ctx, o); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkOrder(ctx, o)
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSONStatus(w, o, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error {
|
||||
if acmePolicy == nil {
|
||||
return nil
|
||||
}
|
||||
return acmePolicy.AreSANsAllowed([]string{identifier.Value})
|
||||
}
|
||||
|
||||
func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error) {
|
||||
if eak == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return policy.NewX509PolicyEngine(eak.Policy)
|
||||
}
|
||||
|
||||
func newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||
if strings.HasPrefix(az.Identifier.Value, "*.") {
|
||||
az.Wildcard = true
|
||||
az.Identifier = acme.Identifier{
|
||||
|
@ -163,6 +234,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
|
|||
if err != nil {
|
||||
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
|
||||
}
|
||||
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
az.Challenges = make([]*acme.Challenge, len(chTypes))
|
||||
for i, typ := range chTypes {
|
||||
ch := &acme.Challenge{
|
||||
|
@ -172,20 +245,23 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
|
|||
Token: az.Token,
|
||||
Status: acme.StatusPending,
|
||||
}
|
||||
if err := h.db.CreateChallenge(ctx, ch); err != nil {
|
||||
if err := db.CreateChallenge(ctx, ch); err != nil {
|
||||
return acme.WrapErrorISE(err, "error creating challenge")
|
||||
}
|
||||
az.Challenges[i] = ch
|
||||
}
|
||||
if err = h.db.CreateAuthorization(ctx, az); err != nil {
|
||||
if err = db.CreateAuthorization(ctx, az); err != nil {
|
||||
return acme.WrapErrorISE(err, "error creating authorization")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOrder ACME api for retrieving an order.
|
||||
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
func GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -196,7 +272,8 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
|
@ -211,20 +288,23 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
|||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
return
|
||||
}
|
||||
if err = o.UpdateStatus(ctx, h.db); err != nil {
|
||||
if err = o.UpdateStatus(ctx, db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkOrder(ctx, o)
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
}
|
||||
|
||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
// FinalizeOrder attempts to finalize an order and create a certificate.
|
||||
func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -251,7 +331,7 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
|
@ -266,14 +346,16 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
|||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
return
|
||||
}
|
||||
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
|
||||
|
||||
ca := mustAuthority(ctx)
|
||||
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
|
||||
return
|
||||
}
|
||||
|
||||
h.linker.LinkOrder(ctx, o)
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
}
|
||||
|
||||
|
|
|
@ -16,9 +16,13 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"go.step.sm/crypto/pemutil"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
func TestNewOrderRequest_Validate(t *testing.T) {
|
||||
|
@ -206,6 +210,13 @@ func TestFinalizeRequestValidate(t *testing.T) {
|
|||
},
|
||||
}
|
||||
},
|
||||
"ok/padding": func(t *testing.T) test {
|
||||
return test{
|
||||
fr: &FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw) + "==", // add intentional padding
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
|
@ -276,15 +287,17 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -294,6 +307,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -301,9 +315,10 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), nil)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -311,7 +326,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/db.GetOrder-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
|
@ -325,7 +340,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
|
@ -341,7 +356,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/provisioner-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
|
@ -357,7 +372,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"fail/order-update-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
|
@ -381,10 +396,9 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
|
||||
|
@ -421,11 +435,11 @@ func TestHandler_GetOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrder(w, req)
|
||||
GetOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -636,8 +650,8 @@ func TestHandler_newAuthorization(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
h := &Handler{db: tc.db}
|
||||
if err := h.newAuthorization(context.Background(), tc.az); err != nil {
|
||||
ctx := newBaseContext(context.Background(), tc.db)
|
||||
if err := newAuthorization(ctx, tc.az); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
|
@ -667,6 +681,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
baseURL.String(), escProvName)
|
||||
|
||||
type test struct {
|
||||
ca acme.CertificateAuthority
|
||||
db acme.DB
|
||||
ctx context.Context
|
||||
nor *NewOrderRequest
|
||||
|
@ -677,15 +692,17 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -695,6 +712,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -702,9 +720,10 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -713,8 +732,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -722,21 +742,23 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("paylod does not exist"),
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"),
|
||||
|
@ -747,13 +769,230 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
fr := &NewOrderRequest{}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail/acmeProvisionerFromContext-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
fr := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "zap.internal"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{})
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: acme.NewErrorISE("error retrieving external account binding key: force"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetExternalAccountKeyByAccountID-error": func(t *testing.T) test {
|
||||
acmeProv := newACMEProv(t)
|
||||
acmeProv.RequireEAB = true
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
fr := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "zap.internal"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: acme.NewErrorISE("error retrieving external account binding key: force"),
|
||||
}
|
||||
},
|
||||
"fail/newACMEPolicyEngine-error": func(t *testing.T) test {
|
||||
acmeProv := newACMEProv(t)
|
||||
acmeProv.RequireEAB = true
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
fr := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "zap.internal"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return &acme.ExternalAccountKey{
|
||||
Policy: &acme.Policy{
|
||||
X509: acme.X509Policy{
|
||||
Allowed: acme.PolicyNames{
|
||||
DNSNames: []string{"**.local"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
err: acme.NewErrorISE("error creating ACME policy engine"),
|
||||
}
|
||||
},
|
||||
"fail/isIdentifierAllowed-error": func(t *testing.T) test {
|
||||
acmeProv := newACMEProv(t)
|
||||
acmeProv.RequireEAB = true
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
fr := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "zap.internal"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return &acme.ExternalAccountKey{
|
||||
Policy: &acme.Policy{
|
||||
X509: acme.X509Policy{
|
||||
Allowed: acme.PolicyNames{
|
||||
DNSNames: []string{"*.local"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
|
||||
}
|
||||
},
|
||||
"fail/prov.AuthorizeOrderIdentifier-error": func(t *testing.T) test {
|
||||
options := &provisioner.Options{
|
||||
X509: &provisioner.X509Options{
|
||||
AllowedNames: &policy.X509NameOptions{
|
||||
DNSDomains: []string{"*.local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
provWithPolicy := newACMEProvWithOptions(t, options)
|
||||
provWithPolicy.RequireEAB = true
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
fr := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "zap.internal"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return &acme.ExternalAccountKey{
|
||||
Policy: &acme.Policy{
|
||||
X509: acme.X509Policy{
|
||||
Allowed: acme.PolicyNames{
|
||||
DNSNames: []string{"*.internal"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
|
||||
}
|
||||
},
|
||||
"fail/ca.AreSANsAllowed-error": func(t *testing.T) test {
|
||||
options := &provisioner.Options{
|
||||
X509: &provisioner.X509Options{
|
||||
AllowedNames: &policy.X509NameOptions{
|
||||
DNSDomains: []string{"*.internal"},
|
||||
},
|
||||
},
|
||||
}
|
||||
provWithPolicy := newACMEProvWithOptions(t, options)
|
||||
provWithPolicy.RequireEAB = true
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
fr := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "zap.internal"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
ca: &mockCA{
|
||||
MockAreSANsallowed: func(ctx context.Context, sans []string) error {
|
||||
return errors.New("force: not authorized by authority")
|
||||
},
|
||||
},
|
||||
db: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return &acme.ExternalAccountKey{
|
||||
Policy: &acme.Policy{
|
||||
X509: acme.X509Policy{
|
||||
Allowed: acme.PolicyNames{
|
||||
DNSNames: []string{"*.internal"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
|
||||
}
|
||||
},
|
||||
"fail/error-h.newAuthorization": func(t *testing.T) test {
|
||||
|
@ -765,12 +1004,13 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
assert.Equals(t, ch.AccountID, "accID")
|
||||
|
@ -780,6 +1020,11 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
assert.Equals(t, ch.Value, "zap.internal")
|
||||
return errors.New("force")
|
||||
},
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: acme.NewErrorISE("error creating challenge: force"),
|
||||
}
|
||||
|
@ -793,7 +1038,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
var (
|
||||
|
@ -804,6 +1049,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
switch count {
|
||||
|
@ -849,6 +1095,11 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
|
||||
return errors.New("force")
|
||||
},
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: acme.NewErrorISE("error creating order: force"),
|
||||
}
|
||||
|
@ -863,10 +1114,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3, ch4 **acme.Challenge
|
||||
az1ID, az2ID *string
|
||||
|
@ -876,6 +1126,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
nor: nor,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
switch chCount {
|
||||
|
@ -945,6 +1196,11 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID})
|
||||
return nil
|
||||
},
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
vr: func(t *testing.T, o *acme.Order) {
|
||||
now := clock.Now()
|
||||
|
@ -978,10 +1234,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
|
@ -991,6 +1246,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
nor: nor,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
switch count {
|
||||
|
@ -1037,6 +1293,11 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
|
||||
return nil
|
||||
},
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
vr: func(t *testing.T, o *acme.Order) {
|
||||
now := clock.Now()
|
||||
|
@ -1070,10 +1331,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
|
@ -1083,6 +1343,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
nor: nor,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
switch count {
|
||||
|
@ -1129,6 +1390,11 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
|
||||
return nil
|
||||
},
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
vr: func(t *testing.T, o *acme.Order) {
|
||||
now := clock.Now()
|
||||
|
@ -1161,10 +1427,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
|
@ -1174,6 +1439,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
nor: nor,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
switch count {
|
||||
|
@ -1220,6 +1486,11 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
|
||||
return nil
|
||||
},
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
vr: func(t *testing.T, o *acme.Order) {
|
||||
testBufferDur := 5 * time.Second
|
||||
|
@ -1253,10 +1524,9 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
|
@ -1266,6 +1536,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
nor: nor,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
switch count {
|
||||
|
@ -1312,11 +1583,119 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
|
||||
return nil
|
||||
},
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
vr: func(t *testing.T, o *acme.Order) {
|
||||
testBufferDur := 5 * time.Second
|
||||
orderExpiry := now.Add(defaultOrderExpiry)
|
||||
|
||||
assert.Equals(t, o.ID, "ordID")
|
||||
assert.Equals(t, o.Status, acme.StatusPending)
|
||||
assert.Equals(t, o.Identifiers, nor.Identifiers)
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
|
||||
assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
|
||||
assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
|
||||
assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
|
||||
assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf))
|
||||
assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry))
|
||||
assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry))
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/default-naf-nbf-with-policy": func(t *testing.T) test {
|
||||
options := &provisioner.Options{
|
||||
X509: &provisioner.X509Options{
|
||||
AllowedNames: &policy.X509NameOptions{
|
||||
DNSDomains: []string{"*.internal"},
|
||||
},
|
||||
},
|
||||
}
|
||||
provWithPolicy := newACMEProvWithOptions(t, options)
|
||||
provWithPolicy.RequireEAB = true
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "zap.internal"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
var (
|
||||
ch1, ch2, ch3 **acme.Challenge
|
||||
az1ID *string
|
||||
count = 0
|
||||
)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
nor: nor,
|
||||
ca: &mockCA{},
|
||||
db: &acme.MockDB{
|
||||
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
|
||||
switch count {
|
||||
case 0:
|
||||
ch.ID = "dns"
|
||||
assert.Equals(t, ch.Type, acme.DNS01)
|
||||
ch1 = &ch
|
||||
case 1:
|
||||
ch.ID = "http"
|
||||
assert.Equals(t, ch.Type, acme.HTTP01)
|
||||
ch2 = &ch
|
||||
case 2:
|
||||
ch.ID = "tls"
|
||||
assert.Equals(t, ch.Type, acme.TLSALPN01)
|
||||
ch3 = &ch
|
||||
default:
|
||||
assert.FatalError(t, errors.New("test logic error"))
|
||||
return errors.New("force")
|
||||
}
|
||||
count++
|
||||
assert.Equals(t, ch.AccountID, "accID")
|
||||
assert.NotEquals(t, ch.Token, "")
|
||||
assert.Equals(t, ch.Status, acme.StatusPending)
|
||||
assert.Equals(t, ch.Value, "zap.internal")
|
||||
return nil
|
||||
},
|
||||
MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
|
||||
az.ID = "az1ID"
|
||||
az1ID = &az.ID
|
||||
assert.Equals(t, az.AccountID, "accID")
|
||||
assert.NotEquals(t, az.Token, "")
|
||||
assert.Equals(t, az.Status, acme.StatusPending)
|
||||
assert.Equals(t, az.Identifier, nor.Identifiers[0])
|
||||
assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
|
||||
assert.Equals(t, az.Wildcard, false)
|
||||
return nil
|
||||
},
|
||||
MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
|
||||
o.ID = "ordID"
|
||||
assert.Equals(t, o.AccountID, "accID")
|
||||
assert.Equals(t, o.ProvisionerID, prov.GetID())
|
||||
assert.Equals(t, o.Status, acme.StatusPending)
|
||||
assert.Equals(t, o.Identifiers, nor.Identifiers)
|
||||
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
|
||||
return nil
|
||||
},
|
||||
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, prov.GetID(), provisionerID)
|
||||
assert.Equals(t, "accID", accountID)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
vr: func(t *testing.T, o *acme.Order) {
|
||||
now := clock.Now()
|
||||
testBufferDur := 5 * time.Second
|
||||
orderExpiry := now.Add(defaultOrderExpiry)
|
||||
expNbf := now.Add(-defaultOrderBackdate)
|
||||
expNaf := now.Add(prov.DefaultTLSCertDuration())
|
||||
|
||||
assert.Equals(t, o.ID, "ordID")
|
||||
assert.Equals(t, o.Status, acme.StatusPending)
|
||||
assert.Equals(t, o.Identifiers, nor.Identifiers)
|
||||
|
@ -1334,11 +1713,12 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||
mockMustAuthority(t, tc.ca)
|
||||
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.NewOrder(w, req)
|
||||
NewOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1371,6 +1751,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandler_FinalizeOrder(t *testing.T) {
|
||||
mockMustAuthority(t, &mockCA{})
|
||||
prov := newProv()
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
|
@ -1429,15 +1810,17 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
db: &acme.MockDB{},
|
||||
ctx: acme.NewProvisionerContext(context.Background(), prov),
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
|
||||
|
@ -1447,6 +1830,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -1454,9 +1838,10 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -1465,8 +1850,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -1474,21 +1860,23 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("paylod does not exist"),
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"),
|
||||
|
@ -1499,10 +1887,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
fr := &FinalizeRequest{}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
|
||||
|
@ -1511,7 +1900,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
"fail/db.GetOrder-error": func(t *testing.T) test {
|
||||
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -1526,7 +1915,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -1543,7 +1932,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/provisioner-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -1560,7 +1949,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"fail/order-finalize-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
|
@ -1585,10 +1974,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
db: &acme.MockDB{
|
||||
|
@ -1624,11 +2012,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.FinalizeOrder(w, req)
|
||||
FinalizeOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
|
|
@ -26,9 +26,11 @@ type revokePayload struct {
|
|||
}
|
||||
|
||||
// RevokeCert attempts to revoke a certificate.
|
||||
func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -69,7 +71,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
serial := certToBeRevoked.SerialNumber.String()
|
||||
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
|
||||
dbCert, err := db.GetCertificateBySerial(ctx, serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||
return
|
||||
|
@ -87,7 +89,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||
acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||
if acmeErr != nil {
|
||||
render.Error(w, acmeErr)
|
||||
return
|
||||
|
@ -103,7 +105,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
|
||||
ca := mustAuthority(ctx)
|
||||
hasBeenRevokedBefore, err := ca.IsRevoked(serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
||||
return
|
||||
|
@ -130,14 +133,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
||||
err = h.ca.Revoke(ctx, options)
|
||||
err = ca.Revoke(ctx, options)
|
||||
if err != nil {
|
||||
render.Error(w, wrapRevokeErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
logRevoke(w, options)
|
||||
w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index"))
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
w.Write(nil)
|
||||
}
|
||||
|
||||
|
@ -148,7 +151,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
|
||||
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
|
||||
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
|
||||
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||
func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||
if !account.IsValid() {
|
||||
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
|
||||
}
|
||||
|
|
|
@ -24,14 +24,16 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
)
|
||||
|
||||
// v is a utility function to return the pointer to an integer
|
||||
|
@ -274,14 +276,22 @@ func jwsFinal(sha crypto.Hash, sig []byte, phead, payload string) ([]byte, error
|
|||
}
|
||||
|
||||
type mockCA struct {
|
||||
MockIsRevoked func(sn string) (bool, error)
|
||||
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
|
||||
MockIsRevoked func(sn string) (bool, error)
|
||||
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
|
||||
MockAreSANsallowed func(ctx context.Context, sans []string) error
|
||||
}
|
||||
|
||||
func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error {
|
||||
if m.MockAreSANsallowed != nil {
|
||||
return m.MockAreSANsallowed(ctx, sans)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCA) IsRevoked(sn string) (bool, error) {
|
||||
if m.MockIsRevoked != nil {
|
||||
return m.MockIsRevoked(sn)
|
||||
|
@ -511,6 +521,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/no-jws": func(t *testing.T) test {
|
||||
ctx := context.Background()
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -519,6 +530,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/nil-jws": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("jws expected in request context"),
|
||||
|
@ -527,6 +539,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -534,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, nil)
|
||||
ctx = acme.NewProvisionerContext(ctx, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner does not exist"),
|
||||
|
@ -543,8 +557,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -552,9 +567,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("payload does not exist"),
|
||||
|
@ -563,9 +579,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/unmarshal-payload": func(t *testing.T) test {
|
||||
malformedPayload := []byte(`{"payload":malformed?}`)
|
||||
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx = acme.NewProvisionerContext(ctx, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload})
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("error unmarshaling payload"),
|
||||
|
@ -577,10 +594,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: &acme.Error{
|
||||
|
@ -596,10 +614,11 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
emptyPayloadBytes, err := json.Marshal(emptyPayload)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
return test{
|
||||
db: &acme.MockDB{},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
err: &acme.Error{
|
||||
|
@ -610,7 +629,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/db.GetCertificateBySerial": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
db := &acme.MockDB{
|
||||
|
@ -628,7 +647,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
"fail/different-certificate-contents": func(t *testing.T) test {
|
||||
aDifferentCert, _, err := generateCertKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
db := &acme.MockDB{
|
||||
|
@ -647,7 +666,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
db := &acme.MockDB{
|
||||
|
@ -666,7 +685,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
|
@ -687,11 +706,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-valid": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -717,11 +735,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/account-not-authorized": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -771,10 +788,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -798,11 +814,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/certificate-revoked-check-fails": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -832,7 +847,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/certificate-already-revoked": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -870,7 +885,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -908,7 +923,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
}
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -940,7 +955,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/ca.Revoke": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -972,7 +987,7 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"fail/ca.Revoke-already-revoked": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
|
@ -1003,11 +1018,10 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
},
|
||||
"ok/using-account-key": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -1031,10 +1045,9 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(jwsBytes))
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
|
||||
ctx = context.WithValue(ctx, jwsContextKey, jws)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
db := &acme.MockDB{
|
||||
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
|
||||
|
@ -1057,11 +1070,12 @@ func TestHandler_RevokeCert(t *testing.T) {
|
|||
for name, setup := range tests {
|
||||
tc := setup(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
|
||||
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
|
||||
mockMustAuthority(t, tc.ca)
|
||||
req := httptest.NewRequest("POST", revokeURL, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.RevokeCert(w, req)
|
||||
RevokeCert(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
@ -1198,8 +1212,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
|
|||
for name, setup := range tests {
|
||||
tc := setup(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db}
|
||||
acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
|
||||
// h := &Handler{db: tc.db}
|
||||
acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
|
||||
|
||||
expectError := tc.err != nil
|
||||
gotError := acmeErr != nil
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) {
|
|||
// type using the DB interface.
|
||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||
// updated.
|
||||
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error {
|
||||
// If already valid or invalid then return without performing validation.
|
||||
if ch.Status != StatusPending {
|
||||
return nil
|
||||
}
|
||||
switch ch.Type {
|
||||
case HTTP01:
|
||||
return http01Validate(ctx, ch, db, jwk, vo)
|
||||
return http01Validate(ctx, ch, db, jwk)
|
||||
case DNS01:
|
||||
return dns01Validate(ctx, ch, db, jwk, vo)
|
||||
return dns01Validate(ctx, ch, db, jwk)
|
||||
case TLSALPN01:
|
||||
return tlsalpn01Validate(ctx, ch, db, jwk, vo)
|
||||
return tlsalpn01Validate(ctx, ch, db, jwk)
|
||||
default:
|
||||
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||
|
||||
resp, err := vo.HTTPGet(u.String())
|
||||
vc := MustClientFromContext(ctx)
|
||||
resp, err := vc.Get(u.String())
|
||||
if err != nil {
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||
"error doing http GET for url %s", u))
|
||||
|
@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
|
|||
return nil
|
||||
}
|
||||
|
||||
// http01ChallengeHost checks if a Challenge value is an IPv6 address
|
||||
// and adds square brackets if that's the case, so that it can be used
|
||||
// as a hostname. Returns the original Challenge value as the host to
|
||||
// use in other cases.
|
||||
func http01ChallengeHost(value string) string {
|
||||
if ip := net.ParseIP(value); ip != nil && ip.To4() == nil {
|
||||
value = "[" + value + "]"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func tlsAlert(err error) uint8 {
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
|
@ -130,7 +141,7 @@ func tlsAlert(err error) uint8 {
|
|||
return 0
|
||||
}
|
||||
|
||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
config := &tls.Config{
|
||||
NextProtos: []string{"acme-tls/1"},
|
||||
// https://tools.ietf.org/html/rfc8737#section-4
|
||||
|
@ -143,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
|
||||
hostPort := net.JoinHostPort(ch.Value, "443")
|
||||
|
||||
conn, err := vo.TLSDial("tcp", hostPort, config)
|
||||
vc := MustClientFromContext(ctx)
|
||||
conn, err := vc.TLSDial("tcp", hostPort, config)
|
||||
if err != nil {
|
||||
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
|
||||
// client and server protocols. When this happens the connection is
|
||||
|
@ -242,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
|
||||
}
|
||||
|
||||
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
// Normalize domain for wildcard DNS names
|
||||
// This is done to avoid making TXT lookups for domains like
|
||||
// _acme-challenge.*.example.com
|
||||
// Instead perform txt lookup for _acme-challenge.example.com
|
||||
domain := strings.TrimPrefix(ch.Value, "*.")
|
||||
|
||||
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
|
||||
vc := MustClientFromContext(ctx)
|
||||
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
|
||||
if err != nil {
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
|
||||
"error looking up TXT records for domain %s", domain))
|
||||
|
@ -365,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type httpGetter func(string) (*http.Response, error)
|
||||
type lookupTxt func(string) ([]string, error)
|
||||
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
|
||||
// ValidateChallengeOptions are ACME challenge validator functions.
|
||||
type ValidateChallengeOptions struct {
|
||||
HTTPGet httpGetter
|
||||
LookupTxt lookupTxt
|
||||
TLSDial tlsDialer
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
|
@ -23,11 +24,23 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"go.step.sm/crypto/jose"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
get func(url string) (*http.Response, error)
|
||||
lookupTxt func(name string) ([]string, error)
|
||||
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) }
|
||||
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
|
||||
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return m.tlsDial(network, addr, config)
|
||||
}
|
||||
|
||||
func Test_storeError(t *testing.T) {
|
||||
type test struct {
|
||||
ch *Challenge
|
||||
|
@ -228,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) {
|
|||
func TestChallenge_Validate(t *testing.T) {
|
||||
type test struct {
|
||||
ch *Challenge
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
srv *httptest.Server
|
||||
|
@ -272,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -308,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -343,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -380,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -415,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
}
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -465,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -492,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
defer tc.srv.Close()
|
||||
}
|
||||
|
||||
if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -523,7 +537,7 @@ func (errReader) Close() error {
|
|||
|
||||
func TestHTTP01Validate(t *testing.T) {
|
||||
type test struct {
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
|
@ -540,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -574,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -607,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: errReader(0),
|
||||
|
@ -644,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: errReader(0),
|
||||
|
@ -680,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: errReader(0),
|
||||
}, nil
|
||||
|
@ -703,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
jwk.Key = "foo"
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||
}, nil
|
||||
|
@ -729,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||
}, nil
|
||||
|
@ -771,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||
}, nil
|
||||
|
@ -814,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
||||
}, nil
|
||||
|
@ -856,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
||||
}, nil
|
||||
|
@ -886,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -910,7 +925,7 @@ func TestDNS01Validate(t *testing.T) {
|
|||
fulldomain := "*.zap.internal"
|
||||
domain := strings.TrimPrefix(fulldomain, "*.")
|
||||
type test struct {
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
|
@ -927,8 +942,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -962,8 +977,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1000,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo"}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1025,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", "bar"}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1067,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", "bar"}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1110,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", expected}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1155,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", expected}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1185,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -1205,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error)
|
||||
|
||||
func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) {
|
||||
srv := httptest.NewUnstartedServer(http.NewServeMux())
|
||||
|
||||
|
@ -1308,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
type test struct {
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
|
@ -1320,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
ch := makeTLSCh()
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1350,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
ch := makeTLSCh()
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1383,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1412,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.Client(&noopConn{}, config), nil
|
||||
},
|
||||
},
|
||||
|
@ -1442,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.Client(&noopConn{}, config), nil
|
||||
},
|
||||
},
|
||||
|
@ -1478,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
||||
},
|
||||
},
|
||||
|
@ -1515,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
||||
},
|
||||
},
|
||||
|
@ -1561,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1604,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1648,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1691,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1735,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
srv: srv,
|
||||
jwk: jwk,
|
||||
|
@ -1757,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1796,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1840,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1883,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1923,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1962,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2007,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2053,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2099,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2143,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2188,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2225,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2252,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
defer tc.srv.Close()
|
||||
}
|
||||
|
||||
if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -2350,3 +2369,34 @@ func Test_serverName(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_http01ChallengeHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "dns",
|
||||
value: "www.example.com",
|
||||
want: "www.example.com",
|
||||
},
|
||||
{
|
||||
name: "ipv4",
|
||||
value: "127.0.0.1",
|
||||
want: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
value: "::1",
|
||||
want: "[::1]",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := http01ChallengeHost(tt.value); got != tt.want {
|
||||
t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
79
acme/client.go
Normal file
79
acme/client.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client is the interface used to verify ACME challenges.
|
||||
type Client interface {
|
||||
// Get issues an HTTP GET to the specified URL.
|
||||
Get(url string) (*http.Response, error)
|
||||
|
||||
// LookupTXT returns the DNS TXT records for the given domain name.
|
||||
LookupTxt(name string) ([]string, error)
|
||||
|
||||
// TLSDial connects to the given network address using net.Dialer and then
|
||||
// initiates a TLS handshake, returning the resulting TLS connection.
|
||||
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
type clientKey struct{}
|
||||
|
||||
// NewClientContext adds the given client to the context.
|
||||
func NewClientContext(ctx context.Context, c Client) context.Context {
|
||||
return context.WithValue(ctx, clientKey{}, c)
|
||||
}
|
||||
|
||||
// ClientFromContext returns the current client from the given context.
|
||||
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
|
||||
c, ok = ctx.Value(clientKey{}).(Client)
|
||||
return
|
||||
}
|
||||
|
||||
// MustClientFromContext returns the current client from the given context. It will
|
||||
// return a new instance of the client if it does not exist.
|
||||
func MustClientFromContext(ctx context.Context) Client {
|
||||
c, ok := ClientFromContext(ctx)
|
||||
if !ok {
|
||||
return NewClient()
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type client struct {
|
||||
http *http.Client
|
||||
dialer *net.Dialer
|
||||
}
|
||||
|
||||
// NewClient returns an implementation of Client for verifying ACME challenges.
|
||||
func NewClient() Client {
|
||||
return &client{
|
||||
http: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
dialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Get(url string) (*http.Response, error) {
|
||||
return c.http.Get(url)
|
||||
}
|
||||
|
||||
func (c *client) LookupTxt(name string) ([]string, error) {
|
||||
return net.LookupTXT(name)
|
||||
}
|
||||
|
||||
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(c.dialer, network, addr, config)
|
||||
}
|
103
acme/common.go
103
acme/common.go
|
@ -9,14 +9,6 @@ import (
|
|||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
IsRevoked(sn string) (bool, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
|
||||
|
@ -27,9 +19,56 @@ func (c *Clock) Now() time.Time {
|
|||
|
||||
var clock Clock
|
||||
|
||||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
AreSANsAllowed(ctx context.Context, sans []string) error
|
||||
IsRevoked(sn string) (bool, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// NewContext adds the given acme components to the context.
|
||||
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
|
||||
ctx = NewDatabaseContext(ctx, db)
|
||||
ctx = NewClientContext(ctx, client)
|
||||
ctx = NewLinkerContext(ctx, linker)
|
||||
// Prerequisite checker is optional.
|
||||
if fn != nil {
|
||||
ctx = NewPrerequisitesCheckerContext(ctx, fn)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// PrerequisitesChecker is a function that checks if all prerequisites for
|
||||
// serving ACME are met by the CA configuration.
|
||||
type PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
|
||||
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
|
||||
// always true.
|
||||
func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type prerequisitesKey struct{}
|
||||
|
||||
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
|
||||
// context.
|
||||
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
|
||||
return context.WithValue(ctx, prerequisitesKey{}, fn)
|
||||
}
|
||||
|
||||
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
|
||||
// context.
|
||||
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
|
||||
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
|
||||
return fn, ok && fn != nil
|
||||
}
|
||||
|
||||
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
||||
// only those methods required by the ACME api/authority.
|
||||
type Provisioner interface {
|
||||
AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error
|
||||
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
|
||||
AuthorizeRevoke(ctx context.Context, token string) error
|
||||
GetID() string
|
||||
|
@ -38,16 +77,40 @@ type Provisioner interface {
|
|||
GetOptions() *provisioner.Options
|
||||
}
|
||||
|
||||
type provisionerKey struct{}
|
||||
|
||||
// NewProvisionerContext adds the given provisioner to the context.
|
||||
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
|
||||
return context.WithValue(ctx, provisionerKey{}, v)
|
||||
}
|
||||
|
||||
// ProvisionerFromContext returns the current provisioner from the given context.
|
||||
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
|
||||
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current provisioner from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustProvisionerFromContext(ctx context.Context) Provisioner {
|
||||
if v, ok := ProvisionerFromContext(ctx); !ok {
|
||||
panic("acme provisioner is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// MockProvisioner for testing
|
||||
type MockProvisioner struct {
|
||||
Mret1 interface{}
|
||||
Merr error
|
||||
MgetID func() string
|
||||
MgetName func() string
|
||||
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
MauthorizeRevoke func(ctx context.Context, token string) error
|
||||
MdefaultTLSCertDuration func() time.Duration
|
||||
MgetOptions func() *provisioner.Options
|
||||
Mret1 interface{}
|
||||
Merr error
|
||||
MgetID func() string
|
||||
MgetName func() string
|
||||
MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error
|
||||
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
MauthorizeRevoke func(ctx context.Context, token string) error
|
||||
MdefaultTLSCertDuration func() time.Duration
|
||||
MgetOptions func() *provisioner.Options
|
||||
}
|
||||
|
||||
// GetName mock
|
||||
|
@ -58,6 +121,14 @@ func (m *MockProvisioner) GetName() string {
|
|||
return m.Mret1.(string)
|
||||
}
|
||||
|
||||
// AuthorizeOrderIdentifiers mock
|
||||
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
|
||||
if m.MauthorizeOrderIdentifier != nil {
|
||||
return m.MauthorizeOrderIdentifier(ctx, identifier)
|
||||
}
|
||||
return m.Merr
|
||||
}
|
||||
|
||||
// AuthorizeSign mock
|
||||
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
if m.MauthorizeSign != nil {
|
||||
|
|
35
acme/db.go
35
acme/db.go
|
@ -23,6 +23,7 @@ type DB interface {
|
|||
GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
|
||||
GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
|
||||
GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
|
||||
GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
|
||||
DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error
|
||||
UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
|
||||
|
||||
|
@ -48,6 +49,29 @@ type DB interface {
|
|||
UpdateOrder(ctx context.Context, o *Order) error
|
||||
}
|
||||
|
||||
type dbKey struct{}
|
||||
|
||||
// NewDatabaseContext adds the given acme database to the context.
|
||||
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
|
||||
return context.WithValue(ctx, dbKey{}, db)
|
||||
}
|
||||
|
||||
// DatabaseFromContext returns the current acme database from the given context.
|
||||
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
|
||||
db, ok = ctx.Value(dbKey{}).(DB)
|
||||
return
|
||||
}
|
||||
|
||||
// MustDatabaseFromContext returns the current database from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustDatabaseFromContext(ctx context.Context) DB {
|
||||
if db, ok := DatabaseFromContext(ctx); !ok {
|
||||
panic("acme database is not in the context")
|
||||
} else {
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
// a mock in tests.
|
||||
type MockDB struct {
|
||||
|
@ -60,6 +84,7 @@ type MockDB struct {
|
|||
MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
|
||||
MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
|
||||
MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
|
||||
MockGetExternalAccountKeyByAccountID func(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
|
||||
MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error
|
||||
MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
|
||||
|
||||
|
@ -168,6 +193,16 @@ func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provision
|
|||
return m.MockRet1.(*ExternalAccountKey), m.MockError
|
||||
}
|
||||
|
||||
// GetExternalAccountKeyByAccountID mock
|
||||
func (m *MockDB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) {
|
||||
if m.MockGetExternalAccountKeyByAccountID != nil {
|
||||
return m.MockGetExternalAccountKeyByAccountID(ctx, provisionerID, accountID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*ExternalAccountKey), m.MockError
|
||||
}
|
||||
|
||||
// DeleteExternalAccountKey mock
|
||||
func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error {
|
||||
if m.MockDeleteExternalAccountKey != nil {
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
nosqlDB "github.com/smallstep/nosql"
|
||||
)
|
||||
|
@ -23,7 +24,7 @@ type dbExternalAccountKey struct {
|
|||
ProvisionerID string `json:"provisionerID"`
|
||||
Reference string `json:"reference"`
|
||||
AccountID string `json:"accountID,omitempty"`
|
||||
KeyBytes []byte `json:"key"`
|
||||
HmacKey []byte `json:"key"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
BoundAt time.Time `json:"boundAt"`
|
||||
}
|
||||
|
@ -72,7 +73,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
|
|||
ID: keyID,
|
||||
ProvisionerID: provisionerID,
|
||||
Reference: reference,
|
||||
KeyBytes: random,
|
||||
HmacKey: random,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
|
@ -99,7 +100,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
|
|||
ProvisionerID: dbeak.ProvisionerID,
|
||||
Reference: dbeak.Reference,
|
||||
AccountID: dbeak.AccountID,
|
||||
KeyBytes: dbeak.KeyBytes,
|
||||
HmacKey: dbeak.HmacKey,
|
||||
CreatedAt: dbeak.CreatedAt,
|
||||
BoundAt: dbeak.BoundAt,
|
||||
}, nil
|
||||
|
@ -124,7 +125,7 @@ func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID st
|
|||
ProvisionerID: dbeak.ProvisionerID,
|
||||
Reference: dbeak.Reference,
|
||||
AccountID: dbeak.AccountID,
|
||||
KeyBytes: dbeak.KeyBytes,
|
||||
HmacKey: dbeak.HmacKey,
|
||||
CreatedAt: dbeak.CreatedAt,
|
||||
BoundAt: dbeak.BoundAt,
|
||||
}, nil
|
||||
|
@ -191,7 +192,7 @@ func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor
|
|||
}
|
||||
keys = append(keys, &acme.ExternalAccountKey{
|
||||
ID: eak.ID,
|
||||
KeyBytes: eak.KeyBytes,
|
||||
HmacKey: eak.HmacKey,
|
||||
ProvisionerID: eak.ProvisionerID,
|
||||
Reference: eak.Reference,
|
||||
AccountID: eak.AccountID,
|
||||
|
@ -226,6 +227,10 @@ func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerI
|
|||
return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID)
|
||||
}
|
||||
|
||||
func (db *DB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
|
||||
externalAccountKeyMutex.Lock()
|
||||
defer externalAccountKeyMutex.Unlock()
|
||||
|
@ -252,7 +257,7 @@ func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string
|
|||
ProvisionerID: eak.ProvisionerID,
|
||||
Reference: eak.Reference,
|
||||
AccountID: eak.AccountID,
|
||||
KeyBytes: eak.KeyBytes,
|
||||
HmacKey: eak.HmacKey,
|
||||
CreatedAt: eak.CreatedAt,
|
||||
BoundAt: eak.BoundAt,
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
certdb "github.com/smallstep/certificates/db"
|
||||
|
@ -32,7 +33,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -108,7 +109,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
|
|||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, dbeak.ID, tc.dbeak.ID)
|
||||
assert.Equals(t, dbeak.KeyBytes, tc.dbeak.KeyBytes)
|
||||
assert.Equals(t, dbeak.HmacKey, tc.dbeak.HmacKey)
|
||||
assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID)
|
||||
assert.Equals(t, dbeak.Reference, tc.dbeak.Reference)
|
||||
assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt)
|
||||
|
@ -136,7 +137,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -154,7 +155,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
@ -179,7 +180,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: "aDifferentProvID",
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -197,7 +198,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: "ref",
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"),
|
||||
|
@ -225,7 +226,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
|
|||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, eak.ID, tc.eak.ID)
|
||||
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes)
|
||||
assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
|
||||
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
|
||||
assert.Equals(t, eak.Reference, tc.eak.Reference)
|
||||
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
|
||||
|
@ -255,7 +256,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -288,7 +289,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
err: nil,
|
||||
|
@ -392,7 +393,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
|
|||
assert.Equals(t, eak.AccountID, tc.eak.AccountID)
|
||||
assert.Equals(t, eak.BoundAt, tc.eak.BoundAt)
|
||||
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
|
||||
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes)
|
||||
assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
|
||||
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
|
||||
assert.Equals(t, eak.Reference, tc.eak.Reference)
|
||||
}
|
||||
|
@ -420,7 +421,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b1, err := json.Marshal(dbeak1)
|
||||
|
@ -430,7 +431,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b2, err := json.Marshal(dbeak2)
|
||||
|
@ -440,7 +441,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: "aDifferentProvID",
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b3, err := json.Marshal(dbeak3)
|
||||
|
@ -513,7 +514,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
|
@ -521,7 +522,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
},
|
||||
},
|
||||
|
@ -598,7 +599,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
|
|||
assert.Equals(t, "", nextCursor)
|
||||
for i, eak := range eaks {
|
||||
assert.Equals(t, eak.ID, tc.eaks[i].ID)
|
||||
assert.Equals(t, eak.KeyBytes, tc.eaks[i].KeyBytes)
|
||||
assert.Equals(t, eak.HmacKey, tc.eaks[i].HmacKey)
|
||||
assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID)
|
||||
assert.Equals(t, eak.Reference, tc.eaks[i].Reference)
|
||||
assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt)
|
||||
|
@ -627,7 +628,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -707,7 +708,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: "aDifferentProvID",
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -730,7 +731,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -780,7 +781,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -830,7 +831,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
dbref := &dbExternalAccountKeyReference{
|
||||
|
@ -953,7 +954,7 @@ func TestDB_CreateExternalAccountKey(t *testing.T) {
|
|||
assert.Equals(t, string(key), dbeak.ID)
|
||||
assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID)
|
||||
assert.Equals(t, eak.Reference, dbeak.Reference)
|
||||
assert.Equals(t, 32, len(dbeak.KeyBytes))
|
||||
assert.Equals(t, 32, len(dbeak.HmacKey))
|
||||
assert.False(t, dbeak.CreatedAt.IsZero())
|
||||
assert.Equals(t, dbeak.AccountID, eak.AccountID)
|
||||
assert.True(t, dbeak.BoundAt.IsZero())
|
||||
|
@ -1078,7 +1079,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(dbeak)
|
||||
|
@ -1096,7 +1097,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
return test{
|
||||
|
@ -1120,7 +1121,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
assert.Equals(t, dbNew.AccountID, dbeak.AccountID)
|
||||
assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt)
|
||||
assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt)
|
||||
assert.Equals(t, dbNew.KeyBytes, dbeak.KeyBytes)
|
||||
assert.Equals(t, dbNew.HmacKey, dbeak.HmacKey)
|
||||
return nu, true, nil
|
||||
},
|
||||
},
|
||||
|
@ -1148,7 +1149,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: "aDifferentProvID",
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(newDBEAK)
|
||||
|
@ -1174,7 +1175,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(newDBEAK)
|
||||
|
@ -1200,7 +1201,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
ProvisionerID: provID,
|
||||
Reference: ref,
|
||||
AccountID: "",
|
||||
KeyBytes: []byte{1, 3, 3, 7},
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: now,
|
||||
}
|
||||
b, err := json.Marshal(newDBEAK)
|
||||
|
@ -1237,7 +1238,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
|
|||
assert.Equals(t, dbeak.AccountID, tc.eak.AccountID)
|
||||
assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt)
|
||||
assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt)
|
||||
assert.Equals(t, dbeak.KeyBytes, tc.eak.KeyBytes)
|
||||
assert.Equals(t, dbeak.HmacKey, tc.eak.HmacKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,100 +1,19 @@
|
|||
package api
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
_, _, err := net.SplitHostPort(dns)
|
||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||
// these cases, then the input dns is not changed.
|
||||
lastIndex := strings.LastIndex(dns, ":")
|
||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||
if ip := net.ParseIP(hostPart); ip != nil {
|
||||
dns = "[" + hostPart + "]:" + portPart
|
||||
} else if ip := net.ParseIP(dns); ip != nil {
|
||||
dns = "[" + dns + "]"
|
||||
}
|
||||
}
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
|
||||
|
||||
LinkOrder(ctx context.Context, o *acme.Order)
|
||||
LinkAccount(ctx context.Context, o *acme.Account)
|
||||
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *acme.Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var (
|
||||
provName string
|
||||
baseURL = baseURLFromContext(ctx)
|
||||
u = url.URL{}
|
||||
)
|
||||
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
|
||||
provName = p.GetName()
|
||||
}
|
||||
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
|
||||
if baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
|
||||
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
|
||||
|
||||
// If no Scheme is set, then default to https.
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
|
||||
// If no Host is set, then use the default (first DNS attr in the ca.json).
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
u.Path = l.prefix + u.Path
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// LinkType captures the link type.
|
||||
type LinkType int
|
||||
|
||||
|
@ -160,8 +79,155 @@ func (l LinkType) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
_, _, err := net.SplitHostPort(dns)
|
||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||
// these cases, then the input dns is not changed.
|
||||
lastIndex := strings.LastIndex(dns, ":")
|
||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||
if ip := net.ParseIP(hostPart); ip != nil {
|
||||
dns = "[" + hostPart + "]:" + portPart
|
||||
} else if ip := net.ParseIP(dns); ip != nil {
|
||||
dns = "[" + dns + "]"
|
||||
}
|
||||
}
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||
Middleware(http.Handler) http.Handler
|
||||
LinkOrder(ctx context.Context, o *Order)
|
||||
LinkAccount(ctx context.Context, o *Account)
|
||||
LinkChallenge(ctx context.Context, o *Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
type linkerKey struct{}
|
||||
|
||||
// NewLinkerContext adds the given linker to the context.
|
||||
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
|
||||
return context.WithValue(ctx, linkerKey{}, v)
|
||||
}
|
||||
|
||||
// LinkerFromContext returns the current linker from the given context.
|
||||
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
|
||||
v, ok = ctx.Value(linkerKey{}).(Linker)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current linker from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustLinkerFromContext(ctx context.Context) Linker {
|
||||
if v, ok := LinkerFromContext(ctx); !ok {
|
||||
panic("acme linker is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
type baseURLKey struct{}
|
||||
|
||||
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
|
||||
var u *url.URL
|
||||
if r.Host != "" {
|
||||
u = &url.URL{Scheme: "https", Host: r.Host}
|
||||
}
|
||||
return context.WithValue(ctx, baseURLKey{}, u)
|
||||
}
|
||||
|
||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
|
||||
return u
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
// Middleware gets the provisioner and current url from the request and sets
|
||||
// them in the context so we can use the linker to create ACME links.
|
||||
func (l *linker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add base url to the context.
|
||||
ctx := newBaseURLContext(r.Context(), r)
|
||||
|
||||
// Add provisioner to the context.
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
|
||||
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit.
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var name string
|
||||
if p, ok := ProvisionerFromContext(ctx); ok {
|
||||
name = p.GetName()
|
||||
}
|
||||
|
||||
var u url.URL
|
||||
if baseURL := baseURLFromContext(ctx); baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// LinkOrder sets the ACME links required by an ACME order.
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *Order) {
|
||||
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
|
||||
for i, azID := range o.AuthorizationIDs {
|
||||
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
|
||||
|
@ -173,17 +239,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
|||
}
|
||||
|
||||
// LinkAccount sets the ACME links required by an ACME account.
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
|
||||
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
|
||||
}
|
||||
|
||||
// LinkChallenge sets the ACME links required by an ACME challenge.
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
|
||||
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
|
||||
}
|
||||
|
||||
// LinkAuthorization sets the ACME links required by an ACME authorization.
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
|
||||
for _, ch := range az.Challenges {
|
||||
l.LinkChallenge(ctx, ch, az.ID)
|
||||
}
|
|
@ -1,21 +1,38 @@
|
|||
package api
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
linker := NewLinker(dns, prefix)
|
||||
func mockProvisioner(t *testing.T) Provisioner {
|
||||
t.Helper()
|
||||
var defaultDisableRenewal = false
|
||||
|
||||
getPath := linker.GetUnescapedPathSuffix
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-<test>provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestGetUnescapedPathSuffix(t *testing.T) {
|
||||
getPath := GetUnescapedPathSuffix
|
||||
|
||||
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
|
||||
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
|
||||
|
@ -32,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLinker_DNS(t *testing.T) {
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
type test struct {
|
||||
name string
|
||||
dns string
|
||||
|
@ -117,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) {
|
|||
linker := NewLinker(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
// No provisioner and no BaseURL from request
|
||||
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
|
||||
// Provisioner: yes, BaseURL: no
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
|
||||
// Provisioner: no, BaseURL: yes
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
|
@ -163,37 +180,37 @@ func TestLinker_GetLink(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkOrder(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
oid := "orderID"
|
||||
certID := "certID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
o *acme.Order
|
||||
validate func(o *acme.Order)
|
||||
o *Order
|
||||
validate func(o *Order)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"no-authz-and-no-cert": {
|
||||
o: &acme.Order{
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{})
|
||||
assert.Equals(t, o.CertificateURL, "")
|
||||
},
|
||||
},
|
||||
"one-authz-and-cert": {
|
||||
o: &acme.Order{
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo"},
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
|
@ -202,12 +219,12 @@ func TestLinker_LinkOrder(t *testing.T) {
|
|||
},
|
||||
},
|
||||
"many-authz": {
|
||||
o: &acme.Order{
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo", "bar", "zap"},
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
|
@ -228,24 +245,24 @@ func TestLinker_LinkOrder(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkAccount(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
accID := "accountID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
a *acme.Account
|
||||
validate func(o *acme.Account)
|
||||
a *Account
|
||||
validate func(o *Account)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
a: &acme.Account{
|
||||
a: &Account{
|
||||
ID: accID,
|
||||
},
|
||||
validate: func(a *acme.Account) {
|
||||
validate: func(a *Account) {
|
||||
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
|
||||
},
|
||||
},
|
||||
|
@ -260,25 +277,25 @@ func TestLinker_LinkAccount(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkChallenge(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
chID := "chID"
|
||||
azID := "azID"
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
ch *acme.Challenge
|
||||
validate func(o *acme.Challenge)
|
||||
ch *Challenge
|
||||
validate func(o *Challenge)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
ch: &acme.Challenge{
|
||||
ch: &Challenge{
|
||||
ID: chID,
|
||||
},
|
||||
validate: func(ch *acme.Challenge) {
|
||||
validate: func(ch *Challenge) {
|
||||
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
|
||||
},
|
||||
},
|
||||
|
@ -293,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkAuthorization(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
chID0 := "chID-0"
|
||||
chID1 := "chID-1"
|
||||
|
@ -305,20 +322,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
|||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
az *acme.Authorization
|
||||
validate func(o *acme.Authorization)
|
||||
az *Authorization
|
||||
validate func(o *Authorization)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
az: &acme.Authorization{
|
||||
az: &Authorization{
|
||||
ID: azID,
|
||||
Challenges: []*acme.Challenge{
|
||||
Challenges: []*Challenge{
|
||||
{ID: chID0},
|
||||
{ID: chID1},
|
||||
{ID: chID2},
|
||||
},
|
||||
},
|
||||
validate: func(az *acme.Authorization) {
|
||||
validate: func(az *Authorization) {
|
||||
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
|
||||
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
|
||||
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
|
||||
|
@ -335,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
|
@ -268,6 +268,7 @@ func TestOrder_UpdateStatus(t *testing.T) {
|
|||
|
||||
type mockSignAuth struct {
|
||||
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
areSANsAllowed func(ctx context.Context, sans []string) error
|
||||
loadProvisionerByName func(string) (provisioner.Interface, error)
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
|
@ -282,6 +283,13 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S
|
|||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error {
|
||||
if m.areSANsAllowed != nil {
|
||||
return m.areSANsAllowed(ctx, sans)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByName != nil {
|
||||
return m.loadProvisionerByName(name)
|
||||
|
|
109
api/api.go
109
api/api.go
|
@ -35,7 +35,6 @@ type Authority interface {
|
|||
SSHAuthority
|
||||
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||
GetTLSOptions() *config.TLSOptions
|
||||
Root(shasum string) (*x509.Certificate, error)
|
||||
|
@ -53,6 +52,11 @@ type Authority interface {
|
|||
GetCertificateRevocationList() ([]byte, error)
|
||||
}
|
||||
|
||||
// mustAuthority will be replaced on unit tests.
|
||||
var mustAuthority = func(ctx context.Context) Authority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// TimeDuration is an alias of provisioner.TimeDuration
|
||||
type TimeDuration = provisioner.TimeDuration
|
||||
|
||||
|
@ -244,49 +248,54 @@ type caHandler struct {
|
|||
Authority Authority
|
||||
}
|
||||
|
||||
// New creates a new RouterHandler with the CA endpoints.
|
||||
func New(auth Authority) RouterHandler {
|
||||
return &caHandler{
|
||||
Authority: auth,
|
||||
}
|
||||
// Route configures the http request router.
|
||||
func (h *caHandler) Route(r Router) {
|
||||
Route(r)
|
||||
}
|
||||
|
||||
func (h *caHandler) Route(r Router) {
|
||||
r.MethodFunc("GET", "/version", h.Version)
|
||||
r.MethodFunc("GET", "/health", h.Health)
|
||||
r.MethodFunc("GET", "/root/{sha}", h.Root)
|
||||
r.MethodFunc("POST", "/sign", h.Sign)
|
||||
r.MethodFunc("POST", "/renew", h.Renew)
|
||||
r.MethodFunc("POST", "/rekey", h.Rekey)
|
||||
r.MethodFunc("POST", "/revoke", h.Revoke)
|
||||
r.MethodFunc("GET", "/crl", h.CRL)
|
||||
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
||||
r.MethodFunc("GET", "/roots", h.Roots)
|
||||
r.MethodFunc("GET", "/roots.pem", h.RootsPEM)
|
||||
r.MethodFunc("GET", "/federation", h.Federation)
|
||||
// New creates a new RouterHandler with the CA endpoints.
|
||||
//
|
||||
// Deprecated: Use api.Route(r Router)
|
||||
func New(auth Authority) RouterHandler {
|
||||
return &caHandler{}
|
||||
}
|
||||
|
||||
func Route(r Router) {
|
||||
r.MethodFunc("GET", "/version", Version)
|
||||
r.MethodFunc("GET", "/health", Health)
|
||||
r.MethodFunc("GET", "/root/{sha}", Root)
|
||||
r.MethodFunc("POST", "/sign", Sign)
|
||||
r.MethodFunc("POST", "/renew", Renew)
|
||||
r.MethodFunc("POST", "/rekey", Rekey)
|
||||
r.MethodFunc("POST", "/revoke", Revoke)
|
||||
r.MethodFunc("GET", "/crl", CRL)
|
||||
r.MethodFunc("GET", "/provisioners", Provisioners)
|
||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
|
||||
r.MethodFunc("GET", "/roots", Roots)
|
||||
r.MethodFunc("GET", "/roots.pem", RootsPEM)
|
||||
r.MethodFunc("GET", "/federation", Federation)
|
||||
// SSH CA
|
||||
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
|
||||
r.MethodFunc("POST", "/ssh/renew", h.SSHRenew)
|
||||
r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke)
|
||||
r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey)
|
||||
r.MethodFunc("GET", "/ssh/roots", h.SSHRoots)
|
||||
r.MethodFunc("GET", "/ssh/federation", h.SSHFederation)
|
||||
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
|
||||
r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts)
|
||||
r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion)
|
||||
r.MethodFunc("POST", "/ssh/sign", SSHSign)
|
||||
r.MethodFunc("POST", "/ssh/renew", SSHRenew)
|
||||
r.MethodFunc("POST", "/ssh/revoke", SSHRevoke)
|
||||
r.MethodFunc("POST", "/ssh/rekey", SSHRekey)
|
||||
r.MethodFunc("GET", "/ssh/roots", SSHRoots)
|
||||
r.MethodFunc("GET", "/ssh/federation", SSHFederation)
|
||||
r.MethodFunc("POST", "/ssh/config", SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost)
|
||||
r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts)
|
||||
r.MethodFunc("POST", "/ssh/bastion", SSHBastion)
|
||||
|
||||
// For compatibility with old code:
|
||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||
r.MethodFunc("POST", "/sign-ssh", h.SSHSign)
|
||||
r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts)
|
||||
r.MethodFunc("POST", "/re-sign", Renew)
|
||||
r.MethodFunc("POST", "/sign-ssh", SSHSign)
|
||||
r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts)
|
||||
}
|
||||
|
||||
// Version is an HTTP handler that returns the version of the server.
|
||||
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
||||
v := h.Authority.Version()
|
||||
func Version(w http.ResponseWriter, r *http.Request) {
|
||||
v := mustAuthority(r.Context()).Version()
|
||||
render.JSON(w, VersionResponse{
|
||||
Version: v.Version,
|
||||
RequireClientAuthentication: v.RequireClientAuthentication,
|
||||
|
@ -294,17 +303,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Health is an HTTP handler that returns the status of the server.
|
||||
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||
func Health(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSON(w, HealthResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
||||
// certificate for the given SHA256.
|
||||
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
||||
func Root(w http.ResponseWriter, r *http.Request) {
|
||||
sha := chi.URLParam(r, "sha")
|
||||
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
|
||||
// Load root certificate with the
|
||||
cert, err := h.Authority.Root(sum)
|
||||
cert, err := mustAuthority(r.Context()).Root(sum)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
return
|
||||
|
@ -322,18 +331,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
|||
}
|
||||
|
||||
// Provisioners returns the list of provisioners configured in the authority.
|
||||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
func Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := h.Authority.GetProvisioners(cursor, limit)
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionersResponse{
|
||||
Provisioners: p,
|
||||
NextCursor: next,
|
||||
|
@ -341,19 +351,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// ProvisionerKey returns the encrypted key of a provisioner by it's key id.
|
||||
func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
kid := chi.URLParam(r, "kid")
|
||||
key, err := h.Authority.GetEncryptedKey(kid)
|
||||
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
|
||||
if err != nil {
|
||||
render.Error(w, errs.NotFoundErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionerKeyResponse{key})
|
||||
}
|
||||
|
||||
// Roots returns all the root certificates for the CA.
|
||||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := h.Authority.GetRoots()
|
||||
func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
|
||||
return
|
||||
|
@ -370,8 +381,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// RootsPEM returns all the root certificates for the CA in PEM format.
|
||||
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := h.Authority.GetRoots()
|
||||
func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -393,8 +404,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Federation returns all the public certificates in the federation.
|
||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := h.Authority.GetFederation()
|
||||
func Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := mustAuthority(r.Context()).GetFederation()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||
return
|
||||
|
|
|
@ -171,10 +171,21 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
|
|||
return csr
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a Authority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) Authority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
type mockAuthority struct {
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
||||
authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||
getTLSOptions func() *authority.TLSOptions
|
||||
root func(shasum string) (*x509.Certificate, error)
|
||||
|
@ -207,12 +218,8 @@ func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) {
|
|||
|
||||
// TODO: remove once Authorize is deprecated.
|
||||
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return m.AuthorizeSign(ott)
|
||||
}
|
||||
|
||||
func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||
if m.authorizeSign != nil {
|
||||
return m.authorizeSign(ott)
|
||||
if m.authorize != nil {
|
||||
return m.authorize(ctx, ott)
|
||||
}
|
||||
return m.ret1.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
@ -793,11 +800,10 @@ func Test_caHandler_Route(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Health(t *testing.T) {
|
||||
func Test_Health(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h := New(&mockAuthority{}).(*caHandler)
|
||||
h.Health(w, req)
|
||||
Health(w, req)
|
||||
|
||||
res := w.Result()
|
||||
if res.StatusCode != 200 {
|
||||
|
@ -815,7 +821,7 @@ func Test_caHandler_Health(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Root(t *testing.T) {
|
||||
func Test_Root(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
root *x509.Certificate
|
||||
|
@ -836,9 +842,9 @@ func Test_caHandler_Root(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
|
||||
mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err})
|
||||
w := httptest.NewRecorder()
|
||||
h.Root(w, req)
|
||||
Root(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -859,7 +865,7 @@ func Test_caHandler_Root(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Sign(t *testing.T) {
|
||||
func Test_Sign(t *testing.T) {
|
||||
csr := parseCertificateRequest(csrPEM)
|
||||
valid, err := json.Marshal(SignRequest{
|
||||
CsrPEM: CertificateRequest{csr},
|
||||
|
@ -900,18 +906,18 @@ func Test_caHandler_Sign(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return tt.certAttrOpts, tt.autherr
|
||||
},
|
||||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
|
||||
w := httptest.NewRecorder()
|
||||
h.Sign(logging.NewResponseLogger(w), req)
|
||||
Sign(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -932,7 +938,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Renew(t *testing.T) {
|
||||
func Test_Renew(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1022,7 +1028,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
|
||||
|
@ -1043,12 +1049,12 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
||||
req.TLS = tt.tls
|
||||
req.Header = tt.header
|
||||
w := httptest.NewRecorder()
|
||||
h.Renew(logging.NewResponseLogger(w), req)
|
||||
Renew(logging.NewResponseLogger(w), req)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
@ -1077,7 +1083,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Rekey(t *testing.T) {
|
||||
func Test_Rekey(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1108,16 +1114,16 @@ func Test_caHandler_Rekey(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
h.Rekey(logging.NewResponseLogger(w), req)
|
||||
Rekey(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -1138,7 +1144,7 @@ func Test_caHandler_Rekey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Provisioners(t *testing.T) {
|
||||
func Test_Provisioners(t *testing.T) {
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
}
|
||||
|
@ -1204,10 +1210,8 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &caHandler{
|
||||
Authority: tt.fields.Authority,
|
||||
}
|
||||
h.Provisioners(tt.args.w, tt.args.r)
|
||||
mockMustAuthority(t, tt.fields.Authority)
|
||||
Provisioners(tt.args.w, tt.args.r)
|
||||
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
res := rec.Result()
|
||||
|
@ -1242,7 +1246,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_ProvisionerKey(t *testing.T) {
|
||||
func Test_ProvisionerKey(t *testing.T) {
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
}
|
||||
|
@ -1274,10 +1278,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &caHandler{
|
||||
Authority: tt.fields.Authority,
|
||||
}
|
||||
h.ProvisionerKey(tt.args.w, tt.args.r)
|
||||
mockMustAuthority(t, tt.fields.Authority)
|
||||
ProvisionerKey(tt.args.w, tt.args.r)
|
||||
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
res := rec.Result()
|
||||
|
@ -1302,7 +1304,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Roots(t *testing.T) {
|
||||
func Test_Roots(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1323,11 +1325,11 @@ func Test_caHandler_Roots(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
h.Roots(w, req)
|
||||
Roots(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -1364,10 +1366,10 @@ func Test_caHandler_RootsPEM(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler)
|
||||
mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err})
|
||||
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.RootsPEM(w, req)
|
||||
RootsPEM(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -1388,7 +1390,7 @@ func Test_caHandler_RootsPEM(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_Federation(t *testing.T) {
|
||||
func Test_Federation(t *testing.T) {
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
@ -1409,11 +1411,11 @@ func Test_caHandler_Federation(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
h.Federation(w, req)
|
||||
Federation(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
|
|
@ -3,16 +3,20 @@ package read
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// JSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
// pointed to by v.
|
||||
func JSON(r io.Reader, v interface{}) error {
|
||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||
return errs.BadRequestErr(err, "error decoding json")
|
||||
|
@ -21,11 +25,42 @@ func JSON(r io.Reader, v interface{}) error {
|
|||
}
|
||||
|
||||
// ProtoJSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
// pointed to by m.
|
||||
func ProtoJSON(r io.Reader, m proto.Message) error {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return errs.BadRequestErr(err, "error reading request body")
|
||||
}
|
||||
return protojson.Unmarshal(data, m)
|
||||
|
||||
switch err := protojson.Unmarshal(data, m); {
|
||||
case errors.Is(err, proto.Error):
|
||||
return badProtoJSONError(err.Error())
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// badProtoJSONError is an error type that is returned by ProtoJSON
|
||||
// when a proto message cannot be unmarshaled. Usually this is caused
|
||||
// by an error in the request body.
|
||||
type badProtoJSONError string
|
||||
|
||||
// Error implements error for badProtoJSONError
|
||||
func (e badProtoJSONError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Render implements render.RenderableError for badProtoJSONError
|
||||
func (e badProtoJSONError) Render(w http.ResponseWriter) {
|
||||
v := struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Type: "badRequest",
|
||||
Detail: "bad request",
|
||||
// trim the proto prefix for the message
|
||||
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
|
||||
}
|
||||
render.JSONStatus(w, v, http.StatusBadRequest)
|
||||
}
|
||||
|
|
|
@ -1,10 +1,21 @@
|
|||
package read
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
@ -44,3 +55,110 @@ func TestJSON(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtoJSON(t *testing.T) {
|
||||
|
||||
p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import?
|
||||
|
||||
type args struct {
|
||||
r io.Reader
|
||||
m proto.Message
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "fail/io.ReadAll",
|
||||
args: args{
|
||||
r: iotest.ErrReader(errors.New("read error")),
|
||||
m: p,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail/proto",
|
||||
args: args{
|
||||
r: strings.NewReader(`{?}`),
|
||||
m: p,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
r: strings.NewReader(`{"x509":{}}`),
|
||||
m: p,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ProtoJSON(tt.args.r, tt.args.m)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
switch err.(type) {
|
||||
case badProtoJSONError:
|
||||
assert.Contains(t, err.Error(), "syntax error")
|
||||
case *errs.Error:
|
||||
var ee *errs.Error
|
||||
if errors.As(err, &ee) {
|
||||
assert.Equal(t, http.StatusBadRequest, ee.Status)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m))
|
||||
assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_badProtoJSONError_Render(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e badProtoJSONError
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "bad proto normal space",
|
||||
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
|
||||
expected: "syntax error (line 1:2): invalid value ?",
|
||||
},
|
||||
{
|
||||
name: "bad proto non breaking space",
|
||||
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
|
||||
expected: "syntax error (line 1:2): invalid value ?",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tt.e.Render(w)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Message string `json:"message"`
|
||||
}{}
|
||||
|
||||
assert.NoError(t, json.Unmarshal(data, &v))
|
||||
assert.Equal(t, "badRequest", v.Type)
|
||||
assert.Equal(t, "bad request", v.Detail)
|
||||
assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error {
|
|||
}
|
||||
|
||||
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
||||
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
func Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
render.Error(w, errs.BadRequest("missing client certificate"))
|
||||
return
|
||||
|
@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||
return
|
||||
|
@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
|||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
|
14
api/renew.go
14
api/renew.go
|
@ -16,14 +16,15 @@ const (
|
|||
|
||||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||
cert, err := h.getPeerCertificate(r)
|
||||
func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
cert, err := getPeerCertificate(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(cert)
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Renew(cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
|
@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
|||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
return r.TLS.PeerCertificates[0], nil
|
||||
}
|
||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
|
||||
ctx := r.Context()
|
||||
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||
}
|
||||
}
|
||||
return nil, errs.BadRequest("missing client certificate")
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) {
|
|||
// NOTE: currently only Passive revocation is supported.
|
||||
//
|
||||
// TODO: Add CRL and OCSP support.
|
||||
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body RevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
PassiveOnly: body.Passive,
|
||||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod)
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
if len(body.OTT) > 0 {
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
opts.MTLS = true
|
||||
}
|
||||
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusOK,
|
||||
auth: &mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
|
@ -152,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
statusCode: http.StatusOK,
|
||||
tls: cs,
|
||||
auth: &mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, ri *authority.RevokeOptions) error {
|
||||
|
@ -187,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusInternalServerError,
|
||||
auth: &mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
|
@ -209,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusForbidden,
|
||||
auth: &mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
|
@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
for name, _tc := range tests {
|
||||
tc := _tc(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*caHandler)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
|
||||
if tc.tls != nil {
|
||||
req.TLS = tc.tls
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
h.Revoke(logging.NewResponseLogger(w), req)
|
||||
Revoke(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
12
api/sign.go
12
api/sign.go
|
@ -49,7 +49,7 @@ type SignResponse struct {
|
|||
// Sign is an HTTP handler that reads a certificate request and an
|
||||
// one-time-token (ott) from the body and creates a new certificate with the
|
||||
// information in the certificate request.
|
||||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||
func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -68,13 +68,17 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
|||
TemplateData: body.TemplateData,
|
||||
}
|
||||
|
||||
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
||||
ctx := r.Context()
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
|
||||
return
|
||||
|
@ -89,6 +93,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
|||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
TLSOptions: a.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
|
45
api/ssh.go
45
api/ssh.go
|
@ -250,7 +250,7 @@ type SSHBastionResponse struct {
|
|||
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHSignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -288,13 +288,16 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||
cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
|
@ -302,7 +305,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var addUserCertificate *SSHCertificate
|
||||
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
||||
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||
addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
|
@ -315,7 +318,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
if cr := body.IdentityCSR.CertificateRequest; cr != nil {
|
||||
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
|
@ -327,7 +330,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
|
||||
})
|
||||
|
||||
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||
certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||
return
|
||||
|
@ -344,8 +347,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
|
||||
// certificates.
|
||||
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHRoots(r.Context())
|
||||
func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -369,8 +373,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
||||
// for user and host certificates.
|
||||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := h.Authority.GetSSHFederation(r.Context())
|
||||
func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -394,7 +399,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
||||
// and servers.
|
||||
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHConfigRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -405,7 +410,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
|
||||
ctx := r.Context()
|
||||
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -426,7 +432,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHCheckPrincipalRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -437,7 +443,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
|
||||
ctx := r.Context()
|
||||
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -448,13 +455,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
||||
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
var cert *x509.Certificate
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
cert = r.TLS.PeerCertificates[0]
|
||||
}
|
||||
|
||||
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
|
||||
ctx := r.Context()
|
||||
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -465,7 +473,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// SSHBastion provides returns the bastion configured if any.
|
||||
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHBastionRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -476,7 +484,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
|
||||
ctx := r.Context()
|
||||
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
|
|
@ -39,7 +39,7 @@ type SSHRekeyResponse struct {
|
|||
// SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -59,7 +59,10 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
|
||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
|
@ -70,7 +73,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||
newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||
return
|
||||
|
@ -80,7 +83,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
|
|
|
@ -37,7 +37,7 @@ type SSHRenewResponse struct {
|
|||
// SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token
|
||||
// (ott) from the body and creates a new SSH certificate with the information in
|
||||
// the request.
|
||||
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRenewRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -51,7 +51,10 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
|
||||
_, err := h.Authority.Authorize(ctx, body.OTT)
|
||||
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
_, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
|
@ -62,7 +65,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
|
||||
newCert, err := a.RenewSSH(ctx, oldCert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||
return
|
||||
|
@ -72,7 +75,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||
|
||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
|
@ -85,7 +88,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
|
||||
func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||
func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -105,7 +108,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte
|
|||
cert.NotAfter = notAfter
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(cert)
|
||||
certChain, err := mustAuthority(r.Context()).Renew(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
|||
// Revoke supports handful of different methods that revoke a Certificate.
|
||||
//
|
||||
// NOTE: currently only Passive revocation is supported.
|
||||
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
|
@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHSign(t *testing.T) {
|
||||
func Test_SSHSign(t *testing.T) {
|
||||
user, err := getSignedUserCertificate()
|
||||
assert.FatalError(t, err)
|
||||
host, err := getSignedHostCertificate()
|
||||
|
@ -315,8 +315,8 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return []provisioner.SignOption{}, tt.authErr
|
||||
},
|
||||
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
|
@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
|||
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
return tt.tlsSignCerts, tt.tlsSignErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHSign(logging.NewResponseLogger(w), req)
|
||||
SSHSign(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHRoots(t *testing.T) {
|
||||
func Test_SSHRoots(t *testing.T) {
|
||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
|
@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHRoots(logging.NewResponseLogger(w), req)
|
||||
SSHRoots(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHFederation(t *testing.T) {
|
||||
func Test_SSHFederation(t *testing.T) {
|
||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||
assert.FatalError(t, err)
|
||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||
|
@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHFederation(logging.NewResponseLogger(w), req)
|
||||
SSHFederation(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHConfig(t *testing.T) {
|
||||
func Test_SSHConfig(t *testing.T) {
|
||||
userOutput := []templates.Output{
|
||||
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
|
||||
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
|
||||
|
@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||
return tt.output, tt.err
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHConfig(logging.NewResponseLogger(w), req)
|
||||
SSHConfig(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||
func Test_SSHCheckHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req string
|
||||
|
@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
|
||||
return tt.exists, tt.err
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||
SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHGetHosts(t *testing.T) {
|
||||
func Test_SSHGetHosts(t *testing.T) {
|
||||
hosts := []authority.Host{
|
||||
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
|
||||
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
|
||||
|
@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
|
||||
return tt.hosts, tt.err
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||
SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_caHandler_SSHBastion(t *testing.T) {
|
||||
func Test_SSHBastion(t *testing.T) {
|
||||
bastion := &authority.Bastion{
|
||||
Hostname: "bastion.local",
|
||||
}
|
||||
|
@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
mockMustAuthority(t, &mockAuthority{
|
||||
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||
return tt.bastion, tt.bastionErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
|
||||
w := httptest.NewRecorder()
|
||||
h.SSHBastion(logging.NewResponseLogger(w), req)
|
||||
SSHBastion(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
|
|
|
@ -1,22 +1,15 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
const (
|
||||
// provisionerContextKey provisioner key
|
||||
provisionerContextKey = ContextKey("provisioner")
|
||||
)
|
||||
|
||||
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests
|
||||
|
@ -40,78 +33,121 @@ type GetExternalAccountKeysResponse struct {
|
|||
|
||||
// requireEABEnabled is a middleware that ensures ACME EAB is enabled
|
||||
// before serving requests that act on ACME EAB credentials.
|
||||
func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
|
||||
func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
provName := chi.URLParam(r, "provisionerName")
|
||||
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
|
||||
acmeProvisioner := prov.GetDetails().GetACME()
|
||||
if acmeProvisioner == nil {
|
||||
render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName()))
|
||||
return
|
||||
}
|
||||
if !eabEnabled {
|
||||
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
|
||||
|
||||
if !acmeProvisioner.RequireEab {
|
||||
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName()))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
next(w, r.WithContext(ctx))
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME
|
||||
// provisioner is set to true and thus has EAB enabled.
|
||||
func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil {
|
||||
return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName)
|
||||
}
|
||||
|
||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID())
|
||||
}
|
||||
|
||||
details := prov.GetDetails()
|
||||
if details == nil {
|
||||
return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID())
|
||||
}
|
||||
|
||||
acmeProvisioner := details.GetACME()
|
||||
if acmeProvisioner == nil {
|
||||
return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID())
|
||||
}
|
||||
|
||||
return acmeProvisioner.GetRequireEab(), prov, nil
|
||||
}
|
||||
|
||||
type acmeAdminResponderInterface interface {
|
||||
// ACMEAdminResponder is responsible for writing ACME admin responses
|
||||
type ACMEAdminResponder interface {
|
||||
GetExternalAccountKeys(w http.ResponseWriter, r *http.Request)
|
||||
CreateExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
||||
DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// ACMEAdminResponder is responsible for writing ACME admin responses
|
||||
type ACMEAdminResponder struct{}
|
||||
// acmeAdminResponder implements ACMEAdminResponder.
|
||||
type acmeAdminResponder struct{}
|
||||
|
||||
// NewACMEAdminResponder returns a new ACMEAdminResponder
|
||||
func NewACMEAdminResponder() *ACMEAdminResponder {
|
||||
return &ACMEAdminResponder{}
|
||||
func NewACMEAdminResponder() ACMEAdminResponder {
|
||||
return &acmeAdminResponder{}
|
||||
}
|
||||
|
||||
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
|
||||
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
|
||||
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
|
||||
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey {
|
||||
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
eak := &linkedca.EABKey{
|
||||
Id: k.ID,
|
||||
HmacKey: k.HmacKey,
|
||||
Provisioner: k.ProvisionerID,
|
||||
Reference: k.Reference,
|
||||
Account: k.AccountID,
|
||||
CreatedAt: timestamppb.New(k.CreatedAt),
|
||||
BoundAt: timestamppb.New(k.BoundAt),
|
||||
}
|
||||
|
||||
if k.Policy != nil {
|
||||
eak.Policy = &linkedca.Policy{
|
||||
X509: &linkedca.X509Policy{
|
||||
Allow: &linkedca.X509Names{},
|
||||
Deny: &linkedca.X509Names{},
|
||||
},
|
||||
}
|
||||
eak.Policy.X509.Allow.Dns = k.Policy.X509.Allowed.DNSNames
|
||||
eak.Policy.X509.Allow.Ips = k.Policy.X509.Allowed.IPRanges
|
||||
eak.Policy.X509.Deny.Dns = k.Policy.X509.Denied.DNSNames
|
||||
eak.Policy.X509.Deny.Ips = k.Policy.X509.Denied.IPRanges
|
||||
eak.Policy.X509.AllowWildcardNames = k.Policy.X509.AllowWildcardNames
|
||||
}
|
||||
|
||||
return eak
|
||||
}
|
||||
|
||||
func linkedEAKToCertificates(k *linkedca.EABKey) *acme.ExternalAccountKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
eak := &acme.ExternalAccountKey{
|
||||
ID: k.Id,
|
||||
ProvisionerID: k.Provisioner,
|
||||
Reference: k.Reference,
|
||||
AccountID: k.Account,
|
||||
HmacKey: k.HmacKey,
|
||||
CreatedAt: k.CreatedAt.AsTime(),
|
||||
BoundAt: k.BoundAt.AsTime(),
|
||||
}
|
||||
|
||||
if policy := k.GetPolicy(); policy != nil {
|
||||
eak.Policy = &acme.Policy{}
|
||||
if x509 := policy.GetX509(); x509 != nil {
|
||||
eak.Policy.X509 = acme.X509Policy{}
|
||||
if allow := x509.GetAllow(); allow != nil {
|
||||
eak.Policy.X509.Allowed = acme.PolicyNames{}
|
||||
eak.Policy.X509.Allowed.DNSNames = allow.Dns
|
||||
eak.Policy.X509.Allowed.IPRanges = allow.Ips
|
||||
}
|
||||
if deny := x509.GetDeny(); deny != nil {
|
||||
eak.Policy.X509.Denied = acme.PolicyNames{}
|
||||
eak.Policy.X509.Denied.DNSNames = deny.Dns
|
||||
eak.Policy.X509.Denied.IPRanges = deny.Ips
|
||||
}
|
||||
eak.Policy.X509.AllowWildcardNames = x509.AllowWildcardNames
|
||||
}
|
||||
}
|
||||
|
||||
return eak
|
||||
}
|
||||
|
|
|
@ -4,20 +4,24 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
||||
|
@ -29,109 +33,90 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
|||
return protojson.Unmarshal(data, m)
|
||||
}
|
||||
|
||||
func mockMustAuthority(t *testing.T, a adminAuthority) {
|
||||
t.Helper()
|
||||
fn := mustAuthority
|
||||
t.Cleanup(func() {
|
||||
mustAuthority = fn
|
||||
})
|
||||
mustAuthority = func(ctx context.Context) adminAuthority {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_requireEABEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
adminDB admin.DB
|
||||
auth adminAuthority
|
||||
next nextHTTP
|
||||
next http.HandlerFunc
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/h.provisionerHasEABEnabled": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
"fail/prov.GetDetails": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
}
|
||||
err := admin.NewErrorISE("error loading provisioner provName: force")
|
||||
err.Message = "error loading provisioner provName: force"
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
|
||||
err.Message = "error getting ACME details for provisioner 'provName'"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"fail/prov.GetDetails.GetACME": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
|
||||
err.Message = "error getting ACME details for provisioner 'provName'"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"ok/eab-disabled": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: false,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName")
|
||||
err.Message = "ACME EAB not enabled for provisioner provName"
|
||||
err.Message = "ACME EAB not enabled for provisioner 'provName'"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
err: err,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"ok/eab-enabled": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: true,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
ctx: ctx,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
|
@ -143,16 +128,9 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
acmeDB: nil,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.requireEABEnabled(tc.next)(w, req)
|
||||
requireEABEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -176,216 +154,6 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHandler_provisionerHasEABEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
adminDB admin.DB
|
||||
auth adminAuthority
|
||||
provisionerName string
|
||||
want bool
|
||||
err *admin.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetProvisioner": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"fail/prov.GetDetails": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: nil,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"fail/details.GetACME": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: nil,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"ok/eab-disabled": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "eab-disabled", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "eab-disabled",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
adminDB: db,
|
||||
auth: auth,
|
||||
provisionerName: "eab-disabled",
|
||||
want: false,
|
||||
}
|
||||
},
|
||||
"ok/eab-enabled": func(t *testing.T) test {
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "eab-enabled", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "eab-enabled",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
adminDB: db,
|
||||
auth: auth,
|
||||
provisionerName: "eab-enabled",
|
||||
want: true,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
acmeDB: nil,
|
||||
}
|
||||
got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName)
|
||||
if (err != nil) != (tc.err != nil) {
|
||||
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
|
||||
return
|
||||
}
|
||||
if tc.err != nil {
|
||||
assert.Type(t, &linkedca.Provisioner{}, prov)
|
||||
assert.Type(t, &admin.Error{}, err)
|
||||
adminError, _ := err.(*admin.Error)
|
||||
assert.Equals(t, tc.err.Type, adminError.Type)
|
||||
assert.Equals(t, tc.err.Status, adminError.Status)
|
||||
assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode())
|
||||
assert.Equals(t, tc.err.Message, adminError.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminError.Detail)
|
||||
return
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Reference string
|
||||
|
@ -585,3 +353,206 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_eakToLinked(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
k *acme.ExternalAccountKey
|
||||
want *linkedca.EABKey
|
||||
}{
|
||||
{
|
||||
name: "no-key",
|
||||
k: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no-policy",
|
||||
k: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: nil,
|
||||
},
|
||||
want: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with-policy",
|
||||
k: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: &acme.Policy{
|
||||
X509: acme.X509Policy{
|
||||
Allowed: acme.PolicyNames{
|
||||
DNSNames: []string{"*.local"},
|
||||
IPRanges: []string{"10.0.0.0/24"},
|
||||
},
|
||||
Denied: acme.PolicyNames{
|
||||
DNSNames: []string{"badhost.local"},
|
||||
IPRanges: []string{"10.0.0.30"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: &linkedca.Policy{
|
||||
X509: &linkedca.X509Policy{
|
||||
Allow: &linkedca.X509Names{
|
||||
Dns: []string{"*.local"},
|
||||
Ips: []string{"10.0.0.0/24"},
|
||||
},
|
||||
Deny: &linkedca.X509Names{
|
||||
Dns: []string{"badhost.local"},
|
||||
Ips: []string{"10.0.0.30"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := eakToLinked(tt.k); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("eakToLinked() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_linkedEAKToCertificates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
k *linkedca.EABKey
|
||||
want *acme.ExternalAccountKey
|
||||
}{
|
||||
{
|
||||
name: "no-key",
|
||||
k: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no-policy",
|
||||
k: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: nil,
|
||||
},
|
||||
want: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no-x509-policy",
|
||||
k: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: &linkedca.Policy{},
|
||||
},
|
||||
want: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: &acme.Policy{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with-x509-policy",
|
||||
k: &linkedca.EABKey{
|
||||
Id: "keyID",
|
||||
Provisioner: "provID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
Reference: "ref",
|
||||
Account: "accID",
|
||||
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
|
||||
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
|
||||
Policy: &linkedca.Policy{
|
||||
X509: &linkedca.X509Policy{
|
||||
Allow: &linkedca.X509Names{
|
||||
Dns: []string{"*.local"},
|
||||
Ips: []string{"10.0.0.0/24"},
|
||||
},
|
||||
Deny: &linkedca.X509Names{
|
||||
Dns: []string{"badhost.local"},
|
||||
Ips: []string{"10.0.0.30"},
|
||||
},
|
||||
AllowWildcardNames: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &acme.ExternalAccountKey{
|
||||
ID: "keyID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
AccountID: "accID",
|
||||
HmacKey: []byte{1, 3, 3, 7},
|
||||
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
|
||||
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
|
||||
Policy: &acme.Policy{
|
||||
X509: acme.X509Policy{
|
||||
Allowed: acme.PolicyNames{
|
||||
DNSNames: []string{"*.local"},
|
||||
IPRanges: []string{"10.0.0.0/24"},
|
||||
},
|
||||
Denied: acme.PolicyNames{
|
||||
DNSNames: []string{"badhost.local"},
|
||||
IPRanges: []string{"10.0.0.30"},
|
||||
},
|
||||
AllowWildcardNames: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := linkedEAKToCertificates(tt.k); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("linkedEAKToCertificates() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,10 @@ type adminAuthority interface {
|
|||
LoadProvisionerByID(id string) (provisioner.Interface, error)
|
||||
UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error
|
||||
RemoveProvisioner(ctx context.Context, id string) error
|
||||
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
|
||||
CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
RemoveAuthorityPolicy(ctx context.Context) error
|
||||
}
|
||||
|
||||
// CreateAdminRequest represents the body for a CreateAdmin request.
|
||||
|
@ -81,10 +85,10 @@ type DeleteResponse struct {
|
|||
}
|
||||
|
||||
// GetAdmin returns the requested admin, or an error.
|
||||
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
adm, ok := h.auth.LoadAdminByID(id)
|
||||
adm, ok := mustAuthority(r.Context()).LoadAdminByID(id)
|
||||
if !ok {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found", id))
|
||||
|
@ -94,7 +98,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// GetAdmins returns a segment of admins associated with the authority.
|
||||
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
func GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
|
@ -102,7 +106,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
|
||||
admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
||||
return
|
||||
|
@ -114,7 +118,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// CreateAdmin creates a new admin.
|
||||
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body CreateAdminRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
|
@ -126,7 +130,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
|
||||
auth := mustAuthority(r.Context())
|
||||
p, err := auth.LoadProvisionerByName(body.Provisioner)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
||||
return
|
||||
|
@ -137,7 +142,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
Type: body.Type,
|
||||
}
|
||||
// Store to authority collection.
|
||||
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||
if err := auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
|
||||
return
|
||||
}
|
||||
|
@ -146,10 +151,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// DeleteAdmin deletes admin.
|
||||
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
|
||||
if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
||||
return
|
||||
}
|
||||
|
@ -158,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// UpdateAdmin updates an existing admin.
|
||||
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body UpdateAdminRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
|
@ -171,8 +176,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||
auth := mustAuthority(r.Context())
|
||||
adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
||||
return
|
||||
|
|
|
@ -14,11 +14,13 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type mockAdminAuthority struct {
|
||||
|
@ -37,6 +39,11 @@ type mockAdminAuthority struct {
|
|||
MockLoadProvisionerByID func(id string) (provisioner.Interface, error)
|
||||
MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error
|
||||
MockRemoveProvisioner func(ctx context.Context, id string) error
|
||||
|
||||
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
|
||||
MockCreateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
MockUpdateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
MockRemoveAuthorityPolicy func(ctx context.Context) error
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) IsAdminAPIEnabled() bool {
|
||||
|
@ -130,6 +137,34 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e
|
|||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
if m.MockGetAuthorityPolicy != nil {
|
||||
return m.MockGetAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Policy), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
if m.MockCreateAuthorityPolicy != nil {
|
||||
return m.MockCreateAuthorityPolicy(ctx, adm, policy)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Policy), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
if m.MockUpdateAuthorityPolicy != nil {
|
||||
return m.MockUpdateAuthorityPolicy(ctx, adm, policy)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Policy), m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error {
|
||||
if m.MockRemoveAuthorityPolicy != nil {
|
||||
return m.MockRemoveAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func TestCreateAdminRequest_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Subject string
|
||||
|
@ -317,14 +352,11 @@ func TestHandler_GetAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAdmin(w, req)
|
||||
GetAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -456,13 +488,10 @@ func TestHandler_GetAdmins(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAdmins(w, req)
|
||||
GetAdmins(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -640,13 +669,11 @@ func TestHandler_CreateAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateAdmin(w, req)
|
||||
CreateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -732,13 +759,11 @@ func TestHandler_DeleteAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteAdmin(w, req)
|
||||
DeleteAdmin(w, req)
|
||||
res := w.Result()
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
|
@ -877,13 +902,11 @@ func TestHandler_UpdateAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateAdmin(w, req)
|
||||
UpdateAdmin(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
|
@ -1,56 +1,117 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
// Handler is the Admin API request handler.
|
||||
type Handler struct {
|
||||
adminDB admin.DB
|
||||
auth adminAuthority
|
||||
acmeDB acme.DB
|
||||
acmeResponder acmeAdminResponderInterface
|
||||
}
|
||||
|
||||
// NewHandler returns a new Authority Config Handler.
|
||||
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler {
|
||||
return &Handler{
|
||||
auth: auth,
|
||||
adminDB: adminDB,
|
||||
acmeDB: acmeDB,
|
||||
acmeResponder: acmeResponder,
|
||||
}
|
||||
acmeResponder ACMEAdminResponder
|
||||
policyResponder PolicyAdminResponder
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
//
|
||||
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
authnz := func(next nextHTTP) nextHTTP {
|
||||
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next))
|
||||
Route(r, h.acmeResponder, h.policyResponder)
|
||||
}
|
||||
|
||||
// NewHandler returns a new Authority Config Handler.
|
||||
//
|
||||
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
|
||||
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) api.RouterHandler {
|
||||
return &Handler{
|
||||
acmeResponder: acmeResponder,
|
||||
policyResponder: policyResponder,
|
||||
}
|
||||
}
|
||||
|
||||
var mustAuthority = func(ctx context.Context) adminAuthority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) {
|
||||
authnz := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return extractAuthorizeTokenAdmin(requireAPIEnabled(next))
|
||||
}
|
||||
|
||||
requireEABEnabled := func(next nextHTTP) nextHTTP {
|
||||
return h.requireEABEnabled(next)
|
||||
enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return checkAction(next, true)
|
||||
}
|
||||
|
||||
disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return checkAction(next, false)
|
||||
}
|
||||
|
||||
acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(loadProvisionerByName(requireEABEnabled(next)))
|
||||
}
|
||||
|
||||
authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(enabledInStandalone(next))
|
||||
}
|
||||
|
||||
provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(disabledInStandalone(loadProvisionerByName(next)))
|
||||
}
|
||||
|
||||
acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next)))))
|
||||
}
|
||||
|
||||
// Provisioners
|
||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners))
|
||||
r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner))
|
||||
r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner))
|
||||
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners))
|
||||
r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner))
|
||||
r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner))
|
||||
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner))
|
||||
|
||||
// Admins
|
||||
r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin))
|
||||
r.MethodFunc("GET", "/admins", authnz(h.GetAdmins))
|
||||
r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin))
|
||||
r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin))
|
||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin))
|
||||
r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin))
|
||||
r.MethodFunc("GET", "/admins", authnz(GetAdmins))
|
||||
r.MethodFunc("POST", "/admins", authnz(CreateAdmin))
|
||||
r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin))
|
||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin))
|
||||
|
||||
// ACME External Account Binding Keys
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
|
||||
r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey)))
|
||||
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey)))
|
||||
// ACME responder
|
||||
if acmeResponder != nil {
|
||||
// ACME External Account Binding Keys
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
|
||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
|
||||
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey))
|
||||
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey))
|
||||
}
|
||||
|
||||
// Policy responder
|
||||
if policyResponder != nil {
|
||||
// Policy - Authority
|
||||
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy))
|
||||
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy))
|
||||
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy))
|
||||
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy))
|
||||
|
||||
// Policy - Provisioner
|
||||
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy))
|
||||
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy))
|
||||
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy))
|
||||
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy))
|
||||
|
||||
// Policy - ACME Account
|
||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
|
||||
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
|
||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
|
||||
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
|
||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
|
||||
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
|
||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
|
||||
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,22 +1,26 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
||||
// requireAPIEnabled is a middleware that ensures the Administration API
|
||||
// is enabled before servicing requests.
|
||||
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
||||
func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.auth.IsAdminAPIEnabled() {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"administration API not enabled"))
|
||||
if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
|
@ -24,8 +28,9 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
|
||||
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
||||
func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
tok := r.Header.Get("Authorization")
|
||||
if tok == "" {
|
||||
render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
|
@ -33,22 +38,111 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
|||
return
|
||||
}
|
||||
|
||||
adm, err := h.auth.AuthorizeAdminToken(r, tok)
|
||||
ctx := r.Context()
|
||||
adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), adminContextKey, adm)
|
||||
ctx = linkedca.NewContextWithAdmin(ctx, adm)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// ContextKey is the key type for storing and searching for ACME request
|
||||
// essentials in the context of a request.
|
||||
type ContextKey string
|
||||
// loadProvisionerByName is a middleware that searches for a provisioner
|
||||
// by name and stores it in the context.
|
||||
func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
|
||||
const (
|
||||
// adminContextKey account key
|
||||
adminContextKey = ContextKey("admin")
|
||||
)
|
||||
ctx := r.Context()
|
||||
auth := mustAuthority(ctx)
|
||||
adminDB := admin.MustFromContext(ctx)
|
||||
name := chi.URLParam(r, "provisionerName")
|
||||
|
||||
// TODO(hs): distinguish 404 vs. 500
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// checkAction checks if an action is supported in standalone or not
|
||||
func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// actions allowed in standalone mode are always supported
|
||||
if supportedInStandalone {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// when an action is not supported in standalone mode and when
|
||||
// using a nosql.DB backend, actions are not supported
|
||||
if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"operation not supported in standalone mode"))
|
||||
return
|
||||
}
|
||||
|
||||
// continue to next http handler
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// loadExternalAccountKey is a middleware that searches for an ACME
|
||||
// External Account Key by reference or keyID and stores it in the context.
|
||||
func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
reference := chi.URLParam(r, "reference")
|
||||
keyID := chi.URLParam(r, "keyID")
|
||||
|
||||
var (
|
||||
eak *acme.ExternalAccountKey
|
||||
err error
|
||||
)
|
||||
|
||||
if keyID != "" {
|
||||
eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
|
||||
} else {
|
||||
eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, acme.ErrNotFound) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
|
||||
return
|
||||
}
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key"))
|
||||
return
|
||||
}
|
||||
|
||||
if eak == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
|
||||
return
|
||||
}
|
||||
|
||||
linkedEAK := eakToLinked(eak)
|
||||
|
||||
ctx = linkedca.NewContextWithExternalAccountKey(ctx, linkedEAK)
|
||||
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,25 +4,32 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
func TestHandler_requireAPIEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
auth adminAuthority
|
||||
next nextHTTP
|
||||
next http.HandlerFunc
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
|
@ -64,13 +71,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.requireAPIEnabled(tc.next)(w, req)
|
||||
requireAPIEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -102,7 +107,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
ctx context.Context
|
||||
auth adminAuthority
|
||||
req *http.Request
|
||||
next nextHTTP
|
||||
next http.HandlerFunc
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
|
@ -152,7 +157,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
req.Header["Authorization"] = []string{"token"}
|
||||
createdAt := time.Now()
|
||||
var deletedAt time.Time
|
||||
admin := &linkedca.Admin{
|
||||
adm := &linkedca.Admin{
|
||||
Id: "adminID",
|
||||
AuthorityId: "authorityID",
|
||||
Subject: "admin",
|
||||
|
@ -164,20 +169,15 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
auth := &mockAdminAuthority{
|
||||
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||
assert.Equals(t, "token", token)
|
||||
return admin, nil
|
||||
return adm, nil
|
||||
},
|
||||
}
|
||||
next := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin
|
||||
adm, ok := a.(*linkedca.Admin)
|
||||
if !ok {
|
||||
t.Errorf("expected *linkedca.Admin; got %T", a)
|
||||
return
|
||||
}
|
||||
adm := linkedca.MustAdminFromContext(ctx) // verifying that the context now has a linkedca.Admin
|
||||
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
|
||||
if !cmp.Equal(admin, adm, opts...) {
|
||||
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...))
|
||||
if !cmp.Equal(adm, adm, opts...) {
|
||||
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...))
|
||||
}
|
||||
w.Write(nil) // mock response with status 200
|
||||
}
|
||||
|
@ -194,13 +194,459 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||
extractAuthorizeTokenAdmin(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
err := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_loadProvisionerByName(t *testing.T) {
|
||||
type test struct {
|
||||
adminDB admin.DB
|
||||
auth adminAuthority
|
||||
ctx context.Context
|
||||
next http.HandlerFunc
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
err := admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName")
|
||||
err.Message = "error loading provisioner provName: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: err,
|
||||
}
|
||||
},
|
||||
"fail/db.GetProvisioner": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving provisioner provName")
|
||||
err.Message = "error retrieving provisioner provName: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
statusCode: 500,
|
||||
err: err,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &mockAdminAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
adminDB: db,
|
||||
statusCode: 200,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
prov := linkedca.MustProvisionerFromContext(r.Context())
|
||||
assert.NotNil(t, prov)
|
||||
assert.Equals(t, "provID", prov.GetId())
|
||||
assert.Equals(t, "provName", prov.GetName())
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
loadProvisionerByName(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
err := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_checkAction(t *testing.T) {
|
||||
type test struct {
|
||||
adminDB admin.DB
|
||||
next http.HandlerFunc
|
||||
supportedInStandalone bool
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"standalone-nosql-supported": func(t *testing.T) test {
|
||||
return test{
|
||||
supportedInStandalone: true,
|
||||
adminDB: &nosql.DB{},
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"standalone-nosql-not-supported": func(t *testing.T) test {
|
||||
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode")
|
||||
err.Message = "operation not supported in standalone mode"
|
||||
return test{
|
||||
supportedInStandalone: false,
|
||||
adminDB: &nosql.DB{},
|
||||
statusCode: 501,
|
||||
err: err,
|
||||
}
|
||||
},
|
||||
"standalone-no-nosql-not-supported": func(t *testing.T) test {
|
||||
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported")
|
||||
err.Message = "operation not supported"
|
||||
return test{
|
||||
supportedInStandalone: false,
|
||||
adminDB: &admin.MockDB{},
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
statusCode: 200,
|
||||
err: err,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := admin.NewContext(context.Background(), tc.adminDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
checkAction(tc.next, tc.supportedInStandalone)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
err := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_loadExternalAccountKey(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
acmeDB acme.DB
|
||||
next http.HandlerFunc
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/keyID-not-found-error": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("keyID", "key")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "key", keyID)
|
||||
return nil, acme.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
},
|
||||
"fail/keyID-error": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("keyID", "key")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
|
||||
err.Message = "error retrieving ACME External Account Key: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "key", keyID)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"fail/reference-not-found-error": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("reference", "ref")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "ref", reference)
|
||||
return nil, acme.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
},
|
||||
"fail/reference-error": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("reference", "ref")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
|
||||
err.Message = "error retrieving ACME External Account Key: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "ref", reference)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"fail/no-key": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("reference", "ref")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "ref", reference)
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: err,
|
||||
statusCode: 404,
|
||||
}
|
||||
},
|
||||
"ok/keyID": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("keyID", "eakID")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
createdAt := time.Now().Add(-1 * time.Hour)
|
||||
var boundAt time.Time
|
||||
eak := &acme.ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
CreatedAt: createdAt,
|
||||
BoundAt: boundAt,
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "eakID", keyID)
|
||||
return eak, nil
|
||||
},
|
||||
},
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
|
||||
assert.NotNil(t, eak)
|
||||
exp := &linkedca.EABKey{
|
||||
Id: "eakID",
|
||||
Provisioner: "provID",
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
BoundAt: timestamppb.New(boundAt),
|
||||
}
|
||||
assert.Equals(t, exp, contextEAK)
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
err: nil,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/reference": func(t *testing.T) test {
|
||||
prov := &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
}
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("reference", "ref")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
|
||||
err.Message = "ACME External Account Key not found"
|
||||
createdAt := time.Now().Add(-1 * time.Hour)
|
||||
var boundAt time.Time
|
||||
eak := &acme.ExternalAccountKey{
|
||||
ID: "eakID",
|
||||
ProvisionerID: "provID",
|
||||
Reference: "ref",
|
||||
CreatedAt: createdAt,
|
||||
BoundAt: boundAt,
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
acmeDB: &acme.MockDB{
|
||||
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
|
||||
assert.Equals(t, "provID", provisionerID)
|
||||
assert.Equals(t, "ref", reference)
|
||||
return eak, nil
|
||||
},
|
||||
},
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
|
||||
assert.NotNil(t, eak)
|
||||
exp := &linkedca.EABKey{
|
||||
Id: "eakID",
|
||||
Provisioner: "provID",
|
||||
Reference: "ref",
|
||||
CreatedAt: timestamppb.New(createdAt),
|
||||
BoundAt: timestamppb.New(boundAt),
|
||||
}
|
||||
assert.Equals(t, exp, contextEAK)
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
err: nil,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB)
|
||||
req := httptest.NewRequest("GET", "/foo", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
loadExternalAccountKey(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
|
496
authority/admin/api/policy.go
Normal file
496
authority/admin/api/policy.go
Normal file
|
@ -0,0 +1,496 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
// PolicyAdminResponder is the interface responsible for writing ACME admin
|
||||
// responses.
|
||||
type PolicyAdminResponder interface {
|
||||
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
GetProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// policyAdminResponder implements PolicyAdminResponder.
|
||||
type policyAdminResponder struct{}
|
||||
|
||||
// NewACMEAdminResponder returns a new PolicyAdminResponder.
|
||||
func NewPolicyAdminResponder() PolicyAdminResponder {
|
||||
return &policyAdminResponder{}
|
||||
}
|
||||
|
||||
// GetAuthorityPolicy handles the GET /admin/authority/policy request
|
||||
func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(r.Context())
|
||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, authorityPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
|
||||
func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
|
||||
render.Error(w, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
adm := linkedca.MustAdminFromContext(ctx)
|
||||
|
||||
var createdPolicy *linkedca.Policy
|
||||
if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, createdPolicy, http.StatusCreated)
|
||||
}
|
||||
|
||||
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
|
||||
func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
adm := linkedca.MustAdminFromContext(ctx)
|
||||
|
||||
var updatedPolicy *linkedca.Policy
|
||||
if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, updatedPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
|
||||
func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.RemoveAuthorityPolicy(ctx); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
|
||||
func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, provisionerPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
|
||||
func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
|
||||
render.Error(w, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
prov.Policy = newPolicy
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
|
||||
}
|
||||
|
||||
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
|
||||
func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
prov.Policy = newPolicy
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
|
||||
func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
if prov.Policy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
// remove the policy
|
||||
prov.Policy = nil
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, eakPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
|
||||
render.Error(w, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
eak.Policy = newPolicy
|
||||
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
eak.Policy = newPolicy
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
// remove the policy
|
||||
eak.Policy = nil
|
||||
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
// blockLinkedCA blocks all API operations on linked deployments
|
||||
func blockLinkedCA(ctx context.Context) error {
|
||||
// temporary blocking linked deployments
|
||||
adminDB := admin.MustFromContext(ctx)
|
||||
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() {
|
||||
return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isBadRequest checks if an error should result in a bad request error
|
||||
// returned to the client.
|
||||
func isBadRequest(err error) bool {
|
||||
var pe *authority.PolicyError
|
||||
isPolicyError := errors.As(err, &pe)
|
||||
return isPolicyError && (pe.Typ == authority.AdminLockOut || pe.Typ == authority.EvaluationFailure || pe.Typ == authority.ConfigurationFailure)
|
||||
}
|
||||
|
||||
func validatePolicy(p *linkedca.Policy) error {
|
||||
|
||||
// convert the policy; return early if nil
|
||||
options := policy.LinkedToCertificates(p)
|
||||
if options == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Initialize a temporary x509 allow/deny policy engine
|
||||
if _, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize a temporary SSH allow/deny policy engine for host certificates
|
||||
if _, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize a temporary SSH allow/deny policy engine for user certificates
|
||||
if _, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
2775
authority/admin/api/policy_test.go
Normal file
2775
authority/admin/api/policy_test.go
Normal file
File diff suppressed because it is too large
Load diff
|
@ -23,29 +23,31 @@ type GetProvisionersResponse struct {
|
|||
}
|
||||
|
||||
// GetProvisioner returns the requested provisioner, or an error.
|
||||
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
func GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
|
||||
ctx := r.Context()
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
auth := mustAuthority(ctx)
|
||||
db := admin.MustFromContext(ctx)
|
||||
|
||||
if len(id) > 0 {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
prov, err := db.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// GetProvisioners returns the given segment of provisioners associated with the authority.
|
||||
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
func GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
|
@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
p, next, err := h.auth.GetProvisioners(cursor, limit)
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
return
|
||||
|
@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// CreateProvisioner creates a new prov.
|
||||
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var prov = new(linkedca.Provisioner)
|
||||
if err := read.ProtoJSON(r.Body, prov); err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
|
||||
if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
||||
return
|
||||
}
|
||||
|
@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// DeleteProvisioner deletes a provisioner.
|
||||
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
auth := mustAuthority(r.Context())
|
||||
|
||||
if len(id) > 0 {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||
if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
||||
return
|
||||
}
|
||||
|
@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// UpdateProvisioner updates an existing prov.
|
||||
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var nu = new(linkedca.Provisioner)
|
||||
if err := read.ProtoJSON(r.Body, nu); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
name := chi.URLParam(r, "name")
|
||||
_old, err := h.auth.LoadProvisionerByName(name)
|
||||
auth := mustAuthority(ctx)
|
||||
db := admin.MustFromContext(ctx)
|
||||
|
||||
p, err := auth.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
||||
return
|
||||
}
|
||||
|
||||
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
|
||||
old, err := db.GetProvisioner(r.Context(), p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||
if err := auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -8,18 +8,21 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func TestHandler_GetProvisioner(t *testing.T) {
|
||||
|
@ -47,6 +50,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||
ctx: ctx,
|
||||
req: req,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
|
@ -71,6 +75,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||
ctx: ctx,
|
||||
req: req,
|
||||
auth: auth,
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
Type: admin.ErrorServerInternalType.String(),
|
||||
|
@ -153,13 +158,11 @@ func TestHandler_GetProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := tc.req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetProvisioner(w, req)
|
||||
GetProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -277,12 +280,10 @@ func TestHandler_GetProvisioners(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetProvisioners(w, req)
|
||||
GetProvisioners(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -335,12 +336,12 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
|||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{ // TODO(hs): this probably needs a better error
|
||||
Type: "",
|
||||
Status: 500,
|
||||
Detail: "",
|
||||
Message: "",
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: "badRequest",
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "proto: syntax error (line 1:2): invalid value !",
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -402,13 +403,11 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateProvisioner(w, req)
|
||||
CreateProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -423,9 +422,15 @@ func TestHandler_CreateProvisioner(t *testing.T) {
|
|||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||
|
||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
if strings.HasPrefix(tc.err.Message, "proto:") {
|
||||
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
|
||||
} else {
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -562,12 +567,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
req := tc.req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteProvisioner(w, req)
|
||||
DeleteProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -616,12 +619,13 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
return test{
|
||||
ctx: context.Background(),
|
||||
body: body,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{ // TODO(hs): this probably needs a better error
|
||||
Type: "",
|
||||
Status: 500,
|
||||
Detail: "",
|
||||
Message: "",
|
||||
adminDB: &admin.MockDB{},
|
||||
statusCode: 400,
|
||||
err: &admin.Error{
|
||||
Type: "badRequest",
|
||||
Status: 400,
|
||||
Detail: "bad request",
|
||||
Message: "proto: syntax error (line 1:2): invalid value !",
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -645,6 +649,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
return test{
|
||||
ctx: ctx,
|
||||
body: body,
|
||||
adminDB: &admin.MockDB{},
|
||||
auth: auth,
|
||||
statusCode: 500,
|
||||
err: &admin.Error{
|
||||
|
@ -1052,14 +1057,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
auth: tc.auth,
|
||||
adminDB: tc.adminDB,
|
||||
}
|
||||
mockMustAuthority(t, tc.auth)
|
||||
ctx := admin.NewContext(tc.ctx, tc.adminDB)
|
||||
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
|
||||
req = req.WithContext(tc.ctx)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateProvisioner(w, req)
|
||||
UpdateProvisioner(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||
|
@ -1074,9 +1077,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
|
|||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
|
||||
|
||||
assert.Equals(t, tc.err.Type, adminErr.Type)
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
assert.Equals(t, tc.err.Detail, adminErr.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
|
||||
if strings.HasPrefix(tc.err.Message, "proto:") {
|
||||
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
|
||||
} else {
|
||||
assert.Equals(t, tc.err.Message, adminErr.Message)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -69,6 +69,34 @@ type DB interface {
|
|||
GetAdmins(ctx context.Context) ([]*linkedca.Admin, error)
|
||||
UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error
|
||||
DeleteAdmin(ctx context.Context, id string) error
|
||||
|
||||
CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
|
||||
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
|
||||
UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
|
||||
DeleteAuthorityPolicy(ctx context.Context) error
|
||||
}
|
||||
|
||||
type dbKey struct{}
|
||||
|
||||
// NewContext adds the given admin database to the context.
|
||||
func NewContext(ctx context.Context, db DB) context.Context {
|
||||
return context.WithValue(ctx, dbKey{}, db)
|
||||
}
|
||||
|
||||
// FromContext returns the current admin database from the given context.
|
||||
func FromContext(ctx context.Context) (db DB, ok bool) {
|
||||
db, ok = ctx.Value(dbKey{}).(DB)
|
||||
return
|
||||
}
|
||||
|
||||
// MustFromContext returns the current admin database from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustFromContext(ctx context.Context) DB {
|
||||
if db, ok := FromContext(ctx); !ok {
|
||||
panic("admin database is not in the context")
|
||||
} else {
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
|
@ -86,6 +114,11 @@ type MockDB struct {
|
|||
MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error
|
||||
MockDeleteAdmin func(ctx context.Context, id string) error
|
||||
|
||||
MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
|
||||
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
|
||||
MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
|
||||
MockDeleteAuthorityPolicy func(ctx context.Context) error
|
||||
|
||||
MockError error
|
||||
MockRet1 interface{}
|
||||
}
|
||||
|
@ -179,3 +212,35 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error {
|
|||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateAuthorityPolicy mock
|
||||
func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
if m.MockCreateAuthorityPolicy != nil {
|
||||
return m.MockCreateAuthorityPolicy(ctx, policy)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAuthorityPolicy mock
|
||||
func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
if m.MockGetAuthorityPolicy != nil {
|
||||
return m.MockGetAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Policy), m.MockError
|
||||
}
|
||||
|
||||
// UpdateAuthorityPolicy mock
|
||||
func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
if m.MockUpdateAuthorityPolicy != nil {
|
||||
return m.MockUpdateAuthorityPolicy(ctx, policy)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// DeleteAuthorityPolicy mock
|
||||
func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error {
|
||||
if m.MockDeleteAuthorityPolicy != nil {
|
||||
return m.MockDeleteAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
|
|
@ -11,8 +11,9 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
adminsTable = []byte("admins")
|
||||
provisionersTable = []byte("provisioners")
|
||||
adminsTable = []byte("admins")
|
||||
provisionersTable = []byte("provisioners")
|
||||
authorityPoliciesTable = []byte("authority_policies")
|
||||
)
|
||||
|
||||
// DB is a struct that implements the AdminDB interface.
|
||||
|
@ -23,7 +24,7 @@ type DB struct {
|
|||
|
||||
// New configures and returns a new Authority DB backend implemented using a nosql DB.
|
||||
func New(db nosqlDB.DB, authorityID string) (*DB, error) {
|
||||
tables := [][]byte{adminsTable, provisionersTable}
|
||||
tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable}
|
||||
for _, b := range tables {
|
||||
if err := db.CreateTable(b); err != nil {
|
||||
return nil, errors.Wrapf(err, "error creating table %s",
|
||||
|
|
339
authority/admin/db/nosql/policy.go
Normal file
339
authority/admin/db/nosql/policy.go
Normal file
|
@ -0,0 +1,339 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type dbX509Policy struct {
|
||||
Allow *dbX509Names `json:"allow,omitempty"`
|
||||
Deny *dbX509Names `json:"deny,omitempty"`
|
||||
AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"`
|
||||
}
|
||||
|
||||
type dbX509Names struct {
|
||||
CommonNames []string `json:"cn,omitempty"`
|
||||
DNSDomains []string `json:"dns,omitempty"`
|
||||
IPRanges []string `json:"ip,omitempty"`
|
||||
EmailAddresses []string `json:"email,omitempty"`
|
||||
URIDomains []string `json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
type dbSSHPolicy struct {
|
||||
// User contains SSH user certificate options.
|
||||
User *dbSSHUserPolicy `json:"user,omitempty"`
|
||||
// Host contains SSH host certificate options.
|
||||
Host *dbSSHHostPolicy `json:"host,omitempty"`
|
||||
}
|
||||
|
||||
type dbSSHHostPolicy struct {
|
||||
Allow *dbSSHHostNames `json:"allow,omitempty"`
|
||||
Deny *dbSSHHostNames `json:"deny,omitempty"`
|
||||
}
|
||||
|
||||
type dbSSHHostNames struct {
|
||||
DNSDomains []string `json:"dns,omitempty"`
|
||||
IPRanges []string `json:"ip,omitempty"`
|
||||
Principals []string `json:"principal,omitempty"`
|
||||
}
|
||||
|
||||
type dbSSHUserPolicy struct {
|
||||
Allow *dbSSHUserNames `json:"allow,omitempty"`
|
||||
Deny *dbSSHUserNames `json:"deny,omitempty"`
|
||||
}
|
||||
|
||||
type dbSSHUserNames struct {
|
||||
EmailAddresses []string `json:"email,omitempty"`
|
||||
Principals []string `json:"principal,omitempty"`
|
||||
}
|
||||
|
||||
type dbPolicy struct {
|
||||
X509 *dbX509Policy `json:"x509,omitempty"`
|
||||
SSH *dbSSHPolicy `json:"ssh,omitempty"`
|
||||
}
|
||||
|
||||
type dbAuthorityPolicy struct {
|
||||
ID string `json:"id"`
|
||||
AuthorityID string `json:"authorityID"`
|
||||
Policy *dbPolicy `json:"policy,omitempty"`
|
||||
}
|
||||
|
||||
func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy {
|
||||
if dbap == nil {
|
||||
return nil
|
||||
}
|
||||
return dbToLinked(dbap.Policy)
|
||||
}
|
||||
|
||||
func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) {
|
||||
data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found")
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("error loading authority policy: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalDBAuthorityPolicy(data []byte) (*dbAuthorityPolicy, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var dba = new(dbAuthorityPolicy)
|
||||
if err := json.Unmarshal(data, dba); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling policy bytes into dbAuthorityPolicy: %w", err)
|
||||
}
|
||||
return dba, nil
|
||||
}
|
||||
|
||||
func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) {
|
||||
data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbap, err := db.unmarshalDBAuthorityPolicy(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dbap == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if dbap.AuthorityID != authorityID {
|
||||
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
|
||||
"authority policy is not owned by authority %s", authorityID)
|
||||
}
|
||||
return dbap, nil
|
||||
}
|
||||
|
||||
func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
|
||||
dbap := &dbAuthorityPolicy{
|
||||
ID: db.authorityID,
|
||||
AuthorityID: db.authorityID,
|
||||
Policy: linkedToDB(policy),
|
||||
}
|
||||
|
||||
if err := db.save(ctx, dbap.ID, dbap, nil, "authority_policy", authorityPoliciesTable); err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating authority policy")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dbap.convert(), nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbap := &dbAuthorityPolicy{
|
||||
ID: db.authorityID,
|
||||
AuthorityID: db.authorityID,
|
||||
Policy: linkedToDB(policy),
|
||||
}
|
||||
|
||||
if err := db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable); err != nil {
|
||||
return admin.WrapErrorISE(err, "error updating authority policy")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error {
|
||||
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.save(ctx, old.ID, nil, old, "authority_policy", authorityPoliciesTable); err != nil {
|
||||
return admin.WrapErrorISE(err, "error deleting authority policy")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbToLinked(p *dbPolicy) *linkedca.Policy {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
r := &linkedca.Policy{}
|
||||
if x509 := p.X509; x509 != nil {
|
||||
r.X509 = &linkedca.X509Policy{}
|
||||
if allow := x509.Allow; allow != nil {
|
||||
r.X509.Allow = &linkedca.X509Names{}
|
||||
r.X509.Allow.Dns = allow.DNSDomains
|
||||
r.X509.Allow.Emails = allow.EmailAddresses
|
||||
r.X509.Allow.Ips = allow.IPRanges
|
||||
r.X509.Allow.Uris = allow.URIDomains
|
||||
r.X509.Allow.CommonNames = allow.CommonNames
|
||||
}
|
||||
if deny := x509.Deny; deny != nil {
|
||||
r.X509.Deny = &linkedca.X509Names{}
|
||||
r.X509.Deny.Dns = deny.DNSDomains
|
||||
r.X509.Deny.Emails = deny.EmailAddresses
|
||||
r.X509.Deny.Ips = deny.IPRanges
|
||||
r.X509.Deny.Uris = deny.URIDomains
|
||||
r.X509.Deny.CommonNames = deny.CommonNames
|
||||
}
|
||||
r.X509.AllowWildcardNames = x509.AllowWildcardNames
|
||||
}
|
||||
if ssh := p.SSH; ssh != nil {
|
||||
r.Ssh = &linkedca.SSHPolicy{}
|
||||
if host := ssh.Host; host != nil {
|
||||
r.Ssh.Host = &linkedca.SSHHostPolicy{}
|
||||
if allow := host.Allow; allow != nil {
|
||||
r.Ssh.Host.Allow = &linkedca.SSHHostNames{}
|
||||
r.Ssh.Host.Allow.Dns = allow.DNSDomains
|
||||
r.Ssh.Host.Allow.Ips = allow.IPRanges
|
||||
r.Ssh.Host.Allow.Principals = allow.Principals
|
||||
}
|
||||
if deny := host.Deny; deny != nil {
|
||||
r.Ssh.Host.Deny = &linkedca.SSHHostNames{}
|
||||
r.Ssh.Host.Deny.Dns = deny.DNSDomains
|
||||
r.Ssh.Host.Deny.Ips = deny.IPRanges
|
||||
r.Ssh.Host.Deny.Principals = deny.Principals
|
||||
}
|
||||
}
|
||||
if user := ssh.User; user != nil {
|
||||
r.Ssh.User = &linkedca.SSHUserPolicy{}
|
||||
if allow := user.Allow; allow != nil {
|
||||
r.Ssh.User.Allow = &linkedca.SSHUserNames{}
|
||||
r.Ssh.User.Allow.Emails = allow.EmailAddresses
|
||||
r.Ssh.User.Allow.Principals = allow.Principals
|
||||
}
|
||||
if deny := user.Deny; deny != nil {
|
||||
r.Ssh.User.Deny = &linkedca.SSHUserNames{}
|
||||
r.Ssh.User.Deny.Emails = deny.EmailAddresses
|
||||
r.Ssh.User.Deny.Principals = deny.Principals
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func linkedToDB(p *linkedca.Policy) *dbPolicy {
|
||||
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// return early if x509 nor SSH is set
|
||||
if p.GetX509() == nil && p.GetSsh() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := &dbPolicy{}
|
||||
// fill x509 policy configuration
|
||||
if x509 := p.GetX509(); x509 != nil {
|
||||
r.X509 = &dbX509Policy{}
|
||||
if allow := x509.GetAllow(); allow != nil {
|
||||
r.X509.Allow = &dbX509Names{}
|
||||
if allow.Dns != nil {
|
||||
r.X509.Allow.DNSDomains = allow.Dns
|
||||
}
|
||||
if allow.Ips != nil {
|
||||
r.X509.Allow.IPRanges = allow.Ips
|
||||
}
|
||||
if allow.Emails != nil {
|
||||
r.X509.Allow.EmailAddresses = allow.Emails
|
||||
}
|
||||
if allow.Uris != nil {
|
||||
r.X509.Allow.URIDomains = allow.Uris
|
||||
}
|
||||
if allow.CommonNames != nil {
|
||||
r.X509.Allow.CommonNames = allow.CommonNames
|
||||
}
|
||||
}
|
||||
if deny := x509.GetDeny(); deny != nil {
|
||||
r.X509.Deny = &dbX509Names{}
|
||||
if deny.Dns != nil {
|
||||
r.X509.Deny.DNSDomains = deny.Dns
|
||||
}
|
||||
if deny.Ips != nil {
|
||||
r.X509.Deny.IPRanges = deny.Ips
|
||||
}
|
||||
if deny.Emails != nil {
|
||||
r.X509.Deny.EmailAddresses = deny.Emails
|
||||
}
|
||||
if deny.Uris != nil {
|
||||
r.X509.Deny.URIDomains = deny.Uris
|
||||
}
|
||||
if deny.CommonNames != nil {
|
||||
r.X509.Deny.CommonNames = deny.CommonNames
|
||||
}
|
||||
}
|
||||
|
||||
r.X509.AllowWildcardNames = x509.GetAllowWildcardNames()
|
||||
}
|
||||
|
||||
// fill ssh policy configuration
|
||||
if ssh := p.GetSsh(); ssh != nil {
|
||||
r.SSH = &dbSSHPolicy{}
|
||||
if host := ssh.GetHost(); host != nil {
|
||||
r.SSH.Host = &dbSSHHostPolicy{}
|
||||
if allow := host.GetAllow(); allow != nil {
|
||||
r.SSH.Host.Allow = &dbSSHHostNames{}
|
||||
if allow.Dns != nil {
|
||||
r.SSH.Host.Allow.DNSDomains = allow.Dns
|
||||
}
|
||||
if allow.Ips != nil {
|
||||
r.SSH.Host.Allow.IPRanges = allow.Ips
|
||||
}
|
||||
if allow.Principals != nil {
|
||||
r.SSH.Host.Allow.Principals = allow.Principals
|
||||
}
|
||||
}
|
||||
if deny := host.GetDeny(); deny != nil {
|
||||
r.SSH.Host.Deny = &dbSSHHostNames{}
|
||||
if deny.Dns != nil {
|
||||
r.SSH.Host.Deny.DNSDomains = deny.Dns
|
||||
}
|
||||
if deny.Ips != nil {
|
||||
r.SSH.Host.Deny.IPRanges = deny.Ips
|
||||
}
|
||||
if deny.Principals != nil {
|
||||
r.SSH.Host.Deny.Principals = deny.Principals
|
||||
}
|
||||
}
|
||||
}
|
||||
if user := ssh.GetUser(); user != nil {
|
||||
r.SSH.User = &dbSSHUserPolicy{}
|
||||
if allow := user.GetAllow(); allow != nil {
|
||||
r.SSH.User.Allow = &dbSSHUserNames{}
|
||||
if allow.Emails != nil {
|
||||
r.SSH.User.Allow.EmailAddresses = allow.Emails
|
||||
}
|
||||
if allow.Principals != nil {
|
||||
r.SSH.User.Allow.Principals = allow.Principals
|
||||
}
|
||||
}
|
||||
if deny := user.GetDeny(); deny != nil {
|
||||
r.SSH.User.Deny = &dbSSHUserNames{}
|
||||
if deny.Emails != nil {
|
||||
r.SSH.User.Deny.EmailAddresses = deny.Emails
|
||||
}
|
||||
if deny.Principals != nil {
|
||||
r.SSH.User.Deny.Principals = deny.Principals
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
1206
authority/admin/db/nosql/policy_test.go
Normal file
1206
authority/admin/db/nosql/policy_test.go
Normal file
File diff suppressed because it is too large
Load diff
|
@ -24,10 +24,12 @@ const (
|
|||
ErrorBadRequestType
|
||||
// ErrorNotImplementedType not implemented.
|
||||
ErrorNotImplementedType
|
||||
// ErrorUnauthorizedType internal server error.
|
||||
// ErrorUnauthorizedType unauthorized.
|
||||
ErrorUnauthorizedType
|
||||
// ErrorServerInternalType internal server error.
|
||||
ErrorServerInternalType
|
||||
// ErrorConflictType conflict.
|
||||
ErrorConflictType
|
||||
)
|
||||
|
||||
// String returns the string representation of the admin problem type,
|
||||
|
@ -48,6 +50,8 @@ func (ap ProblemType) String() string {
|
|||
return "unauthorized"
|
||||
case ErrorServerInternalType:
|
||||
return "internalServerError"
|
||||
case ErrorConflictType:
|
||||
return "conflict"
|
||||
default:
|
||||
return fmt.Sprintf("unsupported error type '%d'", int(ap))
|
||||
}
|
||||
|
@ -64,7 +68,7 @@ var (
|
|||
errorServerInternalMetadata = errorMetadata{
|
||||
typ: ErrorServerInternalType.String(),
|
||||
details: "the server experienced an internal error",
|
||||
status: 500,
|
||||
status: http.StatusInternalServerError,
|
||||
}
|
||||
errorMap = map[ProblemType]errorMetadata{
|
||||
ErrorNotFoundType: {
|
||||
|
@ -98,6 +102,11 @@ var (
|
|||
status: http.StatusUnauthorized,
|
||||
},
|
||||
ErrorServerInternalType: errorServerInternalMetadata,
|
||||
ErrorConflictType: {
|
||||
typ: ErrorConflictType.String(),
|
||||
details: "conflict",
|
||||
status: http.StatusConflict,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -59,12 +59,12 @@ func newSubProv(subject, prov string) subProv {
|
|||
return subProv{subject, prov}
|
||||
}
|
||||
|
||||
// LoadBySubProv a admin by the subject and provisioner name.
|
||||
// LoadBySubProv loads an admin by subject and provisioner name.
|
||||
func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) {
|
||||
return loadAdmin(c.bySubProv, newSubProv(sub, provName))
|
||||
}
|
||||
|
||||
// LoadByProvisioner a admin by the subject and provisioner name.
|
||||
// LoadByProvisioner loads admins by provisioner name.
|
||||
func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) {
|
||||
val, ok := c.byProv.Load(provName)
|
||||
if !ok {
|
||||
|
@ -78,7 +78,7 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool
|
|||
}
|
||||
|
||||
// Store adds an admin to the collection and enforces the uniqueness of
|
||||
// admin IDs and amdin subject <-> provisioner name combos.
|
||||
// admin IDs and admin subject <-> provisioner name combos.
|
||||
func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
// Input validation.
|
||||
if adm.ProvisionerId != prov.GetID() {
|
||||
|
|
|
@ -49,7 +49,7 @@ func (a *Authority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov pr
|
|||
return admin.WrapErrorISE(err, "error creating admin")
|
||||
}
|
||||
if err := a.admins.Store(adm, prov); err != nil {
|
||||
if err := a.reloadAdminResources(ctx); err != nil {
|
||||
if err := a.ReloadAdminResources(ctx); err != nil {
|
||||
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin store")
|
||||
}
|
||||
return admin.WrapErrorISE(err, "error storing admin in authority cache")
|
||||
|
@ -66,7 +66,7 @@ func (a *Authority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Adm
|
|||
return nil, admin.WrapErrorISE(err, "error updating cached admin %s", id)
|
||||
}
|
||||
if err := a.adminDB.UpdateAdmin(ctx, adm); err != nil {
|
||||
if err := a.reloadAdminResources(ctx); err != nil {
|
||||
if err := a.ReloadAdminResources(ctx); err != nil {
|
||||
return nil, admin.WrapErrorISE(err, "error reloading admin resources on failed admin update")
|
||||
}
|
||||
return nil, admin.WrapErrorISE(err, "error updating admin %s", id)
|
||||
|
@ -88,7 +88,7 @@ func (a *Authority) removeAdmin(ctx context.Context, id string) error {
|
|||
return admin.WrapErrorISE(err, "error removing admin %s from authority cache", id)
|
||||
}
|
||||
if err := a.adminDB.DeleteAdmin(ctx, id); err != nil {
|
||||
if err := a.reloadAdminResources(ctx); err != nil {
|
||||
if err := a.ReloadAdminResources(ctx); err != nil {
|
||||
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin remove")
|
||||
}
|
||||
return admin.WrapErrorISE(err, "error deleting admin %s", id)
|
||||
|
|
|
@ -13,10 +13,16 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||
"github.com/smallstep/certificates/authority/administrator"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/cas"
|
||||
casapi "github.com/smallstep/certificates/cas/apiv1"
|
||||
|
@ -27,9 +33,6 @@ import (
|
|||
"github.com/smallstep/certificates/scep"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"github.com/smallstep/nosql"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"go.step.sm/linkedca"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Authority implements the Certificate Authority internal interface.
|
||||
|
@ -81,9 +84,16 @@ type Authority struct {
|
|||
authorizeRenewFunc provisioner.AuthorizeRenewFunc
|
||||
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
|
||||
|
||||
// Policy engines
|
||||
policyEngine *policy.Engine
|
||||
|
||||
adminMutex sync.RWMutex
|
||||
|
||||
// Do Not initialize the authority
|
||||
skipInit bool
|
||||
}
|
||||
|
||||
// Info contains information about the authority.
|
||||
type Info struct {
|
||||
StartTime time.Time
|
||||
RootX509Certs []*x509.Certificate
|
||||
|
@ -111,9 +121,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Initialize authority from options or configuration.
|
||||
if err := a.init(); err != nil {
|
||||
return nil, err
|
||||
if !a.skipInit {
|
||||
// Initialize authority from options or configuration.
|
||||
if err := a.init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return a, nil
|
||||
|
@ -149,16 +161,41 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
|
|||
// Initialize config required fields.
|
||||
a.config.Init()
|
||||
|
||||
// Initialize authority from options or configuration.
|
||||
if err := a.init(); err != nil {
|
||||
return nil, err
|
||||
if !a.skipInit {
|
||||
// Initialize authority from options or configuration.
|
||||
if err := a.init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// reloadAdminResources reloads admins and provisioners from the DB.
|
||||
func (a *Authority) reloadAdminResources(ctx context.Context) error {
|
||||
type authorityKey struct{}
|
||||
|
||||
// NewContext adds the given authority to the context.
|
||||
func NewContext(ctx context.Context, a *Authority) context.Context {
|
||||
return context.WithValue(ctx, authorityKey{}, a)
|
||||
}
|
||||
|
||||
// FromContext returns the current authority from the given context.
|
||||
func FromContext(ctx context.Context) (a *Authority, ok bool) {
|
||||
a, ok = ctx.Value(authorityKey{}).(*Authority)
|
||||
return
|
||||
}
|
||||
|
||||
// MustFromContext returns the current authority from the given context. It will
|
||||
// panic if the authority is not in the context.
|
||||
func MustFromContext(ctx context.Context) *Authority {
|
||||
if a, ok := FromContext(ctx); !ok {
|
||||
panic("authority is not in the context")
|
||||
} else {
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadAdminResources reloads admins and provisioners from the DB.
|
||||
func (a *Authority) ReloadAdminResources(ctx context.Context) error {
|
||||
var (
|
||||
provList provisioner.List
|
||||
adminList []*linkedca.Admin
|
||||
|
@ -213,6 +250,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
|
|||
a.provisioners = provClxn
|
||||
a.config.AuthorityConfig.Admins = adminList
|
||||
a.admins = adminClxn
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -224,6 +262,7 @@ func (a *Authority) init() error {
|
|||
}
|
||||
|
||||
var err error
|
||||
ctx := NewContext(context.Background(), a)
|
||||
|
||||
// Set password if they are not set.
|
||||
var configPassword []byte
|
||||
|
@ -259,12 +298,27 @@ func (a *Authority) init() error {
|
|||
if a.config.KMS != nil {
|
||||
options = *a.config.KMS
|
||||
}
|
||||
a.keyManager, err = kms.New(context.Background(), options)
|
||||
a.keyManager, err = kms.New(ctx, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize linkedca client if necessary. On a linked RA, the issuer
|
||||
// configuration might come from majordomo.
|
||||
var linkedcaClient *linkedCaClient
|
||||
if a.config.AuthorityConfig.EnableAdmin && a.linkedCAToken != "" && a.adminDB == nil {
|
||||
linkedcaClient, err = newLinkedCAClient(a.linkedCAToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If authorityId is configured make sure it matches the one in the token
|
||||
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) {
|
||||
return errors.New("error initializing linkedca: token authority and configured authority do not match")
|
||||
}
|
||||
linkedcaClient.Run()
|
||||
}
|
||||
|
||||
// Initialize the X.509 CA Service if it has not been set in the options.
|
||||
if a.x509CAService == nil {
|
||||
var options casapi.Options
|
||||
|
@ -272,6 +326,22 @@ func (a *Authority) init() error {
|
|||
options = *a.config.AuthorityConfig.Options
|
||||
}
|
||||
|
||||
// Configure linked RA
|
||||
if linkedcaClient != nil && options.CertificateAuthority == "" {
|
||||
conf, err := linkedcaClient.GetConfiguration(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if conf.RaConfig != nil {
|
||||
options.CertificateAuthority = conf.RaConfig.CaUrl
|
||||
options.CertificateAuthorityFingerprint = conf.RaConfig.Fingerprint
|
||||
options.CertificateIssuer = &casapi.CertificateIssuer{
|
||||
Type: conf.RaConfig.Provisioner.Type.String(),
|
||||
Provisioner: conf.RaConfig.Provisioner.Name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the issuer password if passed in the flags.
|
||||
if options.CertificateIssuer != nil && a.issuerPassword != nil {
|
||||
options.CertificateIssuer.Password = string(a.issuerPassword)
|
||||
|
@ -292,7 +362,7 @@ func (a *Authority) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
a.x509CAService, err = cas.New(context.Background(), options)
|
||||
a.x509CAService, err = cas.New(ctx, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -479,7 +549,7 @@ func (a *Authority) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
a.scepService, err = scep.NewService(context.Background(), options)
|
||||
a.scepService, err = scep.NewService(ctx, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -491,40 +561,29 @@ func (a *Authority) init() error {
|
|||
// Initialize step-ca Admin Database if it's not already initialized using
|
||||
// WithAdminDB.
|
||||
if a.adminDB == nil {
|
||||
if a.linkedCAToken == "" {
|
||||
// Check if AuthConfig already exists
|
||||
if linkedcaClient != nil {
|
||||
a.adminDB = linkedcaClient
|
||||
} else {
|
||||
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Use the linkedca client as the admindb.
|
||||
client, err := newLinkedCAClient(a.linkedCAToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If authorityId is configured make sure it matches the one in the token
|
||||
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) {
|
||||
return errors.New("error initializing linkedca: token authority and configured authority do not match")
|
||||
}
|
||||
client.Run()
|
||||
a.adminDB = client
|
||||
}
|
||||
}
|
||||
|
||||
provs, err := a.adminDB.GetProvisioners(context.Background())
|
||||
provs, err := a.adminDB.GetProvisioners(ctx)
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
||||
}
|
||||
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
|
||||
// Create First Provisioner
|
||||
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password))
|
||||
prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password))
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating first provisioner")
|
||||
}
|
||||
|
||||
// Create first admin
|
||||
if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{
|
||||
if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{
|
||||
ProvisionerId: prov.Id,
|
||||
Subject: "step",
|
||||
Type: linkedca.Admin_SUPER_ADMIN,
|
||||
|
@ -535,7 +594,12 @@ func (a *Authority) init() error {
|
|||
}
|
||||
|
||||
// Load Provisioners and Admins
|
||||
if err := a.reloadAdminResources(context.Background()); err != nil {
|
||||
if err := a.ReloadAdminResources(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load x509 and SSH Policy Engines
|
||||
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -570,6 +634,15 @@ func (a *Authority) init() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetID returns the define authority id or a zero uuid.
|
||||
func (a *Authority) GetID() string {
|
||||
const zeroUUID = "00000000-0000-0000-0000-000000000000"
|
||||
if id := a.config.AuthorityConfig.AuthorityID; id != "" {
|
||||
return id
|
||||
}
|
||||
return zeroUUID
|
||||
}
|
||||
|
||||
// GetDatabase returns the authority database. If the configuration does not
|
||||
// define a database, GetDatabase will return a db.SimpleDB instance.
|
||||
func (a *Authority) GetDatabase() db.AuthDB {
|
||||
|
@ -581,6 +654,12 @@ func (a *Authority) GetAdminDatabase() admin.DB {
|
|||
return a.adminDB
|
||||
}
|
||||
|
||||
// GetConfig returns the config.
|
||||
func (a *Authority) GetConfig() *config.Config {
|
||||
return a.config
|
||||
}
|
||||
|
||||
// GetInfo returns information about the authority.
|
||||
func (a *Authority) GetInfo() Info {
|
||||
ai := Info{
|
||||
StartTime: a.startTime,
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"go.step.sm/crypto/jose"
|
||||
|
@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_GetID(t *testing.T) {
|
||||
type fields struct {
|
||||
authorityID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{"ok", fields{""}, "00000000-0000-0000-0000-000000000000"},
|
||||
{"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Authority{
|
||||
config: &config.Config{
|
||||
AuthorityConfig: &config.AuthConfig{
|
||||
AuthorityID: tt.fields.authorityID,
|
||||
},
|
||||
},
|
||||
}
|
||||
if got := a.GetID(); got != tt.want {
|
||||
t.Errorf("Authority.GetID() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
@ -41,14 +42,12 @@ func SkipTokenReuseFromContext(ctx context.Context) bool {
|
|||
return m
|
||||
}
|
||||
|
||||
// authorizeToken parses the token and returns the provisioner used to generate
|
||||
// the token. This method enforces the One-Time use policy (tokens can only be
|
||||
// used once).
|
||||
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
|
||||
// Validate payload
|
||||
// getProvisionerFromToken extracts a provisioner from the given token without
|
||||
// doing any token validation.
|
||||
func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface, *Claims, error) {
|
||||
tok, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token")
|
||||
return nil, nil, fmt.Errorf("error parsing token: %w", err)
|
||||
}
|
||||
|
||||
// Get claims w/out verification. We need to look up the provisioner
|
||||
|
@ -56,7 +55,25 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
|||
// before we can look up the provisioner.
|
||||
var claims Claims
|
||||
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
|
||||
return nil, nil, fmt.Errorf("error unmarshaling token: %w", err)
|
||||
}
|
||||
|
||||
// This method will also validate the audiences for JWK provisioners.
|
||||
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
|
||||
}
|
||||
|
||||
return p, &claims, nil
|
||||
}
|
||||
|
||||
// authorizeToken parses the token and returns the provisioner used to generate
|
||||
// the token. This method enforces the One-Time use policy (tokens can only be
|
||||
// used once).
|
||||
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
|
||||
p, claims, err := a.getProvisionerFromToken(token)
|
||||
if err != nil {
|
||||
return nil, errs.UnauthorizedErr(err)
|
||||
}
|
||||
|
||||
// TODO: use new persistence layer abstraction.
|
||||
|
@ -64,17 +81,10 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
|||
// This check is meant as a stopgap solution to the current lack of a persistence layer.
|
||||
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
|
||||
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
|
||||
return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority")
|
||||
return nil, errs.Unauthorized("token issued before the bootstrap of certificate authority")
|
||||
}
|
||||
}
|
||||
|
||||
// This method will also validate the audiences for JWK provisioners.
|
||||
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
|
||||
if !ok {
|
||||
return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
|
||||
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
|
||||
}
|
||||
|
||||
// Store the token to protect against reuse unless it's skipped.
|
||||
// If we cannot get a token id from the provisioner, just hash the token.
|
||||
if !SkipTokenReuseFromContext(ctx) {
|
||||
|
@ -130,22 +140,24 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
|
|||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||
// more than a few minutes.
|
||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: prov.GetName(),
|
||||
Time: time.Now().UTC(),
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "x5c.authorizeToken; invalid x5c claims")
|
||||
}
|
||||
|
||||
// validate audience: path matches the current path
|
||||
if r.URL.Path != claims.Audience[0] {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
"x5c.authorizeToken; x5c token has invalid audience "+
|
||||
"claim (aud); expected %s, but got %s", r.URL.Path, claims.Audience)
|
||||
if !matchesAudience(claims.Audience, a.config.Audience(r.URL.Path)) {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
// validate issuer: old versions used the provisioner name, new version uses
|
||||
// 'step-admin-client/1.0'
|
||||
if claims.Issuer != "step-admin-client/1.0" && claims.Issuer != prov.GetName() {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid issuer claim (iss)")
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
"x5c.authorizeToken; x5c token subject cannot be empty")
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token subject cannot be empty")
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -156,7 +168,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
|
|||
adminSANs := append([]string{leaf.Subject.CommonName}, leaf.DNSNames...)
|
||||
adminSANs = append(adminSANs, leaf.EmailAddresses...)
|
||||
for _, san := range adminSANs {
|
||||
if adm, ok = a.LoadAdminBySubProv(san, claims.Issuer); ok {
|
||||
if adm, ok = a.LoadAdminBySubProv(san, prov.GetName()); ok {
|
||||
adminFound = true
|
||||
break
|
||||
}
|
||||
|
@ -186,11 +198,10 @@ func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
|
|||
}
|
||||
ok, err := a.db.UseToken(reuseKey, token)
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err,
|
||||
"authority.authorizeToken: failed when attempting to store token")
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token")
|
||||
}
|
||||
if !ok {
|
||||
return errs.Unauthorized("authority.authorizeToken: token already used")
|
||||
return errs.Unauthorized("token already used")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -249,10 +260,10 @@ func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisio
|
|||
// AuthorizeSign authorizes a signature request by validating and authenticating
|
||||
// a token that must be sent w/ the request.
|
||||
//
|
||||
// NOTE: This method is deprecated and should not be used. We make it available
|
||||
// in the short term os as not to break existing clients.
|
||||
// Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error).
|
||||
func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
ctx := NewContext(context.Background(), a)
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
return a.Authorize(ctx, token)
|
||||
}
|
||||
|
||||
|
@ -285,9 +296,16 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
|||
if isRevoked {
|
||||
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
|
||||
}
|
||||
p, ok := a.provisioners.LoadByCertificate(cert)
|
||||
if !ok {
|
||||
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
||||
p, err := a.LoadProvisionerByCertificate(cert)
|
||||
if err != nil {
|
||||
var ok bool
|
||||
// For backward compatibility this method will also succeed if the
|
||||
// certificate does not have a provisioner extension. LoadByCertificate
|
||||
// returns the noop provisioner if this happens, and it allows
|
||||
// certificate renewals.
|
||||
if p, ok = a.provisioners.LoadByCertificate(cert); !ok {
|
||||
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
||||
}
|
||||
}
|
||||
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||
|
@ -386,8 +404,8 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
|
|||
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
|
||||
}
|
||||
|
||||
p, ok := a.provisioners.LoadByCertificate(leaf)
|
||||
if !ok {
|
||||
p, err := a.LoadProvisionerByCertificate(leaf)
|
||||
if err != nil {
|
||||
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
|
||||
}
|
||||
if err := a.UseToken(ott, p); err != nil {
|
||||
|
@ -395,7 +413,6 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
|
|||
}
|
||||
|
||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: p.GetName(),
|
||||
Subject: leaf.Subject.CommonName,
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
|
@ -420,6 +437,12 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
|
|||
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
|
||||
}
|
||||
|
||||
// validate issuer: old versions used the provisioner name, new version uses
|
||||
// 'step-ca-client/1.0'
|
||||
if claims.Issuer != "step-ca-client/1.0" && claims.Issuer != p.GetName() {
|
||||
return nil, admin.NewError(admin.ErrorUnauthorizedType, "error validating renew token: invalid issuer claim (iss)")
|
||||
}
|
||||
|
||||
return leaf, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -114,7 +114,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: "foo",
|
||||
err: errors.New("authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -133,7 +133,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: raw,
|
||||
err: errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"),
|
||||
err: errors.New("token issued before the bootstrap of certificate authority"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -155,7 +155,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: raw,
|
||||
err: errors.New("authority.authorizeToken: provisioner not found or invalid audience (https://example.com/revoke)"),
|
||||
err: errors.New("provisioner not found or invalid audience (https://example.com/revoke)"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -192,7 +192,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: _a,
|
||||
token: raw,
|
||||
err: errors.New("authority.authorizeToken: token already used"),
|
||||
err: errors.New("token already used"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -227,7 +227,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: _a,
|
||||
token: raw,
|
||||
err: errors.New("authority.authorizeToken: token already used"),
|
||||
err: errors.New("token already used"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -275,7 +275,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: _a,
|
||||
token: raw,
|
||||
err: errors.New("authority.authorizeToken: failed when attempting to store token: force"),
|
||||
err: errors.New("failed when attempting to store token: force"),
|
||||
code: http.StatusInternalServerError,
|
||||
}
|
||||
},
|
||||
|
@ -300,7 +300,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: _a,
|
||||
token: raw,
|
||||
err: errors.New("authority.authorizeToken: token already used"),
|
||||
err: errors.New("token already used"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -353,7 +353,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: "foo",
|
||||
err: errors.New("authority.authorizeRevoke: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.authorizeRevoke: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -437,7 +437,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: "foo",
|
||||
err: errors.New("authority.authorizeSign: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.authorizeSign: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -491,7 +491,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
|
|||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Len(t, 7, got)
|
||||
assert.Equals(t, 9, len(got)) // number of provisioner.SignOptions returned
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -524,7 +524,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
auth: a,
|
||||
token: "foo",
|
||||
ctx: context.Background(),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -533,7 +533,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
auth: a,
|
||||
token: "foo",
|
||||
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -559,7 +559,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
auth: a,
|
||||
token: "foo",
|
||||
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod),
|
||||
err: errors.New("authority.Authorize: authority.authorizeRevoke: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.Authorize: authority.authorizeRevoke: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -585,7 +585,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
auth: a,
|
||||
token: "foo",
|
||||
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSSHSign: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSSHSign: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -615,7 +615,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
auth: a,
|
||||
token: "foo",
|
||||
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSSHRenew: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -659,7 +659,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
auth: a,
|
||||
token: "foo",
|
||||
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -685,7 +685,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
auth: a,
|
||||
token: "foo",
|
||||
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.Authorize: authority.authorizeSSHRekey: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -847,6 +847,29 @@ func TestAuthority_authorizeRenew(t *testing.T) {
|
|||
cert: fooCrt,
|
||||
}
|
||||
},
|
||||
"ok/from db": func(t *testing.T) *authorizeTest {
|
||||
a := testAuthority(t)
|
||||
a.db = &db.MockAuthDB{
|
||||
MIsRevoked: func(key string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||
p, ok := a.provisioners.LoadByName("step-cli")
|
||||
if !ok {
|
||||
t.Fatal("provisioner step-cli not found")
|
||||
}
|
||||
return &db.CertificateData{
|
||||
Provisioner: &db.ProvisionerData{
|
||||
ID: p.GetID(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
cert: fooCrt,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for name, genTestCase := range tests {
|
||||
|
@ -965,7 +988,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: "foo",
|
||||
err: errors.New("authority.authorizeSSHSign: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.authorizeSSHSign: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -1011,7 +1034,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
|
|||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Len(t, 7, got)
|
||||
assert.Len(t, 9, got) // number of provisioner.SignOptions returned
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -1059,7 +1082,7 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: "foo",
|
||||
err: errors.New("authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.authorizeSSHRenew: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -1167,7 +1190,7 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: "foo",
|
||||
err: errors.New("authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.authorizeSSHRevoke: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -1259,7 +1282,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: "foo",
|
||||
err: errors.New("authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"),
|
||||
err: errors.New("authority.authorizeSSHRekey: error parsing token"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -1322,7 +1345,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
|
|||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.cert.Serial, cert.Serial)
|
||||
assert.Len(t, 3, signOpts)
|
||||
assert.Len(t, 4, signOpts)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -1381,7 +1404,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
t1, c1 := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
|
@ -1400,7 +1423,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
t2, c2 := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
IssuedAt: jose.NewNumericDate(now),
|
||||
|
@ -1417,12 +1440,31 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
})
|
||||
return nil
|
||||
}))
|
||||
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
|
||||
t3, c3 := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
|
@ -1439,7 +1481,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
|
@ -1477,7 +1519,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "bad-subject",
|
||||
Issuer: "step-cli",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
|
@ -1496,7 +1538,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/sign"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
|
@ -1515,7 +1557,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/sign"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
|
||||
Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
|
@ -1534,7 +1576,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/sign"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
|
@ -1554,7 +1596,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/sign"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
Issuer: "step-ca-client/1.0",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
|
@ -1584,6 +1626,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
|||
}{
|
||||
{"ok", a1, args{ctx, t1}, c1, false},
|
||||
{"ok expired cert", a1, args{ctx, t2}, c2, false},
|
||||
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
|
||||
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
|
||||
{"fail token reuse", a1, args{ctx, t1}, nil, true},
|
||||
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
|
||||
|
|
|
@ -8,12 +8,15 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
cas "github.com/smallstep/certificates/cas/apiv1"
|
||||
"github.com/smallstep/certificates/db"
|
||||
kms "github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -26,27 +29,27 @@ var (
|
|||
DefaultBackdate = time.Minute
|
||||
// DefaultDisableRenewal disables renewals per provisioner.
|
||||
DefaultDisableRenewal = false
|
||||
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is
|
||||
// DefaultAllowRenewalAfterExpiry allows renewals even if the certificate is
|
||||
// expired.
|
||||
DefaultAllowRenewAfterExpiry = false
|
||||
DefaultAllowRenewalAfterExpiry = false
|
||||
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally
|
||||
// for all provisioners.
|
||||
DefaultEnableSSHCA = false
|
||||
// GlobalProvisionerClaims default claims for the Authority. Can be overridden
|
||||
// by provisioner specific claims.
|
||||
GlobalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
|
||||
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &DefaultEnableSSHCA,
|
||||
DisableRenewal: &DefaultDisableRenewal,
|
||||
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry,
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
|
||||
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &DefaultEnableSSHCA,
|
||||
DisableRenewal: &DefaultDisableRenewal,
|
||||
AllowRenewalAfterExpiry: &DefaultAllowRenewalAfterExpiry,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -68,6 +71,7 @@ type Config struct {
|
|||
TLS *TLSOptions `json:"tls,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Templates *templates.Templates `json:"templates,omitempty"`
|
||||
CommonName string `json:"commonName,omitempty"`
|
||||
CRL *CRLConfig `json:"crl,omitempty"`
|
||||
}
|
||||
|
||||
|
@ -95,6 +99,7 @@ type AuthConfig struct {
|
|||
Admins []*linkedca.Admin `json:"-"`
|
||||
Template *ASN1DN `json:"template,omitempty"`
|
||||
Claims *provisioner.Claims `json:"claims,omitempty"`
|
||||
Policy *policy.Options `json:"policy,omitempty"`
|
||||
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
||||
Backdate *provisioner.Duration `json:"backdate,omitempty"`
|
||||
EnableAdmin bool `json:"enableAdmin,omitempty"`
|
||||
|
@ -180,6 +185,9 @@ func (c *Config) Init() {
|
|||
if c.AuthorityConfig == nil {
|
||||
c.AuthorityConfig = &AuthConfig{}
|
||||
}
|
||||
if c.CommonName == "" {
|
||||
c.CommonName = "Step Online CA"
|
||||
}
|
||||
c.AuthorityConfig.init()
|
||||
}
|
||||
|
||||
|
@ -311,6 +319,18 @@ func (c *Config) GetAudiences() provisioner.Audiences {
|
|||
return audiences
|
||||
}
|
||||
|
||||
// Audience returns the list of audiences for a given path.
|
||||
func (c *Config) Audience(path string) []string {
|
||||
audiences := make([]string, len(c.DNSNames)+1)
|
||||
for i, name := range c.DNSNames {
|
||||
hostname := toHostname(name)
|
||||
audiences[i] = "https://" + hostname + path
|
||||
}
|
||||
// For backward compatibility
|
||||
audiences[len(c.DNSNames)] = path
|
||||
return audiences
|
||||
}
|
||||
|
||||
func toHostname(name string) string {
|
||||
// ensure an IPv6 address is represented with square brackets when used as hostname
|
||||
if ip := net.ParseIP(name); ip != nil && ip.To4() == nil {
|
||||
|
|
|
@ -2,6 +2,7 @@ package config
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -317,3 +318,38 @@ func Test_toHostname(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Audience(t *testing.T) {
|
||||
type fields struct {
|
||||
DNSNames []string
|
||||
}
|
||||
type args struct {
|
||||
path string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{"ok", fields{[]string{
|
||||
"ca", "ca.example.com", "127.0.0.1", "::1",
|
||||
}}, args{"/path"}, []string{
|
||||
"https://ca/path",
|
||||
"https://ca.example.com/path",
|
||||
"https://127.0.0.1/path",
|
||||
"https://[::1]/path",
|
||||
"/path",
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Config{
|
||||
DNSNames: tt.fields.DNSNames,
|
||||
}
|
||||
if got := c.Audience(tt.args.path); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Config.Audience() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,15 +15,19 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/tlsutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"go.step.sm/linkedca"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
)
|
||||
|
||||
const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
|
||||
|
@ -34,6 +38,9 @@ type linkedCaClient struct {
|
|||
authorityID string
|
||||
}
|
||||
|
||||
// interface guard
|
||||
var _ admin.DB = (*linkedCaClient)(nil)
|
||||
|
||||
type linkedCAClaims struct {
|
||||
jose.Claims
|
||||
SANs []string `json:"sans"`
|
||||
|
@ -115,6 +122,13 @@ func newLinkedCAClient(token string) (*linkedCaClient, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// IsLinkedCA is a sentinel function that can be used to
|
||||
// check if a linkedCaClient is the underlying type of an
|
||||
// admin.DB interface.
|
||||
func (c *linkedCaClient) IsLinkedCA() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) Run() {
|
||||
c.renewer.Run()
|
||||
}
|
||||
|
@ -151,13 +165,21 @@ func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linked
|
|||
}
|
||||
|
||||
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
||||
resp, err := c.GetConfiguration(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Provisioners, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetConfiguration(ctx context.Context) (*linkedca.ConfigurationResponse, error) {
|
||||
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
||||
AuthorityId: c.authorityID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting provisioners")
|
||||
return nil, errors.Wrap(err, "error getting configuration")
|
||||
}
|
||||
return resp.Provisioners, nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
|
@ -204,11 +226,9 @@ func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Adm
|
|||
}
|
||||
|
||||
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
||||
AuthorityId: c.authorityID,
|
||||
})
|
||||
resp, err := c.GetConfiguration(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting admins")
|
||||
return nil, err
|
||||
}
|
||||
return resp.Admins, nil
|
||||
}
|
||||
|
@ -228,12 +248,35 @@ func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
|
|||
return errors.Wrap(err, "error deleting admin")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error {
|
||||
func (c *linkedCaClient) GetCertificateData(serial string) (*db.CertificateData, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := c.client.GetCertificate(ctx, &linkedca.GetCertificateRequest{
|
||||
Serial: serial,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pd *db.ProvisionerData
|
||||
if p := resp.Provisioner; p != nil {
|
||||
pd = &db.ProvisionerData{
|
||||
ID: p.Id, Name: p.Name, Type: p.Type.String(),
|
||||
}
|
||||
}
|
||||
return &db.CertificateData{
|
||||
Provisioner: pd,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreCertificateChain(p provisioner.Interface, fullchain ...*x509.Certificate) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
|
||||
PemCertificate: serializeCertificateChain(fullchain[0]),
|
||||
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
||||
Provisioner: createProvisionerIdentity(p),
|
||||
})
|
||||
return errors.Wrap(err, "error posting certificate")
|
||||
}
|
||||
|
@ -246,18 +289,30 @@ func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullc
|
|||
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
||||
PemParentCertificate: serializeCertificateChain(parent),
|
||||
})
|
||||
return errors.Wrap(err, "error posting certificate")
|
||||
return errors.Wrap(err, "error posting renewed certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error {
|
||||
func (c *linkedCaClient) StoreSSHCertificate(p provisioner.Interface, crt *ssh.Certificate) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
|
||||
Certificate: string(ssh.MarshalAuthorizedKey(crt)),
|
||||
Provisioner: createProvisionerIdentity(p),
|
||||
})
|
||||
return errors.Wrap(err, "error posting ssh certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreRenewedSSHCertificate(p provisioner.Interface, parent, crt *ssh.Certificate) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
|
||||
Certificate: string(ssh.MarshalAuthorizedKey(crt)),
|
||||
ParentCertificate: string(ssh.MarshalAuthorizedKey(parent)),
|
||||
Provisioner: createProvisionerIdentity(p),
|
||||
})
|
||||
return errors.Wrap(err, "error posting renewed ssh certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
@ -310,6 +365,33 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
|
|||
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
return errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
return errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error {
|
||||
return errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func createProvisionerIdentity(p provisioner.Interface) *linkedca.ProvisionerIdentity {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &linkedca.ProvisionerIdentity{
|
||||
Id: p.GetID(),
|
||||
Type: linkedca.Provisioner_Type(p.GetType()),
|
||||
Name: p.GetName(),
|
||||
}
|
||||
}
|
||||
|
||||
func serializeCertificate(crt *x509.Certificate) string {
|
||||
if crt == nil {
|
||||
return ""
|
||||
|
|
|
@ -266,6 +266,16 @@ func WithAdminDB(d admin.DB) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithProvisioners is an option to set the provisioner collection.
|
||||
//
|
||||
// Deprecated: provisioner collections will likely change
|
||||
func WithProvisioners(ps *provisioner.Collection) Option {
|
||||
return func(a *Authority) error {
|
||||
a.provisioners = ps
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithLinkedCAToken is an option to set the authentication token used to enable
|
||||
// linked ca.
|
||||
func WithLinkedCAToken(token string) Option {
|
||||
|
@ -284,6 +294,15 @@ func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithSkipInit is an option that allows the constructor to skip initializtion
|
||||
// of the authority.
|
||||
func WithSkipInit() Option {
|
||||
return func(a *Authority) error {
|
||||
a.skipInit = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
|
||||
var block *pem.Block
|
||||
var certs []*x509.Certificate
|
||||
|
|
265
authority/policy.go
Normal file
265
authority/policy.go
Normal file
|
@ -0,0 +1,265 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
authPolicy "github.com/smallstep/certificates/authority/policy"
|
||||
policy "github.com/smallstep/certificates/policy"
|
||||
)
|
||||
|
||||
type policyErrorType int
|
||||
|
||||
const (
|
||||
AdminLockOut policyErrorType = iota + 1
|
||||
StoreFailure
|
||||
ReloadFailure
|
||||
ConfigurationFailure
|
||||
EvaluationFailure
|
||||
InternalFailure
|
||||
)
|
||||
|
||||
type PolicyError struct {
|
||||
Typ policyErrorType
|
||||
Err error
|
||||
}
|
||||
|
||||
func (p *PolicyError) Error() string {
|
||||
return p.Err.Error()
|
||||
}
|
||||
|
||||
func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
p, err := a.adminDB.GetAuthorityPolicy(ctx)
|
||||
if err != nil {
|
||||
return nil, &PolicyError{
|
||||
Typ: InternalFailure,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil {
|
||||
return nil, &PolicyError{
|
||||
Typ: StoreFailure,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||
return nil, &PolicyError{
|
||||
Typ: ReloadFailure,
|
||||
Err: fmt.Errorf("error reloading policy engines when creating authority policy: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil {
|
||||
return nil, &PolicyError{
|
||||
Typ: StoreFailure,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||
return nil, &PolicyError{
|
||||
Typ: ReloadFailure,
|
||||
Err: fmt.Errorf("error reloading policy engines when updating authority policy: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil {
|
||||
return &PolicyError{
|
||||
Typ: StoreFailure,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||
return &PolicyError{
|
||||
Typ: ReloadFailure,
|
||||
Err: fmt.Errorf("error reloading policy engines when deleting authority policy: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Authority) checkAuthorityPolicy(ctx context.Context, currentAdmin *linkedca.Admin, p *linkedca.Policy) error {
|
||||
|
||||
// no policy and thus nothing to evaluate; return early
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get all current admins from the database
|
||||
allAdmins, err := a.adminDB.GetAdmins(ctx)
|
||||
if err != nil {
|
||||
return &PolicyError{
|
||||
Typ: InternalFailure,
|
||||
Err: fmt.Errorf("error retrieving admins: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
return a.checkPolicy(ctx, currentAdmin, allAdmins, p)
|
||||
}
|
||||
|
||||
func (a *Authority) checkProvisionerPolicy(ctx context.Context, provName string, p *linkedca.Policy) error {
|
||||
|
||||
// no policy and thus nothing to evaluate; return early
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get all admins for the provisioner; ignoring case in which they're not found
|
||||
allProvisionerAdmins, _ := a.admins.LoadByProvisioner(provName)
|
||||
|
||||
// check the policy; pass in nil as the current admin, as all admins for the
|
||||
// provisioner will be checked by looping through allProvisionerAdmins. Also,
|
||||
// the current admin may be a super admin not belonging to the provisioner, so
|
||||
// can't be blocked, but is not required to be in the policy, either.
|
||||
return a.checkPolicy(ctx, nil, allProvisionerAdmins, p)
|
||||
}
|
||||
|
||||
// checkPolicy checks if a new or updated policy configuration results in the user
|
||||
// locking themselves or other admins out of the CA.
|
||||
func (a *Authority) checkPolicy(ctx context.Context, currentAdmin *linkedca.Admin, otherAdmins []*linkedca.Admin, p *linkedca.Policy) error {
|
||||
|
||||
// convert the policy; return early if nil
|
||||
policyOptions := authPolicy.LinkedToCertificates(p)
|
||||
if policyOptions == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options())
|
||||
if err != nil {
|
||||
return &PolicyError{
|
||||
Typ: ConfigurationFailure,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// when an empty X.509 policy is provided, the resulting engine is nil
|
||||
// and there's no policy to evaluate.
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(hs): Provide option to force the policy, even when the admin subject would be locked out?
|
||||
|
||||
// check if the admin user that instructed the authority policy to be
|
||||
// created or updated, would still be allowed when the provided policy
|
||||
// would be applied. This case is skipped when current admin is nil, which
|
||||
// is the case when a provisioner policy is checked.
|
||||
if currentAdmin != nil {
|
||||
sans := []string{currentAdmin.GetSubject()}
|
||||
if err := isAllowed(engine, sans); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// loop through admins to verify that none of them would be
|
||||
// locked out when the new policy were to be applied. Returns
|
||||
// an error with a message that includes the admin subject that
|
||||
// would be locked out.
|
||||
for _, adm := range otherAdmins {
|
||||
sans := []string{adm.GetSubject()}
|
||||
if err := isAllowed(engine, sans); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(hs): mask the error message for non-super admins?
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadPolicyEngines reloads x509 and SSH policy engines using
|
||||
// configuration stored in the DB or from the configuration file.
|
||||
func (a *Authority) reloadPolicyEngines(ctx context.Context) error {
|
||||
var (
|
||||
err error
|
||||
policyOptions *authPolicy.Options
|
||||
)
|
||||
|
||||
if a.config.AuthorityConfig.EnableAdmin {
|
||||
|
||||
// temporarily disable policy loading when LinkedCA is in use
|
||||
if _, ok := a.adminDB.(*linkedCaClient); ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx)
|
||||
if err != nil {
|
||||
var ae *admin.Error
|
||||
if isAdminError := errors.As(err, &ae); (isAdminError && ae.Type != admin.ErrorNotFoundType.String()) || !isAdminError {
|
||||
return fmt.Errorf("error getting policy to (re)load policy engines: %w", err)
|
||||
}
|
||||
}
|
||||
policyOptions = authPolicy.LinkedToCertificates(linkedPolicy)
|
||||
} else {
|
||||
policyOptions = a.config.AuthorityConfig.Policy
|
||||
}
|
||||
|
||||
engine, err := authPolicy.New(policyOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only update the policy engine when no error was returned
|
||||
a.policyEngine = engine
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAllowed(engine authPolicy.X509Policy, sans []string) error {
|
||||
if err := engine.AreSANsAllowed(sans); err != nil {
|
||||
var policyErr *policy.NamePolicyError
|
||||
isNamePolicyError := errors.As(err, &policyErr)
|
||||
if isNamePolicyError && policyErr.Reason == policy.NotAllowed {
|
||||
return &PolicyError{
|
||||
Typ: AdminLockOut,
|
||||
Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans),
|
||||
}
|
||||
}
|
||||
return &PolicyError{
|
||||
Typ: EvaluationFailure,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
114
authority/policy/engine.go
Normal file
114
authority/policy/engine.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Engine is a container for multiple policies.
|
||||
type Engine struct {
|
||||
x509Policy X509Policy
|
||||
sshUserPolicy UserPolicy
|
||||
sshHostPolicy HostPolicy
|
||||
}
|
||||
|
||||
// New returns a new Engine using Options.
|
||||
func New(options *Options) (*Engine, error) {
|
||||
|
||||
// if no options provided, return early
|
||||
if options == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
x509Policy X509Policy
|
||||
sshHostPolicy HostPolicy
|
||||
sshUserPolicy UserPolicy
|
||||
err error
|
||||
)
|
||||
|
||||
// initialize the x509 allow/deny policy engine
|
||||
if x509Policy, err = NewX509PolicyEngine(options.GetX509Options()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// initialize the SSH allow/deny policy engine for host certificates
|
||||
if sshHostPolicy, err = NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// initialize the SSH allow/deny policy engine for user certificates
|
||||
if sshUserPolicy, err = NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Engine{
|
||||
x509Policy: x509Policy,
|
||||
sshHostPolicy: sshHostPolicy,
|
||||
sshUserPolicy: sshUserPolicy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsX509CertificateAllowed evaluates an X.509 certificate against
|
||||
// the X.509 policy (if available) and returns an error if one of the
|
||||
// names in the certificate is not allowed.
|
||||
func (e *Engine) IsX509CertificateAllowed(cert *x509.Certificate) error {
|
||||
|
||||
// return early if there's no policy to evaluate
|
||||
if e == nil || e.x509Policy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// return result of X.509 policy evaluation
|
||||
return e.x509Policy.IsX509CertificateAllowed(cert)
|
||||
}
|
||||
|
||||
// AreSANsAllowed evaluates the slice of SANs against the X.509 policy
|
||||
// (if available) and returns an error if one of the SANs is not allowed.
|
||||
func (e *Engine) AreSANsAllowed(sans []string) error {
|
||||
|
||||
// return early if there's no policy to evaluate
|
||||
if e == nil || e.x509Policy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// return result of X.509 policy evaluation
|
||||
return e.x509Policy.AreSANsAllowed(sans)
|
||||
}
|
||||
|
||||
// IsSSHCertificateAllowed evaluates an SSH certificate against the
|
||||
// user or host policy (if configured) and returns an error if one of the
|
||||
// principals in the certificate is not allowed.
|
||||
func (e *Engine) IsSSHCertificateAllowed(cert *ssh.Certificate) error {
|
||||
|
||||
// return early if there's no policy to evaluate
|
||||
if e == nil || (e.sshHostPolicy == nil && e.sshUserPolicy == nil) {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch cert.CertType {
|
||||
case ssh.HostCert:
|
||||
// when no host policy engine is configured, but a user policy engine is
|
||||
// configured, the host certificate is denied.
|
||||
if e.sshHostPolicy == nil && e.sshUserPolicy != nil {
|
||||
return errors.New("authority not allowed to sign ssh host certificates")
|
||||
}
|
||||
|
||||
// return result of SSH host policy evaluation
|
||||
return e.sshHostPolicy.IsSSHCertificateAllowed(cert)
|
||||
case ssh.UserCert:
|
||||
// when no user policy engine is configured, but a host policy engine is
|
||||
// configured, the user certificate is denied.
|
||||
if e.sshUserPolicy == nil && e.sshHostPolicy != nil {
|
||||
return errors.New("authority not allowed to sign ssh user certificates")
|
||||
}
|
||||
|
||||
// return result of SSH user policy evaluation
|
||||
return e.sshUserPolicy.IsSSHCertificateAllowed(cert)
|
||||
default:
|
||||
return fmt.Errorf("unexpected ssh certificate type %q", cert.CertType)
|
||||
}
|
||||
}
|
194
authority/policy/options.go
Normal file
194
authority/policy/options.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
package policy
|
||||
|
||||
// Options is a container for authority level x509 and SSH
|
||||
// policy configuration.
|
||||
type Options struct {
|
||||
X509 *X509PolicyOptions `json:"x509,omitempty"`
|
||||
SSH *SSHPolicyOptions `json:"ssh,omitempty"`
|
||||
}
|
||||
|
||||
// GetX509Options returns the x509 authority level policy
|
||||
// configuration
|
||||
func (o *Options) GetX509Options() *X509PolicyOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.X509
|
||||
}
|
||||
|
||||
// GetSSHOptions returns the SSH authority level policy
|
||||
// configuration
|
||||
func (o *Options) GetSSHOptions() *SSHPolicyOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.SSH
|
||||
}
|
||||
|
||||
// X509PolicyOptionsInterface is an interface for providers
|
||||
// of x509 allowed and denied names.
|
||||
type X509PolicyOptionsInterface interface {
|
||||
GetAllowedNameOptions() *X509NameOptions
|
||||
GetDeniedNameOptions() *X509NameOptions
|
||||
AreWildcardNamesAllowed() bool
|
||||
}
|
||||
|
||||
// X509PolicyOptions is a container for x509 allowed and denied
|
||||
// names.
|
||||
type X509PolicyOptions struct {
|
||||
// AllowedNames contains the x509 allowed names
|
||||
AllowedNames *X509NameOptions `json:"allow,omitempty"`
|
||||
|
||||
// DeniedNames contains the x509 denied names
|
||||
DeniedNames *X509NameOptions `json:"deny,omitempty"`
|
||||
|
||||
// AllowWildcardNames indicates if literal wildcard names
|
||||
// like *.example.com are allowed. Defaults to false.
|
||||
AllowWildcardNames bool `json:"allowWildcardNames,omitempty"`
|
||||
}
|
||||
|
||||
// X509NameOptions models the X509 name policy configuration.
|
||||
type X509NameOptions struct {
|
||||
CommonNames []string `json:"cn,omitempty"`
|
||||
DNSDomains []string `json:"dns,omitempty"`
|
||||
IPRanges []string `json:"ip,omitempty"`
|
||||
EmailAddresses []string `json:"email,omitempty"`
|
||||
URIDomains []string `json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
// HasNames checks if the AllowedNameOptions has one or more
|
||||
// names configured.
|
||||
func (o *X509NameOptions) HasNames() bool {
|
||||
return len(o.CommonNames) > 0 ||
|
||||
len(o.DNSDomains) > 0 ||
|
||||
len(o.IPRanges) > 0 ||
|
||||
len(o.EmailAddresses) > 0 ||
|
||||
len(o.URIDomains) > 0
|
||||
}
|
||||
|
||||
// GetAllowedNameOptions returns x509 allowed name policy configuration
|
||||
func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedNameOptions returns the x509 denied name policy configuration
|
||||
func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.DeniedNames
|
||||
}
|
||||
|
||||
// AreWildcardNamesAllowed returns whether the authority allows
|
||||
// literal wildcard names to be signed.
|
||||
func (o *X509PolicyOptions) AreWildcardNamesAllowed() bool {
|
||||
if o == nil {
|
||||
return true
|
||||
}
|
||||
return o.AllowWildcardNames
|
||||
}
|
||||
|
||||
// SSHPolicyOptionsInterface is an interface for providers of
|
||||
// SSH user and host name policy configuration.
|
||||
type SSHPolicyOptionsInterface interface {
|
||||
GetAllowedUserNameOptions() *SSHNameOptions
|
||||
GetDeniedUserNameOptions() *SSHNameOptions
|
||||
GetAllowedHostNameOptions() *SSHNameOptions
|
||||
GetDeniedHostNameOptions() *SSHNameOptions
|
||||
}
|
||||
|
||||
// SSHPolicyOptions is a container for SSH user and host policy
|
||||
// configuration
|
||||
type SSHPolicyOptions struct {
|
||||
// User contains SSH user certificate options.
|
||||
User *SSHUserCertificateOptions `json:"user,omitempty"`
|
||||
// Host contains SSH host certificate options.
|
||||
Host *SSHHostCertificateOptions `json:"host,omitempty"`
|
||||
}
|
||||
|
||||
// GetAllowedUserNameOptions returns the SSH allowed user name policy
|
||||
// configuration.
|
||||
func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions {
|
||||
if o == nil || o.User == nil {
|
||||
return nil
|
||||
}
|
||||
return o.User.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedUserNameOptions returns the SSH denied user name policy
|
||||
// configuration.
|
||||
func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions {
|
||||
if o == nil || o.User == nil {
|
||||
return nil
|
||||
}
|
||||
return o.User.DeniedNames
|
||||
}
|
||||
|
||||
// GetAllowedHostNameOptions returns the SSH allowed host name policy
|
||||
// configuration.
|
||||
func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions {
|
||||
if o == nil || o.Host == nil {
|
||||
return nil
|
||||
}
|
||||
return o.Host.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedHostNameOptions returns the SSH denied host name policy
|
||||
// configuration.
|
||||
func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions {
|
||||
if o == nil || o.Host == nil {
|
||||
return nil
|
||||
}
|
||||
return o.Host.DeniedNames
|
||||
}
|
||||
|
||||
// SSHUserCertificateOptions is a collection of SSH user certificate options.
|
||||
type SSHUserCertificateOptions struct {
|
||||
// AllowedNames contains the names the provisioner is authorized to sign
|
||||
AllowedNames *SSHNameOptions `json:"allow,omitempty"`
|
||||
// DeniedNames contains the names the provisioner is not authorized to sign
|
||||
DeniedNames *SSHNameOptions `json:"deny,omitempty"`
|
||||
}
|
||||
|
||||
// SSHHostCertificateOptions is a collection of SSH host certificate options.
|
||||
// It's an alias of SSHUserCertificateOptions, as the options are the same
|
||||
// for both types of certificates.
|
||||
type SSHHostCertificateOptions SSHUserCertificateOptions
|
||||
|
||||
// SSHNameOptions models the SSH name policy configuration.
|
||||
type SSHNameOptions struct {
|
||||
DNSDomains []string `json:"dns,omitempty"`
|
||||
IPRanges []string `json:"ip,omitempty"`
|
||||
EmailAddresses []string `json:"email,omitempty"`
|
||||
Principals []string `json:"principal,omitempty"`
|
||||
}
|
||||
|
||||
// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the
|
||||
// names that a provisioner is authorized to sign SSH certificates for.
|
||||
func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the
|
||||
// names that a provisioner is NOT authorized to sign SSH certificates for.
|
||||
func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.DeniedNames
|
||||
}
|
||||
|
||||
// HasNames checks if the SSHNameOptions has one or more
|
||||
// names configured.
|
||||
func (o *SSHNameOptions) HasNames() bool {
|
||||
return len(o.DNSDomains) > 0 ||
|
||||
len(o.IPRanges) > 0 ||
|
||||
len(o.EmailAddresses) > 0 ||
|
||||
len(o.Principals) > 0
|
||||
}
|
45
authority/policy/options_test.go
Normal file
45
authority/policy/options_test.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestX509PolicyOptions_IsWildcardLiteralAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options *X509PolicyOptions
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil-options",
|
||||
options: nil,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "not-set",
|
||||
options: &X509PolicyOptions{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "set-true",
|
||||
options: &X509PolicyOptions{
|
||||
AllowWildcardNames: true,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "set-false",
|
||||
options: &X509PolicyOptions{
|
||||
AllowWildcardNames: false,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.options.AreWildcardNamesAllowed(); got != tt.want {
|
||||
t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
256
authority/policy/policy.go
Normal file
256
authority/policy/policy.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/policy"
|
||||
)
|
||||
|
||||
// X509Policy is an alias for policy.X509NamePolicyEngine
|
||||
type X509Policy policy.X509NamePolicyEngine
|
||||
|
||||
// UserPolicy is an alias for policy.SSHNamePolicyEngine
|
||||
type UserPolicy policy.SSHNamePolicyEngine
|
||||
|
||||
// HostPolicy is an alias for policy.SSHNamePolicyEngine
|
||||
type HostPolicy policy.SSHNamePolicyEngine
|
||||
|
||||
// NewX509PolicyEngine creates a new x509 name policy engine
|
||||
func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) {
|
||||
|
||||
// return early if no policy engine options to configure
|
||||
if policyOptions == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
options := []policy.NamePolicyOption{}
|
||||
|
||||
allowed := policyOptions.GetAllowedNameOptions()
|
||||
if allowed != nil && allowed.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithPermittedCommonNames(allowed.CommonNames...),
|
||||
policy.WithPermittedDNSDomains(allowed.DNSDomains...),
|
||||
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...),
|
||||
policy.WithPermittedEmailAddresses(allowed.EmailAddresses...),
|
||||
policy.WithPermittedURIDomains(allowed.URIDomains...),
|
||||
)
|
||||
}
|
||||
|
||||
denied := policyOptions.GetDeniedNameOptions()
|
||||
if denied != nil && denied.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithExcludedCommonNames(denied.CommonNames...),
|
||||
policy.WithExcludedDNSDomains(denied.DNSDomains...),
|
||||
policy.WithExcludedIPsOrCIDRs(denied.IPRanges...),
|
||||
policy.WithExcludedEmailAddresses(denied.EmailAddresses...),
|
||||
policy.WithExcludedURIDomains(denied.URIDomains...),
|
||||
)
|
||||
}
|
||||
|
||||
// ensure no policy engine is returned when no name options were provided
|
||||
if len(options) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// check if configuration specifies that wildcard names are allowed
|
||||
if policyOptions.AreWildcardNamesAllowed() {
|
||||
options = append(options, policy.WithAllowLiteralWildcardNames())
|
||||
}
|
||||
|
||||
// enable subject common name verification by default
|
||||
options = append(options, policy.WithSubjectCommonNameVerification())
|
||||
|
||||
return policy.New(options...)
|
||||
}
|
||||
|
||||
type sshPolicyEngineType string
|
||||
|
||||
const (
|
||||
UserPolicyEngineType sshPolicyEngineType = "user"
|
||||
HostPolicyEngineType sshPolicyEngineType = "host"
|
||||
)
|
||||
|
||||
// newSSHUserPolicyEngine creates a new SSH user certificate policy engine
|
||||
func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) {
|
||||
policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return policyEngine, nil
|
||||
}
|
||||
|
||||
// newSSHHostPolicyEngine create a new SSH host certificate policy engine
|
||||
func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) {
|
||||
policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return policyEngine, nil
|
||||
}
|
||||
|
||||
// newSSHPolicyEngine creates a new SSH name policy engine
|
||||
func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) {
|
||||
|
||||
// return early if no policy engine options to configure
|
||||
if policyOptions == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
allowed *SSHNameOptions
|
||||
denied *SSHNameOptions
|
||||
)
|
||||
|
||||
switch typ {
|
||||
case UserPolicyEngineType:
|
||||
allowed = policyOptions.GetAllowedUserNameOptions()
|
||||
denied = policyOptions.GetDeniedUserNameOptions()
|
||||
case HostPolicyEngineType:
|
||||
allowed = policyOptions.GetAllowedHostNameOptions()
|
||||
denied = policyOptions.GetDeniedHostNameOptions()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ)
|
||||
}
|
||||
|
||||
options := []policy.NamePolicyOption{}
|
||||
|
||||
if allowed != nil && allowed.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithPermittedDNSDomains(allowed.DNSDomains...),
|
||||
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...),
|
||||
policy.WithPermittedEmailAddresses(allowed.EmailAddresses...),
|
||||
policy.WithPermittedPrincipals(allowed.Principals...),
|
||||
)
|
||||
}
|
||||
|
||||
if denied != nil && denied.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithExcludedDNSDomains(denied.DNSDomains...),
|
||||
policy.WithExcludedIPsOrCIDRs(denied.IPRanges...),
|
||||
policy.WithExcludedEmailAddresses(denied.EmailAddresses...),
|
||||
policy.WithExcludedPrincipals(denied.Principals...),
|
||||
)
|
||||
}
|
||||
|
||||
// ensure no policy engine is returned when no name options were provided
|
||||
if len(options) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return policy.New(options...)
|
||||
}
|
||||
|
||||
func LinkedToCertificates(p *linkedca.Policy) *Options {
|
||||
|
||||
// return early
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// return early if x509 nor SSH is set
|
||||
if p.GetX509() == nil && p.GetSsh() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := &Options{}
|
||||
|
||||
// fill x509 policy configuration
|
||||
if x509 := p.GetX509(); x509 != nil {
|
||||
opts.X509 = &X509PolicyOptions{}
|
||||
if allow := x509.GetAllow(); allow != nil {
|
||||
opts.X509.AllowedNames = &X509NameOptions{}
|
||||
if allow.Dns != nil {
|
||||
opts.X509.AllowedNames.DNSDomains = allow.Dns
|
||||
}
|
||||
if allow.Ips != nil {
|
||||
opts.X509.AllowedNames.IPRanges = allow.Ips
|
||||
}
|
||||
if allow.Emails != nil {
|
||||
opts.X509.AllowedNames.EmailAddresses = allow.Emails
|
||||
}
|
||||
if allow.Uris != nil {
|
||||
opts.X509.AllowedNames.URIDomains = allow.Uris
|
||||
}
|
||||
if allow.CommonNames != nil {
|
||||
opts.X509.AllowedNames.CommonNames = allow.CommonNames
|
||||
}
|
||||
}
|
||||
if deny := x509.GetDeny(); deny != nil {
|
||||
opts.X509.DeniedNames = &X509NameOptions{}
|
||||
if deny.Dns != nil {
|
||||
opts.X509.DeniedNames.DNSDomains = deny.Dns
|
||||
}
|
||||
if deny.Ips != nil {
|
||||
opts.X509.DeniedNames.IPRanges = deny.Ips
|
||||
}
|
||||
if deny.Emails != nil {
|
||||
opts.X509.DeniedNames.EmailAddresses = deny.Emails
|
||||
}
|
||||
if deny.Uris != nil {
|
||||
opts.X509.DeniedNames.URIDomains = deny.Uris
|
||||
}
|
||||
if deny.CommonNames != nil {
|
||||
opts.X509.DeniedNames.CommonNames = deny.CommonNames
|
||||
}
|
||||
}
|
||||
|
||||
opts.X509.AllowWildcardNames = x509.GetAllowWildcardNames()
|
||||
}
|
||||
|
||||
// fill ssh policy configuration
|
||||
if ssh := p.GetSsh(); ssh != nil {
|
||||
opts.SSH = &SSHPolicyOptions{}
|
||||
if host := ssh.GetHost(); host != nil {
|
||||
opts.SSH.Host = &SSHHostCertificateOptions{}
|
||||
if allow := host.GetAllow(); allow != nil {
|
||||
opts.SSH.Host.AllowedNames = &SSHNameOptions{}
|
||||
if allow.Dns != nil {
|
||||
opts.SSH.Host.AllowedNames.DNSDomains = allow.Dns
|
||||
}
|
||||
if allow.Ips != nil {
|
||||
opts.SSH.Host.AllowedNames.IPRanges = allow.Ips
|
||||
}
|
||||
if allow.Principals != nil {
|
||||
opts.SSH.Host.AllowedNames.Principals = allow.Principals
|
||||
}
|
||||
}
|
||||
if deny := host.GetDeny(); deny != nil {
|
||||
opts.SSH.Host.DeniedNames = &SSHNameOptions{}
|
||||
if deny.Dns != nil {
|
||||
opts.SSH.Host.DeniedNames.DNSDomains = deny.Dns
|
||||
}
|
||||
if deny.Ips != nil {
|
||||
opts.SSH.Host.DeniedNames.IPRanges = deny.Ips
|
||||
}
|
||||
if deny.Principals != nil {
|
||||
opts.SSH.Host.DeniedNames.Principals = deny.Principals
|
||||
}
|
||||
}
|
||||
}
|
||||
if user := ssh.GetUser(); user != nil {
|
||||
opts.SSH.User = &SSHUserCertificateOptions{}
|
||||
if allow := user.GetAllow(); allow != nil {
|
||||
opts.SSH.User.AllowedNames = &SSHNameOptions{}
|
||||
if allow.Emails != nil {
|
||||
opts.SSH.User.AllowedNames.EmailAddresses = allow.Emails
|
||||
}
|
||||
if allow.Principals != nil {
|
||||
opts.SSH.User.AllowedNames.Principals = allow.Principals
|
||||
}
|
||||
}
|
||||
if deny := user.GetDeny(); deny != nil {
|
||||
opts.SSH.User.DeniedNames = &SSHNameOptions{}
|
||||
if deny.Emails != nil {
|
||||
opts.SSH.User.DeniedNames.EmailAddresses = deny.Emails
|
||||
}
|
||||
if deny.Principals != nil {
|
||||
opts.SSH.User.DeniedNames.Principals = deny.Principals
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
155
authority/policy/policy_test.go
Normal file
155
authority/policy/policy_test.go
Normal file
|
@ -0,0 +1,155 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
func TestPolicyToCertificates(t *testing.T) {
|
||||
type args struct {
|
||||
policy *linkedca.Policy
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Options
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
args: args{
|
||||
policy: nil,
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no-policy",
|
||||
args: args{
|
||||
&linkedca.Policy{},
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "partial-policy",
|
||||
args: args{
|
||||
&linkedca.Policy{
|
||||
X509: &linkedca.X509Policy{
|
||||
Allow: &linkedca.X509Names{
|
||||
Dns: []string{"*.local"},
|
||||
},
|
||||
AllowWildcardNames: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &Options{
|
||||
X509: &X509PolicyOptions{
|
||||
AllowedNames: &X509NameOptions{
|
||||
DNSDomains: []string{"*.local"},
|
||||
},
|
||||
AllowWildcardNames: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full-policy",
|
||||
args: args{
|
||||
&linkedca.Policy{
|
||||
X509: &linkedca.X509Policy{
|
||||
Allow: &linkedca.X509Names{
|
||||
Dns: []string{"step"},
|
||||
Ips: []string{"127.0.0.1/24"},
|
||||
Emails: []string{"*.example.com"},
|
||||
Uris: []string{"https://*.local"},
|
||||
CommonNames: []string{"some name"},
|
||||
},
|
||||
Deny: &linkedca.X509Names{
|
||||
Dns: []string{"bad"},
|
||||
Ips: []string{"127.0.0.30"},
|
||||
Emails: []string{"badhost.example.com"},
|
||||
Uris: []string{"https://badhost.local"},
|
||||
CommonNames: []string{"another name"},
|
||||
},
|
||||
AllowWildcardNames: true,
|
||||
},
|
||||
Ssh: &linkedca.SSHPolicy{
|
||||
Host: &linkedca.SSHHostPolicy{
|
||||
Allow: &linkedca.SSHHostNames{
|
||||
Dns: []string{"*.localhost"},
|
||||
Ips: []string{"127.0.0.1/24"},
|
||||
Principals: []string{"user"},
|
||||
},
|
||||
Deny: &linkedca.SSHHostNames{
|
||||
Dns: []string{"badhost.localhost"},
|
||||
Ips: []string{"127.0.0.40"},
|
||||
Principals: []string{"root"},
|
||||
},
|
||||
},
|
||||
User: &linkedca.SSHUserPolicy{
|
||||
Allow: &linkedca.SSHUserNames{
|
||||
Emails: []string{"@work"},
|
||||
Principals: []string{"user"},
|
||||
},
|
||||
Deny: &linkedca.SSHUserNames{
|
||||
Emails: []string{"root@work"},
|
||||
Principals: []string{"root"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &Options{
|
||||
X509: &X509PolicyOptions{
|
||||
AllowedNames: &X509NameOptions{
|
||||
DNSDomains: []string{"step"},
|
||||
IPRanges: []string{"127.0.0.1/24"},
|
||||
EmailAddresses: []string{"*.example.com"},
|
||||
URIDomains: []string{"https://*.local"},
|
||||
CommonNames: []string{"some name"},
|
||||
},
|
||||
DeniedNames: &X509NameOptions{
|
||||
DNSDomains: []string{"bad"},
|
||||
IPRanges: []string{"127.0.0.30"},
|
||||
EmailAddresses: []string{"badhost.example.com"},
|
||||
URIDomains: []string{"https://badhost.local"},
|
||||
CommonNames: []string{"another name"},
|
||||
},
|
||||
AllowWildcardNames: true,
|
||||
},
|
||||
SSH: &SSHPolicyOptions{
|
||||
Host: &SSHHostCertificateOptions{
|
||||
AllowedNames: &SSHNameOptions{
|
||||
DNSDomains: []string{"*.localhost"},
|
||||
IPRanges: []string{"127.0.0.1/24"},
|
||||
Principals: []string{"user"},
|
||||
},
|
||||
DeniedNames: &SSHNameOptions{
|
||||
DNSDomains: []string{"badhost.localhost"},
|
||||
IPRanges: []string{"127.0.0.40"},
|
||||
Principals: []string{"root"},
|
||||
},
|
||||
},
|
||||
User: &SSHUserCertificateOptions{
|
||||
AllowedNames: &SSHNameOptions{
|
||||
EmailAddresses: []string{"@work"},
|
||||
Principals: []string{"user"},
|
||||
},
|
||||
DeniedNames: &SSHNameOptions{
|
||||
EmailAddresses: []string{"root@work"},
|
||||
Principals: []string{"root"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := LinkedToCertificates(tt.args.policy)
|
||||
if !cmp.Equal(tt.want, got) {
|
||||
t.Errorf("policyToCertificates() diff=\n%s", cmp.Diff(tt.want, got))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1625
authority/policy_test.go
Normal file
1625
authority/policy_test.go
Normal file
File diff suppressed because it is too large
Load diff
|
@ -3,6 +3,8 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -23,7 +25,8 @@ type ACME struct {
|
|||
RequireEAB bool `json:"requireEAB,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
ctl *Controller
|
||||
|
||||
ctl *Controller
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
|
@ -71,7 +74,7 @@ func (p *ACME) DefaultTLSCertDuration() time.Duration {
|
|||
return p.ctl.Claimer.DefaultTLSCertDuration()
|
||||
}
|
||||
|
||||
// Init initializes and validates the fields of a JWK type.
|
||||
// Init initializes and validates the fields of an ACME type.
|
||||
func (p *ACME) Init(config Config) (err error) {
|
||||
switch {
|
||||
case p.Type == "":
|
||||
|
@ -80,15 +83,57 @@ func (p *ACME) Init(config Config) (err error) {
|
|||
return errors.New("provisioner name cannot be empty")
|
||||
}
|
||||
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
p.ctl, err = NewController(p, p.Claims, config, p.Options)
|
||||
return
|
||||
}
|
||||
|
||||
// ACMEIdentifierType encodes ACME Identifier types
|
||||
type ACMEIdentifierType string
|
||||
|
||||
const (
|
||||
// IP is the ACME ip identifier type
|
||||
IP ACMEIdentifierType = "ip"
|
||||
// DNS is the ACME dns identifier type
|
||||
DNS ACMEIdentifierType = "dns"
|
||||
)
|
||||
|
||||
// ACMEIdentifier encodes ACME Order Identifiers
|
||||
type ACMEIdentifier struct {
|
||||
Type ACMEIdentifierType
|
||||
Value string
|
||||
}
|
||||
|
||||
// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a
|
||||
// certificate for an ACME Order Identifier.
|
||||
func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error {
|
||||
|
||||
x509Policy := p.ctl.getPolicy().getX509()
|
||||
|
||||
// identifier is allowed if no policy is configured
|
||||
if x509Policy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// assuming only valid identifiers (IP or DNS) are provided
|
||||
var err error
|
||||
switch identifier.Type {
|
||||
case IP:
|
||||
err = x509Policy.IsIPAllowed(net.ParseIP(identifier.Value))
|
||||
case DNS:
|
||||
err = x509Policy.IsDNSAllowed(identifier.Value)
|
||||
default:
|
||||
err = fmt.Errorf("invalid ACME identifier type '%s' provided", identifier.Type)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AuthorizeSign does not do any validation, because all validation is handled
|
||||
// in the ACME protocol. This method returns a list of modifiers / constraints
|
||||
// on the resulting certificate.
|
||||
func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
return []SignOption{
|
||||
opts := []SignOption{
|
||||
p,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeACME, p.Name, ""),
|
||||
newForceCNOption(p.ForceCN),
|
||||
|
@ -96,7 +141,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
}, nil
|
||||
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// AuthorizeRevoke is called just before the certificate is to be revoked by
|
||||
|
|
|
@ -176,9 +176,10 @@ func TestACME_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
|
||||
assert.Len(t, 5, opts)
|
||||
assert.Equals(t, 7, len(opts)) // number of SignOptions returned
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case *ACME:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, TypeACME)
|
||||
assert.Equals(t, v.Name, tc.p.GetName())
|
||||
|
@ -192,6 +193,8 @@ func TestACME_AuthorizeSign(t *testing.T) {
|
|||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
default:
|
||||
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
|
|
|
@ -17,10 +17,12 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// awsIssuer is the string used as issuer in the generated tokens.
|
||||
|
@ -49,22 +51,27 @@ const awsMetadataTokenTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
|
|||
// signature.
|
||||
//
|
||||
// The first certificate is used in:
|
||||
// ap-northeast-2, ap-south-1, ap-southeast-1, ap-southeast-2
|
||||
// eu-central-1, eu-north-1, eu-west-1, eu-west-2, eu-west-3
|
||||
// us-east-1, us-east-2, us-west-1, us-west-2
|
||||
// ca-central-1, sa-east-1
|
||||
//
|
||||
// ap-northeast-2, ap-south-1, ap-southeast-1, ap-southeast-2
|
||||
// eu-central-1, eu-north-1, eu-west-1, eu-west-2, eu-west-3
|
||||
// us-east-1, us-east-2, us-west-1, us-west-2
|
||||
// ca-central-1, sa-east-1
|
||||
//
|
||||
// The second certificate is used in:
|
||||
// eu-south-1
|
||||
//
|
||||
// eu-south-1
|
||||
//
|
||||
// The third certificate is used in:
|
||||
// ap-east-1
|
||||
//
|
||||
// ap-east-1
|
||||
//
|
||||
// The fourth certificate is used in:
|
||||
// af-south-1
|
||||
//
|
||||
// af-south-1
|
||||
//
|
||||
// The fifth certificate is used in:
|
||||
// me-south-1
|
||||
//
|
||||
// me-south-1
|
||||
const awsCertificate = `-----BEGIN CERTIFICATE-----
|
||||
MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV
|
||||
BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw
|
||||
|
@ -421,7 +428,7 @@ func (p *AWS) Init(config Config) (err error) {
|
|||
}
|
||||
|
||||
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
p.ctl, err = NewController(p, p.Claims, config, p.Options)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -467,6 +474,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
}
|
||||
|
||||
return append(so,
|
||||
p,
|
||||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
|
||||
|
@ -475,6 +483,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
defaultPublicKeyValidator{},
|
||||
commonNameValidator(payload.Claims.Subject),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
|
||||
), nil
|
||||
}
|
||||
|
||||
|
@ -542,7 +551,7 @@ func (p *AWS) readURL(url string) ([]byte, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("Request for metadata returned non-successful status code %d",
|
||||
return nil, fmt.Errorf("request for metadata returned non-successful status code %d",
|
||||
resp.StatusCode)
|
||||
}
|
||||
|
||||
|
@ -575,7 +584,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode)
|
||||
}
|
||||
token, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
|
@ -743,6 +752,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
signOptions = append(signOptions, templateOptions)
|
||||
|
||||
return append(signOptions,
|
||||
p,
|
||||
// Validate user SignSSHOptions.
|
||||
sshCertOptionsValidator(defaults),
|
||||
// Set the validity bounds if not set.
|
||||
|
@ -753,5 +763,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -642,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false},
|
||||
{"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false},
|
||||
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false},
|
||||
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false},
|
||||
{"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false},
|
||||
{"ok", p1, args{t1, "foo.local"}, 8, http.StatusOK, false},
|
||||
{"ok", p2, args{t2, "instance-id"}, 12, http.StatusOK, false},
|
||||
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 12, http.StatusOK, false},
|
||||
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 12, http.StatusOK, false},
|
||||
{"ok", p1, args{t4, "instance-id"}, 8, http.StatusOK, false},
|
||||
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
|
||||
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
|
||||
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
|
||||
|
@ -673,9 +673,10 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
default:
|
||||
assert.Len(t, tt.wantLen, got)
|
||||
assert.Equals(t, tt.wantLen, len(got))
|
||||
for _, o := range got {
|
||||
switch v := o.(type) {
|
||||
case *AWS:
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, TypeAWS)
|
||||
|
@ -698,6 +699,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
assert.Equals(t, v, nil)
|
||||
case dnsNamesValidator:
|
||||
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
default:
|
||||
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
|
@ -810,7 +813,6 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
|||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
if tt.wantSignErr {
|
||||
|
|
|
@ -13,10 +13,12 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
|
||||
|
@ -219,7 +221,7 @@ func (p *Azure) Init(config Config) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
p.ctl, err = NewController(p, p.Claims, config, p.Options)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -352,6 +354,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
}
|
||||
|
||||
return append(so,
|
||||
p,
|
||||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
|
||||
|
@ -359,6 +362,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
|
||||
), nil
|
||||
}
|
||||
|
||||
|
@ -414,6 +418,7 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
signOptions = append(signOptions, templateOptions)
|
||||
|
||||
return append(signOptions,
|
||||
p,
|
||||
// Validate user SignSSHOptions.
|
||||
sshCertOptionsValidator(defaults),
|
||||
// Set the validity bounds if not set.
|
||||
|
@ -424,6 +429,8 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
|
||||
), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -474,11 +474,11 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, 5, http.StatusOK, false},
|
||||
{"ok", p2, args{t2}, 10, http.StatusOK, false},
|
||||
{"ok", p1, args{t11}, 5, http.StatusOK, false},
|
||||
{"ok", p5, args{t5}, 5, http.StatusOK, false},
|
||||
{"ok", p7, args{t7}, 5, http.StatusOK, false},
|
||||
{"ok", p1, args{t1}, 7, http.StatusOK, false},
|
||||
{"ok", p2, args{t2}, 12, http.StatusOK, false},
|
||||
{"ok", p1, args{t11}, 7, http.StatusOK, false},
|
||||
{"ok", p5, args{t5}, 7, http.StatusOK, false},
|
||||
{"ok", p7, args{t7}, 7, http.StatusOK, false},
|
||||
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
|
||||
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
|
||||
{"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true},
|
||||
|
@ -502,9 +502,10 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
default:
|
||||
assert.Len(t, tt.wantLen, got)
|
||||
assert.Equals(t, tt.wantLen, len(got))
|
||||
for _, o := range got {
|
||||
switch v := o.(type) {
|
||||
case *Azure:
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, TypeAzure)
|
||||
|
@ -527,6 +528,8 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
assert.Equals(t, v, nil)
|
||||
case dnsNamesValidator:
|
||||
assert.Equals(t, []string(v), []string{"virtualMachine"})
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
default:
|
||||
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
|
|
|
@ -24,8 +24,8 @@ type Claims struct {
|
|||
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
|
||||
|
||||
// Renewal properties
|
||||
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"`
|
||||
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||
AllowRenewalAfterExpiry *bool `json:"allowRenewalAfterExpiry,omitempty"`
|
||||
}
|
||||
|
||||
// Claimer is the type that controls claims. It provides an interface around the
|
||||
|
@ -44,22 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
|
|||
// Claims returns the merge of the inner and global claims.
|
||||
func (c *Claimer) Claims() Claims {
|
||||
disableRenewal := c.IsDisableRenewal()
|
||||
allowRenewAfterExpiry := c.AllowRenewAfterExpiry()
|
||||
allowRenewalAfterExpiry := c.AllowRenewalAfterExpiry()
|
||||
enableSSHCA := c.IsSSHCAEnabled()
|
||||
|
||||
return Claims{
|
||||
MinTLSDur: &Duration{c.MinTLSCertDuration()},
|
||||
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
||||
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
||||
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
|
||||
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
|
||||
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
|
||||
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
|
||||
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
|
||||
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
|
||||
EnableSSHCA: &enableSSHCA,
|
||||
DisableRenewal: &disableRenewal,
|
||||
AllowRenewAfterExpiry: &allowRenewAfterExpiry,
|
||||
MinTLSDur: &Duration{c.MinTLSCertDuration()},
|
||||
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
||||
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
||||
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
|
||||
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
|
||||
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
|
||||
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
|
||||
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
|
||||
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
|
||||
EnableSSHCA: &enableSSHCA,
|
||||
DisableRenewal: &disableRenewal,
|
||||
AllowRenewalAfterExpiry: &allowRenewalAfterExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -109,14 +109,14 @@ func (c *Claimer) IsDisableRenewal() bool {
|
|||
return *c.claims.DisableRenewal
|
||||
}
|
||||
|
||||
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the
|
||||
// AllowRenewalAfterExpiry returns if the renewal flow is authorized if the
|
||||
// certificate is expired. If the property is not set within the provisioner
|
||||
// then the global value from the authority configuration will be used.
|
||||
func (c *Claimer) AllowRenewAfterExpiry() bool {
|
||||
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil {
|
||||
return *c.global.AllowRenewAfterExpiry
|
||||
func (c *Claimer) AllowRenewalAfterExpiry() bool {
|
||||
if c.claims == nil || c.claims.AllowRenewalAfterExpiry == nil {
|
||||
return *c.global.AllowRenewalAfterExpiry
|
||||
}
|
||||
return *c.claims.AllowRenewAfterExpiry
|
||||
return *c.claims.AllowRenewalAfterExpiry
|
||||
}
|
||||
|
||||
// DefaultSSHCertDuration returns the default SSH certificate duration for the
|
||||
|
|
|
@ -3,6 +3,7 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -21,14 +22,19 @@ type Controller struct {
|
|||
IdentityFunc GetIdentityFunc
|
||||
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||
policy *policyEngine
|
||||
}
|
||||
|
||||
// NewController initializes a new provisioner controller.
|
||||
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) {
|
||||
func NewController(p Interface, claims *Claims, config Config, options *Options) (*Controller, error) {
|
||||
claimer, err := NewClaimer(claims, config.Claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
policy, err := newPolicyEngine(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Controller{
|
||||
Interface: p,
|
||||
Audiences: &config.Audiences,
|
||||
|
@ -36,6 +42,7 @@ func NewController(p Interface, claims *Claims, config Config) (*Controller, err
|
|||
IdentityFunc: config.GetIdentityFunc,
|
||||
AuthorizeRenewFunc: config.AuthorizeRenewFunc,
|
||||
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
|
||||
policy: policy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -124,8 +131,10 @@ func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certif
|
|||
if now.Before(cert.NotBefore) {
|
||||
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
|
||||
}
|
||||
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() {
|
||||
return errs.Unauthorized("certificate has expired")
|
||||
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewalAfterExpiry() {
|
||||
// return a custom 401 Unauthorized error with a clearer message for the client
|
||||
// TODO(hs): these errors likely need to be refactored as a whole; HTTP status codes shouldn't be in this layer.
|
||||
return errs.New(http.StatusUnauthorized, "The request lacked necessary authorization to be completed: certificate expired on %s", cert.NotAfter)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -144,7 +153,7 @@ func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Cert
|
|||
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
|
||||
return errs.Unauthorized("certificate is not yet valid")
|
||||
}
|
||||
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() {
|
||||
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewalAfterExpiry() {
|
||||
return errs.Unauthorized("certificate has expired")
|
||||
}
|
||||
|
||||
|
@ -192,3 +201,10 @@ func SanitizeSSHUserPrincipal(email string) string {
|
|||
}
|
||||
}, strings.ToLower(email))
|
||||
}
|
||||
|
||||
func (c *Controller) getPolicy() *policyEngine {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.policy
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
var trueValue = true
|
||||
|
@ -30,11 +32,40 @@ func mustDuration(t *testing.T, s string) *Duration {
|
|||
return d
|
||||
}
|
||||
|
||||
func mustNewPolicyEngine(t *testing.T, options *Options) *policyEngine {
|
||||
t.Helper()
|
||||
c, err := newPolicyEngine(options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func TestNewController(t *testing.T) {
|
||||
options := &Options{
|
||||
X509: &X509Options{
|
||||
AllowedNames: &policy.X509NameOptions{
|
||||
DNSDomains: []string{"*.local"},
|
||||
},
|
||||
},
|
||||
SSH: &SSHOptions{
|
||||
Host: &policy.SSHHostCertificateOptions{
|
||||
AllowedNames: &policy.SSHNameOptions{
|
||||
DNSDomains: []string{"*.local"},
|
||||
},
|
||||
},
|
||||
User: &policy.SSHUserCertificateOptions{
|
||||
AllowedNames: &policy.SSHNameOptions{
|
||||
EmailAddresses: []string{"@example.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
type args struct {
|
||||
p Interface
|
||||
claims *Claims
|
||||
config Config
|
||||
p Interface
|
||||
claims *Claims
|
||||
config Config
|
||||
options *Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -45,7 +76,7 @@ func TestNewController(t *testing.T) {
|
|||
{"ok", args{&JWK{}, nil, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}}, &Controller{
|
||||
}, nil}, &Controller{
|
||||
Interface: &JWK{},
|
||||
Audiences: &testAudiences,
|
||||
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||
|
@ -55,24 +86,49 @@ func TestNewController(t *testing.T) {
|
|||
}, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}}, &Controller{
|
||||
}, nil}, &Controller{
|
||||
Interface: &JWK{},
|
||||
Audiences: &testAudiences,
|
||||
Claimer: mustClaimer(t, &Claims{
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}, globalProvisionerClaims),
|
||||
}, false},
|
||||
{"ok with claims and options", args{&JWK{}, &Claims{
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}, options}, &Controller{
|
||||
Interface: &JWK{},
|
||||
Audiences: &testAudiences,
|
||||
Claimer: mustClaimer(t, &Claims{
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}, globalProvisionerClaims),
|
||||
policy: mustNewPolicyEngine(t, options),
|
||||
}, false},
|
||||
{"fail claimer", args{&JWK{}, &Claims{
|
||||
MinTLSDur: mustDuration(t, "24h"),
|
||||
MaxTLSDur: mustDuration(t, "2h"),
|
||||
}, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}, nil}, nil, true},
|
||||
{"fail options", args{&JWK{}, &Claims{
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}, &Options{
|
||||
X509: &X509Options{
|
||||
AllowedNames: &policy.X509NameOptions{
|
||||
DNSDomains: []string{"**.local"},
|
||||
},
|
||||
},
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
|
||||
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config, tt.args.options)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -160,13 +216,13 @@ func TestController_AuthorizeRenew(t *testing.T) {
|
|||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
}}, false},
|
||||
|
@ -231,13 +287,13 @@ func TestController_AuthorizeSSHRenew(t *testing.T) {
|
|||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
}}, false},
|
||||
|
@ -296,7 +352,7 @@ func TestDefaultAuthorizeRenew(t *testing.T) {
|
|||
}}, false},
|
||||
{"ok renew after expiry", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||
Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
|
@ -354,7 +410,7 @@ func TestDefaultAuthorizeSSHRenew(t *testing.T) {
|
|||
}}, false},
|
||||
{"ok renew after expiry", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||
Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
|
|
|
@ -14,10 +14,12 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// gcpCertsURL is the url that serves Google OAuth2 public keys.
|
||||
|
@ -212,7 +214,7 @@ func (p *GCP) Init(config Config) (err error) {
|
|||
}
|
||||
|
||||
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
p.ctl, err = NewController(p, p.Claims, config, p.Options)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -262,6 +264,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
}
|
||||
|
||||
return append(so,
|
||||
p,
|
||||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
|
||||
|
@ -269,6 +272,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
|
||||
), nil
|
||||
}
|
||||
|
||||
|
@ -421,6 +425,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
signOptions = append(signOptions, templateOptions)
|
||||
|
||||
return append(signOptions,
|
||||
p,
|
||||
// Validate user SignSSHOptions.
|
||||
sshCertOptionsValidator(defaults),
|
||||
// Set the validity bounds if not set.
|
||||
|
@ -431,5 +436,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -516,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1}, 5, http.StatusOK, false},
|
||||
{"ok", p2, args{t2}, 10, http.StatusOK, false},
|
||||
{"ok", p3, args{t3}, 5, http.StatusOK, false},
|
||||
{"ok", p1, args{t1}, 7, http.StatusOK, false},
|
||||
{"ok", p2, args{t2}, 12, http.StatusOK, false},
|
||||
{"ok", p3, args{t3}, 7, http.StatusOK, false},
|
||||
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
|
||||
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
|
||||
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
|
||||
|
@ -545,9 +545,10 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||
default:
|
||||
assert.Len(t, tt.wantLen, got)
|
||||
assert.Equals(t, tt.wantLen, len(got))
|
||||
for _, o := range got {
|
||||
switch v := o.(type) {
|
||||
case *GCP:
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, TypeGCP)
|
||||
|
@ -570,6 +571,8 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
assert.Equals(t, v, nil)
|
||||
case dnsNamesValidator:
|
||||
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
default:
|
||||
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
|
|
|
@ -7,10 +7,12 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// jwtPayload extends jwt.Claims with step attributes.
|
||||
|
@ -97,7 +99,7 @@ func (p *JWK) Init(config Config) (err error) {
|
|||
return errors.New("provisioner key cannot be empty")
|
||||
}
|
||||
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
p.ctl, err = NewController(p, p.Claims, config, p.Options)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -141,6 +143,7 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
|
|||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
|
||||
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke)
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
|
||||
}
|
||||
|
||||
|
@ -170,6 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
}
|
||||
|
||||
return []SignOption{
|
||||
p,
|
||||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
|
||||
|
@ -179,6 +183,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
defaultPublicKeyValidator{},
|
||||
defaultSANsValidator(claims.SANs),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -187,6 +192,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// revocation status. Just confirms that the provisioner that created the
|
||||
// certificate was configured to allow renewals.
|
||||
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew)
|
||||
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||
}
|
||||
|
||||
|
@ -251,6 +257,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
}
|
||||
|
||||
return append(signOptions,
|
||||
p,
|
||||
// Set the validity bounds if not set.
|
||||
&sshDefaultDuration{p.ctl.Claimer},
|
||||
// Validate public key
|
||||
|
@ -259,11 +266,14 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
|
||||
), nil
|
||||
}
|
||||
|
||||
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
|
||||
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
|
||||
// TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke)
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
|
||||
}
|
||||
|
|
|
@ -297,9 +297,10 @@ func TestJWK_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
} else {
|
||||
if assert.NotNil(t, got) {
|
||||
assert.Len(t, 7, got)
|
||||
assert.Equals(t, 9, len(got))
|
||||
for _, o := range got {
|
||||
switch v := o.(type) {
|
||||
case *JWK:
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, TypeJWK)
|
||||
|
@ -316,6 +317,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
|
|||
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
|
||||
case defaultSANsValidator:
|
||||
assert.Equals(t, []string(v), tt.sans)
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
default:
|
||||
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
|
|
|
@ -10,11 +10,13 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// NOTE: There can be at most one kubernetes service account provisioner configured
|
||||
|
@ -91,6 +93,7 @@ func (p *K8sSA) GetEncryptedKey() (string, string, bool) {
|
|||
|
||||
// Init initializes and validates the fields of a K8sSA type.
|
||||
func (p *K8sSA) Init(config Config) (err error) {
|
||||
|
||||
switch {
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
|
@ -137,7 +140,7 @@ func (p *K8sSA) Init(config Config) (err error) {
|
|||
p.kauthn = k8s.AuthenticationV1()
|
||||
*/
|
||||
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
p.ctl, err = NewController(p, p.Claims, config, p.Options)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -231,6 +234,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
}
|
||||
|
||||
return []SignOption{
|
||||
p,
|
||||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
|
||||
|
@ -238,6 +242,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -270,6 +275,7 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
signOptions := []SignOption{templateOptions}
|
||||
|
||||
return append(signOptions,
|
||||
p,
|
||||
// Require type, key-id and principals in the SignSSHOptions.
|
||||
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
|
||||
// Set the validity bounds if not set.
|
||||
|
@ -280,6 +286,8 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
|
||||
), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -280,9 +280,9 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
|
|||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
if assert.NotNil(t, opts) {
|
||||
tot := 0
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case *K8sSA:
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, TypeK8sSA)
|
||||
|
@ -295,12 +295,13 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
|
|||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
default:
|
||||
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
tot++
|
||||
}
|
||||
assert.Equals(t, tot, 5)
|
||||
assert.Equals(t, 7, len(opts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -367,9 +368,10 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
|||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
if assert.NotNil(t, opts) {
|
||||
tot := 0
|
||||
assert.Len(t, 8, opts)
|
||||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case Interface:
|
||||
case sshCertificateOptionsFunc:
|
||||
case *sshCertOptionsRequireValidator:
|
||||
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
|
||||
|
@ -379,12 +381,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
|||
case *sshCertDefaultValidator:
|
||||
case *sshDefaultDuration:
|
||||
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||
case *sshNamePolicyValidator:
|
||||
assert.Equals(t, nil, v.userPolicyEngine)
|
||||
assert.Equals(t, nil, v.hostPolicyEngine)
|
||||
default:
|
||||
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
tot++
|
||||
}
|
||||
assert.Equals(t, tot, 6)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,3 +61,16 @@ func MethodFromContext(ctx context.Context) Method {
|
|||
m, _ := ctx.Value(methodKey{}).(Method)
|
||||
return m
|
||||
}
|
||||
|
||||
type tokenKey struct{}
|
||||
|
||||
// NewContextWithToken creates a new context with the given token.
|
||||
func NewContextWithToken(ctx context.Context, token string) context.Context {
|
||||
return context.WithValue(ctx, tokenKey{}, token)
|
||||
}
|
||||
|
||||
// TokenFromContext returns the token stored in the given context.
|
||||
func TokenFromContext(ctx context.Context) (string, bool) {
|
||||
token, ok := ctx.Value(tokenKey{}).(string)
|
||||
return token, ok
|
||||
}
|
||||
|
|
|
@ -10,12 +10,14 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
nebula "github.com/slackhq/nebula/cert"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x25519"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -61,7 +63,7 @@ func (p *Nebula) Init(config Config) (err error) {
|
|||
}
|
||||
|
||||
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
p.ctl, err = NewController(p, p.Claims, config, p.Options)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -144,6 +146,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
}
|
||||
|
||||
return []SignOption{
|
||||
p,
|
||||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeNebula, p.Name, ""),
|
||||
|
@ -160,6 +163,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
},
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -246,6 +250,7 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
|
|||
}
|
||||
|
||||
return append(signOptions,
|
||||
p,
|
||||
templateOptions,
|
||||
// Checks the validity bounds, and set the validity if has not been set.
|
||||
&sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter},
|
||||
|
@ -255,6 +260,8 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
|
|||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
|
||||
), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ func (p *noop) Init(config Config) error {
|
|||
}
|
||||
|
||||
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
return []SignOption{}, nil
|
||||
return []SignOption{p}, nil
|
||||
}
|
||||
|
||||
func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
|
@ -50,7 +50,7 @@ func (p *noop) AuthorizeRevoke(ctx context.Context, token string) error {
|
|||
}
|
||||
|
||||
func (p *noop) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
return []SignOption{}, nil
|
||||
return []SignOption{p}, nil
|
||||
}
|
||||
|
||||
func (p *noop) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue