Merge branch 'master' into onboarding

This commit is contained in:
Mariano Cano 2019-09-26 15:36:19 -07:00 committed by GitHub
commit be07334164
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
75 changed files with 16168 additions and 332 deletions

View file

@ -49,6 +49,8 @@ linters:
- misspell - misspell
- ineffassign - ineffassign
- deadcode - deadcode
- staticcheck
- unused
run: run:
skip-dirs: skip-dirs:
@ -63,6 +65,6 @@ issues:
# golangci.com configuration # golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration # https://github.com/golangci/golangci/wiki/Configuration
service: service:
golangci-lint-version: 1.17.x # use the fixed version to not introduce new linters unexpectedly golangci-lint-version: 1.18.x # use the fixed version to not introduce new linters unexpectedly
prepare: prepare:
- echo "here I can run custom commands, but no preparation needed for this repo" - echo "here I can run custom commands, but no preparation needed for this repo"

4
Gopkg.lock generated
View file

@ -233,7 +233,7 @@
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:9c1b7052fa8f2c918efd60ed5ae3c70ccbba08967c58ec71067535449a3ba220" digest = "1:7d03323edb817ca94efaee5489cde6acd06ceeaca9e6eee106d2d6a90deca997"
name = "github.com/smallstep/nosql" name = "github.com/smallstep/nosql"
packages = [ packages = [
".", ".",
@ -243,7 +243,7 @@
"mysql", "mysql",
] ]
pruneopts = "UT" pruneopts = "UT"
revision = "a0934e12468769d8cbede3ed316c47a4b88de4ca" revision = "f80b3f432de0662f07ebd58fe52b0a119fe5dcd9"
[[projects]] [[projects]]
branch = "master" branch = "master"

View file

@ -8,9 +8,6 @@ SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*")
GOOS_OVERRIDE ?= GOOS_OVERRIDE ?=
OUTPUT_ROOT=output/ OUTPUT_ROOT=output/
# Set shell to bash for `echo -e`
SHELL := /bin/bash
all: build test lint all: build test lint
.PHONY: all .PHONY: all
@ -97,16 +94,7 @@ generate:
test: test:
$Q $(GOFLAGS) go test -short -coverprofile=coverage.out ./... $Q $(GOFLAGS) go test -short -coverprofile=coverage.out ./...
vtest: .PHONY: test
$(Q)for d in $$(go list ./... | grep -v vendor); do \
echo -e "TESTS FOR: for \033[0;35m$$d\033[0m"; \
$(GOFLAGS) go test -v -bench=. -run=. -short -coverprofile=vcoverage.out $$d; \
out=$$?; \
if [[ $$out -ne 0 ]]; then ret=$$out; fi;\
rm -f profile.coverage.out; \
done; exit $$ret;
.PHONY: test vtest
integrate: integration integrate: integration
@ -125,7 +113,7 @@ fmt:
lint: lint:
$Q LOG_LEVEL=error golangci-lint run $Q LOG_LEVEL=error golangci-lint run
.PHONY: $(LINTERS) lint fmt .PHONY: lint fmt
######################################### #########################################
# Install # Install

View file

@ -32,7 +32,7 @@ It's super easy to get started and to operate `step-ca` thanks to [streamlined i
### A private certificate authority you run yourself ### A private certificate authority you run yourself
- Issue client and server certificates to VMs, containers, devices, and people using internal hostnames and emails - Issue client and server certificates to VMs, containers, devices, and people using internal hostnames and emails
- [RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliant certificates that work **for TLS and HTTPS** (SSH coming soon!) - [RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliant certificates that work **for TLS and HTTPS**
- Choose key types (RSA, ECDSA, EdDSA) & lifetimes to suit your needs - Choose key types (RSA, ECDSA, EdDSA) & lifetimes to suit your needs
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with **fully automated** enrollment, renewal, and revocation - [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with **fully automated** enrollment, renewal, and revocation
- Fast, stable, and capable of high availability deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries - Fast, stable, and capable of high availability deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
@ -46,7 +46,19 @@ It's super easy to get started and to operate `step-ca` thanks to [streamlined i
- [Instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/) for VMs on AWS, GCP, and Azure - [Instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/) for VMs on AWS, GCP, and Azure
- [Single-use short-lived tokens](https://smallstep.com/docs/design-doc.html#jwk-provisioner) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc. - [Single-use short-lived tokens](https://smallstep.com/docs/design-doc.html#jwk-provisioner) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
- Use an existing certificate from another CA (e.g., using a device certificate like [Twilio's Trust OnBoard](https://www.twilio.com/wireless/trust-onboard)) *coming soon* - Use an existing certificate from another CA (e.g., using a device certificate like [Twilio's Trust OnBoard](https://www.twilio.com/wireless/trust-onboard)) *coming soon*
- ACMEv2 (RFC8555) support so you can **run your own private ACME server** *[coming soon](https://github.com/smallstep/certificates/tree/acme)*
### [Your own private ACME Server](https://smallstep.com/blog/private-acme-server/)
- Issue certificates using ACMEv2 ([RFC8555](https://tools.ietf.org/html/rfc8555)), **the protocol used by Let's Encrypt**
- Great for [using ACME in development & pre-production](https://smallstep.com/blog/private-acme-server/#local-development-pre-production)
- Supports the `http-01` and `dns-01` ACME challenge types
- Works with any compliant ACME client including [certbot](https://smallstep.com/blog/private-acme-server/#certbot-uploads-acme-certbot-png-certbot-example), [acme.sh](https://smallstep.com/blog/private-acme-server/#acme-sh-uploads-acme-acme-sh-png-acme-sh-example), [Caddy](https://smallstep.com/blog/private-acme-server/#caddy-uploads-acme-caddy-png-caddy-example), and [traefik](https://smallstep.com/blog/private-acme-server/#traefik-uploads-acme-traefik-png-traefik-example)
- Get certificates programmatically (e.g., in [Go](https://smallstep.com/blog/private-acme-server/#golang-uploads-acme-golang-png-go-example), [Python](https://smallstep.com/blog/private-acme-server/#python-uploads-acme-python-png-python-example), [Node.js](https://smallstep.com/blog/private-acme-server/#node-js-uploads-acme-node-js-png-node-js-example))
### [SSH Certificates](https://smallstep.com/blog/use-ssh-certificates/)
- Use [certificate authentication for SSH](https://smallstep.com/blog/use-ssh-certificates/): connect SSH to SSO, improve security, and eliminate warnings & errors
- Issue SSH user certificates using OAuth OIDC
- Issue SSH host certificates to cloud VMs using instance identity documents
### Easy certificate management and automation via [`step` CLI](https://github.com/smallstep/cli) [integration](https://smallstep.com/docs/cli/ca/) ### Easy certificate management and automation via [`step` CLI](https://github.com/smallstep/cli) [integration](https://smallstep.com/docs/cli/ca/)

214
acme/account.go Normal file
View file

@ -0,0 +1,214 @@
package acme
import (
"encoding/json"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose"
"github.com/smallstep/nosql"
)
// Account is a subset of the internal account type containing only those
// attributes required for responses in the ACME protocol.
type Account struct {
Contact []string `json:"contact,omitempty"`
Status string `json:"status"`
Orders string `json:"orders"`
ID string `json:"-"`
Key *jose.JSONWebKey `json:"-"`
}
// ToLog enables response logging.
func (a *Account) ToLog() (interface{}, error) {
b, err := json.Marshal(a)
if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging"))
}
return string(b), nil
}
// GetID returns the account ID.
func (a *Account) GetID() string {
return a.ID
}
// GetKey returns the JWK associated with the account.
func (a *Account) GetKey() *jose.JSONWebKey {
return a.Key
}
// IsValid returns true if the Account is valid.
func (a *Account) IsValid() bool {
return a.Status == StatusValid
}
// AccountOptions are the options needed to create a new ACME account.
type AccountOptions struct {
Key *jose.JSONWebKey
Contact []string
}
// account represents an ACME account.
type account struct {
ID string `json:"id"`
Created time.Time `json:"created"`
Deactivated time.Time `json:"deactivated"`
Key *jose.JSONWebKey `json:"key"`
Contact []string `json:"contact,omitempty"`
Status string `json:"status"`
}
// newAccount returns a new acme account type.
func newAccount(db nosql.DB, ops AccountOptions) (*account, error) {
id, err := randID()
if err != nil {
return nil, err
}
a := &account{
ID: id,
Key: ops.Key,
Contact: ops.Contact,
Status: "valid",
Created: clock.Now(),
}
return a, a.saveNew(db)
}
// toACME converts the internal Account type into the public acmeAccount
// type for presentation in the ACME protocol.
func (a *account) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Account, error) {
return &Account{
Status: a.Status,
Contact: a.Contact,
Orders: dir.getLink(OrdersByAccountLink, URLSafeProvisionerName(p), true, a.ID),
Key: a.Key,
ID: a.ID,
}, nil
}
// save writes the Account to the DB.
// If the account is new then the necessary indices will be created.
// Else, the account in the DB will be updated.
func (a *account) saveNew(db nosql.DB) error {
kid, err := keyToID(a.Key)
if err != nil {
return err
}
kidB := []byte(kid)
// Set the jwkID -> acme account ID index
_, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID))
switch {
case err != nil:
return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index"))
case !swapped:
return ServerInternalErr(errors.Errorf("key-id to account-id index already exists"))
default:
if err = a.save(db, nil); err != nil {
db.Del(accountByKeyIDTable, kidB)
return err
}
return nil
}
}
func (a *account) save(db nosql.DB, old *account) error {
var (
err error
oldB []byte
)
if old == nil {
oldB = nil
} else {
if oldB, err = json.Marshal(old); err != nil {
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
}
}
b, err := json.Marshal(*a)
if err != nil {
return errors.Wrap(err, "error marshaling new account object")
}
// Set the Account
_, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b)
switch {
case err != nil:
return ServerInternalErr(errors.Wrap(err, "error storing account"))
case !swapped:
return ServerInternalErr(errors.New("error storing account; " +
"value has changed since last read"))
default:
return nil
}
}
// update updates the acme account object stored in the database if,
// and only if, the account has not changed since the last read.
func (a *account) update(db nosql.DB, contact []string) (*account, error) {
b := *a
b.Contact = contact
if err := b.save(db, a); err != nil {
return nil, err
}
return &b, nil
}
// deactivate deactivates the acme account.
func (a *account) deactivate(db nosql.DB) (*account, error) {
b := *a
b.Status = StatusDeactivated
b.Deactivated = clock.Now()
if err := b.save(db, a); err != nil {
return nil, err
}
return &b, nil
}
// getAccountByID retrieves the account with the given ID.
func getAccountByID(db nosql.DB, id string) (*account, error) {
ab, err := db.Get(accountTable, []byte(id))
if err != nil {
if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id))
}
return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id))
}
a := new(account)
if err = json.Unmarshal(ab, a); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account"))
}
return a, nil
}
// getAccountByKeyID retrieves Id associated with the given Kid.
func getAccountByKeyID(db nosql.DB, kid string) (*account, error) {
id, err := db.Get(accountByKeyIDTable, []byte(kid))
if err != nil {
if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid))
}
return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index"))
}
return getAccountByID(db, string(id))
}
// getOrderIDsByAccount retrieves a list of Order IDs that were created by the
// account.
func getOrderIDsByAccount(db nosql.DB, id string) ([]string, error) {
b, err := db.Get(ordersByAccountIDTable, []byte(id))
if err != nil {
if nosql.IsErrNotFound(err) {
return []string{}, nil
}
return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", id))
}
var orderIDs []string
if err := json.Unmarshal(b, &orderIDs); err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", id))
}
return orderIDs, nil
}

844
acme/account_test.go Normal file
View file

@ -0,0 +1,844 @@
package acme
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/cli/jose"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
var (
defaultDisableRenewal = false
globalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
}
)
func newProv() provisioner.Interface {
// Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-provisioner.com",
}
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
fmt.Printf("%v", err)
}
return p
}
func newAcc() (*account, error) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
if err != nil {
return nil, err
}
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, true, nil
},
}
return newAccount(mockdb, AccountOptions{
Key: jwk, Contact: []string{"foo", "bar"},
})
}
func TestGetAccountByID(t *testing.T) {
type test struct {
id string
db nosql.DB
acc *account
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{
acc: acc,
id: acc.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
},
err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)),
}
},
"fail/db-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{
acc: acc,
id: acc.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{
acc: acc,
id: acc.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
return nil, nil
},
},
err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
id: acc.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
return b, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if acc, err := getAccountByID(tc.db, tc.id); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.acc.ID, acc.ID)
assert.Equals(t, tc.acc.Status, acc.Status)
assert.Equals(t, tc.acc.Created, acc.Created)
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
assert.Equals(t, tc.acc.Contact, acc.Contact)
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
}
}
})
}
}
func TestGetAccountByKeyID(t *testing.T) {
type test struct {
kid string
db nosql.DB
acc *account
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/kid-not-found": func(t *testing.T) test {
return test{
kid: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
},
err: MalformedErr(errors.Errorf("account with key id foo not found: not found")),
}
},
"fail/db-error": func(t *testing.T) test {
return test{
kid: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading key-account index: force")),
}
},
"fail/getAccount-error": func(t *testing.T) test {
count := 0
return test{
kid: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 0 {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte("foo"))
count++
return []byte("bar"), nil
}
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading account bar: force")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
count := 0
return test{
kid: acc.Key.KeyID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
var ret []byte
switch count {
case 0:
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(acc.Key.KeyID))
ret = []byte(acc.ID)
case 1:
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
ret = b
}
count++
return ret, nil
},
},
acc: acc,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.acc.ID, acc.ID)
assert.Equals(t, tc.acc.Status, acc.Status)
assert.Equals(t, tc.acc.Created, acc.Created)
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
assert.Equals(t, tc.acc.Contact, acc.Contact)
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
}
}
})
}
}
func TestGetAccountIDsByAccount(t *testing.T) {
type test struct {
id string
db nosql.DB
res []string
err *Error
}
tests := map[string]func(t *testing.T) test{
"ok/not-found": func(t *testing.T) test {
return test{
id: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
},
res: []string{},
}
},
"fail/db-error": func(t *testing.T) test {
return test{
id: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
return test{
id: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte("foo"))
return nil, nil
},
},
err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")),
}
},
"ok": func(t *testing.T) test {
oids := []string{"foo", "bar", "baz"}
b, err := json.Marshal(oids)
assert.FatalError(t, err)
return test{
id: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte("foo"))
return b, nil
},
},
res: oids,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if oids, err := getOrderIDsByAccount(tc.db, tc.id); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.res, oids)
}
}
})
}
}
func TestAccountToACME(t *testing.T) {
dir := newDirectory("ca.smallstep.com", "acme")
prov := newProv()
type test struct {
acc *account
err *Error
}
tests := map[string]func(t *testing.T) test{
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{acc: acc}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
acmeAccount, err := tc.acc.toACME(nil, dir, prov)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, acmeAccount.ID, tc.acc.ID)
assert.Equals(t, acmeAccount.Status, tc.acc.Status)
assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
assert.Equals(t, acmeAccount.Orders,
fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s/orders", URLSafeProvisionerName(prov), tc.acc.ID))
}
}
})
}
}
func TestAccountSave(t *testing.T) {
type test struct {
acc, old *account
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/old-nil/swap-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{
acc: acc,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing account: force")),
}
},
"fail/old-nil/swap-false": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{
acc: acc,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), false, nil
},
},
err: ServerInternalErr(errors.New("error storing account; value has changed since last read")),
}
},
"ok/old-nil": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, old, nil)
assert.Equals(t, b, newval)
assert.Equals(t, bucket, accountTable)
assert.Equals(t, []byte(acc.ID), key)
return nil, true, nil
},
},
}
},
"ok/old-not-nil": func(t *testing.T) test {
oldAcc, err := newAcc()
assert.FatalError(t, err)
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(oldAcc)
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
old: oldAcc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, old, oldb)
assert.Equals(t, newval, b)
assert.Equals(t, bucket, accountTable)
assert.Equals(t, []byte(acc.ID), key)
return []byte("foo"), true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := tc.acc.save(tc.db, tc.old); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestAccountSaveNew(t *testing.T) {
type test struct {
acc *account
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/keyToID-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
acc.Key.Key = "foo"
return test{
acc: acc,
err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
}
},
"fail/swap-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
kid, err := keyToID(acc.Key)
assert.FatalError(t, err)
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
assert.Equals(t, old, nil)
assert.Equals(t, newval, []byte(acc.ID))
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
}
},
"fail/swap-false": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
kid, err := keyToID(acc.Key)
assert.FatalError(t, err)
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
assert.Equals(t, old, nil)
assert.Equals(t, newval, []byte(acc.ID))
return nil, false, nil
},
},
err: ServerInternalErr(errors.New("key-id to account-id index already exists")),
}
},
"fail/save-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
kid, err := keyToID(acc.Key)
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
count := 0
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
assert.Equals(t, old, nil)
assert.Equals(t, newval, []byte(acc.ID))
count++
return nil, true, nil
}
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, nil)
assert.Equals(t, newval, b)
return nil, false, errors.New("force")
},
MDel: func(bucket, key []byte) error {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
return nil
},
},
err: ServerInternalErr(errors.New("error storing account: force")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
kid, err := keyToID(acc.Key)
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
count := 0
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
assert.Equals(t, old, nil)
assert.Equals(t, newval, []byte(acc.ID))
count++
return nil, true, nil
}
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, nil)
assert.Equals(t, newval, b)
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := tc.acc.saveNew(tc.db); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestAccountUpdate(t *testing.T) {
type test struct {
acc *account
contact []string
db nosql.DB
res []byte
err *Error
}
contact := []string{"foo", "bar"}
tests := map[string]func(t *testing.T) test{
"fail/save-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(acc)
assert.FatalError(t, err)
_acc := *acc
clone := &_acc
clone.Contact = contact
b, err := json.Marshal(clone)
assert.FatalError(t, err)
return test{
acc: acc,
contact: contact,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, oldb)
assert.Equals(t, newval, b)
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing account: force")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(acc)
assert.FatalError(t, err)
_acc := *acc
clone := &_acc
clone.Contact = contact
b, err := json.Marshal(clone)
assert.FatalError(t, err)
return test{
acc: acc,
contact: contact,
res: b,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, oldb)
assert.Equals(t, newval, b)
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
acc, err := tc.acc.update(tc.db, tc.contact)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
b, err := json.Marshal(acc)
assert.FatalError(t, err)
assert.Equals(t, b, tc.res)
}
}
})
}
}
func TestAccountDeactivate(t *testing.T) {
type test struct {
acc *account
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/save-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, oldb)
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing account: force")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, oldb)
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
acc, err := tc.acc.deactivate(tc.db)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, acc.ID, tc.acc.ID)
assert.Equals(t, acc.Contact, tc.acc.Contact)
assert.Equals(t, acc.Status, StatusDeactivated)
assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID)
assert.Equals(t, acc.Created, tc.acc.Created)
assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute)))
assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute)))
}
}
})
}
}
func TestNewAccount(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
kid, err := keyToID(jwk)
assert.FatalError(t, err)
ops := AccountOptions{
Key: jwk,
Contact: []string{"foo", "bar"},
}
type test struct {
ops AccountOptions
db nosql.DB
err *Error
id *string
}
tests := map[string]func(t *testing.T) test{
"fail/store-error": func(t *testing.T) test {
return test{
ops: ops,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
}
},
"ok": func(t *testing.T) test {
var _id string
id := &_id
count := 0
return test{
ops: ops,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
switch count {
case 0:
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
case 1:
assert.Equals(t, bucket, accountTable)
*id = string(key)
}
count++
return nil, true, nil
},
},
id: id,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
acc, err := newAccount(tc.db, tc.ops)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, acc.ID, *tc.id)
assert.Equals(t, acc.Status, StatusValid)
assert.Equals(t, acc.Contact, ops.Contact)
assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID)
assert.True(t, acc.Deactivated.IsZero())
assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute)))
assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute)))
}
}
})
}
}

213
acme/api/account.go Normal file
View file

@ -0,0 +1,213 @@
package api
import (
"encoding/json"
"net/http"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/logging"
)
// NewAccountRequest represents the payload for a new account request.
type NewAccountRequest struct {
Contact []string `json:"contact"`
OnlyReturnExisting bool `json:"onlyReturnExisting"`
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"`
}
func validateContacts(cs []string) error {
for _, c := range cs {
if len(c) == 0 {
return acme.MalformedErr(errors.New("contact cannot be empty string"))
}
}
return nil
}
// Validate validates a new-account request body.
func (n *NewAccountRequest) Validate() error {
if n.OnlyReturnExisting && len(n.Contact) > 0 {
return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone"))
}
return validateContacts(n.Contact)
}
// UpdateAccountRequest represents an update-account request.
type UpdateAccountRequest struct {
Contact []string `json:"contact"`
Status string `json:"status"`
}
// IsDeactivateRequest returns true if the update request is a deactivation
// request, false otherwise.
func (u *UpdateAccountRequest) IsDeactivateRequest() bool {
return u.Status == acme.StatusDeactivated
}
// Validate validates a update-account request body.
func (u *UpdateAccountRequest) Validate() error {
switch {
case len(u.Status) > 0 && len(u.Contact) > 0:
return acme.MalformedErr(errors.New("incompatible input; contact and " +
"status updates are mutually exclusive"))
case len(u.Contact) > 0:
if err := validateContacts(u.Contact); err != nil {
return err
}
return nil
case len(u.Status) > 0:
if u.Status != acme.StatusDeactivated {
return acme.MalformedErr(errors.Errorf("cannot update account "+
"status to %s, only deactivated", u.Status))
}
return nil
default:
return acme.MalformedErr(errors.Errorf("empty update request"))
}
}
// NewAccount is the handler resource for creating new ACME accounts.
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
payload, err := payloadFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
var nar NewAccountRequest
if err := json.Unmarshal(payload.value, &nar); err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
"failed to unmarshal new-account request payload")))
return
}
if err := nar.Validate(); err != nil {
api.WriteError(w, err)
return
}
httpStatus := http.StatusCreated
acc, err := accountFromContext(r)
if err != nil {
acmeErr, ok := err.(*acme.Error)
if !ok || acmeErr.Status != http.StatusNotFound {
// Something went wrong ...
api.WriteError(w, err)
return
}
// Account does not exist //
if nar.OnlyReturnExisting {
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
return
}
jwk, err := jwkFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
if acc, err = h.Auth.NewAccount(prov, acme.AccountOptions{
Key: jwk,
Contact: nar.Contact,
}); err != nil {
api.WriteError(w, err)
return
}
} else {
// Account exists //
httpStatus = http.StatusOK
}
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink,
acme.URLSafeProvisionerName(prov), true, acc.GetID()))
api.JSONStatus(w, acc, httpStatus)
return
}
// GetUpdateAccount is the api for updating an ACME account.
func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
acc, err := accountFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
payload, err := payloadFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
if !payload.isPostAsGet {
var uar UpdateAccountRequest
if err := json.Unmarshal(payload.value, &uar); err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload")))
return
}
if err := uar.Validate(); err != nil {
api.WriteError(w, err)
return
}
var err error
if uar.IsDeactivateRequest() {
acc, err = h.Auth.DeactivateAccount(prov, acc.GetID())
} else {
acc, err = h.Auth.UpdateAccount(prov, acc.GetID(), uar.Contact)
}
if err != nil {
api.WriteError(w, err)
return
}
}
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, acc.GetID()))
api.JSON(w, acc)
return
}
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
if rl, ok := w.(logging.ResponseLogger); ok {
m := map[string]interface{}{
"orders": oids,
}
rl.WithFields(m)
}
}
// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
acc, err := accountFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
accID := chi.URLParam(r, "accID")
if acc.ID != accID {
api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param")))
return
}
orders, err := h.Auth.GetOrdersByAccount(prov, acc.GetID())
if err != nil {
api.WriteError(w, err)
return
}
api.JSON(w, orders)
logOrdersByAccount(w, orders)
return
}

790
acme/api/account_test.go Normal file
View file

