forked from TrueCloudLab/certificates
commit
1807e240ea
54 changed files with 15687 additions and 184 deletions
|
@ -60,6 +60,7 @@ issues:
|
||||||
- declaration of "err" shadows declaration at line
|
- declaration of "err" shadows declaration at line
|
||||||
- should have a package comment, unless it's in another file for this package
|
- should have a package comment, unless it's in another file for this package
|
||||||
- error strings should not be capitalized or end with punctuation or a newline
|
- error strings should not be capitalized or end with punctuation or a newline
|
||||||
|
- declaration of "authz" shadows declaration at line
|
||||||
# golangci.com configuration
|
# golangci.com configuration
|
||||||
# https://github.com/golangci/golangci/wiki/Configuration
|
# https://github.com/golangci/golangci/wiki/Configuration
|
||||||
service:
|
service:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
language: go
|
language: go
|
||||||
go:
|
go:
|
||||||
- 1.12.x
|
- 1.13.x
|
||||||
addons:
|
addons:
|
||||||
apt:
|
apt:
|
||||||
packages:
|
packages:
|
||||||
|
|
2
Gopkg.lock
generated
2
Gopkg.lock
generated
|
@ -228,7 +228,7 @@
|
||||||
"utils",
|
"utils",
|
||||||
]
|
]
|
||||||
pruneopts = "UT"
|
pruneopts = "UT"
|
||||||
revision = "e097873f958542df7505184bee0fadfcf17027de"
|
revision = "ae6e517f70783467afe6199e12fb43309a7e693e"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
|
|
214
acme/account.go
Normal file
214
acme/account.go
Normal file
|
@ -0,0 +1,214 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Account is a subset of the internal account type containing only those
|
||||||
|
// attributes required for responses in the ACME protocol.
|
||||||
|
type Account struct {
|
||||||
|
Contact []string `json:"contact,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Orders string `json:"orders"`
|
||||||
|
ID string `json:"-"`
|
||||||
|
Key *jose.JSONWebKey `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToLog enables response logging.
|
||||||
|
func (a *Account) ToLog() (interface{}, error) {
|
||||||
|
b, err := json.Marshal(a)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging"))
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the account ID.
|
||||||
|
func (a *Account) GetID() string {
|
||||||
|
return a.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKey returns the JWK associated with the account.
|
||||||
|
func (a *Account) GetKey() *jose.JSONWebKey {
|
||||||
|
return a.Key
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid returns true if the Account is valid.
|
||||||
|
func (a *Account) IsValid() bool {
|
||||||
|
return a.Status == StatusValid
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountOptions are the options needed to create a new ACME account.
|
||||||
|
type AccountOptions struct {
|
||||||
|
Key *jose.JSONWebKey
|
||||||
|
Contact []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// account represents an ACME account.
|
||||||
|
type account struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Created time.Time `json:"created"`
|
||||||
|
Deactivated time.Time `json:"deactivated"`
|
||||||
|
Key *jose.JSONWebKey `json:"key"`
|
||||||
|
Contact []string `json:"contact,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAccount returns a new acme account type.
|
||||||
|
func newAccount(db nosql.DB, ops AccountOptions) (*account, error) {
|
||||||
|
id, err := randID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &account{
|
||||||
|
ID: id,
|
||||||
|
Key: ops.Key,
|
||||||
|
Contact: ops.Contact,
|
||||||
|
Status: "valid",
|
||||||
|
Created: clock.Now(),
|
||||||
|
}
|
||||||
|
return a, a.saveNew(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// toACME converts the internal Account type into the public acmeAccount
|
||||||
|
// type for presentation in the ACME protocol.
|
||||||
|
func (a *account) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Account, error) {
|
||||||
|
return &Account{
|
||||||
|
Status: a.Status,
|
||||||
|
Contact: a.Contact,
|
||||||
|
Orders: dir.getLink(OrdersByAccountLink, URLSafeProvisionerName(p), true, a.ID),
|
||||||
|
Key: a.Key,
|
||||||
|
ID: a.ID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// save writes the Account to the DB.
|
||||||
|
// If the account is new then the necessary indices will be created.
|
||||||
|
// Else, the account in the DB will be updated.
|
||||||
|
func (a *account) saveNew(db nosql.DB) error {
|
||||||
|
kid, err := keyToID(a.Key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
kidB := []byte(kid)
|
||||||
|
|
||||||
|
// Set the jwkID -> acme account ID index
|
||||||
|
_, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID))
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index"))
|
||||||
|
case !swapped:
|
||||||
|
return ServerInternalErr(errors.Errorf("key-id to account-id index already exists"))
|
||||||
|
default:
|
||||||
|
if err = a.save(db, nil); err != nil {
|
||||||
|
db.Del(accountByKeyIDTable, kidB)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *account) save(db nosql.DB, old *account) error {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
oldB []byte
|
||||||
|
)
|
||||||
|
if old == nil {
|
||||||
|
oldB = nil
|
||||||
|
} else {
|
||||||
|
if oldB, err = json.Marshal(old); err != nil {
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := json.Marshal(*a)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error marshaling new account object")
|
||||||
|
}
|
||||||
|
// Set the Account
|
||||||
|
_, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error storing account"))
|
||||||
|
case !swapped:
|
||||||
|
return ServerInternalErr(errors.New("error storing account; " +
|
||||||
|
"value has changed since last read"))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update updates the acme account object stored in the database if,
|
||||||
|
// and only if, the account has not changed since the last read.
|
||||||
|
func (a *account) update(db nosql.DB, contact []string) (*account, error) {
|
||||||
|
b := *a
|
||||||
|
b.Contact = contact
|
||||||
|
if err := b.save(db, a); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deactivate deactivates the acme account.
|
||||||
|
func (a *account) deactivate(db nosql.DB) (*account, error) {
|
||||||
|
b := *a
|
||||||
|
b.Status = StatusDeactivated
|
||||||
|
b.Deactivated = clock.Now()
|
||||||
|
if err := b.save(db, a); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAccountByID retrieves the account with the given ID.
|
||||||
|
func getAccountByID(db nosql.DB, id string) (*account, error) {
|
||||||
|
ab, err := db.Get(accountTable, []byte(id))
|
||||||
|
if err != nil {
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id))
|
||||||
|
}
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id))
|
||||||
|
}
|
||||||
|
|
||||||
|
a := new(account)
|
||||||
|
if err = json.Unmarshal(ab, a); err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account"))
|
||||||
|
}
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAccountByKeyID retrieves Id associated with the given Kid.
|
||||||
|
func getAccountByKeyID(db nosql.DB, kid string) (*account, error) {
|
||||||
|
id, err := db.Get(accountByKeyIDTable, []byte(kid))
|
||||||
|
if err != nil {
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid))
|
||||||
|
}
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index"))
|
||||||
|
}
|
||||||
|
return getAccountByID(db, string(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOrderIDsByAccount retrieves a list of Order IDs that were created by the
|
||||||
|
// account.
|
||||||
|
func getOrderIDsByAccount(db nosql.DB, id string) ([]string, error) {
|
||||||
|
b, err := db.Get(ordersByAccountIDTable, []byte(id))
|
||||||
|
if err != nil {
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return []string{}, nil
|
||||||
|
}
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", id))
|
||||||
|
}
|
||||||
|
var orderIDs []string
|
||||||
|
if err := json.Unmarshal(b, &orderIDs); err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", id))
|
||||||
|
}
|
||||||
|
return orderIDs, nil
|
||||||
|
}
|
844
acme/account_test.go
Normal file
844
acme/account_test.go
Normal file
|
@ -0,0 +1,844 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
"github.com/smallstep/nosql/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultDisableRenewal = false
|
||||||
|
globalProvisionerClaims = provisioner.Claims{
|
||||||
|
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||||
|
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
|
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func newProv() provisioner.Interface {
|
||||||
|
// Initialize provisioners
|
||||||
|
p := &provisioner.ACME{
|
||||||
|
Type: "ACME",
|
||||||
|
Name: "test@acme-provisioner.com",
|
||||||
|
}
|
||||||
|
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||||
|
fmt.Printf("%v", err)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAcc() (*account, error) {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mockdb := &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return newAccount(mockdb, AccountOptions{
|
||||||
|
Key: jwk, Contact: []string{"foo", "bar"},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountByID(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
id string
|
||||||
|
db nosql.DB
|
||||||
|
acc *account
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
id: acc.ID,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, database.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db-error": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
id: acc.ID,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
id: acc.ID,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.ID))
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
id: acc.ID,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.ID))
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if acc, err := getAccountByID(tc.db, tc.id); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||||
|
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||||
|
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||||
|
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||||
|
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||||
|
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountByKeyID(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
kid string
|
||||||
|
db nosql.DB
|
||||||
|
acc *account
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/kid-not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
kid: "foo",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, database.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: MalformedErr(errors.Errorf("account with key id foo not found: not found")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
kid: "foo",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error loading key-account index: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/getAccount-error": func(t *testing.T) test {
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
kid: "foo",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
if count == 0 {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, key, []byte("foo"))
|
||||||
|
count++
|
||||||
|
return []byte("bar"), nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error loading account bar: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
kid: acc.Key.KeyID,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
var ret []byte
|
||||||
|
switch count {
|
||||||
|
case 0:
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.Key.KeyID))
|
||||||
|
ret = []byte(acc.ID)
|
||||||
|
case 1:
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.ID))
|
||||||
|
ret = b
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return ret, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acc: acc,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.acc.ID, acc.ID)
|
||||||
|
assert.Equals(t, tc.acc.Status, acc.Status)
|
||||||
|
assert.Equals(t, tc.acc.Created, acc.Created)
|
||||||
|
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
|
||||||
|
assert.Equals(t, tc.acc.Contact, acc.Contact)
|
||||||
|
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountIDsByAccount(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
id string
|
||||||
|
db nosql.DB
|
||||||
|
res []string
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"ok/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
id: "foo",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, database.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: []string{},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
id: "foo",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
id: "foo",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||||
|
assert.Equals(t, key, []byte("foo"))
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
oids := []string{"foo", "bar", "baz"}
|
||||||
|
b, err := json.Marshal(oids)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
id: "foo",
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||||
|
assert.Equals(t, key, []byte("foo"))
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: oids,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if oids, err := getOrderIDsByAccount(tc.db, tc.id); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.res, oids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountToACME(t *testing.T) {
|
||||||
|
dir := newDirectory("ca.smallstep.com", "acme")
|
||||||
|
prov := newProv()
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
acc *account
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{acc: acc}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
acmeAccount, err := tc.acc.toACME(nil, dir, prov)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, acmeAccount.ID, tc.acc.ID)
|
||||||
|
assert.Equals(t, acmeAccount.Status, tc.acc.Status)
|
||||||
|
assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
|
||||||
|
assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
|
||||||
|
assert.Equals(t, acmeAccount.Orders,
|
||||||
|
fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s/orders", URLSafeProvisionerName(prov), tc.acc.ID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountSave(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
acc, old *account
|
||||||
|
db nosql.DB
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
old: nil,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
old: nil,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return []byte("foo"), false, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error storing account; value has changed since last read")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/old-nil": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
old: nil,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, b, newval)
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, []byte(acc.ID), key)
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/old-not-nil": func(t *testing.T) test {
|
||||||
|
oldAcc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
oldb, err := json.Marshal(oldAcc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
old: oldAcc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, old, oldb)
|
||||||
|
assert.Equals(t, newval, b)
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, []byte(acc.ID), key)
|
||||||
|
return []byte("foo"), true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if err := tc.acc.save(tc.db, tc.old); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountSaveNew(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
acc *account
|
||||||
|
db nosql.DB
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/keyToID-error": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc.Key.Key = "foo"
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/swap-error": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
kid, err := keyToID(acc.Key)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, key, []byte(kid))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, newval, []byte(acc.ID))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/swap-false": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
kid, err := keyToID(acc.Key)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, key, []byte(kid))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, newval, []byte(acc.ID))
|
||||||
|
return nil, false, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("key-id to account-id index already exists")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/save-error": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
kid, err := keyToID(acc.Key)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if count == 0 {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, key, []byte(kid))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, newval, []byte(acc.ID))
|
||||||
|
count++
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.ID))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, newval, b)
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
MDel: func(bucket, key []byte) error {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, key, []byte(kid))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
kid, err := keyToID(acc.Key)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if count == 0 {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, key, []byte(kid))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, newval, []byte(acc.ID))
|
||||||
|
count++
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.ID))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, newval, b)
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if err := tc.acc.saveNew(tc.db); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountUpdate(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
acc *account
|
||||||
|
contact []string
|
||||||
|
db nosql.DB
|
||||||
|
res []byte
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
contact := []string{"foo", "bar"}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/save-error": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
oldb, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
_acc := *acc
|
||||||
|
clone := &_acc
|
||||||
|
clone.Contact = contact
|
||||||
|
b, err := json.Marshal(clone)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
contact: contact,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.ID))
|
||||||
|
assert.Equals(t, old, oldb)
|
||||||
|
assert.Equals(t, newval, b)
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
oldb, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
_acc := *acc
|
||||||
|
clone := &_acc
|
||||||
|
clone.Contact = contact
|
||||||
|
b, err := json.Marshal(clone)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
contact: contact,
|
||||||
|
res: b,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.ID))
|
||||||
|
assert.Equals(t, old, oldb)
|
||||||
|
assert.Equals(t, newval, b)
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
acc, err := tc.acc.update(tc.db, tc.contact)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
b, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, b, tc.res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountDeactivate(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
acc *account
|
||||||
|
db nosql.DB
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/save-error": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
oldb, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.ID))
|
||||||
|
assert.Equals(t, old, oldb)
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error storing account: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc, err := newAcc()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
oldb, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, key, []byte(acc.ID))
|
||||||
|
assert.Equals(t, old, oldb)
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
acc, err := tc.acc.deactivate(tc.db)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, acc.ID, tc.acc.ID)
|
||||||
|
assert.Equals(t, acc.Contact, tc.acc.Contact)
|
||||||
|
assert.Equals(t, acc.Status, StatusDeactivated)
|
||||||
|
assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID)
|
||||||
|
assert.Equals(t, acc.Created, tc.acc.Created)
|
||||||
|
|
||||||
|
assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute)))
|
||||||
|
assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAccount(t *testing.T) {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
kid, err := keyToID(jwk)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ops := AccountOptions{
|
||||||
|
Key: jwk,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
type test struct {
|
||||||
|
ops AccountOptions
|
||||||
|
db nosql.DB
|
||||||
|
err *Error
|
||||||
|
id *string
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/store-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ops: ops,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
var _id string
|
||||||
|
id := &_id
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
ops: ops,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
switch count {
|
||||||
|
case 0:
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, key, []byte(kid))
|
||||||
|
case 1:
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
*id = string(key)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
acc, err := newAccount(tc.db, tc.ops)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, acc.ID, *tc.id)
|
||||||
|
assert.Equals(t, acc.Status, StatusValid)
|
||||||
|
assert.Equals(t, acc.Contact, ops.Contact)
|
||||||
|
assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID)
|
||||||
|
|
||||||
|
assert.True(t, acc.Deactivated.IsZero())
|
||||||
|
|
||||||
|
assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute)))
|
||||||
|
assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
213
acme/api/account.go
Normal file
213
acme/api/account.go
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewAccountRequest represents the payload for a new account request.
|
||||||
|
type NewAccountRequest struct {
|
||||||
|
Contact []string `json:"contact"`
|
||||||
|
OnlyReturnExisting bool `json:"onlyReturnExisting"`
|
||||||
|
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateContacts(cs []string) error {
|
||||||
|
for _, c := range cs {
|
||||||
|
if len(c) == 0 {
|
||||||
|
return acme.MalformedErr(errors.New("contact cannot be empty string"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates a new-account request body.
|
||||||
|
func (n *NewAccountRequest) Validate() error {
|
||||||
|
if n.OnlyReturnExisting && len(n.Contact) > 0 {
|
||||||
|
return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone"))
|
||||||
|
}
|
||||||
|
return validateContacts(n.Contact)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountRequest represents an update-account request.
|
||||||
|
type UpdateAccountRequest struct {
|
||||||
|
Contact []string `json:"contact"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDeactivateRequest returns true if the update request is a deactivation
|
||||||
|
// request, false otherwise.
|
||||||
|
func (u *UpdateAccountRequest) IsDeactivateRequest() bool {
|
||||||
|
return u.Status == acme.StatusDeactivated
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates a update-account request body.
|
||||||
|
func (u *UpdateAccountRequest) Validate() error {
|
||||||
|
switch {
|
||||||
|
case len(u.Status) > 0 && len(u.Contact) > 0:
|
||||||
|
return acme.MalformedErr(errors.New("incompatible input; contact and " +
|
||||||
|
"status updates are mutually exclusive"))
|
||||||
|
case len(u.Contact) > 0:
|
||||||
|
if err := validateContacts(u.Contact); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case len(u.Status) > 0:
|
||||||
|
if u.Status != acme.StatusDeactivated {
|
||||||
|
return acme.MalformedErr(errors.Errorf("cannot update account "+
|
||||||
|
"status to %s, only deactivated", u.Status))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return acme.MalformedErr(errors.Errorf("empty update request"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAccount is the handler resource for creating new ACME accounts.
|
||||||
|
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload, err := payloadFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var nar NewAccountRequest
|
||||||
|
if err := json.Unmarshal(payload.value, &nar); err != nil {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
||||||
|
"failed to unmarshal new-account request payload")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := nar.Validate(); err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
httpStatus := http.StatusCreated
|
||||||
|
acc, err := accountFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
acmeErr, ok := err.(*acme.Error)
|
||||||
|
if !ok || acmeErr.Status != http.StatusNotFound {
|
||||||
|
// Something went wrong ...
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Account does not exist //
|
||||||
|
if nar.OnlyReturnExisting {
|
||||||
|
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jwk, err := jwkFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if acc, err = h.Auth.NewAccount(prov, acme.AccountOptions{
|
||||||
|
Key: jwk,
|
||||||
|
Contact: nar.Contact,
|
||||||
|
}); err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Account exists //
|
||||||
|
httpStatus = http.StatusOK
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink,
|
||||||
|
acme.URLSafeProvisionerName(prov), true, acc.GetID()))
|
||||||
|
api.JSONStatus(w, acc, httpStatus)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpdateAccount is the api for updating an ACME account.
|
||||||
|
func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acc, err := accountFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload, err := payloadFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !payload.isPostAsGet {
|
||||||
|
var uar UpdateAccountRequest
|
||||||
|
if err := json.Unmarshal(payload.value, &uar); err != nil {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := uar.Validate(); err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
if uar.IsDeactivateRequest() {
|
||||||
|
acc, err = h.Auth.DeactivateAccount(prov, acc.GetID())
|
||||||
|
} else {
|
||||||
|
acc, err = h.Auth.UpdateAccount(prov, acc.GetID(), uar.Contact)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, acc.GetID()))
|
||||||
|
api.JSON(w, acc)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
||||||
|
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||||
|
m := map[string]interface{}{
|
||||||
|
"orders": oids,
|
||||||
|
}
|
||||||
|
rl.WithFields(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
|
||||||
|
func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acc, err := accountFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accID := chi.URLParam(r, "accID")
|
||||||
|
if acc.ID != accID {
|
||||||
|
api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
orders, err := h.Auth.GetOrdersByAccount(prov, acc.GetID())
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
api.JSON(w, orders)
|
||||||
|
logOrdersByAccount(w, orders)
|
||||||
|
return
|
||||||
|
}
|
790
acme/api/account_test.go
Normal file
790
acme/api/account_test.go
Normal file
|
@ -0,0 +1,790 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultDisableRenewal = false
|
||||||
|
globalProvisionerClaims = provisioner.Claims{
|
||||||
|
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||||
|
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
|
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func newProv() provisioner.Interface {
|
||||||
|
// Initialize provisioners
|
||||||
|
p := &provisioner.ACME{
|
||||||
|
Type: "ACME",
|
||||||
|
Name: "test@acme-provisioner.com",
|
||||||
|
}
|
||||||
|
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
|
||||||
|
fmt.Printf("%v", err)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAccountRequestValidate(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
nar *NewAccountRequest
|
||||||
|
err *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/incompatible-input": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
nar: &NewAccountRequest{
|
||||||
|
OnlyReturnExisting: true,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
},
|
||||||
|
err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/bad-contact": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
nar: &NewAccountRequest{
|
||||||
|
Contact: []string{"foo", ""},
|
||||||
|
},
|
||||||
|
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
nar: &NewAccountRequest{
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/onlyReturnExisting": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
nar: &NewAccountRequest{
|
||||||
|
OnlyReturnExisting: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
if err := tc.nar.Validate(); err != nil {
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
ae, ok := err.(*acme.Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateAccountRequestValidate(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
uar *UpdateAccountRequest
|
||||||
|
err *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/incompatible-input": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
uar: &UpdateAccountRequest{
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Status: "foo",
|
||||||
|
},
|
||||||
|
err: acme.MalformedErr(errors.Errorf("incompatible input; " +
|
||||||
|
"contact and status updates are mutually exclusive")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/bad-contact": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
uar: &UpdateAccountRequest{
|
||||||
|
Contact: []string{"foo", ""},
|
||||||
|
},
|
||||||
|
err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/bad-status": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
uar: &UpdateAccountRequest{
|
||||||
|
Status: "foo",
|
||||||
|
},
|
||||||
|
err: acme.MalformedErr(errors.Errorf("cannot update account " +
|
||||||
|
"status to foo, only deactivated")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/contact": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
uar: &UpdateAccountRequest{
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/status": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
uar: &UpdateAccountRequest{
|
||||||
|
Status: "deactivated",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
if err := tc.uar.Validate(); err != nil {
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
ae, ok := err.(*acme.Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerGetOrdersByAccount(t *testing.T) {
|
||||||
|
oids := []string{
|
||||||
|
"https://ca.smallstep.com/acme/order/foo",
|
||||||
|
"https://ca.smallstep.com/acme/order/bar",
|
||||||
|
}
|
||||||
|
accID := "account-id"
|
||||||
|
prov := newProv()
|
||||||
|
|
||||||
|
// Request with chi context
|
||||||
|
chiCtx := chi.NewRouteContext()
|
||||||
|
chiCtx.URLParams.Add("accID", accID)
|
||||||
|
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID)
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
auth acme.Interface
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-account": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/account-id-mismatch": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "foo"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 401,
|
||||||
|
problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/getOrdersByAccount-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: accID}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
err: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: accID}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
getOrdersByAccount: func(p provisioner.Interface, id string) ([]string, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, id, acc.ID)
|
||||||
|
return oids, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(tc.auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.GetOrdersByAccount(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
expB, err := json.Marshal(oids)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerNewAccount(t *testing.T) {
|
||||||
|
accID := "accountID"
|
||||||
|
acc := acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
Status: "valid",
|
||||||
|
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
|
||||||
|
}
|
||||||
|
prov := newProv()
|
||||||
|
|
||||||
|
url := "https://ca.smallstep.com/acme/new-account"
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
auth acme.Interface
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||||
|
nar := &NewAccountRequest{
|
||||||
|
Contact: []string{"foo", ""},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-existing-account": func(t *testing.T) test {
|
||||||
|
nar := &NewAccountRequest{
|
||||||
|
OnlyReturnExisting: true,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-jwk": func(t *testing.T) test {
|
||||||
|
nar := &NewAccountRequest{
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-jwk": func(t *testing.T) test {
|
||||||
|
nar := &NewAccountRequest{
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/NewAccount-error": func(t *testing.T) test {
|
||||||
|
nar := &NewAccountRequest{
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, ops.Contact, nar.Contact)
|
||||||
|
assert.Equals(t, ops.Key, jwk)
|
||||||
|
return nil, acme.ServerInternalErr(errors.New("force"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/new-account": func(t *testing.T) test {
|
||||||
|
nar := &NewAccountRequest{
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, ops.Contact, nar.Contact)
|
||||||
|
assert.Equals(t, ops.Key, jwk)
|
||||||
|
return &acc, nil
|
||||||
|
},
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.Equals(t, typ, acme.AccountLink)
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{accID})
|
||||||
|
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), accID)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 201,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/return-existing": func(t *testing.T) test {
|
||||||
|
nar := &NewAccountRequest{
|
||||||
|
OnlyReturnExisting: true,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.Equals(t, typ, acme.AccountLink)
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{accID})
|
||||||
|
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), accID)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(tc.auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.NewAccount(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
expB, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
|
assert.Equals(t, res.Header["Location"],
|
||||||
|
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), accID)})
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerGetUpdateAccount(t *testing.T) {
|
||||||
|
accID := "accountID"
|
||||||
|
acc := acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
Status: "valid",
|
||||||
|
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
|
||||||
|
}
|
||||||
|
prov := newProv()
|
||||||
|
|
||||||
|
// Request with chi context
|
||||||
|
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s", accID)
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
auth acme.Interface
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-account": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||||
|
uar := &UpdateAccountRequest{
|
||||||
|
Contact: []string{"foo", ""},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(uar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("contact cannot be empty string")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/Deactivate-error": func(t *testing.T) test {
|
||||||
|
uar := &UpdateAccountRequest{
|
||||||
|
Status: "deactivated",
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(uar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, id, accID)
|
||||||
|
return nil, acme.ServerInternalErr(errors.New("force"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/UpdateAccount-error": func(t *testing.T) test {
|
||||||
|
uar := &UpdateAccountRequest{
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(uar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, id, accID)
|
||||||
|
assert.Equals(t, contacts, uar.Contact)
|
||||||
|
return nil, acme.ServerInternalErr(errors.New("force"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/deactivate": func(t *testing.T) test {
|
||||||
|
uar := &UpdateAccountRequest{
|
||||||
|
Status: "deactivated",
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(uar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, id, accID)
|
||||||
|
return &acc, nil
|
||||||
|
},
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, typ, acme.AccountLink)
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{accID})
|
||||||
|
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), accID)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/new-account": func(t *testing.T) test {
|
||||||
|
uar := &UpdateAccountRequest{
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(uar)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, id, accID)
|
||||||
|
assert.Equals(t, contacts, uar.Contact)
|
||||||
|
return &acc, nil
|
||||||
|
},
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, typ, acme.AccountLink)
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{accID})
|
||||||
|
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), accID)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/post-as-get": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, &acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, typ, acme.AccountLink)
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{accID})
|
||||||
|
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), accID)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(tc.auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.GetUpdateAccount(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
expB, err := json.Marshal(acc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
|
assert.Equals(t, res.Header["Location"],
|
||||||
|
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), accID)})
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
214
acme/api/handler.go
Normal file
214
acme/api/handler.go
Normal file
|
@ -0,0 +1,214 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func link(url, typ string) string {
|
||||||
|
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
accContextKey = contextKey("acc")
|
||||||
|
jwsContextKey = contextKey("jws")
|
||||||
|
jwkContextKey = contextKey("jwk")
|
||||||
|
payloadContextKey = contextKey("payload")
|
||||||
|
provisionerContextKey = contextKey("provisioner")
|
||||||
|
)
|
||||||
|
|
||||||
|
type payloadInfo struct {
|
||||||
|
value []byte
|
||||||
|
isPostAsGet bool
|
||||||
|
isEmptyJSON bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func accountFromContext(r *http.Request) (*acme.Account, error) {
|
||||||
|
val, ok := r.Context().Value(accContextKey).(*acme.Account)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil, acme.AccountDoesNotExistErr(nil)
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
func jwkFromContext(r *http.Request) (*jose.JSONWebKey, error) {
|
||||||
|
val, ok := r.Context().Value(jwkContextKey).(*jose.JSONWebKey)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil, acme.ServerInternalErr(errors.Errorf("jwk expected in request context"))
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
func jwsFromContext(r *http.Request) (*jose.JSONWebSignature, error) {
|
||||||
|
val, ok := r.Context().Value(jwsContextKey).(*jose.JSONWebSignature)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil, acme.ServerInternalErr(errors.Errorf("jws expected in request context"))
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
func payloadFromContext(r *http.Request) (*payloadInfo, error) {
|
||||||
|
val, ok := r.Context().Value(payloadContextKey).(*payloadInfo)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context"))
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
func provisionerFromContext(r *http.Request) (provisioner.Interface, error) {
|
||||||
|
val, ok := r.Context().Value(provisionerContextKey).(provisioner.Interface)
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil, acme.ServerInternalErr(errors.Errorf("provisioner expected in request context"))
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new ACME API router.
|
||||||
|
func New(acmeAuth acme.Interface) api.RouterHandler {
|
||||||
|
return &Handler{acmeAuth}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler is the ACME request handler.
|
||||||
|
type Handler struct {
|
||||||
|
Auth acme.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route traffic and implement the Router interface.
|
||||||
|
func (h *Handler) Route(r api.Router) {
|
||||||
|
getLink := h.Auth.GetLink
|
||||||
|
// Standard ACME API
|
||||||
|
r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
|
||||||
|
r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce)))
|
||||||
|
r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
|
||||||
|
r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory)))
|
||||||
|
|
||||||
|
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||||
|
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))
|
||||||
|
}
|
||||||
|
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||||
|
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false), extractPayloadByJWK(h.NewAccount))
|
||||||
|
r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.GetUpdateAccount))
|
||||||
|
r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false), extractPayloadByKid(h.NewOrder))
|
||||||
|
r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
|
||||||
|
r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount)))
|
||||||
|
r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
|
||||||
|
r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz)))
|
||||||
|
r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, "{chID}"), extractPayloadByKid(h.GetChallenge))
|
||||||
|
r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNonce just sets the right header since a Nonce is added to each response
|
||||||
|
// by middleware by default.
|
||||||
|
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "HEAD" {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDirectory is the ACME resource for returning a directory configuration
|
||||||
|
// for client configuration.
|
||||||
|
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dir := h.Auth.GetDirectory(prov)
|
||||||
|
api.JSON(w, dir)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthz ACME api for retrieving an Authz.
|
||||||
|
func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acc, err := accountFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authz, err := h.Auth.GetAuthz(prov, acc.GetID(), chi.URLParam(r, "authzID"))
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Location", h.Auth.GetLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, authz.GetID()))
|
||||||
|
api.JSON(w, authz)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChallenge ACME api for retrieving a Challenge.
|
||||||
|
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acc, err := accountFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Just verify that the payload was set, since we're not strictly adhering
|
||||||
|
// to ACME V2 spec for reasons specified below.
|
||||||
|
_, err = payloadFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: We should be checking that the request is either a POST-as-GET, or
|
||||||
|
// that the payload is an empty JSON block ({}). However, older ACME clients
|
||||||
|
// still send a vestigial body (rather than an empty JSON block) and
|
||||||
|
// strict enforcement would render these clients broken. For the time being
|
||||||
|
// we'll just ignore the body.
|
||||||
|
var (
|
||||||
|
ch *acme.Challenge
|
||||||
|
chID = chi.URLParam(r, "chID")
|
||||||
|
)
|
||||||
|
ch, err = h.Auth.ValidateChallenge(prov, acc.GetID(), chID, acc.GetKey())
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
getLink := h.Auth.GetLink
|
||||||
|
w.Header().Add("Link", link(getLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, ch.GetAuthzID()), "up"))
|
||||||
|
w.Header().Set("Location", getLink(acme.ChallengeLink, acme.URLSafeProvisionerName(prov), true, ch.GetID()))
|
||||||
|
api.JSON(w, ch)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCertificate ACME api for retrieving a Certificate.
|
||||||
|
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||||
|
acc, err := accountFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
certID := chi.URLParam(r, "certID")
|
||||||
|
certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
|
||||||
|
w.Write(certBytes)
|
||||||
|
return
|
||||||
|
}
|
771
acme/api/handler_test.go
Normal file
771
acme/api/handler_test.go
Normal file
|
@ -0,0 +1,771 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockAcmeAuthority struct {
|
||||||
|
deactivateAccount func(provisioner.Interface, string) (*acme.Account, error)
|
||||||
|
finalizeOrder func(p provisioner.Interface, accID string, id string, csr *x509.CertificateRequest) (*acme.Order, error)
|
||||||
|
getAccount func(p provisioner.Interface, id string) (*acme.Account, error)
|
||||||
|
getAccountByKey func(provisioner.Interface, *jose.JSONWebKey) (*acme.Account, error)
|
||||||
|
getAuthz func(p provisioner.Interface, accID string, id string) (*acme.Authz, error)
|
||||||
|
getCertificate func(accID string, id string) ([]byte, error)
|
||||||
|
getChallenge func(p provisioner.Interface, accID string, id string) (*acme.Challenge, error)
|
||||||
|
getDirectory func(provisioner.Interface) *acme.Directory
|
||||||
|
getLink func(acme.Link, string, bool, ...string) string
|
||||||
|
getOrder func(p provisioner.Interface, accID string, id string) (*acme.Order, error)
|
||||||
|
getOrdersByAccount func(p provisioner.Interface, id string) ([]string, error)
|
||||||
|
loadProvisionerByID func(string) (provisioner.Interface, error)
|
||||||
|
newAccount func(provisioner.Interface, acme.AccountOptions) (*acme.Account, error)
|
||||||
|
newNonce func() (string, error)
|
||||||
|
newOrder func(provisioner.Interface, acme.OrderOptions) (*acme.Order, error)
|
||||||
|
updateAccount func(provisioner.Interface, string, []string) (*acme.Account, error)
|
||||||
|
useNonce func(string) error
|
||||||
|
validateChallenge func(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error)
|
||||||
|
ret1 interface{}
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) DeactivateAccount(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||||
|
if m.deactivateAccount != nil {
|
||||||
|
return m.deactivateAccount(p, id)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Account), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) FinalizeOrder(p provisioner.Interface, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) {
|
||||||
|
if m.finalizeOrder != nil {
|
||||||
|
return m.finalizeOrder(p, accID, id, csr)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Order), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) GetAccount(p provisioner.Interface, id string) (*acme.Account, error) {
|
||||||
|
if m.getAccount != nil {
|
||||||
|
return m.getAccount(p, id)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Account), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) {
|
||||||
|
if m.getAccountByKey != nil {
|
||||||
|
return m.getAccountByKey(p, jwk)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Account), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) GetAuthz(p provisioner.Interface, accID, id string) (*acme.Authz, error) {
|
||||||
|
if m.getAuthz != nil {
|
||||||
|
return m.getAuthz(p, accID, id)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Authz), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) GetCertificate(accID, id string) ([]byte, error) {
|
||||||
|
if m.getCertificate != nil {
|
||||||
|
return m.getCertificate(accID, id)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.([]byte), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) GetChallenge(p provisioner.Interface, accID, id string) (*acme.Challenge, error) {
|
||||||
|
if m.getChallenge != nil {
|
||||||
|
return m.getChallenge(p, accID, id)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Challenge), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) GetDirectory(p provisioner.Interface) *acme.Directory {
|
||||||
|
if m.getDirectory != nil {
|
||||||
|
return m.getDirectory(p)
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Directory)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) GetLink(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
if m.getLink != nil {
|
||||||
|
return m.getLink(typ, provID, abs, in...)
|
||||||
|
}
|
||||||
|
return m.ret1.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) GetOrder(p provisioner.Interface, accID, id string) (*acme.Order, error) {
|
||||||
|
if m.getOrder != nil {
|
||||||
|
return m.getOrder(p, accID, id)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Order), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) {
|
||||||
|
if m.getOrdersByAccount != nil {
|
||||||
|
return m.getOrdersByAccount(p, id)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.([]string), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
|
||||||
|
if m.loadProvisionerByID != nil {
|
||||||
|
return m.loadProvisionerByID(provID)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(provisioner.Interface), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) NewAccount(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) {
|
||||||
|
if m.newAccount != nil {
|
||||||
|
return m.newAccount(p, ops)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Account), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) NewNonce() (string, error) {
|
||||||
|
if m.newNonce != nil {
|
||||||
|
return m.newNonce()
|
||||||
|
} else if m.err != nil {
|
||||||
|
return "", m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(string), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) NewOrder(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||||
|
if m.newOrder != nil {
|
||||||
|
return m.newOrder(p, ops)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Order), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*acme.Account, error) {
|
||||||
|
if m.updateAccount != nil {
|
||||||
|
return m.updateAccount(p, id, contact)
|
||||||
|
} else if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.ret1.(*acme.Account), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) UseNonce(nonce string) error {
|
||||||
|
if m.useNonce != nil {
|
||||||
|
return m.useNonce(nonce)
|
||||||
|
}
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAcmeAuthority) ValidateChallenge(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
|
||||||
|
switch {
|
||||||
|
case m.validateChallenge != nil:
|
||||||
|
return m.validateChallenge(p, accID, id, jwk)
|
||||||
|
case m.err != nil:
|
||||||
|
return nil, m.err
|
||||||
|
default:
|
||||||
|
return m.ret1.(*acme.Challenge), m.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerGetNonce(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"GET", 204},
|
||||||
|
{"HEAD", 200},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request with chi context
|
||||||
|
req := httptest.NewRequest("GET", "http://ca.smallstep.com/nonce", nil)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := New(nil).(*Handler)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req.Method = tt.name
|
||||||
|
h.GetNonce(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("Handler.GetNonce StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerGetDirectory(t *testing.T) {
|
||||||
|
auth := acme.NewAuthority(nil, "ca.smallstep.com", "acme", nil)
|
||||||
|
prov := newProv()
|
||||||
|
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/directory", acme.URLSafeProvisionerName(prov))
|
||||||
|
|
||||||
|
expDir := acme.Directory{
|
||||||
|
NewNonce: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", acme.URLSafeProvisionerName(prov)),
|
||||||
|
NewAccount: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", acme.URLSafeProvisionerName(prov)),
|
||||||
|
NewOrder: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", acme.URLSafeProvisionerName(prov)),
|
||||||
|
RevokeCert: fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", acme.URLSafeProvisionerName(prov)),
|
||||||
|
KeyChange: fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", acme.URLSafeProvisionerName(prov)),
|
||||||
|
}
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.GetDirectory(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
var dir acme.Directory
|
||||||
|
json.Unmarshal(bytes.TrimSpace(body), &dir)
|
||||||
|
assert.Equals(t, dir, expDir)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerGetAuthz(t *testing.T) {
|
||||||
|
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||||
|
az := acme.Authz{
|
||||||
|
ID: "authzID",
|
||||||
|
Identifier: acme.Identifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "example.com",
|
||||||
|
},
|
||||||
|
Status: "pending",
|
||||||
|
Expires: expiry.Format(time.RFC3339),
|
||||||
|
Wildcard: false,
|
||||||
|
Challenges: []*acme.Challenge{
|
||||||
|
{
|
||||||
|
Type: "http-01",
|
||||||
|
Status: "pending",
|
||||||
|
Token: "tok2",
|
||||||
|
URL: "https://ca.smallstep.com/acme/challenge/chHTTPID",
|
||||||
|
ID: "chHTTP01ID",
|
||||||
|
AuthzID: "authzID",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "dns-01",
|
||||||
|
Status: "pending",
|
||||||
|
Token: "tok2",
|
||||||
|
URL: "https://ca.smallstep.com/acme/challenge/chDNSID",
|
||||||
|
ID: "chDNSID",
|
||||||
|
AuthzID: "authzID",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
prov := newProv()
|
||||||
|
|
||||||
|
// Request with chi context
|
||||||
|
chiCtx := chi.NewRouteContext()
|
||||||
|
chiCtx.URLParams.Add("authzID", az.ID)
|
||||||
|
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/challenge/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), az.ID)
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
auth acme.Interface
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-account": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/getAuthz-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
err: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
getAuthz: func(p provisioner.Interface, accID, id string) (*acme.Authz, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, accID, acc.ID)
|
||||||
|
assert.Equals(t, id, az.ID)
|
||||||
|
return &az, nil
|
||||||
|
},
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.Equals(t, typ, acme.AuthzLink)
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{az.ID})
|
||||||
|
return url
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(tc.auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.GetAuthz(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
//var gotAz acme.Authz
|
||||||
|
//assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &gotAz))
|
||||||
|
expB, err := json.Marshal(az)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
|
assert.Equals(t, res.Header["Location"], []string{url})
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerGetCertificate(t *testing.T) {
|
||||||
|
leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
root, err := pemutil.ReadCertificate("../../authority/testdata/certs/root_ca.crt")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
certBytes := append(pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: leaf.Raw,
|
||||||
|
}), pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: inter.Raw,
|
||||||
|
})...)
|
||||||
|
certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: root.Raw,
|
||||||
|
})...)
|
||||||
|
certID := "certID"
|
||||||
|
|
||||||
|
prov := newProv()
|
||||||
|
// Request with chi context
|
||||||
|
chiCtx := chi.NewRouteContext()
|
||||||
|
chiCtx.URLParams.Add("certID", certID)
|
||||||
|
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/certificate/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), certID)
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
auth acme.Interface
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-account": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), accContextKey, nil)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/getCertificate-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
err: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
getCertificate: func(accID, id string) ([]byte, error) {
|
||||||
|
assert.Equals(t, accID, acc.ID)
|
||||||
|
assert.Equals(t, id, certID)
|
||||||
|
return certBytes, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(tc.auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.GetCertificate(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes))
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain; charset=utf-8"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ch() acme.Challenge {
|
||||||
|
return acme.Challenge{
|
||||||
|
Type: "http-01",
|
||||||
|
Status: "pending",
|
||||||
|
Token: "tok2",
|
||||||
|
URL: "https://ca.smallstep.com/acme/challenge/chID",
|
||||||
|
ID: "chID",
|
||||||
|
AuthzID: "authzID",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerGetChallenge(t *testing.T) {
|
||||||
|
chiCtx := chi.NewRouteContext()
|
||||||
|
chiCtx.URLParams.Add("chID", "chID")
|
||||||
|
url := fmt.Sprintf("http://ca.smallstep.com/acme/challenge/%s", "chID")
|
||||||
|
prov := newProv()
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
auth acme.Interface
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
ch acme.Challenge
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-account": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/validate-challenge-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
err: acme.UnauthorizedErr(nil),
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 401,
|
||||||
|
problem: acme.UnauthorizedErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/get-challenge-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
err: acme.UnauthorizedErr(nil),
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 401,
|
||||||
|
problem: acme.UnauthorizedErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/validate-challenge": func(t *testing.T) test {
|
||||||
|
key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{ID: "accID", Key: key}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
ch := ch()
|
||||||
|
ch.Status = "valid"
|
||||||
|
ch.Validated = time.Now().UTC().Format(time.RFC3339)
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
validateChallenge: func(p provisioner.Interface, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, accID, acc.ID)
|
||||||
|
assert.Equals(t, id, ch.ID)
|
||||||
|
assert.Equals(t, jwk.KeyID, key.KeyID)
|
||||||
|
return &ch, nil
|
||||||
|
},
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
var ret string
|
||||||
|
switch count {
|
||||||
|
case 0:
|
||||||
|
assert.Equals(t, typ, acme.AuthzLink)
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{ch.AuthzID})
|
||||||
|
ret = fmt.Sprintf("https://ca.smallstep.com/acme/authz/%s", ch.AuthzID)
|
||||||
|
case 1:
|
||||||
|
assert.Equals(t, typ, acme.ChallengeLink)
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{ch.ID})
|
||||||
|
ret = url
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return ret
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
ch: ch,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(tc.auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.GetChallenge(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
expB, err := json.Marshal(tc.ch)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
|
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<https://ca.smallstep.com/acme/authz/%s>;rel=\"up\"", tc.ch.AuthzID)})
|
||||||
|
assert.Equals(t, res.Header["Location"], []string{url})
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
377
acme/api/middleware.go
Normal file
377
acme/api/middleware.go
Normal file
|
@ -0,0 +1,377 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/logging"
|
||||||
|
"github.com/smallstep/cli/crypto/keys"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||||
|
|
||||||
|
func logNonce(w http.ResponseWriter, nonce string) {
|
||||||
|
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||||
|
m := map[string]interface{}{
|
||||||
|
"nonce": nonce,
|
||||||
|
}
|
||||||
|
rl.WithFields(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNonce is a middleware that adds a nonce to the response header.
|
||||||
|
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
nonce, err := h.Auth.NewNonce()
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Replay-Nonce", nonce)
|
||||||
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
|
logNonce(w, nonce)
|
||||||
|
next(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addDirLink is a middleware that adds a 'Link' response reader with the
|
||||||
|
// directory index url.
|
||||||
|
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Add("Link", link(h.Auth.GetLink(acme.DirectoryLink, acme.URLSafeProvisionerName(prov), true), "index"))
|
||||||
|
next(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyContentType is a middleware that verifies that content type is
|
||||||
|
// application/jose+json.
|
||||||
|
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ct := r.Header.Get("Content-Type")
|
||||||
|
var expected []string
|
||||||
|
if strings.Contains(r.URL.Path, h.Auth.GetLink(acme.CertificateLink, acme.URLSafeProvisionerName(prov), false, "")) {
|
||||||
|
// GET /certificate requests allow a greater range of content types.
|
||||||
|
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||||
|
} else {
|
||||||
|
// By default every request should have content-type applictaion/jose+json.
|
||||||
|
expected = []string{"application/jose+json"}
|
||||||
|
}
|
||||||
|
for _, e := range expected {
|
||||||
|
if ct == e {
|
||||||
|
next(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf(
|
||||||
|
"expected content-type to be in %s, but got %s", expected, ct)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
|
||||||
|
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jws, err := jose.ParseJWS(string(body))
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateJWS checks the request body for to verify that it meets ACME
|
||||||
|
// requirements for a JWS.
|
||||||
|
//
|
||||||
|
// The JWS MUST NOT have multiple signatures
|
||||||
|
// The JWS Unencoded Payload Option [RFC7797] MUST NOT be used
|
||||||
|
// The JWS Unprotected Header [RFC7515] MUST NOT be used
|
||||||
|
// The JWS Payload MUST NOT be detached
|
||||||
|
// The JWS Protected Header MUST include the following fields:
|
||||||
|
// * “alg” (Algorithm)
|
||||||
|
// * This field MUST NOT contain “none” or a Message Authentication Code
|
||||||
|
// (MAC) algorithm (e.g. one in which the algorithm registry description
|
||||||
|
// mentions MAC/HMAC).
|
||||||
|
// * “nonce” (defined in Section 6.5)
|
||||||
|
// * “url” (defined in Section 6.4)
|
||||||
|
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||||
|
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
jws, err := jwsFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(jws.Signatures) == 0 {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(jws.Signatures) > 1 {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sig := jws.Signatures[0]
|
||||||
|
uh := sig.Unprotected
|
||||||
|
if len(uh.KeyID) > 0 ||
|
||||||
|
uh.JSONWebKey != nil ||
|
||||||
|
len(uh.Algorithm) > 0 ||
|
||||||
|
len(uh.Nonce) > 0 ||
|
||||||
|
len(uh.ExtraHeaders) > 0 {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hdr := sig.Protected
|
||||||
|
switch hdr.Algorithm {
|
||||||
|
case jose.RS256, jose.RS384, jose.RS512:
|
||||||
|
if hdr.JSONWebKey != nil {
|
||||||
|
switch k := hdr.JSONWebKey.Key.(type) {
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
if k.Size() < keys.MinRSAKeyBytes {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+
|
||||||
|
"keys must be at least %d bits (%d bytes) in size",
|
||||||
|
8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
||||||
|
// we good
|
||||||
|
default:
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the validity/freshness of the Nonce.
|
||||||
|
if err := h.Auth.UseNonce(hdr.Nonce); err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the JWS url matches the requested url.
|
||||||
|
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
||||||
|
if !ok {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
||||||
|
if jwsURL != reqURL.String() {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
|
||||||
|
// in the context. Make sure to parse and validate the JWS before running this
|
||||||
|
// middleware.
|
||||||
|
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jws, err := jwsFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jwk := jws.Signatures[0].Protected.JSONWebKey
|
||||||
|
if jwk == nil {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !jwk.Valid() {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
|
acc, err := h.Auth.GetAccountByKey(prov, jwk)
|
||||||
|
switch {
|
||||||
|
case nosql.IsErrNotFound(err):
|
||||||
|
// For NewAccount requests ...
|
||||||
|
break
|
||||||
|
case err != nil:
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if !acc.IsValid() {
|
||||||
|
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
}
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupProvisioner loads the provisioner associated with the request.
|
||||||
|
// Responsds 404 if the provisioner does not exist.
|
||||||
|
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
name := chi.URLParam(r, "provisionerID")
|
||||||
|
provID, err := url.PathUnescape(name)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p, err := h.Auth.LoadProvisionerByID("acme/" + provID)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if p.GetType() != provisioner.TypeACME {
|
||||||
|
api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, provisionerContextKey, p)
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||||
|
// kid parameter of the signed payload.
|
||||||
|
// Make sure to parse and validate the JWS before running this middleware.
|
||||||
|
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jws, err := jwsFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
kidPrefix := h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, "")
|
||||||
|
kid := jws.Signatures[0].Protected.KeyID
|
||||||
|
if !strings.HasPrefix(kid, kidPrefix) {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+
|
||||||
|
"required prefix; expected %s, but got %s", kidPrefix, kid)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||||
|
acc, err := h.Auth.GetAccount(prov, accID)
|
||||||
|
switch {
|
||||||
|
case nosql.IsErrNotFound(err):
|
||||||
|
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if !acc.IsValid() {
|
||||||
|
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
|
||||||
|
// Make sure to parse and validate the JWS before running this middleware.
|
||||||
|
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
jws, err := jwsFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jwk, err := jwkFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload, err := jws.Verify(jwk)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(r.Context(), payloadContextKey, &payloadInfo{
|
||||||
|
value: payload,
|
||||||
|
isPostAsGet: string(payload) == "",
|
||||||
|
isEmptyJSON: string(payload) == "{}",
|
||||||
|
})
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
|
||||||
|
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
payload, err := payloadFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !payload.isPostAsGet {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
1550
acme/api/middleware_test.go
Normal file
1550
acme/api/middleware_test.go
Normal file
File diff suppressed because it is too large
Load diff
164
acme/api/order.go
Normal file
164
acme/api/order.go
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewOrderRequest represents the body for a NewOrder request.
|
||||||
|
type NewOrderRequest struct {
|
||||||
|
Identifiers []acme.Identifier `json:"identifiers"`
|
||||||
|
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||||
|
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates a new-order request body.
|
||||||
|
func (n *NewOrderRequest) Validate() error {
|
||||||
|
if len(n.Identifiers) == 0 {
|
||||||
|
return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty"))
|
||||||
|
}
|
||||||
|
for _, id := range n.Identifiers {
|
||||||
|
if id.Type != "dns" {
|
||||||
|
return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeRequest captures the body for a Finalize order request.
|
||||||
|
type FinalizeRequest struct {
|
||||||
|
CSR string `json:"csr"`
|
||||||
|
csr *x509.CertificateRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates a finalize request body.
|
||||||
|
func (f *FinalizeRequest) Validate() error {
|
||||||
|
var err error
|
||||||
|
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
|
||||||
|
if err != nil {
|
||||||
|
return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr"))
|
||||||
|
}
|
||||||
|
f.csr, err = x509.ParseCertificateRequest(csrBytes)
|
||||||
|
if err != nil {
|
||||||
|
return acme.MalformedErr(errors.Wrap(err, "unable to parse csr"))
|
||||||
|
}
|
||||||
|
if err = f.csr.CheckSignature(); err != nil {
|
||||||
|
return acme.MalformedErr(errors.Wrap(err, "csr failed signature check"))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOrder ACME api for creating a new order.
|
||||||
|
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acc, err := accountFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload, err := payloadFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var nor NewOrderRequest
|
||||||
|
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Wrap(err,
|
||||||
|
"failed to unmarshal new-order request payload")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := nor.Validate(); err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o, err := h.Auth.NewOrder(prov, acme.OrderOptions{
|
||||||
|
AccountID: acc.GetID(),
|
||||||
|
Identifiers: nor.Identifiers,
|
||||||
|
NotBefore: nor.NotBefore,
|
||||||
|
NotAfter: nor.NotAfter,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
|
||||||
|
api.JSONStatus(w, o, http.StatusCreated)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrder ACME api for retrieving an order.
|
||||||
|
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acc, err := accountFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
oid := chi.URLParam(r, "ordID")
|
||||||
|
o, err := h.Auth.GetOrder(prov, acc.GetID(), oid)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID()))
|
||||||
|
api.JSON(w, o)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||||
|
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov, err := provisionerFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acc, err := accountFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload, err := payloadFromContext(r)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var fr FinalizeRequest
|
||||||
|
if err := json.Unmarshal(payload.value, &fr); err != nil {
|
||||||
|
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := fr.Validate(); err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oid := chi.URLParam(r, "ordID")
|
||||||
|
o, err := h.Auth.FinalizeOrder(prov, acc.GetID(), oid, fr.csr)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.ID))
|
||||||
|
api.JSON(w, o)
|
||||||
|
return
|
||||||
|
}
|
757
acme/api/order_test.go
Normal file
757
acme/api/order_test.go
Normal file
|
@ -0,0 +1,757 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewOrderRequestValidate(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
nor *NewOrderRequest
|
||||||
|
nbf, naf time.Time
|
||||||
|
err *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-identifiers": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
nor: &NewOrderRequest{},
|
||||||
|
err: acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/bad-identifier": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
nor: &NewOrderRequest{
|
||||||
|
Identifiers: []acme.Identifier{
|
||||||
|
{Type: "dns", Value: "example.com"},
|
||||||
|
{Type: "foo", Value: "bar.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: acme.MalformedErr(errors.Errorf("identifier type unsupported: foo")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
nbf := time.Now().UTC().Add(time.Minute)
|
||||||
|
naf := time.Now().UTC().Add(5 * time.Minute)
|
||||||
|
return test{
|
||||||
|
nor: &NewOrderRequest{
|
||||||
|
Identifiers: []acme.Identifier{
|
||||||
|
{Type: "dns", Value: "example.com"},
|
||||||
|
{Type: "dns", Value: "bar.com"},
|
||||||
|
},
|
||||||
|
NotAfter: naf,
|
||||||
|
NotBefore: nbf,
|
||||||
|
},
|
||||||
|
nbf: nbf,
|
||||||
|
naf: naf,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
if err := tc.nor.Validate(); err != nil {
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
ae, ok := err.(*acme.Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
if tc.nbf.IsZero() {
|
||||||
|
assert.True(t, tc.nor.NotBefore.Before(time.Now().Add(time.Minute)))
|
||||||
|
assert.True(t, tc.nor.NotBefore.After(time.Now().Add(-time.Minute)))
|
||||||
|
} else {
|
||||||
|
assert.Equals(t, tc.nor.NotBefore, tc.nbf)
|
||||||
|
}
|
||||||
|
if tc.naf.IsZero() {
|
||||||
|
assert.True(t, tc.nor.NotAfter.Before(time.Now().Add(24*time.Hour)))
|
||||||
|
assert.True(t, tc.nor.NotAfter.After(time.Now().Add(24*time.Hour-time.Minute)))
|
||||||
|
} else {
|
||||||
|
assert.Equals(t, tc.nor.NotAfter, tc.naf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeRequestValidate(t *testing.T) {
|
||||||
|
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
csr, ok := _csr.(*x509.CertificateRequest)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
type test struct {
|
||||||
|
fr *FinalizeRequest
|
||||||
|
err *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/parse-csr-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
fr: &FinalizeRequest{},
|
||||||
|
err: acme.MalformedErr(errors.Errorf("unable to parse csr: asn1: syntax error: sequence truncated")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/invalid-csr-signature": func(t *testing.T) test {
|
||||||
|
b, err := pemutil.Read("../../authority/testdata/certs/badsig.csr")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
c, ok := b.(*x509.CertificateRequest)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
return test{
|
||||||
|
fr: &FinalizeRequest{
|
||||||
|
CSR: base64.RawURLEncoding.EncodeToString(c.Raw),
|
||||||
|
},
|
||||||
|
err: acme.MalformedErr(errors.Errorf("csr failed signature check: x509: ECDSA verification failure")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
fr: &FinalizeRequest{
|
||||||
|
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
if err := tc.fr.Validate(); err != nil {
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
ae, ok := err.(*acme.Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.fr.csr.Raw, csr.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerGetOrder(t *testing.T) {
|
||||||
|
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||||
|
nbf := time.Now().UTC()
|
||||||
|
naf := time.Now().UTC().Add(24 * time.Hour)
|
||||||
|
o := acme.Order{
|
||||||
|
ID: "orderID",
|
||||||
|
Expires: expiry.Format(time.RFC3339),
|
||||||
|
NotBefore: nbf.Format(time.RFC3339),
|
||||||
|
NotAfter: naf.Format(time.RFC3339),
|
||||||
|
Identifiers: []acme.Identifier{
|
||||||
|
{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "dns",
|
||||||
|
Value: "*.smallstep.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Status: "pending",
|
||||||
|
Authorizations: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request with chi context
|
||||||
|
chiCtx := chi.NewRouteContext()
|
||||||
|
chiCtx.URLParams.Add("ordID", o.ID)
|
||||||
|
prov := newProv()
|
||||||
|
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), o.ID)
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
auth acme.Interface
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-account": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/getOrder-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
err: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
getOrder: func(p provisioner.Interface, accID, id string) (*acme.Order, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, accID, acc.ID)
|
||||||
|
assert.Equals(t, id, o.ID)
|
||||||
|
return &o, nil
|
||||||
|
},
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, typ, acme.OrderLink)
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{o.ID})
|
||||||
|
return url
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(tc.auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.GetOrder(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
expB, err := json.Marshal(o)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
|
assert.Equals(t, res.Header["Location"], []string{url})
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerNewOrder(t *testing.T) {
|
||||||
|
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||||
|
nbf := time.Now().UTC().Add(5 * time.Hour)
|
||||||
|
naf := nbf.Add(17 * time.Hour)
|
||||||
|
o := acme.Order{
|
||||||
|
ID: "orderID",
|
||||||
|
Expires: expiry.Format(time.RFC3339),
|
||||||
|
NotBefore: nbf.Format(time.RFC3339),
|
||||||
|
NotAfter: naf.Format(time.RFC3339),
|
||||||
|
Identifiers: []acme.Identifier{
|
||||||
|
{Type: "dns", Value: "example.com"},
|
||||||
|
{Type: "dns", Value: "bar.com"},
|
||||||
|
},
|
||||||
|
Status: "pending",
|
||||||
|
Authorizations: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
|
||||||
|
prov := newProv()
|
||||||
|
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order",
|
||||||
|
acme.URLSafeProvisionerName(prov))
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
auth acme.Interface
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-account": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("failed to unmarshal new-order request payload: unexpected end of JSON input")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
nor := &NewOrderRequest{}
|
||||||
|
b, err := json.Marshal(nor)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("identifiers list cannot be empty")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/NewOrder-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
nor := &NewOrderRequest{
|
||||||
|
Identifiers: []acme.Identifier{
|
||||||
|
{Type: "dns", Value: "example.com"},
|
||||||
|
{Type: "dns", Value: "bar.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nor)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, ops.AccountID, acc.ID)
|
||||||
|
assert.Equals(t, ops.Identifiers, nor.Identifiers)
|
||||||
|
return nil, acme.MalformedErr(errors.New("force"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
nor := &NewOrderRequest{
|
||||||
|
Identifiers: []acme.Identifier{
|
||||||
|
{Type: "dns", Value: "example.com"},
|
||||||
|
{Type: "dns", Value: "bar.com"},
|
||||||
|
},
|
||||||
|
NotBefore: nbf,
|
||||||
|
NotAfter: naf,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nor)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, ops.AccountID, acc.ID)
|
||||||
|
assert.Equals(t, ops.Identifiers, nor.Identifiers)
|
||||||
|
assert.Equals(t, ops.NotBefore, nbf)
|
||||||
|
assert.Equals(t, ops.NotAfter, naf)
|
||||||
|
return &o, nil
|
||||||
|
},
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, typ, acme.OrderLink)
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{o.ID})
|
||||||
|
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 201,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/default-naf-nbf": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
nor := &NewOrderRequest{
|
||||||
|
Identifiers: []acme.Identifier{
|
||||||
|
{Type: "dns", Value: "example.com"},
|
||||||
|
{Type: "dns", Value: "bar.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nor)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, ops.AccountID, acc.ID)
|
||||||
|
assert.Equals(t, ops.Identifiers, nor.Identifiers)
|
||||||
|
|
||||||
|
assert.True(t, ops.NotBefore.IsZero())
|
||||||
|
assert.True(t, ops.NotAfter.IsZero())
|
||||||
|
return &o, nil
|
||||||
|
},
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, typ, acme.OrderLink)
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{o.ID})
|
||||||
|
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 201,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(tc.auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.NewOrder(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
expB, err := json.Marshal(o)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
|
assert.Equals(t, res.Header["Location"],
|
||||||
|
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)})
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerFinalizeOrder(t *testing.T) {
|
||||||
|
expiry := time.Now().UTC().Add(6 * time.Hour)
|
||||||
|
nbf := time.Now().UTC().Add(5 * time.Hour)
|
||||||
|
naf := nbf.Add(17 * time.Hour)
|
||||||
|
o := acme.Order{
|
||||||
|
ID: "orderID",
|
||||||
|
Expires: expiry.Format(time.RFC3339),
|
||||||
|
NotBefore: nbf.Format(time.RFC3339),
|
||||||
|
NotAfter: naf.Format(time.RFC3339),
|
||||||
|
Identifiers: []acme.Identifier{
|
||||||
|
{Type: "dns", Value: "example.com"},
|
||||||
|
{Type: "dns", Value: "bar.com"},
|
||||||
|
},
|
||||||
|
Status: "valid",
|
||||||
|
Authorizations: []string{"foo", "bar"},
|
||||||
|
Certificate: "https://ca.smallstep.com/acme/certificate/certID",
|
||||||
|
}
|
||||||
|
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
csr, ok := _csr.(*x509.CertificateRequest)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
|
||||||
|
// Request with chi context
|
||||||
|
chiCtx := chi.NewRouteContext()
|
||||||
|
chiCtx.URLParams.Add("ordID", o.ID)
|
||||||
|
prov := newProv()
|
||||||
|
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s/finalize",
|
||||||
|
acme.URLSafeProvisionerName(prov), o.ID)
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
auth acme.Interface
|
||||||
|
ctx context.Context
|
||||||
|
statusCode int
|
||||||
|
problem *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/no-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.Background(),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-provisioner": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-account": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-account": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, nil)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 404,
|
||||||
|
problem: acme.AccountDoesNotExistErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/no-payload": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/nil-payload": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, nil)
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 500,
|
||||||
|
problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-payload-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("failed to unmarshal finalize-order request payload: unexpected end of JSON input")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/malformed-payload-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
fr := &FinalizeRequest{}
|
||||||
|
b, err := json.Marshal(fr)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("unable to parse csr: asn1: syntax error: sequence truncated")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/FinalizeOrder-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
nor := &FinalizeRequest{
|
||||||
|
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nor)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, accID, acc.ID)
|
||||||
|
assert.Equals(t, id, o.ID)
|
||||||
|
assert.Equals(t, incsr.Raw, csr.Raw)
|
||||||
|
return nil, acme.MalformedErr(errors.New("force"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 400,
|
||||||
|
problem: acme.MalformedErr(errors.New("force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{ID: "accID"}
|
||||||
|
nor := &FinalizeRequest{
|
||||||
|
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(nor)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
|
||||||
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
||||||
|
return test{
|
||||||
|
auth: &mockAcmeAuthority{
|
||||||
|
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
|
||||||
|
assert.Equals(t, p, prov)
|
||||||
|
assert.Equals(t, accID, acc.ID)
|
||||||
|
assert.Equals(t, id, o.ID)
|
||||||
|
assert.Equals(t, incsr.Raw, csr.Raw)
|
||||||
|
return &o, nil
|
||||||
|
},
|
||||||
|
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string {
|
||||||
|
assert.Equals(t, typ, acme.OrderLink)
|
||||||
|
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
|
||||||
|
assert.True(t, abs)
|
||||||
|
assert.Equals(t, in, []string{o.ID})
|
||||||
|
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), o.ID)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := New(tc.auth).(*Handler)
|
||||||
|
req := httptest.NewRequest("GET", url, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.FinalizeOrder(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) {
|
||||||
|
var ae acme.AError
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
prob := tc.problem.ToACME()
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, prob.Type)
|
||||||
|
assert.Equals(t, ae.Detail, prob.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, prob.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, prob.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
expB, err := json.Marshal(o)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
|
assert.Equals(t, res.Header["Location"],
|
||||||
|
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s",
|
||||||
|
acme.URLSafeProvisionerName(prov), o.ID)})
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
263
acme/authority.go
Normal file
263
acme/authority.go
Normal file
|
@ -0,0 +1,263 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Interface is the acme authority interface.
|
||||||
|
type Interface interface {
|
||||||
|
DeactivateAccount(provisioner.Interface, string) (*Account, error)
|
||||||
|
FinalizeOrder(provisioner.Interface, string, string, *x509.CertificateRequest) (*Order, error)
|
||||||
|
GetAccount(provisioner.Interface, string) (*Account, error)
|
||||||
|
GetAccountByKey(provisioner.Interface, *jose.JSONWebKey) (*Account, error)
|
||||||
|
GetAuthz(provisioner.Interface, string, string) (*Authz, error)
|
||||||
|
GetCertificate(string, string) ([]byte, error)
|
||||||
|
GetDirectory(provisioner.Interface) *Directory
|
||||||
|
GetLink(Link, string, bool, ...string) string
|
||||||
|
GetOrder(provisioner.Interface, string, string) (*Order, error)
|
||||||
|
GetOrdersByAccount(provisioner.Interface, string) ([]string, error)
|
||||||
|
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||||
|
NewAccount(provisioner.Interface, AccountOptions) (*Account, error)
|
||||||
|
NewNonce() (string, error)
|
||||||
|
NewOrder(provisioner.Interface, OrderOptions) (*Order, error)
|
||||||
|
UpdateAccount(provisioner.Interface, string, []string) (*Account, error)
|
||||||
|
UseNonce(string) error
|
||||||
|
ValidateChallenge(provisioner.Interface, string, string, *jose.JSONWebKey) (*Challenge, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authority is the layer that handles all ACME interactions.
|
||||||
|
type Authority struct {
|
||||||
|
db nosql.DB
|
||||||
|
dir *directory
|
||||||
|
signAuth SignAuthority
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthority returns a new Authority that implements the ACME interface.
|
||||||
|
func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) *Authority {
|
||||||
|
return &Authority{
|
||||||
|
db: db, dir: newDirectory(dns, prefix), signAuth: signAuth,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLink returns the requested link from the directory.
|
||||||
|
func (a *Authority) GetLink(typ Link, provID string, abs bool, inputs ...string) string {
|
||||||
|
return a.dir.getLink(typ, provID, abs, inputs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDirectory returns the ACME directory object.
|
||||||
|
func (a *Authority) GetDirectory(p provisioner.Interface) *Directory {
|
||||||
|
name := url.PathEscape(p.GetName())
|
||||||
|
return &Directory{
|
||||||
|
NewNonce: a.dir.getLink(NewNonceLink, name, true),
|
||||||
|
NewAccount: a.dir.getLink(NewAccountLink, name, true),
|
||||||
|
NewOrder: a.dir.getLink(NewOrderLink, name, true),
|
||||||
|
RevokeCert: a.dir.getLink(RevokeCertLink, name, true),
|
||||||
|
KeyChange: a.dir.getLink(KeyChangeLink, name, true),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadProvisionerByID calls out to the SignAuthority interface to load a
|
||||||
|
// provisioner by ID.
|
||||||
|
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||||
|
return a.signAuth.LoadProvisionerByID(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNonce generates, stores, and returns a new ACME nonce.
|
||||||
|
func (a *Authority) NewNonce() (string, error) {
|
||||||
|
n, err := newNonce(a.db)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return n.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseNonce consumes the given nonce if it is valid, returns error otherwise.
|
||||||
|
func (a *Authority) UseNonce(nonce string) error {
|
||||||
|
return useNonce(a.db, nonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAccount creates, stores, and returns a new ACME account.
|
||||||
|
func (a *Authority) NewAccount(p provisioner.Interface, ao AccountOptions) (*Account, error) {
|
||||||
|
acc, err := newAccount(a.db, ao)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return acc.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccount updates an ACME account.
|
||||||
|
func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*Account, error) {
|
||||||
|
acc, err := getAccountByID(a.db, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(err)
|
||||||
|
}
|
||||||
|
if acc, err = acc.update(a.db, contact); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return acc.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount returns an ACME account.
|
||||||
|
func (a *Authority) GetAccount(p provisioner.Interface, id string) (*Account, error) {
|
||||||
|
acc, err := getAccountByID(a.db, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return acc.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeactivateAccount deactivates an ACME account.
|
||||||
|
func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Account, error) {
|
||||||
|
acc, err := getAccountByID(a.db, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if acc, err = acc.deactivate(a.db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return acc.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyToID(jwk *jose.JSONWebKey) (string, error) {
|
||||||
|
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||||
|
if err != nil {
|
||||||
|
return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint"))
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(kid), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByKey returns the ACME associated with the jwk id.
|
||||||
|
func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*Account, error) {
|
||||||
|
kid, err := keyToID(jwk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
acc, err := getAccountByKeyID(a.db, kid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return acc.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrder returns an ACME order.
|
||||||
|
func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*Order, error) {
|
||||||
|
o, err := getOrder(a.db, orderID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accID != o.AccountID {
|
||||||
|
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||||
|
}
|
||||||
|
if o, err = o.updateStatus(a.db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return o.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrdersByAccount returns the list of order urls owned by the account.
|
||||||
|
func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) {
|
||||||
|
oids, err := getOrderIDsByAccount(a.db, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ret = []string{}
|
||||||
|
for _, oid := range oids {
|
||||||
|
o, err := getOrder(a.db, oid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(err)
|
||||||
|
}
|
||||||
|
if o.Status == StatusInvalid {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ret = append(ret, a.dir.getLink(OrderLink, URLSafeProvisionerName(p), true, o.ID))
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOrder generates, stores, and returns a new ACME order.
|
||||||
|
func (a *Authority) NewOrder(p provisioner.Interface, ops OrderOptions) (*Order, error) {
|
||||||
|
order, err := newOrder(a.db, ops)
|
||||||
|
if err != nil {
|
||||||
|
return nil, Wrap(err, "error creating order")
|
||||||
|
}
|
||||||
|
return order.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeOrder attempts to finalize an order and generate a new certificate.
|
||||||
|
func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
|
||||||
|
o, err := getOrder(a.db, orderID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accID != o.AccountID {
|
||||||
|
return nil, UnauthorizedErr(errors.New("account does not own order"))
|
||||||
|
}
|
||||||
|
o, err = o.finalize(a.db, csr, a.signAuth, p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, Wrap(err, "error finalizing order")
|
||||||
|
}
|
||||||
|
return o.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthz retrieves and attempts to update the status on an ACME authz
|
||||||
|
// before returning.
|
||||||
|
func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*Authz, error) {
|
||||||
|
authz, err := getAuthz(a.db, authzID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accID != authz.getAccountID() {
|
||||||
|
return nil, UnauthorizedErr(errors.New("account does not own authz"))
|
||||||
|
}
|
||||||
|
authz, err = authz.updateStatus(a.db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, Wrap(err, "error updating authz status")
|
||||||
|
}
|
||||||
|
return authz.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateChallenge attempts to validate the challenge.
|
||||||
|
func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) {
|
||||||
|
ch, err := getChallenge(a.db, chID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accID != ch.getAccountID() {
|
||||||
|
return nil, UnauthorizedErr(errors.New("account does not own challenge"))
|
||||||
|
}
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: time.Duration(30 * time.Second),
|
||||||
|
}
|
||||||
|
ch, err = ch.validate(a.db, jwk, validateOptions{
|
||||||
|
httpGet: client.Get,
|
||||||
|
lookupTxt: net.LookupTXT,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, Wrap(err, "error attempting challenge validation")
|
||||||
|
}
|
||||||
|
return ch.toACME(a.db, a.dir, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCertificate retrieves the Certificate by ID.
|
||||||
|
func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) {
|
||||||
|
cert, err := getCert(a.db, certID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accID != cert.AccountID {
|
||||||
|
return nil, UnauthorizedErr(errors.New("account does not own certificate"))
|
||||||
|
}
|
||||||
|
return cert.toACME(a.db, a.dir)
|
||||||
|
}
|
1474
acme/authority_test.go
Normal file
1474
acme/authority_test.go
Normal file
File diff suppressed because it is too large
Load diff
344
acme/authz.go
Normal file
344
acme/authz.go
Normal file
|
@ -0,0 +1,344 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultExpiryDuration = time.Hour * 24
|
||||||
|
|
||||||
|
// Authz is a subset of the Authz type containing only those attributes
|
||||||
|
// required for responses in the ACME protocol.
|
||||||
|
type Authz struct {
|
||||||
|
Identifier Identifier `json:"identifier"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Expires string `json:"expires"`
|
||||||
|
Challenges []*Challenge `json:"challenges"`
|
||||||
|
Wildcard bool `json:"wildcard"`
|
||||||
|
ID string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToLog enables response logging.
|
||||||
|
func (a *Authz) ToLog() (interface{}, error) {
|
||||||
|
b, err := json.Marshal(a)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging"))
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the Authz ID.
|
||||||
|
func (a *Authz) GetID() string {
|
||||||
|
return a.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// authz is the interface that the various authz types must implement.
|
||||||
|
type authz interface {
|
||||||
|
save(nosql.DB, authz) error
|
||||||
|
clone() *baseAuthz
|
||||||
|
getID() string
|
||||||
|
getAccountID() string
|
||||||
|
getType() string
|
||||||
|
getIdentifier() Identifier
|
||||||
|
getStatus() string
|
||||||
|
getExpiry() time.Time
|
||||||
|
getWildcard() bool
|
||||||
|
getChallenges() []string
|
||||||
|
getCreated() time.Time
|
||||||
|
updateStatus(db nosql.DB) (authz, error)
|
||||||
|
toACME(nosql.DB, *directory, provisioner.Interface) (*Authz, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// baseAuthz is the base authz type that others build from.
|
||||||
|
type baseAuthz struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
AccountID string `json:"accountID"`
|
||||||
|
Identifier Identifier `json:"identifier"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Expires time.Time `json:"expires"`
|
||||||
|
Challenges []string `json:"challenges"`
|
||||||
|
Wildcard bool `json:"wildcard"`
|
||||||
|
Created time.Time `json:"created"`
|
||||||
|
Error *Error `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) {
|
||||||
|
id, err := randID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := clock.Now()
|
||||||
|
ba := &baseAuthz{
|
||||||
|
ID: id,
|
||||||
|
AccountID: accID,
|
||||||
|
Status: StatusPending,
|
||||||
|
Created: now,
|
||||||
|
Expires: now.Add(defaultExpiryDuration),
|
||||||
|
Identifier: identifier,
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(identifier.Value, "*.") {
|
||||||
|
ba.Wildcard = true
|
||||||
|
ba.Identifier = Identifier{
|
||||||
|
Value: strings.TrimPrefix(identifier.Value, "*."),
|
||||||
|
Type: identifier.Type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ba, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getID returns the ID of the authz.
|
||||||
|
func (ba *baseAuthz) getID() string {
|
||||||
|
return ba.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAccountID returns the Account ID that created the authz.
|
||||||
|
func (ba *baseAuthz) getAccountID() string {
|
||||||
|
return ba.AccountID
|
||||||
|
}
|
||||||
|
|
||||||
|
// getType returns the type of the authz.
|
||||||
|
func (ba *baseAuthz) getType() string {
|
||||||
|
return ba.Identifier.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
// getIdentifier returns the identifier for the authz.
|
||||||
|
func (ba *baseAuthz) getIdentifier() Identifier {
|
||||||
|
return ba.Identifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// getStatus returns the status of the authz.
|
||||||
|
func (ba *baseAuthz) getStatus() string {
|
||||||
|
return ba.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWildcard returns true if the authz identifier has a '*', false otherwise.
|
||||||
|
func (ba *baseAuthz) getWildcard() bool {
|
||||||
|
return ba.Wildcard
|
||||||
|
}
|
||||||
|
|
||||||
|
// getChallenges returns the authz challenge IDs.
|
||||||
|
func (ba *baseAuthz) getChallenges() []string {
|
||||||
|
return ba.Challenges
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExpiry returns the expiration time of the authz.
|
||||||
|
func (ba *baseAuthz) getExpiry() time.Time {
|
||||||
|
return ba.Expires
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCreated returns the created time of the authz.
|
||||||
|
func (ba *baseAuthz) getCreated() time.Time {
|
||||||
|
return ba.Created
|
||||||
|
}
|
||||||
|
|
||||||
|
// toACME converts the internal Authz type into the public acmeAuthz type for
|
||||||
|
// presentation in the ACME protocol.
|
||||||
|
func (ba *baseAuthz) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Authz, error) {
|
||||||
|
var chs = make([]*Challenge, len(ba.Challenges))
|
||||||
|
for i, chID := range ba.Challenges {
|
||||||
|
ch, err := getChallenge(db, chID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
chs[i], err = ch.toACME(db, dir, p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Authz{
|
||||||
|
Identifier: ba.Identifier,
|
||||||
|
Status: ba.getStatus(),
|
||||||
|
Challenges: chs,
|
||||||
|
Wildcard: ba.getWildcard(),
|
||||||
|
Expires: ba.Expires.Format(time.RFC3339),
|
||||||
|
ID: ba.ID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ba *baseAuthz) save(db nosql.DB, old authz) error {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
oldB, newB []byte
|
||||||
|
)
|
||||||
|
if old == nil {
|
||||||
|
oldB = nil
|
||||||
|
} else {
|
||||||
|
if oldB, err = json.Marshal(old); err != nil {
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error marshaling old authz"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newB, err = json.Marshal(ba); err != nil {
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error marshaling new authz"))
|
||||||
|
}
|
||||||
|
_, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return ServerInternalErr(errors.Wrapf(err, "error storing authz"))
|
||||||
|
case !swapped:
|
||||||
|
return ServerInternalErr(errors.Errorf("error storing authz; " +
|
||||||
|
"value has changed since last read"))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ba *baseAuthz) clone() *baseAuthz {
|
||||||
|
u := *ba
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ba *baseAuthz) storeAndReturnError(db nosql.DB, err *Error) error {
|
||||||
|
clone := ba.clone()
|
||||||
|
clone.Error = err
|
||||||
|
clone.save(db, ba)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ba *baseAuthz) parent() authz {
|
||||||
|
return &dnsAuthz{ba}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateStatus attempts to update the status on a baseAuthz and stores the
|
||||||
|
// updating object if necessary.
|
||||||
|
func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) {
|
||||||
|
newAuthz := ba.clone()
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
switch ba.Status {
|
||||||
|
case StatusInvalid:
|
||||||
|
return ba.parent(), nil
|
||||||
|
case StatusValid:
|
||||||
|
return ba.parent(), nil
|
||||||
|
case StatusPending:
|
||||||
|
// check expiry
|
||||||
|
if now.After(ba.Expires) {
|
||||||
|
newAuthz.Status = StatusInvalid
|
||||||
|
newAuthz.Error = MalformedErr(errors.New("authz has expired"))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var isValid = false
|
||||||
|
for _, chID := range ba.Challenges {
|
||||||
|
ch, err := getChallenge(db, chID)
|
||||||
|
if err != nil {
|
||||||
|
return ba, err
|
||||||
|
}
|
||||||
|
if ch.getStatus() == StatusValid {
|
||||||
|
isValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValid {
|
||||||
|
return ba.parent(), nil
|
||||||
|
}
|
||||||
|
newAuthz.Status = StatusValid
|
||||||
|
newAuthz.Error = nil
|
||||||
|
default:
|
||||||
|
return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := newAuthz.save(db, ba); err != nil {
|
||||||
|
return ba, err
|
||||||
|
}
|
||||||
|
return newAuthz.parent(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshalAuthz unmarshals an authz type into the correct sub-type.
|
||||||
|
func unmarshalAuthz(data []byte) (authz, error) {
|
||||||
|
var getType struct {
|
||||||
|
Identifier Identifier `json:"identifier"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &getType); err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type"))
|
||||||
|
}
|
||||||
|
|
||||||
|
switch getType.Identifier.Type {
|
||||||
|
case "dns":
|
||||||
|
var ba baseAuthz
|
||||||
|
if err := json.Unmarshal(data, &ba); err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz"))
|
||||||
|
}
|
||||||
|
return &dnsAuthz{&ba}, nil
|
||||||
|
default:
|
||||||
|
return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s",
|
||||||
|
getType.Identifier.Type))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsAuthz represents a dns acme authorization.
|
||||||
|
type dnsAuthz struct {
|
||||||
|
*baseAuthz
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAuthz returns a new acme authorization object based on the identifier
|
||||||
|
// type.
|
||||||
|
func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) {
|
||||||
|
switch identifier.Type {
|
||||||
|
case "dns":
|
||||||
|
a, err = newDNSAuthz(db, accID, identifier)
|
||||||
|
default:
|
||||||
|
err = MalformedErr(errors.Errorf("unexpected authz type %s",
|
||||||
|
identifier.Type))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDNSAuthz returns a new dns acme authorization object.
|
||||||
|
func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) {
|
||||||
|
ba, err := newBaseAuthz(accID, identifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ba.Challenges = []string{}
|
||||||
|
if !ba.Wildcard {
|
||||||
|
// http challenges are only permitted if the DNS is not a wildcard dns.
|
||||||
|
ch1, err := newHTTP01Challenge(db, ChallengeOptions{
|
||||||
|
AccountID: accID,
|
||||||
|
AuthzID: ba.ID,
|
||||||
|
Identifier: ba.Identifier})
|
||||||
|
if err != nil {
|
||||||
|
return nil, Wrap(err, "error creating http challenge")
|
||||||
|
}
|
||||||
|
ba.Challenges = append(ba.Challenges, ch1.getID())
|
||||||
|
}
|
||||||
|
ch2, err := newDNS01Challenge(db, ChallengeOptions{
|
||||||
|
AccountID: accID,
|
||||||
|
AuthzID: ba.ID,
|
||||||
|
Identifier: identifier})
|
||||||
|
if err != nil {
|
||||||
|
return nil, Wrap(err, "error creating dns challenge")
|
||||||
|
}
|
||||||
|
ba.Challenges = append(ba.Challenges, ch2.getID())
|
||||||
|
|
||||||
|
da := &dnsAuthz{ba}
|
||||||
|
if err := da.save(db, nil); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return da, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAuthz retrieves and unmarshals an ACME authz type from the database.
|
||||||
|
func getAuthz(db nosql.DB, id string) (authz, error) {
|
||||||
|
b, err := db.Get(authzTable, []byte(id))
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id))
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id))
|
||||||
|
}
|
||||||
|
az, err := unmarshalAuthz(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return az, nil
|
||||||
|
}
|
809
acme/authz_test.go
Normal file
809
acme/authz_test.go
Normal file
|
@ -0,0 +1,809 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
"github.com/smallstep/nosql/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAz() (authz, error) {
|
||||||
|
mockdb := &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return []byte("foo"), true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return newAuthz(mockdb, "1234", Identifier{
|
||||||
|
Type: "dns", Value: "acme.example.com",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAuthz(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
id string
|
||||||
|
db nosql.DB
|
||||||
|
az authz
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
id: az.getID(),
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, database.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db-error": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
id: az.getID(),
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_az, ok := az.(*dnsAuthz)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
_az.baseAuthz.Identifier.Type = "foo"
|
||||||
|
b, err := json.Marshal(az)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
id: az.getID(),
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, key, []byte(az.getID()))
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(az)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
id: az.getID(),
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, key, []byte(az.getID()))
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if az, err := getAuthz(tc.db, tc.id); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.az.getID(), az.getID())
|
||||||
|
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||||
|
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||||
|
assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
|
||||||
|
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||||
|
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||||
|
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthzClone(t *testing.T) {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
clone := az.clone()
|
||||||
|
|
||||||
|
assert.Equals(t, clone.getID(), az.getID())
|
||||||
|
assert.Equals(t, clone.getAccountID(), az.getAccountID())
|
||||||
|
assert.Equals(t, clone.getStatus(), az.getStatus())
|
||||||
|
assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
|
||||||
|
assert.Equals(t, clone.getExpiry(), az.getExpiry())
|
||||||
|
assert.Equals(t, clone.getCreated(), az.getCreated())
|
||||||
|
assert.Equals(t, clone.getChallenges(), az.getChallenges())
|
||||||
|
|
||||||
|
clone.Status = StatusValid
|
||||||
|
|
||||||
|
assert.NotEquals(t, clone.getStatus(), az.getStatus())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAuthz(t *testing.T) {
|
||||||
|
iden := Identifier{
|
||||||
|
Type: "dns", Value: "acme.example.com",
|
||||||
|
}
|
||||||
|
accID := "1234"
|
||||||
|
type test struct {
|
||||||
|
iden Identifier
|
||||||
|
db nosql.DB
|
||||||
|
err *Error
|
||||||
|
resChs *([]string)
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/unexpected-type": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
iden: Identifier{Type: "foo", Value: "acme.example.com"},
|
||||||
|
err: MalformedErr(errors.New("unexpected authz type foo")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/new-http-chall-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
iden: iden,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/new-dns-chall-error": func(t *testing.T) test {
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
iden: iden,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if count == 1 {
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/save-authz-error": func(t *testing.T) test {
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
iden: iden,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if count == 2 {
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
chs := &([]string{})
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
iden: iden,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if count == 2 {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
az, err := unmarshalAuthz(newval)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
assert.Equals(t, az.getID(), string(key))
|
||||||
|
assert.Equals(t, az.getAccountID(), accID)
|
||||||
|
assert.Equals(t, az.getStatus(), StatusPending)
|
||||||
|
assert.Equals(t, az.getIdentifier(), iden)
|
||||||
|
assert.Equals(t, az.getWildcard(), false)
|
||||||
|
|
||||||
|
*chs = az.getChallenges()
|
||||||
|
|
||||||
|
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||||
|
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||||
|
|
||||||
|
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||||
|
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||||
|
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resChs: chs,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/wildcard": func(t *testing.T) test {
|
||||||
|
chs := &([]string{})
|
||||||
|
count := 0
|
||||||
|
_iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
|
||||||
|
return test{
|
||||||
|
iden: _iden,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if count == 1 {
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
az, err := unmarshalAuthz(newval)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
assert.Equals(t, az.getID(), string(key))
|
||||||
|
assert.Equals(t, az.getAccountID(), accID)
|
||||||
|
assert.Equals(t, az.getStatus(), StatusPending)
|
||||||
|
assert.Equals(t, az.getIdentifier(), iden)
|
||||||
|
assert.Equals(t, az.getWildcard(), true)
|
||||||
|
|
||||||
|
*chs = az.getChallenges()
|
||||||
|
// Verify that we only have 1 challenge instead of 2.
|
||||||
|
assert.True(t, len(*chs) == 1)
|
||||||
|
|
||||||
|
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||||
|
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||||
|
|
||||||
|
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||||
|
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||||
|
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resChs: chs,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
az, err := newAuthz(tc.db, accID, tc.iden)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, az.getAccountID(), accID)
|
||||||
|
assert.Equals(t, az.getType(), "dns")
|
||||||
|
assert.Equals(t, az.getStatus(), StatusPending)
|
||||||
|
|
||||||
|
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||||||
|
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||||||
|
|
||||||
|
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||||||
|
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||||||
|
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||||||
|
|
||||||
|
assert.Equals(t, az.getChallenges(), *(tc.resChs))
|
||||||
|
|
||||||
|
if strings.HasPrefix(tc.iden.Value, "*.") {
|
||||||
|
assert.True(t, az.getWildcard())
|
||||||
|
assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
|
||||||
|
} else {
|
||||||
|
assert.False(t, az.getWildcard())
|
||||||
|
assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, az.getID() != "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthzToACME(t *testing.T) {
|
||||||
|
dir := newDirectory("ca.smallstep.com", "acme")
|
||||||
|
|
||||||
|
var (
|
||||||
|
ch1, ch2 challenge
|
||||||
|
ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
mockdb := &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if count == 0 {
|
||||||
|
*ch1Bytes = newval
|
||||||
|
ch1, err = unmarshalChallenge(newval)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
} else if count == 1 {
|
||||||
|
*ch2Bytes = newval
|
||||||
|
ch2, err = unmarshalChallenge(newval)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return []byte("foo"), true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
iden := Identifier{
|
||||||
|
Type: "dns", Value: "acme.example.com",
|
||||||
|
}
|
||||||
|
az, err := newAuthz(mockdb, "1234", iden)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
prov := newProv()
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/getChallenge1-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/getChallenge2-error": func(t *testing.T) test {
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
if count == 1 {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return *ch1Bytes, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
count := 0
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
if count == 0 {
|
||||||
|
count++
|
||||||
|
return *ch1Bytes, nil
|
||||||
|
}
|
||||||
|
return *ch2Bytes, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
acmeAz, err := az.toACME(tc.db, dir, prov)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, acmeAz.ID, az.getID())
|
||||||
|
assert.Equals(t, acmeAz.Identifier, iden)
|
||||||
|
assert.Equals(t, acmeAz.Status, StatusPending)
|
||||||
|
|
||||||
|
acmeCh1, err := ch1.toACME(nil, dir, prov)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acmeCh2, err := ch2.toACME(nil, dir, prov)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
|
||||||
|
assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
|
||||||
|
|
||||||
|
expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, expiry.String(), az.getExpiry().String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthzSave(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
az, old authz
|
||||||
|
db nosql.DB
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
old: nil,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
old: nil,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return []byte("foo"), false, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/old-nil": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(az)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
old: nil,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, b, newval)
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, []byte(az.getID()), key)
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/old-not-nil": func(t *testing.T) test {
|
||||||
|
oldAz, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
oldb, err := json.Marshal(oldAz)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(az)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
old: oldAz,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, old, oldb)
|
||||||
|
assert.Equals(t, b, newval)
|
||||||
|
assert.Equals(t, bucket, authzTable)
|
||||||
|
assert.Equals(t, []byte(az.getID()), key)
|
||||||
|
return []byte("foo"), true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if err := tc.az.save(tc.db, tc.old); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthzUnmarshal(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
az authz
|
||||||
|
azb []byte
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/nil": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
azb: nil,
|
||||||
|
err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unexpected-type": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_az, ok := az.(*dnsAuthz)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
_az.baseAuthz.Identifier.Type = "foo"
|
||||||
|
b, err := json.Marshal(az)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
azb: b,
|
||||||
|
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/dns": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(az)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
azb: b,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if az, err := unmarshalAuthz(tc.azb); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.az.getID(), az.getID())
|
||||||
|
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||||||
|
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||||||
|
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||||||
|
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||||||
|
assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
|
||||||
|
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthzUpdateStatus(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
az, res authz
|
||||||
|
err *Error
|
||||||
|
db nosql.DB
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/already-invalid": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_az, ok := az.(*dnsAuthz)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
_az.baseAuthz.Status = StatusInvalid
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
res: az,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/already-valid": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_az, ok := az.(*dnsAuthz)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
_az.baseAuthz.Status = StatusValid
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
res: az,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unexpected-status": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_az, ok := az.(*dnsAuthz)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
_az.baseAuthz.Status = StatusReady
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
res: az,
|
||||||
|
err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/save-error": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_az, ok := az.(*dnsAuthz)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
res: az,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/expired": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_az, ok := az.(*dnsAuthz)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||||||
|
|
||||||
|
clone := az.clone()
|
||||||
|
clone.Error = MalformedErr(errors.New("authz has expired"))
|
||||||
|
clone.Status = StatusInvalid
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
res: clone.parent(),
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/get-challenge-error": func(t *testing.T) test {
|
||||||
|
az, err := newAz()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
res: az,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error loading challenge")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/valid": func(t *testing.T) test {
|
||||||
|
var (
|
||||||
|
ch2 challenge
|
||||||
|
ch1Bytes = &([]byte{})
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
mockdb := &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if count == 0 {
|
||||||
|
*ch1Bytes = newval
|
||||||
|
} else if count == 1 {
|
||||||
|
ch2, err = unmarshalChallenge(newval)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
iden := Identifier{
|
||||||
|
Type: "dns", Value: "acme.example.com",
|
||||||
|
}
|
||||||
|
az, err := newAuthz(mockdb, "1234", iden)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
_az, ok := az.(*dnsAuthz)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
_az.baseAuthz.Error = MalformedErr(nil)
|
||||||
|
|
||||||
|
_ch, ok := ch2.(*dns01Challenge)
|
||||||
|
assert.Fatal(t, ok)
|
||||||
|
_ch.baseChallenge.Status = StatusValid
|
||||||
|
chb, err := json.Marshal(ch2)
|
||||||
|
|
||||||
|
clone := az.clone()
|
||||||
|
clone.Status = StatusValid
|
||||||
|
clone.Error = nil
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
res: clone.parent(),
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
if count == 0 {
|
||||||
|
count++
|
||||||
|
return *ch1Bytes, nil
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return chb, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/still-pending": func(t *testing.T) test {
|
||||||
|
var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
mockdb := &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if count == 0 {
|
||||||
|
*ch1Bytes = newval
|
||||||
|
} else if count == 1 {
|
||||||
|
*ch2Bytes = newval
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
iden := Identifier{
|
||||||
|
Type: "dns", Value: "acme.example.com",
|
||||||
|
}
|
||||||
|
az, err := newAuthz(mockdb, "1234", iden)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
return test{
|
||||||
|
az: az,
|
||||||
|
res: az,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
if count == 0 {
|
||||||
|
count++
|
||||||
|
return *ch1Bytes, nil
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return *ch2Bytes, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
az, err := tc.az.updateStatus(tc.db)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
expB, err := json.Marshal(tc.res)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(az)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, expB, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
89
acme/certificate.go
Normal file
89
acme/certificate.go
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
type certificate struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Created time.Time `json:"created"`
|
||||||
|
AccountID string `json:"accountID"`
|
||||||
|
OrderID string `json:"orderID"`
|
||||||
|
Leaf []byte `json:"leaf"`
|
||||||
|
Intermediates []byte `json:"intermediates"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CertOptions options with which to create and store a cert object.
|
||||||
|
type CertOptions struct {
|
||||||
|
AccountID string
|
||||||
|
OrderID string
|
||||||
|
Leaf *x509.Certificate
|
||||||
|
Intermediates []*x509.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCert(db nosql.DB, ops CertOptions) (*certificate, error) {
|
||||||
|
id, err := randID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
leaf := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: ops.Leaf.Raw,
|
||||||
|
})
|
||||||
|
var intermediates []byte
|
||||||
|
for _, cert := range ops.Intermediates {
|
||||||
|
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: cert.Raw,
|
||||||
|
})...)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := &certificate{
|
||||||
|
ID: id,
|
||||||
|
AccountID: ops.AccountID,
|
||||||
|
OrderID: ops.OrderID,
|
||||||
|
Leaf: leaf,
|
||||||
|
Intermediates: intermediates,
|
||||||
|
Created: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
certB, err := json.Marshal(cert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate"))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate"))
|
||||||
|
case !swapped:
|
||||||
|
return nil, ServerInternalErr(errors.New("error storing certificate; " +
|
||||||
|
"value has changed since last read"))
|
||||||
|
default:
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) {
|
||||||
|
return append(c.Leaf, c.Intermediates...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCert(db nosql.DB, id string) (*certificate, error) {
|
||||||
|
b, err := db.Get(certTable, []byte(id))
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id))
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate"))
|
||||||
|
}
|
||||||
|
var cert certificate
|
||||||
|
if err := json.Unmarshal(b, &cert); err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate"))
|
||||||
|
}
|
||||||
|
return &cert, nil
|
||||||
|
}
|
253
acme/certificate_test.go
Normal file
253
acme/certificate_test.go
Normal file
|
@ -0,0 +1,253 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
"github.com/smallstep/nosql/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultCertOps() (*CertOptions, error) {
|
||||||
|
crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &CertOptions{
|
||||||
|
AccountID: "accID",
|
||||||
|
OrderID: "ordID",
|
||||||
|
Leaf: crt,
|
||||||
|
Intermediates: []*x509.Certificate{inter, root},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newcert() (*certificate, error) {
|
||||||
|
ops, err := defaultCertOps()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mockdb := &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return newCert(mockdb, *ops)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCert(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
ops CertOptions
|
||||||
|
err *Error
|
||||||
|
id *string
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
ops, err := defaultCertOps()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
ops: *ops,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.Errorf("error storing certificate: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||||
|
ops, err := defaultCertOps()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
ops: *ops,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
return nil, false, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
ops, err := defaultCertOps()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
var _id string
|
||||||
|
id := &_id
|
||||||
|
return test{
|
||||||
|
ops: *ops,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
*id = string(key)
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if cert, err := newCert(tc.db, tc.ops); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, cert.ID, *tc.id)
|
||||||
|
assert.Equals(t, cert.AccountID, tc.ops.AccountID)
|
||||||
|
assert.Equals(t, cert.OrderID, tc.ops.OrderID)
|
||||||
|
|
||||||
|
leaf := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: tc.ops.Leaf.Raw,
|
||||||
|
})
|
||||||
|
var intermediates []byte
|
||||||
|
for _, cert := range tc.ops.Intermediates {
|
||||||
|
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: cert.Raw,
|
||||||
|
})...)
|
||||||
|
}
|
||||||
|
assert.Equals(t, cert.Leaf, leaf)
|
||||||
|
assert.Equals(t, cert.Intermediates, intermediates)
|
||||||
|
|
||||||
|
assert.True(t, cert.Created.Before(time.Now().Add(time.Minute)))
|
||||||
|
assert.True(t, cert.Created.After(time.Now().Add(-time.Minute)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCert(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
id string
|
||||||
|
db nosql.DB
|
||||||
|
cert *certificate
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
cert, err := newcert()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
cert: cert,
|
||||||
|
id: cert.ID,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, key, []byte(cert.ID))
|
||||||
|
return nil, database.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db-error": func(t *testing.T) test {
|
||||||
|
cert, err := newcert()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
cert: cert,
|
||||||
|
id: cert.ID,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, key, []byte(cert.ID))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error loading certificate: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||||||
|
cert, err := newcert()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
cert: cert,
|
||||||
|
id: cert.ID,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, key, []byte(cert.ID))
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
cert, err := newcert()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(cert)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
cert: cert,
|
||||||
|
id: cert.ID,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, certTable)
|
||||||
|
assert.Equals(t, key, []byte(cert.ID))
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if cert, err := getCert(tc.db, tc.id); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.cert.ID, cert.ID)
|
||||||
|
assert.Equals(t, tc.cert.AccountID, cert.AccountID)
|
||||||
|
assert.Equals(t, tc.cert.OrderID, cert.OrderID)
|
||||||
|
assert.Equals(t, tc.cert.Created, cert.Created)
|
||||||
|
assert.Equals(t, tc.cert.Leaf, cert.Leaf)
|
||||||
|
assert.Equals(t, tc.cert.Intermediates, cert.Intermediates)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertificateToACME(t *testing.T) {
|
||||||
|
cert, err := newcert()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acmeCert, err := cert.toACME(nil, nil)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert)
|
||||||
|
}
|
445
acme/challenge.go
Normal file
445
acme/challenge.go
Normal file
|
@ -0,0 +1,445 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Challenge is a subset of the challenge type containing only those attributes
|
||||||
|
// required for responses in the ACME protocol.
|
||||||
|
type Challenge struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
Validated string `json:"validated,omitempty"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Error *AError `json:"error,omitempty"`
|
||||||
|
ID string `json:"-"`
|
||||||
|
AuthzID string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToLog enables response logging.
|
||||||
|
func (c *Challenge) ToLog() (interface{}, error) {
|
||||||
|
b, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging"))
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the Challenge ID.
|
||||||
|
func (c *Challenge) GetID() string {
|
||||||
|
return c.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthzID returns the parent Authz ID that owns the Challenge.
|
||||||
|
func (c *Challenge) GetAuthzID() string {
|
||||||
|
return c.AuthzID
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpGetter func(string) (*http.Response, error)
|
||||||
|
type lookupTxt func(string) ([]string, error)
|
||||||
|
|
||||||
|
type validateOptions struct {
|
||||||
|
httpGet httpGetter
|
||||||
|
lookupTxt lookupTxt
|
||||||
|
}
|
||||||
|
|
||||||
|
// challenge is the interface ACME challenege types must implement.
|
||||||
|
type challenge interface {
|
||||||
|
save(db nosql.DB, swap challenge) error
|
||||||
|
validate(nosql.DB, *jose.JSONWebKey, validateOptions) (challenge, error)
|
||||||
|
getType() string
|
||||||
|
getError() *AError
|
||||||
|
getValue() string
|
||||||
|
getStatus() string
|
||||||
|
getID() string
|
||||||
|
getAuthzID() string
|
||||||
|
getToken() string
|
||||||
|
clone() *baseChallenge
|
||||||
|
getAccountID() string
|
||||||
|
getValidated() time.Time
|
||||||
|
getCreated() time.Time
|
||||||
|
toACME(nosql.DB, *directory, provisioner.Interface) (*Challenge, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChallengeOptions is the type used to created a new Challenge.
|
||||||
|
type ChallengeOptions struct {
|
||||||
|
AccountID string
|
||||||
|
AuthzID string
|
||||||
|
Identifier Identifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// baseChallenge is the base Challenge type that others build from.
|
||||||
|
type baseChallenge struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
AccountID string `json:"accountID"`
|
||||||
|
AuthzID string `json:"authzID"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
Validated time.Time `json:"validated"`
|
||||||
|
Created time.Time `json:"created"`
|
||||||
|
Error *AError `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBaseChallenge(accountID, authzID string) (*baseChallenge, error) {
|
||||||
|
id, err := randID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, Wrap(err, "error generating random id for ACME challenge")
|
||||||
|
}
|
||||||
|
token, err := randID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, Wrap(err, "error generating token for ACME challenge")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &baseChallenge{
|
||||||
|
ID: id,
|
||||||
|
AccountID: accountID,
|
||||||
|
AuthzID: authzID,
|
||||||
|
Status: StatusPending,
|
||||||
|
Token: token,
|
||||||
|
Created: clock.Now(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getID returns the id of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getID() string {
|
||||||
|
return bc.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAuthzID returns the authz ID of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getAuthzID() string {
|
||||||
|
return bc.AuthzID
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAccountID returns the account id of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getAccountID() string {
|
||||||
|
return bc.AccountID
|
||||||
|
}
|
||||||
|
|
||||||
|
// getType returns the type of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getType() string {
|
||||||
|
return bc.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
// getValue returns the type of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getValue() string {
|
||||||
|
return bc.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// getStatus returns the status of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getStatus() string {
|
||||||
|
return bc.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
// getToken returns the token of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getToken() string {
|
||||||
|
return bc.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
// getValidated returns the validated time of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getValidated() time.Time {
|
||||||
|
return bc.Validated
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCreated returns the created time of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getCreated() time.Time {
|
||||||
|
return bc.Created
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCreated returns the created time of the baseChallenge.
|
||||||
|
func (bc *baseChallenge) getError() *AError {
|
||||||
|
return bc.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// toACME converts the internal Challenge type into the public acmeChallenge
|
||||||
|
// type for presentation in the ACME protocol.
|
||||||
|
func (bc *baseChallenge) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Challenge, error) {
|
||||||
|
ac := &Challenge{
|
||||||
|
Type: bc.getType(),
|
||||||
|
Status: bc.getStatus(),
|
||||||
|
Token: bc.getToken(),
|
||||||
|
URL: dir.getLink(ChallengeLink, URLSafeProvisionerName(p), true, bc.getID()),
|
||||||
|
ID: bc.getID(),
|
||||||
|
AuthzID: bc.getAuthzID(),
|
||||||
|
}
|
||||||
|
if !bc.Validated.IsZero() {
|
||||||
|
ac.Validated = bc.Validated.Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
if bc.Error != nil {
|
||||||
|
ac.Error = bc.Error
|
||||||
|
}
|
||||||
|
return ac, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// save writes the challenge to disk. For new challenges 'old' should be nil,
|
||||||
|
// otherwise 'old' should be a pointer to the acme challenge as it was at the
|
||||||
|
// start of the request. This method will fail if the value currently found
|
||||||
|
// in the bucket/row does not match the value of 'old'.
|
||||||
|
func (bc *baseChallenge) save(db nosql.DB, old challenge) error {
|
||||||
|
newB, err := json.Marshal(bc)
|
||||||
|
if err != nil {
|
||||||
|
return ServerInternalErr(errors.Wrap(err,
|
||||||
|
"error marshaling new acme challenge"))
|
||||||
|
}
|
||||||
|
var oldB []byte
|
||||||
|
if old == nil {
|
||||||
|
oldB = nil
|
||||||
|
} else {
|
||||||
|
oldB, err = json.Marshal(old)
|
||||||
|
if err != nil {
|
||||||
|
return ServerInternalErr(errors.Wrap(err,
|
||||||
|
"error marshaling old acme challenge"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error saving acme challenge"))
|
||||||
|
case !swapped:
|
||||||
|
return ServerInternalErr(errors.New("error saving acme challenge; " +
|
||||||
|
"acme challenge has changed since last read"))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc *baseChallenge) clone() *baseChallenge {
|
||||||
|
u := *bc
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc *baseChallenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
||||||
|
return nil, ServerInternalErr(errors.New("unimplemented"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc *baseChallenge) storeError(db nosql.DB, err *Error) error {
|
||||||
|
clone := bc.clone()
|
||||||
|
clone.Error = err.ToACME()
|
||||||
|
return clone.save(db, bc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshalChallenge unmarshals a challenge type into the correct sub-type.
|
||||||
|
func unmarshalChallenge(data []byte) (challenge, error) {
|
||||||
|
var getType struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &getType); err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type"))
|
||||||
|
}
|
||||||
|
|
||||||
|
switch getType.Type {
|
||||||
|
case "dns-01":
|
||||||
|
var bc baseChallenge
|
||||||
|
if err := json.Unmarshal(data, &bc); err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
|
||||||
|
"challenge type into dns01Challenge"))
|
||||||
|
}
|
||||||
|
return &dns01Challenge{&bc}, nil
|
||||||
|
case "http-01":
|
||||||
|
var bc baseChallenge
|
||||||
|
if err := json.Unmarshal(data, &bc); err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+
|
||||||
|
"challenge type into http01Challenge"))
|
||||||
|
}
|
||||||
|
return &http01Challenge{&bc}, nil
|
||||||
|
default:
|
||||||
|
return nil, ServerInternalErr(errors.Errorf("unexpected challenge type %s", getType.Type))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// http01Challenge represents an http-01 acme challenge.
|
||||||
|
type http01Challenge struct {
|
||||||
|
*baseChallenge
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHTTP01Challenge returns a new acme http-01 challenge.
|
||||||
|
func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
|
||||||
|
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bc.Type = "http-01"
|
||||||
|
bc.Value = ops.Identifier.Value
|
||||||
|
|
||||||
|
hc := &http01Challenge{bc}
|
||||||
|
if err := hc.save(db, nil); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return hc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate attempts to validate the challenge. If the challenge has been
|
||||||
|
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||||
|
// updated.
|
||||||
|
func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
||||||
|
// If already valid or invalid then return without performing validation.
|
||||||
|
if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid {
|
||||||
|
return hc, nil
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token)
|
||||||
|
|
||||||
|
resp, err := vo.httpGet(url)
|
||||||
|
if err != nil {
|
||||||
|
if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err,
|
||||||
|
"error doing http GET for url %s", url))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return hc, nil
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
if err = hc.storeError(db,
|
||||||
|
ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d",
|
||||||
|
url, resp.StatusCode))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return hc, nil
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+
|
||||||
|
"response body for url %s", url))
|
||||||
|
}
|
||||||
|
keyAuth := strings.Trim(string(body), "\r\n")
|
||||||
|
|
||||||
|
expected, err := KeyAuthorization(hc.Token, jwk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if keyAuth != expected {
|
||||||
|
if err = hc.storeError(db,
|
||||||
|
RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+
|
||||||
|
"expected %s, but got %s", expected, keyAuth))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return hc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update and store the challenge.
|
||||||
|
upd := &http01Challenge{hc.baseChallenge.clone()}
|
||||||
|
upd.Status = StatusValid
|
||||||
|
upd.Error = nil
|
||||||
|
upd.Validated = clock.Now()
|
||||||
|
|
||||||
|
if err := upd.save(db, hc); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return upd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dns01Challenge represents an dns-01 acme challenge.
|
||||||
|
type dns01Challenge struct {
|
||||||
|
*baseChallenge
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDNS01Challenge returns a new acme dns-01 challenge.
|
||||||
|
func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) {
|
||||||
|
bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bc.Type = "dns-01"
|
||||||
|
bc.Value = ops.Identifier.Value
|
||||||
|
|
||||||
|
dc := &dns01Challenge{bc}
|
||||||
|
if err := dc.save(db, nil); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyAuthorization creates the ACME key authorization value from a token
|
||||||
|
// and a jwk.
|
||||||
|
func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
|
||||||
|
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
|
||||||
|
if err != nil {
|
||||||
|
return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint"))
|
||||||
|
}
|
||||||
|
encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
|
||||||
|
return fmt.Sprintf("%s.%s", token, encPrint), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate attempts to validate the challenge. If the challenge has been
|
||||||
|
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||||
|
// updated.
|
||||||
|
func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) {
|
||||||
|
// If already valid or invalid then return without performing validation.
|
||||||
|
if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid {
|
||||||
|
return dc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
txtRecords, err := vo.lookupTxt("_acme-challenge." + dc.Value)
|
||||||
|
if err != nil {
|
||||||
|
if err = dc.storeError(db,
|
||||||
|
DNSErr(errors.Wrapf(err, "error looking up TXT "+
|
||||||
|
"records for domain %s", dc.Value))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h := sha256.Sum256([]byte(expectedKeyAuth))
|
||||||
|
expected := base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
var found bool
|
||||||
|
for _, r := range txtRecords {
|
||||||
|
if r == expected {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
if err = dc.storeError(db,
|
||||||
|
RejectedIdentifierErr(errors.Errorf("keyAuthorization "+
|
||||||
|
"does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update and store the challenge.
|
||||||
|
upd := &dns01Challenge{dc.baseChallenge.clone()}
|
||||||
|
upd.Status = StatusValid
|
||||||
|
upd.Error = nil
|
||||||
|
upd.Validated = time.Now().UTC()
|
||||||
|
|
||||||
|
if err := upd.save(db, dc); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return upd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getChallenge retrieves and unmarshals an ACME challenge type from the database.
|
||||||
|
func getChallenge(db nosql.DB, id string) (challenge, error) {
|
||||||
|
b, err := db.Get(challengeTable, []byte(id))
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id))
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id))
|
||||||
|
}
|
||||||
|
ch, err := unmarshalChallenge(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ch, nil
|
||||||
|
}
|
1093
acme/challenge_test.go
Normal file
1093
acme/challenge_test.go
Normal file
File diff suppressed because it is too large
Load diff
76
acme/common.go
Normal file
76
acme/common.go
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SignAuthority is the interface implemented by a CA authority.
|
||||||
|
type SignAuthority interface {
|
||||||
|
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||||
|
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifier encodes the type that an order pertains to.
|
||||||
|
type Identifier struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
accountTable = []byte("acme-accounts")
|
||||||
|
accountByKeyIDTable = []byte("acme-keyID-accountID-index")
|
||||||
|
authzTable = []byte("acme-authzs")
|
||||||
|
challengeTable = []byte("acme-challenges")
|
||||||
|
nonceTable = []byte("nonce-table")
|
||||||
|
orderTable = []byte("acme-orders")
|
||||||
|
ordersByAccountIDTable = []byte("acme-account-orders-index")
|
||||||
|
certTable = []byte("acme-certs")
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// StatusValid -- valid
|
||||||
|
StatusValid = "valid"
|
||||||
|
// StatusInvalid -- invalid
|
||||||
|
StatusInvalid = "invalid"
|
||||||
|
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
|
||||||
|
StatusPending = "pending"
|
||||||
|
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
|
||||||
|
StatusDeactivated = "deactivated"
|
||||||
|
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
|
||||||
|
StatusReady = "ready"
|
||||||
|
//statusExpired = "expired"
|
||||||
|
//statusActive = "active"
|
||||||
|
//statusProcessing = "processing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var idLen = 32
|
||||||
|
|
||||||
|
func randID() (val string, err error) {
|
||||||
|
val, err = randutil.Alphanumeric(idLen)
|
||||||
|
if err != nil {
|
||||||
|
return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID"))
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clock that returns time in UTC rounded to seconds.
|
||||||
|
type Clock int
|
||||||
|
|
||||||
|
// Now returns the UTC time rounded to seconds.
|
||||||
|
func (c *Clock) Now() time.Time {
|
||||||
|
return time.Now().UTC().Round(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
var clock = new(Clock)
|
||||||
|
|
||||||
|
// URLSafeProvisionerName returns a path escaped version of the ACME provisioner
|
||||||
|
// ID that is safe to use in URL paths.
|
||||||
|
func URLSafeProvisionerName(p provisioner.Interface) string {
|
||||||
|
return url.PathEscape(p.GetName())
|
||||||
|
}
|
120
acme/directory.go
Normal file
120
acme/directory.go
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Directory represents an ACME directory for configuring clients.
|
||||||
|
type Directory struct {
|
||||||
|
NewNonce string `json:"newNonce,omitempty"`
|
||||||
|
NewAccount string `json:"newAccount,omitempty"`
|
||||||
|
NewOrder string `json:"newOrder,omitempty"`
|
||||||
|
NewAuthz string `json:"newAuthz,omitempty"`
|
||||||
|
RevokeCert string `json:"revokeCert,omitempty"`
|
||||||
|
KeyChange string `json:"keyChange,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToLog enables response logging for the Directory type.
|
||||||
|
func (d *Directory) ToLog() (interface{}, error) {
|
||||||
|
b, err := json.Marshal(d)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging"))
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type directory struct {
|
||||||
|
prefix, dns string
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDirectory returns a new Directory type.
|
||||||
|
func newDirectory(dns, prefix string) *directory {
|
||||||
|
return &directory{prefix: prefix, dns: dns}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link captures the link type.
|
||||||
|
type Link int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// NewNonceLink new-nonce
|
||||||
|
NewNonceLink Link = iota
|
||||||
|
// NewAccountLink new-account
|
||||||
|
NewAccountLink
|
||||||
|
// AccountLink account
|
||||||
|
AccountLink
|
||||||
|
// OrderLink order
|
||||||
|
OrderLink
|
||||||
|
// NewOrderLink new-order
|
||||||
|
NewOrderLink
|
||||||
|
// OrdersByAccountLink list of orders owned by account
|
||||||
|
OrdersByAccountLink
|
||||||
|
// FinalizeLink finalize order
|
||||||
|
FinalizeLink
|
||||||
|
// NewAuthzLink authz
|
||||||
|
NewAuthzLink
|
||||||
|
// AuthzLink new-authz
|
||||||
|
AuthzLink
|
||||||
|
// ChallengeLink challenge
|
||||||
|
ChallengeLink
|
||||||
|
// CertificateLink certificate
|
||||||
|
CertificateLink
|
||||||
|
// DirectoryLink directory
|
||||||
|
DirectoryLink
|
||||||
|
// RevokeCertLink revoke certificate
|
||||||
|
RevokeCertLink
|
||||||
|
// KeyChangeLink key rollover
|
||||||
|
KeyChangeLink
|
||||||
|
)
|
||||||
|
|
||||||
|
func (l Link) String() string {
|
||||||
|
switch l {
|
||||||
|
case NewNonceLink:
|
||||||
|
return "new-nonce"
|
||||||
|
case NewAccountLink:
|
||||||
|
return "new-account"
|
||||||
|
case AccountLink:
|
||||||
|
return "account"
|
||||||
|
case NewOrderLink:
|
||||||
|
return "new-order"
|
||||||
|
case OrderLink:
|
||||||
|
return "order"
|
||||||
|
case NewAuthzLink:
|
||||||
|
return "new-authz"
|
||||||
|
case AuthzLink:
|
||||||
|
return "authz"
|
||||||
|
case ChallengeLink:
|
||||||
|
return "challenge"
|
||||||
|
case CertificateLink:
|
||||||
|
return "certificate"
|
||||||
|
case DirectoryLink:
|
||||||
|
return "directory"
|
||||||
|
case RevokeCertLink:
|
||||||
|
return "revoke-cert"
|
||||||
|
case KeyChangeLink:
|
||||||
|
return "key-change"
|
||||||
|
default:
|
||||||
|
return "unexpected"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLink returns an absolute or partial path to the given resource.
|
||||||
|
func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs ...string) string {
|
||||||
|
var link string
|
||||||
|
switch typ {
|
||||||
|
case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
|
||||||
|
link = fmt.Sprintf("/%s/%s", provisionerName, typ.String())
|
||||||
|
case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink:
|
||||||
|
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0])
|
||||||
|
case OrdersByAccountLink:
|
||||||
|
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0])
|
||||||
|
case FinalizeLink:
|
||||||
|
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
|
||||||
|
}
|
||||||
|
if abs {
|
||||||
|
return fmt.Sprintf("https://%s/%s%s", d.dns, d.prefix, link)
|
||||||
|
}
|
||||||
|
return link
|
||||||
|
}
|
63
acme/directory_test.go
Normal file
63
acme/directory_test.go
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDirectoryGetLink(t *testing.T) {
|
||||||
|
dns := "ca.smallstep.com"
|
||||||
|
prefix := "acme"
|
||||||
|
dir := newDirectory(dns, prefix)
|
||||||
|
id := "1234"
|
||||||
|
|
||||||
|
prov := newProv()
|
||||||
|
provID := URLSafeProvisionerName(prov)
|
||||||
|
|
||||||
|
type newTest struct {
|
||||||
|
actual, expected string
|
||||||
|
}
|
||||||
|
assert.Equals(t, dir.getLink(NewNonceLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", provID))
|
||||||
|
assert.Equals(t, dir.getLink(NewNonceLink, provID, false), fmt.Sprintf("/%s/new-nonce", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(NewAccountLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID))
|
||||||
|
assert.Equals(t, dir.getLink(NewAccountLink, provID, false), fmt.Sprintf("/%s/new-account", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(AccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provID))
|
||||||
|
assert.Equals(t, dir.getLink(AccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(NewOrderLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", provID))
|
||||||
|
assert.Equals(t, dir.getLink(NewOrderLink, provID, false), fmt.Sprintf("/%s/new-order", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(OrderLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234", provID))
|
||||||
|
assert.Equals(t, dir.getLink(OrderLink, provID, false, id), fmt.Sprintf("/%s/order/1234", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234/orders", provID))
|
||||||
|
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234/orders", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(FinalizeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234/finalize", provID))
|
||||||
|
assert.Equals(t, dir.getLink(FinalizeLink, provID, false, id), fmt.Sprintf("/%s/order/1234/finalize", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(NewAuthzLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-authz", provID))
|
||||||
|
assert.Equals(t, dir.getLink(NewAuthzLink, provID, false), fmt.Sprintf("/%s/new-authz", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(AuthzLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/authz/1234", provID))
|
||||||
|
assert.Equals(t, dir.getLink(AuthzLink, provID, false, id), fmt.Sprintf("/%s/authz/1234", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(DirectoryLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", provID))
|
||||||
|
assert.Equals(t, dir.getLink(DirectoryLink, provID, false), fmt.Sprintf("/%s/directory", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(RevokeCertLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", provID))
|
||||||
|
assert.Equals(t, dir.getLink(RevokeCertLink, provID, false), fmt.Sprintf("/%s/revoke-cert", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(KeyChangeLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", provID))
|
||||||
|
assert.Equals(t, dir.getLink(KeyChangeLink, provID, false), fmt.Sprintf("/%s/key-change", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(ChallengeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/1234", provID))
|
||||||
|
assert.Equals(t, dir.getLink(ChallengeLink, provID, false, id), fmt.Sprintf("/%s/challenge/1234", provID))
|
||||||
|
|
||||||
|
assert.Equals(t, dir.getLink(CertificateLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/1234", provID))
|
||||||
|
assert.Equals(t, dir.getLink(CertificateLink, provID, false, id), fmt.Sprintf("/%s/certificate/1234", provID))
|
||||||
|
}
|
439
acme/errors.go
Normal file
439
acme/errors.go
Normal file
|
@ -0,0 +1,439 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccountDoesNotExistErr returns a new acme error.
|
||||||
|
func AccountDoesNotExistErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: accountDoesNotExistErr,
|
||||||
|
Detail: "Account does not exist",
|
||||||
|
Status: 404,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AlreadyRevokedErr returns a new acme error.
|
||||||
|
func AlreadyRevokedErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: alreadyRevokedErr,
|
||||||
|
Detail: "Certificate already revoked",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BadCSRErr returns a new acme error.
|
||||||
|
func BadCSRErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: badCSRErr,
|
||||||
|
Detail: "The CSR is unacceptable",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BadNonceErr returns a new acme error.
|
||||||
|
func BadNonceErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: badNonceErr,
|
||||||
|
Detail: "Unacceptable anti-replay nonce",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BadPublicKeyErr returns a new acme error.
|
||||||
|
func BadPublicKeyErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: badPublicKeyErr,
|
||||||
|
Detail: "The jws was signed by a public key the server does not support",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BadRevocationReasonErr returns a new acme error.
|
||||||
|
func BadRevocationReasonErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: badRevocationReasonErr,
|
||||||
|
Detail: "The revocation reason provided is not allowed by the server",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BadSignatureAlgorithmErr returns a new acme error.
|
||||||
|
func BadSignatureAlgorithmErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: badSignatureAlgorithmErr,
|
||||||
|
Detail: "The JWS was signed with an algorithm the server does not support",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CaaErr returns a new acme error.
|
||||||
|
func CaaErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: caaErr,
|
||||||
|
Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompoundErr returns a new acme error.
|
||||||
|
func CompoundErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: compoundErr,
|
||||||
|
Detail: "Specific error conditions are indicated in the “subproblems” array",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionErr returns a new acme error.
|
||||||
|
func ConnectionErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: connectionErr,
|
||||||
|
Detail: "The server could not connect to validation target",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNSErr returns a new acme error.
|
||||||
|
func DNSErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: dnsErr,
|
||||||
|
Detail: "There was a problem with a DNS query during identifier validation",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExternalAccountRequiredErr returns a new acme error.
|
||||||
|
func ExternalAccountRequiredErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: externalAccountRequiredErr,
|
||||||
|
Detail: "The request must include a value for the \"externalAccountBinding\" field",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncorrectResponseErr returns a new acme error.
|
||||||
|
func IncorrectResponseErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: incorrectResponseErr,
|
||||||
|
Detail: "Response received didn't match the challenge's requirements",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidContactErr returns a new acme error.
|
||||||
|
func InvalidContactErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: invalidContactErr,
|
||||||
|
Detail: "A contact URL for an account was invalid",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MalformedErr returns a new acme error.
|
||||||
|
func MalformedErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: malformedErr,
|
||||||
|
Detail: "The request message was malformed",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderNotReadyErr returns a new acme error.
|
||||||
|
func OrderNotReadyErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: orderNotReadyErr,
|
||||||
|
Detail: "The request attempted to finalize an order that is not ready to be finalized",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimitedErr returns a new acme error.
|
||||||
|
func RateLimitedErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: rateLimitedErr,
|
||||||
|
Detail: "The request exceeds a rate limit",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RejectedIdentifierErr returns a new acme error.
|
||||||
|
func RejectedIdentifierErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: rejectedIdentifierErr,
|
||||||
|
Detail: "The server will not issue certificates for the identifier",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerInternalErr returns a new acme error.
|
||||||
|
func ServerInternalErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: serverInternalErr,
|
||||||
|
Detail: "The server experienced an internal error",
|
||||||
|
Status: 500,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSErr returns a new acme error.
|
||||||
|
func TLSErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: tlsErr,
|
||||||
|
Detail: "The server received a TLS error during validation",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnauthorizedErr returns a new acme error.
|
||||||
|
func UnauthorizedErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: unauthorizedErr,
|
||||||
|
Detail: "The client lacks sufficient authorization",
|
||||||
|
Status: 401,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsupportedContactErr returns a new acme error.
|
||||||
|
func UnsupportedContactErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: unsupportedContactErr,
|
||||||
|
Detail: "A contact URL for an account used an unsupported protocol scheme",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsupportedIdentifierErr returns a new acme error.
|
||||||
|
func UnsupportedIdentifierErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: unsupportedIdentifierErr,
|
||||||
|
Detail: "An identifier is of an unsupported type",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserActionRequiredErr returns a new acme error.
|
||||||
|
func UserActionRequiredErr(err error) *Error {
|
||||||
|
return &Error{
|
||||||
|
Type: userActionRequiredErr,
|
||||||
|
Detail: "Visit the “instance” URL and take actions specified there",
|
||||||
|
Status: 400,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProbType is the type of the ACME problem.
|
||||||
|
type ProbType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// The request specified an account that does not exist
|
||||||
|
accountDoesNotExistErr ProbType = iota
|
||||||
|
// The request specified a certificate to be revoked that has already been revoked
|
||||||
|
alreadyRevokedErr
|
||||||
|
// The CSR is unacceptable (e.g., due to a short key)
|
||||||
|
badCSRErr
|
||||||
|
// The client sent an unacceptable anti-replay nonce
|
||||||
|
badNonceErr
|
||||||
|
// The JWS was signed by a public key the server does not support
|
||||||
|
badPublicKeyErr
|
||||||
|
// The revocation reason provided is not allowed by the server
|
||||||
|
badRevocationReasonErr
|
||||||
|
// The JWS was signed with an algorithm the server does not support
|
||||||
|
badSignatureAlgorithmErr
|
||||||
|
// Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate
|
||||||
|
caaErr
|
||||||
|
// Specific error conditions are indicated in the “subproblems” array.
|
||||||
|
compoundErr
|
||||||
|
// The server could not connect to validation target
|
||||||
|
connectionErr
|
||||||
|
// There was a problem with a DNS query during identifier validation
|
||||||
|
dnsErr
|
||||||
|
// The request must include a value for the “externalAccountBinding” field
|
||||||
|
externalAccountRequiredErr
|
||||||
|
// Response received didn’t match the challenge’s requirements
|
||||||
|
incorrectResponseErr
|
||||||
|
// A contact URL for an account was invalid
|
||||||
|
invalidContactErr
|
||||||
|
// The request message was malformed
|
||||||
|
malformedErr
|
||||||
|
// The request attempted to finalize an order that is not ready to be finalized
|
||||||
|
orderNotReadyErr
|
||||||
|
// The request exceeds a rate limit
|
||||||
|
rateLimitedErr
|
||||||
|
// The server will not issue certificates for the identifier
|
||||||
|
rejectedIdentifierErr
|
||||||
|
// The server experienced an internal error
|
||||||
|
serverInternalErr
|
||||||
|
// The server received a TLS error during validation
|
||||||
|
tlsErr
|
||||||
|
// The client lacks sufficient authorization
|
||||||
|
unauthorizedErr
|
||||||
|
// A contact URL for an account used an unsupported protocol scheme
|
||||||
|
unsupportedContactErr
|
||||||
|
// An identifier is of an unsupported type
|
||||||
|
unsupportedIdentifierErr
|
||||||
|
// Visit the “instance” URL and take actions specified there
|
||||||
|
userActionRequiredErr
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns the string representation of the acme problem type,
|
||||||
|
// fulfilling the Stringer interface.
|
||||||
|
func (ap ProbType) String() string {
|
||||||
|
switch ap {
|
||||||
|
case accountDoesNotExistErr:
|
||||||
|
return "accountDoesNotExist"
|
||||||
|
case alreadyRevokedErr:
|
||||||
|
return "alreadyRevoked"
|
||||||
|
case badCSRErr:
|
||||||
|
return "badCSR"
|
||||||
|
case badNonceErr:
|
||||||
|
return "badNonce"
|
||||||
|
case badPublicKeyErr:
|
||||||
|
return "badPublicKey"
|
||||||
|
case badRevocationReasonErr:
|
||||||
|
return "badRevocationReason"
|
||||||
|
case badSignatureAlgorithmErr:
|
||||||
|
return "badSignatureAlgorithm"
|
||||||
|
case caaErr:
|
||||||
|
return "caa"
|
||||||
|
case compoundErr:
|
||||||
|
return "compound"
|
||||||
|
case connectionErr:
|
||||||
|
return "connection"
|
||||||
|
case dnsErr:
|
||||||
|
return "dns"
|
||||||
|
case externalAccountRequiredErr:
|
||||||
|
return "externalAccountRequired"
|
||||||
|
case incorrectResponseErr:
|
||||||
|
return "incorrectResponse"
|
||||||
|
case invalidContactErr:
|
||||||
|
return "invalidContact"
|
||||||
|
case malformedErr:
|
||||||
|
return "malformed"
|
||||||
|
case orderNotReadyErr:
|
||||||
|
return "orderNotReady"
|
||||||
|
case rateLimitedErr:
|
||||||
|
return "rateLimited"
|
||||||
|
case rejectedIdentifierErr:
|
||||||
|
return "rejectedIdentifier"
|
||||||
|
case serverInternalErr:
|
||||||
|
return "serverInternal"
|
||||||
|
case tlsErr:
|
||||||
|
return "tls"
|
||||||
|
case unauthorizedErr:
|
||||||
|
return "unauthorized"
|
||||||
|
case unsupportedContactErr:
|
||||||
|
return "unsupportedContact"
|
||||||
|
case unsupportedIdentifierErr:
|
||||||
|
return "unsupportedIdentifier"
|
||||||
|
case userActionRequiredErr:
|
||||||
|
return "userActionRequired"
|
||||||
|
default:
|
||||||
|
return "unsupported type"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error is an ACME error type complete with problem document.
|
||||||
|
type Error struct {
|
||||||
|
Type ProbType
|
||||||
|
Detail string
|
||||||
|
Err error
|
||||||
|
Status int
|
||||||
|
Sub []*Error
|
||||||
|
Identifier *Identifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap attempts to wrap the internal error.
|
||||||
|
func Wrap(err error, wrap string) *Error {
|
||||||
|
switch e := err.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case *Error:
|
||||||
|
if e.Err == nil {
|
||||||
|
e.Err = errors.New(wrap + "; " + e.Detail)
|
||||||
|
} else {
|
||||||
|
e.Err = errors.Wrap(e.Err, wrap)
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
default:
|
||||||
|
return ServerInternalErr(errors.Wrap(err, wrap))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface.
|
||||||
|
func (e *Error) Error() string {
|
||||||
|
if e.Err == nil {
|
||||||
|
return e.Detail
|
||||||
|
}
|
||||||
|
return e.Err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cause returns the internal error and implements the Causer interface.
|
||||||
|
func (e *Error) Cause() error {
|
||||||
|
if e.Err == nil {
|
||||||
|
return errors.New(e.Detail)
|
||||||
|
}
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToACME returns an acme representation of the problem type.
|
||||||
|
func (e *Error) ToACME() *AError {
|
||||||
|
ae := &AError{
|
||||||
|
Type: "urn:ietf:params:acme:error:" + e.Type.String(),
|
||||||
|
Detail: e.Error(),
|
||||||
|
Status: e.Status,
|
||||||
|
}
|
||||||
|
if e.Identifier != nil {
|
||||||
|
ae.Identifier = *e.Identifier
|
||||||
|
}
|
||||||
|
for _, p := range e.Sub {
|
||||||
|
ae.Subproblems = append(ae.Subproblems, p.ToACME())
|
||||||
|
}
|
||||||
|
return ae
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusCode returns the status code and implements the StatusCode interface.
|
||||||
|
func (e *Error) StatusCode() int {
|
||||||
|
return e.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
// AError is the error type as seen in acme request/responses.
|
||||||
|
type AError struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Detail string `json:"detail"`
|
||||||
|
Identifier interface{} `json:"identifier,omitempty"`
|
||||||
|
Subproblems []interface{} `json:"subproblems,omitempty"`
|
||||||
|
Status int `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error allows AError to implement the error interface.
|
||||||
|
func (ae *AError) Error() string {
|
||||||
|
return ae.Detail
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusCode returns the status code and implements the StatusCode interface.
|
||||||
|
func (ae *AError) StatusCode() int {
|
||||||
|
return ae.Status
|
||||||
|
}
|
73
acme/nonce.go
Normal file
73
acme/nonce.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
"github.com/smallstep/nosql/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
// nonce contains nonce metadata used in the ACME protocol.
|
||||||
|
type nonce struct {
|
||||||
|
ID string
|
||||||
|
Created time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// newNonce creates, stores, and returns an ACME replay-nonce.
|
||||||
|
func newNonce(db nosql.DB) (*nonce, error) {
|
||||||
|
_id, err := randID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id := base64.RawURLEncoding.EncodeToString([]byte(_id))
|
||||||
|
n := &nonce{
|
||||||
|
ID: id,
|
||||||
|
Created: clock.Now(),
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(n)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce"))
|
||||||
|
}
|
||||||
|
_, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce"))
|
||||||
|
case !swapped:
|
||||||
|
return nil, ServerInternalErr(errors.New("error storing nonce; " +
|
||||||
|
"value has changed since last read"))
|
||||||
|
default:
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// useNonce verifies that the nonce is valid (by checking if it exists),
|
||||||
|
// and if so, consumes the nonce resource by deleting it from the database.
|
||||||
|
func useNonce(db nosql.DB, nonce string) error {
|
||||||
|
err := db.Update(&database.Tx{
|
||||||
|
Operations: []*database.TxEntry{
|
||||||
|
{
|
||||||
|
Bucket: nonceTable,
|
||||||
|
Key: []byte(nonce),
|
||||||
|
Cmd: database.Get,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Bucket: nonceTable,
|
||||||
|
Key: []byte(nonce),
|
||||||
|
Cmd: database.Delete,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case nosql.IsErrNotFound(err):
|
||||||
|
return BadNonceErr(nil)
|
||||||
|
case err != nil:
|
||||||
|
return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
163
acme/nonce_test.go
Normal file
163
acme/nonce_test.go
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
"github.com/smallstep/nosql/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewNonce(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err *Error
|
||||||
|
id *string
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, nonceTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.Errorf("error storing nonce: force")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/cmpAndSwap-false": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, nonceTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
return nil, false, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
var _id string
|
||||||
|
id := &_id
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, nonceTable)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
*id = string(key)
|
||||||
|
return nil, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if n, err := newNonce(tc.db); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, n.ID, *tc.id)
|
||||||
|
|
||||||
|
assert.True(t, n.Created.Before(time.Now().Add(time.Minute)))
|
||||||
|
assert.True(t, n.Created.After(time.Now().Add(-time.Minute)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUseNonce(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
id string
|
||||||
|
db nosql.DB
|
||||||
|
err *Error
|
||||||
|
}
|
||||||
|
tests := map[string]func(t *testing.T) test{
|
||||||
|
"fail/update-not-found": func(t *testing.T) test {
|
||||||
|
id := "foo"
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||||
|
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||||
|
|
||||||
|
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||||
|
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||||
|
return database.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: id,
|
||||||
|
err: BadNonceErr(nil),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/update-error": func(t *testing.T) test {
|
||||||
|
id := "foo"
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||||
|
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||||
|
|
||||||
|
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||||
|
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||||
|
return errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: id,
|
||||||
|
err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
id := "foo"
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MUpdate: func(tx *database.Tx) error {
|
||||||
|
assert.Equals(t, tx.Operations[0].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[0].Key, []byte(id))
|
||||||
|
assert.Equals(t, tx.Operations[0].Cmd, database.Get)
|
||||||
|
|
||||||
|
assert.Equals(t, tx.Operations[1].Bucket, nonceTable)
|
||||||
|
assert.Equals(t, tx.Operations[1].Key, []byte(id))
|
||||||
|
assert.Equals(t, tx.Operations[1].Cmd, database.Delete)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := run(t)
|
||||||
|
if err := useNonce(tc.db, tc.id); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
ae, ok := err.(*Error)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||||||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
342
acme/order.go
Normal file
342
acme/order.go
Normal file
|
@ -0,0 +1,342 @@
|
||||||
|
package acme
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultOrderExpiry = time.Hour * 24
|
||||||
|
|
||||||
|
// Order contains order metadata for the ACME protocol order type.
|
||||||
|
type Order struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Expires string `json:"expires,omitempty"`
|
||||||
|
Identifiers []Identifier `json:"identifiers"`
|
||||||
|
NotBefore string `json:"notBefore,omitempty"`
|
||||||
|
NotAfter string `json:"notAfter,omitempty"`
|
||||||
|
Error interface{} `json:"error,omitempty"`
|
||||||
|
Authorizations []string `json:"authorizations"`
|
||||||
|
Finalize string `json:"finalize"`
|
||||||
|
Certificate string `json:"certificate,omitempty"`
|
||||||
|
ID string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToLog enables response logging.
|
||||||
|
func (o *Order) ToLog() (interface{}, error) {
|
||||||
|
b, err := json.Marshal(o)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging"))
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the Order ID.
|
||||||
|
func (o *Order) GetID() string {
|
||||||
|
return o.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderOptions options with which to create a new Order.
|
||||||
|
type OrderOptions struct {
|
||||||
|
AccountID string `json:"accID"`
|
||||||
|
Identifiers []Identifier `json:"identifiers"`
|
||||||
|
NotBefore time.Time `json:"notBefore"`
|
||||||
|
NotAfter time.Time `json:"notAfter"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type order struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
AccountID string `json:"accountID"`
|
||||||
|
Created time.Time `json:"created"`
|
||||||
|
Expires time.Time `json:"expires,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Identifiers []Identifier `json:"identifiers"`
|
||||||
|
NotBefore time.Time `json:"notBefore,omitempty"`
|
||||||
|
NotAfter time.Time `json:"notAfter,omitempty"`
|
||||||
|
Error *Error `json:"error,omitempty"`
|
||||||
|
Authorizations []string `json:"authorizations"`
|
||||||
|
Certificate string `json:"certificate,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// newOrder returns a new Order type.
|
||||||
|
func newOrder(db nosql.DB, ops OrderOptions) (*order, error) {
|
||||||
|
id, err := randID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
authzs := make([]string, len(ops.Identifiers))
|
||||||
|
for i, identifier := range ops.Identifiers {
|
||||||
|
authz, err := newAuthz(db, ops.AccountID, identifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
authzs[i] = authz.getID()
|
||||||
|
}
|
||||||
|
|
||||||
|
now := clock.Now()
|
||||||
|
o := &order{
|
||||||
|
ID: id,
|
||||||
|
AccountID: ops.AccountID,
|
||||||
|
Created: now,
|
||||||
|
Status: StatusPending,
|
||||||
|
Expires: now.Add(defaultOrderExpiry),
|
||||||
|
Identifiers: ops.Identifiers,
|
||||||
|
NotBefore: ops.NotBefore,
|
||||||
|
NotAfter: ops.NotAfter,
|
||||||
|
Authorizations: authzs,
|
||||||
|
}
|
||||||
|
if err := o.save(db, nil); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the "order IDs by account ID" index //
|
||||||
|
oids, err := getOrderIDsByAccount(db, ops.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newOids := append(oids, o.ID)
|
||||||
|
if err = orderIDs(newOids).save(db, oids, o.AccountID); err != nil {
|
||||||
|
db.Del(orderTable, []byte(o.ID))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type orderIDs []string
|
||||||
|
|
||||||
|
func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
oldb []byte
|
||||||
|
)
|
||||||
|
if len(old) == 0 {
|
||||||
|
oldb = nil
|
||||||
|
} else {
|
||||||
|
oldb, err = json.Marshal(old)
|
||||||
|
if err != nil {
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newb, err := json.Marshal(oids)
|
||||||
|
if err != nil {
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice"))
|
||||||
|
}
|
||||||
|
_, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID))
|
||||||
|
case !swapped:
|
||||||
|
return ServerInternalErr(errors.Errorf("error storing order IDs "+
|
||||||
|
"for account %s; order IDs changed since last read", accID))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *order) save(db nosql.DB, old *order) error {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
oldB []byte
|
||||||
|
)
|
||||||
|
if old == nil {
|
||||||
|
oldB = nil
|
||||||
|
} else {
|
||||||
|
if oldB, err = json.Marshal(old); err != nil {
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newB, err := json.Marshal(o)
|
||||||
|
if err != nil {
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order"))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return ServerInternalErr(errors.Wrap(err, "error storing order"))
|
||||||
|
case !swapped:
|
||||||
|
return ServerInternalErr(errors.New("error storing order; " +
|
||||||
|
"value has changed since last read"))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateStatus updates order status if necessary.
|
||||||
|
func (o *order) updateStatus(db nosql.DB) (*order, error) {
|
||||||
|
_newOrder := *o
|
||||||
|
newOrder := &_newOrder
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
switch o.Status {
|
||||||
|
case StatusInvalid:
|
||||||
|
return o, nil
|
||||||
|
case StatusValid:
|
||||||
|
return o, nil
|
||||||
|
case StatusReady:
|
||||||
|
// check expiry
|
||||||
|
if now.After(o.Expires) {
|
||||||
|
newOrder.Status = StatusInvalid
|
||||||
|
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return o, nil
|
||||||
|
case StatusPending:
|
||||||
|
// check expiry
|
||||||
|
if now.After(o.Expires) {
|
||||||
|
newOrder.Status = StatusInvalid
|
||||||
|
newOrder.Error = MalformedErr(errors.New("order has expired"))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var count = map[string]int{
|
||||||
|
StatusValid: 0,
|
||||||
|
StatusInvalid: 0,
|
||||||
|
StatusPending: 0,
|
||||||
|
}
|
||||||
|
for _, azID := range o.Authorizations {
|
||||||
|
authz, err := getAuthz(db, azID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if authz, err = authz.updateStatus(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
st := authz.getStatus()
|
||||||
|
count[st]++
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case count[StatusInvalid] > 0:
|
||||||
|
newOrder.Status = StatusInvalid
|
||||||
|
case count[StatusPending] > 0:
|
||||||
|
break
|
||||||
|
case count[StatusValid] == len(o.Authorizations):
|
||||||
|
newOrder.Status = StatusReady
|
||||||
|
default:
|
||||||
|
return nil, ServerInternalErr(errors.New("unexpected authz status"))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := newOrder.save(db, o); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newOrder, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// finalize signs a certificate if the necessary conditions for Order completion
|
||||||
|
// have been met.
|
||||||
|
func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p provisioner.Interface) (*order, error) {
|
||||||
|
var err error
|
||||||
|
if o, err = o.updateStatus(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch o.Status {
|
||||||
|
case StatusInvalid:
|
||||||
|
return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID))
|
||||||
|
case StatusValid:
|
||||||
|
return o, nil
|
||||||
|
case StatusPending:
|
||||||
|
return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID))
|
||||||
|
case StatusReady:
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate identifier names against CSR alternative names //
|
||||||
|
csrNames := make(map[string]int)
|
||||||
|
for _, n := range csr.DNSNames {
|
||||||
|
csrNames[n] = 1
|
||||||
|
}
|
||||||
|
orderNames := make(map[string]int)
|
||||||
|
for _, n := range o.Identifiers {
|
||||||
|
orderNames[n.Value] = 1
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(csrNames, orderNames) {
|
||||||
|
return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get authorizations from the ACME provisioner.
|
||||||
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||||
|
signOps, err := p.AuthorizeSign(ctx, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and store a new certificate.
|
||||||
|
leaf, inter, err := auth.Sign(csr, provisioner.Options{
|
||||||
|
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
|
||||||
|
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
|
||||||
|
}, signOps...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID))
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := newCert(db, CertOptions{
|
||||||
|
AccountID: o.AccountID,
|
||||||
|
OrderID: o.ID,
|
||||||
|
Leaf: leaf,
|
||||||
|
Intermediates: []*x509.Certificate{inter},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_newOrder := *o
|
||||||
|
newOrder := &_newOrder
|
||||||
|
newOrder.Certificate = cert.ID
|
||||||
|
newOrder.Status = StatusValid
|
||||||
|
if err := newOrder.save(db, o); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newOrder, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOrder retrieves and unmarshals an ACME Order type from the database.
|
||||||
|
func getOrder(db nosql.DB, id string) (*order, error) {
|
||||||
|
b, err := db.Get(orderTable, []byte(id))
|
||||||
|
if nosql.IsErrNotFound(err) {
|
||||||
|
return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id))
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id))
|
||||||
|
}
|
||||||
|
var o order
|
||||||
|
if err := json.Unmarshal(b, &o); err != nil {
|
||||||
|
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order"))
|
||||||
|
}
|
||||||
|
return &o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// toACME converts the internal Order type into the public acmeOrder type for
|
||||||
|
// presentation in the ACME protocol.
|
||||||
|
func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Order, error) {
|
||||||
|
azs := make([]string, len(o.Authorizations))
|
||||||
|
for i, aid := range o.Authorizations {
|
||||||
|
azs[i] = dir.getLink(AuthzLink, URLSafeProvisionerName(p), true, aid)
|
||||||
|
}
|
||||||
|
ao := &Order{
|
||||||
|
Status: o.Status,
|
||||||
|
Expires: o.Expires.Format(time.RFC3339),
|
||||||
|
Identifiers: o.Identifiers,
|
||||||
|
NotBefore: o.NotBefore.Format(time.RFC3339),
|
||||||
|
NotAfter: o.NotAfter.Format(time.RFC3339),
|
||||||
|
Authorizations: azs,
|
||||||
|
Finalize: dir.getLink(FinalizeLink, URLSafeProvisionerName(p), true, o.ID),
|
||||||
|
ID: o.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.Certificate != "" {
|
||||||
|
ao.Certificate = dir.getLink(CertificateLink, URLSafeProvisionerName(p), true, o.Certificate)
|
||||||
|
}
|
||||||
|
return ao, nil
|
||||||
|
}
|
1129
acme/order_test.go
Normal file
1129
acme/order_test.go
Normal file
File diff suppressed because it is too large
Load diff
24
api/api.go
24
api/api.go
|
@ -28,8 +28,7 @@ import (
|
||||||
// Authority is the interface implemented by a CA authority.
|
// Authority is the interface implemented by a CA authority.
|
||||||
type Authority interface {
|
type Authority interface {
|
||||||
SSHAuthority
|
SSHAuthority
|
||||||
// NOTE: Authorize will be deprecated in future releases. Please use the
|
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||||
// context specific Authorize[Sign|Revoke|etc.] methods.
|
|
||||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||||
GetTLSOptions() *tlsutil.TLSOptions
|
GetTLSOptions() *tlsutil.TLSOptions
|
||||||
|
@ -37,6 +36,7 @@ type Authority interface {
|
||||||
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error)
|
||||||
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||||
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
|
||||||
|
LoadProvisionerByID(string) (provisioner.Interface, error)
|
||||||
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
||||||
Revoke(*authority.RevokeOptions) error
|
Revoke(*authority.RevokeOptions) error
|
||||||
GetEncryptedKey(kid string) (string, error)
|
GetEncryptedKey(kid string) (string, error)
|
||||||
|
@ -308,13 +308,12 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
|
||||||
logCertificate(w, cert)
|
logCertificate(w, cert)
|
||||||
JSON(w, &SignResponse{
|
JSONStatus(w, &SignResponse{
|
||||||
ServerPEM: Certificate{cert},
|
ServerPEM: Certificate{cert},
|
||||||
CaPEM: Certificate{root},
|
CaPEM: Certificate{root},
|
||||||
TLSOptions: h.Authority.GetTLSOptions(),
|
TLSOptions: h.Authority.GetTLSOptions(),
|
||||||
})
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Renew uses the information of certificate in the TLS connection to create a
|
// Renew uses the information of certificate in the TLS connection to create a
|
||||||
|
@ -331,13 +330,12 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
|
||||||
logCertificate(w, cert)
|
logCertificate(w, cert)
|
||||||
JSON(w, &SignResponse{
|
JSONStatus(w, &SignResponse{
|
||||||
ServerPEM: Certificate{cert},
|
ServerPEM: Certificate{cert},
|
||||||
CaPEM: Certificate{root},
|
CaPEM: Certificate{root},
|
||||||
TLSOptions: h.Authority.GetTLSOptions(),
|
TLSOptions: h.Authority.GetTLSOptions(),
|
||||||
})
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Provisioners returns the list of provisioners configured in the authority.
|
// Provisioners returns the list of provisioners configured in the authority.
|
||||||
|
@ -383,10 +381,9 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
certs[i] = Certificate{roots[i]}
|
certs[i] = Certificate{roots[i]}
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
JSONStatus(w, &RootsResponse{
|
||||||
JSON(w, &RootsResponse{
|
|
||||||
Certificates: certs,
|
Certificates: certs,
|
||||||
})
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Federation returns all the public certificates in the federation.
|
// Federation returns all the public certificates in the federation.
|
||||||
|
@ -402,10 +399,9 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||||
certs[i] = Certificate{federated[i]}
|
certs[i] = Certificate{federated[i]}
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
JSONStatus(w, &FederationResponse{
|
||||||
JSON(w, &FederationResponse{
|
|
||||||
Certificates: certs,
|
Certificates: certs,
|
||||||
})
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
var oidStepProvisioner = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}
|
var oidStepProvisioner = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}
|
||||||
|
|
|
@ -506,6 +506,7 @@ type mockAuthority struct {
|
||||||
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||||
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||||
|
loadProvisionerByID func(provID string) (provisioner.Interface, error)
|
||||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||||
revoke func(*authority.RevokeOptions) error
|
revoke func(*authority.RevokeOptions) error
|
||||||
getEncryptedKey func(kid string) (string, error)
|
getEncryptedKey func(kid string) (string, error)
|
||||||
|
@ -581,6 +582,13 @@ func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (pr
|
||||||
return m.ret1.(provisioner.Interface), m.err
|
return m.ret1.(provisioner.Interface), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) {
|
||||||
|
if m.loadProvisionerByID != nil {
|
||||||
|
return m.loadProvisionerByID(provID)
|
||||||
|
}
|
||||||
|
return m.ret1.(provisioner.Interface), m.err
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) Revoke(opts *authority.RevokeOptions) error {
|
func (m *mockAuthority) Revoke(opts *authority.RevokeOptions) error {
|
||||||
if m.revoke != nil {
|
if m.revoke != nil {
|
||||||
return m.revoke(opts)
|
return m.revoke(opts)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -109,7 +110,13 @@ func NotFound(err error) error {
|
||||||
|
|
||||||
// WriteError writes to w a JSON representation of the given error.
|
// WriteError writes to w a JSON representation of the given error.
|
||||||
func WriteError(w http.ResponseWriter, err error) {
|
func WriteError(w http.ResponseWriter, err error) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
w.Header().Set("Content-Type", "application/problem+json")
|
||||||
|
err = k.ToACME()
|
||||||
|
default:
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
cause := errors.Cause(err)
|
cause := errors.Cause(err)
|
||||||
if sc, ok := err.(StatusCoder); ok {
|
if sc, ok := err.(StatusCoder); ok {
|
||||||
w.WriteHeader(sc.StatusCode())
|
w.WriteHeader(sc.StatusCode())
|
||||||
|
|
|
@ -87,8 +87,6 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
logRevoke(w, opts)
|
logRevoke(w, opts)
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
JSON(w, &RevokeResponse{Status: "ok"})
|
JSON(w, &RevokeResponse{Status: "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
33
api/utils.go
33
api/utils.go
|
@ -10,6 +10,11 @@ import (
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// EnableLogger is an interface that enables response logging for an object.
|
||||||
|
type EnableLogger interface {
|
||||||
|
ToLog() (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
// LogError adds to the response writer the given error if it implements
|
// LogError adds to the response writer the given error if it implements
|
||||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||||
// using the log package.
|
// using the log package.
|
||||||
|
@ -23,12 +28,40 @@ func LogError(rw http.ResponseWriter, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogEnabledResponse log the response object if it implements the EnableLogger
|
||||||
|
// interface.
|
||||||
|
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
|
||||||
|
if el, ok := v.(EnableLogger); ok {
|
||||||
|
out, err := el.ToLog()
|
||||||
|
if err != nil {
|
||||||
|
LogError(rw, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||||
|
rl.WithFields(map[string]interface{}{
|
||||||
|
"response": out,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
log.Println(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// JSON writes the passed value into the http.ResponseWriter.
|
// JSON writes the passed value into the http.ResponseWriter.
|
||||||
func JSON(w http.ResponseWriter, v interface{}) {
|
func JSON(w http.ResponseWriter, v interface{}) {
|
||||||
|
JSONStatus(w, v, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSONStatus writes the given value into the http.ResponseWriter and the
|
||||||
|
// given status is written as the status code of the response.
|
||||||
|
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||||
LogError(w, err)
|
LogError(w, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
LogEnabledResponse(w, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadJSON reads JSON from the request body and stores it in the value
|
// ReadJSON reads JSON from the request body and stores it in the value
|
||||||
|
|
|
@ -15,7 +15,9 @@ import (
|
||||||
"github.com/smallstep/cli/crypto/x509util"
|
"github.com/smallstep/cli/crypto/x509util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const legacyAuthority = "step-certificate-authority"
|
const (
|
||||||
|
legacyAuthority = "step-certificate-authority"
|
||||||
|
)
|
||||||
|
|
||||||
// Authority implements the Certificate Authority internal interface.
|
// Authority implements the Certificate Authority internal interface.
|
||||||
type Authority struct {
|
type Authority struct {
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,6 +35,12 @@ func (e *apiError) Error() string {
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrorResponse represents an error in JSON format.
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Status int `json:"status"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
// StatusCode returns an http status code indicating the type and severity of
|
// StatusCode returns an http status code indicating the type and severity of
|
||||||
// the error.
|
// the error.
|
||||||
func (e *apiError) StatusCode() int {
|
func (e *apiError) StatusCode() int {
|
||||||
|
@ -41,3 +49,19 @@ func (e *apiError) StatusCode() int {
|
||||||
}
|
}
|
||||||
return e.code
|
return e.code
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaller interface for the Error struct.
|
||||||
|
func (e *apiError) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(&ErrorResponse{Status: e.code, Message: http.StatusText(e.code)})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
|
||||||
|
func (e *apiError) UnmarshalJSON(data []byte) error {
|
||||||
|
var er ErrorResponse
|
||||||
|
if err := json.Unmarshal(data, &er); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
e.code = er.Status
|
||||||
|
e.err = fmt.Errorf(er.Message)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
85
authority/provisioner/acme.go
Normal file
85
authority/provisioner/acme.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ACME is the acme provisioner type, an entity that can authorize the ACME
|
||||||
|
// provisioning flow.
|
||||||
|
type ACME struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
|
claimer *Claimer
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the provisioner unique identifier.
|
||||||
|
func (p ACME) GetID() string {
|
||||||
|
return "acme/" + p.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenID returns the identifier of the token.
|
||||||
|
func (p *ACME) GetTokenID(ott string) (string, error) {
|
||||||
|
return "", errors.New("acme provisioner does not implement GetTokenID")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetName returns the name of the provisioner.
|
||||||
|
func (p *ACME) GetName() string {
|
||||||
|
return p.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of provisioner.
|
||||||
|
func (p *ACME) GetType() Type {
|
||||||
|
return TypeACME
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEncryptedKey returns the base provisioner encrypted key if it's defined.
|
||||||
|
func (p *ACME) GetEncryptedKey() (string, string, bool) {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init initializes and validates the fields of a JWK type.
|
||||||
|
func (p *ACME) Init(config Config) (err error) {
|
||||||
|
switch {
|
||||||
|
case p.Type == "":
|
||||||
|
return errors.New("provisioner type cannot be empty")
|
||||||
|
case p.Name == "":
|
||||||
|
return errors.New("provisioner name cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update claims with global ones
|
||||||
|
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRevoke is not implemented yet for the ACME provisioner.
|
||||||
|
func (p *ACME) AuthorizeRevoke(token string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeSign validates the given token.
|
||||||
|
func (p *ACME) AuthorizeSign(ctx context.Context, _ string) ([]SignOption, error) {
|
||||||
|
if m := MethodFromContext(ctx); m != SignMethod {
|
||||||
|
return nil, errors.Errorf("unexpected method type %d in context", m)
|
||||||
|
}
|
||||||
|
return []SignOption{
|
||||||
|
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||||
|
newProvisionerExtensionOption(TypeACME, p.Name, ""),
|
||||||
|
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||||
|
defaultPublicKeyValidator{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRenewal is not implemented for the ACME provisioner.
|
||||||
|
func (p *ACME) AuthorizeRenewal(cert *x509.Certificate) error {
|
||||||
|
if p.claimer.IsDisableRenewal() {
|
||||||
|
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
184
authority/provisioner/acme_test.go
Normal file
184
authority/provisioner/acme_test.go
Normal file
|
@ -0,0 +1,184 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestACME_Getters(t *testing.T) {
|
||||||
|
p, err := generateACME()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
id := "acme/" + p.Name
|
||||||
|
if got := p.GetID(); got != id {
|
||||||
|
t.Errorf("ACME.GetID() = %v, want %v", got, id)
|
||||||
|
}
|
||||||
|
if got := p.GetName(); got != p.Name {
|
||||||
|
t.Errorf("ACME.GetName() = %v, want %v", got, p.Name)
|
||||||
|
}
|
||||||
|
if got := p.GetType(); got != TypeACME {
|
||||||
|
t.Errorf("ACME.GetType() = %v, want %v", got, TypeACME)
|
||||||
|
}
|
||||||
|
kid, key, ok := p.GetEncryptedKey()
|
||||||
|
if kid != "" || key != "" || ok == true {
|
||||||
|
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||||
|
kid, key, ok, "", "", false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestACME_Init(t *testing.T) {
|
||||||
|
type ProvisionerValidateTest struct {
|
||||||
|
p *ACME
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
tests := map[string]func(*testing.T) ProvisionerValidateTest{
|
||||||
|
"fail-empty": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &ACME{},
|
||||||
|
err: errors.New("provisioner type cannot be empty"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail-empty-name": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &ACME{
|
||||||
|
Type: "ACME",
|
||||||
|
},
|
||||||
|
err: errors.New("provisioner name cannot be empty"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail-empty-type": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &ACME{Name: "foo"},
|
||||||
|
err: errors.New("provisioner type cannot be empty"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &ACME{Name: "foo", Type: "bar", Claims: &Claims{DefaultTLSDur: &Duration{0}}},
|
||||||
|
err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
return ProvisionerValidateTest{
|
||||||
|
p: &ACME{Name: "foo", Type: "bar"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config := Config{
|
||||||
|
Claims: globalProvisionerClaims,
|
||||||
|
Audiences: testAudiences,
|
||||||
|
}
|
||||||
|
for name, get := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
tc := get(t)
|
||||||
|
err := tc.p.Init(config)
|
||||||
|
if err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.err.Error(), err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestACME_AuthorizeRevoke(t *testing.T) {
|
||||||
|
p, err := generateACME()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Nil(t, p.AuthorizeRevoke(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestACME_AuthorizeRenewal(t *testing.T) {
|
||||||
|
p1, err := generateACME()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p2, err := generateACME()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
// disable renewal
|
||||||
|
disable := true
|
||||||
|
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||||
|
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *ACME
|
||||||
|
args args
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"ok", p1, args{nil}, nil},
|
||||||
|
{"fail", p2, args{nil}, errors.Errorf("renew is disabled for provisioner %s", p2.GetID())},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.prov.AuthorizeRenewal(tt.args.cert); err != nil {
|
||||||
|
if assert.NotNil(t, tt.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tt.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestACME_AuthorizeSign(t *testing.T) {
|
||||||
|
p1, err := generateACME()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prov *ACME
|
||||||
|
method Method
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"fail/method", p1, SignSSHMethod, errors.New("unexpected method type 1 in context")},
|
||||||
|
{"ok", p1, SignMethod, nil},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContextWithMethod(context.Background(), tt.method)
|
||||||
|
if got, err := tt.prov.AuthorizeSign(ctx, ""); err != nil {
|
||||||
|
if assert.NotNil(t, tt.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.NotNil(t, got) {
|
||||||
|
assert.Len(t, 4, got)
|
||||||
|
|
||||||
|
_pdd := got[0]
|
||||||
|
pdd, ok := _pdd.(profileDefaultDuration)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equals(t, pdd, profileDefaultDuration(86400000000000))
|
||||||
|
|
||||||
|
_peo := got[1]
|
||||||
|
peo, ok := _peo.(*provisionerExtensionOption)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equals(t, peo.Type, 6)
|
||||||
|
assert.Equals(t, peo.Name, "test@acme-provisioner.com")
|
||||||
|
assert.Equals(t, peo.CredentialID, "")
|
||||||
|
assert.Equals(t, peo.KeyValuePairs, nil)
|
||||||
|
|
||||||
|
_vv := got[2]
|
||||||
|
vv, ok := _vv.(*validityValidator)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equals(t, vv.min, time.Duration(300000000000))
|
||||||
|
assert.Equals(t, vv.max, time.Duration(86400000000000))
|
||||||
|
|
||||||
|
_dpkv := got[3]
|
||||||
|
_, ok = _dpkv.(defaultPublicKeyValidator)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -127,6 +127,8 @@ func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool)
|
||||||
return c.Load("aws/" + string(provisioner.Name))
|
return c.Load("aws/" + string(provisioner.Name))
|
||||||
case TypeGCP:
|
case TypeGCP:
|
||||||
return c.Load("gcp/" + string(provisioner.Name))
|
return c.Load("gcp/" + string(provisioner.Name))
|
||||||
|
case TypeACME:
|
||||||
|
return c.Load("acme/" + string(provisioner.Name))
|
||||||
default:
|
default:
|
||||||
return c.Load(string(provisioner.CredentialID))
|
return c.Load(string(provisioner.CredentialID))
|
||||||
}
|
}
|
||||||
|
@ -152,8 +154,9 @@ func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
|
||||||
// Store adds a provisioner to the collection and enforces the uniqueness of
|
// Store adds a provisioner to the collection and enforces the uniqueness of
|
||||||
// provisioner IDs.
|
// provisioner IDs.
|
||||||
func (c *Collection) Store(p Interface) error {
|
func (c *Collection) Store(p Interface) error {
|
||||||
|
fmt.Printf("p.GetID() = %+v\n", p.GetID())
|
||||||
// Store provisioner always in byID. ID must be unique.
|
// Store provisioner always in byID. ID must be unique.
|
||||||
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true {
|
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded {
|
||||||
return errors.New("cannot add multiple provisioners with the same id")
|
return errors.New("cannot add multiple provisioners with the same id")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -133,15 +133,20 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p2, err := generateOIDC()
|
p2, err := generateOIDC()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
p3, err := generateACME()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
byID := new(sync.Map)
|
byID := new(sync.Map)
|
||||||
byID.Store(p1.GetID(), p1)
|
byID.Store(p1.GetID(), p1)
|
||||||
byID.Store(p2.GetID(), p2)
|
byID.Store(p2.GetID(), p2)
|
||||||
|
byID.Store(p3.GetID(), p3)
|
||||||
|
|
||||||
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
|
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "")
|
||||||
|
assert.FatalError(t, err)
|
||||||
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
|
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
@ -151,6 +156,9 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
||||||
ok2Cert := &x509.Certificate{
|
ok2Cert := &x509.Certificate{
|
||||||
Extensions: []pkix.Extension{ok2Ext},
|
Extensions: []pkix.Extension{ok2Ext},
|
||||||
}
|
}
|
||||||
|
ok3Cert := &x509.Certificate{
|
||||||
|
Extensions: []pkix.Extension{ok3Ext},
|
||||||
|
}
|
||||||
notFoundCert := &x509.Certificate{
|
notFoundCert := &x509.Certificate{
|
||||||
Extensions: []pkix.Extension{notFoundExt},
|
Extensions: []pkix.Extension{notFoundExt},
|
||||||
}
|
}
|
||||||
|
@ -176,6 +184,7 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
|
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
|
||||||
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
|
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
|
||||||
|
{"ok3", fields{byID, testAudiences}, args{ok3Cert}, p3, true},
|
||||||
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
||||||
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
|
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
|
||||||
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
|
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
|
||||||
|
|
|
@ -84,6 +84,8 @@ const (
|
||||||
TypeAWS Type = 4
|
TypeAWS Type = 4
|
||||||
// TypeAzure is used to indicate the Azure provisioners.
|
// TypeAzure is used to indicate the Azure provisioners.
|
||||||
TypeAzure Type = 5
|
TypeAzure Type = 5
|
||||||
|
// TypeACME is used to indicate the ACME provisioners.
|
||||||
|
TypeACME Type = 6
|
||||||
|
|
||||||
// RevokeAudienceKey is the key for the 'revoke' audiences in the audiences map.
|
// RevokeAudienceKey is the key for the 'revoke' audiences in the audiences map.
|
||||||
RevokeAudienceKey = "revoke"
|
RevokeAudienceKey = "revoke"
|
||||||
|
@ -104,6 +106,8 @@ func (t Type) String() string {
|
||||||
return "AWS"
|
return "AWS"
|
||||||
case TypeAzure:
|
case TypeAzure:
|
||||||
return "Azure"
|
return "Azure"
|
||||||
|
case TypeACME:
|
||||||
|
return "ACME"
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -151,6 +155,8 @@ func (l *List) UnmarshalJSON(data []byte) error {
|
||||||
p = &AWS{}
|
p = &AWS{}
|
||||||
case "azure":
|
case "azure":
|
||||||
p = &Azure{}
|
p = &Azure{}
|
||||||
|
case "acme":
|
||||||
|
p = &ACME{}
|
||||||
default:
|
default:
|
||||||
// Skip unsupported provisioners. A client using this method may be
|
// Skip unsupported provisioners. A client using this method may be
|
||||||
// compiled with a version of smallstep/certificates that does not
|
// compiled with a version of smallstep/certificates that does not
|
||||||
|
@ -197,3 +203,93 @@ func SanitizeSSHUserPrincipal(email string) string {
|
||||||
}
|
}
|
||||||
}, strings.ToLower(email))
|
}, strings.ToLower(email))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MockProvisioner for testing
|
||||||
|
type MockProvisioner struct {
|
||||||
|
Mret1, Mret2, Mret3 interface{}
|
||||||
|
Merr error
|
||||||
|
MgetID func() string
|
||||||
|
MgetTokenID func(string) (string, error)
|
||||||
|
MgetName func() string
|
||||||
|
MgetType func() Type
|
||||||
|
MgetEncryptedKey func() (string, string, bool)
|
||||||
|
Minit func(Config) error
|
||||||
|
MauthorizeRevoke func(ott string) error
|
||||||
|
MauthorizeSign func(ctx context.Context, ott string) ([]SignOption, error)
|
||||||
|
MauthorizeRenewal func(*x509.Certificate) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID mock
|
||||||
|
func (m *MockProvisioner) GetID() string {
|
||||||
|
if m.MgetID != nil {
|
||||||
|
return m.MgetID()
|
||||||
|
}
|
||||||
|
return m.Mret1.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenID mock
|
||||||
|
func (m *MockProvisioner) GetTokenID(token string) (string, error) {
|
||||||
|
if m.MgetTokenID != nil {
|
||||||
|
return m.MgetTokenID(token)
|
||||||
|
}
|
||||||
|
if m.Mret1 == nil {
|
||||||
|
return "", m.Merr
|
||||||
|
}
|
||||||
|
return m.Mret1.(string), m.Merr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetName mock
|
||||||
|
func (m *MockProvisioner) GetName() string {
|
||||||
|
if m.MgetName != nil {
|
||||||
|
return m.MgetName()
|
||||||
|
}
|
||||||
|
return m.Mret1.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType mock
|
||||||
|
func (m *MockProvisioner) GetType() Type {
|
||||||
|
if m.MgetType != nil {
|
||||||
|
return m.MgetType()
|
||||||
|
}
|
||||||
|
return m.Mret1.(Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEncryptedKey mock
|
||||||
|
func (m *MockProvisioner) GetEncryptedKey() (string, string, bool) {
|
||||||
|
if m.MgetEncryptedKey != nil {
|
||||||
|
return m.MgetEncryptedKey()
|
||||||
|
}
|
||||||
|
return m.Mret1.(string), m.Mret2.(string), m.Mret3.(bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init mock
|
||||||
|
func (m *MockProvisioner) Init(c Config) error {
|
||||||
|
if m.Minit != nil {
|
||||||
|
return m.Minit(c)
|
||||||
|
}
|
||||||
|
return m.Merr
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRevoke mock
|
||||||
|
func (m *MockProvisioner) AuthorizeRevoke(ott string) error {
|
||||||
|
if m.MauthorizeRevoke != nil {
|
||||||
|
return m.MauthorizeRevoke(ott)
|
||||||
|
}
|
||||||
|
return m.Merr
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeSign mock
|
||||||
|
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]SignOption, error) {
|
||||||
|
if m.MauthorizeSign != nil {
|
||||||
|
return m.MauthorizeSign(ctx, ott)
|
||||||
|
}
|
||||||
|
return m.Mret1.([]SignOption), m.Merr
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRenewal mock
|
||||||
|
func (m *MockProvisioner) AuthorizeRenewal(c *x509.Certificate) error {
|
||||||
|
if m.MauthorizeRenewal != nil {
|
||||||
|
return m.MauthorizeRenewal(c)
|
||||||
|
}
|
||||||
|
return m.Merr
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/crypto/keys"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -298,8 +299,9 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if key.Size() < 256 {
|
if key.Size() < keys.MinRSAKeyBytes {
|
||||||
return errors.New("ssh certificate key must be at least 2048 bits (256 bytes)")
|
return errors.Errorf("ssh certificate key must be at least %d bits (%d bytes)",
|
||||||
|
8*keys.MinRSAKeyBytes, keys.MinRSAKeyBytes)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case ssh.KeyAlgoDSA:
|
case ssh.KeyAlgoDSA:
|
||||||
|
|
|
@ -730,3 +730,15 @@ func generateJWKServer(n int) *httptest.Server {
|
||||||
srv.Start()
|
srv.Start()
|
||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateACME() (*ACME, error) {
|
||||||
|
// Initialize provisioners
|
||||||
|
p := &ACME{
|
||||||
|
Type: "ACME",
|
||||||
|
Name: "test@acme-provisioner.com",
|
||||||
|
}
|
||||||
|
if err := p.Init(Config{Claims: globalProvisionerClaims}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
|
@ -35,3 +35,13 @@ func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisi
|
||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadProvisionerByID returns an interface to the provisioner with the given ID.
|
||||||
|
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||||
|
p, ok := a.provisioners.Load(id)
|
||||||
|
if !ok {
|
||||||
|
return nil, &apiError{errors.Errorf("provisioner not found"),
|
||||||
|
http.StatusNotFound, apiCtx{}}
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
8
authority/testdata/certs/badsig.csr
vendored
Normal file
8
authority/testdata/certs/badsig.csr
vendored
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
-----BEGIN CERTIFICATE REQUEST-----
|
||||||
|
MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI
|
||||||
|
zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI
|
||||||
|
cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ
|
||||||
|
DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ
|
||||||
|
ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5
|
||||||
|
OI+cWOIc/IGwqZul/zEF5dani5ihOL7UwA==
|
||||||
|
-----END CERTIFICATE REQUEST-----
|
8
authority/testdata/certs/foo.csr
vendored
Normal file
8
authority/testdata/certs/foo.csr
vendored
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
-----BEGIN CERTIFICATE REQUEST-----
|
||||||
|
MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI
|
||||||
|
zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI
|
||||||
|
cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ
|
||||||
|
DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ
|
||||||
|
ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5
|
||||||
|
OI+cWOIc/IGwqZul/zEF5dani5ihOR7UwA==
|
||||||
|
-----END CERTIFICATE REQUEST-----
|
354
ca/acmeClient.go
Normal file
354
ca/acmeClient.go
Normal file
|
@ -0,0 +1,354 @@
|
||||||
|
package ca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
acmeAPI "github.com/smallstep/certificates/acme/api"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ACMEClient implements an HTTP client to an ACME API.
|
||||||
|
type ACMEClient struct {
|
||||||
|
client *http.Client
|
||||||
|
dirLoc string
|
||||||
|
dir *acme.Directory
|
||||||
|
acc *acme.Account
|
||||||
|
Key *jose.JSONWebKey
|
||||||
|
kid string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewACMEClient initializes a new ACMEClient.
|
||||||
|
func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*ACMEClient, error) {
|
||||||
|
// Retrieve transport from options.
|
||||||
|
o := new(clientOptions)
|
||||||
|
if err := o.apply(opts); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tr, err := o.getTransport(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ac := &ACMEClient{
|
||||||
|
client: &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
},
|
||||||
|
dirLoc: endpoint,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ac.client.Get(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "client GET %s failed", endpoint)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
var dir acme.Directory
|
||||||
|
if err := readJSON(resp.Body, &dir); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
ac.dir = &dir
|
||||||
|
|
||||||
|
ac.Key, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nar := &acmeAPI.NewAccountRequest{
|
||||||
|
Contact: contact,
|
||||||
|
TermsOfServiceAgreed: true,
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(nar)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error marshaling new account request")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = ac.post(payload, ac.dir.NewAccount, withJWK(ac))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
var acc acme.Account
|
||||||
|
if err := readJSON(resp.Body, &acc); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", dir.NewAccount)
|
||||||
|
}
|
||||||
|
ac.acc = &acc
|
||||||
|
ac.kid = resp.Header.Get("Location")
|
||||||
|
|
||||||
|
return ac, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDirectory makes a directory request to the ACME api and returns an
|
||||||
|
// ACME directory object.
|
||||||
|
func (c *ACMEClient) GetDirectory() (*acme.Directory, error) {
|
||||||
|
return c.dir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNonce makes a nonce request to the ACME api and returns an
|
||||||
|
// ACME directory object.
|
||||||
|
func (c *ACMEClient) GetNonce() (string, error) {
|
||||||
|
resp, err := c.client.Get(c.dir.NewNonce)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return "", readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
return resp.Header.Get("Replay-Nonce"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type withHeaderOption func(so *jose.SignerOptions)
|
||||||
|
|
||||||
|
func withJWK(c *ACMEClient) withHeaderOption {
|
||||||
|
return func(so *jose.SignerOptions) {
|
||||||
|
so.WithHeader("jwk", c.Key.Public())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func withKid(c *ACMEClient) withHeaderOption {
|
||||||
|
return func(so *jose.SignerOptions) {
|
||||||
|
so.WithHeader("kid", c.kid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// serialize serializes a json web signature and doesn't omit empty fields.
|
||||||
|
func serialize(obj *jose.JSONWebSignature) (string, error) {
|
||||||
|
raw, err := obj.CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "error serializing JWS")
|
||||||
|
}
|
||||||
|
parts := strings.Split(raw, ".")
|
||||||
|
msg := struct {
|
||||||
|
Protected string `json:"protected"`
|
||||||
|
Payload string `json:"payload"`
|
||||||
|
Signature string `json:"signature"`
|
||||||
|
}{Protected: parts[0], Payload: parts[1], Signature: parts[2]}
|
||||||
|
b, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "error marshaling jws message")
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOption) (*http.Response, error) {
|
||||||
|
if c.Key == nil {
|
||||||
|
return nil, errors.New("acme client not configured with account")
|
||||||
|
}
|
||||||
|
nonce, err := c.GetNonce()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
so := new(jose.SignerOptions)
|
||||||
|
so.WithHeader("nonce", nonce)
|
||||||
|
so.WithHeader("url", url)
|
||||||
|
for _, hop := range headerOps {
|
||||||
|
hop(so)
|
||||||
|
}
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{
|
||||||
|
Algorithm: jose.SignatureAlgorithm(c.Key.Algorithm),
|
||||||
|
Key: c.Key.Key,
|
||||||
|
}, so)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error creating JWS signer")
|
||||||
|
}
|
||||||
|
signed, err := signer.Sign(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Errorf("error signing payload: %s", strings.TrimPrefix(err.Error(), "square/go-jose: "))
|
||||||
|
}
|
||||||
|
raw, err := serialize(signed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := c.client.Post(url, "application/jose+json", strings.NewReader(raw))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "client GET %s failed", c.dir.NewOrder)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOrder creates and returns the information for a new ACME order.
|
||||||
|
func (c *ACMEClient) NewOrder(payload []byte) (*acme.Order, error) {
|
||||||
|
resp, err := c.post(payload, c.dir.NewOrder, withKid(c))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var o acme.Order
|
||||||
|
if err := readJSON(resp.Body, &o); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", c.dir.NewOrder)
|
||||||
|
}
|
||||||
|
o.ID = resp.Header.Get("Location")
|
||||||
|
return &o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChallenge returns the Challenge at the given path.
|
||||||
|
// With the validate parameter set to True this method will attempt to validate the
|
||||||
|
// challenge before returning it.
|
||||||
|
func (c *ACMEClient) GetChallenge(url string) (*acme.Challenge, error) {
|
||||||
|
resp, err := c.post(nil, url, withKid(c))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ch acme.Challenge
|
||||||
|
if err := readJSON(resp.Body, &ch); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||||
|
}
|
||||||
|
return &ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateChallenge returns the Challenge at the given path.
|
||||||
|
// With the validate parameter set to True this method will attempt to validate the
|
||||||
|
// challenge before returning it.
|
||||||
|
func (c *ACMEClient) ValidateChallenge(url string) error {
|
||||||
|
resp, err := c.post([]byte("{}"), url, withKid(c))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthz returns the Authz at the given path.
|
||||||
|
func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) {
|
||||||
|
resp, err := c.post(nil, url, withKid(c))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var az acme.Authz
|
||||||
|
if err := readJSON(resp.Body, &az); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||||
|
}
|
||||||
|
return &az, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrder returns the Order at the given path.
|
||||||
|
func (c *ACMEClient) GetOrder(url string) (*acme.Order, error) {
|
||||||
|
resp, err := c.post(nil, url, withKid(c))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var o acme.Order
|
||||||
|
if err := readJSON(resp.Body, &o); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", url)
|
||||||
|
}
|
||||||
|
return &o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeOrder makes a finalize request to the ACME api.
|
||||||
|
func (c *ACMEClient) FinalizeOrder(url string, csr *x509.CertificateRequest) error {
|
||||||
|
payload, err := json.Marshal(acmeAPI.FinalizeRequest{
|
||||||
|
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error marshaling finalize request")
|
||||||
|
}
|
||||||
|
resp, err := c.post(payload, url, withKid(c))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCertificate retrieves the certificate along with all intermediates.
|
||||||
|
func (c *ACMEClient) GetCertificate(url string) (*x509.Certificate, []*x509.Certificate, error) {
|
||||||
|
resp, err := c.post(nil, url, withKid(c))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, nil, readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
bodyBytes, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "error reading GET certificate response")
|
||||||
|
}
|
||||||
|
|
||||||
|
var certs []*x509.Certificate
|
||||||
|
|
||||||
|
block, rest := pem.Decode(bodyBytes)
|
||||||
|
if block == nil {
|
||||||
|
return nil, nil, errors.New("failed to parse any certificates from response")
|
||||||
|
}
|
||||||
|
for block != nil {
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "error parsing certificate pem response")
|
||||||
|
}
|
||||||
|
certs = append(certs, cert)
|
||||||
|
block, rest = pem.Decode(rest)
|
||||||
|
}
|
||||||
|
|
||||||
|
return certs[0], certs[1:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountOrders retrieves the orders belonging to the given account.
|
||||||
|
func (c *ACMEClient) GetAccountOrders() ([]string, error) {
|
||||||
|
if c.acc == nil {
|
||||||
|
return nil, errors.New("acme client not configured with account")
|
||||||
|
}
|
||||||
|
resp, err := c.post(nil, c.acc.Orders, withKid(c))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readACMEError(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var orders []string
|
||||||
|
if err := readJSON(resp.Body, &orders); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders)
|
||||||
|
}
|
||||||
|
|
||||||
|
return orders, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readACMEError(r io.ReadCloser) error {
|
||||||
|
defer r.Close()
|
||||||
|
b, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error reading from body")
|
||||||
|
}
|
||||||
|
ae := new(acme.AError)
|
||||||
|
err = json.Unmarshal(b, &ae)
|
||||||
|
// If we successfully marshaled to an ACMEError then return the ACMEError.
|
||||||
|
if err != nil || len(ae.Error()) == 0 {
|
||||||
|
fmt.Printf("b = %s\n", b)
|
||||||
|
// Throw up our hands.
|
||||||
|
return errors.Errorf("%s", b)
|
||||||
|
}
|
||||||
|
return ae
|
||||||
|
}
|
1358
ca/acmeClient_test.go
Normal file
1358
ca/acmeClient_test.go
Normal file
File diff suppressed because it is too large
Load diff
52
ca/ca.go
52
ca/ca.go
|
@ -3,18 +3,23 @@ package ca
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
acmeAPI "github.com/smallstep/certificates/acme/api"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"github.com/smallstep/certificates/monitoring"
|
"github.com/smallstep/certificates/monitoring"
|
||||||
"github.com/smallstep/certificates/server"
|
"github.com/smallstep/certificates/server"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
)
|
)
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
|
@ -58,11 +63,12 @@ func WithDatabase(db db.AuthDB) Option {
|
||||||
// CA is the type used to build the complete certificate authority. It builds
|
// CA is the type used to build the complete certificate authority. It builds
|
||||||
// the HTTP server, set ups the middlewares and the HTTP handlers.
|
// the HTTP server, set ups the middlewares and the HTTP handlers.
|
||||||
type CA struct {
|
type CA struct {
|
||||||
auth *authority.Authority
|
auth *authority.Authority
|
||||||
config *authority.Config
|
acmeAuth *acme.Authority
|
||||||
srv *server.Server
|
config *authority.Config
|
||||||
opts *options
|
srv *server.Server
|
||||||
renewer *TLSRenewer
|
opts *options
|
||||||
|
renewer *TLSRenewer
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates and initializes the CA with the given configuration and options.
|
// New creates and initializes the CA with the given configuration and options.
|
||||||
|
@ -100,13 +106,47 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) {
|
||||||
mux := chi.NewRouter()
|
mux := chi.NewRouter()
|
||||||
handler := http.Handler(mux)
|
handler := http.Handler(mux)
|
||||||
|
|
||||||
// Add api endpoints in / and /1.0
|
// Add regular CA api endpoints in / and /1.0
|
||||||
routerHandler := api.New(auth)
|
routerHandler := api.New(auth)
|
||||||
routerHandler.Route(mux)
|
routerHandler.Route(mux)
|
||||||
mux.Route("/1.0", func(r chi.Router) {
|
mux.Route("/1.0", func(r chi.Router) {
|
||||||
routerHandler.Route(r)
|
routerHandler.Route(r)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
//Add ACME api endpoints in /acme and /1.0/acme
|
||||||
|
dns := config.DNSNames[0]
|
||||||
|
u, err := url.Parse("https://" + config.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
port := u.Port()
|
||||||
|
if port != "" && port != "443" {
|
||||||
|
dns = fmt.Sprintf("%s:%s", dns, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "acme"
|
||||||
|
acmeAuth := acme.NewAuthority(auth.GetDatabase().(nosql.DB), dns, prefix, auth)
|
||||||
|
acmeRouterHandler := acmeAPI.New(acmeAuth)
|
||||||
|
mux.Route("/"+prefix, func(r chi.Router) {
|
||||||
|
acmeRouterHandler.Route(r)
|
||||||
|
})
|
||||||
|
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
|
||||||
|
// of the ACME spec.
|
||||||
|
mux.Route("/2.0/"+prefix, func(r chi.Router) {
|
||||||
|
acmeRouterHandler.Route(r)
|
||||||
|
})
|
||||||
|
|
||||||
|
/*
|
||||||
|
// helpful routine for logging all routes //
|
||||||
|
walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
|
||||||
|
fmt.Printf("%s %s\n", method, route)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := chi.Walk(mux, walkFunc); err != nil {
|
||||||
|
fmt.Printf("Logging err: %s\n", err.Error())
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
// Add monitoring if configured
|
// Add monitoring if configured
|
||||||
if len(config.Monitoring) > 0 {
|
if len(config.Monitoring) > 0 {
|
||||||
m, err := monitoring.New(config.Monitoring)
|
m, err := monitoring.New(config.Monitoring)
|
||||||
|
|
|
@ -163,8 +163,7 @@ func TestClient_Health(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Health()
|
got, err := c.Health()
|
||||||
|
@ -224,8 +223,7 @@ func TestClient_Root(t *testing.T) {
|
||||||
if req.RequestURI != expected {
|
if req.RequestURI != expected {
|
||||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
||||||
}
|
}
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Root(tt.shasum)
|
got, err := c.Root(tt.shasum)
|
||||||
|
@ -303,8 +301,7 @@ func TestClient_Sign(t *testing.T) {
|
||||||
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
|
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Sign(tt.request)
|
got, err := c.Sign(tt.request)
|
||||||
|
@ -378,8 +375,7 @@ func TestClient_Revoke(t *testing.T) {
|
||||||
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
|
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Revoke(tt.request, nil)
|
got, err := c.Revoke(tt.request, nil)
|
||||||
|
@ -438,8 +434,7 @@ func TestClient_Renew(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Renew(nil)
|
got, err := c.Renew(nil)
|
||||||
|
@ -502,8 +497,7 @@ func TestClient_Provisioners(t *testing.T) {
|
||||||
if req.RequestURI != tt.expectedURI {
|
if req.RequestURI != tt.expectedURI {
|
||||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
|
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
|
||||||
}
|
}
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Provisioners(tt.args...)
|
got, err := c.Provisioners(tt.args...)
|
||||||
|
@ -562,8 +556,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
||||||
if req.RequestURI != expected {
|
if req.RequestURI != expected {
|
||||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
||||||
}
|
}
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.ProvisionerKey(tt.kid)
|
got, err := c.ProvisionerKey(tt.kid)
|
||||||
|
@ -622,8 +615,7 @@ func TestClient_Roots(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Roots()
|
got, err := c.Roots()
|
||||||
|
@ -683,8 +675,7 @@ func TestClient_Federation(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Federation()
|
got, err := c.Federation()
|
||||||
|
@ -783,8 +774,7 @@ func TestClient_RootFingerprint(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
w.WriteHeader(tt.responseCode)
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
api.JSON(w, tt.response)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.RootFingerprint()
|
got, err := c.RootFingerprint()
|
||||||
|
|
129
db/db.go
129
db/db.go
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
|
"github.com/smallstep/nosql/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -102,22 +103,20 @@ func (db *DB) IsRevoked(sn string) (bool, error) {
|
||||||
|
|
||||||
// Revoke adds a certificate to the revocation table.
|
// Revoke adds a certificate to the revocation table.
|
||||||
func (db *DB) Revoke(rci *RevokedCertificateInfo) error {
|
func (db *DB) Revoke(rci *RevokedCertificateInfo) error {
|
||||||
isRvkd, err := db.IsRevoked(rci.Serial)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if isRvkd {
|
|
||||||
return ErrAlreadyExists
|
|
||||||
}
|
|
||||||
rcib, err := json.Marshal(rci)
|
rcib, err := json.Marshal(rci)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "error marshaling revoked certificate info")
|
return errors.Wrap(err, "error marshaling revoked certificate info")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = db.Set(revokedCertsTable, []byte(rci.Serial), rcib); err != nil {
|
_, swapped, err := db.CmpAndSwap(revokedCertsTable, []byte(rci.Serial), nil, rcib)
|
||||||
return errors.Wrap(err, "database Set error")
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return errors.Wrap(err, "error AuthDB CmpAndSwap")
|
||||||
|
case !swapped:
|
||||||
|
return ErrAlreadyExists
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreCertificate stores a certificate PEM.
|
// StoreCertificate stores a certificate PEM.
|
||||||
|
@ -132,15 +131,11 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
|
||||||
// for the first time, false otherwise.
|
// for the first time, false otherwise.
|
||||||
func (db *DB) UseToken(id, tok string) (bool, error) {
|
func (db *DB) UseToken(id, tok string) (bool, error) {
|
||||||
_, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok))
|
_, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok))
|
||||||
switch {
|
if err != nil {
|
||||||
case err != nil:
|
|
||||||
return false, errors.Wrapf(err, "error storing used token %s/%s",
|
return false, errors.Wrapf(err, "error storing used token %s/%s",
|
||||||
string(usedOTTTable), id)
|
string(usedOTTTable), id)
|
||||||
case !swapped:
|
|
||||||
return false, nil
|
|
||||||
default:
|
|
||||||
return true, nil
|
|
||||||
}
|
}
|
||||||
|
return swapped, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown sends a shutdown message to the database.
|
// Shutdown sends a shutdown message to the database.
|
||||||
|
@ -153,3 +148,105 @@ func (db *DB) Shutdown() error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MockNoSQLDB //
|
||||||
|
type MockNoSQLDB struct {
|
||||||
|
Err error
|
||||||
|
Ret1, Ret2 interface{}
|
||||||
|
MGet func(bucket, key []byte) ([]byte, error)
|
||||||
|
MSet func(bucket, key, value []byte) error
|
||||||
|
MOpen func(dataSourceName string, opt ...database.Option) error
|
||||||
|
MClose func() error
|
||||||
|
MCreateTable func(bucket []byte) error
|
||||||
|
MDeleteTable func(bucket []byte) error
|
||||||
|
MDel func(bucket, key []byte) error
|
||||||
|
MList func(bucket []byte) ([]*database.Entry, error)
|
||||||
|
MUpdate func(tx *database.Tx) error
|
||||||
|
MCmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CmpAndSwap mock
|
||||||
|
func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
|
if m.MCmpAndSwap != nil {
|
||||||
|
return m.MCmpAndSwap(bucket, key, old, newval)
|
||||||
|
}
|
||||||
|
if m.Ret1 == nil {
|
||||||
|
return nil, false, m.Err
|
||||||
|
}
|
||||||
|
return m.Ret1.([]byte), m.Ret2.(bool), m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get mock
|
||||||
|
func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) {
|
||||||
|
if m.MGet != nil {
|
||||||
|
return m.MGet(bucket, key)
|
||||||
|
}
|
||||||
|
if m.Ret1 == nil {
|
||||||
|
return nil, m.Err
|
||||||
|
}
|
||||||
|
return m.Ret1.([]byte), m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set mock
|
||||||
|
func (m *MockNoSQLDB) Set(bucket, key, value []byte) error {
|
||||||
|
if m.MSet != nil {
|
||||||
|
return m.MSet(bucket, key, value)
|
||||||
|
}
|
||||||
|
return m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open mock
|
||||||
|
func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error {
|
||||||
|
if m.MOpen != nil {
|
||||||
|
return m.MOpen(dataSourceName, opt...)
|
||||||
|
}
|
||||||
|
return m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close mock
|
||||||
|
func (m *MockNoSQLDB) Close() error {
|
||||||
|
if m.MClose != nil {
|
||||||
|
return m.MClose()
|
||||||
|
}
|
||||||
|
return m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTable mock
|
||||||
|
func (m *MockNoSQLDB) CreateTable(bucket []byte) error {
|
||||||
|
if m.MCreateTable != nil {
|
||||||
|
return m.MCreateTable(bucket)
|
||||||
|
}
|
||||||
|
return m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteTable mock
|
||||||
|
func (m *MockNoSQLDB) DeleteTable(bucket []byte) error {
|
||||||
|
if m.MDeleteTable != nil {
|
||||||
|
return m.MDeleteTable(bucket)
|
||||||
|
}
|
||||||
|
return m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Del mock
|
||||||
|
func (m *MockNoSQLDB) Del(bucket, key []byte) error {
|
||||||
|
if m.MDel != nil {
|
||||||
|
return m.MDel(bucket, key)
|
||||||
|
}
|
||||||
|
return m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// List mock
|
||||||
|
func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) {
|
||||||
|
if m.MList != nil {
|
||||||
|
return m.MList(bucket)
|
||||||
|
}
|
||||||
|
return m.Ret1.([]*database.Entry), m.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update mock
|
||||||
|
func (m *MockNoSQLDB) Update(tx *database.Tx) error {
|
||||||
|
if m.MUpdate != nil {
|
||||||
|
return m.MUpdate(tx)
|
||||||
|
}
|
||||||
|
return m.Err
|
||||||
|
}
|
||||||
|
|
132
db/db_test.go
132
db/db_test.go
|
@ -8,97 +8,6 @@ import (
|
||||||
"github.com/smallstep/nosql/database"
|
"github.com/smallstep/nosql/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MockNoSQLDB struct {
|
|
||||||
err error
|
|
||||||
ret1, ret2 interface{}
|
|
||||||
get func(bucket, key []byte) ([]byte, error)
|
|
||||||
set func(bucket, key, value []byte) error
|
|
||||||
open func(dataSourceName string, opt ...database.Option) error
|
|
||||||
close func() error
|
|
||||||
createTable func(bucket []byte) error
|
|
||||||
deleteTable func(bucket []byte) error
|
|
||||||
del func(bucket, key []byte) error
|
|
||||||
list func(bucket []byte) ([]*database.Entry, error)
|
|
||||||
update func(tx *database.Tx) error
|
|
||||||
cmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
|
||||||
if m.cmpAndSwap != nil {
|
|
||||||
return m.cmpAndSwap(bucket, key, old, newval)
|
|
||||||
}
|
|
||||||
if m.ret1 == nil {
|
|
||||||
return nil, false, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.([]byte), m.ret2.(bool), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) {
|
|
||||||
if m.get != nil {
|
|
||||||
return m.get(bucket, key)
|
|
||||||
}
|
|
||||||
if m.ret1 == nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.ret1.([]byte), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) Set(bucket, key, value []byte) error {
|
|
||||||
if m.set != nil {
|
|
||||||
return m.set(bucket, key, value)
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error {
|
|
||||||
if m.open != nil {
|
|
||||||
return m.open(dataSourceName, opt...)
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) Close() error {
|
|
||||||
if m.close != nil {
|
|
||||||
return m.close()
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) CreateTable(bucket []byte) error {
|
|
||||||
if m.createTable != nil {
|
|
||||||
return m.createTable(bucket)
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) DeleteTable(bucket []byte) error {
|
|
||||||
if m.deleteTable != nil {
|
|
||||||
return m.deleteTable(bucket)
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) Del(bucket, key []byte) error {
|
|
||||||
if m.del != nil {
|
|
||||||
return m.del(bucket, key)
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) {
|
|
||||||
if m.list != nil {
|
|
||||||
return m.list(bucket)
|
|
||||||
}
|
|
||||||
return m.ret1.([]*database.Entry), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockNoSQLDB) Update(tx *database.Tx) error {
|
|
||||||
if m.update != nil {
|
|
||||||
return m.update(tx)
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsRevoked(t *testing.T) {
|
func TestIsRevoked(t *testing.T) {
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
key string
|
key string
|
||||||
|
@ -111,16 +20,16 @@ func TestIsRevoked(t *testing.T) {
|
||||||
},
|
},
|
||||||
"false/ErrNotFound": {
|
"false/ErrNotFound": {
|
||||||
key: "sn",
|
key: "sn",
|
||||||
db: &DB{&MockNoSQLDB{err: database.ErrNotFound, ret1: nil}, true},
|
db: &DB{&MockNoSQLDB{Err: database.ErrNotFound, Ret1: nil}, true},
|
||||||
},
|
},
|
||||||
"error/checking bucket": {
|
"error/checking bucket": {
|
||||||
key: "sn",
|
key: "sn",
|
||||||
db: &DB{&MockNoSQLDB{err: errors.New("force"), ret1: nil}, true},
|
db: &DB{&MockNoSQLDB{Err: errors.New("force"), Ret1: nil}, true},
|
||||||
err: errors.New("error checking revocation bucket: force"),
|
err: errors.New("error checking revocation bucket: force"),
|
||||||
},
|
},
|
||||||
"true": {
|
"true": {
|
||||||
key: "sn",
|
key: "sn",
|
||||||
db: &DB{&MockNoSQLDB{ret1: []byte("value")}, true},
|
db: &DB{&MockNoSQLDB{Ret1: []byte("value")}, true},
|
||||||
isRevoked: true,
|
isRevoked: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -148,41 +57,26 @@ func TestRevoke(t *testing.T) {
|
||||||
"error/force isRevoked": {
|
"error/force isRevoked": {
|
||||||
rci: &RevokedCertificateInfo{Serial: "sn"},
|
rci: &RevokedCertificateInfo{Serial: "sn"},
|
||||||
db: &DB{&MockNoSQLDB{
|
db: &DB{&MockNoSQLDB{
|
||||||
get: func(bucket []byte, sn []byte) ([]byte, error) {
|
MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
|
||||||
return nil, errors.New("force IsRevoked")
|
return nil, false, errors.New("force")
|
||||||
},
|
},
|
||||||
}, true},
|
}, true},
|
||||||
err: errors.New("error checking revocation bucket: force IsRevoked"),
|
err: errors.New("error AuthDB CmpAndSwap: force"),
|
||||||
},
|
},
|
||||||
"error/was already revoked": {
|
"error/was already revoked": {
|
||||||
rci: &RevokedCertificateInfo{Serial: "sn"},
|
rci: &RevokedCertificateInfo{Serial: "sn"},
|
||||||
db: &DB{&MockNoSQLDB{
|
db: &DB{&MockNoSQLDB{
|
||||||
get: func(bucket []byte, sn []byte) ([]byte, error) {
|
MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
|
||||||
return nil, nil
|
return []byte("foo"), false, nil
|
||||||
},
|
},
|
||||||
}, true},
|
}, true},
|
||||||
err: ErrAlreadyExists,
|
err: ErrAlreadyExists,
|
||||||
},
|
},
|
||||||
"error/database set": {
|
|
||||||
rci: &RevokedCertificateInfo{Serial: "sn"},
|
|
||||||
db: &DB{&MockNoSQLDB{
|
|
||||||
get: func(bucket []byte, sn []byte) ([]byte, error) {
|
|
||||||
return nil, database.ErrNotFound
|
|
||||||
},
|
|
||||||
set: func(bucket []byte, key []byte, value []byte) error {
|
|
||||||
return errors.New("force")
|
|
||||||
},
|
|
||||||
}, true},
|
|
||||||
err: errors.New("database Set error: force"),
|
|
||||||
},
|
|
||||||
"ok": {
|
"ok": {
|
||||||
rci: &RevokedCertificateInfo{Serial: "sn"},
|
rci: &RevokedCertificateInfo{Serial: "sn"},
|
||||||
db: &DB{&MockNoSQLDB{
|
db: &DB{&MockNoSQLDB{
|
||||||
get: func(bucket []byte, sn []byte) ([]byte, error) {
|
MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
|
||||||
return nil, database.ErrNotFound
|
return []byte("foo"), true, nil
|
||||||
},
|
|
||||||
set: func(bucket []byte, key []byte, value []byte) error {
|
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
}, true},
|
}, true},
|
||||||
},
|
},
|
||||||
|
@ -214,7 +108,7 @@ func TestUseToken(t *testing.T) {
|
||||||
id: "id",
|
id: "id",
|
||||||
tok: "token",
|
tok: "token",
|
||||||
db: &DB{&MockNoSQLDB{
|
db: &DB{&MockNoSQLDB{
|
||||||
cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
return nil, false, errors.New("force")
|
return nil, false, errors.New("force")
|
||||||
},
|
},
|
||||||
}, true},
|
}, true},
|
||||||
|
@ -227,7 +121,7 @@ func TestUseToken(t *testing.T) {
|
||||||
id: "id",
|
id: "id",
|
||||||
tok: "token",
|
tok: "token",
|
||||||
db: &DB{&MockNoSQLDB{
|
db: &DB{&MockNoSQLDB{
|
||||||
cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
return []byte("foo"), false, nil
|
return []byte("foo"), false, nil
|
||||||
},
|
},
|
||||||
}, true},
|
}, true},
|
||||||
|
@ -239,7 +133,7 @@ func TestUseToken(t *testing.T) {
|
||||||
id: "id",
|
id: "id",
|
||||||
tok: "token",
|
tok: "token",
|
||||||
db: &DB{&MockNoSQLDB{
|
db: &DB{&MockNoSQLDB{
|
||||||
cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||||||
return []byte("bar"), true, nil
|
return []byte("bar"), true, nil
|
||||||
},
|
},
|
||||||
}, true},
|
}, true},
|
||||||
|
|
55
db/simple.go
55
db/simple.go
|
@ -6,6 +6,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/nosql/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrNotImplemented is an error returned when an operation is Not Implemented.
|
// ErrNotImplemented is an error returned when an operation is Not Implemented.
|
||||||
|
@ -61,3 +62,57 @@ func (s *SimpleDB) UseToken(id, tok string) (bool, error) {
|
||||||
func (s *SimpleDB) Shutdown() error {
|
func (s *SimpleDB) Shutdown() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nosql.DB interface implementation //
|
||||||
|
|
||||||
|
// Open opens the database available with the given options.
|
||||||
|
func (s *SimpleDB) Open(dataSourceName string, opt ...database.Option) error {
|
||||||
|
return ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the current database.
|
||||||
|
func (s *SimpleDB) Close() error {
|
||||||
|
return ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the value stored in the given table/bucket and key.
|
||||||
|
func (s *SimpleDB) Get(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets the given value in the given table/bucket and key.
|
||||||
|
func (s *SimpleDB) Set(bucket, key, value []byte) error {
|
||||||
|
return ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// CmpAndSwap swaps the value at the given bucket and key if the current
|
||||||
|
// value is equivalent to the oldValue input. Returns 'true' if the
|
||||||
|
// swap was successful and 'false' otherwise.
|
||||||
|
func (s *SimpleDB) CmpAndSwap(bucket, key, oldValue, newValue []byte) ([]byte, bool, error) {
|
||||||
|
return nil, false, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// Del deletes the data in the given table/bucket and key.
|
||||||
|
func (s *SimpleDB) Del(bucket, key []byte) error {
|
||||||
|
return ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns a list of all the entries in a given table/bucket.
|
||||||
|
func (s *SimpleDB) List(bucket []byte) ([]*database.Entry, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update performs a transaction with multiple read-write commands.
|
||||||
|
func (s *SimpleDB) Update(tx *database.Tx) error {
|
||||||
|
return ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTable creates a table or a bucket in the database.
|
||||||
|
func (s *SimpleDB) CreateTable(bucket []byte) error {
|
||||||
|
return ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteTable deletes a table or a bucket in the database.
|
||||||
|
func (s *SimpleDB) DeleteTable(bucket []byte) error {
|
||||||
|
return ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
160
docs/acme.md
Normal file
160
docs/acme.md
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
# Using ACME with `step-ca `
|
||||||
|
|
||||||
|
Let’s assume you’ve [installed
|
||||||
|
`step-ca`](https://smallstep.com/docs/getting-started/#1-installing-step-and-step-ca)
|
||||||
|
(e.g., using `brew install step`), have it running at `https://ca.internal`,
|
||||||
|
and you’ve [bootstrapped your ACME client
|
||||||
|
system(s)](https://smallstep.com/docs/getting-started/#bootstrapping) (or at
|
||||||
|
least [installed your root
|
||||||
|
certificate](https://smallstep.com/docs/cli/ca/root/) at
|
||||||
|
`~/.step/certs/root_ca.crt`).
|
||||||
|
|
||||||
|
## Enabling ACME
|
||||||
|
|
||||||
|
To enable ACME, simply [add an ACME provisioner](https://smallstep.com/docs/cli/ca/provisioner/add/) to your `step-ca` configuration
|
||||||
|
by running:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ step ca provisioner add my-acme-provisioner --type ACME
|
||||||
|
```
|
||||||
|
|
||||||
|
> NOTE: The above command will add a new provisioner of type `ACME` and name
|
||||||
|
> `my-acme-provisioner`. The name is used to identify the provisioner
|
||||||
|
> (e.g. you cannot have two `ACME` provisioners with the same name).
|
||||||
|
|
||||||
|
Now restart or SIGHUP `step-ca` to pick up the new configuration.
|
||||||
|
|
||||||
|
That’s it.
|
||||||
|
|
||||||
|
## Configuring Clients
|
||||||
|
|
||||||
|
To configure an ACME client to connect to `step-ca` you need to:
|
||||||
|
|
||||||
|
1. Point the client at the right ACME directory URL
|
||||||
|
2. Tell the client to trust your CA’s root certificate
|
||||||
|
|
||||||
|
Once certificates are issued, you’ll also need to ensure they’re renewed before
|
||||||
|
they expire.
|
||||||
|
|
||||||
|
### Pointing Clients at the right ACME Directory URL
|
||||||
|
|
||||||
|
Most ACME clients connect to Let’s Encrypt by default. To connect to `step-ca`
|
||||||
|
you need to point the client at the right [ACME directory
|
||||||
|
URL](https://tools.ietf.org/html/rfc8555#section-7.1.1).
|
||||||
|
|
||||||
|
A single instance of `step-ca` can have multiple ACME provisioners, each with
|
||||||
|
their own ACME directory URL that looks like:
|
||||||
|
|
||||||
|
```
|
||||||
|
https://{ca-host}/acme/{provisioner-name}/directory
|
||||||
|
```
|
||||||
|
|
||||||
|
We just added an ACME provisioner named “acme”. Its ACME directory URL is:
|
||||||
|
|
||||||
|
```
|
||||||
|
https://ca.internal/acme/acme/directory
|
||||||
|
```
|
||||||
|
|
||||||
|
### Telling clients to trust your CA’s root certificate
|
||||||
|
|
||||||
|
Communication between an ACME client and server [always uses
|
||||||
|
HTTPS](https://tools.ietf.org/html/rfc8555#section-6.1). By default, client’s
|
||||||
|
will validate the server’s HTTPS certificate using the public root certificates
|
||||||
|
in your system’s [default
|
||||||
|
trust](https://smallstep.com/blog/everything-pki.html#trust-stores) store.
|
||||||
|
That’s fine when you’re connecting to Let’s Encrypt: it’s a public CA and its
|
||||||
|
root certificate is in your system’s default trust store already. Your internal
|
||||||
|
root certificate isn’t, so HTTPS connections from ACME clients to `step-ca` will
|
||||||
|
fail.
|
||||||
|
|
||||||
|
There are two ways to address this problem:
|
||||||
|
|
||||||
|
1. Explicitly configure your ACME client to trust `step-ca`'s root certificate, or
|
||||||
|
2. Add `step-ca`'s root certificate to your system’s default trust store (e.g.,
|
||||||
|
using `[step certificate
|
||||||
|
install](https://smallstep.com/docs/cli/certificate/install/)`)
|
||||||
|
|
||||||
|
If you’re using your CA for TLS in production, explicitly configuring your ACME
|
||||||
|
client to only trust your root certificate is a better option. We’ll
|
||||||
|
demonstrate this method with several clients below.
|
||||||
|
|
||||||
|
If you’re simulating Let’s Encrypt in pre-production, installing your root
|
||||||
|
certificate is a more faithful simulation of production. Once your root
|
||||||
|
certificate is installed, no additional client configuration is necessary.
|
||||||
|
|
||||||
|
> Caution: adding a root certificate to your system’s trust store is a global
|
||||||
|
> operation. Certificates issued by your CA will be trusted everywhere,
|
||||||
|
> including in web browsers.
|
||||||
|
|
||||||
|
### Example using [`certbot`](https://certbot.eff.org/)
|
||||||
|
|
||||||
|
[`certbot`](https://certbot.eff.org/) is the grandaddy of ACME clients. Built
|
||||||
|
and supported by [the EFF](https://www.eff.org/), it’s the standard-bearer for
|
||||||
|
production-grade command-line ACME.
|
||||||
|
|
||||||
|
To get a certificate from `step-ca` using `certbot` you need to:
|
||||||
|
|
||||||
|
1. Point `certbot` at your ACME directory URL using the `--`server flag.
|
||||||
|
2. Tell `certbot` to trust your root certificate using the `REQUESTS_CA_BUNDLE` environment variable.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ sudo REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt \
|
||||||
|
certbot certonly -n --standalone -d foo.internal \
|
||||||
|
--server https://ca.internal/acme/acme/directory
|
||||||
|
```
|
||||||
|
|
||||||
|
`sudo` is required in `certbot`'s [*standalone*
|
||||||
|
mode](https://certbot.eff.org/docs/using.html#standalone) so it can listen on
|
||||||
|
port 80 to complete the `http-01` challenge. If you already have a webserver
|
||||||
|
running you can use [*webroot*
|
||||||
|
mode](https://certbot.eff.org/docs/using.html#webroot) instead. With the
|
||||||
|
[appropriate plugin](https://certbot.eff.org/docs/using.html#dns-plugins)
|
||||||
|
`certbot` also supports the `dns-01` challenge for most popular DNS providers.
|
||||||
|
Deeper integrations with [nginx](https://certbot.eff.org/docs/using.html#nginx)
|
||||||
|
and [apache](https://certbot.eff.org/docs/using.html#apache) can even configure
|
||||||
|
your server to use HTTPS automatically (we'll set this up ourselves later). All
|
||||||
|
of this works with `step-ca`.
|
||||||
|
|
||||||
|
You can renew all of the certificates you've installed using `cerbot` by running:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ sudo REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt certbot renew
|
||||||
|
```
|
||||||
|
|
||||||
|
You can automate renewal with a simple `cron` entry:
|
||||||
|
|
||||||
|
```
|
||||||
|
*/15 * * * * root REQUESTS_CA_BUNDLE=$(step path)/certs/root_ca.crt certbot -q renew
|
||||||
|
```
|
||||||
|
|
||||||
|
The `certbot` packages for some Linux distributions will create a `cron` entry
|
||||||
|
or [systemd
|
||||||
|
timer](https://stevenwestmoreland.com/2017/11/renewing-certbot-certificates-using-a-systemd-timer.html)
|
||||||
|
like this for you. This entry won't work with `step-ca` because it [doesn't set
|
||||||
|
the `REQUESTS_CA_BUNDLE` environment
|
||||||
|
variable](https://github.com/certbot/certbot/issues/7170). You'll need to
|
||||||
|
manually tweak it to do so.
|
||||||
|
|
||||||
|
More subtly, `certbot`'s default renewal job is tuned for Let's Encrypt's 90
|
||||||
|
day certificate lifetimes: it's run every 12 hours, with actual renewals
|
||||||
|
occurring for certificates within 30 days of expiry. By default, `step-ca`
|
||||||
|
issues certificates with *much shorter* 24 hour lifetimes. The `cron` entry
|
||||||
|
above accounts for this by running `certbot renew` every 15 minutes. You'll
|
||||||
|
also want to configure your domain to only renew certificates when they're
|
||||||
|
within a few hours of expiry by adding a line like:
|
||||||
|
|
||||||
|
```
|
||||||
|
renew_before_expiry = 8 hours
|
||||||
|
```
|
||||||
|
|
||||||
|
to the top of your renewal configuration (e.g., in `/etc/letsencrypt/renewal/foo.internal.conf`).
|
||||||
|
|
||||||
|
## Feedback
|
||||||
|
|
||||||
|
`step-ca` should work with any ACMEv2
|
||||||
|
([RFC8555](https://tools.ietf.org/html/rfc8555)) compliant client that supports
|
||||||
|
the http-01 or dns-01 challenge. If you run into any issues please let us know
|
||||||
|
[on gitter](https://gitter.im/smallstep/community) or [in an
|
||||||
|
issue](https://github.com/smallstep/certificates/issues/new?template=bug_report.md).
|
Loading…
Reference in a new issue