forked from TrueCloudLab/certificates
Merge branch 'master' into onboarding
This commit is contained in:
commit
be07334164
75 changed files with 16168 additions and 332 deletions
|
@ -49,6 +49,8 @@ linters:
|
|||
- misspell
|
||||
- ineffassign
|
||||
- deadcode
|
||||
- staticcheck
|
||||
- unused
|
||||
|
||||
run:
|
||||
skip-dirs:
|
||||
|
@ -63,6 +65,6 @@ issues:
|
|||
# golangci.com configuration
|
||||
# https://github.com/golangci/golangci/wiki/Configuration
|
||||
service:
|
||||
golangci-lint-version: 1.17.x # use the fixed version to not introduce new linters unexpectedly
|
||||
golangci-lint-version: 1.18.x # use the fixed version to not introduce new linters unexpectedly
|
||||
prepare:
|
||||
- echo "here I can run custom commands, but no preparation needed for this repo"
|
||||
|
|
4
Gopkg.lock
generated
4
Gopkg.lock
generated
|
@ -233,7 +233,7 @@
|
|||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:9c1b7052fa8f2c918efd60ed5ae3c70ccbba08967c58ec71067535449a3ba220"
|
||||
digest = "1:7d03323edb817ca94efaee5489cde6acd06ceeaca9e6eee106d2d6a90deca997"
|
||||
name = "github.com/smallstep/nosql"
|
||||
packages = [
|
||||
".",
|
||||
|
@ -243,7 +243,7 @@
|
|||
"mysql",
|
||||
]
|
||||
pruneopts = "UT"
|
||||
revision = "a0934e12468769d8cbede3ed316c47a4b88de4ca"
|
||||
revision = "f80b3f432de0662f07ebd58fe52b0a119fe5dcd9"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
|
|
16
Makefile
16
Makefile
|
@ -8,9 +8,6 @@ SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*")
|
|||
GOOS_OVERRIDE ?=
|
||||
OUTPUT_ROOT=output/
|
||||
|
||||
# Set shell to bash for `echo -e`
|
||||
SHELL := /bin/bash
|
||||
|
||||
all: build test lint
|
||||
|
||||
.PHONY: all
|
||||
|
@ -97,16 +94,7 @@ generate:
|
|||
test:
|
||||
$Q $(GOFLAGS) go test -short -coverprofile=coverage.out ./...
|
||||
|
||||
vtest:
|
||||
$(Q)for d in $$(go list ./... | grep -v vendor); do \
|
||||
echo -e "TESTS FOR: for \033[0;35m$$d\033[0m"; \
|
||||
$(GOFLAGS) go test -v -bench=. -run=. -short -coverprofile=vcoverage.out $$d; \
|
||||
out=$$?; \
|
||||
if [[ $$out -ne 0 ]]; then ret=$$out; fi;\
|
||||
rm -f profile.coverage.out; \
|
||||
done; exit $$ret;
|
||||
|
||||
.PHONY: test vtest
|
||||
.PHONY: test
|
||||
|
||||
integrate: integration
|
||||
|
||||
|
@ -125,7 +113,7 @@ fmt:
|
|||
lint:
|
||||
$Q LOG_LEVEL=error golangci-lint run
|
||||
|
||||
.PHONY: $(LINTERS) lint fmt
|
||||
.PHONY: lint fmt
|
||||
|
||||
#########################################
|
||||
# Install
|
||||
|
|
16
README.md
16
README.md
|
@ -32,7 +32,7 @@ It's super easy to get started and to operate `step-ca` thanks to [streamlined i
|
|||
### A private certificate authority you run yourself
|
||||
|
||||
- Issue client and server certificates to VMs, containers, devices, and people using internal hostnames and emails
|
||||
- [RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliant certificates that work **for TLS and HTTPS** (SSH coming soon!)
|
||||
- [RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliant certificates that work **for TLS and HTTPS**
|
||||
- Choose key types (RSA, ECDSA, EdDSA) & lifetimes to suit your needs
|
||||
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with **fully automated** enrollment, renewal, and revocation
|
||||
- Fast, stable, and capable of high availability deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
|
||||
|
@ -46,7 +46,19 @@ It's super easy to get started and to operate `step-ca` thanks to [streamlined i
|
|||
- [Instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/) for VMs on AWS, GCP, and Azure
|
||||
- [Single-use short-lived tokens](https://smallstep.com/docs/design-doc.html#jwk-provisioner) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
|
||||
- Use an existing certificate from another CA (e.g., using a device certificate like [Twilio's Trust OnBoard](https://www.twilio.com/wireless/trust-onboard)) *coming soon*
|
||||
- ACMEv2 (RFC8555) support so you can **run your own private ACME server** *[coming soon](https://github.com/smallstep/certificates/tree/acme)*
|
||||
|
||||
### [Your own private ACME Server](https://smallstep.com/blog/private-acme-server/)
|
||||
- Issue certificates using ACMEv2 ([RFC8555](https://tools.ietf.org/html/rfc8555)), **the protocol used by Let's Encrypt**
|
||||
- Great for [using ACME in development & pre-production](https://smallstep.com/blog/private-acme-server/#local-development-pre-production)
|
||||
- Supports the `http-01` and `dns-01` ACME challenge types
|
||||
- Works with any compliant ACME client including [certbot](https://smallstep.com/blog/private-acme-server/#certbot-uploads-acme-certbot-png-certbot-example), [acme.sh](https://smallstep.com/blog/private-acme-server/#acme-sh-uploads-acme-acme-sh-png-acme-sh-example), [Caddy](https://smallstep.com/blog/private-acme-server/#caddy-uploads-acme-caddy-png-caddy-example), and [traefik](https://smallstep.com/blog/private-acme-server/#traefik-uploads-acme-traefik-png-traefik-example)
|
||||
- Get certificates programmatically (e.g., in [Go](https://smallstep.com/blog/private-acme-server/#golang-uploads-acme-golang-png-go-example), [Python](https://smallstep.com/blog/private-acme-server/#python-uploads-acme-python-png-python-example), [Node.js](https://smallstep.com/blog/private-acme-server/#node-js-uploads-acme-node-js-png-node-js-example))
|
||||
|
||||
### [SSH Certificates](https://smallstep.com/blog/use-ssh-certificates/)
|
||||
|
||||
- Use [certificate authentication for SSH](https://smallstep.com/blog/use-ssh-certificates/): connect SSH to SSO, improve security, and eliminate warnings & errors
|
||||
- Issue SSH user certificates using OAuth OIDC
|
||||
- Issue SSH host certificates to cloud VMs using instance identity documents
|
||||
|
||||
### Easy certificate management and automation via [`step` CLI](https://github.com/smallstep/cli) [integration](https://smallstep.com/docs/cli/ca/)
|
||||
|
||||
|
|
214
acme/account.go
Normal file
214
acme/account.go
Normal file
|
@ -0,0 +1,214 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Account is a subset of the internal account type containing only those
|
||||
// attributes required for responses in the ACME protocol.
|
||||
type Account struct {
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Orders string `json:"orders"`
|
||||
ID string `json:"-"`
|
||||
Key *jose.JSONWebKey `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (a *Account) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the account ID.
|
||||
func (a *Account) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
// GetKey returns the JWK associated with the account.
|
||||
func (a *Account) GetKey() *jose.JSONWebKey {
|
||||
return a.Key
|
||||
}
|
||||
|
||||
// IsValid returns true if the Account is valid.
|
||||
func (a *Account) IsValid() bool {
|
||||
return a.Status == StatusValid
|
||||
}
|
||||
|
||||
// AccountOptions are the options needed to create a new ACME account.
|
||||
type AccountOptions struct {
|
||||
Key *jose.JSONWebKey
|
||||
Contact []string
|
||||
}
|
||||
|
||||
// account represents an ACME account.
|
||||
type account struct {
|
||||
ID string `json:"id"`
|
||||
Created time.Time `json:"created"`
|
||||
Deactivated time.Time `json:"deactivated"`
|
||||
Key *jose.JSONWebKey `json:"key"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// newAccount returns a new acme account type.
|
||||
func newAccount(db nosql.DB, ops AccountOptions) (*account, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a := &account{
|
||||
ID: id,
|
||||
Key: ops.Key,
|
||||
Contact: ops.Contact,
|
||||
Status: "valid",
|
||||
Created: clock.Now(),
|
||||
}
|
||||
return a, a.saveNew(db)
|
||||
}
|
||||
|
||||
// toACME converts the internal Account type into the public acmeAccount
|
||||
// type for presentation in the ACME protocol.
|
||||
func (a *account) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Account, error) {
|
||||
return &Account{
|
||||
Status: a.Status,
|
||||
Contact: a.Contact,
|
||||
Orders: dir.getLink(OrdersByAccountLink, URLSafeProvisionerName(p), true, a.ID),
|
||||
Key: a.Key,
|
||||
ID: a.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// save writes the Account to the DB.
|
||||
// If the account is new then the necessary indices will be created.
|
||||
// Else, the account in the DB will be updated.
|
||||
func (a *account) saveNew(db nosql.DB) error {
|
||||
kid, err := keyToID(a.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
kidB := []byte(kid)
|
||||
|
||||
// Set the jwkID -> acme account ID index
|
||||
_, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID))
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("key-id to account-id index already exists"))
|
||||
default:
|
||||
if err = a.save(db, nil); err != nil {
|
||||
db.Del(accountByKeyIDTable, kidB)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *account) save(db nosql.DB, old *account) error {
|
||||
var (
|
||||
err error
|
||||
oldB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
||||
}
|
||||
}
|
||||
|
||||
b, err := json.Marshal(*a)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error marshaling new account object")
|
||||
}
|
||||
// Set the Account
|
||||
_, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error storing account"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.New("error storing account; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// update updates the acme account object stored in the database if,
|
||||
// and only if, the account has not changed since the last read.
|
||||
func (a *account) update(db nosql.DB, contact []string) (*account, error) {
|
||||
b := *a
|
||||
b.Contact = contact
|
||||
if err := b.save(db, a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// deactivate deactivates the acme account.
|
||||
func (a *account) deactivate(db nosql.DB) (*account, error) {
|
||||
b := *a
|
||||
b.Status = StatusDeactivated
|
||||
b.Deactivated = clock.Now()
|
||||
if err := b.save(db, a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// getAccountByID retrieves the account with the given ID.
|
||||
func getAccountByID(db nosql.DB, id string) (*account, error) {
|
||||
ab, err := db.Get(accountTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id))
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id))
|
||||
}
|
||||
|
||||
a := new(account)
|
||||
if err = json.Unmarshal(ab, a); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account"))
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// getAccountByKeyID retrieves Id associated with the given Kid.
|
||||
func getAccountByKeyID(db nosql.DB, kid string) (*account, error) {
|
||||
id, err := db.Get(accountByKeyIDTable, []byte(kid))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid))
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index"))
|
||||
}
|
||||
return getAccountByID(db, string(id))
|
||||
}
|
||||
|
||||
// getOrderIDsByAccount retrieves a list of Order IDs that were created by the
|
||||
// account.
|
||||
func getOrderIDsByAccount(db nosql.DB, id string) ([]string, error) {
|
||||
b, err := db.Get(ordersByAccountIDTable, []byte(id))
|
||||
if err != nil {
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return []string{}, nil
|
||||
}
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", id))
|
||||
}
|
||||
var orderIDs []string
|
||||
if err := json.Unmarshal(b, &orderIDs); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", id))
|
||||
}
|
||||
return orderIDs, nil
|
||||
}
|
844
acme/account_test.go
Normal file
844
acme/account_test.go
Normal file
|
@ -0,0 +1,844 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}
|
||||
)
|
||||
|
||||
func newProv() provisioner.Interface {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func newAcc() (*account, error) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
return newAccount(mockdb, AccountOptions{
|
||||
Key: jwk, Contact: []string{"foo", "bar"},
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAccountByID(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
id: acc.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if acc, err := getAccountByID(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountByKeyID(t *testing.T) {
|
||||
type test struct {
|
||||
kid string
|
||||
db nosql.DB
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/kid-not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("account with key id foo not found: not found")),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading key-account index: force")),
|
||||
}
|
||||
},
|
||||
"fail/getAccount-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
kid: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
count++
|
||||
return []byte("bar"), nil
|
||||
}
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading account bar: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
kid: acc.Key.KeyID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
var ret []byte
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(acc.Key.KeyID))
|
||||
ret = []byte(acc.ID)
|
||||
case 1:
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
ret = b
|
||||
}
|
||||
count++
|
||||
return ret, nil
|
||||
},
|
||||
},
|
||||
acc: acc,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountIDsByAccount(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
res []string
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok/not-found": func(t *testing.T) test {
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
res: []string{},
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
oids := []string{"foo", "bar", "baz"}
|
||||
b, err := json.Marshal(oids)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
id: "foo",
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||
assert.Equals(t, key, []byte("foo"))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
res: oids,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if oids, err := getOrderIDsByAccount(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.res, oids)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountToACME(t *testing.T) {
|
||||
dir := newDirectory("ca.smallstep.com", "acme")
|
||||
prov := newProv()
|
||||
|
||||
type test struct {
|
||||
acc *account
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{acc: acc}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acmeAccount, err := tc.acc.toACME(nil, dir, prov)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acmeAccount.ID, tc.acc.ID)
|
||||
assert.Equals(t, acmeAccount.Status, tc.acc.Status)
|
||||
assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
|
||||
assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
|
||||
assert.Equals(t, acmeAccount.Orders,
|
||||
fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s/orders", URLSafeProvisionerName(prov), tc.acc.ID))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSave(t *testing.T) {
|
||||
type test struct {
|
||||
acc, old *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok/old-nil": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, []byte(acc.ID), key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/old-not-nil": func(t *testing.T) test {
|
||||
oldAcc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
oldb, err := json.Marshal(oldAcc)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
old: oldAcc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, []byte(acc.ID), key)
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.acc.save(tc.db, tc.old); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSaveNew(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/keyToID-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
acc.Key.Key = "foo"
|
||||
return test{
|
||||
acc: acc,
|
||||
err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
|
||||
}
|
||||
},
|
||||
"fail/swap-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||
}
|
||||
},
|
||||
"fail/swap-false": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("key-id to account-id index already exists")),
|
||||
}
|
||||
},
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
count++
|
||||
return nil, true, nil
|
||||
}
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
MDel: func(bucket, key []byte) error {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
return nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(acc.Key)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
count := 0
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, []byte(acc.ID))
|
||||
count++
|
||||
return nil, true, nil
|
||||
}
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.acc.saveNew(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountUpdate(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
contact []string
|
||||
db nosql.DB
|
||||
res []byte
|
||||
err *Error
|
||||
}
|
||||
contact := []string{"foo", "bar"}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_acc := *acc
|
||||
clone := &_acc
|
||||
clone.Contact = contact
|
||||
b, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
contact: contact,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
_acc := *acc
|
||||
clone := &_acc
|
||||
clone.Contact = contact
|
||||
b, err := json.Marshal(clone)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
acc: acc,
|
||||
contact: contact,
|
||||
res: b,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, newval, b)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
acc, err := tc.acc.update(tc.db, tc.contact)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
b, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, b, tc.res)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountDeactivate(t *testing.T) {
|
||||
type test struct {
|
||||
acc *account
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc, err := newAcc()
|
||||
assert.FatalError(t, err)
|
||||
oldb, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
acc: acc,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
assert.Equals(t, key, []byte(acc.ID))
|
||||
assert.Equals(t, old, oldb)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
acc, err := tc.acc.deactivate(tc.db)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, tc.acc.ID)
|
||||
assert.Equals(t, acc.Contact, tc.acc.Contact)
|
||||
assert.Equals(t, acc.Status, StatusDeactivated)
|
||||
assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID)
|
||||
assert.Equals(t, acc.Created, tc.acc.Created)
|
||||
|
||||
assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAccount(t *testing.T) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
kid, err := keyToID(jwk)
|
||||
assert.FatalError(t, err)
|
||||
ops := AccountOptions{
|
||||
Key: jwk,
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
type test struct {
|
||||
ops AccountOptions
|
||||
db nosql.DB
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/store-error": func(t *testing.T) test {
|
||||
return test{
|
||||
ops: ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var _id string
|
||||
id := &_id
|
||||
count := 0
|
||||
return test{
|
||||
ops: ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||
assert.Equals(t, key, []byte(kid))
|
||||
case 1:
|
||||
assert.Equals(t, bucket, accountTable)
|
||||
*id = string(key)
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acc, err := newAccount(tc.db, tc.ops)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acc.ID, *tc.id)
|
||||
assert.Equals(t, acc.Status, StatusValid)
|
||||
assert.Equals(t, acc.Contact, ops.Contact)
|
||||
assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID)
|
||||
|
||||
assert.True(t, acc.Deactivated.IsZero())
|
||||
|
||||
assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
213
acme/api/account.go
Normal file
213
acme/api/account.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// NewAccountRequest represents the payload for a new account request.
|
||||
type NewAccountRequest struct {
|
||||
Contact []string `json:"contact"`
|
||||
OnlyReturnExisting bool `json:"onlyReturnExisting"`
|
||||
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"`
|
||||
}
|
||||
|
||||
func validateContacts(cs []string) error {
|
||||
for _, c := range cs {
|
||||
if len(c) == 0 {
|
||||
return acme.MalformedErr(errors.New("contact cannot be empty string"))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates a new-account request body.
|
||||
func (n *NewAccountRequest) Validate() error {
|
||||
if n.OnlyReturnExisting && len(n.Contact) > 0 {
|
||||
return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone"))
|
||||
}
|
||||
return validateContacts(n.Contact)
|
||||
}
|
||||
|
||||
// UpdateAccountRequest represents an update-account request.
|
||||
type UpdateAccountRequest struct {
|
||||
Contact []string `json:"contact"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// IsDeactivateRequest returns true if the update request is a deactivation
|
||||
// request, false otherwise.
|
||||
func (u *UpdateAccountRequest) IsDeactivateRequest() bool {
|
||||
return u.Status == acme.StatusDeactivated
|
||||
}
|
||||
|
||||
// Validate validates a update-account request body.
|
||||
func (u *UpdateAccountRequest) Validate() error {
|
||||
switch {
|
||||
case len(u.Status) > 0 && len(u.Contact) > 0:
|
||||
return acme.MalformedErr(errors.New("incompatible input; contact and " +
|
||||
"status updates are mutually exclusive"))
|
||||
case len(u.Contact) > 0:
|
||||
if err := validateContacts(u.Contact); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
case len(u.Status) > 0:
|
||||
if u.Status != acme.StatusDeactivated {
|
||||
return acme.MalformedErr(errors.Errorf("cannot update account "+
|
||||
"status to %s, only deactivated", u.Status))
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return acme.MalformedErr(errors.Errorf("empty update request"))
|
||||
}
|
||||
}
|
||||
|
||||
// NewAccount is the handler resource for creating new ACME accounts.
|
||||
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var nar NewAccountRequest
|
||||
if err := json.Unmarshal(payload.value, &nar); err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
||||
"failed to unmarshal new-account request payload")))
|
||||
return
|
||||
}
|
||||
if err := nar.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
httpStatus := http.StatusCreated
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
acmeErr, ok := err.(*acme.Error)
|
||||
if !ok || acmeErr.Status != http.StatusNotFound {
|
||||
// Something went wrong ...
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Account does not exist //
|
||||
if nar.OnlyReturnExisting {
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
||||
return
|
||||
}
|
||||
jwk, err := jwkFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if acc, err = h.Auth.NewAccount(prov, acme.AccountOptions{
|
||||
Key: jwk,
|
||||
Contact: nar.Contact,
|
||||
}); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Account exists //
|
||||
httpStatus = http.StatusOK
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink,
|
||||
acme.URLSafeProvisionerName(prov), true, acc.GetID()))
|
||||
api.JSONStatus(w, acc, httpStatus)
|
||||
return
|
||||
}
|
||||
|
||||
// GetUpdateAccount is the api for updating an ACME account.
|
||||
func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !payload.isPostAsGet {
|
||||
var uar UpdateAccountRequest
|
||||
if err := json.Unmarshal(payload.value, &uar); err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload")))
|
||||
return
|
||||
}
|
||||
if err := uar.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var err error
|
||||
if uar.IsDeactivateRequest() {
|
||||
acc, err = h.Auth.DeactivateAccount(prov, acc.GetID())
|
||||
} else {
|
||||
acc, err = h.Auth.UpdateAccount(prov, acc.GetID(), uar.Contact)
|
||||
}
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, acc.GetID()))
|
||||
api.JSON(w, acc)
|
||||
return
|
||||
}
|
||||
|
||||
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
m := map[string]interface{}{
|
||||
"orders": oids,
|
||||
}
|
||||
rl.WithFields(m)
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
|
||||
func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
accID := chi.URLParam(r, "accID")
|
||||
if acc.ID != accID {
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param")))
|
||||
return
|
||||
}
|
||||
orders, err := h.Auth.GetOrdersByAccount(prov, acc.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
api.JSON(w, orders)
|
||||
logOrdersByAccount(w, orders)
|
||||
return
|
||||
}
|
790
acme/api/account_test.go
Normal file
790
acme/api/account_test.go
Normal file
|
@ -0,0 +1,790 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}
|
||||
)
|
||||
|
||||
func newProv() provisioner.Interface {
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestNewAccountRequestValidate(t *testing.T) {
|
||||
type test struct {
|
||||
nar *NewAccountRequest
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/incompatible-input": func(t *testing.T) test {
|
||||
return test{
|
||||
nar: &NewAccountRequest{
|
||||
OnlyReturnExisting: true,
|
||||
Contact: []string{"foo", "bar"},
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")),
|
||||
}
|
||||
},
|
||||
"fail/bad-contact": func(t *testing.T) test {
|
||||
return test{
|
||||
nar: &NewAccountRequest{
|
||||
Contact: []string{"foo", ""},
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
nar: &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/onlyReturnExisting": func(t *testing.T) test {
|
||||
return test{
|
||||
nar: &NewAccountRequest{
|
||||
OnlyReturnExisting: true,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.nar.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountRequestValidate(t *testing.T) {
|
||||
type test struct {
|
||||
uar *UpdateAccountRequest
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/incompatible-input": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
Status: "foo",
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("incompatible input; " +
|
||||
"contact and status updates are mutually exclusive")),
|
||||
}
|
||||
},
|
||||
"fail/bad-contact": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Contact: []string{"foo", ""},
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
|
||||
}
|
||||
},
|
||||
"fail/bad-status": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Status: "foo",
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("cannot update account " +
|
||||
"status to foo, only deactivated")),
|
||||
}
|
||||
},
|
||||
"ok/contact": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/status": func(t *testing.T) test {
|
||||
return test{
|
||||
uar: &UpdateAccountRequest{
|
||||
Status: "deactivated",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.uar.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetOrdersByAccount(t *testing.T) {
|
||||
oids := []string{
|
||||
"https://ca.smallstep.com/acme/order/foo",
|
||||
"https://ca.smallstep.com/acme/order/bar",
|
||||
}
|
||||
accID := "account-id"
|
||||
prov := newProv()
|
||||
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("accID", accID)
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "foo"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 401,
|
||||
problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")),
|
||||
}
|
||||
},
|
||||
"fail/getOrdersByAccount-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: accID}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.ServerInternalErr(errors.New("force")),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: accID}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getOrdersByAccount: func(p provisioner.Interface, id string) ([]string, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, acc.ID)
|
||||
return oids, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrdersByAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(oids)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerNewAccount(t *testing.T) {
|
||||
accID := "accountID"
|
||||
acc := acme.Account{
|
||||
ID: accID,
|
||||
Status: "valid",
|
||||
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
|
||||
}
|
||||
prov := newProv()
|
||||
|
||||
url := "https://ca.smallstep.com/acme/new-account"
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", ""},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
|
||||
}
|
||||
},
|
||||
"fail/no-existing-account": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
OnlyReturnExisting: true,
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-jwk": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-jwk": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/NewAccount-error": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.Contact, nar.Contact)
|
||||
assert.Equals(t, ops.Key, jwk)
|
||||
return nil, acme.ServerInternalErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok/new-account": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.Contact, nar.Contact)
|
||||
assert.Equals(t, ops.Key, jwk)
|
||||
return &acc, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
}
|
||||
},
|
||||
"ok/return-existing": func(t *testing.T) test {
|
||||
nar := &NewAccountRequest{
|
||||
OnlyReturnExisting: true,
|
||||
}
|
||||
b, err := json.Marshal(nar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.NewAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"],
|
||||
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||
accID := "accountID"
|
||||
acc := acme.Account{
|
||||
ID: accID,
|
||||
Status: "valid",
|
||||
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
|
||||
}
|
||||
prov := newProv()
|
||||
|
||||
// Request with chi context
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s", accID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Contact: []string{"foo", ""},
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
|
||||
}
|
||||
},
|
||||
"fail/Deactivate-error": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Status: "deactivated",
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, accID)
|
||||
return nil, acme.ServerInternalErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"fail/UpdateAccount-error": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, accID)
|
||||
assert.Equals(t, contacts, uar.Contact)
|
||||
return nil, acme.ServerInternalErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok/deactivate": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Status: "deactivated",
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, accID)
|
||||
return &acc, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/new-account": func(t *testing.T) test {
|
||||
uar := &UpdateAccountRequest{
|
||||
Contact: []string{"foo", "bar"},
|
||||
}
|
||||
b, err := json.Marshal(uar)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, id, accID)
|
||||
assert.Equals(t, contacts, uar.Contact)
|
||||
return &acc, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/post-as-get": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.AccountLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{accID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetUpdateAccount(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(acc)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"],
|
||||
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||
acme.URLSafeProvisionerName(prov), accID)})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
214
acme/api/handler.go
Normal file
214
acme/api/handler.go
Normal file
|
@ -0,0 +1,214 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func link(url, typ string) string {
|
||||
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
accContextKey = contextKey("acc")
|
||||
jwsContextKey = contextKey("jws")
|
||||
jwkContextKey = contextKey("jwk")
|
||||
payloadContextKey = contextKey("payload")
|
||||
provisionerContextKey = contextKey("provisioner")
|
||||
)
|
||||
|
||||
type payloadInfo struct {
|
||||
value []byte
|
||||
isPostAsGet bool
|
||||
isEmptyJSON bool
|
||||
}
|
||||
|
||||
func accountFromContext(r *http.Request) (*acme.Account, error) {
|
||||
val, ok := r.Context().Value(accContextKey).(*acme.Account)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.AccountDoesNotExistErr(nil)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func jwkFromContext(r *http.Request) (*jose.JSONWebKey, error) {
|
||||
val, ok := r.Context().Value(jwkContextKey).(*jose.JSONWebKey)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("jwk expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func jwsFromContext(r *http.Request) (*jose.JSONWebSignature, error) {
|
||||
val, ok := r.Context().Value(jwsContextKey).(*jose.JSONWebSignature)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("jws expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func payloadFromContext(r *http.Request) (*payloadInfo, error) {
|
||||
val, ok := r.Context().Value(payloadContextKey).(*payloadInfo)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func provisionerFromContext(r *http.Request) (provisioner.Interface, error) {
|
||||
val, ok := r.Context().Value(provisionerContextKey).(provisioner.Interface)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.ServerInternalErr(errors.Errorf("provisioner expected in request context"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// New returns a new ACME API router.
|
||||
func New(acmeAuth acme.Interface) api.RouterHandler {
|
||||
return &Handler{acmeAuth}
|
||||
}
|
||||
|
||||
// Handler is the ACME request handler.
|
||||
type Handler struct {
|
||||
Auth acme.Interface
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
getLink := h.Auth.GetLink
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
|
||||
r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
|
||||
r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
|
||||
r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
|
||||
|
||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))
|
||||
}
|
||||
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))
|
||||
}
|
||||
|
||||
r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false), extractPayloadByJWK(h.NewAccount))
|
||||
r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.GetUpdateAccount))
|
||||
r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false), extractPayloadByKid(h.NewOrder))
|
||||
r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
||||
r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount)))
|
||||
r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
||||
r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz)))
|
||||
r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, "{chID}"), extractPayloadByKid(h.GetChallenge))
|
||||
r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
||||
}
|
||||
|
||||
// GetNonce just sets the right header since a Nonce is added to each response
|
||||
// by middleware by default.
|
||||
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "HEAD" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetDirectory is the ACME resource for returning a directory configuration
|
||||
// for client configuration.
|
||||
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
dir := h.Auth.GetDirectory(prov)
|
||||
api.JSON(w, dir)
|
||||
return
|
||||
}
|
||||
|
||||
// GetAuthz ACME api for retrieving an Authz.
|
||||
func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
authz, err := h.Auth.GetAuthz(prov, acc.GetID(), chi.URLParam(r, "authzID"))
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, authz.GetID()))
|
||||
api.JSON(w, authz)
|
||||
return
|
||||
}
|
||||
|
||||
// GetChallenge ACME api for retrieving a Challenge.
|
||||
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
// Just verify that the payload was set, since we're not strictly adhering
|
||||
// to ACME V2 spec for reasons specified below.
|
||||
_, err = payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// NOTE: We should be checking that the request is either a POST-as-GET, or
|
||||
// that the payload is an empty JSON block ({}). However, older ACME clients
|
||||
// still send a vestigial body (rather than an empty JSON block) and
|
||||
// strict enforcement would render these clients broken. For the time being
|
||||
// we'll just ignore the body.
|
||||
var (
|
||||
ch *acme.Challenge
|
||||
chID = chi.URLParam(r, "chID")
|
||||
)
|
||||
ch, err = h.Auth.ValidateChallenge(prov, acc.GetID(), chID, acc.GetKey())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
getLink := h.Auth.GetLink
|
||||
w.Header().Add("Link", link(getLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, ch.GetAuthzID()), "up"))
|
||||
w.Header().Set("Location", getLink(acme.ChallengeLink, acme.URLSafeProvisionerName(prov), true, ch.GetID()))
|
||||
api.JSON(w, ch)
|
||||
return
|
||||
}
|
||||
|
||||
// GetCertificate ACME api for retrieving a Certificate.
|
||||
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
certID := chi.URLParam(r, "certID")
|
||||
certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
|
||||
w.Write(certBytes)
|
||||
return
|
||||
}
|
771
acme/api/handler_test.go
Normal file
771
acme/api/handler_test.go
Normal file
|
@ -0,0 +1,771 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
type mockAcmeAuthority struct {
|
||||
deactivateAccount func(provisioner.Interface, string) (*acme.Account, error)
|
||||
finalizeOrder func(p provisioner.Interface, accID string, id string, csr *x509.CertificateRequest) (*acme.Order, error)
|
||||
getAccount func(p provisioner.Interface, id string) (*acme.Account, error)
|
||||
getAccountByKey func(provisioner.Interface, *jose.JSONWebKey) (*acme.Account, error)
|
||||
getAuthz func(p provisioner.Interface, accID string, id string) (*acme.Authz, error)
|
||||
getCertificate func(accID string, id string) ([]byte, error)
|
||||
getChallenge func(p provisioner.Interface, accID string, id string) (*acme.Challenge, error)
|
||||
getDirectory func(provisioner.Interface) *acme.Directory
|
||||
getLink func(acme.Link, string, bool, ...string) string
|
||||
getOrder func(p provisioner.Interface, accID string, id string) (*acme.Order, error)
|
||||
getOrdersByAccount func(p provisioner.Interface, id string) ([]string, error)
|
||||
loadProvisionerByID func(string) (provisioner.Interface, error)
|
||||
newAccount func(provisioner.Interface, acme.AccountOptions) (*acme.Account, error)
|
||||
newNonce func() (string, error)
|
||||
newOrder func(provisioner.Interface, acme.OrderOptions) (*acme.Order, error)
|
||||
updateAccount func(provisioner.Interface, string, []string) (*acme.Account, error)
|
||||
useNonce func(string) error
|
||||
validateChallenge func(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error)
|
||||
ret1 interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) DeactivateAccount(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||
if m.deactivateAccount != nil {
|
||||
return m.deactivateAccount(p, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) FinalizeOrder(p provisioner.Interface, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) {
|
||||
if m.finalizeOrder != nil {
|
||||
return m.finalizeOrder(p, accID, id, csr)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Order), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetAccount(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||
if m.getAccount != nil {
|
||||
return m.getAccount(p, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) {
|
||||
if m.getAccountByKey != nil {
|
||||
return m.getAccountByKey(p, jwk)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetAuthz(p provisioner.Interface, accID, id string) (*acme.Authz, error) {
|
||||
if m.getAuthz != nil {
|
||||
return m.getAuthz(p, accID, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Authz), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetCertificate(accID, id string) ([]byte, error) {
|
||||
if m.getCertificate != nil {
|
||||
return m.getCertificate(accID, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.([]byte), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetChallenge(p provisioner.Interface, accID, id string) (*acme.Challenge, error) {
|
||||
if m.getChallenge != nil {
|
||||
return m.getChallenge(p, accID, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Challenge), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetDirectory(p provisioner.Interface) *acme.Directory {
|
||||
if m.getDirectory != nil {
|
||||
return m.getDirectory(p)
|
||||
}
|
||||
return m.ret1.(*acme.Directory)
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetLink(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
if m.getLink != nil {
|
||||
return m.getLink(typ, provID, abs, in...)
|
||||
}
|
||||
return m.ret1.(string)
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetOrder(p provisioner.Interface, accID, id string) (*acme.Order, error) {
|
||||
if m.getOrder != nil {
|
||||
return m.getOrder(p, accID, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Order), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) {
|
||||
if m.getOrdersByAccount != nil {
|
||||
return m.getOrdersByAccount(p, id)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.([]string), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByID != nil {
|
||||
return m.loadProvisionerByID(provID)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) NewAccount(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
|
||||
if m.newAccount != nil {
|
||||
return m.newAccount(p, ops)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) NewNonce() (string, error) {
|
||||
if m.newNonce != nil {
|
||||
return m.newNonce()
|
||||
} else if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
return m.ret1.(string), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) NewOrder(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||
if m.newOrder != nil {
|
||||
return m.newOrder(p, ops)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Order), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*acme.Account, error) {
|
||||
if m.updateAccount != nil {
|
||||
return m.updateAccount(p, id, contact)
|
||||
} else if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.(*acme.Account), m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) UseNonce(nonce string) error {
|
||||
if m.useNonce != nil {
|
||||
return m.useNonce(nonce)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockAcmeAuthority) ValidateChallenge(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
|
||||
switch {
|
||||
case m.validateChallenge != nil:
|
||||
return m.validateChallenge(p, accID, id, jwk)
|
||||
case m.err != nil:
|
||||
return nil, m.err
|
||||
default:
|
||||
return m.ret1.(*acme.Challenge), m.err
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetNonce(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
}{
|
||||
{"GET", 204},
|
||||
{"HEAD", 200},
|
||||
}
|
||||
|
||||
// Request with chi context
|
||||
req := httptest.NewRequest("GET", "http://ca.smallstep.com/nonce", nil)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(nil).(*Handler)
|
||||
w := httptest.NewRecorder()
|
||||
req.Method = tt.name
|
||||
h.GetNonce(w, req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("Handler.GetNonce StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetDirectory(t *testing.T) {
|
||||
auth := acme.NewAuthority(nil, "ca.smallstep.com", "acme", nil)
|
||||
prov := newProv()
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/directory", acme.URLSafeProvisionerName(prov))
|
||||
|
||||
expDir := acme.Directory{
|
||||
NewNonce: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", acme.URLSafeProvisionerName(prov)),
|
||||
NewAccount: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", acme.URLSafeProvisionerName(prov)),
|
||||
NewOrder: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", acme.URLSafeProvisionerName(prov)),
|
||||
RevokeCert: fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", acme.URLSafeProvisionerName(prov)),
|
||||
KeyChange: fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", acme.URLSafeProvisionerName(prov)),
|
||||
}
|
||||
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetDirectory(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
var dir acme.Directory
|
||||
json.Unmarshal(bytes.TrimSpace(body), &dir)
|
||||
assert.Equals(t, dir, expDir)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetAuthz(t *testing.T) {
|
||||
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||
az := acme.Authz{
|
||||
ID: "authzID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "example.com",
|
||||
},
|
||||
Status: "pending",
|
||||
Expires: expiry.Format(time.RFC3339),
|
||||
Wildcard: false,
|
||||
Challenges: []*acme.Challenge{
|
||||
{
|
||||
Type: "http-01",
|
||||
Status: "pending",
|
||||
Token: "tok2",
|
||||
URL: "https://ca.smallstep.com/acme/challenge/chHTTPID",
|
||||
ID: "chHTTP01ID",
|
||||
AuthzID: "authzID",
|
||||
},
|
||||
{
|
||||
Type: "dns-01",
|
||||
Status: "pending",
|
||||
Token: "tok2",
|
||||
URL: "https://ca.smallstep.com/acme/challenge/chDNSID",
|
||||
ID: "chDNSID",
|
||||
AuthzID: "authzID",
|
||||
},
|
||||
},
|
||||
}
|
||||
prov := newProv()
|
||||
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("authzID", az.ID)
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/challenge/%s",
|
||||
acme.URLSafeProvisionerName(prov), az.ID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/getAuthz-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.ServerInternalErr(errors.New("force")),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getAuthz: func(p provisioner.Interface, accID, id string) (*acme.Authz, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, az.ID)
|
||||
return &az, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.Equals(t, typ, acme.AuthzLink)
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{az.ID})
|
||||
return url
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAuthz(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
//var gotAz acme.Authz
|
||||
//assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &gotAz))
|
||||
expB, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetCertificate(t *testing.T) {
|
||||
leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt")
|
||||
assert.FatalError(t, err)
|
||||
inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
root, err := pemutil.ReadCertificate("../../authority/testdata/certs/root_ca.crt")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
certBytes := append(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: leaf.Raw,
|
||||
}), pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: inter.Raw,
|
||||
})...)
|
||||
certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: root.Raw,
|
||||
})...)
|
||||
certID := "certID"
|
||||
|
||||
prov := newProv()
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("certID", certID)
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/certificate/%s",
|
||||
acme.URLSafeProvisionerName(prov), certID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/getCertificate-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.ServerInternalErr(errors.New("force")),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getCertificate: func(accID, id string) ([]byte, error) {
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, certID)
|
||||
return certBytes, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetCertificate(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes))
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain; charset=utf-8"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ch() acme.Challenge {
|
||||
return acme.Challenge{
|
||||
Type: "http-01",
|
||||
Status: "pending",
|
||||
Token: "tok2",
|
||||
URL: "https://ca.smallstep.com/acme/challenge/chID",
|
||||
ID: "chID",
|
||||
AuthzID: "authzID",
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetChallenge(t *testing.T) {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("chID", "chID")
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/challenge/%s", "chID")
|
||||
prov := newProv()
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
ch acme.Challenge
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/validate-challenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.UnauthorizedErr(nil),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 401,
|
||||
problem: acme.UnauthorizedErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/get-challenge-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.UnauthorizedErr(nil),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 401,
|
||||
problem: acme.UnauthorizedErr(nil),
|
||||
}
|
||||
},
|
||||
"ok/validate-challenge": func(t *testing.T) test {
|
||||
key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{ID: "accID", Key: key}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
ch := ch()
|
||||
ch.Status = "valid"
|
||||
ch.Validated = time.Now().UTC().Format(time.RFC3339)
|
||||
count := 0
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
validateChallenge: func(p provisioner.Interface, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, ch.ID)
|
||||
assert.Equals(t, jwk.KeyID, key.KeyID)
|
||||
return &ch, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
var ret string
|
||||
switch count {
|
||||
case 0:
|
||||
assert.Equals(t, typ, acme.AuthzLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{ch.AuthzID})
|
||||
ret = fmt.Sprintf("https://ca.smallstep.com/acme/authz/%s", ch.AuthzID)
|
||||
case 1:
|
||||
assert.Equals(t, typ, acme.ChallengeLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{ch.ID})
|
||||
ret = url
|
||||
}
|
||||
count++
|
||||
return ret
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
ch: ch,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetChallenge(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(tc.ch)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<https://ca.smallstep.com/acme/authz/%s>;rel=\"up\"", tc.ch.AuthzID)})
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
377
acme/api/middleware.go
Normal file
377
acme/api/middleware.go
Normal file
|
@ -0,0 +1,377 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
||||
func logNonce(w http.ResponseWriter, nonce string) {
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
m := map[string]interface{}{
|
||||
"nonce": nonce,
|
||||
}
|
||||
rl.WithFields(m)
|
||||
}
|
||||
}
|
||||
|
||||
// addNonce is a middleware that adds a nonce to the response header.
|
||||
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
nonce, err := h.Auth.NewNonce()
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Replay-Nonce", nonce)
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
logNonce(w, nonce)
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// addDirLink is a middleware that adds a 'Link' response reader with the
|
||||
// directory index url.
|
||||
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Link", link(h.Auth.GetLink(acme.DirectoryLink, acme.URLSafeProvisionerName(prov), true), "index"))
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// verifyContentType is a middleware that verifies that content type is
|
||||
// application/jose+json.
|
||||
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
ct := r.Header.Get("Content-Type")
|
||||
var expected []string
|
||||
if strings.Contains(r.URL.Path, h.Auth.GetLink(acme.CertificateLink, acme.URLSafeProvisionerName(prov), false, "")) {
|
||||
// GET /certificate requests allow a greater range of content types.
|
||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||
} else {
|
||||
// By default every request should have content-type applictaion/jose+json.
|
||||
expected = []string{"application/jose+json"}
|
||||
}
|
||||
for _, e := range expected {
|
||||
if ct == e {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf(
|
||||
"expected content-type to be in %s, but got %s", expected, ct)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
|
||||
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body")))
|
||||
return
|
||||
}
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body")))
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// validateJWS checks the request body for to verify that it meets ACME
|
||||
// requirements for a JWS.
|
||||
//
|
||||
// The JWS MUST NOT have multiple signatures
|
||||
// The JWS Unencoded Payload Option [RFC7797] MUST NOT be used
|
||||
// The JWS Unprotected Header [RFC7515] MUST NOT be used
|
||||
// The JWS Payload MUST NOT be detached
|
||||
// The JWS Protected Header MUST include the following fields:
|
||||
// * “alg” (Algorithm)
|
||||
// * This field MUST NOT contain “none” or a Message Authentication Code
|
||||
// (MAC) algorithm (e.g. one in which the algorithm registry description
|
||||
// mentions MAC/HMAC).
|
||||
// * “nonce” (defined in Section 6.5)
|
||||
// * “url” (defined in Section 6.4)
|
||||
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature")))
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) > 1 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature")))
|
||||
return
|
||||
}
|
||||
|
||||
sig := jws.Signatures[0]
|
||||
uh := sig.Unprotected
|
||||
if len(uh.KeyID) > 0 ||
|
||||
uh.JSONWebKey != nil ||
|
||||
len(uh.Algorithm) > 0 ||
|
||||
len(uh.Nonce) > 0 ||
|
||||
len(uh.ExtraHeaders) > 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used")))
|
||||
return
|
||||
}
|
||||
hdr := sig.Protected
|
||||
switch hdr.Algorithm {
|
||||
case jose.RS256, jose.RS384, jose.RS512:
|
||||
if hdr.JSONWebKey != nil {
|
||||
switch k := hdr.JSONWebKey.Key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if k.Size() < keys.MinRSAKeyBytes {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+
|
||||
"keys must be at least %d bits (%d bytes) in size",
|
||||
8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes)))
|
||||
return
|
||||
}
|
||||
default:
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")))
|
||||
return
|
||||
}
|
||||
}
|
||||
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
||||
// we good
|
||||
default:
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm)))
|
||||
return
|
||||
}
|
||||
|
||||
// Check the validity/freshness of the Nonce.
|
||||
if err := h.Auth.UseNonce(hdr.Nonce); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the JWS url matches the requested url.
|
||||
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
||||
if !ok {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header")))
|
||||
return
|
||||
}
|
||||
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
||||
if jwsURL != reqURL.String() {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)))
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")))
|
||||
return
|
||||
}
|
||||
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
|
||||
// in the context. Make sure to parse and validate the JWS before running this
|
||||
// middleware.
|
||||
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jwk := jws.Signatures[0].Protected.JSONWebKey
|
||||
if jwk == nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header")))
|
||||
return
|
||||
}
|
||||
if !jwk.Valid() {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
acc, err := h.Auth.GetAccountByKey(prov, jwk)
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
// For NewAccount requests ...
|
||||
break
|
||||
case err != nil:
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
}
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// lookupProvisioner loads the provisioner associated with the request.
|
||||
// Responsds 404 if the provisioner does not exist.
|
||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
name := chi.URLParam(r, "provisionerID")
|
||||
provID, err := url.PathUnescape(name)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name)))
|
||||
return
|
||||
}
|
||||
p, err := h.Auth.LoadProvisionerByID("acme/" + provID)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if p.GetType() != provisioner.TypeACME {
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, p)
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||
// kid parameter of the signed payload.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
kidPrefix := h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, "")
|
||||
kid := jws.Signatures[0].Protected.KeyID
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+
|
||||
"required prefix; expected %s, but got %s", kidPrefix, kid)))
|
||||
return
|
||||
}
|
||||
|
||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||
acc, err := h.Auth.GetAccount(prov, accID)
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
||||
return
|
||||
case err != nil:
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
||||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
jws, err := jwsFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jwk, err := jwkFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||
api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match")))
|
||||
return
|
||||
}
|
||||
payload, err := jws.Verify(jwk)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws")))
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), payloadContextKey, &payloadInfo{
|
||||
value: payload,
|
||||
isPostAsGet: string(payload) == "",
|
||||
isEmptyJSON: string(payload) == "{}",
|
||||
})
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
|
||||
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if !payload.isPostAsGet {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET")))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
}
|
1550
acme/api/middleware_test.go
Normal file
1550
acme/api/middleware_test.go
Normal file
File diff suppressed because it is too large
Load diff
164
acme/api/order.go
Normal file
164
acme/api/order.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
)
|
||||
|
||||
// NewOrderRequest represents the body for a NewOrder request.
|
||||
type NewOrderRequest struct {
|
||||
Identifiers []acme.Identifier `json:"identifiers"`
|
||||
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||
}
|
||||
|
||||
// Validate validates a new-order request body.
|
||||
func (n *NewOrderRequest) Validate() error {
|
||||
if len(n.Identifiers) == 0 {
|
||||
return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty"))
|
||||
}
|
||||
for _, id := range n.Identifiers {
|
||||
if id.Type != "dns" {
|
||||
return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinalizeRequest captures the body for a Finalize order request.
|
||||
type FinalizeRequest struct {
|
||||
CSR string `json:"csr"`
|
||||
csr *x509.CertificateRequest
|
||||
}
|
||||
|
||||
// Validate validates a finalize request body.
|
||||
func (f *FinalizeRequest) Validate() error {
|
||||
var err error
|
||||
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
|
||||
if err != nil {
|
||||
return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr"))
|
||||
}
|
||||
f.csr, err = x509.ParseCertificateRequest(csrBytes)
|
||||
if err != nil {
|
||||
return acme.MalformedErr(errors.Wrap(err, "unable to parse csr"))
|
||||
}
|
||||
if err = f.csr.CheckSignature(); err != nil {
|
||||
return acme.MalformedErr(errors.Wrap(err, "csr failed signature check"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewOrder ACME api for creating a new order.
|
||||
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var nor NewOrderRequest
|
||||
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
||||
"failed to unmarshal new-order request payload")))
|
||||
return
|
||||
}
|
||||
if err := nor.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
o, err := h.Auth.NewOrder(prov, acme.OrderOptions{
|
||||
AccountID: acc.GetID(),
|
||||
Identifiers: nor.Identifiers,
|
||||
NotBefore: nor.NotBefore,
|
||||
NotAfter: nor.NotAfter,
|
||||
})
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
|
||||
api.JSONStatus(w, o, http.StatusCreated)
|
||||
return
|
||||
}
|
||||
|
||||
// GetOrder ACME api for retrieving an order.
|
||||
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
oid := chi.URLParam(r, "ordID")
|
||||
o, err := h.Auth.GetOrder(prov, acc.GetID(), oid)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
|
||||
api.JSON(w, o)
|
||||
return
|
||||
}
|
||||
|
||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
prov, err := provisionerFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acc, err := accountFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
var fr FinalizeRequest
|
||||
if err := json.Unmarshal(payload.value, &fr); err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload")))
|
||||
return
|
||||
}
|
||||
if err := fr.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
oid := chi.URLParam(r, "ordID")
|
||||
o, err := h.Auth.FinalizeOrder(prov, acc.GetID(), oid, fr.csr)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.ID))
|
||||
api.JSON(w, o)
|
||||
return
|
||||
}
|
757
acme/api/order_test.go
Normal file
757
acme/api/order_test.go
Normal file
|
@ -0,0 +1,757 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestNewOrderRequestValidate(t *testing.T) {
|
||||
type test struct {
|
||||
nor *NewOrderRequest
|
||||
nbf, naf time.Time
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-identifiers": func(t *testing.T) test {
|
||||
return test{
|
||||
nor: &NewOrderRequest{},
|
||||
err: acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")),
|
||||
}
|
||||
},
|
||||
"fail/bad-identifier": func(t *testing.T) test {
|
||||
return test{
|
||||
nor: &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "foo", Value: "bar.com"},
|
||||
},
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("identifier type unsupported: foo")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
nbf := time.Now().UTC().Add(time.Minute)
|
||||
naf := time.Now().UTC().Add(5 * time.Minute)
|
||||
return test{
|
||||
nor: &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
NotAfter: naf,
|
||||
NotBefore: nbf,
|
||||
},
|
||||
nbf: nbf,
|
||||
naf: naf,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.nor.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
if tc.nbf.IsZero() {
|
||||
assert.True(t, tc.nor.NotBefore.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, tc.nor.NotBefore.After(time.Now().Add(-time.Minute)))
|
||||
} else {
|
||||
assert.Equals(t, tc.nor.NotBefore, tc.nbf)
|
||||
}
|
||||
if tc.naf.IsZero() {
|
||||
assert.True(t, tc.nor.NotAfter.Before(time.Now().Add(24*time.Hour)))
|
||||
assert.True(t, tc.nor.NotAfter.After(time.Now().Add(24*time.Hour-time.Minute)))
|
||||
} else {
|
||||
assert.Equals(t, tc.nor.NotAfter, tc.naf)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestValidate(t *testing.T) {
|
||||
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
||||
assert.FatalError(t, err)
|
||||
csr, ok := _csr.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
type test struct {
|
||||
fr *FinalizeRequest
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/parse-csr-error": func(t *testing.T) test {
|
||||
return test{
|
||||
fr: &FinalizeRequest{},
|
||||
err: acme.MalformedErr(errors.Errorf("unable to parse csr: asn1: syntax error: sequence truncated")),
|
||||
}
|
||||
},
|
||||
"fail/invalid-csr-signature": func(t *testing.T) test {
|
||||
b, err := pemutil.Read("../../authority/testdata/certs/badsig.csr")
|
||||
assert.FatalError(t, err)
|
||||
c, ok := b.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
return test{
|
||||
fr: &FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(c.Raw),
|
||||
},
|
||||
err: acme.MalformedErr(errors.Errorf("csr failed signature check: x509: ECDSA verification failure")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
return test{
|
||||
fr: &FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if err := tc.fr.Validate(); err != nil {
|
||||
if assert.NotNil(t, err) {
|
||||
ae, ok := err.(*acme.Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.fr.csr.Raw, csr.Raw)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerGetOrder(t *testing.T) {
|
||||
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||
nbf := time.Now().UTC()
|
||||
naf := time.Now().UTC().Add(24 * time.Hour)
|
||||
o := acme.Order{
|
||||
ID: "orderID",
|
||||
Expires: expiry.Format(time.RFC3339),
|
||||
NotBefore: nbf.Format(time.RFC3339),
|
||||
NotAfter: naf.Format(time.RFC3339),
|
||||
Identifiers: []acme.Identifier{
|
||||
{
|
||||
Type: "dns",
|
||||
Value: "example.com",
|
||||
},
|
||||
{
|
||||
Type: "dns",
|
||||
Value: "*.smallstep.com",
|
||||
},
|
||||
},
|
||||
Status: "pending",
|
||||
Authorizations: []string{"foo", "bar"},
|
||||
}
|
||||
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("ordID", o.ID)
|
||||
prov := newProv()
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s",
|
||||
acme.URLSafeProvisionerName(prov), o.ID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/getOrder-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
err: acme.ServerInternalErr(errors.New("force")),
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
getOrder: func(p provisioner.Interface, accID, id string) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, o.ID)
|
||||
return &o, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.OrderLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{o.ID})
|
||||
return url
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(o)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"], []string{url})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerNewOrder(t *testing.T) {
|
||||
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||
nbf := time.Now().UTC().Add(5 * time.Hour)
|
||||
naf := nbf.Add(17 * time.Hour)
|
||||
o := acme.Order{
|
||||
ID: "orderID",
|
||||
Expires: expiry.Format(time.RFC3339),
|
||||
NotBefore: nbf.Format(time.RFC3339),
|
||||
NotAfter: naf.Format(time.RFC3339),
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
Status: "pending",
|
||||
Authorizations: []string{"foo", "bar"},
|
||||
}
|
||||
|
||||
prov := newProv()
|
||||
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order",
|
||||
acme.URLSafeProvisionerName(prov))
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("failed to unmarshal new-order request payload: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &NewOrderRequest{}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("identifiers list cannot be empty")),
|
||||
}
|
||||
},
|
||||
"fail/NewOrder-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.AccountID, acc.ID)
|
||||
assert.Equals(t, ops.Identifiers, nor.Identifiers)
|
||||
return nil, acme.MalformedErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
NotBefore: nbf,
|
||||
NotAfter: naf,
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.AccountID, acc.ID)
|
||||
assert.Equals(t, ops.Identifiers, nor.Identifiers)
|
||||
assert.Equals(t, ops.NotBefore, nbf)
|
||||
assert.Equals(t, ops.NotAfter, naf)
|
||||
return &o, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.OrderLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{o.ID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
}
|
||||
},
|
||||
"ok/default-naf-nbf": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &NewOrderRequest{
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, ops.AccountID, acc.ID)
|
||||
assert.Equals(t, ops.Identifiers, nor.Identifiers)
|
||||
|
||||
assert.True(t, ops.NotBefore.IsZero())
|
||||
assert.True(t, ops.NotAfter.IsZero())
|
||||
return &o, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.OrderLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{o.ID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 201,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.NewOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(o)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"],
|
||||
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerFinalizeOrder(t *testing.T) {
|
||||
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||
nbf := time.Now().UTC().Add(5 * time.Hour)
|
||||
naf := nbf.Add(17 * time.Hour)
|
||||
o := acme.Order{
|
||||
ID: "orderID",
|
||||
Expires: expiry.Format(time.RFC3339),
|
||||
NotBefore: nbf.Format(time.RFC3339),
|
||||
NotAfter: naf.Format(time.RFC3339),
|
||||
Identifiers: []acme.Identifier{
|
||||
{Type: "dns", Value: "example.com"},
|
||||
{Type: "dns", Value: "bar.com"},
|
||||
},
|
||||
Status: "valid",
|
||||
Authorizations: []string{"foo", "bar"},
|
||||
Certificate: "https://ca.smallstep.com/acme/certificate/certID",
|
||||
}
|
||||
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
||||
assert.FatalError(t, err)
|
||||
csr, ok := _csr.(*x509.CertificateRequest)
|
||||
assert.Fatal(t, ok)
|
||||
|
||||
// Request with chi context
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("ordID", o.ID)
|
||||
prov := newProv()
|
||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s/finalize",
|
||||
acme.URLSafeProvisionerName(prov), o.ID)
|
||||
|
||||
type test struct {
|
||||
auth acme.Interface
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
problem *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.Background(),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-provisioner": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/no-account": func(t *testing.T) test {
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/nil-account": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{},
|
||||
ctx: ctx,
|
||||
statusCode: 404,
|
||||
problem: acme.AccountDoesNotExistErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/no-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/nil-payload": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("failed to unmarshal finalize-order request payload: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
fr := &FinalizeRequest{}
|
||||
b, err := json.Marshal(fr)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("unable to parse csr: asn1: syntax error: sequence truncated")),
|
||||
}
|
||||
},
|
||||
"fail/FinalizeOrder-error": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, o.ID)
|
||||
assert.Equals(t, incsr.Raw, csr.Raw)
|
||||
return nil, acme.MalformedErr(errors.New("force"))
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 400,
|
||||
problem: acme.MalformedErr(errors.New("force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
acc := &acme.Account{ID: "accID"}
|
||||
nor := &FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||
}
|
||||
b, err := json.Marshal(nor)
|
||||
assert.FatalError(t, err)
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||
return test{
|
||||
auth: &mockAcmeAuthority{
|
||||
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
|
||||
assert.Equals(t, p, prov)
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
assert.Equals(t, id, o.ID)
|
||||
assert.Equals(t, incsr.Raw, csr.Raw)
|
||||
return &o, nil
|
||||
},
|
||||
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||
assert.Equals(t, typ, acme.OrderLink)
|
||||
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||
assert.True(t, abs)
|
||||
assert.Equals(t, in, []string{o.ID})
|
||||
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s",
|
||||
acme.URLSafeProvisionerName(prov), o.ID)
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := New(tc.auth).(*Handler)
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.FinalizeOrder(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||
var ae acme.AError
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
prob := tc.problem.ToACME()
|
||||
|
||||
assert.Equals(t, ae.Type, prob.Type)
|
||||
assert.Equals(t, ae.Detail, prob.Detail)
|
||||
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
expB, err := json.Marshal(o)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||
assert.Equals(t, res.Header["Location"],
|
||||
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s",
|
||||
acme.URLSafeProvisionerName(prov), o.ID)})
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
263
acme/authority.go
Normal file
263
acme/authority.go
Normal file
|
@ -0,0 +1,263 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Interface is the acme authority interface.
|
||||
type Interface interface {
|
||||
DeactivateAccount(provisioner.Interface, string) (*Account, error)
|
||||
FinalizeOrder(provisioner.Interface, string, string, *x509.CertificateRequest) (*Order, error)
|
||||
GetAccount(provisioner.Interface, string) (*Account, error)
|
||||
GetAccountByKey(provisioner.Interface, *jose.JSONWebKey) (*Account, error)
|
||||
GetAuthz(provisioner.Interface, string, string) (*Authz, error)
|
||||
GetCertificate(string, string) ([]byte, error)
|
||||
GetDirectory(provisioner.Interface) *Directory
|
||||
GetLink(Link, string, bool, ...string) string
|
||||
GetOrder(provisioner.Interface, string, string) (*Order, error)
|
||||
GetOrdersByAccount(provisioner.Interface, string) ([]string, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
NewAccount(provisioner.Interface, AccountOptions) (*Account, error)
|
||||
NewNonce() (string, error)
|
||||
NewOrder(provisioner.Interface, OrderOptions) (*Order, error)
|
||||
UpdateAccount(provisioner.Interface, string, []string) (*Account, error)
|
||||
UseNonce(string) error
|
||||
ValidateChallenge(provisioner.Interface, string, string, *jose.JSONWebKey) (*Challenge, error)
|
||||
}
|
||||
|
||||
// Authority is the layer that handles all ACME interactions.
|
||||
type Authority struct {
|
||||
db nosql.DB
|
||||
dir *directory
|
||||
signAuth SignAuthority
|
||||
}
|
||||
|
||||
// NewAuthority returns a new Authority that implements the ACME interface.
|
||||
func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) *Authority {
|
||||
return &Authority{
|
||||
db: db, dir: newDirectory(dns, prefix), signAuth: signAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// GetLink returns the requested link from the directory.
|
||||
func (a *Authority) GetLink(typ Link, provID string, abs bool, inputs ...string) string {
|
||||
return a.dir.getLink(typ, provID, abs, inputs...)
|
||||
}
|
||||
|
||||
// GetDirectory returns the ACME directory object.
|
||||
func (a *Authority) GetDirectory(p provisioner.Interface) *Directory {
|
||||
name := url.PathEscape(p.GetName())
|
||||
return &Directory{
|
||||
NewNonce: a.dir.getLink(NewNonceLink, name, true),
|
||||
NewAccount: a.dir.getLink(NewAccountLink, name, true),
|
||||
NewOrder: a.dir.getLink(NewOrderLink, name, true),
|
||||
RevokeCert: a.dir.getLink(RevokeCertLink, name, true),
|
||||
KeyChange: a.dir.getLink(KeyChangeLink, name, true),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadProvisionerByID calls out to the SignAuthority interface to load a
|
||||
// provisioner by ID.
|
||||
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
return a.signAuth.LoadProvisionerByID(id)
|
||||
}
|
||||
|
||||
// NewNonce generates, stores, and returns a new ACME nonce.
|
||||
func (a *Authority) NewNonce() (string, error) {
|
||||
n, err := newNonce(a.db)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return n.ID, nil
|
||||
}
|
||||
|
||||
// UseNonce consumes the given nonce if it is valid, returns error otherwise.
|
||||
func (a *Authority) UseNonce(nonce string) error {
|
||||
return useNonce(a.db, nonce)
|
||||
}
|
||||
|
||||
// NewAccount creates, stores, and returns a new ACME account.
|
||||
func (a *Authority) NewAccount(p provisioner.Interface, ao AccountOptions) (*Account, error) {
|
||||
acc, err := newAccount(a.db, ao)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// UpdateAccount updates an ACME account.
|
||||
func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(err)
|
||||
}
|
||||
if acc, err = acc.update(a.db, contact); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetAccount returns an ACME account.
|
||||
func (a *Authority) GetAccount(p provisioner.Interface, id string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// DeactivateAccount deactivates an ACME account.
|
||||
func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Account, error) {
|
||||
acc, err := getAccountByID(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if acc, err = acc.deactivate(a.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
func keyToID(jwk *jose.JSONWebKey) (string, error) {
|
||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint"))
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||
}
|
||||
|
||||
// GetAccountByKey returns the ACME associated with the jwk id.
|
||||
func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*Account, error) {
|
||||
kid, err := keyToID(jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
acc, err := getAccountByKeyID(a.db, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return acc.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetOrder returns an ACME order.
|
||||
func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*Order, error) {
|
||||
o, err := getOrder(a.db, orderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != o.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||
}
|
||||
if o, err = o.updateStatus(a.db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return o.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetOrdersByAccount returns the list of order urls owned by the account.
|
||||
func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) {
|
||||
oids, err := getOrderIDsByAccount(a.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret = []string{}
|
||||
for _, oid := range oids {
|
||||
o, err := getOrder(a.db, oid)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(err)
|
||||
}
|
||||
if o.Status == StatusInvalid {
|
||||
continue
|
||||
}
|
||||
ret = append(ret, a.dir.getLink(OrderLink, URLSafeProvisionerName(p), true, o.ID))
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// NewOrder generates, stores, and returns a new ACME order.
|
||||
func (a *Authority) NewOrder(p provisioner.Interface, ops OrderOptions) (*Order, error) {
|
||||
order, err := newOrder(a.db, ops)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating order")
|
||||
}
|
||||
return order.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// FinalizeOrder attempts to finalize an order and generate a new certificate.
|
||||
func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
|
||||
o, err := getOrder(a.db, orderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != o.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||
}
|
||||
o, err = o.finalize(a.db, csr, a.signAuth, p)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error finalizing order")
|
||||
}
|
||||
return o.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetAuthz retrieves and attempts to update the status on an ACME authz
|
||||
// before returning.
|
||||
func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*Authz, error) {
|
||||
az, err := getAuthz(a.db, authzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != az.getAccountID() {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own authz"))
|
||||
}
|
||||
az, err = az.updateStatus(a.db)
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error updating authz status")
|
||||
}
|
||||
return az.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// ValidateChallenge attempts to validate the challenge.
|
||||
func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) {
|
||||
ch, err := getChallenge(a.db, chID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != ch.getAccountID() {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own challenge"))
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: time.Duration(30 * time.Second),
|
||||
}
|
||||
ch, err = ch.validate(a.db, jwk, validateOptions{
|
||||
httpGet: client.Get,
|
||||
lookupTxt: net.LookupTXT,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error attempting challenge validation")
|
||||
}
|
||||
return ch.toACME(a.db, a.dir, p)
|
||||
}
|
||||
|
||||
// GetCertificate retrieves the Certificate by ID.
|
||||
func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) {
|
||||
cert, err := getCert(a.db, certID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accID != cert.AccountID {
|
||||
return nil, UnauthorizedErr(errors.New("account does not own certificate"))
|
||||
}
|
||||
return cert.toACME(a.db, a.dir)
|
||||
}
|
1474
acme/authority_test.go
Normal file
1474
acme/authority_test.go
Normal file
File diff suppressed because it is too large
Load diff
337
acme/authz.go
Normal file
337
acme/authz.go
Normal file
|
@ -0,0 +1,337 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
var defaultExpiryDuration = time.Hour * 24
|
||||
|
||||
// Authz is a subset of the Authz type containing only those attributes
|
||||
// required for responses in the ACME protocol.
|
||||
type Authz struct {
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status string `json:"status"`
|
||||
Expires string `json:"expires"`
|
||||
Challenges []*Challenge `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
ID string `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (a *Authz) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the Authz ID.
|
||||
func (a *Authz) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
// authz is the interface that the various authz types must implement.
|
||||
type authz interface {
|
||||
save(nosql.DB, authz) error
|
||||
clone() *baseAuthz
|
||||
getID() string
|
||||
getAccountID() string
|
||||
getType() string
|
||||
getIdentifier() Identifier
|
||||
getStatus() string
|
||||
getExpiry() time.Time
|
||||
getWildcard() bool
|
||||
getChallenges() []string
|
||||
getCreated() time.Time
|
||||
updateStatus(db nosql.DB) (authz, error)
|
||||
toACME(nosql.DB, *directory, provisioner.Interface) (*Authz, error)
|
||||
}
|
||||
|
||||
// baseAuthz is the base authz type that others build from.
|
||||
type baseAuthz struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Identifier Identifier `json:"identifier"`
|
||||
Status string `json:"status"`
|
||||
Expires time.Time `json:"expires"`
|
||||
Challenges []string `json:"challenges"`
|
||||
Wildcard bool `json:"wildcard"`
|
||||
Created time.Time `json:"created"`
|
||||
Error *Error `json:"error"`
|
||||
}
|
||||
|
||||
func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
ba := &baseAuthz{
|
||||
ID: id,
|
||||
AccountID: accID,
|
||||
Status: StatusPending,
|
||||
Created: now,
|
||||
Expires: now.Add(defaultExpiryDuration),
|
||||
Identifier: identifier,
|
||||
}
|
||||
|
||||
if strings.HasPrefix(identifier.Value, "*.") {
|
||||
ba.Wildcard = true
|
||||
ba.Identifier = Identifier{
|
||||
Value: strings.TrimPrefix(identifier.Value, "*."),
|
||||
Type: identifier.Type,
|
||||
}
|
||||
}
|
||||
|
||||
return ba, nil
|
||||
}
|
||||
|
||||
// getID returns the ID of the authz.
|
||||
func (ba *baseAuthz) getID() string {
|
||||
return ba.ID
|
||||
}
|
||||
|
||||
// getAccountID returns the Account ID that created the authz.
|
||||
func (ba *baseAuthz) getAccountID() string {
|
||||
return ba.AccountID
|
||||
}
|
||||
|
||||
// getType returns the type of the authz.
|
||||
func (ba *baseAuthz) getType() string {
|
||||
return ba.Identifier.Type
|
||||
}
|
||||
|
||||
// getIdentifier returns the identifier for the authz.
|
||||
func (ba *baseAuthz) getIdentifier() Identifier {
|
||||
return ba.Identifier
|
||||
}
|
||||
|
||||
// getStatus returns the status of the authz.
|
||||
func (ba *baseAuthz) getStatus() string {
|
||||
return ba.Status
|
||||
}
|
||||
|
||||
// getWildcard returns true if the authz identifier has a '*', false otherwise.
|
||||
func (ba *baseAuthz) getWildcard() bool {
|
||||
return ba.Wildcard
|
||||
}
|
||||
|
||||
// getChallenges returns the authz challenge IDs.
|
||||
func (ba *baseAuthz) getChallenges() []string {
|
||||
return ba.Challenges
|
||||
}
|
||||
|
||||
// getExpiry returns the expiration time of the authz.
|
||||
func (ba *baseAuthz) getExpiry() time.Time {
|
||||
return ba.Expires
|
||||
}
|
||||
|
||||
// getCreated returns the created time of the authz.
|
||||
func (ba *baseAuthz) getCreated() time.Time {
|
||||
return ba.Created
|
||||
}
|
||||
|
||||
// toACME converts the internal Authz type into the public acmeAuthz type for
|
||||
// presentation in the ACME protocol.
|
||||
func (ba *baseAuthz) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Authz, error) {
|
||||
var chs = make([]*Challenge, len(ba.Challenges))
|
||||
for i, chID := range ba.Challenges {
|
||||
ch, err := getChallenge(db, chID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chs[i], err = ch.toACME(db, dir, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Authz{
|
||||
Identifier: ba.Identifier,
|
||||
Status: ba.getStatus(),
|
||||
Challenges: chs,
|
||||
Wildcard: ba.getWildcard(),
|
||||
Expires: ba.Expires.Format(time.RFC3339),
|
||||
ID: ba.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) save(db nosql.DB, old authz) error {
|
||||
var (
|
||||
err error
|
||||
oldB, newB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old authz"))
|
||||
}
|
||||
}
|
||||
if newB, err = json.Marshal(ba); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new authz"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error storing authz"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("error storing authz; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) clone() *baseAuthz {
|
||||
u := *ba
|
||||
return &u
|
||||
}
|
||||
|
||||
func (ba *baseAuthz) parent() authz {
|
||||
return &dnsAuthz{ba}
|
||||
}
|
||||
|
||||
// updateStatus attempts to update the status on a baseAuthz and stores the
|
||||
// updating object if necessary.
|
||||
func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) {
|
||||
newAuthz := ba.clone()
|
||||
|
||||
now := time.Now().UTC()
|
||||
switch ba.Status {
|
||||
case StatusInvalid:
|
||||
return ba.parent(), nil
|
||||
case StatusValid:
|
||||
return ba.parent(), nil
|
||||
case StatusPending:
|
||||
// check expiry
|
||||
if now.After(ba.Expires) {
|
||||
newAuthz.Status = StatusInvalid
|
||||
newAuthz.Error = MalformedErr(errors.New("authz has expired"))
|
||||
break
|
||||
}
|
||||
|
||||
var isValid = false
|
||||
for _, chID := range ba.Challenges {
|
||||
ch, err := getChallenge(db, chID)
|
||||
if err != nil {
|
||||
return ba, err
|
||||
}
|
||||
if ch.getStatus() == StatusValid {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
return ba.parent(), nil
|
||||
}
|
||||
newAuthz.Status = StatusValid
|
||||
newAuthz.Error = nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status))
|
||||
}
|
||||
|
||||
if err := newAuthz.save(db, ba); err != nil {
|
||||
return ba, err
|
||||
}
|
||||
return newAuthz.parent(), nil
|
||||
}
|
||||
|
||||
// unmarshalAuthz unmarshals an authz type into the correct sub-type.
|
||||
func unmarshalAuthz(data []byte) (authz, error) {
|
||||
var getType struct {
|
||||
Identifier Identifier `json:"identifier"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &getType); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type"))
|
||||
}
|
||||
|
||||
switch getType.Identifier.Type {
|
||||
case "dns":
|
||||
var ba baseAuthz
|
||||
if err := json.Unmarshal(data, &ba); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz"))
|
||||
}
|
||||
return &dnsAuthz{&ba}, nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s",
|
||||
getType.Identifier.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// dnsAuthz represents a dns acme authorization.
|
||||
type dnsAuthz struct {
|
||||
*baseAuthz
|
||||
}
|
||||
|
||||
// newAuthz returns a new acme authorization object based on the identifier
|
||||
// type.
|
||||
func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) {
|
||||
switch identifier.Type {
|
||||
case "dns":
|
||||
a, err = newDNSAuthz(db, accID, identifier)
|
||||
default:
|
||||
err = MalformedErr(errors.Errorf("unexpected authz type %s",
|
||||
identifier.Type))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// newDNSAuthz returns a new dns acme authorization object.
|
||||
func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) {
|
||||
ba, err := newBaseAuthz(accID, identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ba.Challenges = []string{}
|
||||
if !ba.Wildcard {
|
||||
// http challenges are only permitted if the DNS is not a wildcard dns.
|
||||
ch1, err := newHTTP01Challenge(db, ChallengeOptions{
|
||||
AccountID: accID,
|
||||
AuthzID: ba.ID,
|
||||
Identifier: ba.Identifier})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating http challenge")
|
||||
}
|
||||
ba.Challenges = append(ba.Challenges, ch1.getID())
|
||||
}
|
||||
ch2, err := newDNS01Challenge(db, ChallengeOptions{
|
||||
AccountID: accID,
|
||||
AuthzID: ba.ID,
|
||||
Identifier: identifier})
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error creating dns challenge")
|
||||
}
|
||||
ba.Challenges = append(ba.Challenges, ch2.getID())
|
||||
|
||||
da := &dnsAuthz{ba}
|
||||
if err := da.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return da, nil
|
||||
}
|
||||
|
||||
// getAuthz retrieves and unmarshals an ACME authz type from the database.
|
||||
func getAuthz(db nosql.DB, id string) (authz, error) {
|
||||
b, err := db.Get(authzTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id))
|
||||
}
|
||||
az, err := unmarshalAuthz(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return az, nil
|
||||
}
|
809
acme/authz_test.go
Normal file
809
acme/authz_test.go
Normal file
|
@ -0,0 +1,809 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func newAz() (authz, error) {
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
}
|
||||
return newAuthz(mockdb, "1234", Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAuthz(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
az authz
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Identifier.Type = "foo"
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, key, []byte(az.getID()))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
id: az.getID(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, key, []byte(az.getID()))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if az, err := getAuthz(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.getID(), az.getID())
|
||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||
assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
|
||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzClone(t *testing.T) {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
clone := az.clone()
|
||||
|
||||
assert.Equals(t, clone.getID(), az.getID())
|
||||
assert.Equals(t, clone.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, clone.getStatus(), az.getStatus())
|
||||
assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
|
||||
assert.Equals(t, clone.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, clone.getCreated(), az.getCreated())
|
||||
assert.Equals(t, clone.getChallenges(), az.getChallenges())
|
||||
|
||||
clone.Status = StatusValid
|
||||
|
||||
assert.NotEquals(t, clone.getStatus(), az.getStatus())
|
||||
}
|
||||
|
||||
func TestNewAuthz(t *testing.T) {
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
accID := "1234"
|
||||
type test struct {
|
||||
iden Identifier
|
||||
db nosql.DB
|
||||
err *Error
|
||||
resChs *([]string)
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/unexpected-type": func(t *testing.T) test {
|
||||
return test{
|
||||
iden: Identifier{Type: "foo", Value: "acme.example.com"},
|
||||
err: MalformedErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"fail/new-http-chall-error": func(t *testing.T) test {
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
|
||||
}
|
||||
},
|
||||
"fail/new-dns-chall-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 1 {
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
|
||||
}
|
||||
},
|
||||
"fail/save-authz-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 2 {
|
||||
return nil, false, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
chs := &([]string{})
|
||||
count := 0
|
||||
return test{
|
||||
iden: iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 2 {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
az, err := unmarshalAuthz(newval)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, az.getID(), string(key))
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
assert.Equals(t, az.getIdentifier(), iden)
|
||||
assert.Equals(t, az.getWildcard(), false)
|
||||
|
||||
*chs = az.getChallenges()
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
resChs: chs,
|
||||
}
|
||||
},
|
||||
"ok/wildcard": func(t *testing.T) test {
|
||||
chs := &([]string{})
|
||||
count := 0
|
||||
_iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
|
||||
return test{
|
||||
iden: _iden,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 1 {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, old, nil)
|
||||
|
||||
az, err := unmarshalAuthz(newval)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, az.getID(), string(key))
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
assert.Equals(t, az.getIdentifier(), iden)
|
||||
assert.Equals(t, az.getWildcard(), true)
|
||||
|
||||
*chs = az.getChallenges()
|
||||
// Verify that we only have 1 challenge instead of 2.
|
||||
assert.True(t, len(*chs) == 1)
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
resChs: chs,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
az, err := newAuthz(tc.db, accID, tc.iden)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, az.getAccountID(), accID)
|
||||
assert.Equals(t, az.getType(), "dns")
|
||||
assert.Equals(t, az.getStatus(), StatusPending)
|
||||
|
||||
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||
|
||||
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||
|
||||
assert.Equals(t, az.getChallenges(), *(tc.resChs))
|
||||
|
||||
if strings.HasPrefix(tc.iden.Value, "*.") {
|
||||
assert.True(t, az.getWildcard())
|
||||
assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
|
||||
} else {
|
||||
assert.False(t, az.getWildcard())
|
||||
assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
|
||||
}
|
||||
|
||||
assert.True(t, az.getID() != "")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzToACME(t *testing.T) {
|
||||
dir := newDirectory("ca.smallstep.com", "acme")
|
||||
|
||||
var (
|
||||
ch1, ch2 challenge
|
||||
ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||
err error
|
||||
)
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
ch1, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
} else if count == 1 {
|
||||
*ch2Bytes = newval
|
||||
ch2, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
count++
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
prov := newProv()
|
||||
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/getChallenge1-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"fail/getChallenge2-error": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 1 {
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
count := 0
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
return *ch2Bytes, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
acmeAz, err := az.toACME(tc.db, dir, prov)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, acmeAz.ID, az.getID())
|
||||
assert.Equals(t, acmeAz.Identifier, iden)
|
||||
assert.Equals(t, acmeAz.Status, StatusPending)
|
||||
|
||||
acmeCh1, err := ch1.toACME(nil, dir, prov)
|
||||
assert.FatalError(t, err)
|
||||
acmeCh2, err := ch2.toACME(nil, dir, prov)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
|
||||
assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
|
||||
|
||||
expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, expiry.String(), az.getExpiry().String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzSave(t *testing.T) {
|
||||
type test struct {
|
||||
az, old authz
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok/old-nil": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: nil,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, nil)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, []byte(az.getID()), key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/old-not-nil": func(t *testing.T) test {
|
||||
oldAz, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
oldb, err := json.Marshal(oldAz)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
old: oldAz,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, old, oldb)
|
||||
assert.Equals(t, b, newval)
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
assert.Equals(t, []byte(az.getID()), key)
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := tc.az.save(tc.db, tc.old); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzUnmarshal(t *testing.T) {
|
||||
type test struct {
|
||||
az authz
|
||||
azb []byte
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/nil": func(t *testing.T) test {
|
||||
return test{
|
||||
azb: nil,
|
||||
err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"fail/unexpected-type": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Identifier.Type = "foo"
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
azb: b,
|
||||
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||
}
|
||||
},
|
||||
"ok/dns": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
az: az,
|
||||
azb: b,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if az, err := unmarshalAuthz(tc.azb); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.az.getID(), az.getID())
|
||||
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||
assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
|
||||
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthzUpdateStatus(t *testing.T) {
|
||||
type test struct {
|
||||
az, res authz
|
||||
err *Error
|
||||
db nosql.DB
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/already-invalid": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusInvalid
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
}
|
||||
},
|
||||
"fail/already-valid": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusValid
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
}
|
||||
},
|
||||
"fail/unexpected-status": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Status = StatusReady
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
|
||||
}
|
||||
},
|
||||
"fail/save-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||
}
|
||||
},
|
||||
"ok/expired": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||
|
||||
clone := az.clone()
|
||||
clone.Error = MalformedErr(errors.New("authz has expired"))
|
||||
clone.Status = StatusInvalid
|
||||
return test{
|
||||
az: az,
|
||||
res: clone.parent(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/get-challenge-error": func(t *testing.T) test {
|
||||
az, err := newAz()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||
}
|
||||
},
|
||||
"ok/valid": func(t *testing.T) test {
|
||||
var (
|
||||
ch2 challenge
|
||||
ch1Bytes = &([]byte{})
|
||||
err error
|
||||
)
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
} else if count == 1 {
|
||||
ch2, err = unmarshalChallenge(newval)
|
||||
assert.FatalError(t, err)
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
_az, ok := az.(*dnsAuthz)
|
||||
assert.Fatal(t, ok)
|
||||
_az.baseAuthz.Error = MalformedErr(nil)
|
||||
|
||||
_ch, ok := ch2.(*dns01Challenge)
|
||||
assert.Fatal(t, ok)
|
||||
_ch.baseChallenge.Status = StatusValid
|
||||
chb, err := json.Marshal(ch2)
|
||||
|
||||
clone := az.clone()
|
||||
clone.Status = StatusValid
|
||||
clone.Error = nil
|
||||
|
||||
count = 0
|
||||
return test{
|
||||
az: az,
|
||||
res: clone.parent(),
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
count++
|
||||
return chb, nil
|
||||
},
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/still-pending": func(t *testing.T) test {
|
||||
var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||
|
||||
count := 0
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if count == 0 {
|
||||
*ch1Bytes = newval
|
||||
} else if count == 1 {
|
||||
*ch2Bytes = newval
|
||||
}
|
||||
count++
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
iden := Identifier{
|
||||
Type: "dns", Value: "acme.example.com",
|
||||
}
|
||||
az, err := newAuthz(mockdb, "1234", iden)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
count = 0
|
||||
return test{
|
||||
az: az,
|
||||
res: az,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
if count == 0 {
|
||||
count++
|
||||
return *ch1Bytes, nil
|
||||
}
|
||||
count++
|
||||
return *ch2Bytes, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
az, err := tc.az.updateStatus(tc.db)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
expB, err := json.Marshal(tc.res)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(az)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, expB, b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
89
acme/certificate.go
Normal file
89
acme/certificate.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type certificate struct {
|
||||
ID string `json:"id"`
|
||||
Created time.Time `json:"created"`
|
||||
AccountID string `json:"accountID"`
|
||||
OrderID string `json:"orderID"`
|
||||
Leaf []byte `json:"leaf"`
|
||||
Intermediates []byte `json:"intermediates"`
|
||||
}
|
||||
|
||||
// CertOptions options with which to create and store a cert object.
|
||||
type CertOptions struct {
|
||||
AccountID string
|
||||
OrderID string
|
||||
Leaf *x509.Certificate
|
||||
Intermediates []*x509.Certificate
|
||||
}
|
||||
|
||||
func newCert(db nosql.DB, ops CertOptions) (*certificate, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leaf := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: ops.Leaf.Raw,
|
||||
})
|
||||
var intermediates []byte
|
||||
for _, cert := range ops.Intermediates {
|
||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
|
||||
cert := &certificate{
|
||||
ID: id,
|
||||
AccountID: ops.AccountID,
|
||||
OrderID: ops.OrderID,
|
||||
Leaf: leaf,
|
||||
Intermediates: intermediates,
|
||||
Created: time.Now().UTC(),
|
||||
}
|
||||
certB, err := json.Marshal(cert)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate"))
|
||||
}
|
||||
|
||||
_, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate"))
|
||||
case !swapped:
|
||||
return nil, ServerInternalErr(errors.New("error storing certificate; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) {
|
||||
return append(c.Leaf, c.Intermediates...), nil
|
||||
}
|
||||
|
||||
func getCert(db nosql.DB, id string) (*certificate, error) {
|
||||
b, err := db.Get(certTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate"))
|
||||
}
|
||||
var cert certificate
|
||||
if err := json.Unmarshal(b, &cert); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate"))
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
253
acme/certificate_test.go
Normal file
253
acme/certificate_test.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func defaultCertOps() (*CertOptions, error) {
|
||||
crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CertOptions{
|
||||
AccountID: "accID",
|
||||
OrderID: "ordID",
|
||||
Leaf: crt,
|
||||
Intermediates: []*x509.Certificate{inter, root},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newcert() (*certificate, error) {
|
||||
ops, err := defaultCertOps()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockdb := &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, true, nil
|
||||
},
|
||||
}
|
||||
return newCert(mockdb, *ops)
|
||||
}
|
||||
|
||||
func TestNewCert(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
ops CertOptions
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing certificate: force")),
|
||||
}
|
||||
},
|
||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
ops, err := defaultCertOps()
|
||||
assert.FatalError(t, err)
|
||||
var _id string
|
||||
id := &_id
|
||||
return test{
|
||||
ops: *ops,
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, old, nil)
|
||||
*id = string(key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if cert, err := newCert(tc.db, tc.ops); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, cert.ID, *tc.id)
|
||||
assert.Equals(t, cert.AccountID, tc.ops.AccountID)
|
||||
assert.Equals(t, cert.OrderID, tc.ops.OrderID)
|
||||
|
||||
leaf := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: tc.ops.Leaf.Raw,
|
||||
})
|
||||
var intermediates []byte
|
||||
for _, cert := range tc.ops.Intermediates {
|
||||
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})...)
|
||||
}
|
||||
assert.Equals(t, cert.Leaf, leaf)
|
||||
assert.Equals(t, cert.Intermediates, intermediates)
|
||||
|
||||
assert.True(t, cert.Created.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, cert.Created.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCert(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
cert *certificate
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/not-found": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
},
|
||||
err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)),
|
||||
}
|
||||
},
|
||||
"fail/db-error": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error loading certificate: force")),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal-error": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(cert)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
cert: cert,
|
||||
id: cert.ID,
|
||||
db: &db.MockNoSQLDB{
|
||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||
assert.Equals(t, bucket, certTable)
|
||||
assert.Equals(t, key, []byte(cert.ID))
|
||||
return b, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if cert, err := getCert(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.cert.ID, cert.ID)
|
||||
assert.Equals(t, tc.cert.AccountID, cert.AccountID)
|
||||
assert.Equals(t, tc.cert.OrderID, cert.OrderID)
|
||||
assert.Equals(t, tc.cert.Created, cert.Created)
|
||||
assert.Equals(t, tc.cert.Leaf, cert.Leaf)
|
||||
assert.Equals(t, tc.cert.Intermediates, cert.Intermediates)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateToACME(t *testing.T) {
|
||||
cert, err := newcert()
|
||||
assert.FatalError(t, err)
|
||||
acmeCert, err := cert.toACME(nil, nil)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert)
|
||||
}
|
445
acme/challenge.go
Normal file
445
acme/challenge.go
Normal file
|
@ -0,0 +1,445 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/jose"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
// Challenge is a subset of the challenge type containing only those attributes
|
||||
// required for responses in the ACME protocol.
|
||||
type Challenge struct {
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Validated string `json:"validated,omitempty"`
|
||||
URL string `json:"url"`
|
||||
Error *AError `json:"error,omitempty"`
|
||||
ID string `json:"-"`
|
||||
AuthzID string `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (c *Challenge) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the Challenge ID.
|
||||
func (c *Challenge) GetID() string {
|
||||
return c.ID
|
||||
}
|
||||
|
||||
// GetAuthzID returns the parent Authz ID that owns the Challenge.
|
||||
func (c *Challenge) GetAuthzID() string {
|
||||
return c.AuthzID
|
||||
}
|
||||
|
||||
type httpGetter func(string) (*http.Response, error)
|
||||
type lookupTxt func(string) ([]string, error)
|
||||
|
||||
type validateOptions struct {
|
||||
httpGet httpGetter
|
||||
lookupTxt lookupTxt
|
||||
}
|
||||
|
||||
// challenge is the interface ACME challenege types must implement.
|
||||
type challenge interface {
|
||||
save(db nosql.DB, swap challenge) error
|
||||
validate(nosql.DB, *jose.JSONWebKey, validateOptions) (challenge, error)
|
||||
getType() string
|
||||
getError() *AError
|
||||
getValue() string
|
||||
getStatus() string
|
||||
getID() string
|
||||
getAuthzID() string
|
||||
getToken() string
|
||||
clone() *baseChallenge
|
||||
getAccountID() string
|
||||
getValidated() time.Time
|
||||
getCreated() time.Time
|
||||
toACME(nosql.DB, *directory, provisioner.Interface) (*Challenge, error)
|
||||
}
|
||||
|
||||
// ChallengeOptions is the type used to created a new Challenge.
|
||||
type ChallengeOptions struct {
|
||||
AccountID string
|
||||
AuthzID string
|
||||
Identifier Identifier
|
||||
}
|
||||
|
||||
// baseChallenge is the base Challenge type that others build from.
|
||||
type baseChallenge struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
AuthzID string `json:"authzID"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Value string `json:"value"`
|
||||
Validated time.Time `json:"validated"`
|
||||
Created time.Time `json:"created"`
|
||||
Error *AError `json:"error"`
|
||||
}
|
||||
|
||||
func newBaseChallenge(accountID, authzID string) (*baseChallenge, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error generating random id for ACME challenge")
|
||||
}
|
||||
token, err := randID()
|
||||
if err != nil {
|
||||
return nil, Wrap(err, "error generating token for ACME challenge")
|
||||
}
|
||||
|
||||
return &baseChallenge{
|
||||
ID: id,
|
||||
AccountID: accountID,
|
||||
AuthzID: authzID,
|
||||
Status: StatusPending,
|
||||
Token: token,
|
||||
Created: clock.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getID returns the id of the baseChallenge.
|
||||
func (bc *baseChallenge) getID() string {
|
||||
return bc.ID
|
||||
}
|
||||
|
||||
// getAuthzID returns the authz ID of the baseChallenge.
|
||||
func (bc *baseChallenge) getAuthzID() string {
|
||||
return bc.AuthzID
|
||||
}
|
||||
|
||||
// getAccountID returns the account id of the baseChallenge.
|
||||
func (bc *baseChallenge) getAccountID() string {
|
||||
return bc.AccountID
|
||||
}
|
||||
|
||||
// getType returns the type of the baseChallenge.
|
||||
func (bc *baseChallenge) getType() string {
|
||||
return bc.Type
|
||||
}
|
||||
|
||||
// getValue returns the type of the baseChallenge.
|
||||
func (bc *baseChallenge) getValue() string {
|
||||
return bc.Value
|
||||
}
|
||||
|
||||
// getStatus returns the status of the baseChallenge.
|
||||
func (bc *baseChallenge) getStatus() string {
|
||||
return bc.Status
|
||||
}
|
||||
|
||||
// getToken returns the token of the baseChallenge.
|
||||
func (bc *baseChallenge) getToken() string {
|
||||
return bc.Token
|
||||
}
|
||||
|
||||
// getValidated returns the validated time of the baseChallenge.
|
||||
func (bc *baseChallenge) getValidated() time.Time {
|
||||
return bc.Validated
|
||||
}
|
||||
|
||||
// getCreated returns the created time of the baseChallenge.
|
||||
func (bc *baseChallenge) getCreated() time.Time {
|
||||
return bc.Created
|
||||
}
|
||||
|
||||
// getCreated returns the created time of the baseChallenge.
|
||||
func (bc *baseChallenge) getError() *AError {
|
||||
return bc.Error
|
||||
}
|
||||
|
||||
// toACME converts the internal Challenge type into the public acmeChallenge
|
||||
// type for presentation in the ACME protocol.
|
||||
func (bc *baseChallenge) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Challenge, error) {
|
||||
ac := &Challenge{
|
||||
Type: bc.getType(),
|
||||
Status: bc.getStatus(),
|
||||
Token: bc.getToken(),
|
||||
URL: dir.getLink(ChallengeLink, URLSafeProvisionerName(p), true, bc.getID()),
|
||||
ID: bc.getID(),
|
||||
AuthzID: bc.getAuthzID(),
|
||||
}
|
||||
if !bc.Validated.IsZero() {
|
||||
ac.Validated = bc.Validated.Format(time.RFC3339)
|
||||
}
|
||||
if bc.Error != nil {
|
||||
ac.Error = bc.Error
|
||||
}
|
||||
return ac, nil
|
||||
}
|
||||
|
||||
// save writes the challenge to disk. For new challenges 'old' should be nil,
|
||||
// otherwise 'old' should be a pointer to the acme challenge as it was at the
|
||||
// start of the request. This method will fail if the value currently found
|
||||
// in the bucket/row does not match the value of 'old'.
|
||||
func (bc *baseChallenge) save(db nosql.DB, old challenge) error {
|
||||
newB, err := json.Marshal(bc)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err,
|
||||
"error marshaling new acme challenge"))
|
||||
}
|
||||
var oldB []byte
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
oldB, err = json.Marshal(old)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err,
|
||||
"error marshaling old acme challenge"))
|
||||
}
|
||||
}
|
||||
|
||||
_, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error saving acme challenge"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.New("error saving acme challenge; " +
|
||||
"acme challenge has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (bc *baseChallenge) clone() *baseChallenge {
|
||||
u := *bc
|
||||
return &u
|
||||
}
|
||||
|
||||
func (bc *baseChallenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
||||
return nil, ServerInternalErr(errors.New("unimplemented"))
|
||||
}
|
||||
|
||||
func (bc *baseChallenge) storeError(db nosql.DB, err *Error) error {
|
||||
clone := bc.clone()
|
||||
clone.Error = err.ToACME()
|
||||
return clone.save(db, bc)
|
||||
}
|
||||
|
||||
// unmarshalChallenge unmarshals a challenge type into the correct sub-type.
|
||||
func unmarshalChallenge(data []byte) (challenge, error) {
|
||||
var getType struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &getType); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type"))
|
||||
}
|
||||
|
||||
switch getType.Type {
|
||||
case "dns-01":
|
||||
var bc baseChallenge
|
||||
if err := json.Unmarshal(data, &bc); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
|
||||
"challenge type into dns01Challenge"))
|
||||
}
|
||||
return &dns01Challenge{&bc}, nil
|
||||
case "http-01":
|
||||
var bc baseChallenge
|
||||
if err := json.Unmarshal(data, &bc); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
|
||||
"challenge type into http01Challenge"))
|
||||
}
|
||||
return &http01Challenge{&bc}, nil
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unexpected challenge type %s", getType.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// http01Challenge represents an http-01 acme challenge.
|
||||
type http01Challenge struct {
|
||||
*baseChallenge
|
||||
}
|
||||
|
||||
// newHTTP01Challenge returns a new acme http-01 challenge.
|
||||
func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
|
||||
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bc.Type = "http-01"
|
||||
bc.Value = ops.Identifier.Value
|
||||
|
||||
hc := &http01Challenge{bc}
|
||||
if err := hc.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hc, nil
|
||||
}
|
||||
|
||||
// Validate attempts to validate the challenge. If the challenge has been
|
||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||
// updated.
|
||||
func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
||||
// If already valid or invalid then return without performing validation.
|
||||
if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid {
|
||||
return hc, nil
|
||||
}
|
||||
url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token)
|
||||
|
||||
resp, err := vo.httpGet(url)
|
||||
if err != nil {
|
||||
if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err,
|
||||
"error doing http GET for url %s", url))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hc, nil
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if err = hc.storeError(db,
|
||||
ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d",
|
||||
url, resp.StatusCode))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hc, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+
|
||||
"response body for url %s", url))
|
||||
}
|
||||
keyAuth := strings.Trim(string(body), "\r\n")
|
||||
|
||||
expected, err := KeyAuthorization(hc.Token, jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if keyAuth != expected {
|
||||
if err = hc.storeError(db,
|
||||
RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+
|
||||
"expected %s, but got %s", expected, keyAuth))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hc, nil
|
||||
}
|
||||
|
||||
// Update and store the challenge.
|
||||
upd := &http01Challenge{hc.baseChallenge.clone()}
|
||||
upd.Status = StatusValid
|
||||
upd.Error = nil
|
||||
upd.Validated = clock.Now()
|
||||
|
||||
if err := upd.save(db, hc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upd, nil
|
||||
}
|
||||
|
||||
// dns01Challenge represents an dns-01 acme challenge.
|
||||
type dns01Challenge struct {
|
||||
*baseChallenge
|
||||
}
|
||||
|
||||
// newDNS01Challenge returns a new acme dns-01 challenge.
|
||||
func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
|
||||
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bc.Type = "dns-01"
|
||||
bc.Value = ops.Identifier.Value
|
||||
|
||||
dc := &dns01Challenge{bc}
|
||||
if err := dc.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dc, nil
|
||||
}
|
||||
|
||||
// KeyAuthorization creates the ACME key authorization value from a token
|
||||
// and a jwk.
|
||||
func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
|
||||
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint"))
|
||||
}
|
||||
encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
|
||||
return fmt.Sprintf("%s.%s", token, encPrint), nil
|
||||
}
|
||||
|
||||
// validate attempts to validate the challenge. If the challenge has been
|
||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||
// updated.
|
||||
func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
||||
// If already valid or invalid then return without performing validation.
|
||||
if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid {
|
||||
return dc, nil
|
||||
}
|
||||
|
||||
txtRecords, err := vo.lookupTxt("_acme-challenge." + dc.Value)
|
||||
if err != nil {
|
||||
if err = dc.storeError(db,
|
||||
DNSErr(errors.Wrapf(err, "error looking up TXT "+
|
||||
"records for domain %s", dc.Value))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dc, nil
|
||||
}
|
||||
|
||||
expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h := sha256.Sum256([]byte(expectedKeyAuth))
|
||||
expected := base64.RawURLEncoding.EncodeToString(h[:])
|
||||
var found bool
|
||||
for _, r := range txtRecords {
|
||||
if r == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if err = dc.storeError(db,
|
||||
RejectedIdentifierErr(errors.Errorf("keyAuthorization "+
|
||||
"does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dc, nil
|
||||
}
|
||||
|
||||
// Update and store the challenge.
|
||||
upd := &dns01Challenge{dc.baseChallenge.clone()}
|
||||
upd.Status = StatusValid
|
||||
upd.Error = nil
|
||||
upd.Validated = time.Now().UTC()
|
||||
|
||||
if err := upd.save(db, dc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upd, nil
|
||||
}
|
||||
|
||||
// getChallenge retrieves and unmarshals an ACME challenge type from the database.
|
||||
func getChallenge(db nosql.DB, id string) (challenge, error) {
|
||||
b, err := db.Get(challengeTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id))
|
||||
}
|
||||
ch, err := unmarshalChallenge(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ch, nil
|
||||
}
|
1093
acme/challenge_test.go
Normal file
1093
acme/challenge_test.go
Normal file
File diff suppressed because it is too large
Load diff
76
acme/common.go
Normal file
76
acme/common.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
)
|
||||
|
||||
// SignAuthority is the interface implemented by a CA authority.
|
||||
type SignAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// Identifier encodes the type that an order pertains to.
|
||||
type Identifier struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
var (
|
||||
accountTable = []byte("acme-accounts")
|
||||
accountByKeyIDTable = []byte("acme-keyID-accountID-index")
|
||||
authzTable = []byte("acme-authzs")
|
||||
challengeTable = []byte("acme-challenges")
|
||||
nonceTable = []byte("nonce-table")
|
||||
orderTable = []byte("acme-orders")
|
||||
ordersByAccountIDTable = []byte("acme-account-orders-index")
|
||||
certTable = []byte("acme-certs")
|
||||
)
|
||||
|
||||
var (
|
||||
// StatusValid -- valid
|
||||
StatusValid = "valid"
|
||||
// StatusInvalid -- invalid
|
||||
StatusInvalid = "invalid"
|
||||
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
|
||||
StatusPending = "pending"
|
||||
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
|
||||
StatusDeactivated = "deactivated"
|
||||
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
|
||||
StatusReady = "ready"
|
||||
//statusExpired = "expired"
|
||||
//statusActive = "active"
|
||||
//statusProcessing = "processing"
|
||||
)
|
||||
|
||||
var idLen = 32
|
||||
|
||||
func randID() (val string, err error) {
|
||||
val, err = randutil.Alphanumeric(idLen)
|
||||
if err != nil {
|
||||
return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID"))
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock int
|
||||
|
||||
// Now returns the UTC time rounded to seconds.
|
||||
func (c *Clock) Now() time.Time {
|
||||
return time.Now().UTC().Round(time.Second)
|
||||
}
|
||||
|
||||
var clock = new(Clock)
|
||||
|
||||
// URLSafeProvisionerName returns a path escaped version of the ACME provisioner
|
||||
// ID that is safe to use in URL paths.
|
||||
func URLSafeProvisionerName(p provisioner.Interface) string {
|
||||
return url.PathEscape(p.GetName())
|
||||
}
|
120
acme/directory.go
Normal file
120
acme/directory.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Directory represents an ACME directory for configuring clients.
|
||||
type Directory struct {
|
||||
NewNonce string `json:"newNonce,omitempty"`
|
||||
NewAccount string `json:"newAccount,omitempty"`
|
||||
NewOrder string `json:"newOrder,omitempty"`
|
||||
NewAuthz string `json:"newAuthz,omitempty"`
|
||||
RevokeCert string `json:"revokeCert,omitempty"`
|
||||
KeyChange string `json:"keyChange,omitempty"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging for the Directory type.
|
||||
func (d *Directory) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
type directory struct {
|
||||
prefix, dns string
|
||||
}
|
||||
|
||||
// newDirectory returns a new Directory type.
|
||||
func newDirectory(dns, prefix string) *directory {
|
||||
return &directory{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Link captures the link type.
|
||||
type Link int
|
||||
|
||||
const (
|
||||
// NewNonceLink new-nonce
|
||||
NewNonceLink Link = iota
|
||||
// NewAccountLink new-account
|
||||
NewAccountLink
|
||||
// AccountLink account
|
||||
AccountLink
|
||||
// OrderLink order
|
||||
OrderLink
|
||||
// NewOrderLink new-order
|
||||
NewOrderLink
|
||||
// OrdersByAccountLink list of orders owned by account
|
||||
OrdersByAccountLink
|
||||
// FinalizeLink finalize order
|
||||
FinalizeLink
|
||||
// NewAuthzLink authz
|
||||
NewAuthzLink
|
||||
// AuthzLink new-authz
|
||||
AuthzLink
|
||||
// ChallengeLink challenge
|
||||
ChallengeLink
|
||||
// CertificateLink certificate
|
||||
CertificateLink
|
||||
// DirectoryLink directory
|
||||
DirectoryLink
|
||||
// RevokeCertLink revoke certificate
|
||||
RevokeCertLink
|
||||
// KeyChangeLink key rollover
|
||||
KeyChangeLink
|
||||
)
|
||||
|
||||
func (l Link) String() string {
|
||||
switch l {
|
||||
case NewNonceLink:
|
||||
return "new-nonce"
|
||||
case NewAccountLink:
|
||||
return "new-account"
|
||||
case AccountLink:
|
||||
return "account"
|
||||
case NewOrderLink:
|
||||
return "new-order"
|
||||
case OrderLink:
|
||||
return "order"
|
||||
case NewAuthzLink:
|
||||
return "new-authz"
|
||||
case AuthzLink:
|
||||
return "authz"
|
||||
case ChallengeLink:
|
||||
return "challenge"
|
||||
case CertificateLink:
|
||||
return "certificate"
|
||||
case DirectoryLink:
|
||||
return "directory"
|
||||
case RevokeCertLink:
|
||||
return "revoke-cert"
|
||||
case KeyChangeLink:
|
||||
return "key-change"
|
||||
default:
|
||||
return "unexpected"
|
||||
}
|
||||
}
|
||||
|
||||
// getLink returns an absolute or partial path to the given resource.
|
||||
func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs ...string) string {
|
||||
var link string
|
||||
switch typ {
|
||||
case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
|
||||
link = fmt.Sprintf("/%s/%s", provisionerName, typ.String())
|
||||
case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0])
|
||||
case OrdersByAccountLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0])
|
||||
case FinalizeLink:
|
||||
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
|
||||
}
|
||||
if abs {
|
||||
return fmt.Sprintf("https://%s/%s%s", d.dns, d.prefix, link)
|
||||
}
|
||||
return link
|
||||
}
|
60
acme/directory_test.go
Normal file
60
acme/directory_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestDirectoryGetLink(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
dir := newDirectory(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
provID := URLSafeProvisionerName(prov)
|
||||
|
||||
assert.Equals(t, dir.getLink(NewNonceLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", provID))
|
||||
assert.Equals(t, dir.getLink(NewNonceLink, provID, false), fmt.Sprintf("/%s/new-nonce", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(NewAccountLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID))
|
||||
assert.Equals(t, dir.getLink(NewAccountLink, provID, false), fmt.Sprintf("/%s/new-account", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(AccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provID))
|
||||
assert.Equals(t, dir.getLink(AccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(NewOrderLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", provID))
|
||||
assert.Equals(t, dir.getLink(NewOrderLink, provID, false), fmt.Sprintf("/%s/new-order", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(OrderLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234", provID))
|
||||
assert.Equals(t, dir.getLink(OrderLink, provID, false, id), fmt.Sprintf("/%s/order/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234/orders", provID))
|
||||
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234/orders", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(FinalizeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234/finalize", provID))
|
||||
assert.Equals(t, dir.getLink(FinalizeLink, provID, false, id), fmt.Sprintf("/%s/order/1234/finalize", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(NewAuthzLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-authz", provID))
|
||||
assert.Equals(t, dir.getLink(NewAuthzLink, provID, false), fmt.Sprintf("/%s/new-authz", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(AuthzLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/authz/1234", provID))
|
||||
assert.Equals(t, dir.getLink(AuthzLink, provID, false, id), fmt.Sprintf("/%s/authz/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(DirectoryLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", provID))
|
||||
assert.Equals(t, dir.getLink(DirectoryLink, provID, false), fmt.Sprintf("/%s/directory", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(RevokeCertLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", provID))
|
||||
assert.Equals(t, dir.getLink(RevokeCertLink, provID, false), fmt.Sprintf("/%s/revoke-cert", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(KeyChangeLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", provID))
|
||||
assert.Equals(t, dir.getLink(KeyChangeLink, provID, false), fmt.Sprintf("/%s/key-change", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(ChallengeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/1234", provID))
|
||||
assert.Equals(t, dir.getLink(ChallengeLink, provID, false, id), fmt.Sprintf("/%s/challenge/1234", provID))
|
||||
|
||||
assert.Equals(t, dir.getLink(CertificateLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/1234", provID))
|
||||
assert.Equals(t, dir.getLink(CertificateLink, provID, false, id), fmt.Sprintf("/%s/certificate/1234", provID))
|
||||
}
|
439
acme/errors.go
Normal file
439
acme/errors.go
Normal file
|
@ -0,0 +1,439 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// AccountDoesNotExistErr returns a new acme error.
|
||||
func AccountDoesNotExistErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: accountDoesNotExistErr,
|
||||
Detail: "Account does not exist",
|
||||
Status: 404,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// AlreadyRevokedErr returns a new acme error.
|
||||
func AlreadyRevokedErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: alreadyRevokedErr,
|
||||
Detail: "Certificate already revoked",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// BadCSRErr returns a new acme error.
|
||||
func BadCSRErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badCSRErr,
|
||||
Detail: "The CSR is unacceptable",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// BadNonceErr returns a new acme error.
|
||||
func BadNonceErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badNonceErr,
|
||||
Detail: "Unacceptable anti-replay nonce",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// BadPublicKeyErr returns a new acme error.
|
||||
func BadPublicKeyErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badPublicKeyErr,
|
||||
Detail: "The jws was signed by a public key the server does not support",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// BadRevocationReasonErr returns a new acme error.
|
||||
func BadRevocationReasonErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badRevocationReasonErr,
|
||||
Detail: "The revocation reason provided is not allowed by the server",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// BadSignatureAlgorithmErr returns a new acme error.
|
||||
func BadSignatureAlgorithmErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: badSignatureAlgorithmErr,
|
||||
Detail: "The JWS was signed with an algorithm the server does not support",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// CaaErr returns a new acme error.
|
||||
func CaaErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: caaErr,
|
||||
Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// CompoundErr returns a new acme error.
|
||||
func CompoundErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: compoundErr,
|
||||
Detail: "Specific error conditions are indicated in the “subproblems” array",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectionErr returns a new acme error.
|
||||
func ConnectionErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: connectionErr,
|
||||
Detail: "The server could not connect to validation target",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// DNSErr returns a new acme error.
|
||||
func DNSErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: dnsErr,
|
||||
Detail: "There was a problem with a DNS query during identifier validation",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ExternalAccountRequiredErr returns a new acme error.
|
||||
func ExternalAccountRequiredErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: externalAccountRequiredErr,
|
||||
Detail: "The request must include a value for the \"externalAccountBinding\" field",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// IncorrectResponseErr returns a new acme error.
|
||||
func IncorrectResponseErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: incorrectResponseErr,
|
||||
Detail: "Response received didn't match the challenge's requirements",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidContactErr returns a new acme error.
|
||||
func InvalidContactErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: invalidContactErr,
|
||||
Detail: "A contact URL for an account was invalid",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// MalformedErr returns a new acme error.
|
||||
func MalformedErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: malformedErr,
|
||||
Detail: "The request message was malformed",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// OrderNotReadyErr returns a new acme error.
|
||||
func OrderNotReadyErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: orderNotReadyErr,
|
||||
Detail: "The request attempted to finalize an order that is not ready to be finalized",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitedErr returns a new acme error.
|
||||
func RateLimitedErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: rateLimitedErr,
|
||||
Detail: "The request exceeds a rate limit",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// RejectedIdentifierErr returns a new acme error.
|
||||
func RejectedIdentifierErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: rejectedIdentifierErr,
|
||||
Detail: "The server will not issue certificates for the identifier",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ServerInternalErr returns a new acme error.
|
||||
func ServerInternalErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: serverInternalErr,
|
||||
Detail: "The server experienced an internal error",
|
||||
Status: 500,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// TLSErr returns a new acme error.
|
||||
func TLSErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: tlsErr,
|
||||
Detail: "The server received a TLS error during validation",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// UnauthorizedErr returns a new acme error.
|
||||
func UnauthorizedErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: unauthorizedErr,
|
||||
Detail: "The client lacks sufficient authorization",
|
||||
Status: 401,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// UnsupportedContactErr returns a new acme error.
|
||||
func UnsupportedContactErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: unsupportedContactErr,
|
||||
Detail: "A contact URL for an account used an unsupported protocol scheme",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// UnsupportedIdentifierErr returns a new acme error.
|
||||
func UnsupportedIdentifierErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: unsupportedIdentifierErr,
|
||||
Detail: "An identifier is of an unsupported type",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// UserActionRequiredErr returns a new acme error.
|
||||
func UserActionRequiredErr(err error) *Error {
|
||||
return &Error{
|
||||
Type: userActionRequiredErr,
|
||||
Detail: "Visit the “instance” URL and take actions specified there",
|
||||
Status: 400,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ProbType is the type of the ACME problem.
|
||||
type ProbType int
|
||||
|
||||
const (
|
||||
// The request specified an account that does not exist
|
||||
accountDoesNotExistErr ProbType = iota
|
||||
// The request specified a certificate to be revoked that has already been revoked
|
||||
alreadyRevokedErr
|
||||
// The CSR is unacceptable (e.g., due to a short key)
|
||||
badCSRErr
|
||||
// The client sent an unacceptable anti-replay nonce
|
||||
badNonceErr
|
||||
// The JWS was signed by a public key the server does not support
|
||||
badPublicKeyErr
|
||||
// The revocation reason provided is not allowed by the server
|
||||
badRevocationReasonErr
|
||||
// The JWS was signed with an algorithm the server does not support
|
||||
badSignatureAlgorithmErr
|
||||
// Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate
|
||||
caaErr
|
||||
// Specific error conditions are indicated in the “subproblems” array.
|
||||
compoundErr
|
||||
// The server could not connect to validation target
|
||||
connectionErr
|
||||
// There was a problem with a DNS query during identifier validation
|
||||
dnsErr
|
||||
// The request must include a value for the “externalAccountBinding” field
|
||||
externalAccountRequiredErr
|
||||
// Response received didn’t match the challenge’s requirements
|
||||
incorrectResponseErr
|
||||
// A contact URL for an account was invalid
|
||||
invalidContactErr
|
||||
// The request message was malformed
|
||||
malformedErr
|
||||
// The request attempted to finalize an order that is not ready to be finalized
|
||||
orderNotReadyErr
|
||||
// The request exceeds a rate limit
|
||||
rateLimitedErr
|
||||
// The server will not issue certificates for the identifier
|
||||
rejectedIdentifierErr
|
||||
// The server experienced an internal error
|
||||
serverInternalErr
|
||||
// The server received a TLS error during validation
|
||||
tlsErr
|
||||
// The client lacks sufficient authorization
|
||||
unauthorizedErr
|
||||
// A contact URL for an account used an unsupported protocol scheme
|
||||
unsupportedContactErr
|
||||
// An identifier is of an unsupported type
|
||||
unsupportedIdentifierErr
|
||||
// Visit the “instance” URL and take actions specified there
|
||||
userActionRequiredErr
|
||||
)
|
||||
|
||||
// String returns the string representation of the acme problem type,
|
||||
// fulfilling the Stringer interface.
|
||||
func (ap ProbType) String() string {
|
||||
switch ap {
|
||||
case accountDoesNotExistErr:
|
||||
return "accountDoesNotExist"
|
||||
case alreadyRevokedErr:
|
||||
return "alreadyRevoked"
|
||||
case badCSRErr:
|
||||
return "badCSR"
|
||||
case badNonceErr:
|
||||
return "badNonce"
|
||||
case badPublicKeyErr:
|
||||
return "badPublicKey"
|
||||
case badRevocationReasonErr:
|
||||
return "badRevocationReason"
|
||||
case badSignatureAlgorithmErr:
|
||||
return "badSignatureAlgorithm"
|
||||
case caaErr:
|
||||
return "caa"
|
||||
case compoundErr:
|
||||
return "compound"
|
||||
case connectionErr:
|
||||
return "connection"
|
||||
case dnsErr:
|
||||
return "dns"
|
||||
case externalAccountRequiredErr:
|
||||
return "externalAccountRequired"
|
||||
case incorrectResponseErr:
|
||||
return "incorrectResponse"
|
||||
case invalidContactErr:
|
||||
return "invalidContact"
|
||||
case malformedErr:
|
||||
return "malformed"
|
||||
case orderNotReadyErr:
|
||||
return "orderNotReady"
|
||||
case rateLimitedErr:
|
||||
return "rateLimited"
|
||||
case rejectedIdentifierErr:
|
||||
return "rejectedIdentifier"
|
||||
case serverInternalErr:
|
||||
return "serverInternal"
|
||||
case tlsErr:
|
||||
return "tls"
|
||||
case unauthorizedErr:
|
||||
return "unauthorized"
|
||||
case unsupportedContactErr:
|
||||
return "unsupportedContact"
|
||||
case unsupportedIdentifierErr:
|
||||
return "unsupportedIdentifier"
|
||||
case userActionRequiredErr:
|
||||
return "userActionRequired"
|
||||
default:
|
||||
return "unsupported type"
|
||||
}
|
||||
}
|
||||
|
||||
// Error is an ACME error type complete with problem document.
|
||||
type Error struct {
|
||||
Type ProbType
|
||||
Detail string
|
||||
Err error
|
||||
Status int
|
||||
Sub []*Error
|
||||
Identifier *Identifier
|
||||
}
|
||||
|
||||
// Wrap attempts to wrap the internal error.
|
||||
func Wrap(err error, wrap string) *Error {
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case *Error:
|
||||
if e.Err == nil {
|
||||
e.Err = errors.New(wrap + "; " + e.Detail)
|
||||
} else {
|
||||
e.Err = errors.Wrap(e.Err, wrap)
|
||||
}
|
||||
return e
|
||||
default:
|
||||
return ServerInternalErr(errors.Wrap(err, wrap))
|
||||
}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *Error) Error() string {
|
||||
if e.Err == nil {
|
||||
return e.Detail
|
||||
}
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
// Cause returns the internal error and implements the Causer interface.
|
||||
func (e *Error) Cause() error {
|
||||
if e.Err == nil {
|
||||
return errors.New(e.Detail)
|
||||
}
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// ToACME returns an acme representation of the problem type.
|
||||
func (e *Error) ToACME() *AError {
|
||||
ae := &AError{
|
||||
Type: "urn:ietf:params:acme:error:" + e.Type.String(),
|
||||
Detail: e.Error(),
|
||||
Status: e.Status,
|
||||
}
|
||||
if e.Identifier != nil {
|
||||
ae.Identifier = *e.Identifier
|
||||
}
|
||||
for _, p := range e.Sub {
|
||||
ae.Subproblems = append(ae.Subproblems, p.ToACME())
|
||||
}
|
||||
return ae
|
||||
}
|
||||
|
||||
// StatusCode returns the status code and implements the StatusCode interface.
|
||||
func (e *Error) StatusCode() int {
|
||||
return e.Status
|
||||
}
|
||||
|
||||
// AError is the error type as seen in acme request/responses.
|
||||
type AError struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Identifier interface{} `json:"identifier,omitempty"`
|
||||
Subproblems []interface{} `json:"subproblems,omitempty"`
|
||||
Status int `json:"-"`
|
||||
}
|
||||
|
||||
// Error allows AError to implement the error interface.
|
||||
func (ae *AError) Error() string {
|
||||
return ae.Detail
|
||||
}
|
||||
|
||||
// StatusCode returns the status code and implements the StatusCode interface.
|
||||
func (ae *AError) StatusCode() int {
|
||||
return ae.Status
|
||||
}
|
73
acme/nonce.go
Normal file
73
acme/nonce.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
// nonce contains nonce metadata used in the ACME protocol.
|
||||
type nonce struct {
|
||||
ID string
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
// newNonce creates, stores, and returns an ACME replay-nonce.
|
||||
func newNonce(db nosql.DB) (*nonce, error) {
|
||||
_id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
|
||||
n := &nonce{
|
||||
ID: id,
|
||||
Created: clock.Now(),
|
||||
}
|
||||
b, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce"))
|
||||
case !swapped:
|
||||
return nil, ServerInternalErr(errors.New("error storing nonce; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
// useNonce verifies that the nonce is valid (by checking if it exists),
|
||||
// and if so, consumes the nonce resource by deleting it from the database.
|
||||
func useNonce(db nosql.DB, nonce string) error {
|
||||
err := db.Update(&database.Tx{
|
||||
Operations: []*database.TxEntry{
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Get,
|
||||
},
|
||||
{
|
||||
Bucket: nonceTable,
|
||||
Key: []byte(nonce),
|
||||
Cmd: database.Delete,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
return BadNonceErr(nil)
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
163
acme/nonce_test.go
Normal file
163
acme/nonce_test.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
func TestNewNonce(t *testing.T) {
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err *Error
|
||||
id *string
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing nonce: force")),
|
||||
}
|
||||
},
|
||||
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
return nil, false, nil
|
||||
},
|
||||
},
|
||||
err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
var _id string
|
||||
id := &_id
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
assert.Equals(t, bucket, nonceTable)
|
||||
assert.Equals(t, old, nil)
|
||||
*id = string(key)
|
||||
return nil, true, nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if n, err := newNonce(tc.db); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, n.ID, *tc.id)
|
||||
|
||||
assert.True(t, n.Created.Before(time.Now().Add(time.Minute)))
|
||||
assert.True(t, n.Created.After(time.Now().Add(-time.Minute)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseNonce(t *testing.T) {
|
||||
type test struct {
|
||||
id string
|
||||
db nosql.DB
|
||||
err *Error
|
||||
}
|
||||
tests := map[string]func(t *testing.T) test{
|
||||
"fail/update-not-found": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return database.ErrNotFound
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
err: BadNonceErr(nil),
|
||||
}
|
||||
},
|
||||
"fail/update-error": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
return errors.New("force")
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
id := "foo"
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MUpdate: func(tx *database.Tx) error {
|
||||
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||
|
||||
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
id: id,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := useNonce(tc.db, tc.id); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
ae, ok := err.(*Error)
|
||||
assert.True(t, ok)
|
||||
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
342
acme/order.go
Normal file
342
acme/order.go
Normal file
|
@ -0,0 +1,342 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
var defaultOrderExpiry = time.Hour * 24
|
||||
|
||||
// Order contains order metadata for the ACME protocol order type.
|
||||
type Order struct {
|
||||
Status string `json:"status"`
|
||||
Expires string `json:"expires,omitempty"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore string `json:"notBefore,omitempty"`
|
||||
NotAfter string `json:"notAfter,omitempty"`
|
||||
Error interface{} `json:"error,omitempty"`
|
||||
Authorizations []string `json:"authorizations"`
|
||||
Finalize string `json:"finalize"`
|
||||
Certificate string `json:"certificate,omitempty"`
|
||||
ID string `json:"-"`
|
||||
}
|
||||
|
||||
// ToLog enables response logging.
|
||||
func (o *Order) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging"))
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GetID returns the Order ID.
|
||||
func (o *Order) GetID() string {
|
||||
return o.ID
|
||||
}
|
||||
|
||||
// OrderOptions options with which to create a new Order.
|
||||
type OrderOptions struct {
|
||||
AccountID string `json:"accID"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore time.Time `json:"notBefore"`
|
||||
NotAfter time.Time `json:"notAfter"`
|
||||
}
|
||||
|
||||
type order struct {
|
||||
ID string `json:"id"`
|
||||
AccountID string `json:"accountID"`
|
||||
Created time.Time `json:"created"`
|
||||
Expires time.Time `json:"expires,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Identifiers []Identifier `json:"identifiers"`
|
||||
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
Authorizations []string `json:"authorizations"`
|
||||
Certificate string `json:"certificate,omitempty"`
|
||||
}
|
||||
|
||||
// newOrder returns a new Order type.
|
||||
func newOrder(db nosql.DB, ops OrderOptions) (*order, error) {
|
||||
id, err := randID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authzs := make([]string, len(ops.Identifiers))
|
||||
for i, identifier := range ops.Identifiers {
|
||||
authz, err := newAuthz(db, ops.AccountID, identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authzs[i] = authz.getID()
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
o := &order{
|
||||
ID: id,
|
||||
AccountID: ops.AccountID,
|
||||
Created: now,
|
||||
Status: StatusPending,
|
||||
Expires: now.Add(defaultOrderExpiry),
|
||||
Identifiers: ops.Identifiers,
|
||||
NotBefore: ops.NotBefore,
|
||||
NotAfter: ops.NotAfter,
|
||||
Authorizations: authzs,
|
||||
}
|
||||
if err := o.save(db, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update the "order IDs by account ID" index //
|
||||
oids, err := getOrderIDsByAccount(db, ops.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newOids := append(oids, o.ID)
|
||||
if err = orderIDs(newOids).save(db, oids, o.AccountID); err != nil {
|
||||
db.Del(orderTable, []byte(o.ID))
|
||||
return nil, err
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
type orderIDs []string
|
||||
|
||||
func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
|
||||
var (
|
||||
err error
|
||||
oldb []byte
|
||||
)
|
||||
if len(old) == 0 {
|
||||
oldb = nil
|
||||
} else {
|
||||
oldb, err = json.Marshal(old)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice"))
|
||||
}
|
||||
}
|
||||
newb, err := json.Marshal(oids)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice"))
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.Errorf("error storing order IDs "+
|
||||
"for account %s; order IDs changed since last read", accID))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (o *order) save(db nosql.DB, old *order) error {
|
||||
var (
|
||||
err error
|
||||
oldB []byte
|
||||
)
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
if oldB, err = json.Marshal(old); err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
||||
}
|
||||
}
|
||||
|
||||
newB, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order"))
|
||||
}
|
||||
|
||||
_, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return ServerInternalErr(errors.Wrap(err, "error storing order"))
|
||||
case !swapped:
|
||||
return ServerInternalErr(errors.New("error storing order; " +
|
||||
"value has changed since last read"))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// updateStatus updates order status if necessary.
|
||||
func (o *order) updateStatus(db nosql.DB) (*order, error) {
|
||||
_newOrder := *o
|
||||
newOrder := &_newOrder
|
||||
|
||||
now := time.Now().UTC()
|
||||
switch o.Status {
|
||||
case StatusInvalid:
|
||||
return o, nil
|
||||
case StatusValid:
|
||||
return o, nil
|
||||
case StatusReady:
|
||||
// check expiry
|
||||
if now.After(o.Expires) {
|
||||
newOrder.Status = StatusInvalid
|
||||
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
||||
break
|
||||
}
|
||||
return o, nil
|
||||
case StatusPending:
|
||||
// check expiry
|
||||
if now.After(o.Expires) {
|
||||
newOrder.Status = StatusInvalid
|
||||
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
||||
break
|
||||
}
|
||||
|
||||
var count = map[string]int{
|
||||
StatusValid: 0,
|
||||
StatusInvalid: 0,
|
||||
StatusPending: 0,
|
||||
}
|
||||
for _, azID := range o.Authorizations {
|
||||
authz, err := getAuthz(db, azID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz, err = authz.updateStatus(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
st := authz.getStatus()
|
||||
count[st]++
|
||||
}
|
||||
switch {
|
||||
case count[StatusInvalid] > 0:
|
||||
newOrder.Status = StatusInvalid
|
||||
case count[StatusPending] > 0:
|
||||
break
|
||||
case count[StatusValid] == len(o.Authorizations):
|
||||
newOrder.Status = StatusReady
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.New("unexpected authz status"))
|
||||
}
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status))
|
||||
}
|
||||
|
||||
if err := newOrder.save(db, o); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newOrder, nil
|
||||
}
|
||||
|
||||
// finalize signs a certificate if the necessary conditions for Order completion
|
||||
// have been met.
|
||||
func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p provisioner.Interface) (*order, error) {
|
||||
var err error
|
||||
if o, err = o.updateStatus(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch o.Status {
|
||||
case StatusInvalid:
|
||||
return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID))
|
||||
case StatusValid:
|
||||
return o, nil
|
||||
case StatusPending:
|
||||
return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID))
|
||||
case StatusReady:
|
||||
break
|
||||
default:
|
||||
return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID))
|
||||
}
|
||||
|
||||
// Validate identifier names against CSR alternative names //
|
||||
csrNames := make(map[string]int)
|
||||
for _, n := range csr.DNSNames {
|
||||
csrNames[n] = 1
|
||||
}
|
||||
orderNames := make(map[string]int)
|
||||
for _, n := range o.Identifiers {
|
||||
orderNames[n.Value] = 1
|
||||
}
|
||||
if !reflect.DeepEqual(csrNames, orderNames) {
|
||||
return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly"))
|
||||
}
|
||||
|
||||
// Get authorizations from the ACME provisioner.
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
signOps, err := p.AuthorizeSign(ctx, "")
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner"))
|
||||
}
|
||||
|
||||
// Create and store a new certificate.
|
||||
leaf, inter, err := auth.Sign(csr, provisioner.Options{
|
||||
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
|
||||
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
|
||||
}, signOps...)
|
||||
if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID))
|
||||
}
|
||||
|
||||
cert, err := newCert(db, CertOptions{
|
||||
AccountID: o.AccountID,
|
||||
OrderID: o.ID,
|
||||
Leaf: leaf,
|
||||
Intermediates: []*x509.Certificate{inter},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_newOrder := *o
|
||||
newOrder := &_newOrder
|
||||
newOrder.Certificate = cert.ID
|
||||
newOrder.Status = StatusValid
|
||||
if err := newOrder.save(db, o); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newOrder, nil
|
||||
}
|
||||
|
||||
// getOrder retrieves and unmarshals an ACME Order type from the database.
|
||||
func getOrder(db nosql.DB, id string) (*order, error) {
|
||||
b, err := db.Get(orderTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id))
|
||||
} else if err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id))
|
||||
}
|
||||
var o order
|
||||
if err := json.Unmarshal(b, &o); err != nil {
|
||||
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order"))
|
||||
}
|
||||
return &o, nil
|
||||
}
|
||||
|
||||
// toACME converts the internal Order type into the public acmeOrder type for
|
||||
// presentation in the ACME protocol.
|
||||
func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Order, error) {
|
||||
azs := make([]string, len(o.Authorizations))
|
||||
for i, aid := range o.Authorizations {
|
||||
azs[i] = dir.getLink(AuthzLink, URLSafeProvisionerName(p), true, aid)
|
||||
}
|
||||
ao := &Order{
|
||||
Status: o.Status,
|
||||
Expires: o.Expires.Format(time.RFC3339),
|
||||
Identifiers: o.Identifiers,
|
||||
NotBefore: o.NotBefore.Format(time.RFC3339),
|
||||
NotAfter: o.NotAfter.Format(time.RFC3339),
|
||||
Authorizations: azs,
|
||||
Finalize: dir.getLink(FinalizeLink, URLSafeProvisionerName(p), true, o.ID),
|
||||
ID: o.ID,
|
||||
}
|
||||
|
||||
if o.Certificate != "" {
|
||||
ao.Certificate = dir.getLink(CertificateLink, URLSafeProvisionerName(p), true, o.Certificate)
|
||||
}
|
||||
return ao, nil
|
||||
}
|
1129
acme/order_test.go
Normal file
1129
acme/order_test.go
Normal file
File diff suppressed because it is too large
Load diff
24
api/api.go
24
api/api.go
|
@ -28,8 +28,7 @@ import (
|
|||
// Authority is the interface implemented by a CA authority.
|
||||
type Authority interface {
|
||||
SSHAuthority
|
||||
// NOTE: Authorize will be deprecated in future releases. Please use the
|
||||
// context specific Authorize[Sign|Revoke|etc.] methods.
|
||||
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||
GetTLSOptions() *tlsutil.TLSOptions
|
||||
|
@ -37,6 +36,7 @@ type Authority interface {
|
|||
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
||||
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
||||
Revoke(*authority.RevokeOptions) error
|
||||
GetEncryptedKey(kid string) (string, error)
|
||||
|
@ -308,13 +308,12 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
logCertificate(w, cert)
|
||||
JSON(w, &SignResponse{
|
||||
JSONStatus(w, &SignResponse{
|
||||
ServerPEM: Certificate{cert},
|
||||
CaPEM: Certificate{root},
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
})
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
|
@ -331,13 +330,12 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
logCertificate(w, cert)
|
||||
JSON(w, &SignResponse{
|
||||
JSONStatus(w, &SignResponse{
|
||||
ServerPEM: Certificate{cert},
|
||||
CaPEM: Certificate{root},
|
||||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
})
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Provisioners returns the list of provisioners configured in the authority.
|
||||
|
@ -383,10 +381,9 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
|||
certs[i] = Certificate{roots[i]}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
JSON(w, &RootsResponse{
|
||||
JSONStatus(w, &RootsResponse{
|
||||
Certificates: certs,
|
||||
})
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Federation returns all the public certificates in the federation.
|
||||
|
@ -402,10 +399,9 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
|||
certs[i] = Certificate{federated[i]}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
JSON(w, &FederationResponse{
|
||||
JSONStatus(w, &FederationResponse{
|
||||
Certificates: certs,
|
||||
})
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
var oidStepProvisioner = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}
|
||||
|
|
|
@ -506,6 +506,7 @@ type mockAuthority struct {
|
|||
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||
loadProvisionerByID func(provID string) (provisioner.Interface, error)
|
||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
revoke func(*authority.RevokeOptions) error
|
||||
getEncryptedKey func(kid string) (string, error)
|
||||
|
@ -581,6 +582,13 @@ func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (pr
|
|||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByID != nil {
|
||||
return m.loadProvisionerByID(provID)
|
||||
}
|
||||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Revoke(opts *authority.RevokeOptions) error {
|
||||
if m.revoke != nil {
|
||||
return m.revoke(opts)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
|
@ -109,7 +110,13 @@ func NotFound(err error) error {
|
|||
|
||||
// WriteError writes to w a JSON representation of the given error.
|
||||
func WriteError(w http.ResponseWriter, err error) {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
w.Header().Set("Content-Type", "application/problem+json")
|
||||
err = k.ToACME()
|
||||
default:
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
cause := errors.Cause(err)
|
||||
if sc, ok := err.(StatusCoder); ok {
|
||||
w.WriteHeader(sc.StatusCode())
|
||||
|
|
|
@ -87,8 +87,6 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
logRevoke(w, opts)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
JSON(w, &RevokeResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
|
|
|
@ -74,7 +74,6 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input string
|
||||
auth Authority
|
||||
tls *tls.ConnectionState
|
||||
err error
|
||||
statusCode int
|
||||
expected []byte
|
||||
}
|
||||
|
|
|
@ -260,13 +260,6 @@ func Test_caHandler_SignSSH(t *testing.T) {
|
|||
})
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
}
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
req []byte
|
||||
|
|
33
api/utils.go
33
api/utils.go
|
@ -10,6 +10,11 @@ import (
|
|||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// EnableLogger is an interface that enables response logging for an object.
|
||||
type EnableLogger interface {
|
||||
ToLog() (interface{}, error)
|
||||
}
|
||||
|
||||
// LogError adds to the response writer the given error if it implements
|
||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||
// using the log package.
|
||||
|
@ -23,12 +28,40 @@ func LogError(rw http.ResponseWriter, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// LogEnabledResponse log the response object if it implements the EnableLogger
|
||||
// interface.
|
||||
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
|
||||
if el, ok := v.(EnableLogger); ok {
|
||||
out, err := el.ToLog()
|
||||
if err != nil {
|
||||
LogError(rw, err)
|
||||
return
|
||||
}
|
||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"response": out,
|
||||
})
|
||||
} else {
|
||||
log.Println(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JSON writes the passed value into the http.ResponseWriter.
|
||||
func JSON(w http.ResponseWriter, v interface{}) {
|
||||
JSONStatus(w, v, http.StatusOK)
|
||||
}
|
||||
|
||||
// JSONStatus writes the given value into the http.ResponseWriter and the
|
||||
// given status is written as the status code of the response.
|
||||
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
LogError(w, err)
|
||||
return
|
||||
}
|
||||
LogEnabledResponse(w, v)
|
||||
}
|
||||
|
||||
// ReadJSON reads JSON from the request body and stores it in the value
|
||||
|
|
|
@ -15,7 +15,9 @@ import (
|
|||
"github.com/smallstep/cli/crypto/x509util"
|
||||
)
|
||||
|
||||
const legacyAuthority = "step-certificate-authority"
|
||||
const (
|
||||
legacyAuthority = "step-certificate-authority"
|
||||
)
|
||||
|
||||
// Authority implements the Certificate Authority internal interface.
|
||||
type Authority struct {
|
||||
|
@ -24,7 +26,6 @@ type Authority struct {
|
|||
intermediateIdentity *x509util.Identity
|
||||
sshCAUserCertSignKey crypto.Signer
|
||||
sshCAHostCertSignKey crypto.Signer
|
||||
validateOnce bool
|
||||
certificates *sync.Map
|
||||
startTime time.Time
|
||||
provisioners *provisioner.Collection
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
type MockAuthDB struct {
|
||||
err error
|
||||
ret1, ret2 interface{}
|
||||
ret1 interface{}
|
||||
init func(*db.Config) (db.AuthDB, error)
|
||||
isRevoked func(string) (bool, error)
|
||||
revoke func(rci *db.RevokedCertificateInfo) error
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
|
@ -33,6 +35,12 @@ func (e *apiError) Error() string {
|
|||
return ret
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error in JSON format.
|
||||
type ErrorResponse struct {
|
||||
Status int `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// StatusCode returns an http status code indicating the type and severity of
|
||||
// the error.
|
||||
func (e *apiError) StatusCode() int {
|
||||
|
@ -41,3 +49,19 @@ func (e *apiError) StatusCode() int {
|
|||
}
|
||||
return e.code
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaller interface for the Error struct.
|
||||
func (e *apiError) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(&ErrorResponse{Status: e.code, Message: http.StatusText(e.code)})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
|
||||
func (e *apiError) UnmarshalJSON(data []byte) error {
|
||||
var er ErrorResponse
|
||||
if err := json.Unmarshal(data, &er); err != nil {
|
||||
return err
|
||||
}
|
||||
e.code = er.Status
|
||||
e.err = fmt.Errorf(er.Message)
|
||||
return nil
|
||||
}
|
||||
|
|
85
authority/provisioner/acme.go
Normal file
85
authority/provisioner/acme.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ACME is the acme provisioner type, an entity that can authorize the ACME
|
||||
// provisioning flow.
|
||||
type ACME struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
claimer *Claimer
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
func (p ACME) GetID() string {
|
||||
return "acme/" + p.Name
|
||||
}
|
||||
|
||||
// GetTokenID returns the identifier of the token.
|
||||
func (p *ACME) GetTokenID(ott string) (string, error) {
|
||||
return "", errors.New("acme provisioner does not implement GetTokenID")
|
||||
}
|
||||
|
||||
// GetName returns the name of the provisioner.
|
||||
func (p *ACME) GetName() string {
|
||||
return p.Name
|
||||
}
|
||||
|
||||
// GetType returns the type of provisioner.
|
||||
func (p *ACME) GetType() Type {
|
||||
return TypeACME
|
||||
}
|
||||
|
||||
// GetEncryptedKey returns the base provisioner encrypted key if it's defined.
|
||||
func (p *ACME) GetEncryptedKey() (string, string, bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Init initializes and validates the fields of a JWK type.
|
||||
func (p *ACME) Init(config Config) (err error) {
|
||||
switch {
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
case p.Name == "":
|
||||
return errors.New("provisioner name cannot be empty")
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AuthorizeRevoke is not implemented yet for the ACME provisioner.
|
||||
func (p *ACME) AuthorizeRevoke(token string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (p *ACME) AuthorizeSign(ctx context.Context, _ string) ([]SignOption, error) {
|
||||
if m := MethodFromContext(ctx); m != SignMethod {
|
||||
return nil, errors.Errorf("unexpected method type %d in context", m)
|
||||
}
|
||||
return []SignOption{
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
newProvisionerExtensionOption(TypeACME, p.Name, ""),
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
defaultPublicKeyValidator{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeRenewal is not implemented for the ACME provisioner.
|
||||
func (p *ACME) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||
}
|
||||
return nil
|
||||
}
|
184
authority/provisioner/acme_test.go
Normal file
184
authority/provisioner/acme_test.go
Normal file
|
@ -0,0 +1,184 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
func TestACME_Getters(t *testing.T) {
|
||||
p, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
id := "acme/" + p.Name
|
||||
if got := p.GetID(); got != id {
|
||||
t.Errorf("ACME.GetID() = %v, want %v", got, id)
|
||||
}
|
||||
if got := p.GetName(); got != p.Name {
|
||||
t.Errorf("ACME.GetName() = %v, want %v", got, p.Name)
|
||||
}
|
||||
if got := p.GetType(); got != TypeACME {
|
||||
t.Errorf("ACME.GetType() = %v, want %v", got, TypeACME)
|
||||
}
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
if kid != "" || key != "" || ok == true {
|
||||
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, "", "", false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACME_Init(t *testing.T) {
|
||||
type ProvisionerValidateTest struct {
|
||||
p *ACME
|
||||
err error
|
||||
}
|
||||
tests := map[string]func(*testing.T) ProvisionerValidateTest{
|
||||
"fail-empty": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{
|
||||
Type: "ACME",
|
||||
},
|
||||
err: errors.New("provisioner name cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{Name: "foo"},
|
||||
err: errors.New("provisioner type cannot be empty"),
|
||||
}
|
||||
},
|
||||
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{Name: "foo", Type: "bar", Claims: &Claims{DefaultTLSDur: &Duration{0}}},
|
||||
err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &ACME{Name: "foo", Type: "bar"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
config := Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}
|
||||
for name, get := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := get(t)
|
||||
err := tc.p.Init(config)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestACME_AuthorizeRevoke(t *testing.T) {
|
||||
p, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
assert.Nil(t, p.AuthorizeRevoke(""))
|
||||
}
|
||||
|
||||
func TestACME_AuthorizeRenewal(t *testing.T) {
|
||||
p1, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *ACME
|
||||
args args
|
||||
err error
|
||||
}{
|
||||
{"ok", p1, args{nil}, nil},
|
||||
{"fail", p2, args{nil}, errors.Errorf("renew is disabled for provisioner %s", p2.GetID())},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.prov.AuthorizeRenewal(tt.args.cert); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestACME_AuthorizeSign(t *testing.T) {
|
||||
p1, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prov *ACME
|
||||
method Method
|
||||
err error
|
||||
}{
|
||||
{"fail/method", p1, SignSSHMethod, errors.New("unexpected method type 1 in context")},
|
||||
{"ok", p1, SignMethod, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContextWithMethod(context.Background(), tt.method)
|
||||
if got, err := tt.prov.AuthorizeSign(ctx, ""); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
} else {
|
||||
if assert.NotNil(t, got) {
|
||||
assert.Len(t, 4, got)
|
||||
|
||||
_pdd := got[0]
|
||||
pdd, ok := _pdd.(profileDefaultDuration)
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, pdd, profileDefaultDuration(86400000000000))
|
||||
|
||||
_peo := got[1]
|
||||
peo, ok := _peo.(*provisionerExtensionOption)
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, peo.Type, 6)
|
||||
assert.Equals(t, peo.Name, "test@acme-provisioner.com")
|
||||
assert.Equals(t, peo.CredentialID, "")
|
||||
assert.Equals(t, peo.KeyValuePairs, nil)
|
||||
|
||||
_vv := got[2]
|
||||
vv, ok := _vv.(*validityValidator)
|
||||
assert.True(t, ok)
|
||||
assert.Equals(t, vv.min, time.Duration(300000000000))
|
||||
assert.Equals(t, vv.max, time.Duration(86400000000000))
|
||||
|
||||
_dpkv := got[3]
|
||||
_, ok = _dpkv.(defaultPublicKeyValidator)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -470,6 +470,8 @@ func (p *AWS) authorizeSSHSign(claims *awsPayload) ([]SignOption, error) {
|
|||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{p.claimer},
|
||||
// validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
|
|
|
@ -377,6 +377,12 @@ func TestAWS_AuthorizeSign_SSH(t *testing.T) {
|
|||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
pub := key.Public().Key
|
||||
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.FatalError(t, err)
|
||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedHostOptions := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
|
||||
|
@ -394,6 +400,7 @@ func TestAWS_AuthorizeSign_SSH(t *testing.T) {
|
|||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
key interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -403,15 +410,17 @@ func TestAWS_AuthorizeSign_SSH(t *testing.T) {
|
|||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false},
|
||||
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}}, expectedHostOptionsIP, false, false},
|
||||
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptionsHostname, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}}, expectedHostOptions, false, false},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}}, nil, false, true},
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principal-ip", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, false, false},
|
||||
{"ok-principal-hostname", p1, args{t1, SSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -424,7 +433,7 @@ func TestAWS_AuthorizeSign_SSH(t *testing.T) {
|
|||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
|
|
|
@ -327,6 +327,8 @@ func (p *Azure) authorizeSSHSign(claims azurePayload, name string) ([]SignOption
|
|||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{p.claimer},
|
||||
// validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
|
|
|
@ -3,6 +3,8 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
|
@ -325,6 +327,12 @@ func TestAzure_AuthorizeSign_SSH(t *testing.T) {
|
|||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
pub := key.Public().Key
|
||||
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.FatalError(t, err)
|
||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedHostOptions := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"virtualMachine"},
|
||||
|
@ -334,6 +342,7 @@ func TestAzure_AuthorizeSign_SSH(t *testing.T) {
|
|||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
key interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -343,13 +352,15 @@ func TestAzure_AuthorizeSign_SSH(t *testing.T) {
|
|||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}}, expectedHostOptions, false, false},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}}, nil, false, true},
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -362,7 +373,7 @@ func TestAzure_AuthorizeSign_SSH(t *testing.T) {
|
|||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
|
|
|
@ -127,6 +127,8 @@ func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool)
|
|||
return c.Load("aws/" + string(provisioner.Name))
|
||||
case TypeGCP:
|
||||
return c.Load("gcp/" + string(provisioner.Name))
|
||||
case TypeACME:
|
||||
return c.Load("acme/" + string(provisioner.Name))
|
||||
default:
|
||||
return c.Load(string(provisioner.CredentialID))
|
||||
}
|
||||
|
@ -153,7 +155,7 @@ func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
|
|||
// provisioner IDs.
|
||||
func (c *Collection) Store(p Interface) error {
|
||||
// Store provisioner always in byID. ID must be unique.
|
||||
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true {
|
||||
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded {
|
||||
return errors.New("cannot add multiple provisioners with the same id")
|
||||
}
|
||||
|
||||
|
|
|
@ -133,15 +133,20 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p3, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p1.GetID(), p1)
|
||||
byID.Store(p2.GetID(), p2)
|
||||
byID.Store(p3.GetID(), p3)
|
||||
|
||||
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
||||
assert.FatalError(t, err)
|
||||
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
|
||||
assert.FatalError(t, err)
|
||||
ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "")
|
||||
assert.FatalError(t, err)
|
||||
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
@ -151,6 +156,9 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
|||
ok2Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok2Ext},
|
||||
}
|
||||
ok3Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok3Ext},
|
||||
}
|
||||
notFoundCert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{notFoundExt},
|
||||
}
|
||||
|
@ -176,6 +184,7 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
|||
}{
|
||||
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
|
||||
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
|
||||
{"ok3", fields{byID, testAudiences}, args{ok3Cert}, p3, true},
|
||||
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
||||
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
|
||||
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
|
||||
|
|
|
@ -382,6 +382,8 @@ func (p *GCP) authorizeSSHSign(claims *gcpPayload) ([]SignOption, error) {
|
|||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{p.claimer},
|
||||
// validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
|
|
|
@ -3,6 +3,8 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
|
@ -362,6 +364,12 @@ func TestGCP_AuthorizeSign_SSH(t *testing.T) {
|
|||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
pub := key.Public().Key
|
||||
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.FatalError(t, err)
|
||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedHostOptions := &SSHOptions{
|
||||
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
|
||||
|
@ -379,6 +387,7 @@ func TestGCP_AuthorizeSign_SSH(t *testing.T) {
|
|||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
key interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -388,15 +397,17 @@ func TestGCP_AuthorizeSign_SSH(t *testing.T) {
|
|||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false},
|
||||
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}}, expectedHostOptionsPrincipal1, false, false},
|
||||
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}}, expectedHostOptionsPrincipal2, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}}, expectedHostOptions, false, false},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}}, nil, false, true},
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedHostOptions, false, false},
|
||||
{"ok-type", p1, args{t1, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false},
|
||||
{"ok-principal1", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, false, false},
|
||||
{"ok-principal2", p1, args{t1, SSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedHostOptions, false, true},
|
||||
{"fail-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, nil, false, true},
|
||||
{"fail-principal", p1, args{t1, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, false, true},
|
||||
{"fail-extra-principal", p1, args{t1, SSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -409,7 +420,7 @@ func TestGCP_AuthorizeSign_SSH(t *testing.T) {
|
|||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
|
|
|
@ -210,6 +210,8 @@ func (p *JWK) authorizeSSHSign(claims *jwtPayload) ([]SignOption, error) {
|
|||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{p.claimer},
|
||||
// validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
|
|
|
@ -3,6 +3,8 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"strings"
|
||||
|
@ -356,6 +358,12 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
|
|||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
pub := key.Public().Key
|
||||
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.FatalError(t, err)
|
||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedUserOptions := &SSHOptions{
|
||||
|
@ -370,6 +378,7 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
|
|||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
key interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -379,15 +388,17 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
|
|||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"user", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false},
|
||||
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false},
|
||||
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"host", p1, args{t2, SSHOptions{}}, expectedHostOptions, false, false},
|
||||
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}}, expectedHostOptions, false, false},
|
||||
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
|
||||
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
|
||||
{"fail-signature", p1, args{failSig, SSHOptions{}}, nil, true, false},
|
||||
{"user", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false},
|
||||
{"user-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false},
|
||||
{"user-type", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false},
|
||||
{"user-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
|
||||
{"user-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
|
||||
{"host", p1, args{t2, SSHOptions{}, pub}, expectedHostOptions, false, false},
|
||||
{"host-type", p1, args{t2, SSHOptions{CertType: "host"}, pub}, expectedHostOptions, false, false},
|
||||
{"host-principals", p1, args{t2, SSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
|
||||
{"host-options", p1, args{t2, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-signature", p1, args{failSig, SSHOptions{}, pub}, nil, true, false},
|
||||
{"rail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -400,7 +411,7 @@ func TestJWK_AuthorizeSign_SSH(t *testing.T) {
|
|||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
|
|
|
@ -4,7 +4,10 @@ import (
|
|||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -55,6 +58,7 @@ type OIDC struct {
|
|||
Admins []string `json:"admins,omitempty"`
|
||||
Domains []string `json:"domains,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
ListenAddress string `json:"listenAddress,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
configuration openIDConfiguration
|
||||
keyStore *keyStore
|
||||
|
@ -133,13 +137,27 @@ func (o *OIDC) Init(config Config) (err error) {
|
|||
return errors.New("configurationEndpoint cannot be empty")
|
||||
}
|
||||
|
||||
// Validate listenAddress if given
|
||||
if o.ListenAddress != "" {
|
||||
if _, _, err := net.SplitHostPort(o.ListenAddress); err != nil {
|
||||
return errors.Wrap(err, "error parsing listenAddress")
|
||||
}
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode and validate openid-configuration endpoint
|
||||
if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil {
|
||||
u, err := url.Parse(o.ConfigurationEndpoint)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint)
|
||||
}
|
||||
if !strings.Contains(u.Path, "/.well-known/openid-configuration") {
|
||||
u.Path = path.Join(u.Path, "/.well-known/openid-configuration")
|
||||
}
|
||||
if err := getAndDecode(u.String(), &o.configuration); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := o.configuration.Validate(); err != nil {
|
||||
|
@ -336,6 +354,8 @@ func (o *OIDC) authorizeSSHSign(claims *openIDPayload) ([]SignOption, error) {
|
|||
&sshDefaultExtensionModifier{},
|
||||
// checks the validity bounds, and set the validity if has not been set
|
||||
&sshCertificateValidityModifier{o.claimer},
|
||||
// validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// require all the fields in the SSH certificate
|
||||
&sshCertificateDefaultValidator{},
|
||||
), nil
|
||||
|
|
|
@ -3,6 +3,8 @@ package provisioner
|
|||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -79,6 +81,7 @@ func TestOIDC_Init(t *testing.T) {
|
|||
Claims *Claims
|
||||
Admins []string
|
||||
Domains []string
|
||||
ListenAddress string
|
||||
}
|
||||
type args struct {
|
||||
config Config
|
||||
|
@ -89,16 +92,21 @@ func TestOIDC_Init(t *testing.T) {
|
|||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
|
||||
{"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}, nil}, args{config}, false},
|
||||
{"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, []string{"smallstep.com"}}, args{config}, false},
|
||||
{"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
|
||||
{"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||
{"no-type", fields{"", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||
{"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
|
||||
{"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil}, args{config}, true},
|
||||
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil}, args{config}, true},
|
||||
{"bad-claims", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", badClaims, nil, nil}, args{config}, true},
|
||||
{"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, false},
|
||||
{"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/.well-known/openid-configuration", nil, []string{"foo@smallstep.com"}, nil, ""}, args{config}, false},
|
||||
{"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, []string{"smallstep.com"}, ""}, args{config}, false},
|
||||
{"ok-listen-port", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ":10000"}, args{config}, false},
|
||||
{"ok-listen-host-port", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, "127.0.0.1:10000"}, args{config}, false},
|
||||
{"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL, nil, nil, nil, ""}, args{config}, false},
|
||||
{"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, true},
|
||||
{"no-type", fields{"", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, true},
|
||||
{"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, true},
|
||||
{"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil, ""}, args{config}, true},
|
||||
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/random", nil, nil, nil, ""}, args{config}, true},
|
||||
{"bad-claims", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/.well-known/openid-configuration", badClaims, nil, nil, ""}, args{config}, true},
|
||||
{"bad-parse-url", fields{"oidc", "name", "client-id", "client-secret", ":", nil, nil, nil, ""}, args{config}, true},
|
||||
{"bad-get-url", fields{"oidc", "name", "client-id", "client-secret", "https://", nil, nil, nil, ""}, args{config}, true},
|
||||
{"bad-listen-address", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, "127.0.0.1"}, args{config}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -109,9 +117,12 @@ func TestOIDC_Init(t *testing.T) {
|
|||
ConfigurationEndpoint: tt.fields.ConfigurationEndpoint,
|
||||
Claims: tt.fields.Claims,
|
||||
Admins: tt.fields.Admins,
|
||||
Domains: tt.fields.Domains,
|
||||
ListenAddress: tt.fields.ListenAddress,
|
||||
}
|
||||
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
|
||||
t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr == false {
|
||||
assert.Len(t, 2, p.keyStore.keySet.Keys)
|
||||
|
@ -343,6 +354,12 @@ func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
|
|||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
pub := key.Public().Key
|
||||
rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.FatalError(t, err)
|
||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
expectedUserOptions := &SSHOptions{
|
||||
|
@ -361,6 +378,7 @@ func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
|
|||
type args struct {
|
||||
token string
|
||||
sshOpts SSHOptions
|
||||
key interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -370,18 +388,20 @@ func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
|
|||
wantErr bool
|
||||
wantSignErr bool
|
||||
}{
|
||||
{"ok", p1, args{t1, SSHOptions{}}, expectedUserOptions, false, false},
|
||||
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}}, expectedUserOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"admin", p3, args{okAdmin, SSHOptions{}}, expectedAdminOptions, false, false},
|
||||
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}}, expectedAdminOptions, false, false},
|
||||
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}}, expectedAdminOptions, false, false},
|
||||
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}}, expectedUserOptions, false, false},
|
||||
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}}, expectedHostOptions, false, false},
|
||||
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}}, nil, false, true},
|
||||
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}}, nil, false, true},
|
||||
{"fail-email", p3, args{failEmail, SSHOptions{}}, nil, true, false},
|
||||
{"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false},
|
||||
{"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false},
|
||||
{"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false},
|
||||
{"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
|
||||
{"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
|
||||
{"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, false, false},
|
||||
{"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, false, false},
|
||||
{"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub}, expectedAdminOptions, false, false},
|
||||
{"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false},
|
||||
{"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false},
|
||||
{"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true},
|
||||
{"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, false, true},
|
||||
{"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, false, true},
|
||||
{"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, true, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -394,7 +414,7 @@ func TestOIDC_AuthorizeSign_SSH(t *testing.T) {
|
|||
if err != nil {
|
||||
assert.Nil(t, got)
|
||||
} else if assert.NotNil(t, got) {
|
||||
cert, err := signSSHCertificate(key.Public().Key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
|
||||
if (err != nil) != tt.wantSignErr {
|
||||
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
|
||||
} else {
|
||||
|
|
|
@ -84,6 +84,8 @@ const (
|
|||
TypeAWS Type = 4
|
||||
// TypeAzure is used to indicate the Azure provisioners.
|
||||
TypeAzure Type = 5
|
||||
// TypeACME is used to indicate the ACME provisioners.
|
||||
TypeACME Type = 6
|
||||
|
||||
// RevokeAudienceKey is the key for the 'revoke' audiences in the audiences map.
|
||||
RevokeAudienceKey = "revoke"
|
||||
|
@ -104,6 +106,8 @@ func (t Type) String() string {
|
|||
return "AWS"
|
||||
case TypeAzure:
|
||||
return "Azure"
|
||||
case TypeACME:
|
||||
return "ACME"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
@ -151,6 +155,8 @@ func (l *List) UnmarshalJSON(data []byte) error {
|
|||
p = &AWS{}
|
||||
case "azure":
|
||||
p = &Azure{}
|
||||
case "acme":
|
||||
p = &ACME{}
|
||||
default:
|
||||
// Skip unsupported provisioners. A client using this method may be
|
||||
// compiled with a version of smallstep/certificates that does not
|
||||
|
@ -197,3 +203,93 @@ func SanitizeSSHUserPrincipal(email string) string {
|
|||
}
|
||||
}, strings.ToLower(email))
|
||||
}
|
||||
|
||||
// MockProvisioner for testing
|
||||
type MockProvisioner struct {
|
||||
Mret1, Mret2, Mret3 interface{}
|
||||
Merr error
|
||||
MgetID func() string
|
||||
MgetTokenID func(string) (string, error)
|
||||
MgetName func() string
|
||||
MgetType func() Type
|
||||
MgetEncryptedKey func() (string, string, bool)
|
||||
Minit func(Config) error
|
||||
MauthorizeRevoke func(ott string) error
|
||||
MauthorizeSign func(ctx context.Context, ott string) ([]SignOption, error)
|
||||
MauthorizeRenewal func(*x509.Certificate) error
|
||||
}
|
||||
|
||||
// GetID mock
|
||||
func (m *MockProvisioner) GetID() string {
|
||||
if m.MgetID != nil {
|
||||
return m.MgetID()
|
||||
}
|
||||
return m.Mret1.(string)
|
||||
}
|
||||
|
||||
// GetTokenID mock
|
||||
func (m *MockProvisioner) GetTokenID(token string) (string, error) {
|
||||
if m.MgetTokenID != nil {
|
||||
return m.MgetTokenID(token)
|
||||
}
|
||||
if m.Mret1 == nil {
|
||||
return "", m.Merr
|
||||
}
|
||||
return m.Mret1.(string), m.Merr
|
||||
}
|
||||
|
||||
// GetName mock
|
||||
func (m *MockProvisioner) GetName() string {
|
||||
if m.MgetName != nil {
|
||||
return m.MgetName()
|
||||
}
|
||||
return m.Mret1.(string)
|
||||
}
|
||||
|
||||
// GetType mock
|
||||
func (m *MockProvisioner) GetType() Type {
|
||||
if m.MgetType != nil {
|
||||
return m.MgetType()
|
||||
}
|
||||
return m.Mret1.(Type)
|
||||
}
|
||||
|
||||
// GetEncryptedKey mock
|
||||
func (m *MockProvisioner) GetEncryptedKey() (string, string, bool) {
|
||||
if m.MgetEncryptedKey != nil {
|
||||
return m.MgetEncryptedKey()
|
||||
}
|
||||
return m.Mret1.(string), m.Mret2.(string), m.Mret3.(bool)
|
||||
}
|
||||
|
||||
// Init mock
|
||||
func (m *MockProvisioner) Init(c Config) error {
|
||||
if m.Minit != nil {
|
||||
return m.Minit(c)
|
||||
}
|
||||
return m.Merr
|
||||
}
|
||||
|
||||
// AuthorizeRevoke mock
|
||||
func (m *MockProvisioner) AuthorizeRevoke(ott string) error {
|
||||
if m.MauthorizeRevoke != nil {
|
||||
return m.MauthorizeRevoke(ott)
|
||||
}
|
||||
return m.Merr
|
||||
}
|
||||
|
||||
// AuthorizeSign mock
|
||||
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]SignOption, error) {
|
||||
if m.MauthorizeSign != nil {
|
||||
return m.MauthorizeSign(ctx, ott)
|
||||
}
|
||||
return m.Mret1.([]SignOption), m.Merr
|
||||
}
|
||||
|
||||
// AuthorizeRenewal mock
|
||||
func (m *MockProvisioner) AuthorizeRenewal(c *x509.Certificate) error {
|
||||
if m.MauthorizeRenewal != nil {
|
||||
return m.MauthorizeRenewal(c)
|
||||
}
|
||||
return m.Merr
|
||||
}
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/binary"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
|
@ -212,7 +216,7 @@ func (m *sshCertificateValidityModifier) Modify(cert *ssh.Certificate) error {
|
|||
}
|
||||
|
||||
if cert.ValidAfter == 0 {
|
||||
cert.ValidAfter = uint64(now().Unix())
|
||||
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
|
||||
}
|
||||
if cert.ValidBefore == 0 {
|
||||
t := time.Unix(int64(cert.ValidAfter), 0)
|
||||
|
@ -261,9 +265,11 @@ func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
|
|||
case len(cert.ValidPrincipals) == 0:
|
||||
return errors.New("ssh certificate valid principals cannot be empty")
|
||||
case cert.ValidAfter == 0:
|
||||
return errors.New("ssh certificate valid after cannot be 0")
|
||||
case cert.ValidBefore == 0:
|
||||
return errors.New("ssh certificate valid before cannot be 0")
|
||||
return errors.New("ssh certificate validAfter cannot be 0")
|
||||
case cert.ValidBefore < uint64(now().Unix()):
|
||||
return errors.New("ssh certificate validBefore cannot be in the past")
|
||||
case cert.ValidBefore < cert.ValidAfter:
|
||||
return errors.New("ssh certificate validBefore cannot be before validAfter")
|
||||
case cert.CertType == ssh.UserCert && len(cert.Extensions) == 0:
|
||||
return errors.New("ssh certificate extensions cannot be empty")
|
||||
case cert.SignatureKey == nil:
|
||||
|
@ -275,6 +281,36 @@ func (v *sshCertificateDefaultValidator) Valid(cert *ssh.Certificate) error {
|
|||
}
|
||||
}
|
||||
|
||||
// sshDefaultPublicKeyValidator implements a validator for the certificate key.
|
||||
type sshDefaultPublicKeyValidator struct{}
|
||||
|
||||
// Valid checks that certificate request common name matches the one configured.
|
||||
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error {
|
||||
if cert.Key == nil {
|
||||
return errors.New("ssh certificate key cannot be nil")
|
||||
}
|
||||
switch cert.Key.Type() {
|
||||
case ssh.KeyAlgoRSA:
|
||||
_, in, ok := sshParseString(cert.Key.Marshal())
|
||||
if !ok {
|
||||
return errors.New("ssh certificate key is invalid")
|
||||
}
|
||||
key, err := sshParseRSAPublicKey(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if key.Size() < keys.MinRSAKeyBytes {
|
||||
return errors.Errorf("ssh certificate key must be at least %d bits (%d bytes)",
|
||||
8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes)
|
||||
}
|
||||
return nil
|
||||
case ssh.KeyAlgoDSA:
|
||||
return errors.New("ssh certificate key algorithm (DSA) is not supported")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// sshCertTypeUInt32
|
||||
func sshCertTypeUInt32(ct string) uint32 {
|
||||
switch ct {
|
||||
|
@ -304,3 +340,41 @@ func containsAllMembers(group, subgroup []string) bool {
|
|||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func sshParseString(in []byte) (out, rest []byte, ok bool) {
|
||||
if len(in) < 4 {
|
||||
return
|
||||
}
|
||||
length := binary.BigEndian.Uint32(in)
|
||||
in = in[4:]
|
||||
if uint32(len(in)) < length {
|
||||
return
|
||||
}
|
||||
out = in[:length]
|
||||
rest = in[length:]
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func sshParseRSAPublicKey(in []byte) (*rsa.PublicKey, error) {
|
||||
var w struct {
|
||||
E *big.Int
|
||||
N *big.Int
|
||||
Rest []byte `ssh:"rest"`
|
||||
}
|
||||
if err := ssh.Unmarshal(in, &w); err != nil {
|
||||
return nil, errors.Wrap(err, "error unmarshalling public key")
|
||||
}
|
||||
if w.E.BitLen() > 24 {
|
||||
return nil, errors.New("invalid public key: exponent too large")
|
||||
}
|
||||
e := w.E.Int64()
|
||||
if e < 3 || e&1 == 0 {
|
||||
return nil, errors.New("invalid public key: incorrect exponent")
|
||||
}
|
||||
|
||||
var key rsa.PublicKey
|
||||
key.E = int(e)
|
||||
key.N = w.N
|
||||
return &key, nil
|
||||
}
|
||||
|
|
192
authority/provisioner/sign_ssh_options_test.go
Normal file
192
authority/provisioner/sign_ssh_options_test.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/crypto/keys"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func Test_sshCertificateDefaultValidator_Valid(t *testing.T) {
|
||||
pub, _, err := keys.GenerateDefaultKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
sshPub, err := ssh.NewPublicKey(pub)
|
||||
assert.FatalError(t, err)
|
||||
v := sshCertificateDefaultValidator{}
|
||||
tests := []struct {
|
||||
name string
|
||||
cert *ssh.Certificate
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"fail/zero-nonce",
|
||||
&ssh.Certificate{},
|
||||
errors.New("ssh certificate nonce cannot be empty"),
|
||||
},
|
||||
{
|
||||
"fail/nil-key",
|
||||
&ssh.Certificate{Nonce: []byte("foo")},
|
||||
errors.New("ssh certificate key cannot be nil"),
|
||||
},
|
||||
{
|
||||
"fail/zero-serial",
|
||||
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub},
|
||||
errors.New("ssh certificate serial cannot be 0"),
|
||||
},
|
||||
{
|
||||
"fail/unexpected-cert-type",
|
||||
// UserCert = 1, HostCert = 2
|
||||
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1},
|
||||
errors.New("ssh certificate has an unknown type: 3"),
|
||||
},
|
||||
{
|
||||
"fail/empty-cert-key-id",
|
||||
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1},
|
||||
errors.New("ssh certificate key id cannot be empty"),
|
||||
},
|
||||
{
|
||||
"fail/empty-valid-principals",
|
||||
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo"},
|
||||
errors.New("ssh certificate valid principals cannot be empty"),
|
||||
},
|
||||
{
|
||||
"fail/zero-validAfter",
|
||||
&ssh.Certificate{
|
||||
Nonce: []byte("foo"),
|
||||
Key: sshPub,
|
||||
Serial: 1,
|
||||
CertType: 1,
|
||||
KeyId: "foo",
|
||||
ValidPrincipals: []string{"foo"},
|
||||
ValidAfter: 0,
|
||||
},
|
||||
errors.New("ssh certificate validAfter cannot be 0"),
|
||||
},
|
||||
{
|
||||
"fail/validBefore-past",
|
||||
&ssh.Certificate{
|
||||
Nonce: []byte("foo"),
|
||||
Key: sshPub,
|
||||
Serial: 1,
|
||||
CertType: 1,
|
||||
KeyId: "foo",
|
||||
ValidPrincipals: []string{"foo"},
|
||||
ValidAfter: uint64(time.Now().Add(-10 * time.Minute).Unix()),
|
||||
ValidBefore: uint64(time.Now().Add(-5 * time.Minute).Unix()),
|
||||
},
|
||||
errors.New("ssh certificate validBefore cannot be in the past"),
|
||||
},
|
||||
{
|
||||
"fail/validAfter-after-validBefore",
|
||||
&ssh.Certificate{
|
||||
Nonce: []byte("foo"),
|
||||
Key: sshPub,
|
||||
Serial: 1,
|
||||
CertType: 1,
|
||||
KeyId: "foo",
|
||||
ValidPrincipals: []string{"foo"},
|
||||
ValidAfter: uint64(time.Now().Add(15 * time.Minute).Unix()),
|
||||
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
|
||||
},
|
||||
errors.New("ssh certificate validBefore cannot be before validAfter"),
|
||||
},
|
||||
{
|
||||
"fail/empty-extensions",
|
||||
&ssh.Certificate{
|
||||
Nonce: []byte("foo"),
|
||||
Key: sshPub,
|
||||
Serial: 1,
|
||||
CertType: 1,
|
||||
KeyId: "foo",
|
||||
ValidPrincipals: []string{"foo"},
|
||||
ValidAfter: uint64(time.Now().Unix()),
|
||||
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
|
||||
},
|
||||
errors.New("ssh certificate extensions cannot be empty"),
|
||||
},
|
||||
{
|
||||
"fail/nil-signature-key",
|
||||
&ssh.Certificate{
|
||||
Nonce: []byte("foo"),
|
||||
Key: sshPub,
|
||||
Serial: 1,
|
||||
CertType: 1,
|
||||
KeyId: "foo",
|
||||
ValidPrincipals: []string{"foo"},
|
||||
ValidAfter: uint64(time.Now().Unix()),
|
||||
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
|
||||
Permissions: ssh.Permissions{
|
||||
Extensions: map[string]string{"foo": "bar"},
|
||||
},
|
||||
},
|
||||
errors.New("ssh certificate signature key cannot be nil"),
|
||||
},
|
||||
{
|
||||
"fail/nil-signature",
|
||||
&ssh.Certificate{
|
||||
Nonce: []byte("foo"),
|
||||
Key: sshPub,
|
||||
Serial: 1,
|
||||
CertType: 1,
|
||||
KeyId: "foo",
|
||||
ValidPrincipals: []string{"foo"},
|
||||
ValidAfter: uint64(time.Now().Unix()),
|
||||
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
|
||||
Permissions: ssh.Permissions{
|
||||
Extensions: map[string]string{"foo": "bar"},
|
||||
},
|
||||
SignatureKey: sshPub,
|
||||
},
|
||||
errors.New("ssh certificate signature cannot be nil"),
|
||||
},
|
||||
{
|
||||
"ok/userCert",
|
||||
&ssh.Certificate{
|
||||
Nonce: []byte("foo"),
|
||||
Key: sshPub,
|
||||
Serial: 1,
|
||||
CertType: 1,
|
||||
KeyId: "foo",
|
||||
ValidPrincipals: []string{"foo"},
|
||||
ValidAfter: uint64(time.Now().Unix()),
|
||||
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
|
||||
Permissions: ssh.Permissions{
|
||||
Extensions: map[string]string{"foo": "bar"},
|
||||
},
|
||||
SignatureKey: sshPub,
|
||||
Signature: &ssh.Signature{},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"ok/hostCert",
|
||||
&ssh.Certificate{
|
||||
Nonce: []byte("foo"),
|
||||
Key: sshPub,
|
||||
Serial: 1,
|
||||
CertType: 2,
|
||||
KeyId: "foo",
|
||||
ValidPrincipals: []string{"foo"},
|
||||
ValidAfter: uint64(time.Now().Unix()),
|
||||
ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()),
|
||||
SignatureKey: sshPub,
|
||||
Signature: &ssh.Signature{},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := v.Valid(tt.cert); err != nil {
|
||||
if assert.NotNil(t, tt.err) {
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -248,10 +248,6 @@ func TestTimeDuration_Unix(t *testing.T) {
|
|||
func TestTimeDuration_String(t *testing.T) {
|
||||
tm, fn := mockNow()
|
||||
defer fn()
|
||||
type fields struct {
|
||||
t time.Time
|
||||
d time.Duration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
timeDuration *TimeDuration
|
||||
|
|
|
@ -709,7 +709,7 @@ func generateJWKServer(n int) *httptest.Server {
|
|||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
case "/hits":
|
||||
writeJSON(w, hits)
|
||||
case "/openid-configuration", "/.well-known/openid-configuration":
|
||||
case "/.well-known/openid-configuration":
|
||||
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
|
||||
case "/random":
|
||||
keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
|
||||
|
@ -730,3 +730,15 @@ func generateJWKServer(n int) *httptest.Server {
|
|||
srv.Start()
|
||||
return srv
|
||||
}
|
||||
|
||||
func generateACME() (*ACME, error) {
|
||||
// Initialize provisioners
|
||||
p := &ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-provisioner.com",
|
||||
}
|
||||
if err := p.Init(Config{Claims: globalProvisionerClaims}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
|
|
@ -35,3 +35,13 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
|
|||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// LoadProvisionerByID returns an interface to the provisioner with the given ID.
|
||||
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
p, ok := a.provisioners.Load(id)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("provisioner not found"),
|
||||
http.StatusNotFound, apiCtx{}}
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
|
8
authority/testdata/certs/badsig.csr
vendored
Normal file
8
authority/testdata/certs/badsig.csr
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
-----BEGIN CERTIFICATE REQUEST-----
|
||||
MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI
|
||||
zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI
|
||||
cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ
|
||||
DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ
|
||||
ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5
|
||||
OI+cWOIc/IGwqZul/zEF5dani5ihOL7UwA==
|
||||
-----END CERTIFICATE REQUEST-----
|
8
authority/testdata/certs/foo.csr
vendored
Normal file
8
authority/testdata/certs/foo.csr
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
-----BEGIN CERTIFICATE REQUEST-----
|
||||
MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI
|
||||
zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI
|
||||
cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ
|
||||
DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ
|
||||
ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5
|
||||
OI+cWOIc/IGwqZul/zEF5dani5ihOR7UwA==
|
||||
-----END CERTIFICATE REQUEST-----
|
|
@ -212,7 +212,6 @@ type RevokeOptions struct {
|
|||
MTLS bool
|
||||
Crt *x509.Certificate
|
||||
OTT string
|
||||
errCtxt map[string]interface{}
|
||||
}
|
||||
|
||||
// Revoke revokes a certificate.
|
||||
|
|
354
ca/acmeClient.go
Normal file
354
ca/acmeClient.go
Normal file
|
@ -0,0 +1,354 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
acmeAPI "github.com/smallstep/certificates/acme/api"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// ACMEClient implements an HTTP client to an ACME API.
|
||||
type ACMEClient struct {
|
||||
client *http.Client
|
||||
dirLoc string
|
||||
dir *acme.Directory
|
||||
acc *acme.Account
|
||||
Key *jose.JSONWebKey
|
||||
kid string
|
||||
}
|
||||
|
||||
// NewACMEClient initializes a new ACMEClient.
|
||||
func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*ACMEClient, error) {
|
||||
// Retrieve transport from options.
|
||||
o := new(clientOptions)
|
||||
if err := o.apply(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr, err := o.getTransport(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ac := &ACMEClient{
|
||||
client: &http.Client{
|
||||
Transport: tr,
|
||||
},
|
||||
dirLoc: endpoint,
|
||||
}
|
||||
|
||||
resp, err := ac.client.Get(endpoint)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", endpoint)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
var dir acme.Directory
|
||||
if err := readJSON(resp.Body, &dir); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", endpoint)
|
||||
}
|
||||
|
||||
ac.dir = &dir
|
||||
|
||||
ac.Key, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nar := &acmeAPI.NewAccountRequest{
|
||||
Contact: contact,
|
||||
TermsOfServiceAgreed: true,
|
||||
}
|
||||
payload, err := json.Marshal(nar)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshaling new account request")
|
||||
}
|
||||
|
||||
resp, err = ac.post(payload, ac.dir.NewAccount, withJWK(ac))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
var acc acme.Account
|
||||
if err := readJSON(resp.Body, &acc); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", dir.NewAccount)
|
||||
}
|
||||
ac.acc = &acc
|
||||
ac.kid = resp.Header.Get("Location")
|
||||
|
||||
return ac, nil
|
||||
}
|
||||
|
||||
// GetDirectory makes a directory request to the ACME api and returns an
|
||||
// ACME directory object.
|
||||
func (c *ACMEClient) GetDirectory() (*acme.Directory, error) {
|
||||
return c.dir, nil
|
||||
}
|
||||
|
||||
// GetNonce makes a nonce request to the ACME api and returns an
|
||||
// ACME directory object.
|
||||
func (c *ACMEClient) GetNonce() (string, error) {
|
||||
resp, err := c.client.Get(c.dir.NewNonce)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return "", readACMEError(resp.Body)
|
||||
}
|
||||
return resp.Header.Get("Replay-Nonce"), nil
|
||||
}
|
||||
|
||||
type withHeaderOption func(so *jose.SignerOptions)
|
||||
|
||||
func withJWK(c *ACMEClient) withHeaderOption {
|
||||
return func(so *jose.SignerOptions) {
|
||||
so.WithHeader("jwk", c.Key.Public())
|
||||
}
|
||||
}
|
||||
|
||||
func withKid(c *ACMEClient) withHeaderOption {
|
||||
return func(so *jose.SignerOptions) {
|
||||
so.WithHeader("kid", c.kid)
|
||||
}
|
||||
}
|
||||
|
||||
// serialize serializes a json web signature and doesn't omit empty fields.
|
||||
func serialize(obj *jose.JSONWebSignature) (string, error) {
|
||||
raw, err := obj.CompactSerialize()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error serializing JWS")
|
||||
}
|
||||
parts := strings.Split(raw, ".")
|
||||
msg := struct {
|
||||
Protected string `json:"protected"`
|
||||
Payload string `json:"payload"`
|
||||
Signature string `json:"signature"`
|
||||
}{Protected: parts[0], Payload: parts[1], Signature: parts[2]}
|
||||
b, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error marshaling jws message")
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOption) (*http.Response, error) {
|
||||
if c.Key == nil {
|
||||
return nil, errors.New("acme client not configured with account")
|
||||
}
|
||||
nonce, err := c.GetNonce()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithHeader("nonce", nonce)
|
||||
so.WithHeader("url", url)
|
||||
for _, hop := range headerOps {
|
||||
hop(so)
|
||||
}
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(c.Key.Algorithm),
|
||||
Key: c.Key.Key,
|
||||
}, so)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error creating JWS signer")
|
||||
}
|
||||
signed, err := signer.Sign(payload)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("error signing payload: %s", strings.TrimPrefix(err.Error(), "square/go-jose: "))
|
||||
}
|
||||
raw, err := serialize(signed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.client.Post(url, "application/jose+json", strings.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", c.dir.NewOrder)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// NewOrder creates and returns the information for a new ACME order.
|
||||
func (c *ACMEClient) NewOrder(payload []byte) (*acme.Order, error) {
|
||||
resp, err := c.post(payload, c.dir.NewOrder, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var o acme.Order
|
||||
if err := readJSON(resp.Body, &o); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", c.dir.NewOrder)
|
||||
}
|
||||
o.ID = resp.Header.Get("Location")
|
||||
return &o, nil
|
||||
}
|
||||
|
||||
// GetChallenge returns the Challenge at the given path.
|
||||
// With the validate parameter set to True this method will attempt to validate the
|
||||
// challenge before returning it.
|
||||
func (c *ACMEClient) GetChallenge(url string) (*acme.Challenge, error) {
|
||||
resp, err := c.post(nil, url, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var ch acme.Challenge
|
||||
if err := readJSON(resp.Body, &ch); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||
}
|
||||
return &ch, nil
|
||||
}
|
||||
|
||||
// ValidateChallenge returns the Challenge at the given path.
|
||||
// With the validate parameter set to True this method will attempt to validate the
|
||||
// challenge before returning it.
|
||||
func (c *ACMEClient) ValidateChallenge(url string) error {
|
||||
resp, err := c.post([]byte("{}"), url, withKid(c))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return readACMEError(resp.Body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAuthz returns the Authz at the given path.
|
||||
func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) {
|
||||
resp, err := c.post(nil, url, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var az acme.Authz
|
||||
if err := readJSON(resp.Body, &az); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||
}
|
||||
return &az, nil
|
||||
}
|
||||
|
||||
// GetOrder returns the Order at the given path.
|
||||
func (c *ACMEClient) GetOrder(url string) (*acme.Order, error) {
|
||||
resp, err := c.post(nil, url, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var o acme.Order
|
||||
if err := readJSON(resp.Body, &o); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||
}
|
||||
return &o, nil
|
||||
}
|
||||
|
||||
// FinalizeOrder makes a finalize request to the ACME api.
|
||||
func (c *ACMEClient) FinalizeOrder(url string, csr *x509.CertificateRequest) error {
|
||||
payload, err := json.Marshal(acmeAPI.FinalizeRequest{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error marshaling finalize request")
|
||||
}
|
||||
resp, err := c.post(payload, url, withKid(c))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return readACMEError(resp.Body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCertificate retrieves the certificate along with all intermediates.
|
||||
func (c *ACMEClient) GetCertificate(url string) (*x509.Certificate, []*x509.Certificate, error) {
|
||||
resp, err := c.post(nil, url, withKid(c))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, nil, readACMEError(resp.Body)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error reading GET certificate response")
|
||||
}
|
||||
|
||||
var certs []*x509.Certificate
|
||||
|
||||
block, rest := pem.Decode(bodyBytes)
|
||||
if block == nil {
|
||||
return nil, nil, errors.New("failed to parse any certificates from response")
|
||||
}
|
||||
for block != nil {
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error parsing certificate pem response")
|
||||
}
|
||||
certs = append(certs, cert)
|
||||
block, rest = pem.Decode(rest)
|
||||
}
|
||||
|
||||
return certs[0], certs[1:], nil
|
||||
}
|
||||
|
||||
// GetAccountOrders retrieves the orders belonging to the given account.
|
||||
func (c *ACMEClient) GetAccountOrders() ([]string, error) {
|
||||
if c.acc == nil {
|
||||
return nil, errors.New("acme client not configured with account")
|
||||
}
|
||||
resp, err := c.post(nil, c.acc.Orders, withKid(c))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readACMEError(resp.Body)
|
||||
}
|
||||
|
||||
var orders []string
|
||||
if err := readJSON(resp.Body, &orders); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders)
|
||||
}
|
||||
|
||||
return orders, nil
|
||||
}
|
||||
|
||||
func readACMEError(r io.ReadCloser) error {
|
||||
defer r.Close()
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error reading from body")
|
||||
}
|
||||
ae := new(acme.AError)
|
||||
err = json.Unmarshal(b, &ae)
|
||||
// If we successfully marshaled to an ACMEError then return the ACMEError.
|
||||
if err != nil || len(ae.Error()) == 0 {
|
||||
fmt.Printf("b = %s\n", b)
|
||||
// Throw up our hands.
|
||||
return errors.Errorf("%s", b)
|
||||
}
|
||||
return ae
|
||||
}
|
1355
ca/acmeClient_test.go
Normal file
1355
ca/acmeClient_test.go
Normal file
File diff suppressed because it is too large
Load diff
|
@ -570,8 +570,8 @@ func TestBootstrapListener(t *testing.T) {
|
|||
return
|
||||
}
|
||||
wg := new(sync.WaitGroup)
|
||||
go func() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
http.Serve(lis, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
|
|
41
ca/ca.go
41
ca/ca.go
|
@ -3,18 +3,23 @@ package ca
|
|||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
acmeAPI "github.com/smallstep/certificates/acme/api"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/monitoring"
|
||||
"github.com/smallstep/certificates/server"
|
||||
"github.com/smallstep/nosql"
|
||||
)
|
||||
|
||||
type options struct {
|
||||
|
@ -100,13 +105,47 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) {
|
|||
mux := chi.NewRouter()
|
||||
handler := http.Handler(mux)
|
||||
|
||||
// Add api endpoints in / and /1.0
|
||||
// Add regular CA api endpoints in / and /1.0
|
||||
routerHandler := api.New(auth)
|
||||
routerHandler.Route(mux)
|
||||
mux.Route("/1.0", func(r chi.Router) {
|
||||
routerHandler.Route(r)
|
||||
})
|
||||
|
||||
//Add ACME api endpoints in /acme and /1.0/acme
|
||||
dns := config.DNSNames[0]
|
||||
u, err := url.Parse("https://" + config.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
port := u.Port()
|
||||
if port != "" && port != "443" {
|
||||
dns = fmt.Sprintf("%s:%s", dns, port)
|
||||
}
|
||||
|
||||
prefix := "acme"
|
||||
acmeAuth := acme.NewAuthority(auth.GetDatabase().(nosql.DB), dns, prefix, auth)
|
||||
acmeRouterHandler := acmeAPI.New(acmeAuth)
|
||||
mux.Route("/"+prefix, func(r chi.Router) {
|
||||
acmeRouterHandler.Route(r)
|
||||
})
|
||||
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
|
||||
// of the ACME spec.
|
||||
mux.Route("/2.0/"+prefix, func(r chi.Router) {
|
||||
acmeRouterHandler.Route(r)
|
||||
})
|
||||
|
||||
/*
|
||||
// helpful routine for logging all routes //
|
||||
walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
|
||||
fmt.Printf("%s %s\n", method, route)
|
||||
return nil
|
||||
}
|
||||
if err := chi.Walk(mux, walkFunc); err != nil {
|
||||
fmt.Printf("Logging err: %s\n", err.Error())
|
||||
}
|
||||
*/
|
||||
|
||||
// Add monitoring if configured
|
||||
if len(config.Monitoring) > 0 {
|
||||
m, err := monitoring.New(config.Monitoring)
|
||||
|
|
|
@ -163,8 +163,7 @@ func TestClient_Health(t *testing.T) {
|
|||
}
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Health()
|
||||
|
@ -224,8 +223,7 @@ func TestClient_Root(t *testing.T) {
|
|||
if req.RequestURI != expected {
|
||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
||||
}
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Root(tt.shasum)
|
||||
|
@ -303,8 +301,7 @@ func TestClient_Sign(t *testing.T) {
|
|||
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Sign(tt.request)
|
||||
|
@ -378,8 +375,7 @@ func TestClient_Revoke(t *testing.T) {
|
|||
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Revoke(tt.request, nil)
|
||||
|
@ -438,8 +434,7 @@ func TestClient_Renew(t *testing.T) {
|
|||
}
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Renew(nil)
|
||||
|
@ -502,8 +497,7 @@ func TestClient_Provisioners(t *testing.T) {
|
|||
if req.RequestURI != tt.expectedURI {
|
||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
|
||||
}
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Provisioners(tt.args...)
|
||||
|
@ -562,8 +556,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|||
if req.RequestURI != expected {
|
||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
||||
}
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.ProvisionerKey(tt.kid)
|
||||
|
@ -622,8 +615,7 @@ func TestClient_Roots(t *testing.T) {
|
|||
}
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Roots()
|
||||
|
@ -683,8 +675,7 @@ func TestClient_Federation(t *testing.T) {
|
|||
}
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Federation()
|
||||
|
@ -783,8 +774,7 @@ func TestClient_RootFingerprint(t *testing.T) {
|
|||
}
|
||||
|
||||
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.RootFingerprint()
|
||||
|
|
131
db/db.go
131
db/db.go
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -54,7 +55,7 @@ func New(c *Config) (AuthDB, error) {
|
|||
return nil, errors.Wrapf(err, "Error opening database of Type %s with source %s", c.Type, c.DataSource)
|
||||
}
|
||||
|
||||
tables := [][]byte{revokedCertsTable, certsTable}
|
||||
tables := [][]byte{revokedCertsTable, certsTable, usedOTTTable}
|
||||
for _, b := range tables {
|
||||
if err := db.CreateTable(b); err != nil {
|
||||
return nil, errors.Wrapf(err, "error creating table %s",
|
||||
|
@ -102,22 +103,20 @@ func (db *DB) IsRevoked(sn string) (bool, error) {
|
|||
|
||||
// Revoke adds a certificate to the revocation table.
|
||||
func (db *DB) Revoke(rci *RevokedCertificateInfo) error {
|
||||
isRvkd, err := db.IsRevoked(rci.Serial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isRvkd {
|
||||
return ErrAlreadyExists
|
||||
}
|
||||
rcib, err := json.Marshal(rci)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error marshaling revoked certificate info")
|
||||
}
|
||||
|
||||
if err = db.Set(revokedCertsTable, []byte(rci.Serial), rcib); err != nil {
|
||||
return errors.Wrap(err, "database Set error")
|
||||
}
|
||||
_, swapped, err := db.CmpAndSwap(revokedCertsTable, []byte(rci.Serial), nil, rcib)
|
||||
switch {
|
||||
case err != nil:
|
||||
return errors.Wrap(err, "error AuthDB CmpAndSwap")
|
||||
case !swapped:
|
||||
return ErrAlreadyExists
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// StoreCertificate stores a certificate PEM.
|
||||
|
@ -132,15 +131,11 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
|
|||
// for the first time, false otherwise.
|
||||
func (db *DB) UseToken(id, tok string) (bool, error) {
|
||||
_, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok))
|
||||
switch {
|
||||
case err != nil:
|
||||
if err != nil {
|
||||
return false, errors.Wrapf(err, "error storing used token %s/%s",
|
||||
string(usedOTTTable), id)
|
||||
case !swapped:
|
||||
return false, nil
|
||||
default:
|
||||
return true, nil
|
||||
}
|
||||
return swapped, nil
|
||||
}
|
||||
|
||||
// Shutdown sends a shutdown message to the database.
|
||||
|
@ -153,3 +148,105 @@ func (db *DB) Shutdown() error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockNoSQLDB //
|
||||
type MockNoSQLDB struct {
|
||||
Err error
|
||||
Ret1, Ret2 interface{}
|
||||
MGet func(bucket, key []byte) ([]byte, error)
|
||||
MSet func(bucket, key, value []byte) error
|
||||
MOpen func(dataSourceName string, opt ...database.Option) error
|
||||
MClose func() error
|
||||
MCreateTable func(bucket []byte) error
|
||||
MDeleteTable func(bucket []byte) error
|
||||
MDel func(bucket, key []byte) error
|
||||
MList func(bucket []byte) ([]*database.Entry, error)
|
||||
MUpdate func(tx *database.Tx) error
|
||||
MCmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error)
|
||||
}
|
||||
|
||||
// CmpAndSwap mock
|
||||
func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if m.MCmpAndSwap != nil {
|
||||
return m.MCmpAndSwap(bucket, key, old, newval)
|
||||
}
|
||||
if m.Ret1 == nil {
|
||||
return nil, false, m.Err
|
||||
}
|
||||
return m.Ret1.([]byte), m.Ret2.(bool), m.Err
|
||||
}
|
||||
|
||||
// Get mock
|
||||
func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) {
|
||||
if m.MGet != nil {
|
||||
return m.MGet(bucket, key)
|
||||
}
|
||||
if m.Ret1 == nil {
|
||||
return nil, m.Err
|
||||
}
|
||||
return m.Ret1.([]byte), m.Err
|
||||
}
|
||||
|
||||
// Set mock
|
||||
func (m *MockNoSQLDB) Set(bucket, key, value []byte) error {
|
||||
if m.MSet != nil {
|
||||
return m.MSet(bucket, key, value)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// Open mock
|
||||
func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error {
|
||||
if m.MOpen != nil {
|
||||
return m.MOpen(dataSourceName, opt...)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// Close mock
|
||||
func (m *MockNoSQLDB) Close() error {
|
||||
if m.MClose != nil {
|
||||
return m.MClose()
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// CreateTable mock
|
||||
func (m *MockNoSQLDB) CreateTable(bucket []byte) error {
|
||||
if m.MCreateTable != nil {
|
||||
return m.MCreateTable(bucket)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// DeleteTable mock
|
||||
func (m *MockNoSQLDB) DeleteTable(bucket []byte) error {
|
||||
if m.MDeleteTable != nil {
|
||||
return m.MDeleteTable(bucket)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// Del mock
|
||||
func (m *MockNoSQLDB) Del(bucket, key []byte) error {
|
||||
if m.MDel != nil {
|
||||
return m.MDel(bucket, key)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// List mock
|
||||
func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) {
|
||||
if m.MList != nil {
|
||||
return m.MList(bucket)
|
||||
}
|
||||
return m.Ret1.([]*database.Entry), m.Err
|
||||
}
|
||||
|
||||
// Update mock
|
||||
func (m *MockNoSQLDB) Update(tx *database.Tx) error {
|
||||
if m.MUpdate != nil {
|
||||
return m.MUpdate(tx)
|
||||
}
|
||||
return m.Err
|
||||
}
|
||||
|
|
132
db/db_test.go
132
db/db_test.go
|
@ -8,97 +8,6 @@ import (
|
|||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
type MockNoSQLDB struct {
|
||||
err error
|
||||
ret1, ret2 interface{}
|
||||
get func(bucket, key []byte) ([]byte, error)
|
||||
set func(bucket, key, value []byte) error
|
||||
open func(dataSourceName string, opt ...database.Option) error
|
||||
close func() error
|
||||
createTable func(bucket []byte) error
|
||||
deleteTable func(bucket []byte) error
|
||||
del func(bucket, key []byte) error
|
||||
list func(bucket []byte) ([]*database.Entry, error)
|
||||
update func(tx *database.Tx) error
|
||||
cmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error)
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
if m.cmpAndSwap != nil {
|
||||
return m.cmpAndSwap(bucket, key, old, newval)
|
||||
}
|
||||
if m.ret1 == nil {
|
||||
return nil, false, m.err
|
||||
}
|
||||
return m.ret1.([]byte), m.ret2.(bool), m.err
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) {
|
||||
if m.get != nil {
|
||||
return m.get(bucket, key)
|
||||
}
|
||||
if m.ret1 == nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.ret1.([]byte), m.err
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) Set(bucket, key, value []byte) error {
|
||||
if m.set != nil {
|
||||
return m.set(bucket, key, value)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error {
|
||||
if m.open != nil {
|
||||
return m.open(dataSourceName, opt...)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) Close() error {
|
||||
if m.close != nil {
|
||||
return m.close()
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) CreateTable(bucket []byte) error {
|
||||
if m.createTable != nil {
|
||||
return m.createTable(bucket)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) DeleteTable(bucket []byte) error {
|
||||
if m.deleteTable != nil {
|
||||
return m.deleteTable(bucket)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) Del(bucket, key []byte) error {
|
||||
if m.del != nil {
|
||||
return m.del(bucket, key)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) {
|
||||
if m.list != nil {
|
||||
return m.list(bucket)
|
||||
}
|
||||
return m.ret1.([]*database.Entry), m.err
|
||||
}
|
||||
|
||||
func (m *MockNoSQLDB) Update(tx *database.Tx) error {
|
||||
if m.update != nil {
|
||||
return m.update(tx)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func TestIsRevoked(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
key string
|
||||
|
@ -111,16 +20,16 @@ func TestIsRevoked(t *testing.T) {
|
|||
},
|
||||
"false/ErrNotFound": {
|
||||
key: "sn",
|
||||
db: &DB{&MockNoSQLDB{err: database.ErrNotFound, ret1: nil}, true},
|
||||
db: &DB{&MockNoSQLDB{Err: database.ErrNotFound, Ret1: nil}, true},
|
||||
},
|
||||
"error/checking bucket": {
|
||||
key: "sn",
|
||||
db: &DB{&MockNoSQLDB{err: errors.New("force"), ret1: nil}, true},
|
||||
db: &DB{&MockNoSQLDB{Err: errors.New("force"), Ret1: nil}, true},
|
||||
err: errors.New("error checking revocation bucket: force"),
|
||||
},
|
||||
"true": {
|
||||
key: "sn",
|
||||
db: &DB{&MockNoSQLDB{ret1: []byte("value")}, true},
|
||||
db: &DB{&MockNoSQLDB{Ret1: []byte("value")}, true},
|
||||
isRevoked: true,
|
||||
},
|
||||
}
|
||||
|
@ -148,41 +57,26 @@ func TestRevoke(t *testing.T) {
|
|||
"error/force isRevoked": {
|
||||
rci: &RevokedCertificateInfo{Serial: "sn"},
|
||||
db: &DB{&MockNoSQLDB{
|
||||
get: func(bucket []byte, sn []byte) ([]byte, error) {
|
||||
return nil, errors.New("force IsRevoked")
|
||||
MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
}, true},
|
||||
err: errors.New("error checking revocation bucket: force IsRevoked"),
|
||||
err: errors.New("error AuthDB CmpAndSwap: force"),
|
||||
},
|
||||
"error/was already revoked": {
|
||||
rci: &RevokedCertificateInfo{Serial: "sn"},
|
||||
db: &DB{&MockNoSQLDB{
|
||||
get: func(bucket []byte, sn []byte) ([]byte, error) {
|
||||
return nil, nil
|
||||
MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
}, true},
|
||||
err: ErrAlreadyExists,
|
||||
},
|
||||
"error/database set": {
|
||||
rci: &RevokedCertificateInfo{Serial: "sn"},
|
||||
db: &DB{&MockNoSQLDB{
|
||||
get: func(bucket []byte, sn []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
set: func(bucket []byte, key []byte, value []byte) error {
|
||||
return errors.New("force")
|
||||
},
|
||||
}, true},
|
||||
err: errors.New("database Set error: force"),
|
||||
},
|
||||
"ok": {
|
||||
rci: &RevokedCertificateInfo{Serial: "sn"},
|
||||
db: &DB{&MockNoSQLDB{
|
||||
get: func(bucket []byte, sn []byte) ([]byte, error) {
|
||||
return nil, database.ErrNotFound
|
||||
},
|
||||
set: func(bucket []byte, key []byte, value []byte) error {
|
||||
return nil
|
||||
MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), true, nil
|
||||
},
|
||||
}, true},
|
||||
},
|
||||
|
@ -214,7 +108,7 @@ func TestUseToken(t *testing.T) {
|
|||
id: "id",
|
||||
tok: "token",
|
||||
db: &DB{&MockNoSQLDB{
|
||||
cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return nil, false, errors.New("force")
|
||||
},
|
||||
}, true},
|
||||
|
@ -227,7 +121,7 @@ func TestUseToken(t *testing.T) {
|
|||
id: "id",
|
||||
tok: "token",
|
||||
db: &DB{&MockNoSQLDB{
|
||||
cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("foo"), false, nil
|
||||
},
|
||||
}, true},
|
||||
|
@ -239,7 +133,7 @@ func TestUseToken(t *testing.T) {
|
|||
id: "id",
|
||||
tok: "token",
|
||||
db: &DB{&MockNoSQLDB{
|
||||
cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||
return []byte("bar"), true, nil
|
||||
},
|
||||
}, true},
|
||||
|
|
55
db/simple.go
55
db/simple.go
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/nosql/database"
|
||||
)
|
||||
|
||||
// ErrNotImplemented is an error returned when an operation is Not Implemented.
|
||||
|
@ -61,3 +62,57 @@ func (s *SimpleDB) UseToken(id, tok string) (bool, error) {
|
|||
func (s *SimpleDB) Shutdown() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// nosql.DB interface implementation //
|
||||
|
||||
// Open opens the database available with the given options.
|
||||
func (s *SimpleDB) Open(dataSourceName string, opt ...database.Option) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// Close closes the current database.
|
||||
func (s *SimpleDB) Close() error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// Get returns the value stored in the given table/bucket and key.
|
||||
func (s *SimpleDB) Get(bucket, key []byte) ([]byte, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// Set sets the given value in the given table/bucket and key.
|
||||
func (s *SimpleDB) Set(bucket, key, value []byte) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// CmpAndSwap swaps the value at the given bucket and key if the current
|
||||
// value is equivalent to the oldValue input. Returns 'true' if the
|
||||
// swap was successful and 'false' otherwise.
|
||||
func (s *SimpleDB) CmpAndSwap(bucket, key, oldValue, newValue []byte) ([]byte, bool, error) {
|
||||
return nil, false, ErrNotImplemented
|
||||
}
|
||||
|
||||
// Del deletes the data in the given table/bucket and key.
|
||||
func (s *SimpleDB) Del(bucket, key []byte) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// List returns a list of all the entries in a given table/bucket.
|
||||
func (s *SimpleDB) List(bucket []byte) ([]*database.Entry, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// Update performs a transaction with multiple read-write commands.
|
||||
func (s *SimpleDB) Update(tx *database.Tx) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// CreateTable creates a table or a bucket in the database.
|
||||
func (s *SimpleDB) CreateTable(bucket []byte) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// DeleteTable deletes a table or a bucket in the database.
|
||||
func (s *SimpleDB) DeleteTable(bucket []byte) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
|
|
@ -15,30 +15,38 @@ e.g. `v1.0.2`
|
|||
* **Release Candidate**: not ready for public use, still testing. must have a
|
||||
`-rc*` suffix. e.g. `v1.0.2-rc` or `v1.0.2-rc.4`
|
||||
|
||||
1. **Update the version of step/cli**
|
||||
---
|
||||
1. **Release `cli` first**
|
||||
|
||||
```
|
||||
$ dep ensure -update github.com/smallstep/cli
|
||||
```
|
||||
If you plan to release [`cli`](https://github.com/smallstep/cli) as part of
|
||||
this release, `cli` must be released first. The `certificates` docker container
|
||||
depends on the `cli` container. Make certain to wait until the `cli` travis
|
||||
build has completed.
|
||||
|
||||
2. **Commit all changes.**
|
||||
2. **Update the version of step/cli**
|
||||
|
||||
<pre><code>
|
||||
<b>$ dep ensure -update github.com/smallstep/cli</b>
|
||||
</code></pre>
|
||||
|
||||
3. **Commit all changes.**
|
||||
|
||||
Make sure that the local checkout is up to date with the remote origin and
|
||||
that all local changes have been pushed.
|
||||
|
||||
```
|
||||
$ git pull --rebase origin master
|
||||
$ git push
|
||||
```
|
||||
<pre><code>
|
||||
<b>$ git pull --rebase origin master</b>
|
||||
<b>$ git push</b>
|
||||
</code></pre>
|
||||
|
||||
3. **Tag it!**
|
||||
4. **Tag it!**
|
||||
|
||||
1. **Find the most recent tag.**
|
||||
|
||||
```
|
||||
$ git fetch --tags
|
||||
$ git tag
|
||||
```
|
||||
<pre><code>
|
||||
<b>$ git fetch --tags</b>
|
||||
<b>$ git tag</b>
|
||||
</code></pre>
|
||||
|
||||
The new tag needs to be the logical successor of the most recent existing tag.
|
||||
See [versioning](#versioning) section for more information on version numbers.
|
||||
|
@ -47,14 +55,14 @@ e.g. `v1.0.2`
|
|||
|
||||
Is the new release a *release candidate* or a *standard release*?
|
||||
|
||||
1. Release Candidate
|
||||
1. **Release Candidate**
|
||||
|
||||
If the most recent tag is a standard release, say `v1.0.2`, then the version
|
||||
of the next release candidate should be `v1.0.3-rc.1`. If the most recent tag
|
||||
is a release candidate, say `v1.0.2-rc.3`, then the version of the next
|
||||
release candidate should be `v1.0.2-rc.4`.
|
||||
|
||||
2. Standard Release
|
||||
2. **Standard Release**
|
||||
|
||||
If the most recent tag is a standard release, say `v1.0.2`, then the version
|
||||
of the next standard release should be `v1.0.3`. If the most recent tag
|
||||
|
@ -64,21 +72,25 @@ e.g. `v1.0.2`
|
|||
|
||||
3. **Create a local tag.**
|
||||
|
||||
```
|
||||
$ git tag v1.0.3 # standard release
|
||||
<pre><code>
|
||||
# standard release
|
||||
<b>$ git tag v1.0.3</b>
|
||||
...or
|
||||
$ git tag v1.0.3-rc.1 # release candidate
|
||||
```
|
||||
# release candidate
|
||||
<b>$ git tag v1.0.3-rc.1</b>
|
||||
</code></pre>
|
||||
|
||||
4. **Push the new tag to the remote origin.**
|
||||
|
||||
```
|
||||
$ git push origin tag v1.0.3 # standard release
|
||||
<pre><code>
|
||||
# standard release
|
||||
<b>$ git push origin tag v1.0.3</b>
|
||||
...or
|
||||
$ git push origin tag v1.0.3-rc.1 # release candidate
|
||||
```
|
||||
# release candidate
|
||||
<b>$ git push origin tag v1.0.3-rc.1</b>
|
||||
</code></pre>
|
||||
|
||||
4. Check the build status at
|
||||
5. **Check the build status at**
|
||||
[Travis-CI](https://travis-ci.com/smallstep/certificates/builds/).
|
||||
|
||||
Travis will begin by verifying that there are no compilation or linting errors
|
||||
|
@ -97,15 +109,15 @@ e.g. `v1.0.2`
|
|||
|
||||
> **NOTE**: if you plan to release `cli` next then you can skip this step.
|
||||
|
||||
```
|
||||
$ cd archlinux
|
||||
<pre><code>
|
||||
<b>$ cd archlinux</b>
|
||||
|
||||
# Get up to date...
|
||||
$ git pull origin master
|
||||
$ make
|
||||
<b>$ git pull origin master</b>
|
||||
<b>$ make</b>
|
||||
|
||||
$ ./update --ca v1.0.3
|
||||
```
|
||||
<b>$ ./update --ca v1.0.3</b>
|
||||
</code></pre>
|
||||
|
||||
7. **Update the Helm packages**
|
||||
|
||||
|
@ -125,9 +137,9 @@ e.g. `v1.0.2`
|
|||
|
||||
Then create the step-certificates package running:
|
||||
|
||||
```sh
|
||||
$ helm package ./step-certificates
|
||||
```
|
||||
<pre><code>
|
||||
<b>$ helm package ./step-certificates</b>
|
||||
</code></pre>
|
||||
|
||||
A new file like `step-certificates-<version>.tgz` will be created.
|
||||
Now commit and push your changes (don't commit the tarball) to the master
|
||||
|
@ -136,15 +148,15 @@ e.g. `v1.0.2`
|
|||
Next checkout the `gh-pages` branch. `git add` the new tar-ball and update
|
||||
the index.yaml using the `helm repo index` command:
|
||||
|
||||
```sh
|
||||
$ git checkout gh-pages
|
||||
$ git add "step-certificates-<version>.tgz"
|
||||
$ helm repo index --merge index.yaml --url https://smallstep.github.io/helm-charts/ .
|
||||
$ git commit -a -m "Add package for step-certificates <appVersion>"
|
||||
$ git push origin gh-pages
|
||||
```
|
||||
<pre><code>
|
||||
<b>$ git checkout gh-pages</b>
|
||||
<b>$ git add "step-certificates-<version>.tgz"</b>
|
||||
<b>$ helm repo index --merge index.yaml --url https://smallstep.github.io/helm-charts/ .</b>
|
||||
<b>$ git commit -a -m "Add package for step-certificates <appVersion>"</b>
|
||||
<b>$ git push origin gh-pages</b>
|
||||
</code></pre>
|
||||
|
||||
*All Done!*
|
||||
***All Done!***
|
||||
|
||||
## Versioning
|
||||
|
||||
|
|
160
docs/acme.md
Normal file
160
docs/acme.md
Normal file
|
@ -0,0 +1,160 @@
|
|||
# Using ACME with `step-ca `
|
||||
|
||||
Let’s assume you’ve [installed
|
||||
`step-ca`](https://smallstep.com/docs/getting-started/#1-installing-step-and-step-ca)
|
||||
(e.g., using `brew install step`), have it running at `https://ca.internal`,
|
||||
and you’ve [bootstrapped your ACME client
|
||||
system(s)](https://smallstep.com/docs/getting-started/#bootstrapping) (or at
|
||||
least [installed your root
|
||||
certificate](https://smallstep.com/docs/cli/ca/root/) at
|
||||
`~/.step/certs/root_ca.crt`).
|
||||
|
||||
## Enabling ACME
|
||||
|
||||
To enable ACME, simply [add an ACME provisioner](https://smallstep.com/docs/cli/ca/provisioner/add/) to your `step-ca` configuration
|
||||
by running:
|
||||
|
||||
```
|
||||
$ step ca provisioner add my-acme-provisioner --type ACME
|
||||
```
|
||||
|
||||
> NOTE: The above command will add a new provisioner of type `ACME` and name
|
||||
> `my-acme-provisioner`. The name is used to identify the provisioner
|
||||
> (e.g. you cannot have two `ACME` provisioners with the same name).
|
||||
|
||||
Now restart or SIGHUP `step-ca` to pick up the new configuration.
|
||||
|
||||
That’s it.
|
||||
|
||||
## Configuring Clients
|
||||
|
||||
To configure an ACME client to connect to `step-ca` you need to:
|
||||
|
||||
1. Point the client at the right ACME directory URL
|
||||
2. Tell the client to trust your CA’s root certificate
|
||||
|
||||
Once certificates are issued, you’ll also need to ensure they’re renewed before
|
||||
they expire.
|
||||
|
||||
### Pointing Clients at the right ACME Directory URL
|
||||
|
||||
Most ACME clients connect to Let’s Encrypt by default. To connect to `step-ca`
|
||||
you need to point the client at the right [ACME directory
|
||||
URL](https://tools.ietf.org/html/rfc8555#section-7.1.1).
|
||||
|
||||
A single instance of `step-ca` can have multiple ACME provisioners, each with
|
||||
their own ACME directory URL that looks like:
|
||||
|
||||
```
|
||||
https://{ca-host}/acme/{provisioner-name}/directory
|
||||
```
|
||||
|
||||
We just added an ACME provisioner named “acme”. Its ACME directory URL is:
|
||||
|
||||
```
|
||||
https://ca.internal/acme/acme/directory
|
||||
```
|
||||
|
||||
### Telling clients to trust your CA’s root certificate
|
||||
|
||||
Communication between an ACME client and server [always uses
|
||||
HTTPS](https://tools.ietf.org/html/rfc8555#section-6.1). By default, client’s
|
||||
will validate the server’s HTTPS certificate using the public root certificates
|
||||
in your system’s [default
|
||||
trust](https://smallstep.com/blog/everything-pki.html#trust-stores) store.
|
||||
That’s fine when you’re connecting to Let’s Encrypt: it’s a public CA and its
|
||||
root certificate is in your system’s default trust store already. Your internal
|
||||
root certificate isn’t, so HTTPS connections from ACME clients to `step-ca` will
|
||||
fail.
|
||||
|
||||
There are two ways to address this problem:
|
||||
|
||||
1. Explicitly configure your ACME client to trust `step-ca`'s root certificate, or
|
||||
2. Add `step-ca`'s root certificate to your system’s default trust store (e.g.,
|
||||
using `[step certificate
|
||||
install](https://smallstep.com/docs/cli/certificate/install/)`)
|
||||
|
||||
If you’re using your CA for TLS in production, explicitly configuring your ACME
|
||||
client to only trust your root certificate is a better option. We’ll
|
||||
demonstrate this method with several clients below.
|
||||
|
||||
If you’re simulating Let’s Encrypt in pre-production, installing your root
|
||||
certificate is a more faithful simulation of production. Once your root
|
||||
certificate is installed, no additional client configuration is necessary.
|
||||
|
||||
> Caution: adding a root certificate to your system’s trust store is a global
|
||||
> operation. Certificates issued by your CA will be trusted everywhere,
|
||||
> including in web browsers.
|
||||
|
||||
### Example using [`certbot`](https://certbot.eff.org/)
|
||||
|
||||
[`certbot`](https://certbot.eff.org/) is the grandaddy of ACME clients. Built
|
||||
and supported by [the EFF](https://www.eff.org/), it’s the standard-bearer for
|
||||
production-grade command-line ACME.
|
||||
|
||||
To get a certificate from `step-ca` using `certbot` you need to:
|
||||
|
||||
1. Point `certbot` at your ACME directory URL using the `--`server flag.
|
||||
2. Tell `certbot` to trust your root certificate using the `REQUESTS_CA_BUNDLE` environment variable.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ sudo REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt \
|
||||
certbot certonly -n --standalone -d foo.internal \
|
||||
--server https://ca.internal/acme/acme/directory
|
||||
```
|
||||
|
||||
`sudo` is required in `certbot`'s [*standalone*
|
||||
mode](https://certbot.eff.org/docs/using.html#standalone) so it can listen on
|
||||
port 80 to complete the `http-01` challenge. If you already have a webserver
|
||||
running you can use [*webroot*
|
||||
mode](https://certbot.eff.org/docs/using.html#webroot) instead. With the
|
||||
[appropriate plugin](https://certbot.eff.org/docs/using.html#dns-plugins)
|
||||
`certbot` also supports the `dns-01` challenge for most popular DNS providers.
|
||||
Deeper integrations with [nginx](https://certbot.eff.org/docs/using.html#nginx)
|
||||
and [apache](https://certbot.eff.org/docs/using.html#apache) can even configure
|
||||
your server to use HTTPS automatically (we'll set this up ourselves later). All
|
||||
of this works with `step-ca`.
|
||||
|
||||
You can renew all of the certificates you've installed using `cerbot` by running:
|
||||
|
||||
```
|
||||
$ sudo REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt certbot renew
|
||||
```
|
||||
|
||||
You can automate renewal with a simple `cron` entry:
|
||||
|
||||
```
|
||||
*/15 * * * * root REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt certbot -q renew
|
||||
```
|
||||
|
||||
The `certbot` packages for some Linux distributions will create a `cron` entry
|
||||
or [systemd
|
||||
timer](https://stevenwestmoreland.com/2017/11/renewing-certbot-certificates-using-a-systemd-timer.html)
|
||||
like this for you. This entry won't work with `step-ca` because it [doesn't set
|
||||
the `REQUESTS_CA_BUNDLE` environment
|
||||
variable](https://github.com/certbot/certbot/issues/7170). You'll need to
|
||||
manually tweak it to do so.
|
||||
|
||||
More subtly, `certbot`'s default renewal job is tuned for Let's Encrypt's 90
|
||||
day certificate lifetimes: it's run every 12 hours, with actual renewals
|
||||
occurring for certificates within 30 days of expiry. By default, `step-ca`
|
||||
issues certificates with *much shorter* 24 hour lifetimes. The `cron` entry
|
||||
above accounts for this by running `certbot renew` every 15 minutes. You'll
|
||||
also want to configure your domain to only renew certificates when they're
|
||||
within a few hours of expiry by adding a line like:
|
||||
|
||||
```
|
||||
renew_before_expiry = 8 hours
|
||||
```
|
||||
|
||||
to the top of your renewal configuration (e.g., in `/etc/letsencrypt/renewal/foo.internal.conf`).
|
||||
|
||||
## Feedback
|
||||
|
||||
`step-ca` should work with any ACMEv2
|
||||
([RFC8555](https://tools.ietf.org/html/rfc8555)) compliant client that supports
|
||||
the http-01 or dns-01 challenge. If you run into any issues please let us know
|
||||
[on gitter](https://gitter.im/smallstep/community) or [in an
|
||||
issue](https://github.com/smallstep/certificates/issues/new?template=bug_report.md).
|
|
@ -111,6 +111,7 @@ is G-Suite.
|
|||
"configurationEndpoint": "https://accounts.google.com/.well-known/openid-configuration",
|
||||
"admins": ["you@smallstep.com"],
|
||||
"domains": ["smallstep.com"],
|
||||
"listenAddress": ":10000",
|
||||
"claims": {
|
||||
"maxTLSCertDuration": "8h",
|
||||
"defaultTLSCertDuration": "2h",
|
||||
|
@ -141,6 +142,12 @@ is G-Suite.
|
|||
* `domains` (optional): is the list of domains valid. If provided only the
|
||||
emails with the provided domains will be able to authenticate.
|
||||
|
||||
* `listenAddress` (optional): is the loopback address (`:port` or `host:port`)
|
||||
where the authorization server will redirect to complete the authorization
|
||||
flow. If it's not defined `step` will use `127.0.0.1` with a random port. This
|
||||
configuration is only required if the authorization server doesn't allow any
|
||||
port to be specified at the time of the request for loopback IP redirect URIs.
|
||||
|
||||
* `claims` (optional): overwrites the default claims set in the authority, see
|
||||
the [JWK](#jwk) section for all the options.
|
||||
|
||||
|
|
|
@ -30,9 +30,6 @@ func NewResponseLogger(w http.ResponseWriter) ResponseLogger {
|
|||
|
||||
func wrapLogger(w http.ResponseWriter) (rw ResponseLogger) {
|
||||
rw = &rwDefault{w, 200, 0, nil}
|
||||
if c, ok := w.(http.CloseNotifier); ok {
|
||||
rw = &rwCloseNotifier{rw, c}
|
||||
}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
rw = &rwFlusher{rw, f}
|
||||
}
|
||||
|
@ -88,15 +85,6 @@ func (r *rwDefault) WithFields(fields map[string]interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
type rwCloseNotifier struct {
|
||||
ResponseLogger
|
||||
c http.CloseNotifier
|
||||
}
|
||||
|
||||
func (r *rwCloseNotifier) CloseNotify() <-chan bool {
|
||||
return r.CloseNotify()
|
||||
}
|
||||
|
||||
type rwFlusher struct {
|
||||
ResponseLogger
|
||||
f http.Flusher
|
||||
|
|
Loading…
Reference in a new issue