@ -0,0 +1,790 @@
package api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose"
)
var (
defaultDisableRenewal = false
globalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
}
)
func newProv() provisioner.Interface {
// Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-provisioner.com",
}
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
fmt.Printf("%v", err)
}
return p
}
func TestNewAccountRequestValidate(t *testing.T) {
type test struct {
nar *NewAccountRequest
err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/incompatible-input": func(t *testing.T) test {
return test{
nar: &NewAccountRequest{
OnlyReturnExisting: true,
Contact: []string{"foo", "bar"},
},
err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")),
}
},
"fail/bad-contact": func(t *testing.T) test {
return test{
nar: &NewAccountRequest{
Contact: []string{"foo", ""},
},
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
}
},
"ok": func(t *testing.T) test {
return test{
nar: &NewAccountRequest{
Contact: []string{"foo", "bar"},
},
}
},
"ok/onlyReturnExisting": func(t *testing.T) test {
return test{
nar: &NewAccountRequest{
OnlyReturnExisting: true,
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
if err := tc.nar.Validate(); err != nil {
if assert.NotNil(t, err) {
ae, ok := err.(*acme.Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestUpdateAccountRequestValidate(t *testing.T) {
type test struct {
uar *UpdateAccountRequest
err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/incompatible-input": func(t *testing.T) test {
return test{
uar: &UpdateAccountRequest{
Contact: []string{"foo", "bar"},
Status: "foo",
},
err: acme.MalformedErr(errors.Errorf("incompatible input; " +
"contact and status updates are mutually exclusive")),
}
},
"fail/bad-contact": func(t *testing.T) test {
return test{
uar: &UpdateAccountRequest{
Contact: []string{"foo", ""},
},
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
}
},
"fail/bad-status": func(t *testing.T) test {
return test{
uar: &UpdateAccountRequest{
Status: "foo",
},
err: acme.MalformedErr(errors.Errorf("cannot update account " +
"status to foo, only deactivated")),
}
},
"ok/contact": func(t *testing.T) test {
return test{
uar: &UpdateAccountRequest{
Contact: []string{"foo", "bar"},
},
}
},
"ok/status": func(t *testing.T) test {
return test{
uar: &UpdateAccountRequest{
Status: "deactivated",
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
if err := tc.uar.Validate(); err != nil {
if assert.NotNil(t, err) {
ae, ok := err.(*acme.Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestHandlerGetOrdersByAccount(t *testing.T) {
oids := []string{
"https://ca.smallstep.com/acme/order/foo",
"https://ca.smallstep.com/acme/order/bar",
}
accID := "account-id"
prov := newProv()
// Request with chi context
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("accID", accID)
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID)
type test struct {
auth acme.Interface
ctx context.Context
statusCode int
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
auth: &mockAcmeAuthority{},
ctx: ctx,
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "foo"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{},
ctx: ctx,
statusCode: 401,
problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")),
}
},
"fail/getOrdersByAccount-error": func(t *testing.T) test {
acc := &acme.Account{ID: accID}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
err: acme.ServerInternalErr(errors.New("force")),
},
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("force")),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: accID}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
getOrdersByAccount: func(p provisioner.Interface, id string) ([]string, error) {
assert.Equals(t, p, prov)
assert.Equals(t, id, acc.ID)
return oids, nil
},
},
ctx: ctx,
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetOrdersByAccount(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(oids)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}
func TestHandlerNewAccount(t *testing.T) {
accID := "accountID"
acc := acme.Account{
ID: accID,
Status: "valid",
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
}
prov := newProv()
url := "https://ca.smallstep.com/acme/new-account"
type test struct {
auth acme.Interface
ctx context.Context
statusCode int
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-payload": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
}
},
"fail/malformed-payload-error": func(t *testing.T) test {
nar := &NewAccountRequest{
Contact: []string{"foo", ""},
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
}
},
"fail/no-existing-account": func(t *testing.T) test {
nar := &NewAccountRequest{
OnlyReturnExisting: true,
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/no-jwk": func(t *testing.T) test {
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
}
},
"fail/nil-jwk": func(t *testing.T) test {
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
}
},
"fail/NewAccount-error": func(t *testing.T) test {
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
return test{
auth: &mockAcmeAuthority{
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
assert.Equals(t, p, prov)
assert.Equals(t, ops.Contact, nar.Contact)
assert.Equals(t, ops.Key, jwk)
return nil, acme.ServerInternalErr(errors.New("force"))
},
},
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("force")),
}
},
"ok/new-account": func(t *testing.T) test {
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
return test{
auth: &mockAcmeAuthority{
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
assert.Equals(t, p, prov)
assert.Equals(t, ops.Contact, nar.Contact)
assert.Equals(t, ops.Key, jwk)
return &acc, nil
},
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, typ, acme.AccountLink)
assert.True(t, abs)
assert.Equals(t, in, []string{accID})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
acme.URLSafeProvisionerName(prov), accID)
},
},
ctx: ctx,
statusCode: 201,
}
},
"ok/return-existing": func(t *testing.T) test {
nar := &NewAccountRequest{
OnlyReturnExisting: true,
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, accContextKey, &acc)
return test{
auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, typ, acme.AccountLink)
assert.True(t, abs)
assert.Equals(t, in, []string{accID})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
acme.URLSafeProvisionerName(prov), accID)
},
},
ctx: ctx,
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.NewAccount(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(acc)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
acme.URLSafeProvisionerName(prov), accID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}
func TestHandlerGetUpdateAccount(t *testing.T) {
accID := "accountID"
acc := acme.Account{
ID: accID,
Status: "valid",
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
}
prov := newProv()
// Request with chi context
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s", accID)
type test struct {
auth acme.Interface
ctx context.Context
statusCode int
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
ctx: ctx,
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
}
},
"fail/malformed-payload-error": func(t *testing.T) test {
uar := &UpdateAccountRequest{
Contact: []string{"foo", ""},
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
}
},
"fail/Deactivate-error": func(t *testing.T) test {
uar := &UpdateAccountRequest{
Status: "deactivated",
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
auth: &mockAcmeAuthority{
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) {
assert.Equals(t, p, prov)
assert.Equals(t, id, accID)
return nil, acme.ServerInternalErr(errors.New("force"))
},
},
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("force")),
}
},
"fail/UpdateAccount-error": func(t *testing.T) test {
uar := &UpdateAccountRequest{
Contact: []string{"foo", "bar"},
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
auth: &mockAcmeAuthority{
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) {
assert.Equals(t, p, prov)
assert.Equals(t, id, accID)
assert.Equals(t, contacts, uar.Contact)
return nil, acme.ServerInternalErr(errors.New("force"))
},
},
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("force")),
}
},
"ok/deactivate": func(t *testing.T) test {
uar := &UpdateAccountRequest{
Status: "deactivated",
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
auth: &mockAcmeAuthority{
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) {
assert.Equals(t, p, prov)
assert.Equals(t, id, accID)
return &acc, nil
},
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs)
assert.Equals(t, in, []string{accID})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
acme.URLSafeProvisionerName(prov), accID)
},
},
ctx: ctx,
statusCode: 200,
}
},
"ok/new-account": func(t *testing.T) test {
uar := &UpdateAccountRequest{
Contact: []string{"foo", "bar"},
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
auth: &mockAcmeAuthority{
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) {
assert.Equals(t, p, prov)
assert.Equals(t, id, accID)
assert.Equals(t, contacts, uar.Contact)
return &acc, nil
},
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs)
assert.Equals(t, in, []string{accID})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
acme.URLSafeProvisionerName(prov), accID)
},
},
ctx: ctx,
statusCode: 200,
}
},
"ok/post-as-get": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
return test{
auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs)
assert.Equals(t, in, []string{accID})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
acme.URLSafeProvisionerName(prov), accID)
},
},
ctx: ctx,
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetUpdateAccount(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(acc)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
acme.URLSafeProvisionerName(prov), accID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}

214
acme/api/handler.go Normal file
View file

@ -0,0 +1,214 @@
package api
import (
"fmt"
"net/http"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose"
)
func link(url, typ string) string {
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
}
type contextKey string
const (
accContextKey = contextKey("acc")
jwsContextKey = contextKey("jws")
jwkContextKey = contextKey("jwk")
payloadContextKey = contextKey("payload")
provisionerContextKey = contextKey("provisioner")
)
type payloadInfo struct {
value []byte
isPostAsGet bool
isEmptyJSON bool
}
func accountFromContext(r *http.Request) (*acme.Account, error) {
val, ok := r.Context().Value(accContextKey).(*acme.Account)
if !ok || val == nil {
return nil, acme.AccountDoesNotExistErr(nil)
}
return val, nil
}
func jwkFromContext(r *http.Request) (*jose.JSONWebKey, error) {
val, ok := r.Context().Value(jwkContextKey).(*jose.JSONWebKey)
if !ok || val == nil {
return nil, acme.ServerInternalErr(errors.Errorf("jwk expected in request context"))
}
return val, nil
}
func jwsFromContext(r *http.Request) (*jose.JSONWebSignature, error) {
val, ok := r.Context().Value(jwsContextKey).(*jose.JSONWebSignature)
if !ok || val == nil {
return nil, acme.ServerInternalErr(errors.Errorf("jws expected in request context"))
}
return val, nil
}
func payloadFromContext(r *http.Request) (*payloadInfo, error) {
val, ok := r.Context().Value(payloadContextKey).(*payloadInfo)
if !ok || val == nil {
return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context"))
}
return val, nil
}
func provisionerFromContext(r *http.Request) (provisioner.Interface, error) {
val, ok := r.Context().Value(provisionerContextKey).(provisioner.Interface)
if !ok || val == nil {
return nil, acme.ServerInternalErr(errors.Errorf("provisioner expected in request context"))
}
return val, nil
}
// New returns a new ACME API router.
func New(acmeAuth acme.Interface) api.RouterHandler {
return &Handler{acmeAuth}
}
// Handler is the ACME request handler.
type Handler struct {
Auth acme.Interface
}
// Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) {
getLink := h.Auth.GetLink
// Standard ACME API
r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))
}
extractPayloadByKid := func(next nextHTTP) nextHTTP {
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))
}
r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false), extractPayloadByJWK(h.NewAccount))
r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.GetUpdateAccount))
r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false), extractPayloadByKid(h.NewOrder))
r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount)))
r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz)))
r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, "{chID}"), extractPayloadByKid(h.GetChallenge))
r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
}
// GetNonce just sets the right header since a Nonce is added to each response
// by middleware by default.
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusNoContent)
}
return
}
// GetDirectory is the ACME resource for returning a directory configuration
// for client configuration.
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
dir := h.Auth.GetDirectory(prov)
api.JSON(w, dir)
return
}
// GetAuthz ACME api for retrieving an Authz.
func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
acc, err := accountFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
authz, err := h.Auth.GetAuthz(prov, acc.GetID(), chi.URLParam(r, "authzID"))
if err != nil {
api.WriteError(w, err)
return
}
w.Header().Set("Location", h.Auth.GetLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, authz.GetID()))
api.JSON(w, authz)
return
}
// GetChallenge ACME api for retrieving a Challenge.
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
acc, err := accountFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
// Just verify that the payload was set, since we're not strictly adhering
// to ACME V2 spec for reasons specified below.
_, err = payloadFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
// NOTE: We should be checking that the request is either a POST-as-GET, or
// that the payload is an empty JSON block ({}). However, older ACME clients
// still send a vestigial body (rather than an empty JSON block) and
// strict enforcement would render these clients broken. For the time being
// we'll just ignore the body.
var (
ch *acme.Challenge
chID = chi.URLParam(r, "chID")
)
ch, err = h.Auth.ValidateChallenge(prov, acc.GetID(), chID, acc.GetKey())
if err != nil {
api.WriteError(w, err)
return
}
getLink := h.Auth.GetLink
w.Header().Add("Link", link(getLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, ch.GetAuthzID()), "up"))
w.Header().Set("Location", getLink(acme.ChallengeLink, acme.URLSafeProvisionerName(prov), true, ch.GetID()))
api.JSON(w, ch)
return
}
// GetCertificate ACME api for retrieving a Certificate.
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
certID := chi.URLParam(r, "certID")
certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID)
if err != nil {
api.WriteError(w, err)
return
}
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
w.Write(certBytes)
return
}

771
acme/api/handler_test.go Normal file
View file

@ -0,0 +1,771 @@
package api
import (
"bytes"
"context"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/jose"
)
type mockAcmeAuthority struct {
deactivateAccount func(provisioner.Interface, string) (*acme.Account, error)
finalizeOrder func(p provisioner.Interface, accID string, id string, csr *x509.CertificateRequest) (*acme.Order, error)
getAccount func(p provisioner.Interface, id string) (*acme.Account, error)
getAccountByKey func(provisioner.Interface, *jose.JSONWebKey) (*acme.Account, error)
getAuthz func(p provisioner.Interface, accID string, id string) (*acme.Authz, error)
getCertificate func(accID string, id string) ([]byte, error)
getChallenge func(p provisioner.Interface, accID string, id string) (*acme.Challenge, error)
getDirectory func(provisioner.Interface) *acme.Directory
getLink func(acme.Link, string, bool, ...string) string
getOrder func(p provisioner.Interface, accID string, id string) (*acme.Order, error)
getOrdersByAccount func(p provisioner.Interface, id string) ([]string, error)
loadProvisionerByID func(string) (provisioner.Interface, error)
newAccount func(provisioner.Interface, acme.AccountOptions) (*acme.Account, error)
newNonce func() (string, error)
newOrder func(provisioner.Interface, acme.OrderOptions) (*acme.Order, error)
updateAccount func(provisioner.Interface, string, []string) (*acme.Account, error)
useNonce func(string) error
validateChallenge func(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error)
ret1 interface{}
err error
}
func (m *mockAcmeAuthority) DeactivateAccount(p provisioner.Interface, id string) (*acme.Account, error) {
if m.deactivateAccount != nil {
return m.deactivateAccount(p, id)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Account), m.err
}
func (m *mockAcmeAuthority) FinalizeOrder(p provisioner.Interface, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) {
if m.finalizeOrder != nil {
return m.finalizeOrder(p, accID, id, csr)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Order), m.err
}
func (m *mockAcmeAuthority) GetAccount(p provisioner.Interface, id string) (*acme.Account, error) {
if m.getAccount != nil {
return m.getAccount(p, id)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Account), m.err
}
func (m *mockAcmeAuthority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) {
if m.getAccountByKey != nil {
return m.getAccountByKey(p, jwk)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Account), m.err
}
func (m *mockAcmeAuthority) GetAuthz(p provisioner.Interface, accID, id string) (*acme.Authz, error) {
if m.getAuthz != nil {
return m.getAuthz(p, accID, id)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Authz), m.err
}
func (m *mockAcmeAuthority) GetCertificate(accID, id string) ([]byte, error) {
if m.getCertificate != nil {
return m.getCertificate(accID, id)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.([]byte), m.err
}
func (m *mockAcmeAuthority) GetChallenge(p provisioner.Interface, accID, id string) (*acme.Challenge, error) {
if m.getChallenge != nil {
return m.getChallenge(p, accID, id)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Challenge), m.err
}
func (m *mockAcmeAuthority) GetDirectory(p provisioner.Interface) *acme.Directory {
if m.getDirectory != nil {
return m.getDirectory(p)
}
return m.ret1.(*acme.Directory)
}
func (m *mockAcmeAuthority) GetLink(typ acme.Link, provID string, abs bool, in ...string) string {
if m.getLink != nil {
return m.getLink(typ, provID, abs, in...)
}
return m.ret1.(string)
}
func (m *mockAcmeAuthority) GetOrder(p provisioner.Interface, accID, id string) (*acme.Order, error) {
if m.getOrder != nil {
return m.getOrder(p, accID, id)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Order), m.err
}
func (m *mockAcmeAuthority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) {
if m.getOrdersByAccount != nil {
return m.getOrdersByAccount(p, id)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.([]string), m.err
}
func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
if m.loadProvisionerByID != nil {
return m.loadProvisionerByID(provID)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(provisioner.Interface), m.err
}
func (m *mockAcmeAuthority) NewAccount(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
if m.newAccount != nil {
return m.newAccount(p, ops)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Account), m.err
}
func (m *mockAcmeAuthority) NewNonce() (string, error) {
if m.newNonce != nil {
return m.newNonce()
} else if m.err != nil {
return "", m.err
}
return m.ret1.(string), m.err
}
func (m *mockAcmeAuthority) NewOrder(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
if m.newOrder != nil {
return m.newOrder(p, ops)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Order), m.err
}
func (m *mockAcmeAuthority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*acme.Account, error) {
if m.updateAccount != nil {
return m.updateAccount(p, id, contact)
} else if m.err != nil {
return nil, m.err
}
return m.ret1.(*acme.Account), m.err
}
func (m *mockAcmeAuthority) UseNonce(nonce string) error {
if m.useNonce != nil {
return m.useNonce(nonce)
}
return m.err
}
func (m *mockAcmeAuthority) ValidateChallenge(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
switch {
case m.validateChallenge != nil:
return m.validateChallenge(p, accID, id, jwk)
case m.err != nil:
return nil, m.err
default:
return m.ret1.(*acme.Challenge), m.err
}
}
func TestHandlerGetNonce(t *testing.T) {
tests := []struct {
name string
statusCode int
}{
{"GET", 204},
{"HEAD", 200},
}
// Request with chi context
req := httptest.NewRequest("GET", "http://ca.smallstep.com/nonce", nil)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(nil).(*Handler)
w := httptest.NewRecorder()
req.Method = tt.name
h.GetNonce(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("Handler.GetNonce StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
})
}
}
func TestHandlerGetDirectory(t *testing.T) {
auth := acme.NewAuthority(nil, "ca.smallstep.com", "acme", nil)
prov := newProv()
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/directory", acme.URLSafeProvisionerName(prov))
expDir := acme.Directory{
NewNonce: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", acme.URLSafeProvisionerName(prov)),
NewAccount: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", acme.URLSafeProvisionerName(prov)),
NewOrder: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", acme.URLSafeProvisionerName(prov)),
RevokeCert: fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", acme.URLSafeProvisionerName(prov)),
KeyChange: fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", acme.URLSafeProvisionerName(prov)),
}
type test struct {
ctx context.Context
statusCode int
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"ok": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetDirectory(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
var dir acme.Directory
json.Unmarshal(bytes.TrimSpace(body), &dir)
assert.Equals(t, dir, expDir)
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}
func TestHandlerGetAuthz(t *testing.T) {
expiry := time.Now().UTC().Add(6 * time.Hour)
az := acme.Authz{
ID: "authzID",
Identifier: acme.Identifier{
Type: "dns",
Value: "example.com",
},
Status: "pending",
Expires: expiry.Format(time.RFC3339),
Wildcard: false,
Challenges: []*acme.Challenge{
{
Type: "http-01",
Status: "pending",
Token: "tok2",
URL: "https://ca.smallstep.com/acme/challenge/chHTTPID",
ID: "chHTTP01ID",
AuthzID: "authzID",
},
{
Type: "dns-01",
Status: "pending",
Token: "tok2",
URL: "https://ca.smallstep.com/acme/challenge/chDNSID",
ID: "chDNSID",
AuthzID: "authzID",
},
},
}
prov := newProv()
// Request with chi context
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("authzID", az.ID)
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/challenge/%s",
acme.URLSafeProvisionerName(prov), az.ID)
type test struct {
auth acme.Interface
ctx context.Context
statusCode int
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
auth: &mockAcmeAuthority{},
ctx: ctx,
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/getAuthz-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
err: acme.ServerInternalErr(errors.New("force")),
},
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("force")),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
getAuthz: func(p provisioner.Interface, accID, id string) (*acme.Authz, error) {
assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, az.ID)
return &az, nil
},
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, typ, acme.AuthzLink)
assert.True(t, abs)
assert.Equals(t, in, []string{az.ID})
return url
},
},
ctx: ctx,
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetAuthz(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
//var gotAz acme.Authz
//assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &gotAz))
expB, err := json.Marshal(az)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], []string{url})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}
func TestHandlerGetCertificate(t *testing.T) {
leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt")
assert.FatalError(t, err)
inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt")
assert.FatalError(t, err)
root, err := pemutil.ReadCertificate("../../authority/testdata/certs/root_ca.crt")
assert.FatalError(t, err)
certBytes := append(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: leaf.Raw,
}), pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: inter.Raw,
})...)
certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: root.Raw,
})...)
certID := "certID"
prov := newProv()
// Request with chi context
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("certID", certID)
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/certificate/%s",
acme.URLSafeProvisionerName(prov), certID)
type test struct {
auth acme.Interface
ctx context.Context
statusCode int
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, nil)
return test{
auth: &mockAcmeAuthority{},
ctx: ctx,
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/getCertificate-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
err: acme.ServerInternalErr(errors.New("force")),
},
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("force")),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
getCertificate: func(accID, id string) ([]byte, error) {
assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, certID)
return certBytes, nil
},
},
ctx: ctx,
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetCertificate(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes))
assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain; charset=utf-8"})
}
})
}
}
func ch() acme.Challenge {
return acme.Challenge{
Type: "http-01",
Status: "pending",
Token: "tok2",
URL: "https://ca.smallstep.com/acme/challenge/chID",
ID: "chID",
AuthzID: "authzID",
}
}
func TestHandlerGetChallenge(t *testing.T) {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("chID", "chID")
url := fmt.Sprintf("http://ca.smallstep.com/acme/challenge/%s", "chID")
prov := newProv()
type test struct {
auth acme.Interface
ctx context.Context
statusCode int
ch acme.Challenge
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
ctx: ctx,
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/validate-challenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
err: acme.UnauthorizedErr(nil),
},
ctx: ctx,
statusCode: 401,
problem: acme.UnauthorizedErr(nil),
}
},
"fail/get-challenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
err: acme.UnauthorizedErr(nil),
},
ctx: ctx,
statusCode: 401,
problem: acme.UnauthorizedErr(nil),
}
},
"ok/validate-challenge": func(t *testing.T) test {
key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
acc := &acme.Account{ID: "accID", Key: key}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ch := ch()
ch.Status = "valid"
ch.Validated = time.Now().UTC().Format(time.RFC3339)
count := 0
return test{
auth: &mockAcmeAuthority{
validateChallenge: func(p provisioner.Interface, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, ch.ID)
assert.Equals(t, jwk.KeyID, key.KeyID)
return &ch, nil
},
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
var ret string
switch count {
case 0:
assert.Equals(t, typ, acme.AuthzLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs)
assert.Equals(t, in, []string{ch.AuthzID})
ret = fmt.Sprintf("https://ca.smallstep.com/acme/authz/%s", ch.AuthzID)
case 1:
assert.Equals(t, typ, acme.ChallengeLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs)
assert.Equals(t, in, []string{ch.ID})
ret = url
}
count++
return ret
},
},
ctx: ctx,
statusCode: 200,
ch: ch,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetChallenge(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(tc.ch)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<https://ca.smallstep.com/acme/authz/%s>;rel=\"up\"", tc.ch.AuthzID)})
assert.Equals(t, res.Header["Location"], []string{url})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}

377
acme/api/middleware.go Normal file
View file

@ -0,0 +1,377 @@
package api
import (
"context"
"crypto/rsa"
"io/ioutil"
"net/http"
"net/url"
"strings"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/jose"
"github.com/smallstep/nosql"
)
type nextHTTP = func(http.ResponseWriter, *http.Request)
func logNonce(w http.ResponseWriter, nonce string) {
if rl, ok := w.(logging.ResponseLogger); ok {
m := map[string]interface{}{
"nonce": nonce,
}
rl.WithFields(m)
}
}
// addNonce is a middleware that adds a nonce to the response header.
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.Auth.NewNonce()
if err != nil {
api.WriteError(w, err)
return
}
w.Header().Set("Replay-Nonce", nonce)
w.Header().Set("Cache-Control", "no-store")
logNonce(w, nonce)
next(w, r)
return
}
}
// addDirLink is a middleware that adds a 'Link' response reader with the
// directory index url.
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
w.Header().Add("Link", link(h.Auth.GetLink(acme.DirectoryLink, acme.URLSafeProvisionerName(prov), true), "index"))
next(w, r)
return
}
}
// verifyContentType is a middleware that verifies that content type is
// application/jose+json.
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
ct := r.Header.Get("Content-Type")
var expected []string
if strings.Contains(r.URL.Path, h.Auth.GetLink(acme.CertificateLink, acme.URLSafeProvisionerName(prov), false, "")) {
// GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
} else {
// By default every request should have content-type applictaion/jose+json.
expected = []string{"application/jose+json"}
}
for _, e := range expected {
if ct == e {
next(w, r)
return
}
}
api.WriteError(w, acme.MalformedErr(errors.Errorf(
"expected content-type to be in %s, but got %s", expected, ct)))
return
}
}
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body")))
return
}
jws, err := jose.ParseJWS(string(body))
if err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body")))
return
}
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
next(w, r.WithContext(ctx))
return
}
}
// validateJWS checks the request body for to verify that it meets ACME
// requirements for a JWS.
//
// The JWS MUST NOT have multiple signatures
// The JWS Unencoded Payload Option [RFC7797] MUST NOT be used
// The JWS Unprotected Header [RFC7515] MUST NOT be used
// The JWS Payload MUST NOT be detached
// The JWS Protected Header MUST include the following fields:
// * “alg” (Algorithm)
// * This field MUST NOT contain “none” or a Message Authentication Code
// (MAC) algorithm (e.g. one in which the algorithm registry description
// mentions MAC/HMAC).
// * “nonce” (defined in Section 6.5)
// * “url” (defined in Section 6.4)
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
jws, err := jwsFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
if len(jws.Signatures) == 0 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature")))
return
}
if len(jws.Signatures) > 1 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature")))
return
}
sig := jws.Signatures[0]
uh := sig.Unprotected
if len(uh.KeyID) > 0 ||
uh.JSONWebKey != nil ||
len(uh.Algorithm) > 0 ||
len(uh.Nonce) > 0 ||
len(uh.ExtraHeaders) > 0 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used")))
return
}
hdr := sig.Protected
switch hdr.Algorithm {
case jose.RS256, jose.RS384, jose.RS512:
if hdr.JSONWebKey != nil {
switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey:
if k.Size() < keys.MinRSAKeyBytes {
api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+
"keys must be at least %d bits (%d bytes) in size",
8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes)))
return
}
default:
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")))
return
}
}
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good
default:
api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm)))
return
}
// Check the validity/freshness of the Nonce.
if err := h.Auth.UseNonce(hdr.Nonce); err != nil {
api.WriteError(w, err)
return
}
// Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok {
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header")))
return
}
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() {
api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)))
return
}
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")))
return
}
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")))
return
}
next(w, r)
return
}
}
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
// in the context. Make sure to parse and validate the JWS before running this
// middleware.
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
jws, err := jwsFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil {
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header")))
return
}
if !jwk.Valid() {
api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header")))
return
}
ctx = context.WithValue(ctx, jwkContextKey, jwk)
acc, err := h.Auth.GetAccountByKey(prov, jwk)
switch {
case nosql.IsErrNotFound(err):
// For NewAccount requests ...
break
case err != nil:
api.WriteError(w, err)
return
default:
if !acc.IsValid() {
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
return
}
ctx = context.WithValue(ctx, accContextKey, acc)
}
next(w, r.WithContext(ctx))
return
}
}
// lookupProvisioner loads the provisioner associated with the request.
// Responsds 404 if the provisioner does not exist.
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
name := chi.URLParam(r, "provisionerID")
provID, err := url.PathUnescape(name)
if err != nil {
api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name)))
return
}
p, err := h.Auth.LoadProvisionerByID("acme/" + provID)
if err != nil {
api.WriteError(w, err)
return
}
if p.GetType() != provisioner.TypeACME {
api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME")))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, p)
next(w, r.WithContext(ctx))
return
}
}
// lookupJWK loads the JWK associated with the acme account referenced by the
// kid parameter of the signed payload.
// Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
jws, err := jwsFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
kidPrefix := h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, "")
kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) {
api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+
"required prefix; expected %s, but got %s", kidPrefix, kid)))
return
}
accID := strings.TrimPrefix(kid, kidPrefix)
acc, err := h.Auth.GetAccount(prov, accID)
switch {
case nosql.IsErrNotFound(err):
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
return
case err != nil:
api.WriteError(w, err)
return
default:
if !acc.IsValid() {
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
return
}
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
next(w, r.WithContext(ctx))
return
}
}
}
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
// Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
jws, err := jwsFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
jwk, err := jwkFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match")))
return
}
payload, err := jws.Verify(jwk)
if err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws")))
return
}
ctx := context.WithValue(r.Context(), payloadContextKey, &payloadInfo{
value: payload,
isPostAsGet: string(payload) == "",
isEmptyJSON: string(payload) == "{}",
})
next(w, r.WithContext(ctx))
return
}
}
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
if !payload.isPostAsGet {
api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET")))
return
}
next(w, r)
return
}
}

1550
acme/api/middleware_test.go Normal file

File diff suppressed because it is too large Load diff

164
acme/api/order.go Normal file
View file

@ -0,0 +1,164 @@
package api
import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"net/http"
"time"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
)
// NewOrderRequest represents the body for a NewOrder request.
type NewOrderRequest struct {
Identifiers []acme.Identifier `json:"identifiers"`
NotBefore time.Time `json:"notBefore,omitempty"`
NotAfter time.Time `json:"notAfter,omitempty"`
}
// Validate validates a new-order request body.
func (n *NewOrderRequest) Validate() error {
if len(n.Identifiers) == 0 {
return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty"))
}
for _, id := range n.Identifiers {
if id.Type != "dns" {
return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type))
}
}
return nil
}
// FinalizeRequest captures the body for a Finalize order request.
type FinalizeRequest struct {
CSR string `json:"csr"`
csr *x509.CertificateRequest
}
// Validate validates a finalize request body.
func (f *FinalizeRequest) Validate() error {
var err error
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
if err != nil {
return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr"))
}
f.csr, err = x509.ParseCertificateRequest(csrBytes)
if err != nil {
return acme.MalformedErr(errors.Wrap(err, "unable to parse csr"))
}
if err = f.csr.CheckSignature(); err != nil {
return acme.MalformedErr(errors.Wrap(err, "csr failed signature check"))
}
return nil
}
// NewOrder ACME api for creating a new order.
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
acc, err := accountFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
payload, err := payloadFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
"failed to unmarshal new-order request payload")))
return
}
if err := nor.Validate(); err != nil {
api.WriteError(w, err)
return
}
o, err := h.Auth.NewOrder(prov, acme.OrderOptions{
AccountID: acc.GetID(),
Identifiers: nor.Identifiers,
NotBefore: nor.NotBefore,
NotAfter: nor.NotAfter,
})
if err != nil {
api.WriteError(w, err)
return
}
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
api.JSONStatus(w, o, http.StatusCreated)
return
}
// GetOrder ACME api for retrieving an order.
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
acc, err := accountFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
oid := chi.URLParam(r, "ordID")
o, err := h.Auth.GetOrder(prov, acc.GetID(), oid)
if err != nil {
api.WriteError(w, err)
return
}
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
api.JSON(w, o)
return
}
// FinalizeOrder attemptst to finalize an order and create a certificate.
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
acc, err := accountFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
payload, err := payloadFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
var fr FinalizeRequest
if err := json.Unmarshal(payload.value, &fr); err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload")))
return
}
if err := fr.Validate(); err != nil {
api.WriteError(w, err)
return
}
oid := chi.URLParam(r, "ordID")
o, err := h.Auth.FinalizeOrder(prov, acc.GetID(), oid, fr.csr)
if err != nil {
api.WriteError(w, err)
return
}
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.ID))
api.JSON(w, o)
return
}

757
acme/api/order_test.go Normal file
View file

@ -0,0 +1,757 @@
package api
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/pemutil"
)
func TestNewOrderRequestValidate(t *testing.T) {
type test struct {
nor *NewOrderRequest
nbf, naf time.Time
err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-identifiers": func(t *testing.T) test {
return test{
nor: &NewOrderRequest{},
err: acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")),
}
},
"fail/bad-identifier": func(t *testing.T) test {
return test{
nor: &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "example.com"},
{Type: "foo", Value: "bar.com"},
},
},
err: acme.MalformedErr(errors.Errorf("identifier type unsupported: foo")),
}
},
"ok": func(t *testing.T) test {
nbf := time.Now().UTC().Add(time.Minute)
naf := time.Now().UTC().Add(5 * time.Minute)
return test{
nor: &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "example.com"},
{Type: "dns", Value: "bar.com"},
},
NotAfter: naf,
NotBefore: nbf,
},
nbf: nbf,
naf: naf,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
if err := tc.nor.Validate(); err != nil {
if assert.NotNil(t, err) {
ae, ok := err.(*acme.Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
if tc.nbf.IsZero() {
assert.True(t, tc.nor.NotBefore.Before(time.Now().Add(time.Minute)))
assert.True(t, tc.nor.NotBefore.After(time.Now().Add(-time.Minute)))
} else {
assert.Equals(t, tc.nor.NotBefore, tc.nbf)
}
if tc.naf.IsZero() {
assert.True(t, tc.nor.NotAfter.Before(time.Now().Add(24*time.Hour)))
assert.True(t, tc.nor.NotAfter.After(time.Now().Add(24*time.Hour-time.Minute)))
} else {
assert.Equals(t, tc.nor.NotAfter, tc.naf)
}
}
}
})
}
}
func TestFinalizeRequestValidate(t *testing.T) {
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
assert.FatalError(t, err)
csr, ok := _csr.(*x509.CertificateRequest)
assert.Fatal(t, ok)
type test struct {
fr *FinalizeRequest
err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/parse-csr-error": func(t *testing.T) test {
return test{
fr: &FinalizeRequest{},
err: acme.MalformedErr(errors.Errorf("unable to parse csr: asn1: syntax error: sequence truncated")),
}
},
"fail/invalid-csr-signature": func(t *testing.T) test {
b, err := pemutil.Read("../../authority/testdata/certs/badsig.csr")
assert.FatalError(t, err)
c, ok := b.(*x509.CertificateRequest)
assert.Fatal(t, ok)
return test{
fr: &FinalizeRequest{
CSR: base64.RawURLEncoding.EncodeToString(c.Raw),
},
err: acme.MalformedErr(errors.Errorf("csr failed signature check: x509: ECDSA verification failure")),
}
},
"ok": func(t *testing.T) test {
return test{
fr: &FinalizeRequest{
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
if err := tc.fr.Validate(); err != nil {
if assert.NotNil(t, err) {
ae, ok := err.(*acme.Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.fr.csr.Raw, csr.Raw)
}
}
})
}
}
func TestHandlerGetOrder(t *testing.T) {
expiry := time.Now().UTC().Add(6 * time.Hour)
nbf := time.Now().UTC()
naf := time.Now().UTC().Add(24 * time.Hour)
o := acme.Order{
ID: "orderID",
Expires: expiry.Format(time.RFC3339),
NotBefore: nbf.Format(time.RFC3339),
NotAfter: naf.Format(time.RFC3339),
Identifiers: []acme.Identifier{
{
Type: "dns",
Value: "example.com",
},
{
Type: "dns",
Value: "*.smallstep.com",
},
},
Status: "pending",
Authorizations: []string{"foo", "bar"},
}
// Request with chi context
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID)
prov := newProv()
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s",
acme.URLSafeProvisionerName(prov), o.ID)
type test struct {
auth acme.Interface
ctx context.Context
statusCode int
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
auth: &mockAcmeAuthority{},
ctx: ctx,
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/getOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
err: acme.ServerInternalErr(errors.New("force")),
},
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("force")),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
getOrder: func(p provisioner.Interface, accID, id string) (*acme.Order, error) {
assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, o.ID)
return &o, nil
},
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, typ, acme.OrderLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs)
assert.Equals(t, in, []string{o.ID})
return url
},
},
ctx: ctx,
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetOrder(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(o)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], []string{url})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}
func TestHandlerNewOrder(t *testing.T) {
expiry := time.Now().UTC().Add(6 * time.Hour)
nbf := time.Now().UTC().Add(5 * time.Hour)
naf := nbf.Add(17 * time.Hour)
o := acme.Order{
ID: "orderID",
Expires: expiry.Format(time.RFC3339),
NotBefore: nbf.Format(time.RFC3339),
NotAfter: naf.Format(time.RFC3339),
Identifiers: []acme.Identifier{
{Type: "dns", Value: "example.com"},
{Type: "dns", Value: "bar.com"},
},
Status: "pending",
Authorizations: []string{"foo", "bar"},
}
prov := newProv()
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order",
acme.URLSafeProvisionerName(prov))
type test struct {
auth acme.Interface
ctx context.Context
statusCode int
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
ctx: ctx,
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("failed to unmarshal new-order request payload: unexpected end of JSON input")),
}
},
"fail/malformed-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
nor := &NewOrderRequest{}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("identifiers list cannot be empty")),
}
},
"fail/NewOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
nor := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "example.com"},
{Type: "dns", Value: "bar.com"},
},
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
auth: &mockAcmeAuthority{
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
assert.Equals(t, p, prov)
assert.Equals(t, ops.AccountID, acc.ID)
assert.Equals(t, ops.Identifiers, nor.Identifiers)
return nil, acme.MalformedErr(errors.New("force"))
},
},
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("force")),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
nor := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "example.com"},
{Type: "dns", Value: "bar.com"},
},
NotBefore: nbf,
NotAfter: naf,
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
auth: &mockAcmeAuthority{
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
assert.Equals(t, p, prov)
assert.Equals(t, ops.AccountID, acc.ID)
assert.Equals(t, ops.Identifiers, nor.Identifiers)
assert.Equals(t, ops.NotBefore, nbf)
assert.Equals(t, ops.NotAfter, naf)
return &o, nil
},
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, typ, acme.OrderLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs)
assert.Equals(t, in, []string{o.ID})
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)
},
},
ctx: ctx,
statusCode: 201,
}
},
"ok/default-naf-nbf": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
nor := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "example.com"},
{Type: "dns", Value: "bar.com"},
},
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
auth: &mockAcmeAuthority{
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
assert.Equals(t, p, prov)
assert.Equals(t, ops.AccountID, acc.ID)
assert.Equals(t, ops.Identifiers, nor.Identifiers)
assert.True(t, ops.NotBefore.IsZero())
assert.True(t, ops.NotAfter.IsZero())
return &o, nil
},
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, typ, acme.OrderLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs)
assert.Equals(t, in, []string{o.ID})
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)
},
},
ctx: ctx,
statusCode: 201,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.NewOrder(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(o)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}
func TestHandlerFinalizeOrder(t *testing.T) {
expiry := time.Now().UTC().Add(6 * time.Hour)
nbf := time.Now().UTC().Add(5 * time.Hour)
naf := nbf.Add(17 * time.Hour)
o := acme.Order{
ID: "orderID",
Expires: expiry.Format(time.RFC3339),
NotBefore: nbf.Format(time.RFC3339),
NotAfter: naf.Format(time.RFC3339),
Identifiers: []acme.Identifier{
{Type: "dns", Value: "example.com"},
{Type: "dns", Value: "bar.com"},
},
Status: "valid",
Authorizations: []string{"foo", "bar"},
Certificate: "https://ca.smallstep.com/acme/certificate/certID",
}
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
assert.FatalError(t, err)
csr, ok := _csr.(*x509.CertificateRequest)
assert.Fatal(t, ok)
// Request with chi context
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID)
prov := newProv()
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s/finalize",
acme.URLSafeProvisionerName(prov), o.ID)
type test struct {
auth acme.Interface
ctx context.Context
statusCode int
problem *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
auth: &mockAcmeAuthority{},
ctx: ctx,
statusCode: 404,
problem: acme.AccountDoesNotExistErr(nil),
}
},
"fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("failed to unmarshal finalize-order request payload: unexpected end of JSON input")),
}
},
"fail/malformed-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
fr := &FinalizeRequest{}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("unable to parse csr: asn1: syntax error: sequence truncated")),
}
},
"fail/FinalizeOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
nor := &FinalizeRequest{
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, o.ID)
assert.Equals(t, incsr.Raw, csr.Raw)
return nil, acme.MalformedErr(errors.New("force"))
},
},
ctx: ctx,
statusCode: 400,
problem: acme.MalformedErr(errors.New("force")),
}
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
nor := &FinalizeRequest{
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
auth: &mockAcmeAuthority{
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, o.ID)
assert.Equals(t, incsr.Raw, csr.Raw)
return &o, nil
},
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
assert.Equals(t, typ, acme.OrderLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs)
assert.Equals(t, in, []string{o.ID})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s",
acme.URLSafeProvisionerName(prov), o.ID)
},
},
ctx: ctx,
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.FinalizeOrder(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
var ae acme.AError
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
prob := tc.problem.ToACME()
assert.Equals(t, ae.Type, prob.Type)
assert.Equals(t, ae.Detail, prob.Detail)
assert.Equals(t, ae.Identifier, prob.Identifier)
assert.Equals(t, ae.Subproblems, prob.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
expB, err := json.Marshal(o)
assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s",
acme.URLSafeProvisionerName(prov), o.ID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
}
})
}
}

263
acme/authority.go Normal file
View file

@ -0,0 +1,263 @@
package acme
import (
"crypto"
"crypto/x509"
"encoding/base64"
"net"
"net/http"
"net/url"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose"
"github.com/smallstep/nosql"
)
// Interface is the acme authority interface.
type Interface interface {
DeactivateAccount(provisioner.Interface, string) (*Account, error)
FinalizeOrder(provisioner.Interface, string, string, *x509.CertificateRequest) (*Order, error)
GetAccount(provisioner.Interface, string) (*Account, error)
GetAccountByKey(provisioner.Interface, *jose.JSONWebKey) (*Account, error)
GetAuthz(provisioner.Interface, string, string) (*Authz, error)
GetCertificate(string, string) ([]byte, error)
GetDirectory(provisioner.Interface) *Directory
GetLink(Link, string, bool, ...string) string
GetOrder(provisioner.Interface, string, string) (*Order, error)
GetOrdersByAccount(provisioner.Interface, string) ([]string, error)
LoadProvisionerByID(string) (provisioner.Interface, error)
NewAccount(provisioner.Interface, AccountOptions) (*Account, error)
NewNonce() (string, error)
NewOrder(provisioner.Interface, OrderOptions) (*Order, error)
UpdateAccount(provisioner.Interface, string, []string) (*Account, error)
UseNonce(string) error
ValidateChallenge(provisioner.Interface, string, string, *jose.JSONWebKey) (*Challenge, error)
}
// Authority is the layer that handles all ACME interactions.
type Authority struct {
db nosql.DB
dir *directory
signAuth SignAuthority
}
// NewAuthority returns a new Authority that implements the ACME interface.
func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) *Authority {
return &Authority{
db: db, dir: newDirectory(dns, prefix), signAuth: signAuth,
}
}
// GetLink returns the requested link from the directory.
func (a *Authority) GetLink(typ Link, provID string, abs bool, inputs ...string) string {
return a.dir.getLink(typ, provID, abs, inputs...)
}
// GetDirectory returns the ACME directory object.
func (a *Authority) GetDirectory(p provisioner.Interface) *Directory {
name := url.PathEscape(p.GetName())
return &Directory{
NewNonce: a.dir.getLink(NewNonceLink, name, true),
NewAccount: a.dir.getLink(NewAccountLink, name, true),
NewOrder: a.dir.getLink(NewOrderLink, name, true),
RevokeCert: a.dir.getLink(RevokeCertLink, name, true),
KeyChange: a.dir.getLink(KeyChangeLink, name, true),
}
}
// LoadProvisionerByID calls out to the SignAuthority interface to load a
// provisioner by ID.
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
return a.signAuth.LoadProvisionerByID(id)
}
// NewNonce generates, stores, and returns a new ACME nonce.
func (a *Authority) NewNonce() (string, error) {
n, err := newNonce(a.db)
if err != nil {
return "", err
}
return n.ID, nil
}
// UseNonce consumes the given nonce if it is valid, returns error otherwise.
func (a *Authority) UseNonce(nonce string) error {
return useNonce(a.db, nonce)
}
// NewAccount creates, stores, and returns a new ACME account.
func (a *Authority) NewAccount(p provisioner.Interface, ao AccountOptions) (*Account, error) {
acc, err := newAccount(a.db, ao)
if err != nil {
return nil, err
}
return acc.toACME(a.db, a.dir, p)
}
// UpdateAccount updates an ACME account.
func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*Account, error) {
acc, err := getAccountByID(a.db, id)
if err != nil {
return nil, ServerInternalErr(err)
}
if acc, err = acc.update(a.db, contact); err != nil {
return nil, err
}
return acc.toACME(a.db, a.dir, p)
}
// GetAccount returns an ACME account.
func (a *Authority) GetAccount(p provisioner.Interface, id string) (*Account, error) {
acc, err := getAccountByID(a.db, id)
if err != nil {
return nil, err
}
return acc.toACME(a.db, a.dir, p)
}
// DeactivateAccount deactivates an ACME account.
func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Account, error) {
acc, err := getAccountByID(a.db, id)
if err != nil {
return nil, err
}
if acc, err = acc.deactivate(a.db); err != nil {
return nil, err
}
return acc.toACME(a.db, a.dir, p)
}
func keyToID(jwk *jose.JSONWebKey) (string, error) {
kid, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint"))
}
return base64.RawURLEncoding.EncodeToString(kid), nil
}
// GetAccountByKey returns the ACME associated with the jwk id.
func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*Account, error) {
kid, err := keyToID(jwk)
if err != nil {
return nil, err
}
acc, err := getAccountByKeyID(a.db, kid)
if err != nil {
return nil, err
}
return acc.toACME(a.db, a.dir, p)
}
// GetOrder returns an ACME order.
func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*Order, error) {
o, err := getOrder(a.db, orderID)
if err != nil {
return nil, err
}
if accID != o.AccountID {
return nil, UnauthorizedErr(errors.New("account does not own order"))
}
if o, err = o.updateStatus(a.db); err != nil {
return nil, err
}
return o.toACME(a.db, a.dir, p)
}
// GetOrdersByAccount returns the list of order urls owned by the account.
func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) {
oids, err := getOrderIDsByAccount(a.db, id)
if err != nil {
return nil, err
}
var ret = []string{}
for _, oid := range oids {
o, err := getOrder(a.db, oid)
if err != nil {
return nil, ServerInternalErr(err)
}
if o.Status == StatusInvalid {
continue
}
ret = append(ret, a.dir.getLink(OrderLink, URLSafeProvisionerName(p), true, o.ID))
}
return ret, nil
}
// NewOrder generates, stores, and returns a new ACME order.
func (a *Authority) NewOrder(p provisioner.Interface, ops OrderOptions) (*Order, error) {
order, err := newOrder(a.db, ops)
if err != nil {
return nil, Wrap(err, "error creating order")
}
return order.toACME(a.db, a.dir, p)
}
// FinalizeOrder attempts to finalize an order and generate a new certificate.
func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
o, err := getOrder(a.db, orderID)
if err != nil {
return nil, err
}
if accID != o.AccountID {
return nil, UnauthorizedErr(errors.New("account does not own order"))
}
o, err = o.finalize(a.db, csr, a.signAuth, p)
if err != nil {
return nil, Wrap(err, "error finalizing order")
}
return o.toACME(a.db, a.dir, p)
}
// GetAuthz retrieves and attempts to update the status on an ACME authz
// before returning.
func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*Authz, error) {
az, err := getAuthz(a.db, authzID)
if err != nil {
return nil, err
}
if accID != az.getAccountID() {
return nil, UnauthorizedErr(errors.New("account does not own authz"))
}
az, err = az.updateStatus(a.db)
if err != nil {
return nil, Wrap(err, "error updating authz status")
}
return az.toACME(a.db, a.dir, p)
}
// ValidateChallenge attempts to validate the challenge.
func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) {
ch, err := getChallenge(a.db, chID)
if err != nil {
return nil, err
}
if accID != ch.getAccountID() {
return nil, UnauthorizedErr(errors.New("account does not own challenge"))
}
client := http.Client{
Timeout: time.Duration(30 * time.Second),
}
ch, err = ch.validate(a.db, jwk, validateOptions{
httpGet: client.Get,
lookupTxt: net.LookupTXT,
})
if err != nil {
return nil, Wrap(err, "error attempting challenge validation")
}
return ch.toACME(a.db, a.dir, p)
}
// GetCertificate retrieves the Certificate by ID.
func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) {
cert, err := getCert(a.db, certID)
if err != nil {
return nil, err
}
if accID != cert.AccountID {
return nil, UnauthorizedErr(errors.New("account does not own certificate"))
}
return cert.toACME(a.db, a.dir)
}

1474
acme/authority_test.go Normal file

File diff suppressed because it is too large Load diff

337
acme/authz.go Normal file
View file

@ -0,0 +1,337 @@
package acme
import (
"encoding/json"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/nosql"
)
var defaultExpiryDuration = time.Hour * 24
// Authz is a subset of the Authz type containing only those attributes
// required for responses in the ACME protocol.
type Authz struct {
Identifier Identifier `json:"identifier"`
Status string `json:"status"`
Expires string `json:"expires"`
Challenges []*Challenge `json:"challenges"`
Wildcard bool `json:"wildcard"`
ID string `json:"-"`
}
// ToLog enables response logging.
func (a *Authz) ToLog() (interface{}, error) {
b, err := json.Marshal(a)
if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging"))
}
return string(b), nil
}
// GetID returns the Authz ID.
func (a *Authz) GetID() string {
return a.ID
}
// authz is the interface that the various authz types must implement.
type authz interface {
save(nosql.DB, authz) error
clone() *baseAuthz
getID() string
getAccountID() string
getType() string
getIdentifier() Identifier
getStatus() string
getExpiry() time.Time
getWildcard() bool
getChallenges() []string
getCreated() time.Time
updateStatus(db nosql.DB) (authz, error)
toACME(nosql.DB, *directory, provisioner.Interface) (*Authz, error)
}
// baseAuthz is the base authz type that others build from.
type baseAuthz struct {
ID string `json:"id"`
AccountID string `json:"accountID"`
Identifier Identifier `json:"identifier"`
Status string `json:"status"`
Expires time.Time `json:"expires"`
Challenges []string `json:"challenges"`
Wildcard bool `json:"wildcard"`
Created time.Time `json:"created"`
Error *Error `json:"error"`
}
func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) {
id, err := randID()
if err != nil {
return nil, err
}
now := clock.Now()
ba := &baseAuthz{
ID: id,
AccountID: accID,
Status: StatusPending,
Created: now,
Expires: now.Add(defaultExpiryDuration),
Identifier: identifier,
}
if strings.HasPrefix(identifier.Value, "*.") {
ba.Wildcard = true
ba.Identifier = Identifier{
Value: strings.TrimPrefix(identifier.Value, "*."),
Type: identifier.Type,
}
}
return ba, nil
}
// getID returns the ID of the authz.
func (ba *baseAuthz) getID() string {
return ba.ID
}
// getAccountID returns the Account ID that created the authz.
func (ba *baseAuthz) getAccountID() string {
return ba.AccountID
}
// getType returns the type of the authz.
func (ba *baseAuthz) getType() string {
return ba.Identifier.Type
}
// getIdentifier returns the identifier for the authz.
func (ba *baseAuthz) getIdentifier() Identifier {
return ba.Identifier
}
// getStatus returns the status of the authz.
func (ba *baseAuthz) getStatus() string {
return ba.Status
}
// getWildcard returns true if the authz identifier has a '*', false otherwise.
func (ba *baseAuthz) getWildcard() bool {
return ba.Wildcard
}
// getChallenges returns the authz challenge IDs.
func (ba *baseAuthz) getChallenges() []string {
return ba.Challenges
}
// getExpiry returns the expiration time of the authz.
func (ba *baseAuthz) getExpiry() time.Time {
return ba.Expires
}
// getCreated returns the created time of the authz.
func (ba *baseAuthz) getCreated() time.Time {
return ba.Created
}
// toACME converts the internal Authz type into the public acmeAuthz type for
// presentation in the ACME protocol.
func (ba *baseAuthz) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Authz, error) {
var chs = make([]*Challenge, len(ba.Challenges))
for i, chID := range ba.Challenges {
ch, err := getChallenge(db, chID)
if err != nil {
return nil, err
}
chs[i], err = ch.toACME(db, dir, p)
if err != nil {
return nil, err
}
}
return &Authz{
Identifier: ba.Identifier,
Status: ba.getStatus(),
Challenges: chs,
Wildcard: ba.getWildcard(),
Expires: ba.Expires.Format(time.RFC3339),
ID: ba.ID,
}, nil
}
func (ba *baseAuthz) save(db nosql.DB, old authz) error {
var (
err error
oldB, newB []byte
)
if old == nil {
oldB = nil
} else {
if oldB, err = json.Marshal(old); err != nil {
return ServerInternalErr(errors.Wrap(err, "error marshaling old authz"))
}
}
if newB, err = json.Marshal(ba); err != nil {
return ServerInternalErr(errors.Wrap(err, "error marshaling new authz"))
}
_, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB)
switch {
case err != nil:
return ServerInternalErr(errors.Wrapf(err, "error storing authz"))
case !swapped:
return ServerInternalErr(errors.Errorf("error storing authz; " +
"value has changed since last read"))
default:
return nil
}
}
func (ba *baseAuthz) clone() *baseAuthz {
u := *ba
return &u
}
func (ba *baseAuthz) parent() authz {
return &dnsAuthz{ba}
}
// updateStatus attempts to update the status on a baseAuthz and stores the
// updating object if necessary.
func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) {
newAuthz := ba.clone()
now := time.Now().UTC()
switch ba.Status {
case StatusInvalid:
return ba.parent(), nil
case StatusValid:
return ba.parent(), nil
case StatusPending:
// check expiry
if now.After(ba.Expires) {
newAuthz.Status = StatusInvalid
newAuthz.Error = MalformedErr(errors.New("authz has expired"))
break
}
var isValid = false
for _, chID := range ba.Challenges {
ch, err := getChallenge(db, chID)
if err != nil {
return ba, err
}
if ch.getStatus() == StatusValid {
isValid = true
break
}
}
if !isValid {
return ba.parent(), nil
}
newAuthz.Status = StatusValid
newAuthz.Error = nil
default:
return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status))
}
if err := newAuthz.save(db, ba); err != nil {
return ba, err
}
return newAuthz.parent(), nil
}
// unmarshalAuthz unmarshals an authz type into the correct sub-type.
func unmarshalAuthz(data []byte) (authz, error) {
var getType struct {
Identifier Identifier `json:"identifier"`
}
if err := json.Unmarshal(data, &getType); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type"))
}
switch getType.Identifier.Type {
case "dns":
var ba baseAuthz
if err := json.Unmarshal(data, &ba); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz"))
}
return &dnsAuthz{&ba}, nil
default:
return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s",
getType.Identifier.Type))
}
}
// dnsAuthz represents a dns acme authorization.
type dnsAuthz struct {
*baseAuthz
}
// newAuthz returns a new acme authorization object based on the identifier
// type.
func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) {
switch identifier.Type {
case "dns":
a, err = newDNSAuthz(db, accID, identifier)
default:
err = MalformedErr(errors.Errorf("unexpected authz type %s",
identifier.Type))
}
return
}
// newDNSAuthz returns a new dns acme authorization object.
func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) {
ba, err := newBaseAuthz(accID, identifier)
if err != nil {
return nil, err
}
ba.Challenges = []string{}
if !ba.Wildcard {
// http challenges are only permitted if the DNS is not a wildcard dns.
ch1, err := newHTTP01Challenge(db, ChallengeOptions{
AccountID: accID,
AuthzID: ba.ID,
Identifier: ba.Identifier})
if err != nil {
return nil, Wrap(err, "error creating http challenge")
}
ba.Challenges = append(ba.Challenges, ch1.getID())
}
ch2, err := newDNS01Challenge(db, ChallengeOptions{
AccountID: accID,
AuthzID: ba.ID,
Identifier: identifier})
if err != nil {
return nil, Wrap(err, "error creating dns challenge")
}
ba.Challenges = append(ba.Challenges, ch2.getID())
da := &dnsAuthz{ba}
if err := da.save(db, nil); err != nil {
return nil, err
}
return da, nil
}
// getAuthz retrieves and unmarshals an ACME authz type from the database.
func getAuthz(db nosql.DB, id string) (authz, error) {
b, err := db.Get(authzTable, []byte(id))
if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id))
} else if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id))
}
az, err := unmarshalAuthz(b)
if err != nil {
return nil, err
}
return az, nil
}

809
acme/authz_test.go Normal file
View file

@ -0,0 +1,809 @@
package acme
import (
"encoding/json"
"strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
func newAz() (authz, error) {
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), true, nil
},
}
return newAuthz(mockdb, "1234", Identifier{
Type: "dns", Value: "acme.example.com",
})
}
func TestGetAuthz(t *testing.T) {
type test struct {
id string
db nosql.DB
az authz
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
id: az.getID(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
},
err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
}
},
"fail/db-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
id: az.getID(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Identifier.Type = "foo"
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
id: az.getID(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authzTable)
assert.Equals(t, key, []byte(az.getID()))
return b, nil
},
},
err: ServerInternalErr(errors.New("unexpected authz type foo")),
}
},
"ok": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
id: az.getID(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authzTable)
assert.Equals(t, key, []byte(az.getID()))
return b, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if az, err := getAuthz(tc.db, tc.id); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.az.getID(), az.getID())
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
assert.Equals(t, tc.az.getStatus(), az.getStatus())
assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
assert.Equals(t, tc.az.getCreated(), az.getCreated())
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
}
}
})
}
}
func TestAuthzClone(t *testing.T) {
az, err := newAz()
assert.FatalError(t, err)
clone := az.clone()
assert.Equals(t, clone.getID(), az.getID())
assert.Equals(t, clone.getAccountID(), az.getAccountID())
assert.Equals(t, clone.getStatus(), az.getStatus())
assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
assert.Equals(t, clone.getExpiry(), az.getExpiry())
assert.Equals(t, clone.getCreated(), az.getCreated())
assert.Equals(t, clone.getChallenges(), az.getChallenges())
clone.Status = StatusValid
assert.NotEquals(t, clone.getStatus(), az.getStatus())
}
func TestNewAuthz(t *testing.T) {
iden := Identifier{
Type: "dns", Value: "acme.example.com",
}
accID := "1234"
type test struct {
iden Identifier
db nosql.DB
err *Error
resChs *([]string)
}
tests := map[string]func(t *testing.T) test{
"fail/unexpected-type": func(t *testing.T) test {
return test{
iden: Identifier{Type: "foo", Value: "acme.example.com"},
err: MalformedErr(errors.New("unexpected authz type foo")),
}
},
"fail/new-http-chall-error": func(t *testing.T) test {
return test{
iden: iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
}
},
"fail/new-dns-chall-error": func(t *testing.T) test {
count := 0
return test{
iden: iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 1 {
return nil, false, errors.New("force")
}
count++
return nil, true, nil
},
},
err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
}
},
"fail/save-authz-error": func(t *testing.T) test {
count := 0
return test{
iden: iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 2 {
return nil, false, errors.New("force")
}
count++
return nil, true, nil
},
},
err: ServerInternalErr(errors.New("error storing authz: force")),
}
},
"ok": func(t *testing.T) test {
chs := &([]string{})
count := 0
return test{
iden: iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 2 {
assert.Equals(t, bucket, authzTable)
assert.Equals(t, old, nil)
az, err := unmarshalAuthz(newval)
assert.FatalError(t, err)
assert.Equals(t, az.getID(), string(key))
assert.Equals(t, az.getAccountID(), accID)
assert.Equals(t, az.getStatus(), StatusPending)
assert.Equals(t, az.getIdentifier(), iden)
assert.Equals(t, az.getWildcard(), false)
*chs = az.getChallenges()
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
expiry := az.getCreated().Add(defaultExpiryDuration)
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
}
count++
return nil, true, nil
},
},
resChs: chs,
}
},
"ok/wildcard": func(t *testing.T) test {
chs := &([]string{})
count := 0
_iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
return test{
iden: _iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 1 {
assert.Equals(t, bucket, authzTable)
assert.Equals(t, old, nil)
az, err := unmarshalAuthz(newval)
assert.FatalError(t, err)
assert.Equals(t, az.getID(), string(key))
assert.Equals(t, az.getAccountID(), accID)
assert.Equals(t, az.getStatus(), StatusPending)
assert.Equals(t, az.getIdentifier(), iden)
assert.Equals(t, az.getWildcard(), true)
*chs = az.getChallenges()
// Verify that we only have 1 challenge instead of 2.
assert.True(t, len(*chs) == 1)
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
expiry := az.getCreated().Add(defaultExpiryDuration)
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
}
count++
return nil, true, nil
},
},
resChs: chs,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
az, err := newAuthz(tc.db, accID, tc.iden)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, az.getAccountID(), accID)
assert.Equals(t, az.getType(), "dns")
assert.Equals(t, az.getStatus(), StatusPending)
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
expiry := az.getCreated().Add(defaultExpiryDuration)
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
assert.Equals(t, az.getChallenges(), *(tc.resChs))
if strings.HasPrefix(tc.iden.Value, "*.") {
assert.True(t, az.getWildcard())
assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
} else {
assert.False(t, az.getWildcard())
assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
}
assert.True(t, az.getID() != "")
}
}
})
}
}
func TestAuthzToACME(t *testing.T) {
dir := newDirectory("ca.smallstep.com", "acme")
var (
ch1, ch2 challenge
ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
err error
)
count := 0
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
*ch1Bytes = newval
ch1, err = unmarshalChallenge(newval)
assert.FatalError(t, err)
} else if count == 1 {
*ch2Bytes = newval
ch2, err = unmarshalChallenge(newval)
assert.FatalError(t, err)
}
count++
return []byte("foo"), true, nil
},
}
iden := Identifier{
Type: "dns", Value: "acme.example.com",
}
az, err := newAuthz(mockdb, "1234", iden)
assert.FatalError(t, err)
prov := newProv()
type test struct {
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/getChallenge1-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading challenge")),
}
},
"fail/getChallenge2-error": func(t *testing.T) test {
count := 0
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 1 {
return nil, errors.New("force")
}
count++
return *ch1Bytes, nil
},
},
err: ServerInternalErr(errors.New("error loading challenge")),
}
},
"ok": func(t *testing.T) test {
count := 0
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 0 {
count++
return *ch1Bytes, nil
}
return *ch2Bytes, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
acmeAz, err := az.toACME(tc.db, dir, prov)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, acmeAz.ID, az.getID())
assert.Equals(t, acmeAz.Identifier, iden)
assert.Equals(t, acmeAz.Status, StatusPending)
acmeCh1, err := ch1.toACME(nil, dir, prov)
assert.FatalError(t, err)
acmeCh2, err := ch2.toACME(nil, dir, prov)
assert.FatalError(t, err)
assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
assert.FatalError(t, err)
assert.Equals(t, expiry.String(), az.getExpiry().String())
}
}
})
}
}
func TestAuthzSave(t *testing.T) {
type test struct {
az, old authz
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/old-nil/swap-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing authz: force")),
}
},
"fail/old-nil/swap-false": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), false, nil
},
},
err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
}
},
"ok/old-nil": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, old, nil)
assert.Equals(t, b, newval)
assert.Equals(t, bucket, authzTable)
assert.Equals(t, []byte(az.getID()), key)
return nil, true, nil
},
},
}
},
"ok/old-not-nil": func(t *testing.T) test {
oldAz, err := newAz()
assert.FatalError(t, err)
az, err := newAz()
assert.FatalError(t, err)
oldb, err := json.Marshal(oldAz)
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
old: oldAz,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, old, oldb)
assert.Equals(t, b, newval)
assert.Equals(t, bucket, authzTable)
assert.Equals(t, []byte(az.getID()), key)
return []byte("foo"), true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := tc.az.save(tc.db, tc.old); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestAuthzUnmarshal(t *testing.T) {
type test struct {
az authz
azb []byte
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/nil": func(t *testing.T) test {
return test{
azb: nil,
err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
}
},
"fail/unexpected-type": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Identifier.Type = "foo"
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
azb: b,
err: ServerInternalErr(errors.New("unexpected authz type foo")),
}
},
"ok/dns": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
azb: b,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if az, err := unmarshalAuthz(tc.azb); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.az.getID(), az.getID())
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
assert.Equals(t, tc.az.getStatus(), az.getStatus())
assert.Equals(t, tc.az.getCreated(), az.getCreated())
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
}
}
})
}
}
func TestAuthzUpdateStatus(t *testing.T) {
type test struct {
az, res authz
err *Error
db nosql.DB
}
tests := map[string]func(t *testing.T) test{
"fail/already-invalid": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Status = StatusInvalid
return test{
az: az,
res: az,
}
},
"fail/already-valid": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Status = StatusValid
return test{
az: az,
res: az,
}
},
"fail/unexpected-status": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Status = StatusReady
return test{
az: az,
res: az,
err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
}
},
"fail/save-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
return test{
az: az,
res: az,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing authz: force")),
}
},
"ok/expired": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
clone := az.clone()
clone.Error = MalformedErr(errors.New("authz has expired"))
clone.Status = StatusInvalid
return test{
az: az,
res: clone.parent(),
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, true, nil
},
},
}
},
"fail/get-challenge-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
res: az,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading challenge")),
}
},
"ok/valid": func(t *testing.T) test {
var (
ch2 challenge
ch1Bytes = &([]byte{})
err error
)
count := 0
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
*ch1Bytes = newval
} else if count == 1 {
ch2, err = unmarshalChallenge(newval)
assert.FatalError(t, err)
}
count++
return nil, true, nil
},
}
iden := Identifier{
Type: "dns", Value: "acme.example.com",
}
az, err := newAuthz(mockdb, "1234", iden)
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Error = MalformedErr(nil)
_ch, ok := ch2.(*dns01Challenge)
assert.Fatal(t, ok)
_ch.baseChallenge.Status = StatusValid
chb, err := json.Marshal(ch2)
clone := az.clone()
clone.Status = StatusValid
clone.Error = nil
count = 0
return test{
az: az,
res: clone.parent(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 0 {
count++
return *ch1Bytes, nil
}
count++
return chb, nil
},
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, true, nil
},
},
}
},
"ok/still-pending": func(t *testing.T) test {
var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
count := 0
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
*ch1Bytes = newval
} else if count == 1 {
*ch2Bytes = newval
}
count++
return nil, true, nil
},
}
iden := Identifier{
Type: "dns", Value: "acme.example.com",
}
az, err := newAuthz(mockdb, "1234", iden)
assert.FatalError(t, err)
count = 0
return test{
az: az,
res: az,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 0 {
count++
return *ch1Bytes, nil
}
count++
return *ch2Bytes, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
az, err := tc.az.updateStatus(tc.db)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
expB, err := json.Marshal(tc.res)
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
assert.Equals(t, expB, b)
}
}
})
}
}

89
acme/certificate.go Normal file
View file

@ -0,0 +1,89 @@
package acme
import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"time"
"github.com/pkg/errors"
"github.com/smallstep/nosql"
)
type certificate struct {
ID string `json:"id"`
Created time.Time `json:"created"`
AccountID string `json:"accountID"`
OrderID string `json:"orderID"`
Leaf []byte `json:"leaf"`
Intermediates []byte `json:"intermediates"`
}
// CertOptions options with which to create and store a cert object.
type CertOptions struct {
AccountID string
OrderID string
Leaf *x509.Certificate
Intermediates []*x509.Certificate
}
func newCert(db nosql.DB, ops CertOptions) (*certificate, error) {
id, err := randID()
if err != nil {
return nil, err
}
leaf := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: ops.Leaf.Raw,
})
var intermediates []byte
for _, cert := range ops.Intermediates {
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
})...)
}
cert := &certificate{
ID: id,
AccountID: ops.AccountID,
OrderID: ops.OrderID,
Leaf: leaf,
Intermediates: intermediates,
Created: time.Now().UTC(),
}
certB, err := json.Marshal(cert)
if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate"))
}
_, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB)
switch {
case err != nil:
return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate"))
case !swapped:
return nil, ServerInternalErr(errors.New("error storing certificate; " +
"value has changed since last read"))
default:
return cert, nil
}
}
func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) {
return append(c.Leaf, c.Intermediates...), nil
}
func getCert(db nosql.DB, id string) (*certificate, error) {
b, err := db.Get(certTable, []byte(id))
if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id))
} else if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate"))
}
var cert certificate
if err := json.Unmarshal(b, &cert); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate"))
}
return &cert, nil
}

253
acme/certificate_test.go Normal file
View file

@ -0,0 +1,253 @@
package acme
import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/db"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
func defaultCertOps() (*CertOptions, error) {
crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt")
if err != nil {
return nil, err
}
inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt")
if err != nil {
return nil, err
}
root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt")
if err != nil {
return nil, err
}
return &CertOptions{
AccountID: "accID",
OrderID: "ordID",
Leaf: crt,
Intermediates: []*x509.Certificate{inter, root},
}, nil
}
func newcert() (*certificate, error) {
ops, err := defaultCertOps()
if err != nil {
return nil, err
}
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, true, nil
},
}
return newCert(mockdb, *ops)
}
func TestNewCert(t *testing.T) {
type test struct {
db nosql.DB
ops CertOptions
err *Error
id *string
}
tests := map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test {
ops, err := defaultCertOps()
assert.FatalError(t, err)
return test{
ops: *ops,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, old, nil)
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.Errorf("error storing certificate: force")),
}
},
"fail/cmpAndSwap-false": func(t *testing.T) test {
ops, err := defaultCertOps()
assert.FatalError(t, err)
return test{
ops: *ops,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, old, nil)
return nil, false, nil
},
},
err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")),
}
},
"ok": func(t *testing.T) test {
ops, err := defaultCertOps()
assert.FatalError(t, err)
var _id string
id := &_id
return test{
ops: *ops,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, old, nil)
*id = string(key)
return nil, true, nil
},
},
id: id,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if cert, err := newCert(tc.db, tc.ops); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, cert.ID, *tc.id)
assert.Equals(t, cert.AccountID, tc.ops.AccountID)
assert.Equals(t, cert.OrderID, tc.ops.OrderID)
leaf := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: tc.ops.Leaf.Raw,
})
var intermediates []byte
for _, cert := range tc.ops.Intermediates {
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
})...)
}
assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Intermediates, intermediates)
assert.True(t, cert.Created.Before(time.Now().Add(time.Minute)))
assert.True(t, cert.Created.After(time.Now().Add(-time.Minute)))
}
}
})
}
}
func TestGetCert(t *testing.T) {
type test struct {
id string
db nosql.DB
cert *certificate
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
cert, err := newcert()
assert.FatalError(t, err)
return test{
cert: cert,
id: cert.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, key, []byte(cert.ID))
return nil, database.ErrNotFound
},
},
err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)),
}
},
"fail/db-error": func(t *testing.T) test {
cert, err := newcert()
assert.FatalError(t, err)
return test{
cert: cert,
id: cert.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, key, []byte(cert.ID))
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading certificate: force")),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
cert, err := newcert()
assert.FatalError(t, err)
return test{
cert: cert,
id: cert.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, key, []byte(cert.ID))
return nil, nil
},
},
err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")),
}
},
"ok": func(t *testing.T) test {
cert, err := newcert()
assert.FatalError(t, err)
b, err := json.Marshal(cert)
assert.FatalError(t, err)
return test{
cert: cert,
id: cert.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, certTable)
assert.Equals(t, key, []byte(cert.ID))
return b, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if cert, err := getCert(tc.db, tc.id); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.cert.ID, cert.ID)
assert.Equals(t, tc.cert.AccountID, cert.AccountID)
assert.Equals(t, tc.cert.OrderID, cert.OrderID)
assert.Equals(t, tc.cert.Created, cert.Created)
assert.Equals(t, tc.cert.Leaf, cert.Leaf)
assert.Equals(t, tc.cert.Intermediates, cert.Intermediates)
}
}
})
}
}
func TestCertificateToACME(t *testing.T) {
cert, err := newcert()
assert.FatalError(t, err)
acmeCert, err := cert.toACME(nil, nil)
assert.FatalError(t, err)
assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert)
}

445
acme/challenge.go Normal file
View file

@ -0,0 +1,445 @@
package acme
import (
"crypto"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose"
"github.com/smallstep/nosql"
)
// Challenge is a subset of the challenge type containing only those attributes
// required for responses in the ACME protocol.
type Challenge struct {
Type string `json:"type"`
Status string `json:"status"`
Token string `json:"token"`
Validated string `json:"validated,omitempty"`
URL string `json:"url"`
Error *AError `json:"error,omitempty"`
ID string `json:"-"`
AuthzID string `json:"-"`
}
// ToLog enables response logging.
func (c *Challenge) ToLog() (interface{}, error) {
b, err := json.Marshal(c)
if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging"))
}
return string(b), nil
}
// GetID returns the Challenge ID.
func (c *Challenge) GetID() string {
return c.ID
}
// GetAuthzID returns the parent Authz ID that owns the Challenge.
func (c *Challenge) GetAuthzID() string {
return c.AuthzID
}
type httpGetter func(string) (*http.Response, error)
type lookupTxt func(string) ([]string, error)
type validateOptions struct {
httpGet httpGetter
lookupTxt lookupTxt
}
// challenge is the interface ACME challenege types must implement.
type challenge interface {
save(db nosql.DB, swap challenge) error
validate(nosql.DB, *jose.JSONWebKey, validateOptions) (challenge, error)
getType() string
getError() *AError
getValue() string
getStatus() string
getID() string
getAuthzID() string
getToken() string
clone() *baseChallenge
getAccountID() string
getValidated() time.Time
getCreated() time.Time
toACME(nosql.DB, *directory, provisioner.Interface) (*Challenge, error)
}
// ChallengeOptions is the type used to created a new Challenge.
type ChallengeOptions struct {
AccountID string
AuthzID string
Identifier Identifier
}
// baseChallenge is the base Challenge type that others build from.
type baseChallenge struct {
ID string `json:"id"`
AccountID string `json:"accountID"`
AuthzID string `json:"authzID"`
Type string `json:"type"`
Status string `json:"status"`
Token string `json:"token"`
Value string `json:"value"`
Validated time.Time `json:"validated"`
Created time.Time `json:"created"`
Error *AError `json:"error"`
}
func newBaseChallenge(accountID, authzID string) (*baseChallenge, error) {
id, err := randID()
if err != nil {
return nil, Wrap(err, "error generating random id for ACME challenge")
}
token, err := randID()
if err != nil {
return nil, Wrap(err, "error generating token for ACME challenge")
}
return &baseChallenge{
ID: id,
AccountID: accountID,
AuthzID: authzID,
Status: StatusPending,
Token: token,
Created: clock.Now(),
}, nil
}
// getID returns the id of the baseChallenge.
func (bc *baseChallenge) getID() string {
return bc.ID
}
// getAuthzID returns the authz ID of the baseChallenge.
func (bc *baseChallenge) getAuthzID() string {
return bc.AuthzID
}
// getAccountID returns the account id of the baseChallenge.
func (bc *baseChallenge) getAccountID() string {
return bc.AccountID
}
// getType returns the type of the baseChallenge.
func (bc *baseChallenge) getType() string {
return bc.Type
}
// getValue returns the type of the baseChallenge.
func (bc *baseChallenge) getValue() string {
return bc.Value
}
// getStatus returns the status of the baseChallenge.
func (bc *baseChallenge) getStatus() string {
return bc.Status
}
// getToken returns the token of the baseChallenge.
func (bc *baseChallenge) getToken() string {
return bc.Token
}
// getValidated returns the validated time of the baseChallenge.
func (bc *baseChallenge) getValidated() time.Time {
return bc.Validated
}
// getCreated returns the created time of the baseChallenge.
func (bc *baseChallenge) getCreated() time.Time {
return bc.Created
}
// getCreated returns the created time of the baseChallenge.
func (bc *baseChallenge) getError() *AError {
return bc.Error
}
// toACME converts the internal Challenge type into the public acmeChallenge
// type for presentation in the ACME protocol.
func (bc *baseChallenge) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Challenge, error) {
ac := &Challenge{
Type: bc.getType(),
Status: bc.getStatus(),
Token: bc.getToken(),
URL: dir.getLink(ChallengeLink, URLSafeProvisionerName(p), true, bc.getID()),
ID: bc.getID(),
AuthzID: bc.getAuthzID(),
}
if !bc.Validated.IsZero() {
ac.Validated = bc.Validated.Format(time.RFC3339)
}
if bc.Error != nil {
ac.Error = bc.Error
}
return ac, nil
}
// save writes the challenge to disk. For new challenges 'old' should be nil,
// otherwise 'old' should be a pointer to the acme challenge as it was at the
// start of the request. This method will fail if the value currently found
// in the bucket/row does not match the value of 'old'.
func (bc *baseChallenge) save(db nosql.DB, old challenge) error {
newB, err := json.Marshal(bc)
if err != nil {
return ServerInternalErr(errors.Wrap(err,
"error marshaling new acme challenge"))
}
var oldB []byte
if old == nil {
oldB = nil
} else {
oldB, err = json.Marshal(old)
if err != nil {
return ServerInternalErr(errors.Wrap(err,
"error marshaling old acme challenge"))
}
}
_, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB)
switch {
case err != nil:
return ServerInternalErr(errors.Wrap(err, "error saving acme challenge"))
case !swapped:
return ServerInternalErr(errors.New("error saving acme challenge; " +
"acme challenge has changed since last read"))
default:
return nil
}
}
func (bc *baseChallenge) clone() *baseChallenge {
u := *bc
return &u
}
func (bc *baseChallenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
return nil, ServerInternalErr(errors.New("unimplemented"))
}
func (bc *baseChallenge) storeError(db nosql.DB, err *Error) error {
clone := bc.clone()
clone.Error = err.ToACME()
return clone.save(db, bc)
}
// unmarshalChallenge unmarshals a challenge type into the correct sub-type.
func unmarshalChallenge(data []byte) (challenge, error) {
var getType struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &getType); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type"))
}
switch getType.Type {
case "dns-01":
var bc baseChallenge
if err := json.Unmarshal(data, &bc); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
"challenge type into dns01Challenge"))
}
return &dns01Challenge{&bc}, nil
case "http-01":
var bc baseChallenge
if err := json.Unmarshal(data, &bc); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
"challenge type into http01Challenge"))
}
return &http01Challenge{&bc}, nil
default:
return nil, ServerInternalErr(errors.Errorf("unexpected challenge type %s", getType.Type))
}
}
// http01Challenge represents an http-01 acme challenge.
type http01Challenge struct {
*baseChallenge
}
// newHTTP01Challenge returns a new acme http-01 challenge.
func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
if err != nil {
return nil, err
}
bc.Type = "http-01"
bc.Value = ops.Identifier.Value
hc := &http01Challenge{bc}
if err := hc.save(db, nil); err != nil {
return nil, err
}
return hc, nil
}
// Validate attempts to validate the challenge. If the challenge has been
// satisfactorily validated, the 'status' and 'validated' attributes are
// updated.
func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
// If already valid or invalid then return without performing validation.
if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid {
return hc, nil
}
url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token)
resp, err := vo.httpGet(url)
if err != nil {
if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err,
"error doing http GET for url %s", url))); err != nil {
return nil, err
}
return hc, nil
}
if resp.StatusCode >= 400 {
if err = hc.storeError(db,
ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d",
url, resp.StatusCode))); err != nil {
return nil, err
}
return hc, nil
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+
"response body for url %s", url))
}
keyAuth := strings.Trim(string(body), "\r\n")
expected, err := KeyAuthorization(hc.Token, jwk)
if err != nil {
return nil, err
}
if keyAuth != expected {
if err = hc.storeError(db,
RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+
"expected %s, but got %s", expected, keyAuth))); err != nil {
return nil, err
}
return hc, nil
}
// Update and store the challenge.
upd := &http01Challenge{hc.baseChallenge.clone()}
upd.Status = StatusValid
upd.Error = nil
upd.Validated = clock.Now()
if err := upd.save(db, hc); err != nil {
return nil, err
}
return upd, nil
}
// dns01Challenge represents an dns-01 acme challenge.
type dns01Challenge struct {
*baseChallenge
}
// newDNS01Challenge returns a new acme dns-01 challenge.
func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
if err != nil {
return nil, err
}
bc.Type = "dns-01"
bc.Value = ops.Identifier.Value
dc := &dns01Challenge{bc}
if err := dc.save(db, nil); err != nil {
return nil, err
}
return dc, nil
}
// KeyAuthorization creates the ACME key authorization value from a token
// and a jwk.
func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint"))
}
encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
return fmt.Sprintf("%s.%s", token, encPrint), nil
}
// validate attempts to validate the challenge. If the challenge has been
// satisfactorily validated, the 'status' and 'validated' attributes are
// updated.
func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
// If already valid or invalid then return without performing validation.
if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid {
return dc, nil
}
txtRecords, err := vo.lookupTxt("_acme-challenge." + dc.Value)
if err != nil {
if err = dc.storeError(db,
DNSErr(errors.Wrapf(err, "error looking up TXT "+
"records for domain %s", dc.Value))); err != nil {
return nil, err
}
return dc, nil
}
expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk)
if err != nil {
return nil, err
}
h := sha256.Sum256([]byte(expectedKeyAuth))
expected := base64.RawURLEncoding.EncodeToString(h[:])
var found bool
for _, r := range txtRecords {
if r == expected {
found = true
break
}
}
if !found {
if err = dc.storeError(db,
RejectedIdentifierErr(errors.Errorf("keyAuthorization "+
"does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil {
return nil, err
}
return dc, nil
}
// Update and store the challenge.
upd := &dns01Challenge{dc.baseChallenge.clone()}
upd.Status = StatusValid
upd.Error = nil
upd.Validated = time.Now().UTC()
if err := upd.save(db, dc); err != nil {
return nil, err
}
return upd, nil
}
// getChallenge retrieves and unmarshals an ACME challenge type from the database.
func getChallenge(db nosql.DB, id string) (challenge, error) {
b, err := db.Get(challengeTable, []byte(id))
if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id))
} else if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id))
}
ch, err := unmarshalChallenge(b)
if err != nil {
return nil, err
}
return ch, nil
}

1093
acme/challenge_test.go Normal file

File diff suppressed because it is too large Load diff

76
acme/common.go Normal file
View file

@ -0,0 +1,76 @@
package acme
import (
"crypto/x509"
"net/url"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/randutil"
)
// SignAuthority is the interface implemented by a CA authority.
type SignAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
LoadProvisionerByID(string) (provisioner.Interface, error)
}
// Identifier encodes the type that an order pertains to.
type Identifier struct {
Type string `json:"type"`
Value string `json:"value"`
}
var (
accountTable = []byte("acme-accounts")
accountByKeyIDTable = []byte("acme-keyID-accountID-index")
authzTable = []byte("acme-authzs")
challengeTable = []byte("acme-challenges")
nonceTable = []byte("nonce-table")
orderTable = []byte("acme-orders")
ordersByAccountIDTable = []byte("acme-account-orders-index")
certTable = []byte("acme-certs")
)
var (
// StatusValid -- valid
StatusValid = "valid"
// StatusInvalid -- invalid
StatusInvalid = "invalid"
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
StatusPending = "pending"
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
StatusDeactivated = "deactivated"
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
StatusReady = "ready"
//statusExpired = "expired"
//statusActive = "active"
//statusProcessing = "processing"
)
var idLen = 32
func randID() (val string, err error) {
val, err = randutil.Alphanumeric(idLen)
if err != nil {
return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID"))
}
return val, nil
}
// Clock that returns time in UTC rounded to seconds.
type Clock int
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Round(time.Second)
}
var clock = new(Clock)
// URLSafeProvisionerName returns a path escaped version of the ACME provisioner
// ID that is safe to use in URL paths.
func URLSafeProvisionerName(p provisioner.Interface) string {
return url.PathEscape(p.GetName())
}

120
acme/directory.go Normal file
View file

@ -0,0 +1,120 @@
package acme
import (
"encoding/json"
"fmt"
"github.com/pkg/errors"
)
// Directory represents an ACME directory for configuring clients.
type Directory struct {
NewNonce string `json:"newNonce,omitempty"`
NewAccount string `json:"newAccount,omitempty"`
NewOrder string `json:"newOrder,omitempty"`
NewAuthz string `json:"newAuthz,omitempty"`
RevokeCert string `json:"revokeCert,omitempty"`
KeyChange string `json:"keyChange,omitempty"`
}
// ToLog enables response logging for the Directory type.
func (d *Directory) ToLog() (interface{}, error) {
b, err := json.Marshal(d)
if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging"))
}
return string(b), nil
}
type directory struct {
prefix, dns string
}
// newDirectory returns a new Directory type.
func newDirectory(dns, prefix string) *directory {
return &directory{prefix: prefix, dns: dns}
}
// Link captures the link type.
type Link int
const (
// NewNonceLink new-nonce
NewNonceLink Link = iota
// NewAccountLink new-account
NewAccountLink
// AccountLink account
AccountLink
// OrderLink order
OrderLink
// NewOrderLink new-order
NewOrderLink
// OrdersByAccountLink list of orders owned by account
OrdersByAccountLink
// FinalizeLink finalize order
FinalizeLink
// NewAuthzLink authz
NewAuthzLink
// AuthzLink new-authz
AuthzLink
// ChallengeLink challenge
ChallengeLink
// CertificateLink certificate
CertificateLink
// DirectoryLink directory
DirectoryLink
// RevokeCertLink revoke certificate
RevokeCertLink
// KeyChangeLink key rollover
KeyChangeLink
)
func (l Link) String() string {
switch l {
case NewNonceLink:
return "new-nonce"
case NewAccountLink:
return "new-account"
case AccountLink:
return "account"
case NewOrderLink:
return "new-order"
case OrderLink:
return "order"
case NewAuthzLink:
return "new-authz"
case AuthzLink:
return "authz"
case ChallengeLink:
return "challenge"
case CertificateLink:
return "certificate"
case DirectoryLink:
return "directory"
case RevokeCertLink:
return "revoke-cert"
case KeyChangeLink:
return "key-change"
default:
return "unexpected"
}
}
// getLink returns an absolute or partial path to the given resource.
func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs ...string) string {
var link string
switch typ {
case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
link = fmt.Sprintf("/%s/%s", provisionerName, typ.String())
case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink:
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0])
case OrdersByAccountLink:
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0])
case FinalizeLink:
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
}
if abs {
return fmt.Sprintf("https://%s/%s%s", d.dns, d.prefix, link)
}
return link
}

60
acme/directory_test.go Normal file
View file

@ -0,0 +1,60 @@
package acme
import (
"fmt"
"testing"
"github.com/smallstep/assert"
)
func TestDirectoryGetLink(t *testing.T) {
dns := "ca.smallstep.com"
prefix := "acme"
dir := newDirectory(dns, prefix)
id := "1234"
prov := newProv()
provID := URLSafeProvisionerName(prov)
assert.Equals(t, dir.getLink(NewNonceLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", provID))
assert.Equals(t, dir.getLink(NewNonceLink, provID, false), fmt.Sprintf("/%s/new-nonce", provID))
assert.Equals(t, dir.getLink(NewAccountLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID))
assert.Equals(t, dir.getLink(NewAccountLink, provID, false), fmt.Sprintf("/%s/new-account", provID))
assert.Equals(t, dir.getLink(AccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provID))
assert.Equals(t, dir.getLink(AccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234", provID))
assert.Equals(t, dir.getLink(NewOrderLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", provID))
assert.Equals(t, dir.getLink(NewOrderLink, provID, false), fmt.Sprintf("/%s/new-order", provID))
assert.Equals(t, dir.getLink(OrderLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234", provID))
assert.Equals(t, dir.getLink(OrderLink, provID, false, id), fmt.Sprintf("/%s/order/1234", provID))
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234/orders", provID))
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234/orders", provID))
assert.Equals(t, dir.getLink(FinalizeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234/finalize", provID))
assert.Equals(t, dir.getLink(FinalizeLink, provID, false, id), fmt.Sprintf("/%s/order/1234/finalize", provID))
assert.Equals(t, dir.getLink(NewAuthzLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-authz", provID))
assert.Equals(t, dir.getLink(NewAuthzLink, provID, false), fmt.Sprintf("/%s/new-authz", provID))
assert.Equals(t, dir.getLink(AuthzLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/authz/1234", provID))
assert.Equals(t, dir.getLink(AuthzLink, provID, false, id), fmt.Sprintf("/%s/authz/1234", provID))
assert.Equals(t, dir.getLink(DirectoryLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", provID))
assert.Equals(t, dir.getLink(DirectoryLink, provID, false), fmt.Sprintf("/%s/directory", provID))
assert.Equals(t, dir.getLink(RevokeCertLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", provID))
assert.Equals(t, dir.getLink(RevokeCertLink, provID, false), fmt.Sprintf("/%s/revoke-cert", provID))
assert.Equals(t, dir.getLink(KeyChangeLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", provID))
assert.Equals(t, dir.getLink(KeyChangeLink, provID, false), fmt.Sprintf("/%s/key-change", provID))
assert.Equals(t, dir.getLink(ChallengeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/1234", provID))
assert.Equals(t, dir.getLink(ChallengeLink, provID, false, id), fmt.Sprintf("/%s/challenge/1234", provID))
assert.Equals(t, dir.getLink(CertificateLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/1234", provID))
assert.Equals(t, dir.getLink(CertificateLink, provID, false, id), fmt.Sprintf("/%s/certificate/1234", provID))
}

439
acme/errors.go Normal file
View file

@ -0,0 +1,439 @@
package acme
import (
"github.com/pkg/errors"
)
// AccountDoesNotExistErr returns a new acme error.
func AccountDoesNotExistErr(err error) *Error {
return &Error{
Type: accountDoesNotExistErr,
Detail: "Account does not exist",
Status: 404,
Err: err,
}
}
// AlreadyRevokedErr returns a new acme error.
func AlreadyRevokedErr(err error) *Error {
return &Error{
Type: alreadyRevokedErr,
Detail: "Certificate already revoked",
Status: 400,
Err: err,
}
}
// BadCSRErr returns a new acme error.
func BadCSRErr(err error) *Error {
return &Error{
Type: badCSRErr,
Detail: "The CSR is unacceptable",
Status: 400,
Err: err,
}
}
// BadNonceErr returns a new acme error.
func BadNonceErr(err error) *Error {
return &Error{
Type: badNonceErr,
Detail: "Unacceptable anti-replay nonce",
Status: 400,
Err: err,
}
}
// BadPublicKeyErr returns a new acme error.
func BadPublicKeyErr(err error) *Error {
return &Error{
Type: badPublicKeyErr,
Detail: "The jws was signed by a public key the server does not support",
Status: 400,
Err: err,
}
}
// BadRevocationReasonErr returns a new acme error.
func BadRevocationReasonErr(err error) *Error {
return &Error{
Type: badRevocationReasonErr,
Detail: "The revocation reason provided is not allowed by the server",
Status: 400,
Err: err,
}
}
// BadSignatureAlgorithmErr returns a new acme error.
func BadSignatureAlgorithmErr(err error) *Error {
return &Error{
Type: badSignatureAlgorithmErr,
Detail: "The JWS was signed with an algorithm the server does not support",
Status: 400,
Err: err,
}
}
// CaaErr returns a new acme error.
func CaaErr(err error) *Error {
return &Error{
Type: caaErr,
Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
Status: 400,
Err: err,
}
}
// CompoundErr returns a new acme error.
func CompoundErr(err error) *Error {
return &Error{
Type: compoundErr,
Detail: "Specific error conditions are indicated in the “subproblems” array",
Status: 400,
Err: err,
}
}
// ConnectionErr returns a new acme error.
func ConnectionErr(err error) *Error {
return &Error{
Type: connectionErr,
Detail: "The server could not connect to validation target",
Status: 400,
Err: err,
}
}
// DNSErr returns a new acme error.
func DNSErr(err error) *Error {
return &Error{
Type: dnsErr,
Detail: "There was a problem with a DNS query during identifier validation",
Status: 400,
Err: err,
}
}
// ExternalAccountRequiredErr returns a new acme error.
func ExternalAccountRequiredErr(err error) *Error {
return &Error{
Type: externalAccountRequiredErr,
Detail: "The request must include a value for the \"externalAccountBinding\" field",
Status: 400,
Err: err,
}
}
// IncorrectResponseErr returns a new acme error.
func IncorrectResponseErr(err error) *Error {
return &Error{
Type: incorrectResponseErr,
Detail: "Response received didn't match the challenge's requirements",
Status: 400,
Err: err,
}
}
// InvalidContactErr returns a new acme error.
func InvalidContactErr(err error) *Error {
return &Error{
Type: invalidContactErr,
Detail: "A contact URL for an account was invalid",
Status: 400,
Err: err,
}
}
// MalformedErr returns a new acme error.
func MalformedErr(err error) *Error {
return &Error{
Type: malformedErr,
Detail: "The request message was malformed",
Status: 400,
Err: err,
}
}
// OrderNotReadyErr returns a new acme error.
func OrderNotReadyErr(err error) *Error {
return &Error{
Type: orderNotReadyErr,
Detail: "The request attempted to finalize an order that is not ready to be finalized",
Status: 400,
Err: err,
}
}
// RateLimitedErr returns a new acme error.
func RateLimitedErr(err error) *Error {
return &Error{
Type: rateLimitedErr,
Detail: "The request exceeds a rate limit",
Status: 400,
Err: err,
}
}
// RejectedIdentifierErr returns a new acme error.
func RejectedIdentifierErr(err error) *Error {
return &Error{
Type: rejectedIdentifierErr,
Detail: "The server will not issue certificates for the identifier",
Status: 400,
Err: err,
}
}
// ServerInternalErr returns a new acme error.
func ServerInternalErr(err error) *Error {
return &Error{
Type: serverInternalErr,
Detail: "The server experienced an internal error",
Status: 500,
Err: err,
}
}
// TLSErr returns a new acme error.
func TLSErr(err error) *Error {
return &Error{
Type: tlsErr,
Detail: "The server received a TLS error during validation",
Status: 400,
Err: err,
}
}
// UnauthorizedErr returns a new acme error.
func UnauthorizedErr(err error) *Error {
return &Error{
Type: unauthorizedErr,
Detail: "The client lacks sufficient authorization",
Status: 401,
Err: err,
}
}
// UnsupportedContactErr returns a new acme error.
func UnsupportedContactErr(err error) *Error {
return &Error{
Type: unsupportedContactErr,
Detail: "A contact URL for an account used an unsupported protocol scheme",
Status: 400,
Err: err,
}
}
// UnsupportedIdentifierErr returns a new acme error.
func UnsupportedIdentifierErr(err error) *Error {
return &Error{
Type: unsupportedIdentifierErr,
Detail: "An identifier is of an unsupported type",
Status: 400,
Err: err,
}
}
// UserActionRequiredErr returns a new acme error.
func UserActionRequiredErr(err error) *Error {
return &Error{
Type: userActionRequiredErr,
Detail: "Visit the “instance” URL and take actions specified there",
Status: 400,
Err: err,
}
}
// ProbType is the type of the ACME problem.
type ProbType int
const (
// The request specified an account that does not exist
accountDoesNotExistErr ProbType = iota
// The request specified a certificate to be revoked that has already been revoked
alreadyRevokedErr
// The CSR is unacceptable (e.g., due to a short key)
badCSRErr
// The client sent an unacceptable anti-replay nonce
badNonceErr
// The JWS was signed by a public key the server does not support
badPublicKeyErr
// The revocation reason provided is not allowed by the server
badRevocationReasonErr
// The JWS was signed with an algorithm the server does not support
badSignatureAlgorithmErr
// Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate
caaErr
// Specific error conditions are indicated in the “subproblems” array.
compoundErr
// The server could not connect to validation target
connectionErr
// There was a problem with a DNS query during identifier validation
dnsErr
// The request must include a value for the “externalAccountBinding” field
externalAccountRequiredErr
// Response received didnt match the challenges requirements
incorrectResponseErr
// A contact URL for an account was invalid
invalidContactErr
// The request message was malformed
malformedErr
// The request attempted to finalize an order that is not ready to be finalized
orderNotReadyErr
// The request exceeds a rate limit
rateLimitedErr
// The server will not issue certificates for the identifier
rejectedIdentifierErr
// The server experienced an internal error
serverInternalErr
// The server received a TLS error during validation
tlsErr
// The client lacks sufficient authorization
unauthorizedErr
// A contact URL for an account used an unsupported protocol scheme
unsupportedContactErr
// An identifier is of an unsupported type
unsupportedIdentifierErr
// Visit the “instance” URL and take actions specified there
userActionRequiredErr
)
// String returns the string representation of the acme problem type,
// fulfilling the Stringer interface.
func (ap ProbType) String() string {
switch ap {
case accountDoesNotExistErr:
return "accountDoesNotExist"
case alreadyRevokedErr:
return "alreadyRevoked"
case badCSRErr:
return "badCSR"
case badNonceErr:
return "badNonce"
case badPublicKeyErr:
return "badPublicKey"
case badRevocationReasonErr:
return "badRevocationReason"
case badSignatureAlgorithmErr:
return "badSignatureAlgorithm"
case caaErr:
return "caa"
case compoundErr:
return "compound"
case connectionErr:
return "connection"
case dnsErr:
return "dns"
case externalAccountRequiredErr:
return "externalAccountRequired"
case incorrectResponseErr:
return "incorrectResponse"
case invalidContactErr:
return "invalidContact"
case malformedErr:
return "malformed"
case orderNotReadyErr:
return "orderNotReady"
case rateLimitedErr:
return "rateLimited"
case rejectedIdentifierErr:
return "rejectedIdentifier"
case serverInternalErr:
return "serverInternal"
case tlsErr:
return "tls"
case unauthorizedErr:
return "unauthorized"
case unsupportedContactErr:
return "unsupportedContact"
case unsupportedIdentifierErr:
return "unsupportedIdentifier"
case userActionRequiredErr:
return "userActionRequired"
default:
return "unsupported type"
}
}
// Error is an ACME error type complete with problem document.
type Error struct {
Type ProbType
Detail string
Err error
Status int
Sub []*Error
Identifier *Identifier
}
// Wrap attempts to wrap the internal error.
func Wrap(err error, wrap string) *Error {
switch e := err.(type) {
case nil:
return nil
case *Error:
if e.Err == nil {
e.Err = errors.New(wrap + "; " + e.Detail)
} else {
e.Err = errors.Wrap(e.Err, wrap)
}
return e
default:
return ServerInternalErr(errors.Wrap(err, wrap))
}
}
// Error implements the error interface.
func (e *Error) Error() string {
if e.Err == nil {
return e.Detail
}
return e.Err.Error()
}
// Cause returns the internal error and implements the Causer interface.
func (e *Error) Cause() error {
if e.Err == nil {
return errors.New(e.Detail)
}
return e.Err
}
// ToACME returns an acme representation of the problem type.
func (e *Error) ToACME() *AError {
ae := &AError{
Type: "urn:ietf:params:acme:error:" + e.Type.String(),
Detail: e.Error(),
Status: e.Status,
}
if e.Identifier != nil {
ae.Identifier = *e.Identifier
}
for _, p := range e.Sub {
ae.Subproblems = append(ae.Subproblems, p.ToACME())
}
return ae
}
// StatusCode returns the status code and implements the StatusCode interface.
func (e *Error) StatusCode() int {
return e.Status
}
// AError is the error type as seen in acme request/responses.
type AError struct {
Type string `json:"type"`
Detail string `json:"detail"`
Identifier interface{} `json:"identifier,omitempty"`
Subproblems []interface{} `json:"subproblems,omitempty"`
Status int `json:"-"`
}
// Error allows AError to implement the error interface.
func (ae *AError) Error() string {
return ae.Detail
}
// StatusCode returns the status code and implements the StatusCode interface.
func (ae *AError) StatusCode() int {
return ae.Status
}

73
acme/nonce.go Normal file
View file

@ -0,0 +1,73 @@
package acme
import (
"encoding/base64"
"encoding/json"
"time"
"github.com/pkg/errors"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
// nonce contains nonce metadata used in the ACME protocol.
type nonce struct {
ID string
Created time.Time
}
// newNonce creates, stores, and returns an ACME replay-nonce.
func newNonce(db nosql.DB) (*nonce, error) {
_id, err := randID()
if err != nil {
return nil, err
}
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
n := &nonce{
ID: id,
Created: clock.Now(),
}
b, err := json.Marshal(n)
if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce"))
}
_, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b)
switch {
case err != nil:
return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce"))
case !swapped:
return nil, ServerInternalErr(errors.New("error storing nonce; " +
"value has changed since last read"))
default:
return n, nil
}
}
// useNonce verifies that the nonce is valid (by checking if it exists),
// and if so, consumes the nonce resource by deleting it from the database.
func useNonce(db nosql.DB, nonce string) error {
err := db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: nonceTable,
Key: []byte(nonce),
Cmd: database.Get,
},
{
Bucket: nonceTable,
Key: []byte(nonce),
Cmd: database.Delete,
},
},
})
switch {
case nosql.IsErrNotFound(err):
return BadNonceErr(nil)
case err != nil:
return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce))
default:
return nil
}
}

163
acme/nonce_test.go Normal file
View file

@ -0,0 +1,163 @@
package acme
import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
func TestNewNonce(t *testing.T) {
type test struct {
db nosql.DB
err *Error
id *string
}
tests := map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, old, nil)
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.Errorf("error storing nonce: force")),
}
},
"fail/cmpAndSwap-false": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, old, nil)
return nil, false, nil
},
},
err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")),
}
},
"ok": func(t *testing.T) test {
var _id string
id := &_id
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, old, nil)
*id = string(key)
return nil, true, nil
},
},
id: id,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if n, err := newNonce(tc.db); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, n.ID, *tc.id)
assert.True(t, n.Created.Before(time.Now().Add(time.Minute)))
assert.True(t, n.Created.After(time.Now().Add(-time.Minute)))
}
}
})
}
}
func TestUseNonce(t *testing.T) {
type test struct {
id string
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/update-not-found": func(t *testing.T) test {
id := "foo"
return test{
db: &db.MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
assert.Equals(t, tx.Operations[0].Key, []byte(id))
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
assert.Equals(t, tx.Operations[1].Key, []byte(id))
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
return database.ErrNotFound
},
},
id: id,
err: BadNonceErr(nil),
}
},
"fail/update-error": func(t *testing.T) test {
id := "foo"
return test{
db: &db.MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
assert.Equals(t, tx.Operations[0].Key, []byte(id))
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
assert.Equals(t, tx.Operations[1].Key, []byte(id))
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
return errors.New("force")
},
},
id: id,
err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)),
}
},
"ok": func(t *testing.T) test {
id := "foo"
return test{
db: &db.MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
assert.Equals(t, tx.Operations[0].Key, []byte(id))
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
assert.Equals(t, tx.Operations[1].Key, []byte(id))
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
return nil
},
},
id: id,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := useNonce(tc.db, tc.id); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
}
})
}
}

342
acme/order.go Normal file
View file

@ -0,0 +1,342 @@
package acme
import (
"context"
"crypto/x509"
"encoding/json"
"reflect"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/nosql"
)
var defaultOrderExpiry = time.Hour * 24
// Order contains order metadata for the ACME protocol order type.
type Order struct {
Status string `json:"status"`
Expires string `json:"expires,omitempty"`
Identifiers []Identifier `json:"identifiers"`
NotBefore string `json:"notBefore,omitempty"`
NotAfter string `json:"notAfter,omitempty"`
Error interface{} `json:"error,omitempty"`
Authorizations []string `json:"authorizations"`
Finalize string `json:"finalize"`
Certificate string `json:"certificate,omitempty"`
ID string `json:"-"`
}
// ToLog enables response logging.
func (o *Order) ToLog() (interface{}, error) {
b, err := json.Marshal(o)
if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging"))
}
return string(b), nil
}
// GetID returns the Order ID.
func (o *Order) GetID() string {
return o.ID
}
// OrderOptions options with which to create a new Order.
type OrderOptions struct {
AccountID string `json:"accID"`
Identifiers []Identifier `json:"identifiers"`
NotBefore time.Time `json:"notBefore"`
NotAfter time.Time `json:"notAfter"`
}
type order struct {
ID string `json:"id"`
AccountID string `json:"accountID"`
Created time.Time `json:"created"`
Expires time.Time `json:"expires,omitempty"`
Status string `json:"status"`
Identifiers []Identifier `json:"identifiers"`
NotBefore time.Time `json:"notBefore,omitempty"`
NotAfter time.Time `json:"notAfter,omitempty"`
Error *Error `json:"error,omitempty"`
Authorizations []string `json:"authorizations"`
Certificate string `json:"certificate,omitempty"`
}
// newOrder returns a new Order type.
func newOrder(db nosql.DB, ops OrderOptions) (*order, error) {
id, err := randID()
if err != nil {
return nil, err
}
authzs := make([]string, len(ops.Identifiers))
for i, identifier := range ops.Identifiers {
authz, err := newAuthz(db, ops.AccountID, identifier)
if err != nil {
return nil, err
}
authzs[i] = authz.getID()
}
now := clock.Now()
o := &order{
ID: id,
AccountID: ops.AccountID,
Created: now,
Status: StatusPending,
Expires: now.Add(defaultOrderExpiry),
Identifiers: ops.Identifiers,
NotBefore: ops.NotBefore,
NotAfter: ops.NotAfter,
Authorizations: authzs,
}
if err := o.save(db, nil); err != nil {
return nil, err
}
// Update the "order IDs by account ID" index //
oids, err := getOrderIDsByAccount(db, ops.AccountID)
if err != nil {
return nil, err
}
newOids := append(oids, o.ID)
if err = orderIDs(newOids).save(db, oids, o.AccountID); err != nil {
db.Del(orderTable, []byte(o.ID))
return nil, err
}
return o, nil
}
type orderIDs []string
func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
var (
err error
oldb []byte
)
if len(old) == 0 {
oldb = nil
} else {
oldb, err = json.Marshal(old)
if err != nil {
return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice"))
}
}
newb, err := json.Marshal(oids)
if err != nil {
return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice"))
}
_, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb)
switch {
case err != nil:
return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID))
case !swapped:
return ServerInternalErr(errors.Errorf("error storing order IDs "+
"for account %s; order IDs changed since last read", accID))
default:
return nil
}
}
func (o *order) save(db nosql.DB, old *order) error {
var (
err error
oldB []byte
)
if old == nil {
oldB = nil
} else {
if oldB, err = json.Marshal(old); err != nil {
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
}
}
newB, err := json.Marshal(o)
if err != nil {
return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order"))
}
_, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB)
switch {
case err != nil:
return ServerInternalErr(errors.Wrap(err, "error storing order"))
case !swapped:
return ServerInternalErr(errors.New("error storing order; " +
"value has changed since last read"))
default:
return nil
}
}
// updateStatus updates order status if necessary.
func (o *order) updateStatus(db nosql.DB) (*order, error) {
_newOrder := *o
newOrder := &_newOrder
now := time.Now().UTC()
switch o.Status {
case StatusInvalid:
return o, nil
case StatusValid:
return o, nil
case StatusReady:
// check expiry
if now.After(o.Expires) {
newOrder.Status = StatusInvalid
newOrder.Error = MalformedErr(errors.New("order has expired"))
break
}
return o, nil
case StatusPending:
// check expiry
if now.After(o.Expires) {
newOrder.Status = StatusInvalid
newOrder.Error = MalformedErr(errors.New("order has expired"))
break
}
var count = map[string]int{
StatusValid: 0,
StatusInvalid: 0,
StatusPending: 0,
}
for _, azID := range o.Authorizations {
authz, err := getAuthz(db, azID)
if err != nil {
return nil, err
}
if authz, err = authz.updateStatus(db); err != nil {
return nil, err
}
st := authz.getStatus()
count[st]++
}
switch {
case count[StatusInvalid] > 0:
newOrder.Status = StatusInvalid
case count[StatusPending] > 0:
break
case count[StatusValid] == len(o.Authorizations):
newOrder.Status = StatusReady
default:
return nil, ServerInternalErr(errors.New("unexpected authz status"))
}
default:
return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status))
}
if err := newOrder.save(db, o); err != nil {
return nil, err
}
return newOrder, nil
}
// finalize signs a certificate if the necessary conditions for Order completion
// have been met.
func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p provisioner.Interface) (*order, error) {
var err error
if o, err = o.updateStatus(db); err != nil {
return nil, err
}
switch o.Status {
case StatusInvalid:
return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID))
case StatusValid:
return o, nil
case StatusPending:
return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID))
case StatusReady:
break
default:
return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID))
}
// Validate identifier names against CSR alternative names //
csrNames := make(map[string]int)
for _, n := range csr.DNSNames {
csrNames[n] = 1
}
orderNames := make(map[string]int)
for _, n := range o.Identifiers {
orderNames[n.Value] = 1
}
if !reflect.DeepEqual(csrNames, orderNames) {
return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly"))
}
// Get authorizations from the ACME provisioner.
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
signOps, err := p.AuthorizeSign(ctx, "")
if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner"))
}
// Create and store a new certificate.
leaf, inter, err := auth.Sign(csr, provisioner.Options{
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
}, signOps...)
if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID))
}
cert, err := newCert(db, CertOptions{
AccountID: o.AccountID,
OrderID: o.ID,
Leaf: leaf,
Intermediates: []*x509.Certificate{inter},
})
if err != nil {
return nil, err
}
_newOrder := *o
newOrder := &_newOrder
newOrder.Certificate = cert.ID
newOrder.Status = StatusValid
if err := newOrder.save(db, o); err != nil {
return nil, err
}
return newOrder, nil
}
// getOrder retrieves and unmarshals an ACME Order type from the database.
func getOrder(db nosql.DB, id string) (*order, error) {
b, err := db.Get(orderTable, []byte(id))
if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id))
} else if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id))
}
var o order
if err := json.Unmarshal(b, &o); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order"))
}
return &o, nil
}
// toACME converts the internal Order type into the public acmeOrder type for
// presentation in the ACME protocol.
func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Order, error) {
azs := make([]string, len(o.Authorizations))
for i, aid := range o.Authorizations {
azs[i] = dir.getLink(AuthzLink, URLSafeProvisionerName(p), true, aid)
}
ao := &Order{
Status: o.Status,
Expires: o.Expires.Format(time.RFC3339),
Identifiers: o.Identifiers,
NotBefore: o.NotBefore.Format(time.RFC3339),
NotAfter: o.NotAfter.Format(time.RFC3339),
Authorizations: azs,
Finalize: dir.getLink(FinalizeLink, URLSafeProvisionerName(p), true, o.ID),
ID: o.ID,
}
if o.Certificate != "" {
ao.Certificate = dir.getLink(CertificateLink, URLSafeProvisionerName(p), true, o.Certificate)
}
return ao, nil
}

1129
acme/order_test.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -28,8 +28,7 @@ import (
// Authority is the interface implemented by a CA authority. // Authority is the interface implemented by a CA authority.
type Authority interface { type Authority interface {
SSHAuthority SSHAuthority
// NOTE: Authorize will be deprecated in future releases. Please use the // context specifies the Authorize[Sign|Revoke|etc.] method.
// context specific Authorize[Sign|Revoke|etc.] methods.
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error)
GetTLSOptions() *tlsutil.TLSOptions GetTLSOptions() *tlsutil.TLSOptions
@ -37,6 +36,7 @@ type Authority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error) LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
LoadProvisionerByID(string) (provisioner.Interface, error)
GetProvisioners(cursor string, limit int) (provisioner.List, string, error) GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
Revoke(*authority.RevokeOptions) error Revoke(*authority.RevokeOptions) error
GetEncryptedKey(kid string) (string, error) GetEncryptedKey(kid string) (string, error)
@ -308,13 +308,12 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
return return
} }
w.WriteHeader(http.StatusCreated)
logCertificate(w, cert) logCertificate(w, cert)
JSON(w, &SignResponse{ JSONStatus(w, &SignResponse{
ServerPEM: Certificate{cert}, ServerPEM: Certificate{cert},
CaPEM: Certificate{root}, CaPEM: Certificate{root},
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: h.Authority.GetTLSOptions(),
}) }, http.StatusCreated)
} }
// Renew uses the information of certificate in the TLS connection to create a // Renew uses the information of certificate in the TLS connection to create a
@ -331,13 +330,12 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
return return
} }
w.WriteHeader(http.StatusCreated)
logCertificate(w, cert) logCertificate(w, cert)
JSON(w, &SignResponse{ JSONStatus(w, &SignResponse{
ServerPEM: Certificate{cert}, ServerPEM: Certificate{cert},
CaPEM: Certificate{root}, CaPEM: Certificate{root},
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: h.Authority.GetTLSOptions(),
}) }, http.StatusCreated)
} }
// Provisioners returns the list of provisioners configured in the authority. // Provisioners returns the list of provisioners configured in the authority.
@ -383,10 +381,9 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{roots[i]} certs[i] = Certificate{roots[i]}
} }
w.WriteHeader(http.StatusCreated) JSONStatus(w, &RootsResponse{
JSON(w, &RootsResponse{
Certificates: certs, Certificates: certs,
}) }, http.StatusCreated)
} }
// Federation returns all the public certificates in the federation. // Federation returns all the public certificates in the federation.
@ -402,10 +399,9 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{federated[i]} certs[i] = Certificate{federated[i]}
} }
w.WriteHeader(http.StatusCreated) JSONStatus(w, &FederationResponse{
JSON(w, &FederationResponse{
Certificates: certs, Certificates: certs,
}) }, http.StatusCreated)
} }
var oidStepProvisioner = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1} var oidStepProvisioner = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}

View file

@ -506,6 +506,7 @@ type mockAuthority struct {
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
loadProvisionerByID func(provID string) (provisioner.Interface, error)
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
revoke func(*authority.RevokeOptions) error revoke func(*authority.RevokeOptions) error
getEncryptedKey func(kid string) (string, error) getEncryptedKey func(kid string) (string, error)
@ -581,6 +582,13 @@ func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (pr
return m.ret1.(provisioner.Interface), m.err return m.ret1.(provisioner.Interface), m.err
} }
func (m *mockAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
if m.loadProvisionerByID != nil {
return m.loadProvisionerByID(provID)
}
return m.ret1.(provisioner.Interface), m.err
}
func (m *mockAuthority) Revoke(opts *authority.RevokeOptions) error { func (m *mockAuthority) Revoke(opts *authority.RevokeOptions) error {
if m.revoke != nil { if m.revoke != nil {
return m.revoke(opts) return m.revoke(opts)

View file

@ -7,6 +7,7 @@ import (
"os" "os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
@ -109,7 +110,13 @@ func NotFound(err error) error {
// WriteError writes to w a JSON representation of the given error. // WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err error) { func WriteError(w http.ResponseWriter, err error) {
switch k := err.(type) {
case *acme.Error:
w.Header().Set("Content-Type", "application/problem+json")
err = k.ToACME()
default:
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
}
cause := errors.Cause(err) cause := errors.Cause(err)
if sc, ok := err.(StatusCoder); ok { if sc, ok := err.(StatusCoder); ok {
w.WriteHeader(sc.StatusCode()) w.WriteHeader(sc.StatusCode())

View file

@ -87,8 +87,6 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
} }
logRevoke(w, opts) logRevoke(w, opts)
w.WriteHeader(http.StatusOK)
JSON(w, &RevokeResponse{Status: "ok"}) JSON(w, &RevokeResponse{Status: "ok"})
} }

View file

@ -74,7 +74,6 @@ func Test_caHandler_Revoke(t *testing.T) {
input string input string
auth Authority auth Authority
tls *tls.ConnectionState tls *tls.ConnectionState
err error
statusCode int statusCode int
expected []byte expected []byte
} }

View file

@ -260,13 +260,6 @@ func Test_caHandler_SignSSH(t *testing.T) {
}) })
assert.FatalError(t, err) assert.FatalError(t, err)
type fields struct {
Authority Authority
}
type args struct {
w http.ResponseWriter
r *http.Request
}
tests := []struct { tests := []struct {
name string name string
req []byte req []byte

View file

@ -10,6 +10,11 @@ import (
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
// EnableLogger is an interface that enables response logging for an object.
type EnableLogger interface {
ToLog() (interface{}, error)
}
// LogError adds to the response writer the given error if it implements // LogError adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error // logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package. // using the log package.
@ -23,12 +28,40 @@ func LogError(rw http.ResponseWriter, err error) {
} }
} }
// LogEnabledResponse log the response object if it implements the EnableLogger
// interface.
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
if el, ok := v.(EnableLogger); ok {
out, err := el.ToLog()
if err != nil {
LogError(rw, err)
return
}
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"response": out,
})
} else {
log.Println(out)
}
}
}
// JSON writes the passed value into the http.ResponseWriter. // JSON writes the passed value into the http.ResponseWriter.
func JSON(w http.ResponseWriter, v interface{}) { func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK)
}
// JSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil { if err := json.NewEncoder(w).Encode(v); err != nil {
LogError(w, err) LogError(w, err)
return
} }
LogEnabledResponse(w, v)
} }
// ReadJSON reads JSON from the request body and stores it in the value // ReadJSON reads JSON from the request body and stores it in the value

View file

@ -15,7 +15,9 @@ import (
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
) )
const legacyAuthority = "step-certificate-authority" const (
legacyAuthority = "step-certificate-authority"
)
// Authority implements the Certificate Authority internal interface. // Authority implements the Certificate Authority internal interface.
type Authority struct { type Authority struct {
@ -24,7 +26,6 @@ type Authority struct {
intermediateIdentity *x509util.Identity intermediateIdentity *x509util.Identity
sshCAUserCertSignKey crypto.Signer sshCAUserCertSignKey crypto.Signer
sshCAHostCertSignKey crypto.Signer sshCAHostCertSignKey crypto.Signer
validateOnce bool
certificates *sync.Map certificates *sync.Map
startTime time.Time startTime time.Time
provisioners *provisioner.Collection provisioners *provisioner.Collection

View file

@ -8,7 +8,7 @@ import (
type MockAuthDB struct { type MockAuthDB struct {
err error err error
ret1, ret2 interface{} ret1 interface{}
init func(*db.Config) (db.AuthDB, error) init func(*db.Config) (db.AuthDB, error)
isRevoked func(string) (bool, error) isRevoked func(string) (bool, error)
revoke func(rci *db.RevokedCertificateInfo) error revoke func(rci *db.RevokedCertificateInfo) error

View file

@ -1,6 +1,8 @@
package authority package authority
import ( import (
"encoding/json"
"fmt"
"net/http" "net/http"
) )
@ -33,6 +35,12 @@ func (e *apiError) Error() string {
return ret return ret
} }
// ErrorResponse represents an error in JSON format.
type ErrorResponse struct {
Status int `json:"status"`
Message string `json:"message"`
}
// StatusCode returns an http status code indicating the type and severity of // StatusCode returns an http status code indicating the type and severity of
// the error. // the error.
func (e *apiError) StatusCode() int { func (e *apiError) StatusCode() int {
@ -41,3 +49,19 @@ func (e *apiError) StatusCode() int {
} }
return e.code return e.code
} }
// MarshalJSON implements json.Marshaller interface for the Error struct.
func (e *apiError) MarshalJSON() ([]byte, error) {
return json.Marshal(&ErrorResponse{Status: e.code, Message: http.StatusText(e.code)})
}
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
func (e *apiError) UnmarshalJSON(data []byte) error {
var er ErrorResponse
if err := json.Unmarshal(data, &er); err != nil {
return err
}
e.code = er.Status
e.err = fmt.Errorf(er.Message)
return nil
}

View file

@ -0,0 +1,85 @@
package provisioner
import (
"context"
"crypto/x509"
"github.com/pkg/errors"
)
// ACME is the acme provisioner type, an entity that can authorize the ACME
// provisioning flow.
type ACME struct {
Type string `json:"type"`
Name string `json:"name"`
Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
}
// GetID returns the provisioner unique identifier.
func (p ACME) GetID() string {
return "acme/" + p.Name
}
// GetTokenID returns the identifier of the token.
func (p *ACME) GetTokenID(ott string) (string, error) {
return "", errors.New("acme provisioner does not implement GetTokenID")
}
// GetName returns the name of the provisioner.
func (p *ACME) GetName() string {
return p.Name
}
// GetType returns the type of provisioner.
func (p *ACME) GetType() Type {
return TypeACME
}
// GetEncryptedKey returns the base provisioner encrypted key if it's defined.
func (p *ACME) GetEncryptedKey() (string, string, bool) {
return "", "", false
}
// Init initializes and validates the fields of a JWK type.
func (p *ACME) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
case p.Name == "":
return errors.New("provisioner name cannot be empty")
}
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
return err
}
// AuthorizeRevoke is not implemented yet for the ACME provisioner.
func (p *ACME) AuthorizeRevoke(token string) error {
return nil
}
// AuthorizeSign validates the given token.
func (p *ACME) AuthorizeSign(ctx context.Context, _ string) ([]SignOption, error) {
if m := MethodFromContext(ctx); m != SignMethod {
return nil, errors.Errorf("unexpected method type %d in context", m)
}
return []SignOption{
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeACME, p.Name, ""),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
defaultPublicKeyValidator{},
}, nil
}
// AuthorizeRenewal is not implemented for the ACME provisioner.
func (p *ACME) AuthorizeRenewal(cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
}
return nil
}

View file

@ -0,0 +1,184 @@
package provisioner
import (
"context"
"crypto/x509"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
)
func TestACME_Getters(t *testing.T) {
p, err := generateACME()
assert.FatalError(t, err)
id := "acme/" + p.Name
if got := p.GetID(); got != id {
t.Errorf("ACME.GetID() = %v, want %v", got, id)
}
if got := p.GetName(); got != p.Name {
t.Errorf("ACME.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeACME {
t.Errorf("ACME.GetType() = %v, want %v", got, TypeACME)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
}
func TestACME_Init(t *testing.T) {
type ProvisionerValidateTest struct {
p *ACME
err error
}
tests := map[string]func(*testing.T) ProvisionerValidateTest{
"fail-empty": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &ACME{},
err: errors.New("provisioner type cannot be empty"),
}
},
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &ACME{
Type: "ACME",
},
err: errors.New("provisioner name cannot be empty"),
}
},
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &ACME{Name: "foo"},
err: errors.New("provisioner type cannot be empty"),
}
},
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &ACME{Name: "foo", Type: "bar", Claims: &Claims{DefaultTLSDur: &Duration{0}}},
err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"),
}
},
"ok": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &ACME{Name: "foo", Type: "bar"},
}
},
}
config := Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}
for name, get := range tests {
t.Run(name, func(t *testing.T) {
tc := get(t)
err := tc.p.Init(config)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestACME_AuthorizeRevoke(t *testing.T) {
p, err := generateACME()
assert.FatalError(t, err)
assert.Nil(t, p.AuthorizeRevoke(""))
}
func TestACME_AuthorizeRenewal(t *testing.T) {
p1, err := generateACME()
assert.FatalError(t, err)
p2, err := generateACME()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *ACME
args args
err error
}{
{"ok", p1, args{nil}, nil},
{"fail", p2, args{nil}, errors.Errorf("renew is disabled for provisioner %s", p2.GetID())},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenewal(tt.args.cert); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} else {
assert.Nil(t, tt.err)
}
})
}
}
func TestACME_AuthorizeSign(t *testing.T) {
p1, err := generateACME()
assert.FatalError(t, err)
tests := []struct {
name string
prov *ACME
method Method
err error
}{
{"fail/method", p1, SignSSHMethod, errors.New("unexpected method type 1 in context")},
{"ok", p1, SignMethod, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), tt.method)
if got, err := tt.prov.AuthorizeSign(ctx, ""); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} else {
if assert.NotNil(t, got) {
assert.Len(t, 4, got)
_pdd := got[0]
pdd, ok := _pdd.(profileDefaultDuration)
assert.True(t, ok)
assert.Equals(t, pdd, profileDefaultDuration(86400000000000))
_peo := got[1]
peo, ok := _peo.(*provisionerExtensionOption)
assert.True(t, ok)
assert.Equals(t, peo.Type, 6)
assert.Equals(t, peo.Name, "test@acme-provisioner.com")
assert.Equals(t, peo.CredentialID, "")
assert.Equals(t, peo.KeyValuePairs, nil)
_vv := got[2]
vv, ok := _vv.(*validityValidator)
assert.True(t, ok)
assert.Equals(t, vv.min, time.Duration(300000000000))
assert.Equals(t, vv.max, time.Duration(86400000000000))
_dpkv := got[3]
_, ok = _dpkv.(defaultPublicKeyValidator)
assert.True(t, ok)
}
}
})
}
}

View file

@ -470,6 +470,8 @@ func (p *AWS) authorizeSSHSign(claims *awsPayload) ([]SignOption, error) {
&sshDefaultExtensionModifier{}, &sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set // checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{p.claimer}, &sshCertificateValidityModifier{p.claimer},
// validate public key
&sshDefaultPublicKeyValidator{},
// require all the fields in the SSH certificate // require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertificateDefaultValidator{},
), nil ), nil

View file

@ -377,6 +377,12 @@ func TestAWS_AuthorizeSign_SSH(t *testing.T) {
signer, err := generateJSONWebKey() signer, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
pub := key.Public().Key
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
assert.FatalError(t, err)
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SSHOptions{ expectedHostOptions := &SSHOptions{
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
@ -394,6 +400,7 @@ func TestAWS_AuthorizeSign_SSH(t *testing.T) {
type args struct { type args struct {
token string token string
sshOpts SSHOptions sshOpts SSHOptions
key interface{}
} }
tests := []struct { tests := []struct {
name string name string
@ -403,15 +410,17 @@ func TestAWS_AuthorizeSign_SSH(t *testing.T) {
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false}, {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false}, {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false}, {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}}, expectedHostOptionsIP, false, false}, {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false},
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptionsHostname, false, false}, {"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false}, {"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, false, false},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true}, {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true}, {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}}, nil, false, true}, {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -424,7 +433,7 @@ func TestAWS_AuthorizeSign_SSH(t *testing.T) {
if err != nil { if err != nil {
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr { if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else { } else {

View file

@ -327,6 +327,8 @@ func (p *Azure) authorizeSSHSign(claims azurePayload, name string) ([]SignOption
&sshDefaultExtensionModifier{}, &sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set // checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{p.claimer}, &sshCertificateValidityModifier{p.claimer},
// validate public key
&sshDefaultPublicKeyValidator{},
// require all the fields in the SSH certificate // require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertificateDefaultValidator{},
), nil ), nil

View file

@ -3,6 +3,8 @@ package provisioner
import ( import (
"context" "context"
"crypto" "crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
@ -325,6 +327,12 @@ func TestAzure_AuthorizeSign_SSH(t *testing.T) {
signer, err := generateJSONWebKey() signer, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
pub := key.Public().Key
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
assert.FatalError(t, err)
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SSHOptions{ expectedHostOptions := &SSHOptions{
CertType: "host", Principals: []string{"virtualMachine"}, CertType: "host", Principals: []string{"virtualMachine"},
@ -334,6 +342,7 @@ func TestAzure_AuthorizeSign_SSH(t *testing.T) {
type args struct { type args struct {
token string token string
sshOpts SSHOptions sshOpts SSHOptions
key interface{}
} }
tests := []struct { tests := []struct {
name string name string
@ -343,13 +352,15 @@ func TestAzure_AuthorizeSign_SSH(t *testing.T) {
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false}, {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false}, {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false}, {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false}, {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true}, {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true}, {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}}, nil, false, true}, {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -362,7 +373,7 @@ func TestAzure_AuthorizeSign_SSH(t *testing.T) {
if err != nil { if err != nil {
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr { if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else { } else {

View file

@ -127,6 +127,8 @@ func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool)
return c.Load("aws/" + string(provisioner.Name)) return c.Load("aws/" + string(provisioner.Name))
case TypeGCP: case TypeGCP:
return c.Load("gcp/" + string(provisioner.Name)) return c.Load("gcp/" + string(provisioner.Name))
case TypeACME:
return c.Load("acme/" + string(provisioner.Name))
default: default:
return c.Load(string(provisioner.CredentialID)) return c.Load(string(provisioner.CredentialID))
} }
@ -153,7 +155,7 @@ func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
// provisioner IDs. // provisioner IDs.
func (c *Collection) Store(p Interface) error { func (c *Collection) Store(p Interface) error {
// Store provisioner always in byID. ID must be unique. // Store provisioner always in byID. ID must be unique.
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true { if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded {
return errors.New("cannot add multiple provisioners with the same id") return errors.New("cannot add multiple provisioners with the same id")
} }

View file

@ -133,15 +133,20 @@ func TestCollection_LoadByCertificate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateOIDC() p2, err := generateOIDC()
assert.FatalError(t, err) assert.FatalError(t, err)
p3, err := generateACME()
assert.FatalError(t, err)
byID := new(sync.Map) byID := new(sync.Map)
byID.Store(p1.GetID(), p1) byID.Store(p1.GetID(), p1)
byID.Store(p2.GetID(), p2) byID.Store(p2.GetID(), p2)
byID.Store(p3.GetID(), p3)
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID) ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
assert.FatalError(t, err) assert.FatalError(t, err)
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID) ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
assert.FatalError(t, err) assert.FatalError(t, err)
ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "")
assert.FatalError(t, err)
notFoundExt, err := createProvisionerExtension(1, "foo", "bar") notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
assert.FatalError(t, err) assert.FatalError(t, err)
@ -151,6 +156,9 @@ func TestCollection_LoadByCertificate(t *testing.T) {
ok2Cert := &x509.Certificate{ ok2Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok2Ext}, Extensions: []pkix.Extension{ok2Ext},
} }
ok3Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok3Ext},
}
notFoundCert := &x509.Certificate{ notFoundCert := &x509.Certificate{
Extensions: []pkix.Extension{notFoundExt}, Extensions: []pkix.Extension{notFoundExt},
} }
@ -176,6 +184,7 @@ func TestCollection_LoadByCertificate(t *testing.T) {
}{ }{
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true}, {"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true}, {"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
{"ok3", fields{byID, testAudiences}, args{ok3Cert}, p3, true},
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true}, {"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false}, {"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false}, {"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},

View file

@ -382,6 +382,8 @@ func (p *GCP) authorizeSSHSign(claims *gcpPayload) ([]SignOption, error) {
&sshDefaultExtensionModifier{}, &sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set // checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{p.claimer}, &sshCertificateValidityModifier{p.claimer},
// validate public key
&sshDefaultPublicKeyValidator{},
// require all the fields in the SSH certificate // require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertificateDefaultValidator{},
), nil ), nil

View file

@ -3,6 +3,8 @@ package provisioner
import ( import (
"context" "context"
"crypto" "crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
@ -362,6 +364,12 @@ func TestGCP_AuthorizeSign_SSH(t *testing.T) {
signer, err := generateJSONWebKey() signer, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
pub := key.Public().Key
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
assert.FatalError(t, err)
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SSHOptions{ expectedHostOptions := &SSHOptions{
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
@ -379,6 +387,7 @@ func TestGCP_AuthorizeSign_SSH(t *testing.T) {
type args struct { type args struct {
token string token string
sshOpts SSHOptions sshOpts SSHOptions
key interface{}
} }
tests := []struct { tests := []struct {
name string name string
@ -388,15 +397,17 @@ func TestGCP_AuthorizeSign_SSH(t *testing.T) {
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false}, {"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false}, {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false}, {"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}}, expectedHostOptionsPrincipal1, false, false}, {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false},
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}}, expectedHostOptionsPrincipal2, false, false}, {"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false}, {"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, false, false},
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true}, {"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true}, {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}}, nil, false, true}, {"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -409,7 +420,7 @@ func TestGCP_AuthorizeSign_SSH(t *testing.T) {
if err != nil { if err != nil {
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr { if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else { } else {

View file

@ -210,6 +210,8 @@ func (p *JWK) authorizeSSHSign(claims *jwtPayload) ([]SignOption, error) {
&sshDefaultExtensionModifier{}, &sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set // checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{p.claimer}, &sshCertificateValidityModifier{p.claimer},
// validate public key
&sshDefaultPublicKeyValidator{},
// require all the fields in the SSH certificate // require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertificateDefaultValidator{},
), nil ), nil

View file

@ -3,6 +3,8 @@ package provisioner
import ( import (
"context" "context"
"crypto" "crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509" "crypto/x509"
"errors" "errors"
"strings" "strings"
@ -356,6 +358,12 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
signer, err := generateJSONWebKey() signer, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
pub := key.Public().Key
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
assert.FatalError(t, err)
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SSHOptions{ expectedUserOptions := &SSHOptions{
@ -370,6 +378,7 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
type args struct { type args struct {
token string token string
sshOpts SSHOptions sshOpts SSHOptions
key interface{}
} }
tests := []struct { tests := []struct {
name string name string
@ -379,15 +388,17 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"user", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false}, {"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false},
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false}, {"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false},
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false}, {"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false},
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false}, {"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"host", p1, args{t2, SSHOptions{}}, expectedHostOptions, false, false}, {"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false}, {"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, false, false},
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false}, {"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false}, {"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
{"fail-signature", p1, args{failSig, SSHOptions{}}, nil, true, false}, {"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
{"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, true, false},
{"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -400,7 +411,7 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
if err != nil { if err != nil {
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr { if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else { } else {

View file

@ -4,7 +4,10 @@ import (
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"net"
"net/http" "net/http"
"net/url"
"path"
"strings" "strings"
"time" "time"
@ -55,6 +58,7 @@ type OIDC struct {
Admins []string `json:"admins,omitempty"` Admins []string `json:"admins,omitempty"`
Domains []string `json:"domains,omitempty"` Domains []string `json:"domains,omitempty"`
Groups []string `json:"groups,omitempty"` Groups []string `json:"groups,omitempty"`
ListenAddress string `json:"listenAddress,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
configuration openIDConfiguration configuration openIDConfiguration
keyStore *keyStore keyStore *keyStore
@ -133,13 +137,27 @@ func (o *OIDC) Init(config Config) (err error) {
return errors.New("configurationEndpoint cannot be empty") return errors.New("configurationEndpoint cannot be empty")
} }
// Validate listenAddress if given
if o.ListenAddress != "" {
if _, _, err := net.SplitHostPort(o.ListenAddress); err != nil {
return errors.Wrap(err, "error parsing listenAddress")
}
}
// Update claims with global ones // Update claims with global ones
if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil { if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
return err return err
} }
// Decode and validate openid-configuration endpoint // Decode and validate openid-configuration endpoint
if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil { u, err := url.Parse(o.ConfigurationEndpoint)
if err != nil {
return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint)
}
if !strings.Contains(u.Path, "/.well-known/openid-configuration") {
u.Path = path.Join(u.Path, "/.well-known/openid-configuration")
}
if err := getAndDecode(u.String(), &o.configuration); err != nil {
return err return err
} }
if err := o.configuration.Validate(); err != nil { if err := o.configuration.Validate(); err != nil {
@ -336,6 +354,8 @@ func (o *OIDC) authorizeSSHSign(claims *openIDPayload) ([]SignOption, error) {
&sshDefaultExtensionModifier{}, &sshDefaultExtensionModifier{},
// checks the validity bounds, and set the validity if has not been set // checks the validity bounds, and set the validity if has not been set
&sshCertificateValidityModifier{o.claimer}, &sshCertificateValidityModifier{o.claimer},
// validate public key
&sshDefaultPublicKeyValidator{},
// require all the fields in the SSH certificate // require all the fields in the SSH certificate
&sshCertificateDefaultValidator{}, &sshCertificateDefaultValidator{},
), nil ), nil

View file

@ -3,6 +3,8 @@ package provisioner
import ( import (
"context" "context"
"crypto" "crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"strings" "strings"
@ -79,6 +81,7 @@ func TestOIDC_Init(t *testing.T) {
Claims *Claims Claims *Claims
Admins []string Admins []string
Domains []string Domains []string
ListenAddress string
} }
type args struct { type args struct {
config Config config Config
@ -89,16 +92,21 @@ func TestOIDC_Init(t *testing.T) {
args args args args
wantErr bool wantErr bool
}{ }{
{"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false}, {"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, false},
{"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}, nil}, args{config}, false}, {"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/.well-known/openid-configuration", nil, []string{"foo@smallstep.com"}, nil, ""}, args{config}, false},
{"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, []string{"smallstep.com"}}, args{config}, false}, {"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, []string{"smallstep.com"}, ""}, args{config}, false},
{"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false}, {"ok-listen-port", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ":10000"}, args{config}, false},
{"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true}, {"ok-listen-host-port", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, "127.0.0.1:10000"}, args{config}, false},
{"no-type", fields{"", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true}, {"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL, nil, nil, nil, ""}, args{config}, false},
{"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true}, {"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, true},
{"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil}, args{config}, true}, {"no-type", fields{"", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, true},
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil}, args{config}, true}, {"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, true},
{"bad-claims", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", badClaims, nil, nil}, args{config}, true}, {"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil, ""}, args{config}, true},
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/random", nil, nil, nil, ""}, args{config}, true},
{"bad-claims", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/.well-known/openid-configuration", badClaims, nil, nil, ""}, args{config}, true},
{"bad-parse-url", fields{"oidc", "name", "client-id", "client-secret", ":", nil, nil, nil, ""}, args{config}, true},
{"bad-get-url", fields{"oidc", "name", "client-id", "client-secret", "https://", nil, nil, nil, ""}, args{config}, true},
{"bad-listen-address", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, "127.0.0.1"}, args{config}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -109,9 +117,12 @@ func TestOIDC_Init(t *testing.T) {
ConfigurationEndpoint: tt.fields.ConfigurationEndpoint, ConfigurationEndpoint: tt.fields.ConfigurationEndpoint,
Claims: tt.fields.Claims, Claims: tt.fields.Claims,
Admins: tt.fields.Admins, Admins: tt.fields.Admins,
Domains: tt.fields.Domains,
ListenAddress: tt.fields.ListenAddress,
} }
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr)
return
} }
if tt.wantErr == false { if tt.wantErr == false {
assert.Len(t, 2, p.keyStore.keySet.Keys) assert.Len(t, 2, p.keyStore.keySet.Keys)
@ -343,6 +354,12 @@ func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
signer, err := generateJSONWebKey() signer, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
pub := key.Public().Key
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
assert.FatalError(t, err)
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SSHOptions{ expectedUserOptions := &SSHOptions{
@ -361,6 +378,7 @@ func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
type args struct { type args struct {
token string token string
sshOpts SSHOptions sshOpts SSHOptions
key interface{}
} }
tests := []struct { tests := []struct {
name string name string
@ -370,18 +388,20 @@ func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
wantErr bool wantErr bool
wantSignErr bool wantSignErr bool
}{ }{
{"ok", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false}, {"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false},
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false}, {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false},
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false}, {"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false},
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false}, {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"admin", p3, args{okAdmin, SSHOptions{}}, expectedAdminOptions, false, false}, {"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}}, expectedAdminOptions, false, false}, {"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, false, false},
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}}, expectedAdminOptions, false, false}, {"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, false, false},
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false}, {"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub}, expectedAdminOptions, false, false},
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false}, {"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}}, nil, false, true}, {"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}}, nil, false, true}, {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true},
{"fail-email", p3, args{failEmail, SSHOptions{}}, nil, true, false}, {"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, false, true},
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, false, true},
{"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -394,7 +414,7 @@ func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
if err != nil { if err != nil {
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr { if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else { } else {

View file

@ -84,6 +84,8 @@ const (
TypeAWS Type = 4 TypeAWS Type = 4
// TypeAzure is used to indicate the Azure provisioners. // TypeAzure is used to indicate the Azure provisioners.
TypeAzure Type = 5 TypeAzure Type = 5
// TypeACME is used to indicate the ACME provisioners.
TypeACME Type = 6
// RevokeAudienceKey is the key for the 'revoke' audiences in the audiences map. // RevokeAudienceKey is the key for the 'revoke' audiences in the audiences map.
RevokeAudienceKey = "revoke" RevokeAudienceKey = "revoke"
@ -104,6 +106,8 @@ func (t Type) String() string {
return "AWS" return "AWS"
case TypeAzure: case TypeAzure:
return "Azure" return "Azure"
case TypeACME:
return "ACME"
default: default:
return "" return ""
} }
@ -151,6 +155,8 @@ func (l *List) UnmarshalJSON(data []byte) error {
p = &AWS{} p = &AWS{}
case "azure": case "azure":
p = &Azure{} p = &Azure{}
case "acme":
p = &ACME{}
default: default:
// Skip unsupported provisioners. A client using this method may be // Skip unsupported provisioners. A client using this method may be
// compiled with a version of smallstep/certificates that does not // compiled with a version of smallstep/certificates that does not
@ -197,3 +203,93 @@ func SanitizeSSHUserPrincipal(email string) string {
} }
}, strings.ToLower(email)) }, strings.ToLower(email))
} }
// MockProvisioner for testing
type MockProvisioner struct {
Mret1, Mret2, Mret3 interface{}
Merr error
MgetID func() string
MgetTokenID func(string) (string, error)
MgetName func() string
MgetType func() Type
MgetEncryptedKey func() (string, string, bool)
Minit func(Config) error
MauthorizeRevoke func(ott string) error
MauthorizeSign func(ctx context.Context, ott string) ([]SignOption, error)
MauthorizeRenewal func(*x509.Certificate) error
}
// GetID mock
func (m *MockProvisioner) GetID() string {
if m.MgetID != nil {
return m.MgetID()
}
return m.Mret1.(string)
}
// GetTokenID mock
func (m *MockProvisioner) GetTokenID(token string) (string, error) {
if m.MgetTokenID != nil {
return m.MgetTokenID(token)
}
if m.Mret1 == nil {
return "", m.Merr
}
return m.Mret1.(string), m.Merr
}
// GetName mock
func (m *MockProvisioner) GetName() string {
if m.MgetName != nil {
return m.MgetName()
}
return m.Mret1.(string)
}
// GetType mock
func (m *MockProvisioner) GetType() Type {
if m.MgetType != nil {
return m.MgetType()
}
return m.Mret1.(Type)
}
// GetEncryptedKey mock
func (m *MockProvisioner) GetEncryptedKey() (string, string, bool) {
if m.MgetEncryptedKey != nil {
return m.MgetEncryptedKey()
}
return m.Mret1.(string), m.Mret2.(string), m.Mret3.(bool)
}
// Init mock
func (m *MockProvisioner) Init(c Config) error {
if m.Minit != nil {
return m.Minit(c)
}
return m.Merr
}
// AuthorizeRevoke mock
func (m *MockProvisioner) AuthorizeRevoke(ott string) error {
if m.MauthorizeRevoke != nil {
return m.MauthorizeRevoke(ott)
}
return m.Merr
}
// AuthorizeSign mock
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]SignOption, error) {
if m.MauthorizeSign != nil {
return m.MauthorizeSign(ctx, ott)
}
return m.Mret1.([]SignOption), m.Merr
}
// AuthorizeRenewal mock
func (m *MockProvisioner) AuthorizeRenewal(c *x509.Certificate) error {
if m.MauthorizeRenewal != nil {
return m.MauthorizeRenewal(c)
}
return m.Merr
}

View file

@ -1,9 +1,13 @@
package provisioner package provisioner
import ( import (
"crypto/rsa"
"encoding/binary"
"math/big"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/cli/crypto/keys"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -212,7 +216,7 @@ func (m *sshCertificateValidityModifier) Modify(cert *ssh.Certificate) error {
} }
if cert.ValidAfter == 0 { if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now().Unix()) cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
} }
if cert.ValidBefore == 0 { if cert.ValidBefore == 0 {
t := time.Unix(int64(cert.ValidAfter), 0) t := time.Unix(int64(cert.ValidAfter), 0)
@ -261,9 +265,11 @@ func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
case len(cert.ValidPrincipals) == 0: case len(cert.ValidPrincipals) == 0:
return errors.New("ssh certificate valid principals cannot be empty") return errors.New("ssh certificate valid principals cannot be empty")
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return errors.New("ssh certificate valid after cannot be 0") return errors.New("ssh certificate validAfter cannot be 0")
case cert.ValidBefore == 0: case cert.ValidBefore < uint64(now().Unix()):
return errors.New("ssh certificate valid before cannot be 0") return errors.New("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter:
return errors.New("ssh certificate validBefore cannot be before validAfter")
case cert.CertType == ssh.UserCert && len(cert.Extensions) == 0: case cert.CertType == ssh.UserCert && len(cert.Extensions) == 0:
return errors.New("ssh certificate extensions cannot be empty") return errors.New("ssh certificate extensions cannot be empty")
case cert.SignatureKey == nil: case cert.SignatureKey == nil:
@ -275,6 +281,36 @@ func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
} }
} }
// sshDefaultPublicKeyValidator implements a validator for the certificate key.
type sshDefaultPublicKeyValidator struct{}
// Valid checks that certificate request common name matches the one configured.
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error {
if cert.Key == nil {
return errors.New("ssh certificate key cannot be nil")
}
switch cert.Key.Type() {
case ssh.KeyAlgoRSA:
_, in, ok := sshParseString(cert.Key.Marshal())
if !ok {
return errors.New("ssh certificate key is invalid")
}
key, err := sshParseRSAPublicKey(in)
if err != nil {
return err
}
if key.Size() < keys.MinRSAKeyBytes {
return errors.Errorf("ssh certificate key must be at least %d bits (%d bytes)",
8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes)
}
return nil
case ssh.KeyAlgoDSA:
return errors.New("ssh certificate key algorithm (DSA) is not supported")
default:
return nil
}
}
// sshCertTypeUInt32 // sshCertTypeUInt32
func sshCertTypeUInt32(ct string) uint32 { func sshCertTypeUInt32(ct string) uint32 {
switch ct { switch ct {
@ -304,3 +340,41 @@ func containsAllMembers(group, subgroup []string) bool {
} }
return true return true
} }
func sshParseString(in []byte) (out, rest []byte, ok bool) {
if len(in) < 4 {
return
}
length := binary.BigEndian.Uint32(in)
in = in[4:]
if uint32(len(in)) < length {
return
}
out = in[:length]
rest = in[length:]
ok = true
return
}
func sshParseRSAPublicKey(in []byte) (*rsa.PublicKey, error) {
var w struct {
E *big.Int
N *big.Int
Rest []byte `ssh:"rest"`
}
if err := ssh.Unmarshal(in, &w); err != nil {
return nil, errors.Wrap(err, "error unmarshalling public key")
}
if w.E.BitLen() > 24 {
return nil, errors.New("invalid public key: exponent too large")
}
e := w.E.Int64()
if e < 3 || e&1 == 0 {
return nil, errors.New("invalid public key: incorrect exponent")
}
var key rsa.PublicKey
key.E = int(e)
key.N = w.N
return &key, nil
}

View file

@ -0,0 +1,192 @@
package provisioner
import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/cli/crypto/keys"
"golang.org/x/crypto/ssh"
)
func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
pub, _, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
sshPub, err := ssh.NewPublicKey(pub)
assert.FatalError(t, err)
v := sshCertificateDefaultValidator{}
tests := []struct {
name string
cert *ssh.Certificate
err error
}{
{
"fail/zero-nonce",
&ssh.Certificate{},
errors.New("ssh certificate nonce cannot be empty"),
},
{
"fail/nil-key",
&ssh.Certificate{Nonce: []byte("foo")},
errors.New("ssh certificate key cannot be nil"),
},
{
"fail/zero-serial",
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub},
errors.New("ssh certificate serial cannot be 0"),
},
{
"fail/unexpected-cert-type",
// UserCert = 1, HostCert = 2
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1},
errors.New("ssh certificate has an unknown type: 3"),
},
{
"fail/empty-cert-key-id",
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1},
errors.New("ssh certificate key id cannot be empty"),
},
{
"fail/empty-valid-principals",
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo"},
errors.New("ssh certificate valid principals cannot be empty"),
},
{
"fail/zero-validAfter",
&ssh.Certificate{
Nonce: []byte("foo"),
Key: sshPub,
Serial: 1,
CertType: 1,
KeyId: "foo",
ValidPrincipals: []string{"foo"},
ValidAfter: 0,
},
errors.New("ssh certificate validAfter cannot be 0"),
},
{
"fail/validBefore-past",
&ssh.Certificate{
Nonce: []byte("foo"),
Key: sshPub,
Serial: 1,
CertType: 1,
KeyId: "foo",
ValidPrincipals: []string{"foo"},
ValidAfter: uint64(time.Now().Add(-10 * time.Minute).Unix()),
ValidBefore: uint64(time.Now().Add(-5 * time.Minute).Unix()),
},
errors.New("ssh certificate validBefore cannot be in the past"),
},
{
"fail/validAfter-after-validBefore",
&ssh.Certificate{
Nonce: []byte("foo"),
Key: sshPub,
Serial: 1,
CertType: 1,
KeyId: "foo",
ValidPrincipals: []string{"foo"},
ValidAfter: uint64(time.Now().Add(15 * time.Minute).Unix()),
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
},
errors.New("ssh certificate validBefore cannot be before validAfter"),
},
{
"fail/empty-extensions",
&ssh.Certificate{
Nonce: []byte("foo"),
Key: sshPub,
Serial: 1,
CertType: 1,
KeyId: "foo",
ValidPrincipals: []string{"foo"},
ValidAfter: uint64(time.Now().Unix()),
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
},
errors.New("ssh certificate extensions cannot be empty"),
},
{
"fail/nil-signature-key",
&ssh.Certificate{
Nonce: []byte("foo"),
Key: sshPub,
Serial: 1,
CertType: 1,
KeyId: "foo",
ValidPrincipals: []string{"foo"},
ValidAfter: uint64(time.Now().Unix()),
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
Permissions: ssh.Permissions{
Extensions: map[string]string{"foo": "bar"},
},
},
errors.New("ssh certificate signature key cannot be nil"),
},
{
"fail/nil-signature",
&ssh.Certificate{
Nonce: []byte("foo"),
Key: sshPub,
Serial: 1,
CertType: 1,
KeyId: "foo",
ValidPrincipals: []string{"foo"},
ValidAfter: uint64(time.Now().Unix()),
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
Permissions: ssh.Permissions{
Extensions: map[string]string{"foo": "bar"},
},
SignatureKey: sshPub,
},
errors.New("ssh certificate signature cannot be nil"),
},
{
"ok/userCert",
&ssh.Certificate{
Nonce: []byte("foo"),
Key: sshPub,
Serial: 1,
CertType: 1,
KeyId: "foo",
ValidPrincipals: []string{"foo"},
ValidAfter: uint64(time.Now().Unix()),
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
Permissions: ssh.Permissions{
Extensions: map[string]string{"foo": "bar"},
},
SignatureKey: sshPub,
Signature: &ssh.Signature{},
},
nil,
},
{
"ok/hostCert",
&ssh.Certificate{
Nonce: []byte("foo"),
Key: sshPub,
Serial: 1,
CertType: 2,
KeyId: "foo",
ValidPrincipals: []string{"foo"},
ValidAfter: uint64(time.Now().Unix()),
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
SignatureKey: sshPub,
Signature: &ssh.Signature{},
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := v.Valid(tt.cert); err != nil {
if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} else {
assert.Nil(t, tt.err)
}
})
}
}

View file

@ -248,10 +248,6 @@ func TestTimeDuration_Unix(t *testing.T) {
func TestTimeDuration_String(t *testing.T) { func TestTimeDuration_String(t *testing.T) {
tm, fn := mockNow() tm, fn := mockNow()
defer fn() defer fn()
type fields struct {
t time.Time
d time.Duration
}
tests := []struct { tests := []struct {
name string name string
timeDuration *TimeDuration timeDuration *TimeDuration

View file

@ -709,7 +709,7 @@ func generateJWKServer(n int) *httptest.Server {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case "/hits": case "/hits":
writeJSON(w, hits) writeJSON(w, hits)
case "/openid-configuration", "/.well-known/openid-configuration": case "/.well-known/openid-configuration":
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"}) writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
case "/random": case "/random":
keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet) keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
@ -730,3 +730,15 @@ func generateJWKServer(n int) *httptest.Server {
srv.Start() srv.Start()
return srv return srv
} }
func generateACME() (*ACME, error) {
// Initialize provisioners
p := &ACME{
Type: "ACME",
Name: "test@acme-provisioner.com",
}
if err := p.Init(Config{Claims: globalProvisionerClaims}); err != nil {
return nil, err
}
return p, nil
}

View file

@ -35,3 +35,13 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
} }
return p, nil return p, nil
} }
// LoadProvisionerByID returns an interface to the provisioner with the given ID.
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
p, ok := a.provisioners.Load(id)
if !ok {
return nil, &apiError{errors.Errorf("provisioner not found"),
http.StatusNotFound, apiCtx{}}
}
return p, nil
}

8
authority/testdata/certs/badsig.csr vendored Normal file
View file

@ -0,0 +1,8 @@
-----BEGIN CERTIFICATE REQUEST-----
MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI
zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI
cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ
DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ
ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5
OI+cWOIc/IGwqZul/zEF5dani5ihOL7UwA==
-----END CERTIFICATE REQUEST-----

8
authority/testdata/certs/foo.csr vendored Normal file
View file

@ -0,0 +1,8 @@
-----BEGIN CERTIFICATE REQUEST-----
MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI
zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI
cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ
DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ
ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5
OI+cWOIc/IGwqZul/zEF5dani5ihOR7UwA==
-----END CERTIFICATE REQUEST-----

View file

@ -212,7 +212,6 @@ type RevokeOptions struct {
MTLS bool MTLS bool
Crt *x509.Certificate Crt *x509.Certificate
OTT string OTT string
errCtxt map[string]interface{}
} }
// Revoke revokes a certificate. // Revoke revokes a certificate.

354
ca/acmeClient.go Normal file
View file

@ -0,0 +1,354 @@
package ca
import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api"
"github.com/smallstep/cli/jose"
)
// ACMEClient implements an HTTP client to an ACME API.
type ACMEClient struct {
client *http.Client
dirLoc string
dir *acme.Directory
acc *acme.Account
Key *jose.JSONWebKey
kid string
}
// NewACMEClient initializes a new ACMEClient.
func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*ACMEClient, error) {
// Retrieve transport from options.
o := new(clientOptions)
if err := o.apply(opts); err != nil {
return nil, err
}
tr, err := o.getTransport(endpoint)
if err != nil {
return nil, err
}
ac := &ACMEClient{
client: &http.Client{
Transport: tr,
},
dirLoc: endpoint,
}
resp, err := ac.client.Get(endpoint)
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", endpoint)
}
if resp.StatusCode >= 400 {
return nil, readACMEError(resp.Body)
}
var dir acme.Directory
if err := readJSON(resp.Body, &dir); err != nil {
return nil, errors.Wrapf(err, "error reading %s", endpoint)
}
ac.dir = &dir
ac.Key, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
if err != nil {
return nil, err
}
nar := &acmeAPI.NewAccountRequest{
Contact: contact,
TermsOfServiceAgreed: true,
}
payload, err := json.Marshal(nar)
if err != nil {
return nil, errors.Wrap(err, "error marshaling new account request")
}
resp, err = ac.post(payload, ac.dir.NewAccount, withJWK(ac))
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, readACMEError(resp.Body)
}
var acc acme.Account
if err := readJSON(resp.Body, &acc); err != nil {
return nil, errors.Wrapf(err, "error reading %s", dir.NewAccount)
}
ac.acc = &acc
ac.kid = resp.Header.Get("Location")
return ac, nil
}
// GetDirectory makes a directory request to the ACME api and returns an
// ACME directory object.
func (c *ACMEClient) GetDirectory() (*acme.Directory, error) {
return c.dir, nil
}
// GetNonce makes a nonce request to the ACME api and returns an
// ACME directory object.
func (c *ACMEClient) GetNonce() (string, error) {
resp, err := c.client.Get(c.dir.NewNonce)
if err != nil {
return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce)
}
if resp.StatusCode >= 400 {
return "", readACMEError(resp.Body)
}
return resp.Header.Get("Replay-Nonce"), nil
}
type withHeaderOption func(so *jose.SignerOptions)
func withJWK(c *ACMEClient) withHeaderOption {
return func(so *jose.SignerOptions) {
so.WithHeader("jwk", c.Key.Public())
}
}
func withKid(c *ACMEClient) withHeaderOption {
return func(so *jose.SignerOptions) {
so.WithHeader("kid", c.kid)
}
}
// serialize serializes a json web signature and doesn't omit empty fields.
func serialize(obj *jose.JSONWebSignature) (string, error) {
raw, err := obj.CompactSerialize()
if err != nil {
return "", errors.Wrap(err, "error serializing JWS")
}
parts := strings.Split(raw, ".")
msg := struct {
Protected string `json:"protected"`
Payload string `json:"payload"`
Signature string `json:"signature"`
}{Protected: parts[0], Payload: parts[1], Signature: parts[2]}
b, err := json.Marshal(msg)
if err != nil {
return "", errors.Wrap(err, "error marshaling jws message")
}
return string(b), nil
}
func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOption) (*http.Response, error) {
if c.Key == nil {
return nil, errors.New("acme client not configured with account")
}
nonce, err := c.GetNonce()
if err != nil {
return nil, err
}
so := new(jose.SignerOptions)
so.WithHeader("nonce", nonce)
so.WithHeader("url", url)
for _, hop := range headerOps {
hop(so)
}
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(c.Key.Algorithm),
Key: c.Key.Key,
}, so)
if err != nil {
return nil, errors.Wrap(err, "error creating JWS signer")
}
signed, err := signer.Sign(payload)
if err != nil {
return nil, errors.Errorf("error signing payload: %s", strings.TrimPrefix(err.Error(), "square/go-jose: "))
}
raw, err := serialize(signed)
if err != nil {
return nil, err
}
resp, err := c.client.Post(url, "application/jose+json", strings.NewReader(raw))
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", c.dir.NewOrder)
}
return resp, nil
}
// NewOrder creates and returns the information for a new ACME order.
func (c *ACMEClient) NewOrder(payload []byte) (*acme.Order, error) {
resp, err := c.post(payload, c.dir.NewOrder, withKid(c))
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, readACMEError(resp.Body)
}
var o acme.Order
if err := readJSON(resp.Body, &o); err != nil {
return nil, errors.Wrapf(err, "error reading %s", c.dir.NewOrder)
}
o.ID = resp.Header.Get("Location")
return &o, nil
}
// GetChallenge returns the Challenge at the given path.
// With the validate parameter set to True this method will attempt to validate the
// challenge before returning it.
func (c *ACMEClient) GetChallenge(url string) (*acme.Challenge, error) {
resp, err := c.post(nil, url, withKid(c))
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, readACMEError(resp.Body)
}
var ch acme.Challenge
if err := readJSON(resp.Body, &ch); err != nil {
return nil, errors.Wrapf(err, "error reading %s", url)
}
return &ch, nil
}
// ValidateChallenge returns the Challenge at the given path.
// With the validate parameter set to True this method will attempt to validate the
// challenge before returning it.
func (c *ACMEClient) ValidateChallenge(url string) error {
resp, err := c.post([]byte("{}"), url, withKid(c))
if err != nil {
return err
}
if resp.StatusCode >= 400 {
return readACMEError(resp.Body)
}
return nil
}
// GetAuthz returns the Authz at the given path.
func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) {
resp, err := c.post(nil, url, withKid(c))
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, readACMEError(resp.Body)
}
var az acme.Authz
if err := readJSON(resp.Body, &az); err != nil {
return nil, errors.Wrapf(err, "error reading %s", url)
}
return &az, nil
}
// GetOrder returns the Order at the given path.
func (c *ACMEClient) GetOrder(url string) (*acme.Order, error) {
resp, err := c.post(nil, url, withKid(c))
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, readACMEError(resp.Body)
}
var o acme.Order
if err := readJSON(resp.Body, &o); err != nil {
return nil, errors.Wrapf(err, "error reading %s", url)
}
return &o, nil
}
// FinalizeOrder makes a finalize request to the ACME api.
func (c *ACMEClient) FinalizeOrder(url string, csr *x509.CertificateRequest) error {
payload, err := json.Marshal(acmeAPI.FinalizeRequest{
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
})
if err != nil {
return errors.Wrap(err, "error marshaling finalize request")
}
resp, err := c.post(payload, url, withKid(c))
if err != nil {
return err
}
if resp.StatusCode >= 400 {
return readACMEError(resp.Body)
}
return nil
}
// GetCertificate retrieves the certificate along with all intermediates.
func (c *ACMEClient) GetCertificate(url string) (*x509.Certificate, []*x509.Certificate, error) {
resp, err := c.post(nil, url, withKid(c))
if err != nil {
return nil, nil, err
}
if resp.StatusCode >= 400 {
return nil, nil, readACMEError(resp.Body)
}
defer resp.Body.Close()
bodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, nil, errors.Wrap(err, "error reading GET certificate response")
}
var certs []*x509.Certificate
block, rest := pem.Decode(bodyBytes)
if block == nil {
return nil, nil, errors.New("failed to parse any certificates from response")
}
for block != nil {
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing certificate pem response")
}
certs = append(certs, cert)
block, rest = pem.Decode(rest)
}
return certs[0], certs[1:], nil
}
// GetAccountOrders retrieves the orders belonging to the given account.
func (c *ACMEClient) GetAccountOrders() ([]string, error) {
if c.acc == nil {
return nil, errors.New("acme client not configured with account")
}
resp, err := c.post(nil, c.acc.Orders, withKid(c))
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, readACMEError(resp.Body)
}
var orders []string
if err := readJSON(resp.Body, &orders); err != nil {
return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders)
}
return orders, nil
}
func readACMEError(r io.ReadCloser) error {
defer r.Close()
b, err := ioutil.ReadAll(r)
if err != nil {
return errors.Wrap(err, "error reading from body")
}
ae := new(acme.AError)
err = json.Unmarshal(b, &ae)
// If we successfully marshaled to an ACMEError then return the ACMEError.
if err != nil || len(ae.Error()) == 0 {
fmt.Printf("b = %s\n", b)
// Throw up our hands.
return errors.Errorf("%s", b)
}
return ae
}

1355
ca/acmeClient_test.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -570,8 +570,8 @@ func TestBootstrapListener(t *testing.T) {
return return
} }
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
go func() {
wg.Add(1) wg.Add(1)
go func() {
http.Serve(lis, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Serve(lis, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok")) w.Write([]byte("ok"))
})) }))

View file

@ -3,18 +3,23 @@ package ca
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt"
"log" "log"
"net/http" "net/http"
"net/url"
"reflect" "reflect"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/monitoring" "github.com/smallstep/certificates/monitoring"
"github.com/smallstep/certificates/server" "github.com/smallstep/certificates/server"
"github.com/smallstep/nosql"
) )
type options struct { type options struct {
@ -100,13 +105,47 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) {
mux := chi.NewRouter() mux := chi.NewRouter()
handler := http.Handler(mux) handler := http.Handler(mux)
// Add api endpoints in / and /1.0 // Add regular CA api endpoints in / and /1.0
routerHandler := api.New(auth) routerHandler := api.New(auth)
routerHandler.Route(mux) routerHandler.Route(mux)
mux.Route("/1.0", func(r chi.Router) { mux.Route("/1.0", func(r chi.Router) {
routerHandler.Route(r) routerHandler.Route(r)
}) })
//Add ACME api endpoints in /acme and /1.0/acme
dns := config.DNSNames[0]
u, err := url.Parse("https://" + config.Address)
if err != nil {
return nil, err
}
port := u.Port()
if port != "" && port != "443" {
dns = fmt.Sprintf("%s:%s", dns, port)
}
prefix := "acme"
acmeAuth := acme.NewAuthority(auth.GetDatabase().(nosql.DB), dns, prefix, auth)
acmeRouterHandler := acmeAPI.New(acmeAuth)
mux.Route("/"+prefix, func(r chi.Router) {
acmeRouterHandler.Route(r)
})
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
// of the ACME spec.
mux.Route("/2.0/"+prefix, func(r chi.Router) {
acmeRouterHandler.Route(r)
})
/*
// helpful routine for logging all routes //
walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
fmt.Printf("%s %s\n", method, route)
return nil
}
if err := chi.Walk(mux, walkFunc); err != nil {
fmt.Printf("Logging err: %s\n", err.Error())
}
*/
// Add monitoring if configured // Add monitoring if configured
if len(config.Monitoring) > 0 { if len(config.Monitoring) > 0 {
m, err := monitoring.New(config.Monitoring) m, err := monitoring.New(config.Monitoring)

View file

@ -163,8 +163,7 @@ func TestClient_Health(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.Health() got, err := c.Health()
@ -224,8 +223,7 @@ func TestClient_Root(t *testing.T) {
if req.RequestURI != expected { if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
} }
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.Root(tt.shasum) got, err := c.Root(tt.shasum)
@ -303,8 +301,7 @@ func TestClient_Sign(t *testing.T) {
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
} }
} }
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.Sign(tt.request) got, err := c.Sign(tt.request)
@ -378,8 +375,7 @@ func TestClient_Revoke(t *testing.T) {
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
} }
} }
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.Revoke(tt.request, nil) got, err := c.Revoke(tt.request, nil)
@ -438,8 +434,7 @@ func TestClient_Renew(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.Renew(nil) got, err := c.Renew(nil)
@ -502,8 +497,7 @@ func TestClient_Provisioners(t *testing.T) {
if req.RequestURI != tt.expectedURI { if req.RequestURI != tt.expectedURI {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
} }
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.Provisioners(tt.args...) got, err := c.Provisioners(tt.args...)
@ -562,8 +556,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
if req.RequestURI != expected { if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
} }
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.ProvisionerKey(tt.kid) got, err := c.ProvisionerKey(tt.kid)
@ -622,8 +615,7 @@ func TestClient_Roots(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.Roots() got, err := c.Roots()
@ -683,8 +675,7 @@ func TestClient_Federation(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.Federation() got, err := c.Federation()
@ -783,8 +774,7 @@ func TestClient_RootFingerprint(t *testing.T) {
} }
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode) api.JSONStatus(w, tt.response, tt.responseCode)
api.JSON(w, tt.response)
}) })
got, err := c.RootFingerprint() got, err := c.RootFingerprint()

131
db/db.go
View file

@ -7,6 +7,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
) )
var ( var (
@ -54,7 +55,7 @@ func New(c *Config) (AuthDB, error) {
return nil, errors.Wrapf(err, "Error opening database of Type %s with source %s", c.Type, c.DataSource) return nil, errors.Wrapf(err, "Error opening database of Type %s with source %s", c.Type, c.DataSource)
} }
tables := [][]byte{revokedCertsTable, certsTable} tables := [][]byte{revokedCertsTable, certsTable, usedOTTTable}
for _, b := range tables { for _, b := range tables {
if err := db.CreateTable(b); err != nil { if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s", return nil, errors.Wrapf(err, "error creating table %s",
@ -102,22 +103,20 @@ func (db *DB) IsRevoked(sn string) (bool, error) {
// Revoke adds a certificate to the revocation table. // Revoke adds a certificate to the revocation table.
func (db *DB) Revoke(rci *RevokedCertificateInfo) error { func (db *DB) Revoke(rci *RevokedCertificateInfo) error {
isRvkd, err := db.IsRevoked(rci.Serial)
if err != nil {
return err
}
if isRvkd {
return ErrAlreadyExists
}
rcib, err := json.Marshal(rci) rcib, err := json.Marshal(rci)
if err != nil { if err != nil {
return errors.Wrap(err, "error marshaling revoked certificate info") return errors.Wrap(err, "error marshaling revoked certificate info")
} }
if err = db.Set(revokedCertsTable, []byte(rci.Serial), rcib); err != nil { _, swapped, err := db.CmpAndSwap(revokedCertsTable, []byte(rci.Serial), nil, rcib)
return errors.Wrap(err, "database Set error") switch {
} case err != nil:
return errors.Wrap(err, "error AuthDB CmpAndSwap")
case !swapped:
return ErrAlreadyExists
default:
return nil return nil
}
} }
// StoreCertificate stores a certificate PEM. // StoreCertificate stores a certificate PEM.
@ -132,15 +131,11 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
// for the first time, false otherwise. // for the first time, false otherwise.
func (db *DB) UseToken(id, tok string) (bool, error) { func (db *DB) UseToken(id, tok string) (bool, error) {
_, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok)) _, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok))
switch { if err != nil {
case err != nil:
return false, errors.Wrapf(err, "error storing used token %s/%s", return false, errors.Wrapf(err, "error storing used token %s/%s",
string(usedOTTTable), id) string(usedOTTTable), id)
case !swapped:
return false, nil
default:
return true, nil
} }
return swapped, nil
} }
// Shutdown sends a shutdown message to the database. // Shutdown sends a shutdown message to the database.
@ -153,3 +148,105 @@ func (db *DB) Shutdown() error {
} }
return nil return nil
} }
// MockNoSQLDB //
type MockNoSQLDB struct {
Err error
Ret1, Ret2 interface{}
MGet func(bucket, key []byte) ([]byte, error)
MSet func(bucket, key, value []byte) error
MOpen func(dataSourceName string, opt ...database.Option) error
MClose func() error
MCreateTable func(bucket []byte) error
MDeleteTable func(bucket []byte) error
MDel func(bucket, key []byte) error
MList func(bucket []byte) ([]*database.Entry, error)
MUpdate func(tx *database.Tx) error
MCmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error)
}
// CmpAndSwap mock
func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) {
if m.MCmpAndSwap != nil {
return m.MCmpAndSwap(bucket, key, old, newval)
}
if m.Ret1 == nil {
return nil, false, m.Err
}
return m.Ret1.([]byte), m.Ret2.(bool), m.Err
}
// Get mock
func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) {
if m.MGet != nil {
return m.MGet(bucket, key)
}
if m.Ret1 == nil {
return nil, m.Err
}
return m.Ret1.([]byte), m.Err
}
// Set mock
func (m *MockNoSQLDB) Set(bucket, key, value []byte) error {
if m.MSet != nil {
return m.MSet(bucket, key, value)
}
return m.Err
}
// Open mock
func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error {
if m.MOpen != nil {
return m.MOpen(dataSourceName, opt...)
}
return m.Err
}
// Close mock
func (m *MockNoSQLDB) Close() error {
if m.MClose != nil {
return m.MClose()
}
return m.Err
}
// CreateTable mock
func (m *MockNoSQLDB) CreateTable(bucket []byte) error {
if m.MCreateTable != nil {
return m.MCreateTable(bucket)
}
return m.Err
}
// DeleteTable mock
func (m *MockNoSQLDB) DeleteTable(bucket []byte) error {
if m.MDeleteTable != nil {
return m.MDeleteTable(bucket)
}
return m.Err
}
// Del mock
func (m *MockNoSQLDB) Del(bucket, key []byte) error {
if m.MDel != nil {
return m.MDel(bucket, key)
}
return m.Err
}
// List mock
func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) {
if m.MList != nil {
return m.MList(bucket)
}
return m.Ret1.([]*database.Entry), m.Err
}
// Update mock
func (m *MockNoSQLDB) Update(tx *database.Tx) error {
if m.MUpdate != nil {
return m.MUpdate(tx)
}
return m.Err
}

View file

@ -8,97 +8,6 @@ import (
"github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
) )
type MockNoSQLDB struct {
err error
ret1, ret2 interface{}
get func(bucket, key []byte) ([]byte, error)
set func(bucket, key, value []byte) error
open func(dataSourceName string, opt ...database.Option) error
close func() error
createTable func(bucket []byte) error
deleteTable func(bucket []byte) error
del func(bucket, key []byte) error
list func(bucket []byte) ([]*database.Entry, error)
update func(tx *database.Tx) error
cmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error)
}
func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) {
if m.cmpAndSwap != nil {
return m.cmpAndSwap(bucket, key, old, newval)
}
if m.ret1 == nil {
return nil, false, m.err
}
return m.ret1.([]byte), m.ret2.(bool), m.err
}
func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) {
if m.get != nil {
return m.get(bucket, key)
}
if m.ret1 == nil {
return nil, m.err
}
return m.ret1.([]byte), m.err
}
func (m *MockNoSQLDB) Set(bucket, key, value []byte) error {
if m.set != nil {
return m.set(bucket, key, value)
}
return m.err
}
func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error {
if m.open != nil {
return m.open(dataSourceName, opt...)
}
return m.err
}
func (m *MockNoSQLDB) Close() error {
if m.close != nil {
return m.close()
}
return m.err
}
func (m *MockNoSQLDB) CreateTable(bucket []byte) error {
if m.createTable != nil {
return m.createTable(bucket)
}
return m.err
}
func (m *MockNoSQLDB) DeleteTable(bucket []byte) error {
if m.deleteTable != nil {
return m.deleteTable(bucket)
}
return m.err
}
func (m *MockNoSQLDB) Del(bucket, key []byte) error {
if m.del != nil {
return m.del(bucket, key)
}
return m.err
}
func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) {
if m.list != nil {
return m.list(bucket)
}
return m.ret1.([]*database.Entry), m.err
}
func (m *MockNoSQLDB) Update(tx *database.Tx) error {
if m.update != nil {
return m.update(tx)
}
return m.err
}
func TestIsRevoked(t *testing.T) { func TestIsRevoked(t *testing.T) {
tests := map[string]struct { tests := map[string]struct {
key string key string
@ -111,16 +20,16 @@ func TestIsRevoked(t *testing.T) {
}, },
"false/ErrNotFound": { "false/ErrNotFound": {
key: "sn", key: "sn",
db: &DB{&MockNoSQLDB{err: database.ErrNotFound, ret1: nil}, true}, db: &DB{&MockNoSQLDB{Err: database.ErrNotFound, Ret1: nil}, true},
}, },
"error/checking bucket": { "error/checking bucket": {
key: "sn", key: "sn",
db: &DB{&MockNoSQLDB{err: errors.New("force"), ret1: nil}, true}, db: &DB{&MockNoSQLDB{Err: errors.New("force"), Ret1: nil}, true},
err: errors.New("error checking revocation bucket: force"), err: errors.New("error checking revocation bucket: force"),
}, },
"true": { "true": {
key: "sn", key: "sn",
db: &DB{&MockNoSQLDB{ret1: []byte("value")}, true}, db: &DB{&MockNoSQLDB{Ret1: []byte("value")}, true},
isRevoked: true, isRevoked: true,
}, },
} }
@ -148,41 +57,26 @@ func TestRevoke(t *testing.T) {
"error/force isRevoked": { "error/force isRevoked": {
rci: &RevokedCertificateInfo{Serial: "sn"}, rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{ db: &DB{&MockNoSQLDB{
get: func(bucket []byte, sn []byte) ([]byte, error) { MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
return nil, errors.New("force IsRevoked") return nil, false, errors.New("force")
}, },
}, true}, }, true},
err: errors.New("error checking revocation bucket: force IsRevoked"), err: errors.New("error AuthDB CmpAndSwap: force"),
}, },
"error/was already revoked": { "error/was already revoked": {
rci: &RevokedCertificateInfo{Serial: "sn"}, rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{ db: &DB{&MockNoSQLDB{
get: func(bucket []byte, sn []byte) ([]byte, error) { MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
return nil, nil return []byte("foo"), false, nil
}, },
}, true}, }, true},
err: ErrAlreadyExists, err: ErrAlreadyExists,
}, },
"error/database set": {
rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{
get: func(bucket []byte, sn []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
set: func(bucket []byte, key []byte, value []byte) error {
return errors.New("force")
},
}, true},
err: errors.New("database Set error: force"),
},
"ok": { "ok": {
rci: &RevokedCertificateInfo{Serial: "sn"}, rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{ db: &DB{&MockNoSQLDB{
get: func(bucket []byte, sn []byte) ([]byte, error) { MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
return nil, database.ErrNotFound return []byte("foo"), true, nil
},
set: func(bucket []byte, key []byte, value []byte) error {
return nil
}, },
}, true}, }, true},
}, },
@ -214,7 +108,7 @@ func TestUseToken(t *testing.T) {
id: "id", id: "id",
tok: "token", tok: "token",
db: &DB{&MockNoSQLDB{ db: &DB{&MockNoSQLDB{
cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force") return nil, false, errors.New("force")
}, },
}, true}, }, true},
@ -227,7 +121,7 @@ func TestUseToken(t *testing.T) {
id: "id", id: "id",
tok: "token", tok: "token",
db: &DB{&MockNoSQLDB{ db: &DB{&MockNoSQLDB{
cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), false, nil return []byte("foo"), false, nil
}, },
}, true}, }, true},
@ -239,7 +133,7 @@ func TestUseToken(t *testing.T) {
id: "id", id: "id",
tok: "token", tok: "token",
db: &DB{&MockNoSQLDB{ db: &DB{&MockNoSQLDB{
cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("bar"), true, nil return []byte("bar"), true, nil
}, },
}, true}, }, true},

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/nosql/database"
) )
// ErrNotImplemented is an error returned when an operation is Not Implemented. // ErrNotImplemented is an error returned when an operation is Not Implemented.
@ -61,3 +62,57 @@ func (s *SimpleDB) UseToken(id, tok string) (bool, error) {
func (s *SimpleDB) Shutdown() error { func (s *SimpleDB) Shutdown() error {
return nil return nil
} }
// nosql.DB interface implementation //
// Open opens the database available with the given options.
func (s *SimpleDB) Open(dataSourceName string, opt ...database.Option) error {
return ErrNotImplemented
}
// Close closes the current database.
func (s *SimpleDB) Close() error {
return ErrNotImplemented
}
// Get returns the value stored in the given table/bucket and key.
func (s *SimpleDB) Get(bucket, key []byte) ([]byte, error) {
return nil, ErrNotImplemented
}
// Set sets the given value in the given table/bucket and key.
func (s *SimpleDB) Set(bucket, key, value []byte) error {
return ErrNotImplemented
}
// CmpAndSwap swaps the value at the given bucket and key if the current
// value is equivalent to the oldValue input. Returns 'true' if the
// swap was successful and 'false' otherwise.
func (s *SimpleDB) CmpAndSwap(bucket, key, oldValue, newValue []byte) ([]byte, bool, error) {
return nil, false, ErrNotImplemented
}
// Del deletes the data in the given table/bucket and key.
func (s *SimpleDB) Del(bucket, key []byte) error {
return ErrNotImplemented
}
// List returns a list of all the entries in a given table/bucket.
func (s *SimpleDB) List(bucket []byte) ([]*database.Entry, error) {
return nil, ErrNotImplemented
}
// Update performs a transaction with multiple read-write commands.
func (s *SimpleDB) Update(tx *database.Tx) error {
return ErrNotImplemented
}
// CreateTable creates a table or a bucket in the database.
func (s *SimpleDB) CreateTable(bucket []byte) error {
return ErrNotImplemented
}
// DeleteTable deletes a table or a bucket in the database.
func (s *SimpleDB) DeleteTable(bucket []byte) error {
return ErrNotImplemented
}

View file

@ -15,30 +15,38 @@ e.g. `v1.0.2`
* **Release Candidate**: not ready for public use, still testing. must have a * **Release Candidate**: not ready for public use, still testing. must have a
`-rc*` suffix. e.g. `v1.0.2-rc` or `v1.0.2-rc.4` `-rc*` suffix. e.g. `v1.0.2-rc` or `v1.0.2-rc.4`
1. **Update the version of step/cli** ---
1. **Release `cli` first**
``` If you plan to release [`cli`](https://github.com/smallstep/cli) as part of
$ dep ensure -update github.com/smallstep/cli this release, `cli` must be released first. The `certificates` docker container
``` depends on the `cli` container. Make certain to wait until the `cli` travis
build has completed.
2. **Commit all changes.** 2. **Update the version of step/cli**
<pre><code>
<b>$ dep ensure -update github.com/smallstep/cli</b>
</code></pre>
3. **Commit all changes.**
Make sure that the local checkout is up to date with the remote origin and Make sure that the local checkout is up to date with the remote origin and
that all local changes have been pushed. that all local changes have been pushed.
``` <pre><code>
$ git pull --rebase origin master <b>$ git pull --rebase origin master</b>
$ git push <b>$ git push</b>
``` </code></pre>
3. **Tag it!** 4. **Tag it!**
1. **Find the most recent tag.** 1. **Find the most recent tag.**
``` <pre><code>
$ git fetch --tags <b>$ git fetch --tags</b>
$ git tag <b>$ git tag</b>
``` </code></pre>
The new tag needs to be the logical successor of the most recent existing tag. The new tag needs to be the logical successor of the most recent existing tag.
See [versioning](#versioning) section for more information on version numbers. See [versioning](#versioning) section for more information on version numbers.
@ -47,14 +55,14 @@ e.g. `v1.0.2`
Is the new release a *release candidate* or a *standard release*? Is the new release a *release candidate* or a *standard release*?
1. Release Candidate 1. **Release Candidate**
If the most recent tag is a standard release, say `v1.0.2`, then the version If the most recent tag is a standard release, say `v1.0.2`, then the version
of the next release candidate should be `v1.0.3-rc.1`. If the most recent tag of the next release candidate should be `v1.0.3-rc.1`. If the most recent tag
is a release candidate, say `v1.0.2-rc.3`, then the version of the next is a release candidate, say `v1.0.2-rc.3`, then the version of the next
release candidate should be `v1.0.2-rc.4`. release candidate should be `v1.0.2-rc.4`.
2. Standard Release 2. **Standard Release**
If the most recent tag is a standard release, say `v1.0.2`, then the version If the most recent tag is a standard release, say `v1.0.2`, then the version
of the next standard release should be `v1.0.3`. If the most recent tag of the next standard release should be `v1.0.3`. If the most recent tag
@ -64,21 +72,25 @@ e.g. `v1.0.2`
3. **Create a local tag.** 3. **Create a local tag.**
``` <pre><code>
$ git tag v1.0.3 # standard release # standard release
<b>$ git tag v1.0.3</b>
...or ...or
$ git tag v1.0.3-rc.1 # release candidate # release candidate
``` <b>$ git tag v1.0.3-rc.1</b>
</code></pre>
4. **Push the new tag to the remote origin.** 4. **Push the new tag to the remote origin.**
``` <pre><code>
$ git push origin tag v1.0.3 # standard release # standard release
<b>$ git push origin tag v1.0.3</b>
...or ...or
$ git push origin tag v1.0.3-rc.1 # release candidate # release candidate
``` <b>$ git push origin tag v1.0.3-rc.1</b>
</code></pre>
4. Check the build status at 5. **Check the build status at**
[Travis-CI](https://travis-ci.com/smallstep/certificates/builds/). [Travis-CI](https://travis-ci.com/smallstep/certificates/builds/).
Travis will begin by verifying that there are no compilation or linting errors Travis will begin by verifying that there are no compilation or linting errors
@ -97,15 +109,15 @@ e.g. `v1.0.2`
> **NOTE**: if you plan to release `cli` next then you can skip this step. > **NOTE**: if you plan to release `cli` next then you can skip this step.
``` <pre><code>
$ cd archlinux <b>$ cd archlinux</b>
# Get up to date... # Get up to date...
$ git pull origin master <b>$ git pull origin master</b>
$ make <b>$ make</b>
$ ./update --ca v1.0.3 <b>$ ./update --ca v1.0.3</b>
``` </code></pre>
7. **Update the Helm packages** 7. **Update the Helm packages**
@ -125,9 +137,9 @@ e.g. `v1.0.2`
Then create the step-certificates package running: Then create the step-certificates package running:
```sh <pre><code>
$ helm package ./step-certificates <b>$ helm package ./step-certificates</b>
``` </code></pre>
A new file like `step-certificates-<version>.tgz` will be created. A new file like `step-certificates-<version>.tgz` will be created.
Now commit and push your changes (don't commit the tarball) to the master Now commit and push your changes (don't commit the tarball) to the master
@ -136,15 +148,15 @@ e.g. `v1.0.2`
Next checkout the `gh-pages` branch. `git add` the new tar-ball and update Next checkout the `gh-pages` branch. `git add` the new tar-ball and update
the index.yaml using the `helm repo index` command: the index.yaml using the `helm repo index` command:
```sh <pre><code>
$ git checkout gh-pages <b>$ git checkout gh-pages</b>
$ git add "step-certificates-<version>.tgz" <b>$ git add "step-certificates-<version>.tgz"</b>
$ helm repo index --merge index.yaml --url https://smallstep.github.io/helm-charts/ . <b>$ helm repo index --merge index.yaml --url https://smallstep.github.io/helm-charts/ .</b>
$ git commit -a -m "Add package for step-certificates <appVersion>" <b>$ git commit -a -m "Add package for step-certificates <appVersion>"</b>
$ git push origin gh-pages <b>$ git push origin gh-pages</b>
``` </code></pre>
*All Done!* ***All Done!***
## Versioning ## Versioning

160
docs/acme.md Normal file
View file

@ -0,0 +1,160 @@
# Using ACME with `step-ca `
Lets assume youve [installed
`step-ca`](https://smallstep.com/docs/getting-started/#1-installing-step-and-step-ca)
(e.g., using `brew install step`), have it running at `https://ca.internal`,
and youve [bootstrapped your ACME client
system(s)](https://smallstep.com/docs/getting-started/#bootstrapping) (or at
least [installed your root
certificate](https://smallstep.com/docs/cli/ca/root/) at
`~/.step/certs/root_ca.crt`).
## Enabling ACME
To enable ACME, simply [add an ACME provisioner](https://smallstep.com/docs/cli/ca/provisioner/add/) to your `step-ca` configuration
by running:
```
$ step ca provisioner add my-acme-provisioner --type ACME
```
> NOTE: The above command will add a new provisioner of type `ACME` and name
> `my-acme-provisioner`. The name is used to identify the provisioner
> (e.g. you cannot have two `ACME` provisioners with the same name).
Now restart or SIGHUP `step-ca` to pick up the new configuration.
Thats it.
## Configuring Clients
To configure an ACME client to connect to `step-ca` you need to:
1. Point the client at the right ACME directory URL
2. Tell the client to trust your CAs root certificate
Once certificates are issued, youll also need to ensure theyre renewed before
they expire.
### Pointing Clients at the right ACME Directory URL
Most ACME clients connect to Lets Encrypt by default. To connect to `step-ca`
you need to point the client at the right [ACME directory
URL](https://tools.ietf.org/html/rfc8555#section-7.1.1).
A single instance of `step-ca` can have multiple ACME provisioners, each with
their own ACME directory URL that looks like:
```
https://{ca-host}/acme/{provisioner-name}/directory
```
We just added an ACME provisioner named “acme”. Its ACME directory URL is:
```
https://ca.internal/acme/acme/directory
```
### Telling clients to trust your CAs root certificate
Communication between an ACME client and server [always uses
HTTPS](https://tools.ietf.org/html/rfc8555#section-6.1). By default, clients
will validate the servers HTTPS certificate using the public root certificates
in your systems [default
trust](https://smallstep.com/blog/everything-pki.html#trust-stores) store.
Thats fine when youre connecting to Lets Encrypt: its a public CA and its
root certificate is in your systems default trust store already. Your internal
root certificate isnt, so HTTPS connections from ACME clients to `step-ca` will
fail.
There are two ways to address this problem:
1. Explicitly configure your ACME client to trust `step-ca`'s root certificate, or
2. Add `step-ca`'s root certificate to your systems default trust store (e.g.,
using `[step certificate
install](https://smallstep.com/docs/cli/certificate/install/)`)
If youre using your CA for TLS in production, explicitly configuring your ACME
client to only trust your root certificate is a better option. Well
demonstrate this method with several clients below.
If youre simulating Lets Encrypt in pre-production, installing your root
certificate is a more faithful simulation of production. Once your root
certificate is installed, no additional client configuration is necessary.
> Caution: adding a root certificate to your systems trust store is a global
> operation. Certificates issued by your CA will be trusted everywhere,
> including in web browsers.
### Example using [`certbot`](https://certbot.eff.org/)
[`certbot`](https://certbot.eff.org/) is the grandaddy of ACME clients. Built
and supported by [the EFF](https://www.eff.org/), its the standard-bearer for
production-grade command-line ACME.
To get a certificate from `step-ca` using `certbot` you need to:
1. Point `certbot` at your ACME directory URL using the `--`server flag.
2. Tell `certbot` to trust your root certificate using the `REQUESTS_CA_BUNDLE` environment variable.
For example:
```
$ sudo REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt \
certbot certonly -n --standalone -d foo.internal \
--server https://ca.internal/acme/acme/directory
```
`sudo` is required in `certbot`'s [*standalone*
mode](https://certbot.eff.org/docs/using.html#standalone) so it can listen on
port 80 to complete the `http-01` challenge. If you already have a webserver
running you can use [*webroot*
mode](https://certbot.eff.org/docs/using.html#webroot) instead. With the
[appropriate plugin](https://certbot.eff.org/docs/using.html#dns-plugins)
`certbot` also supports the `dns-01` challenge for most popular DNS providers.
Deeper integrations with [nginx](https://certbot.eff.org/docs/using.html#nginx)
and [apache](https://certbot.eff.org/docs/using.html#apache) can even configure
your server to use HTTPS automatically (we'll set this up ourselves later). All
of this works with `step-ca`.
You can renew all of the certificates you've installed using `cerbot` by running:
```
$ sudo REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt certbot renew
```
You can automate renewal with a simple `cron` entry:
```
*/15 * * * * root REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt certbot -q renew
```
The `certbot` packages for some Linux distributions will create a `cron` entry
or [systemd
timer](https://stevenwestmoreland.com/2017/11/renewing-certbot-certificates-using-a-systemd-timer.html)
like this for you. This entry won't work with `step-ca` because it [doesn't set
the `REQUESTS_CA_BUNDLE` environment
variable](https://github.com/certbot/certbot/issues/7170). You'll need to
manually tweak it to do so.
More subtly, `certbot`'s default renewal job is tuned for Let's Encrypt's 90
day certificate lifetimes: it's run every 12 hours, with actual renewals
occurring for certificates within 30 days of expiry. By default, `step-ca`
issues certificates with *much shorter* 24 hour lifetimes. The `cron` entry
above accounts for this by running `certbot renew` every 15 minutes. You'll
also want to configure your domain to only renew certificates when they're
within a few hours of expiry by adding a line like:
```
renew_before_expiry = 8 hours
```
to the top of your renewal configuration (e.g., in `/etc/letsencrypt/renewal/foo.internal.conf`).
## Feedback
`step-ca` should work with any ACMEv2
([RFC8555](https://tools.ietf.org/html/rfc8555)) compliant client that supports
the http-01 or dns-01 challenge. If you run into any issues please let us know
[on gitter](https://gitter.im/smallstep/community) or [in an
issue](https://github.com/smallstep/certificates/issues/new?template=bug_report.md).

View file

@ -111,6 +111,7 @@ is G-Suite.
"configurationEndpoint": "https://accounts.google.com/.well-known/openid-configuration", "configurationEndpoint": "https://accounts.google.com/.well-known/openid-configuration",
"admins": ["you@smallstep.com"], "admins": ["you@smallstep.com"],
"domains": ["smallstep.com"], "domains": ["smallstep.com"],
"listenAddress": ":10000",
"claims": { "claims": {
"maxTLSCertDuration": "8h", "maxTLSCertDuration": "8h",
"defaultTLSCertDuration": "2h", "defaultTLSCertDuration": "2h",
@ -141,6 +142,12 @@ is G-Suite.
* `domains` (optional): is the list of domains valid. If provided only the * `domains` (optional): is the list of domains valid. If provided only the
emails with the provided domains will be able to authenticate. emails with the provided domains will be able to authenticate.
* `listenAddress` (optional): is the loopback address (`:port` or `host:port`)
where the authorization server will redirect to complete the authorization
flow. If it's not defined `step` will use `127.0.0.1` with a random port. This
configuration is only required if the authorization server doesn't allow any
port to be specified at the time of the request for loopback IP redirect URIs.
* `claims` (optional): overwrites the default claims set in the authority, see * `claims` (optional): overwrites the default claims set in the authority, see
the [JWK](#jwk) section for all the options. the [JWK](#jwk) section for all the options.

View file

@ -30,9 +30,6 @@ func NewResponseLogger(w http.ResponseWriter) ResponseLogger {
func wrapLogger(w http.ResponseWriter) (rw ResponseLogger) { func wrapLogger(w http.ResponseWriter) (rw ResponseLogger) {
rw = &rwDefault{w, 200, 0, nil} rw = &rwDefault{w, 200, 0, nil}
if c, ok := w.(http.CloseNotifier); ok {
rw = &rwCloseNotifier{rw, c}
}
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
rw = &rwFlusher{rw, f} rw = &rwFlusher{rw, f}
} }
@ -88,15 +85,6 @@ func (r *rwDefault) WithFields(fields map[string]interface{}) {
} }
} }
type rwCloseNotifier struct {
ResponseLogger
c http.CloseNotifier
}
func (r *rwCloseNotifier) CloseNotify() <-chan bool {
return r.CloseNotify()
}
type rwFlusher struct { type rwFlusher struct {
ResponseLogger ResponseLogger
f http.Flusher f http.Flusher