diff --git a/acme/account.go b/acme/account.go index 1c5870d5..197a3400 100644 --- a/acme/account.go +++ b/acme/account.go @@ -1,197 +1,42 @@ package acme import ( - "context" + "crypto" + "encoding/base64" "encoding/json" - "time" - "github.com/pkg/errors" - "github.com/smallstep/nosql" "go.step.sm/crypto/jose" ) // Account is a subset of the internal account type containing only those // attributes required for responses in the ACME protocol. type Account struct { - Contact []string `json:"contact,omitempty"` - Status string `json:"status"` - Orders string `json:"orders"` - ID string `json:"-"` - Key *jose.JSONWebKey `json:"-"` + ID string `json:"-"` + Key *jose.JSONWebKey `json:"-"` + Contact []string `json:"contact,omitempty"` + Status Status `json:"status"` + OrdersURL string `json:"orders"` } // ToLog enables response logging. func (a *Account) ToLog() (interface{}, error) { b, err := json.Marshal(a) if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging")) + return nil, WrapErrorISE(err, "error marshaling account for logging") } return string(b), nil } -// GetID returns the account ID. -func (a *Account) GetID() string { - return a.ID -} - -// GetKey returns the JWK associated with the account. -func (a *Account) GetKey() *jose.JSONWebKey { - return a.Key -} - // IsValid returns true if the Account is valid. func (a *Account) IsValid() bool { - return a.Status == StatusValid + return Status(a.Status) == StatusValid } -// AccountOptions are the options needed to create a new ACME account. -type AccountOptions struct { - Key *jose.JSONWebKey - Contact []string -} - -// account represents an ACME account. -type account struct { - ID string `json:"id"` - Created time.Time `json:"created"` - Deactivated time.Time `json:"deactivated"` - Key *jose.JSONWebKey `json:"key"` - Contact []string `json:"contact,omitempty"` - Status string `json:"status"` -} - -// newAccount returns a new acme account type. -func newAccount(db nosql.DB, ops AccountOptions) (*account, error) { - id, err := randID() +// KeyToID converts a JWK to a thumbprint. +func KeyToID(jwk *jose.JSONWebKey) (string, error) { + kid, err := jwk.Thumbprint(crypto.SHA256) if err != nil { - return nil, err + return "", WrapErrorISE(err, "error generating jwk thumbprint") } - - a := &account{ - ID: id, - Key: ops.Key, - Contact: ops.Contact, - Status: "valid", - Created: clock.Now(), - } - return a, a.saveNew(db) -} - -// toACME converts the internal Account type into the public acmeAccount -// type for presentation in the ACME protocol. -func (a *account) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Account, error) { - return &Account{ - Status: a.Status, - Contact: a.Contact, - Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID), - Key: a.Key, - ID: a.ID, - }, nil -} - -// save writes the Account to the DB. -// If the account is new then the necessary indices will be created. -// Else, the account in the DB will be updated. -func (a *account) saveNew(db nosql.DB) error { - kid, err := keyToID(a.Key) - if err != nil { - return err - } - kidB := []byte(kid) - - // Set the jwkID -> acme account ID index - _, swapped, err := db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID)) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index")) - case !swapped: - return ServerInternalErr(errors.Errorf("key-id to account-id index already exists")) - default: - if err = a.save(db, nil); err != nil { - db.Del(accountByKeyIDTable, kidB) - return err - } - return nil - } -} - -func (a *account) save(db nosql.DB, old *account) error { - var ( - err error - oldB []byte - ) - if old == nil { - oldB = nil - } else { - if oldB, err = json.Marshal(old); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) - } - } - - b, err := json.Marshal(*a) - if err != nil { - return errors.Wrap(err, "error marshaling new account object") - } - // Set the Account - _, swapped, err := db.CmpAndSwap(accountTable, []byte(a.ID), oldB, b) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error storing account")) - case !swapped: - return ServerInternalErr(errors.New("error storing account; " + - "value has changed since last read")) - default: - return nil - } -} - -// update updates the acme account object stored in the database if, -// and only if, the account has not changed since the last read. -func (a *account) update(db nosql.DB, contact []string) (*account, error) { - b := *a - b.Contact = contact - if err := b.save(db, a); err != nil { - return nil, err - } - return &b, nil -} - -// deactivate deactivates the acme account. -func (a *account) deactivate(db nosql.DB) (*account, error) { - b := *a - b.Status = StatusDeactivated - b.Deactivated = clock.Now() - if err := b.save(db, a); err != nil { - return nil, err - } - return &b, nil -} - -// getAccountByID retrieves the account with the given ID. -func getAccountByID(db nosql.DB, id string) (*account, error) { - ab, err := db.Get(accountTable, []byte(id)) - if err != nil { - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id)) - } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id)) - } - - a := new(account) - if err = json.Unmarshal(ab, a); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account")) - } - return a, nil -} - -// getAccountByKeyID retrieves Id associated with the given Kid. -func getAccountByKeyID(db nosql.DB, kid string) (*account, error) { - id, err := db.Get(accountByKeyIDTable, []byte(kid)) - if err != nil { - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid)) - } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index")) - } - return getAccountByID(db, string(id)) + return base64.RawURLEncoding.EncodeToString(kid), nil } diff --git a/acme/account_test.go b/acme/account_test.go index 2e072af5..5625c3dc 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -1,770 +1,81 @@ package acme import ( - "context" - "encoding/json" - "fmt" - "net/url" + "crypto" + "encoding/base64" "testing" - "time" "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" "go.step.sm/crypto/jose" ) -var ( - defaultDisableRenewal = false - globalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, - } -) - -func newProv() Provisioner { - // Initialize provisioners - p := &provisioner.ACME{ - Type: "ACME", - Name: "test@acme-provisioner.com", - } - if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { - fmt.Printf("%v", err) - } - return p -} - -func newAcc() (*account, error) { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - if err != nil { - return nil, err - } - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - } - return newAccount(mockdb, AccountOptions{ - Key: jwk, Contact: []string{"foo", "bar"}, - }) -} - -func TestGetAccountByID(t *testing.T) { +func TestKeyToID(t *testing.T) { type test struct { - id string - db nosql.DB - acc *account + jwk *jose.JSONWebKey + exp string err *Error } tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - acc, err := newAcc() + "fail/error-generating-thumbprint": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) + jwk.Key = "foo" return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("account %s not found: not found", acc.ID)), - } - }, - "fail/db-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading account %s: force", acc.ID)), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return nil, nil - }, - }, - err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")), + jwk: jwk, + err: NewErrorISE("error generating jwk thumbprint: square/go-jose: unknown key type 'string'"), } }, "ok": func(t *testing.T) test { - acc, err := newAcc() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - b, err := json.Marshal(acc) + + kid, err := jwk.Thumbprint(crypto.SHA256) assert.FatalError(t, err) + return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return b, nil - }, - }, + jwk: jwk, + exp: base64.RawURLEncoding.EncodeToString(kid), } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acc, err := getAccountByID(tc.db, tc.id); err != nil { + if id, err := KeyToID(tc.jwk); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, tc.acc.ID, acc.ID) - assert.Equals(t, tc.acc.Status, acc.Status) - assert.Equals(t, tc.acc.Created, acc.Created) - assert.Equals(t, tc.acc.Deactivated, acc.Deactivated) - assert.Equals(t, tc.acc.Contact, acc.Contact) - assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID) + assert.Equals(t, id, tc.exp) } } }) } } -func TestGetAccountByKeyID(t *testing.T) { +func TestAccount_IsValid(t *testing.T) { type test struct { - kid string - db nosql.DB - acc *account - err *Error + acc *Account + exp bool } - tests := map[string]func(t *testing.T) test{ - "fail/kid-not-found": func(t *testing.T) test { - return test{ - kid: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("account with key id foo not found: not found")), - } - }, - "fail/db-error": func(t *testing.T) test { - return test{ - kid: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading key-account index: force")), - } - }, - "fail/getAccount-error": func(t *testing.T) test { - count := 0 - return test{ - kid: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte("foo")) - count++ - return []byte("bar"), nil - } - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading account bar: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - return test{ - kid: acc.Key.KeyID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(acc.Key.KeyID)) - ret = []byte(acc.ID) - case 1: - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - ret = b - } - count++ - return ret, nil - }, - }, - acc: acc, - } - }, + tests := map[string]test{ + "valid": {acc: &Account{Status: StatusValid}, exp: true}, + "invalid": {acc: &Account{Status: StatusInvalid}, exp: false}, } - for name, run := range tests { + for name, tc := range tests { t.Run(name, func(t *testing.T) { - tc := run(t) - if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.acc.ID, acc.ID) - assert.Equals(t, tc.acc.Status, acc.Status) - assert.Equals(t, tc.acc.Created, acc.Created) - assert.Equals(t, tc.acc.Deactivated, acc.Deactivated) - assert.Equals(t, tc.acc.Contact, acc.Contact) - assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID) - } - } - }) - } -} - -func TestAccountToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - type test struct { - acc *account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{acc: acc} - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - acmeAccount, err := tc.acc.toACME(ctx, nil, dir) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acmeAccount.ID, tc.acc.ID) - assert.Equals(t, acmeAccount.Status, tc.acc.Status) - assert.Equals(t, acmeAccount.Contact, tc.acc.Contact) - assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID) - assert.Equals(t, acmeAccount.Orders, - fmt.Sprintf("%s/acme/%s/account/%s/orders", baseURL.String(), provName, tc.acc.ID)) - } - } - }) - } -} - -func TestAccountSave(t *testing.T) { - type test struct { - acc, old *account - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error storing account; value has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - return test{ - acc: acc, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, accountTable) - assert.Equals(t, []byte(acc.ID), key) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldAcc, err := newAcc() - assert.FatalError(t, err) - acc, err := newAcc() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldAcc) - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - return test{ - acc: acc, - old: oldAcc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - assert.Equals(t, bucket, accountTable) - assert.Equals(t, []byte(acc.ID), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.acc.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAccountSaveNew(t *testing.T) { - type test struct { - acc *account - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/keyToID-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - acc.Key.Key = "foo" - return test{ - acc: acc, - err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")), - } - }, - "fail/swap-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), - } - }, - "fail/swap-false": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - return nil, false, nil - }, - }, - err: ServerInternalErr(errors.New("key-id to account-id index already exists")), - } - }, - "fail/save-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - count++ - return nil, true, nil - } - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, b) - return nil, false, errors.New("force") - }, - MDel: func(bucket, key []byte) error { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - return nil - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - count++ - return nil, true, nil - } - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, b) - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.acc.saveNew(tc.db); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAccountUpdate(t *testing.T) { - type test struct { - acc *account - contact []string - db nosql.DB - res []byte - err *Error - } - contact := []string{"foo", "bar"} - tests := map[string]func(t *testing.T) test{ - "fail/save-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) - assert.FatalError(t, err) - - _acc := *acc - clone := &_acc - clone.Contact = contact - b, err := json.Marshal(clone) - assert.FatalError(t, err) - return test{ - acc: acc, - contact: contact, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) - assert.FatalError(t, err) - - _acc := *acc - clone := &_acc - clone.Contact = contact - b, err := json.Marshal(clone) - assert.FatalError(t, err) - return test{ - acc: acc, - contact: contact, - res: b, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - acc, err := tc.acc.update(tc.db, tc.contact) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - b, err := json.Marshal(acc) - assert.FatalError(t, err) - assert.Equals(t, b, tc.res) - } - } - }) - } -} - -func TestAccountDeactivate(t *testing.T) { - type test struct { - acc *account - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/save-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) - assert.FatalError(t, err) - - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) - assert.FatalError(t, err) - - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - acc, err := tc.acc.deactivate(tc.db) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acc.ID, tc.acc.ID) - assert.Equals(t, acc.Contact, tc.acc.Contact) - assert.Equals(t, acc.Status, StatusDeactivated) - assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID) - assert.Equals(t, acc.Created, tc.acc.Created) - - assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute))) - assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute))) - } - } - }) - } -} - -func TestNewAccount(t *testing.T) { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - kid, err := keyToID(jwk) - assert.FatalError(t, err) - ops := AccountOptions{ - Key: jwk, - Contact: []string{"foo", "bar"}, - } - type test struct { - ops AccountOptions - db nosql.DB - err *Error - id *string - } - tests := map[string]func(t *testing.T) test{ - "fail/store-error": func(t *testing.T) test { - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), - } - }, - "ok": func(t *testing.T) test { - var _id string - id := &_id - count := 0 - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - switch count { - case 0: - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - case 1: - assert.Equals(t, bucket, accountTable) - *id = string(key) - } - count++ - return nil, true, nil - }, - }, - id: id, - } - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - acc, err := newAccount(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acc.ID, *tc.id) - assert.Equals(t, acc.Status, StatusValid) - assert.Equals(t, acc.Contact, ops.Contact) - assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID) - - assert.True(t, acc.Deactivated.IsZero()) - - assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute))) - } - } + assert.Equals(t, tc.acc.IsValid(), tc.exp) }) } } diff --git a/acme/api/account.go b/acme/api/account.go index 93f46651..92c5dbfc 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -5,7 +5,6 @@ import ( "net/http" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/logging" @@ -21,7 +20,7 @@ type NewAccountRequest struct { func validateContacts(cs []string) error { for _, c := range cs { if len(c) == 0 { - return acme.MalformedErr(errors.New("contact cannot be empty string")) + return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string") } } return nil @@ -30,29 +29,23 @@ func validateContacts(cs []string) error { // Validate validates a new-account request body. func (n *NewAccountRequest) Validate() error { if n.OnlyReturnExisting && len(n.Contact) > 0 { - return acme.MalformedErr(errors.New("incompatible input; onlyReturnExisting must be alone")) + return acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone") } return validateContacts(n.Contact) } // UpdateAccountRequest represents an update-account request. type UpdateAccountRequest struct { - Contact []string `json:"contact"` - Status string `json:"status"` -} - -// IsDeactivateRequest returns true if the update request is a deactivation -// request, false otherwise. -func (u *UpdateAccountRequest) IsDeactivateRequest() bool { - return u.Status == acme.StatusDeactivated + Contact []string `json:"contact"` + Status acme.Status `json:"status"` } // Validate validates a update-account request body. func (u *UpdateAccountRequest) Validate() error { switch { case len(u.Status) > 0 && len(u.Contact) > 0: - return acme.MalformedErr(errors.New("incompatible input; contact and " + - "status updates are mutually exclusive")) + return acme.NewError(acme.ErrorMalformedType, "incompatible input; contact and "+ + "status updates are mutually exclusive") case len(u.Contact) > 0: if err := validateContacts(u.Contact); err != nil { return err @@ -60,8 +53,8 @@ func (u *UpdateAccountRequest) Validate() error { return nil case len(u.Status) > 0: if u.Status != acme.StatusDeactivated { - return acme.MalformedErr(errors.Errorf("cannot update account "+ - "status to %s, only deactivated", u.Status)) + return acme.NewError(acme.ErrorMalformedType, "cannot update account "+ + "status to %s, only deactivated", u.Status) } return nil default: @@ -73,15 +66,16 @@ func (u *UpdateAccountRequest) Validate() error { // NewAccount is the handler resource for creating new ACME accounts. func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { - payload, err := payloadFromContext(r.Context()) + ctx := r.Context() + payload, err := payloadFromContext(ctx) if err != nil { api.WriteError(w, err) return } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, - "failed to unmarshal new-account request payload"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + "failed to unmarshal new-account request payload")) return } if err := nar.Validate(); err != nil { @@ -90,7 +84,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { } httpStatus := http.StatusCreated - acc, err := acme.AccountFromContext(r.Context()) + acc, err := accountFromContext(r.Context()) if err != nil { acmeErr, ok := err.(*acme.Error) if !ok || acmeErr.Status != http.StatusBadRequest { @@ -101,20 +95,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { // Account does not exist // if nar.OnlyReturnExisting { - api.WriteError(w, acme.AccountDoesNotExistErr(nil)) + api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, + "account does not exist")) return } - jwk, err := acme.JwkFromContext(r.Context()) + jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } - if acc, err = h.Auth.NewAccount(r.Context(), acme.AccountOptions{ + acc = &acme.Account{ Key: jwk, Contact: nar.Contact, - }); err != nil { - api.WriteError(w, err) + Status: acme.StatusValid, + } + if err := h.db.CreateAccount(ctx, acc); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) return } } else { @@ -122,19 +119,22 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink, - true, acc.GetID())) + h.linker.LinkAccount(ctx, acc) + + w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, + true, acc.ID)) api.JSONStatus(w, acc, httpStatus) } -// GetUpdateAccount is the api for updating an ACME account. -func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) +// GetOrUpdateAccount is the api for updating an ACME account. +func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - payload, err := payloadFromContext(r.Context()) + payload, err := payloadFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -145,29 +145,31 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal new-account request payload"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + "failed to unmarshal new-account request payload")) return } if err := uar.Validate(); err != nil { api.WriteError(w, err) return } - var err error - // If neither the status nor the contacts are being updated then ignore - // the updates and return 200. This conforms with the behavior detailed - // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2). - if uar.IsDeactivateRequest() { - acc, err = h.Auth.DeactivateAccount(r.Context(), acc.GetID()) - } else if len(uar.Contact) > 0 { - acc, err = h.Auth.UpdateAccount(r.Context(), acc.GetID(), uar.Contact) - } - if err != nil { - api.WriteError(w, err) - return + if len(uar.Status) > 0 || len(uar.Contact) > 0 { + if len(uar.Status) > 0 { + acc.Status = uar.Status + } else if len(uar.Contact) > 0 { + acc.Contact = uar.Contact + } + + if err := h.db.UpdateAccount(ctx, acc); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) + return + } } } - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink, - true, acc.GetID())) + + h.linker.LinkAccount(ctx, acc) + + w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, true, acc.ID)) api.JSON(w, acc) } @@ -180,23 +182,27 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { } } -// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. -func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) +// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. +func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } accID := chi.URLParam(r, "accID") if acc.ID != accID { - api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param"))) + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } - orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID()) + orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { api.WriteError(w, err) return } + + h.linker.LinkOrdersByAccountID(ctx, orders) + api.JSON(w, orders) logOrdersByAccount(w, orders) } diff --git a/acme/api/account_test.go b/acme/api/account_test.go index bdd61c59..c4d7a812 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -12,7 +12,6 @@ import ( "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner" @@ -29,11 +28,11 @@ var ( } ) -func newProv() provisioner.Interface { +func newProv() acme.Provisioner { // Initialize provisioners p := &provisioner.ACME{ Type: "ACME", - Name: "test@acme-provisioner.com", + Name: "test@acme-provisioner.com", } if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { fmt.Printf("%v", err) @@ -41,7 +40,7 @@ func newProv() provisioner.Interface { return p } -func TestNewAccountRequestValidate(t *testing.T) { +func TestNewAccountRequest_Validate(t *testing.T) { type test struct { nar *NewAccountRequest err *acme.Error @@ -53,7 +52,7 @@ func TestNewAccountRequestValidate(t *testing.T) { OnlyReturnExisting: true, Contact: []string{"foo", "bar"}, }, - err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")), + err: acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone"), } }, "fail/bad-contact": func(t *testing.T) test { @@ -61,7 +60,7 @@ func TestNewAccountRequestValidate(t *testing.T) { nar: &NewAccountRequest{ Contact: []string{"foo", ""}, }, - err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "ok": func(t *testing.T) test { @@ -97,7 +96,7 @@ func TestNewAccountRequestValidate(t *testing.T) { } } -func TestUpdateAccountRequestValidate(t *testing.T) { +func TestUpdateAccountRequest_Validate(t *testing.T) { type test struct { uar *UpdateAccountRequest err *acme.Error @@ -109,8 +108,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) { Contact: []string{"foo", "bar"}, Status: "foo", }, - err: acme.MalformedErr(errors.Errorf("incompatible input; " + - "contact and status updates are mutually exclusive")), + err: acme.NewError(acme.ErrorMalformedType, "incompatible input; "+ + "contact and status updates are mutually exclusive"), } }, "fail/bad-contact": func(t *testing.T) test { @@ -118,7 +117,7 @@ func TestUpdateAccountRequestValidate(t *testing.T) { uar: &UpdateAccountRequest{ Contact: []string{"foo", ""}, }, - err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "fail/bad-status": func(t *testing.T) test { @@ -126,8 +125,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) { uar: &UpdateAccountRequest{ Status: "foo", }, - err: acme.MalformedErr(errors.Errorf("cannot update account " + - "status to foo, only deactivated")), + err: acme.NewError(acme.ErrorMalformedType, "cannot update account "+ + "status to foo, only deactivated"), } }, "ok/contact": func(t *testing.T) test { @@ -168,81 +167,81 @@ func TestUpdateAccountRequestValidate(t *testing.T) { } } -func TestHandlerGetOrdersByAccount(t *testing.T) { - oids := []string{ - "https://ca.smallstep.com/acme/order/foo", - "https://ca.smallstep.com/acme/order/bar", - } +func TestHandler_GetOrdersByAccountID(t *testing.T) { accID := "account-id" - prov := newProv() // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("accID", accID) - url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID) + + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + + url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) + + oids := []string{"foo", "bar"} + oidURLs := []string{ + fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName), + fmt.Sprintf("%s/acme/%s/order/bar", baseURL.String(), provName), + } type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, - ctx: ctx, + db: &acme.MockDB{}, + ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "foo"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")), + err: acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"), } }, - "fail/getOrdersByAccount-error": func(t *testing.T) test { + "fail/db.GetOrdersByAccountID-error": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) return test{ - auth: &mockAcmeAuthority{ - getOrdersByAccount: func(ctx context.Context, id string) ([]string, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) + db: &acme.MockDB{ + MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { assert.Equals(t, id, acc.ID) return oids, nil }, @@ -255,11 +254,11 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetOrdersByAccount(w, req) + h.GetOrdersByAccountID(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -268,18 +267,17 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - expB, err := json.Marshal(oids) + expB, err := json.Marshal(oidURLs) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) @@ -288,47 +286,41 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { } } -func TestHandlerNewAccount(t *testing.T) { - accID := "accountID" - acc := acme.Account{ - ID: accID, - Status: "valid", - Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), - } +func TestHandler_NewAccount(t *testing.T) { prov := newProv() - provName := url.PathEscape(prov.GetName()) + escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - auth acme.Interface + db acme.DB + acc *acme.Account ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-payload": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to "+ + "unmarshal new-account request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { @@ -337,12 +329,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "fail/no-existing-account": func(t *testing.T) test { @@ -351,12 +342,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-jwk": func(t *testing.T) test { @@ -365,12 +355,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { @@ -379,16 +368,15 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.JwkContextKey, nil) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, - "fail/NewAccount-error": func(t *testing.T) test { + "fail/db.CreateAccount-error": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, } @@ -396,23 +384,19 @@ func TestHandlerNewAccount(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) return test{ - auth: &mockAcmeAuthority{ - newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.Contact, nar.Contact) - assert.Equals(t, ops.Key, jwk) - return nil, acme.ServerInternalErr(errors.New("force")) + db: &acme.MockDB{ + MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + assert.Equals(t, acc.Contact, nar.Contact) + assert.Equals(t, acc.Key, jwk) + return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok/new-account": func(t *testing.T) test { @@ -423,29 +407,26 @@ func TestHandlerNewAccount(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) return test{ - auth: &mockAcmeAuthority{ - newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.Contact, nar.Contact) - assert.Equals(t, ops.Key, jwk) - return &acc, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.True(t, abs) - assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) + db: &acme.MockDB{ + MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + acc.ID = "accountID" + assert.Equals(t, acc.Contact, nar.Contact) + assert.Equals(t, acc.Key, jwk) + return nil }, }, + acc: &acme.Account{ + ID: "accountID", + Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName), + }, ctx: ctx, statusCode: 201, } @@ -456,22 +437,21 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + ID: "accountID", + Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + } + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) - }, - }, ctx: ctx, + acc: acc, statusCode: 200, } }, @@ -479,7 +459,7 @@ func TestHandlerNewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -492,90 +472,85 @@ func TestHandlerNewAccount(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - expB, err := json.Marshal(acc) + expB, err := json.Marshal(tc.acc) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), - provName, accID)}) + escProvName, "accountID")}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } -func TestHandlerGetUpdateAccount(t *testing.T) { +func TestHandler_GetOrUpdateAccount(t *testing.T) { accID := "accountID" acc := acme.Account{ - ID: accID, - Status: "valid", - Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), + ID: accID, + Status: "valid", + OrdersURL: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), } prov := newProv() - provName := url.PathEscape(prov.GetName()) + escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx := context.WithValue(context.Background(), accContextKey, &acc) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { @@ -584,62 +559,33 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, - "fail/Deactivate-error": func(t *testing.T) test { + "fail/db.UpdateAccount-error": func(t *testing.T) test { uar := &UpdateAccountRequest{ Status: "deactivated", } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ - auth: &mockAcmeAuthority{ - deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - return nil, acme.ServerInternalErr(errors.New("force")) + db: &acme.MockDB{ + MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { + assert.Equals(t, upd.Status, acme.StatusDeactivated) + assert.Equals(t, upd.ID, acc.ID) + return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), - } - }, - "fail/UpdateAccount-error": func(t *testing.T) test { - uar := &UpdateAccountRequest{ - Contact: []string{"foo", "bar"}, - } - b, err := json.Marshal(uar) - assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - return test{ - auth: &mockAcmeAuthority{ - updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - assert.Equals(t, contacts, uar.Contact) - return nil, acme.ServerInternalErr(errors.New("force")) - }, - }, - ctx: ctx, - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok/deactivate": func(t *testing.T) test { @@ -648,26 +594,16 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - return &acc, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) + db: &acme.MockDB{ + MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { + assert.Equals(t, upd.Status, acme.StatusDeactivated) + assert.Equals(t, upd.ID, acc.ID) + return nil }, }, ctx: ctx, @@ -678,21 +614,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) { uar := &UpdateAccountRequest{} b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) - }, - }, ctx: ctx, statusCode: 200, } @@ -703,27 +629,16 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - assert.Equals(t, contacts, uar.Contact) - return &acc, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) + db: &acme.MockDB{ + MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { + assert.Equals(t, upd.Contact, uar.Contact) + assert.Equals(t, upd.ID, acc.ID) + return nil }, }, ctx: ctx, @@ -731,21 +646,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } }, "ok/post-as-get": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL, provName, accID) - }, - }, ctx: ctx, statusCode: 200, } @@ -754,11 +659,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetUpdateAccount(w, req) + h.GetOrUpdateAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -767,15 +672,14 @@ func TestHandlerGetUpdateAccount(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(acc) @@ -783,7 +687,7 @@ func TestHandlerGetUpdateAccount(t *testing.T) { assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), - provName, accID)}) + escProvName, accID)}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/acme/api/handler.go b/acme/api/handler.go index 921e614e..7d02861e 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -1,56 +1,98 @@ package api import ( - "context" + "crypto/tls" "crypto/x509" + "encoding/json" "encoding/pem" "fmt" + "net" "net/http" + "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/provisioner" ) func link(url, typ string) string { return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) } +// Clock that returns time in UTC rounded to seconds. +type Clock struct{} + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Truncate(time.Second) +} + +var clock Clock + type payloadInfo struct { value []byte isPostAsGet bool isEmptyJSON bool } -// payloadFromContext searches the context for a payload. Returns the payload -// or an error. -func payloadFromContext(ctx context.Context) (*payloadInfo, error) { - val, ok := ctx.Value(acme.PayloadContextKey).(*payloadInfo) - if !ok || val == nil { - return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context")) - } - return val, nil -} - -// New returns a new ACME API router. -func New(acmeAuth acme.Interface) api.RouterHandler { - return &Handler{acmeAuth} -} - -// Handler is the ACME request handler. +// Handler is the ACME API request handler. type Handler struct { - Auth acme.Interface + db acme.DB + backdate provisioner.Duration + ca acme.CertificateAuthority + linker Linker + validateChallengeOptions *acme.ValidateChallengeOptions +} + +// HandlerOptions required to create a new ACME API request handler. +type HandlerOptions struct { + Backdate provisioner.Duration + // DB storage backend that impements the acme.DB interface. + DB acme.DB + // DNS the host used to generate accurate ACME links. By default the authority + // will use the Host from the request, so this value will only be used if + // request.Host is empty. + DNS string + // Prefix is a URL path prefix under which the ACME api is served. This + // prefix is required to generate accurate ACME links. + // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- + // "acme" is the prefix from which the ACME api is accessed. + Prefix string + CA acme.CertificateAuthority +} + +// NewHandler returns a new ACME API handler. +func NewHandler(ops HandlerOptions) api.RouterHandler { + client := http.Client{ + Timeout: 30 * time.Second, + } + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + } + return &Handler{ + ca: ops.CA, + db: ops.DB, + backdate: ops.Backdate, + linker: NewLinker(ops.DNS, ops.Prefix), + validateChallengeOptions: &acme.ValidateChallengeOptions{ + HTTPGet: client.Get, + LookupTxt: net.LookupTXT, + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(dialer, network, addr, config) + }, + }, + } } // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { - getLink := h.Auth.GetLinkExplicit + getLink := h.linker.GetLinkExplicit // Standard ACME API - r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) - r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) - r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) - r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("GET", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) + r.MethodFunc("HEAD", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) + r.MethodFunc("GET", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("HEAD", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) extractPayloadByJWK := func(next nextHTTP) nextHTTP { return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) @@ -59,16 +101,16 @@ func (h *Handler) Route(r api.Router) { return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))) } - r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) - r.MethodFunc("POST", getLink(acme.KeyChangeLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) - r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) - r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) - r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) - r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, nil, "{chID}"), extractPayloadByKid(h.GetChallenge)) - r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) + r.MethodFunc("POST", getLink(NewAccountLinkType, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) + r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount)) + r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) + r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) + r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) + r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) + r.MethodFunc("POST", getLink(FinalizeLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) + r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) + r.MethodFunc("POST", getLink(ChallengeLinkType, "{provisionerID}", false, nil, "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) + r.MethodFunc("POST", getLink(CertificateLinkType, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) } // GetNonce just sets the right header since a Nonce is added to each response @@ -81,101 +123,153 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { } } +// Directory represents an ACME directory for configuring clients. +type Directory struct { + NewNonce string `json:"newNonce"` + NewAccount string `json:"newAccount"` + NewOrder string `json:"newOrder"` + RevokeCert string `json:"revokeCert"` + KeyChange string `json:"keyChange"` +} + +// ToLog enables response logging for the Directory type. +func (d *Directory) ToLog() (interface{}, error) { + b, err := json.Marshal(d) + if err != nil { + return nil, acme.WrapErrorISE(err, "error marshaling directory for logging") + } + return string(b), nil +} + // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { - dir, err := h.Auth.GetDirectory(r.Context()) - if err != nil { - api.WriteError(w, err) - return - } - api.JSON(w, dir) + ctx := r.Context() + api.JSON(w, &Directory{ + NewNonce: h.linker.GetLink(ctx, NewNonceLinkType, true), + NewAccount: h.linker.GetLink(ctx, NewAccountLinkType, true), + NewOrder: h.linker.GetLink(ctx, NewOrderLinkType, true), + RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType, true), + KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType, true), + }) } // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, acme.NotImplemented(nil).ToACME()) + api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } -// GetAuthz ACME api for retrieving an Authz. -func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) +// GetAuthorization ACME api for retrieving an Authz. +func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - authz, err := h.Auth.GetAuthz(r.Context(), acc.GetID(), chi.URLParam(r, "authzID")) + az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { - api.WriteError(w, err) + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) + return + } + if acc.ID != az.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own authorization '%s'", acc.ID, az.ID)) + return + } + if err = az.UpdateStatus(ctx, h.db); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) return } - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AuthzLink, true, authz.GetID())) - api.JSON(w, authz) + h.linker.LinkAuthorization(ctx, az) + + w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, true, az.ID)) + api.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } // Just verify that the payload was set, since we're not strictly adhering // to ACME V2 spec for reasons specified below. - _, err = payloadFromContext(r.Context()) + _, err = payloadFromContext(ctx) if err != nil { api.WriteError(w, err) return } - // NOTE: We should be checking that the request is either a POST-as-GET, or + // NOTE: We should be checking ^^^ that the request is either a POST-as-GET, or // that the payload is an empty JSON block ({}). However, older ACME clients // still send a vestigial body (rather than an empty JSON block) and // strict enforcement would render these clients broken. For the time being // we'll just ignore the body. - var ( - ch *acme.Challenge - chID = chi.URLParam(r, "chID") - ) - ch, err = h.Auth.ValidateChallenge(r.Context(), acc.GetID(), chID, acc.GetKey()) + + azID := chi.URLParam(r, "authzID") + ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) + return + } + ch.AuthorizationID = azID + if acc.ID != ch.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) + return + } + jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } + if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) + return + } - w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), acme.AuthzLink, true, ch.GetAuthzID()), "up")) - w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.ChallengeLink, true, ch.GetID())) + h.linker.LinkChallenge(ctx, ch, azID) + + w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, true, azID), "up")) + w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, true, azID, ch.ID)) api.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { - acc, err := acme.AccountFromContext(r.Context()) + ctx := r.Context() + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } certID := chi.URLParam(r, "certID") - certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID) + + cert, err := h.db.GetCertificate(ctx, certID) if err != nil { - api.WriteError(w, err) + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) + return + } + if cert.AccountID != acc.ID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own certificate '%s'", acc.ID, certID)) return } - block, _ := pem.Decode(certBytes) - if block == nil { - api.WriteError(w, acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes"))) - return - } - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - api.WriteError(w, acme.Wrap(err, "failed to parse generated leaf certificate")) - return + var certBytes []byte + for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { + certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: c.Raw, + })...) } - api.LogCertificate(w, cert) + api.LogCertificate(w, cert.Leaf) w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8") w.Write(certBytes) } diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 7e19ea75..5501479d 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -8,6 +8,7 @@ import ( "encoding/pem" "fmt" "io/ioutil" + "net/http" "net/http/httptest" "net/url" "testing" @@ -17,206 +18,11 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) -type mockAcmeAuthority struct { - getLink func(ctx context.Context, link acme.Link, absPath bool, ins ...string) string - getLinkExplicit func(acme.Link, string, bool, *url.URL, ...string) string - - deactivateAccount func(ctx context.Context, accID string) (*acme.Account, error) - getAccount func(ctx context.Context, accID string) (*acme.Account, error) - getAccountByKey func(ctx context.Context, key *jose.JSONWebKey) (*acme.Account, error) - newAccount func(ctx context.Context, ao acme.AccountOptions) (*acme.Account, error) - updateAccount func(context.Context, string, []string) (*acme.Account, error) - - getChallenge func(ctx context.Context, accID string, chID string) (*acme.Challenge, error) - validateChallenge func(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*acme.Challenge, error) - getAuthz func(ctx context.Context, accID string, authzID string) (*acme.Authz, error) - getDirectory func(ctx context.Context) (*acme.Directory, error) - getCertificate func(string, string) ([]byte, error) - - finalizeOrder func(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*acme.Order, error) - getOrder func(ctx context.Context, accID string, orderID string) (*acme.Order, error) - getOrdersByAccount func(ctx context.Context, accID string) ([]string, error) - newOrder func(ctx context.Context, oo acme.OrderOptions) (*acme.Order, error) - - loadProvisionerByID func(string) (provisioner.Interface, error) - newNonce func() (string, error) - useNonce func(string) error - ret1 interface{} - err error -} - -func (m *mockAcmeAuthority) DeactivateAccount(ctx context.Context, id string) (*acme.Account, error) { - if m.deactivateAccount != nil { - return m.deactivateAccount(ctx, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) FinalizeOrder(ctx context.Context, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) { - if m.finalizeOrder != nil { - return m.finalizeOrder(ctx, accID, id, csr) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Order), m.err -} - -func (m *mockAcmeAuthority) GetAccount(ctx context.Context, id string) (*acme.Account, error) { - if m.getAccount != nil { - return m.getAccount(ctx, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - if m.getAccountByKey != nil { - return m.getAccountByKey(ctx, jwk) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) GetAuthz(ctx context.Context, accID, id string) (*acme.Authz, error) { - if m.getAuthz != nil { - return m.getAuthz(ctx, accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Authz), m.err -} - -func (m *mockAcmeAuthority) GetCertificate(accID string, id string) ([]byte, error) { - if m.getCertificate != nil { - return m.getCertificate(accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.([]byte), m.err -} - -func (m *mockAcmeAuthority) GetChallenge(ctx context.Context, accID, id string) (*acme.Challenge, error) { - if m.getChallenge != nil { - return m.getChallenge(ctx, accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Challenge), m.err -} - -func (m *mockAcmeAuthority) GetDirectory(ctx context.Context) (*acme.Directory, error) { - if m.getDirectory != nil { - return m.getDirectory(ctx) - } - return m.ret1.(*acme.Directory), m.err -} - -func (m *mockAcmeAuthority) GetLink(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - if m.getLink != nil { - return m.getLink(ctx, typ, abs, ins...) - } - return m.ret1.(string) -} - -func (m *mockAcmeAuthority) GetLinkExplicit(typ acme.Link, provID string, abs bool, baseURL *url.URL, ins ...string) string { - if m.getLinkExplicit != nil { - return m.getLinkExplicit(typ, provID, abs, baseURL, ins...) - } - return m.ret1.(string) -} - -func (m *mockAcmeAuthority) GetOrder(ctx context.Context, accID, id string) (*acme.Order, error) { - if m.getOrder != nil { - return m.getOrder(ctx, accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Order), m.err -} - -func (m *mockAcmeAuthority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { - if m.getOrdersByAccount != nil { - return m.getOrdersByAccount(ctx, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.([]string), m.err -} - -func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) { - if m.loadProvisionerByID != nil { - return m.loadProvisionerByID(provID) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *mockAcmeAuthority) NewAccount(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { - if m.newAccount != nil { - return m.newAccount(ctx, ops) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) NewNonce() (string, error) { - if m.newNonce != nil { - return m.newNonce() - } else if m.err != nil { - return "", m.err - } - return m.ret1.(string), m.err -} - -func (m *mockAcmeAuthority) NewOrder(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { - if m.newOrder != nil { - return m.newOrder(ctx, ops) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Order), m.err -} - -func (m *mockAcmeAuthority) UpdateAccount(ctx context.Context, id string, contact []string) (*acme.Account, error) { - if m.updateAccount != nil { - return m.updateAccount(ctx, id, contact) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) UseNonce(nonce string) error { - if m.useNonce != nil { - return m.useNonce(nonce) - } - return m.err -} - -func (m *mockAcmeAuthority) ValidateChallenge(ctx context.Context, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { - switch { - case m.validateChallenge != nil: - return m.validateChallenge(ctx, accID, id, jwk) - case m.err != nil: - return nil, m.err - default: - return m.ret1.(*acme.Challenge), m.err - } -} - -func TestHandlerGetNonce(t *testing.T) { +func TestHandler_GetNonce(t *testing.T) { tests := []struct { name string statusCode int @@ -230,7 +36,7 @@ func TestHandlerGetNonce(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} w := httptest.NewRecorder() req.Method = tt.name h.GetNonce(w, req) @@ -243,21 +49,16 @@ func TestHandlerGetNonce(t *testing.T) { } } -func TestHandlerGetDirectory(t *testing.T) { - auth, err := acme.New(nil, acme.AuthorityOptions{ - DB: new(db.MockNoSQLDB), - DNS: "ca.smallstep.com", - Prefix: "acme", - }) - assert.FatalError(t, err) +func TestHandler_GetDirectory(t *testing.T) { + linker := NewLinker("ca.smallstep.com", "acme") prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - expDir := acme.Directory{ + expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), @@ -267,7 +68,7 @@ func TestHandlerGetDirectory(t *testing.T) { type test struct { statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { @@ -279,7 +80,7 @@ func TestHandlerGetDirectory(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(auth).(*Handler) + h := &Handler{linker: linker} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -292,18 +93,17 @@ func TestHandlerGetDirectory(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - var dir acme.Directory + var dir Directory json.Unmarshal(bytes.TrimSpace(body), &dir) assert.Equals(t, dir, expDir) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) @@ -312,33 +112,32 @@ func TestHandlerGetDirectory(t *testing.T) { } } -func TestHandlerGetAuthz(t *testing.T) { +func TestHandler_GetAuthorization(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) - az := acme.Authz{ - ID: "authzID", + az := acme.Authorization{ + ID: "authzID", + AccountID: "accID", Identifier: acme.Identifier{ Type: "dns", Value: "example.com", }, - Status: "pending", - Expires: expiry.Format(time.RFC3339), - Wildcard: false, + Status: "pending", + ExpiresAt: expiry, + Wildcard: false, Challenges: []*acme.Challenge{ { - Type: "http-01", - Status: "pending", - Token: "tok2", - URL: "https://ca.smallstep.com/acme/challenge/chHTTPID", - ID: "chHTTP01ID", - AuthzID: "authzID", + Type: "http-01", + Status: "pending", + Token: "tok2", + URL: "https://ca.smallstep.com/acme/challenge/chHTTPID", + ID: "chHTTP01ID", }, { - Type: "dns-01", - Status: "pending", - Token: "tok2", - URL: "https://ca.smallstep.com/acme/challenge/chDNSID", - ID: "chDNSID", - AuthzID: "authzID", + Type: "dns-01", + Status: "pending", + Token: "tok2", + URL: "https://ca.smallstep.com/acme/challenge/chDNSID", + ID: "chDNSID", }, }, } @@ -349,71 +148,101 @@ func TestHandlerGetAuthz(t *testing.T) { // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("authzID", az.ID) - url := fmt.Sprintf("%s/acme/%s/challenge/%s", + url := fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, az.ID) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/getAuthz-error": func(t *testing.T) test { + "fail/db.GetAuthorization-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), + } + }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { + assert.Equals(t, id, az.ID) + return &acme.Authorization{ + AccountID: "foo", + }, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), + } + }, + "fail/db.UpdateAuthorization-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { + assert.Equals(t, id, az.ID) + return &acme.Authorization{ + AccountID: "accID", + Status: acme.StatusPending, + ExpiresAt: time.Now().Add(-1 * time.Hour), + }, nil + }, + MockUpdateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + assert.Equals(t, az.Status, acme.StatusInvalid) + return acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getAuthz: func(ctx context.Context, accID, id string) (*acme.Authz, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { assert.Equals(t, id, az.ID) return &az, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AuthzLink) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.True(t, abs) - assert.Equals(t, in, []string{az.ID}) - return url - }, }, ctx: ctx, statusCode: 200, @@ -423,11 +252,11 @@ func TestHandlerGetAuthz(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAuthz(w, req) + h.GetAuthorization(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -436,15 +265,14 @@ func TestHandlerGetAuthz(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { //var gotAz acme.Authz @@ -459,7 +287,7 @@ func TestHandlerGetAuthz(t *testing.T) { } } -func TestHandlerGetCertificate(t *testing.T) { +func TestHandler_GetCertificate(t *testing.T) { leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt") assert.FatalError(t, err) inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt") @@ -490,89 +318,73 @@ func TestHandlerGetCertificate(t *testing.T) { baseURL.String(), provName, certID) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/getCertificate-error": func(t *testing.T) test { + "fail/db.GetCertificate-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, - "fail/decode-leaf-for-loggger": func(t *testing.T) test { + "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getCertificate: func(accID, id string) ([]byte, error) { - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return []byte("foo"), nil + return &acme.Certificate{AccountID: "foo"}, nil }, }, ctx: ctx, - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes")), - } - }, - "fail/parse-x509-leaf-for-logger": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - auth: &mockAcmeAuthority{ - getCertificate: func(accID, id string) ([]byte, error) { - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, certID) - return pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: []byte("foo"), - }), nil - }, - }, - ctx: ctx, - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("failed to parse generated leaf certificate")), + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getCertificate: func(accID, id string) ([]byte, error) { - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return certBytes, nil + return &acme.Certificate{ + AccountID: "accID", + OrderID: "ordID", + Leaf: leaf, + Intermediates: []*x509.Certificate{inter, root}, + ID: id, + }, nil }, }, ctx: ctx, @@ -583,7 +395,7 @@ func TestHandlerGetCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -596,15 +408,14 @@ func TestHandlerGetCertificate(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.HasPrefix(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.HasPrefix(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes)) @@ -614,152 +425,233 @@ func TestHandlerGetCertificate(t *testing.T) { } } -func ch() acme.Challenge { - return acme.Challenge{ - Type: "http-01", - Status: "pending", - Token: "tok2", - URL: "https://ca.smallstep.com/acme/challenge/chID", - ID: "chID", - AuthzID: "authzID", - } -} - -func TestHandlerGetChallenge(t *testing.T) { +func TestHandler_GetChallenge(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("chID", "chID") + chiCtx.URLParams.Add("authzID", "authzID") prov := newProv() provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/challenge/%s", baseURL, "chID") + + url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s", + baseURL.String(), provName, "authzID", "chID") type test struct { - auth acme.Interface + db acme.DB + vco *acme.ValidateChallengeOptions ctx context.Context statusCode int - ch acme.Challenge - problem *acme.Error + ch *acme.Challenge + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ - ctx: ctx, + ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), + } + }, + "fail/db.GetChallenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return nil, acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{AccountID: "foo"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "accout id mismatch"), + } + }, + "fail/no-jwk": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{AccountID: "accID"}, nil + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("missing jwk"), + } + }, + "fail/nil-jwk": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, jwkContextKey, nil) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{AccountID: "accID"}, nil + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("nil jwk"), } }, "fail/validate-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - auth: &mockAcmeAuthority{ - err: acme.UnauthorizedErr(nil), - }, - ctx: ctx, - statusCode: 401, - problem: acme.UnauthorizedErr(nil), - } - }, - "fail/get-challenge-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - auth: &mockAcmeAuthority{ - err: acme.UnauthorizedErr(nil), - }, - ctx: ctx, - statusCode: 401, - problem: acme.UnauthorizedErr(nil), - } - }, - "ok/validate-challenge": func(t *testing.T) test { - key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - acc := &acme.Account{ID: "accID", Key: key} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true}) + _pub := _jwk.Public() + ctx = context.WithValue(ctx, jwkContextKey, &_pub) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) - ch := ch() - ch.Status = "valid" - ch.Validated = time.Now().UTC().Format(time.RFC3339) - count := 0 return test{ - auth: &mockAcmeAuthority{ - validateChallenge: func(ctx context.Context, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, ch.ID) - assert.Equals(t, jwk.KeyID, key.KeyID) - return &ch, nil + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{ + Status: acme.StatusPending, + Type: "http-01", + AccountID: "accID", + }, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - var ret string - switch count { - case 0: - assert.Equals(t, typ, acme.AuthzLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.AuthzID}) - ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) - case 1: - assert.Equals(t, typ, acme.ChallengeLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.ID}) - ret = url - } - count++ - return ret + MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.AccountID, "accID") + assert.Equals(t, ch.AuthorizationID, "authzID") + assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) + return acme.NewErrorISE("force") + }, + }, + vco: &acme.ValidateChallengeOptions{ + HTTPGet: func(string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + _pub := _jwk.Public() + ctx = context.WithValue(ctx, jwkContextKey, &_pub) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, "chID") + assert.Equals(t, azID, "authzID") + return &acme.Challenge{ + ID: "chID", + Status: acme.StatusPending, + Type: "http-01", + AccountID: "accID", + }, nil + }, + MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.AccountID, "accID") + assert.Equals(t, ch.AuthorizationID, "authzID") + assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) + return nil + }, + }, + ch: &acme.Challenge{ + ID: "chID", + Status: acme.StatusPending, + AuthorizationID: "authzID", + Type: "http-01", + AccountID: "accID", + URL: url, + Error: acme.NewError(acme.ErrorConnectionType, "force"), + }, + vco: &acme.ValidateChallengeOptions{ + HTTPGet: func(string) (*http.Response, error) { + return nil, errors.New("force") }, }, ctx: ctx, statusCode: 200, - ch: ch, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -772,21 +664,20 @@ func TestHandlerGetChallenge(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(tc.ch) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, tc.ch.AuthzID)}) + assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")}) assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } diff --git a/acme/api/linker.go b/acme/api/linker.go new file mode 100644 index 00000000..702f7433 --- /dev/null +++ b/acme/api/linker.go @@ -0,0 +1,182 @@ +package api + +import ( + "context" + "fmt" + "net/url" + + "github.com/smallstep/certificates/acme" +) + +// NewLinker returns a new Directory type. +func NewLinker(dns, prefix string) Linker { + return &linker{prefix: prefix, dns: dns} +} + +// Linker interface for generating links for ACME resources. +type Linker interface { + GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string + GetLinkExplicit(typ LinkType, provName string, abs bool, baseURL *url.URL, inputs ...string) string + + LinkOrder(ctx context.Context, o *acme.Order) + LinkAccount(ctx context.Context, o *acme.Account) + LinkChallenge(ctx context.Context, o *acme.Challenge, azID string) + LinkAuthorization(ctx context.Context, o *acme.Authorization) + LinkOrdersByAccountID(ctx context.Context, orders []string) +} + +// linker generates ACME links. +type linker struct { + prefix string + dns string +} + +// GetLink is a helper for GetLinkExplicit +func (l *linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string { + var provName string + if p, err := provisionerFromContext(ctx); err == nil && p != nil { + provName = p.GetName() + } + return l.GetLinkExplicit(typ, provName, abs, baseURLFromContext(ctx), inputs...) +} + +// GetLinkExplicit returns an absolute or partial path to the given resource and a base +// URL dynamically obtained from the request for which the link is being +// calculated. +func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { + var u = url.URL{} + // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 + if baseURL != nil { + u = *baseURL + } + + switch typ { + case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: + u.Path = fmt.Sprintf("/%s/%s", provisionerName, typ) + case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: + u.Path = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) + case ChallengeLinkType: + u.Path = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) + case OrdersByAccountLinkType: + u.Path = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) + case FinalizeLinkType: + u.Path = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) + } + + if abs { + // If no Scheme is set, then default to https. + if u.Scheme == "" { + u.Scheme = "https" + } + + // If no Host is set, then use the default (first DNS attr in the ca.json). + if u.Host == "" { + u.Host = l.dns + } + + u.Path = l.prefix + u.Path + return u.String() + } + return u.EscapedPath() +} + +// LinkType captures the link type. +type LinkType int + +const ( + // NewNonceLinkType new-nonce + NewNonceLinkType LinkType = iota + // NewAccountLinkType new-account + NewAccountLinkType + // AccountLinkType account + AccountLinkType + // OrderLinkType order + OrderLinkType + // NewOrderLinkType new-order + NewOrderLinkType + // OrdersByAccountLinkType list of orders owned by account + OrdersByAccountLinkType + // FinalizeLinkType finalize order + FinalizeLinkType + // NewAuthzLinkType authz + NewAuthzLinkType + // AuthzLinkType new-authz + AuthzLinkType + // ChallengeLinkType challenge + ChallengeLinkType + // CertificateLinkType certificate + CertificateLinkType + // DirectoryLinkType directory + DirectoryLinkType + // RevokeCertLinkType revoke certificate + RevokeCertLinkType + // KeyChangeLinkType key rollover + KeyChangeLinkType +) + +func (l LinkType) String() string { + switch l { + case NewNonceLinkType: + return "new-nonce" + case NewAccountLinkType: + return "new-account" + case AccountLinkType: + return "account" + case NewOrderLinkType: + return "new-order" + case OrderLinkType: + return "order" + case NewAuthzLinkType: + return "new-authz" + case AuthzLinkType: + return "authz" + case ChallengeLinkType: + return "challenge" + case CertificateLinkType: + return "certificate" + case DirectoryLinkType: + return "directory" + case RevokeCertLinkType: + return "revoke-cert" + case KeyChangeLinkType: + return "key-change" + default: + return fmt.Sprintf("unexpected LinkType '%d'", int(l)) + } +} + +// LinkOrder sets the ACME links required by an ACME order. +func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { + o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) + for i, azID := range o.AuthorizationIDs { + o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID) + } + o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, true, o.ID) + if o.CertificateID != "" { + o.CertificateURL = l.GetLink(ctx, CertificateLinkType, true, o.CertificateID) + } +} + +// LinkAccount sets the ACME links required by an ACME account. +func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { + acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) +} + +// LinkChallenge sets the ACME links required by an ACME challenge. +func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { + ch.URL = l.GetLink(ctx, ChallengeLinkType, true, azID, ch.ID) +} + +// LinkAuthorization sets the ACME links required by an ACME authorization. +func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { + for _, ch := range az.Challenges { + l.LinkChallenge(ctx, ch, az.ID) + } +} + +// LinkOrdersByAccountID converts each order ID to an ACME link. +func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) { + for i, id := range orders { + orders[i] = l.GetLink(ctx, OrderLinkType, true, id) + } +} diff --git a/acme/api/linker_test.go b/acme/api/linker_test.go new file mode 100644 index 00000000..6bb1f739 --- /dev/null +++ b/acme/api/linker_test.go @@ -0,0 +1,302 @@ +package api + +import ( + "context" + "fmt" + "net/url" + "testing" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" +) + +func TestLinker_GetLink(t *testing.T) { + dns := "ca.smallstep.com" + prefix := "acme" + linker := NewLinker(dns, prefix) + id := "1234" + + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + + assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, true), + fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName)) + assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName)) + + // No provisioner + ctxNoProv := context.WithValue(context.Background(), baseURLContextKey, baseURL) + assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, true), + fmt.Sprintf("%s/acme//new-nonce", baseURL.String())) + assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, false), "//new-nonce") + + // No baseURL + ctxNoBaseURL := context.WithValue(context.Background(), provisionerContextKey, prov) + assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, true), + fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName)) + assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName)) + + assert.Equals(t, linker.GetLink(ctx, OrderLinkType, true, id), + fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName)) + assert.Equals(t, linker.GetLink(ctx, OrderLinkType, false, id), fmt.Sprintf("/%s/order/1234", provName)) +} + +func TestLinker_GetLinkExplicit(t *testing.T) { + dns := "ca.smallstep.com" + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prefix := "acme" + linker := NewLinker(dns, prefix) + id := "1234" + + prov := newProv() + provName := prov.GetName() + escProvName := url.PathEscape(provName) + + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-nonce", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-account", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-order", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-authz", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, false, baseURL), fmt.Sprintf("/%s/directory", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, false, baseURL), fmt.Sprintf("/%s/revoke-cert", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, false, baseURL), fmt.Sprintf("/%s/key-change", escProvName)) + + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id)) + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", escProvName, id, id)) + + assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName)) + assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", escProvName)) +} + +func TestLinker_LinkOrder(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + oid := "orderID" + certID := "certID" + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + o *acme.Order + validate func(o *acme.Order) + } + var tests = map[string]test{ + "no-authz-and-no-cert": { + o: &acme.Order{ + ID: oid, + }, + validate: func(o *acme.Order) { + assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) + assert.Equals(t, o.AuthorizationURLs, []string{}) + assert.Equals(t, o.CertificateURL, "") + }, + }, + "one-authz-and-cert": { + o: &acme.Order{ + ID: oid, + CertificateID: certID, + AuthorizationIDs: []string{"foo"}, + }, + validate: func(o *acme.Order) { + assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) + assert.Equals(t, o.AuthorizationURLs, []string{ + fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), + }) + assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID)) + }, + }, + "many-authz": { + o: &acme.Order{ + ID: oid, + CertificateID: certID, + AuthorizationIDs: []string{"foo", "bar", "zap"}, + }, + validate: func(o *acme.Order) { + assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) + assert.Equals(t, o.AuthorizationURLs, []string{ + fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), + fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "bar"), + fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "zap"), + }) + assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID)) + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkOrder(ctx, tc.o) + tc.validate(tc.o) + }) + } +} + +func TestLinker_LinkAccount(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + accID := "accountID" + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + a *acme.Account + validate func(o *acme.Account) + } + var tests = map[string]test{ + "ok": { + a: &acme.Account{ + ID: accID, + }, + validate: func(a *acme.Account) { + assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkAccount(ctx, tc.a) + tc.validate(tc.a) + }) + } +} + +func TestLinker_LinkChallenge(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + chID := "chID" + azID := "azID" + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + ch *acme.Challenge + validate func(o *acme.Challenge) + } + var tests = map[string]test{ + "ok": { + ch: &acme.Challenge{ + ID: chID, + }, + validate: func(ch *acme.Challenge) { + assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkChallenge(ctx, tc.ch, azID) + tc.validate(tc.ch) + }) + } +} + +func TestLinker_LinkAuthorization(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + chID0 := "chID-0" + chID1 := "chID-1" + chID2 := "chID-2" + azID := "azID" + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + az *acme.Authorization + validate func(o *acme.Authorization) + } + var tests = map[string]test{ + "ok": { + az: &acme.Authorization{ + ID: azID, + Challenges: []*acme.Challenge{ + {ID: chID0}, + {ID: chID1}, + {ID: chID2}, + }, + }, + validate: func(az *acme.Authorization) { + assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) + assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) + assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkAuthorization(ctx, tc.az) + tc.validate(tc.az) + }) + } +} + +func TestLinker_LinkOrdersByAccountID(t *testing.T) { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prov := newProv() + provName := url.PathEscape(prov.GetName()) + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + + linkerPrefix := "acme" + l := NewLinker("dns", linkerPrefix) + type test struct { + oids []string + } + var tests = map[string]test{ + "ok": { + oids: []string{"foo", "bar", "baz"}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + l.LinkOrdersByAccountID(ctx, tc.oids) + assert.Equals(t, tc.oids, []string{ + fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "foo"), + fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "bar"), + fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "baz"), + }) + }) + } +} diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 3bf5f89a..861876a9 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -3,13 +3,13 @@ package api import ( "context" "crypto/rsa" + "errors" "io/ioutil" "net/http" "net/url" "strings" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/provisioner" @@ -54,7 +54,7 @@ func baseURLFromRequest(r *http.Request) *url.URL { // E.g. https://ca.smallstep.com/ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), acme.BaseURLContextKey, baseURLFromRequest(r)) + ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r)) next(w, r.WithContext(ctx)) } } @@ -62,14 +62,14 @@ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { // addNonce is a middleware that adds a nonce to the response header. func (h *Handler) addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - nonce, err := h.Auth.NewNonce() + nonce, err := h.db.CreateNonce(r.Context()) if err != nil { api.WriteError(w, err) return } - w.Header().Set("Replay-Nonce", nonce) + w.Header().Set("Replay-Nonce", string(nonce)) w.Header().Set("Cache-Control", "no-store") - logNonce(w, nonce) + logNonce(w, string(nonce)) next(w, r) } } @@ -78,8 +78,8 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { // directory index url. func (h *Handler) addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), - acme.DirectoryLink, true), "index")) + w.Header().Add("Link", link(h.linker.GetLink(r.Context(), + DirectoryLinkType, true), "index")) next(w, r) } } @@ -90,7 +90,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ct := r.Header.Get("Content-Type") var expected []string - if strings.Contains(r.URL.Path, h.Auth.GetLink(r.Context(), acme.CertificateLink, false, "")) { + if strings.Contains(r.URL.String(), h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} } else { @@ -103,8 +103,8 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return } } - api.WriteError(w, acme.MalformedErr(errors.Errorf( - "expected content-type to be in %s, but got %s", expected, ct))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "expected content-type to be in %s, but got %s", expected, ct)) } } @@ -113,15 +113,15 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body"))) + api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } - ctx := context.WithValue(r.Context(), acme.JwsContextKey, jws) + ctx := context.WithValue(r.Context(), jwsContextKey, jws) next(w, r.WithContext(ctx)) } } @@ -143,17 +143,18 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below func (h *Handler) validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - jws, err := acme.JwsFromContext(r.Context()) + ctx := r.Context() + jws, err := jwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } if len(jws.Signatures) == 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } @@ -164,35 +165,36 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { len(uh.Algorithm) > 0 || len(uh.Nonce) > 0 || len(uh.ExtraHeaders) > 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected switch hdr.Algorithm { - case jose.RS256, jose.RS384, jose.RS512: + case jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512: if hdr.JSONWebKey != nil { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { - api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+ - "keys must be at least %d bits (%d bytes) in size", - 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "rsa keys must be at least %d bits (%d bytes) in size", + 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: - api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "jws key type and algorithm do not match")) return } } case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: - api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm))) + api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. - if err := h.Auth.UseNonce(hdr.Nonce); err != nil { + if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { api.WriteError(w, err) return } @@ -200,21 +202,22 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { - api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { - api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 { - api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) @@ -227,24 +230,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { func (h *Handler) extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := acme.JwsFromContext(r.Context()) + jws, err := jwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { - api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { - api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } - ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) - acc, err := h.Auth.GetAccountByKey(ctx, jwk) + + // Overwrite KeyID with the JWK thumbprint. + jwk.KeyID, err = acme.KeyToID(jwk) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) + return + } + + // Store the JWK in the context. + ctx = context.WithValue(ctx, jwkContextKey, jwk) + + // Get Account or continue to generate a new one. + acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) switch { - case nosql.IsErrNotFound(err): + case errors.Is(err, acme.ErrNotFound): // For NewAccount requests ... break case err != nil: @@ -252,10 +266,10 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { return default: if !acc.IsValid() { - api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, accContextKey, acc) } next(w, r.WithContext(ctx)) } @@ -270,20 +284,20 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { name := chi.URLParam(r, "provisionerID") provID, err := url.PathUnescape(name) if err != nil { - api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name))) + api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner id '%s'", name)) return } - p, err := h.Auth.LoadProvisionerByID("acme/" + provID) + p, err := h.ca.LoadProvisionerByID("acme/" + provID) if err != nil { api.WriteError(w, err) return } acmeProv, ok := p.(*provisioner.ACME) if !ok { - api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME"))) + api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) return } - ctx = context.WithValue(ctx, acme.ProvisionerContextKey, acme.Provisioner(acmeProv)) + ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) next(w, r.WithContext(ctx)) } } @@ -294,36 +308,37 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := acme.JwsFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { api.WriteError(w, err) return } - kidPrefix := h.Auth.GetLink(ctx, acme.AccountLink, true, "") + kidPrefix := h.linker.GetLink(ctx, AccountLinkType, true, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { - api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+ - "required prefix; expected %s, but got %s", kidPrefix, kid))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + "kid does not have required prefix; expected %s, but got %s", + kidPrefix, kid)) return } accID := strings.TrimPrefix(kid, kidPrefix) - acc, err := h.Auth.GetAccount(r.Context(), accID) + acc, err := h.db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): - api.WriteError(w, acme.AccountDoesNotExistErr(nil)) + api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: api.WriteError(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.JwkContextKey, acc.Key) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, jwkContextKey, acc.Key) next(w, r.WithContext(ctx)) return } @@ -334,26 +349,27 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { // Make sure to parse and validate the JWS before running this middleware. func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - jws, err := acme.JwsFromContext(r.Context()) + ctx := r.Context() + jws, err := jwsFromContext(ctx) if err != nil { api.WriteError(w, err) return } - jwk, err := acme.JwkFromContext(r.Context()) + jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { - api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } payload, err := jws.Verify(jwk) if err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } - ctx := context.WithValue(r.Context(), acme.PayloadContextKey, &payloadInfo{ + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ value: payload, isPostAsGet: string(payload) == "", isEmptyJSON: string(payload) == "{}", @@ -371,9 +387,89 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { return } if !payload.isPostAsGet { - api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET"))) + api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) } } + +// ContextKey is the key type for storing and searching for ACME request +// essentials in the context of a request. +type ContextKey string + +const ( + // accContextKey account key + accContextKey = ContextKey("acc") + // baseURLContextKey baseURL key + baseURLContextKey = ContextKey("baseURL") + // jwsContextKey jws key + jwsContextKey = ContextKey("jws") + // jwkContextKey jwk key + jwkContextKey = ContextKey("jwk") + // payloadContextKey payload key + payloadContextKey = ContextKey("payload") + // provisionerContextKey provisioner key + provisionerContextKey = ContextKey("provisioner") +) + +// accountFromContext searches the context for an ACME account. Returns the +// account or an error. +func accountFromContext(ctx context.Context) (*acme.Account, error) { + val, ok := ctx.Value(accContextKey).(*acme.Account) + if !ok || val == nil { + return nil, acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context") + } + return val, nil +} + +// baseURLFromContext returns the baseURL if one is stored in the context. +func baseURLFromContext(ctx context.Context) *url.URL { + val, ok := ctx.Value(baseURLContextKey).(*url.URL) + if !ok || val == nil { + return nil + } + return val +} + +// jwkFromContext searches the context for a JWK. Returns the JWK or an error. +func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { + val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) + if !ok || val == nil { + return nil, acme.NewErrorISE("jwk expected in request context") + } + return val, nil +} + +// jwsFromContext searches the context for a JWS. Returns the JWS or an error. +func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { + val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature) + if !ok || val == nil { + return nil, acme.NewErrorISE("jws expected in request context") + } + return val, nil +} + +// provisionerFromContext searches the context for a provisioner. Returns the +// provisioner or an error. +func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { + val := ctx.Value(provisionerContextKey) + if val == nil { + return nil, acme.NewErrorISE("provisioner expected in request context") + } + pval, ok := val.(acme.Provisioner) + if !ok || pval == nil { + return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") + } + return pval, nil +} + +// payloadFromContext searches the context for a payload. Returns the payload +// or an error. +func payloadFromContext(ctx context.Context) (*payloadInfo, error) { + val, ok := ctx.Value(payloadContextKey).(*payloadInfo) + if !ok || val == nil { + return nil, acme.NewErrorISE("payload expected in request context") + } + return val, nil +} diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index d2a9cdc0..4c316910 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -81,14 +81,14 @@ func Test_baseURLFromRequest(t *testing.T) { } } -func TestHandlerBaseURLFromRequest(t *testing.T) { - h := New(&mockAcmeAuthority{}).(*Handler) +func TestHandler_baseURLFromRequest(t *testing.T) { + h := &Handler{} req := httptest.NewRequest("GET", "/foo", nil) req.Host = "test.ca.smallstep.com:8080" w := httptest.NewRecorder() next := func(w http.ResponseWriter, r *http.Request) { - bu := acme.BaseURLFromContext(r.Context()) + bu := baseURLFromContext(r.Context()) if assert.NotNil(t, bu) { assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080") assert.Equals(t, bu.Scheme, "https") @@ -101,35 +101,35 @@ func TestHandlerBaseURLFromRequest(t *testing.T) { req.Host = "" next = func(w http.ResponseWriter, r *http.Request) { - assert.Equals(t, acme.BaseURLFromContext(r.Context()), nil) + assert.Equals(t, baseURLFromContext(r.Context()), nil) } h.baseURLFromRequest(next)(w, req) } -func TestHandlerAddNonce(t *testing.T) { +func TestHandler_addNonce(t *testing.T) { url := "https://ca.smallstep.com/acme/new-nonce" type test struct { - auth acme.Interface - problem *acme.Error + db acme.DB + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/AddNonce-error": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{ - newNonce: func() (string, error) { - return "", acme.ServerInternalErr(errors.New("force")) + db: &acme.MockDB{ + MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) { + return acme.Nonce(""), acme.NewErrorISE("force") }, }, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{ - newNonce: func() (string, error) { + db: &acme.MockDB{ + MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) { return "bar", nil }, }, @@ -140,7 +140,7 @@ func TestHandlerAddNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) w := httptest.NewRecorder() h.addNonce(testNext)(w, req) @@ -152,15 +152,14 @@ func TestHandlerAddNonce(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, res.Header["Replay-Nonce"], []string{"bar"}) @@ -171,28 +170,23 @@ func TestHandlerAddNonce(t *testing.T) { } } -func TestHandlerAddDirLink(t *testing.T) { +func TestHandler_addDirLink(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - auth acme.Interface link string + linker Linker statusCode int ctx context.Context - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - return fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName) - }, - }, + linker: NewLinker("dns", "acme"), ctx: ctx, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), statusCode: 200, @@ -202,7 +196,7 @@ func TestHandlerAddDirLink(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{linker: tc.linker} req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -215,15 +209,14 @@ func TestHandlerAddDirLink(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s>;rel=\"index\"", tc.link)}) @@ -233,16 +226,16 @@ func TestHandlerAddDirLink(t *testing.T) { } } -func TestHandlerVerifyContentType(t *testing.T) { +func TestHandler_verifyContentType(t *testing.T) { prov := newProv() - provName := prov.GetName() + escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), provName) + url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) type test struct { h Handler ctx context.Context contentType string - problem *acme.Error + err *acme.Error statusCode int url string } @@ -250,53 +243,32 @@ func TestHandlerVerifyContentType(t *testing.T) { "fail/general-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("/acme/%s/certificate/", provName) - }, - }, + linker: NewLinker("dns", "acme"), }, - url: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + url: url, + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", statusCode: 400, - problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json], but got foo")), + err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), } }, "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", statusCode: 400, - problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo")), + err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), } }, "ok": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", statusCode: 200, } @@ -304,16 +276,9 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkix-cert", statusCode: 200, } @@ -321,16 +286,9 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/jose+json": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", statusCode: 200, } @@ -338,16 +296,9 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, abs, false) - assert.Equals(t, in, []string{""}) - return "/certificate/" - }, - }, + linker: NewLinker("dns", "acme"), }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkcs7-mime", statusCode: 200, } @@ -373,15 +324,14 @@ func TestHandlerVerifyContentType(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -390,11 +340,11 @@ func TestHandlerVerifyContentType(t *testing.T) { } } -func TestHandlerIsPostAsGet(t *testing.T) { +func TestHandler_isPostAsGet(t *testing.T) { url := "https://ca.smallstep.com/acme/new-account" type test struct { ctx context.Context - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -402,26 +352,26 @@ func TestHandlerIsPostAsGet(t *testing.T) { return test{ ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.PayloadContextKey, nil), + ctx: context.WithValue(context.Background(), payloadContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/not-post-as-get": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{}), + ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}), statusCode: 400, - problem: acme.MalformedErr(errors.New("expected POST-as-GET")), + err: acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"), } }, "ok": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}), + ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{isPostAsGet: true}), statusCode: 200, } }, @@ -429,7 +379,7 @@ func TestHandlerIsPostAsGet(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -442,15 +392,14 @@ func TestHandlerIsPostAsGet(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -468,12 +417,12 @@ func (errReader) Close() error { return nil } -func TestHandlerParseJWS(t *testing.T) { +func TestHandler_parseJWS(t *testing.T) { url := "https://ca.smallstep.com/acme/new-account" type test struct { next nextHTTP body io.Reader - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -481,14 +430,14 @@ func TestHandlerParseJWS(t *testing.T) { return test{ body: errReader(0), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("failed to read request body: force")), + err: acme.NewErrorISE("failed to read request body: force"), } }, "fail/parse-jws-error": func(t *testing.T) test { return test{ body: strings.NewReader("foo"), statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to parse JWS from request body: square/go-jose: compact JWS format must have three parts")), + err: acme.NewError(acme.ErrorMalformedType, "failed to parse JWS from request body: square/go-jose: compact JWS format must have three parts"), } }, "ok": func(t *testing.T) test { @@ -507,7 +456,7 @@ func TestHandlerParseJWS(t *testing.T) { return test{ body: strings.NewReader(expRaw), next: func(w http.ResponseWriter, r *http.Request) { - jws, err := acme.JwsFromContext(r.Context()) + jws, err := jwsFromContext(r.Context()) assert.FatalError(t, err) gotRaw, err := jws.CompactSerialize() assert.FatalError(t, err) @@ -521,7 +470,7 @@ func TestHandlerParseJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", url, tc.body) w := httptest.NewRecorder() h.parseJWS(tc.next)(w, req) @@ -533,15 +482,14 @@ func TestHandlerParseJWS(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -550,7 +498,7 @@ func TestHandlerParseJWS(t *testing.T) { } } -func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { +func TestHandler_verifyAndExtractJWSPayload(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := jwk.Public() @@ -572,7 +520,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { type test struct { ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -580,58 +528,58 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { return test{ ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil), + ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-jwk": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) return test{ - ctx: context.WithValue(ctx, acme.JwkContextKey, nil), + ctx: context.WithValue(ctx, jwsContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/verify-jws-failure": func(t *testing.T) test { _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, &_pub) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwkContextKey, &_pub) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("error verifying jws: square/go-jose: error in cryptographic primitive")), + err: acme.NewError(acme.ErrorMalformedType, "error verifying jws: square/go-jose: error in cryptographic primitive"), } }, "fail/algorithm-mismatch": func(t *testing.T) test { _pub := *pub clone := &_pub clone.Algorithm = jose.HS256 - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, clone) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwkContextKey, clone) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("verifier and signature algorithm do not match")), + err: acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"), } }, "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -651,8 +599,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { _pub := *pub clone := &_pub clone.Algorithm = "" - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -675,8 +623,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) + ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -699,8 +647,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) + ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -720,7 +668,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -733,15 +681,14 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -750,7 +697,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { } } -func TestHandlerLookupJWK(t *testing.T) { +func TestHandler_lookupJWK(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} @@ -775,27 +722,28 @@ func TestHandlerLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) type test struct { - auth acme.Interface + linker Linker + db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-kid": func(t *testing.T) test { @@ -806,21 +754,14 @@ func TestHandlerLookupJWK(t *testing.T) { assert.FatalError(t, err) _jws, err := _signer.Sign([]byte("baz")) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return prefix - }, - }, + linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got ", prefix)), + err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), } }, "fail/bad-kid-prefix": func(t *testing.T) test { @@ -837,126 +778,87 @@ func TestHandlerLookupJWK(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _parsed) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _parsed) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) - }, - }, + linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got foo", prefix)), + err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), } }, "fail/account-not-found": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) + linker: NewLinker("dns", "acme"), + db: &acme.MockDB{ + MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, accID) return nil, database.ErrNotFound }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) - }, }, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/GetAccount-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, accID) - return nil, acme.ServerInternalErr(errors.New("force")) - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) + linker: NewLinker("dns", "acme"), + db: &acme.MockDB{ + MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { + assert.Equals(t, id, accID) + return nil, acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, accID) + linker: NewLinker("dns", "acme"), + db: &acme.MockDB{ + MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { + assert.Equals(t, id, accID) return acc, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) - }, }, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(errors.New("account is not active")), + err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, accID) + linker: NewLinker("dns", "acme"), + db: &acme.MockDB{ + MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { + assert.Equals(t, id, accID) return acc, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, in, []string{""}) - return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) - }, }, ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { - _acc, err := acme.AccountFromContext(r.Context()) + _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) - _jwk, err := acme.JwkFromContext(r.Context()) + _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk, jwk) w.Write(testBody) @@ -968,7 +870,7 @@ func TestHandlerLookupJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db, linker: tc.linker} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -981,15 +883,14 @@ func TestHandlerLookupJWK(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -998,7 +899,7 @@ func TestHandlerLookupJWK(t *testing.T) { } } -func TestHandlerExtractJWK(t *testing.T) { +func TestHandler_extractJWK(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -1024,27 +925,27 @@ func TestHandlerExtractJWK(t *testing.T) { url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provName) type test struct { - auth acme.Interface + db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { @@ -1057,12 +958,12 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("jwk expected in protected header")), + err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), } }, "fail/invalid-jwk": func(t *testing.T) test { @@ -1075,71 +976,62 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("invalid jwk in protected header")), + err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), } }, "fail/GetAccountByKey-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ - getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, jwk.KeyID, pub.KeyID) - return nil, acme.ServerInternalErr(errors.New("force")) + db: &acme.MockDB{ + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) + return nil, acme.NewErrorISE("force") }, }, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ - getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, jwk.KeyID, pub.KeyID) + db: &acme.MockDB{ + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) return acc, nil }, }, statusCode: 401, - problem: acme.UnauthorizedErr(errors.New("account is not active")), + err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ - getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, jwk.KeyID, pub.KeyID) + db: &acme.MockDB{ + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) return acc, nil }, }, next: func(w http.ResponseWriter, r *http.Request) { - _acc, err := acme.AccountFromContext(r.Context()) + _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) - _jwk, err := acme.JwkFromContext(r.Context()) + _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk.KeyID, pub.KeyID) w.Write(testBody) @@ -1148,24 +1040,21 @@ func TestHandlerExtractJWK(t *testing.T) { } }, "ok/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ - getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, jwk.KeyID, pub.KeyID) - return nil, database.ErrNotFound + db: &acme.MockDB{ + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) + return nil, acme.ErrNotFound }, }, next: func(w http.ResponseWriter, r *http.Request) { - _acc, err := acme.AccountFromContext(r.Context()) + _acc, err := accountFromContext(r.Context()) assert.NotNil(t, err) assert.Nil(t, _acc) - _jwk, err := acme.JwkFromContext(r.Context()) + _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk.KeyID, pub.KeyID) w.Write(testBody) @@ -1177,7 +1066,7 @@ func TestHandlerExtractJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -1190,15 +1079,14 @@ func TestHandlerExtractJWK(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -1207,13 +1095,13 @@ func TestHandlerExtractJWK(t *testing.T) { } } -func TestHandlerValidateJWS(t *testing.T) { +func TestHandler_validateJWS(t *testing.T) { url := "https://ca.smallstep.com/acme/account/1234" type test struct { - auth acme.Interface + db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -1221,21 +1109,21 @@ func TestHandlerValidateJWS(t *testing.T) { return test{ ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil), + ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-signature": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, &jose.JSONWebSignature{}), + ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, - problem: acme.MalformedErr(errors.New("request body does not contain a signature")), + err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), } }, "fail/more-than-one-signature": func(t *testing.T) test { @@ -1246,9 +1134,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("request body contains more than one signature")), + err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), } }, "fail/unprotected-header-not-empty": func(t *testing.T) test { @@ -1258,9 +1146,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("unprotected header must not be used")), + err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), } }, "fail/unsuitable-algorithm-none": func(t *testing.T) test { @@ -1270,9 +1158,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("unsuitable algorithm: none")), + err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), } }, "fail/unsuitable-algorithm-mac": func(t *testing.T) test { @@ -1282,9 +1170,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)), + err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), } }, "fail/rsa-key-&-alg-mismatch": func(t *testing.T) test { @@ -1305,14 +1193,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")), + err: acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match"), } }, "fail/rsa-key-too-small": func(t *testing.T) test { @@ -1333,14 +1221,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")), + err: acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least 2048 bits (256 bytes) in size"), } }, "fail/UseNonce-error": func(t *testing.T) test { @@ -1350,14 +1238,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { - return acme.ServerInternalErr(errors.New("force")) + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { + return acme.NewErrorISE("force") }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/no-url-header": func(t *testing.T) test { @@ -1367,14 +1255,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("jws missing url protected header")), + err: acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"), } }, "fail/url-mismatch": func(t *testing.T) test { @@ -1391,14 +1279,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)), + err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", url), } }, "fail/both-jwk-kid": func(t *testing.T) test { @@ -1420,14 +1308,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")), + err: acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"), } }, "fail/no-jwk-kid": func(t *testing.T) test { @@ -1444,14 +1332,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")), + err: acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"), } }, "ok/kid": func(t *testing.T) test { @@ -1469,12 +1357,12 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1499,12 +1387,12 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1529,12 +1417,12 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ - useNonce: func(n string) error { + db: &acme.MockDB{ + MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1545,7 +1433,7 @@ func TestHandlerValidateJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -1558,15 +1446,14 @@ func TestHandlerValidateJWS(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) diff --git a/acme/api/order.go b/acme/api/order.go index 5c62cb52..e7a913ab 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -1,16 +1,18 @@ package api import ( + "context" "crypto/x509" "encoding/base64" "encoding/json" "net/http" + "strings" "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "go.step.sm/crypto/randutil" ) // NewOrderRequest represents the body for a NewOrder request. @@ -23,11 +25,11 @@ type NewOrderRequest struct { // Validate validates a new-order request body. func (n *NewOrderRequest) Validate() error { if len(n.Identifiers) == 0 { - return acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")) + return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty") } for _, id := range n.Identifiers { if id.Type != "dns" { - return acme.MalformedErr(errors.Errorf("identifier type unsupported: %s", id.Type)) + return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type) } } return nil @@ -44,22 +46,30 @@ func (f *FinalizeRequest) Validate() error { var err error csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR) if err != nil { - return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr")) + return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr") } f.csr, err = x509.ParseCertificateRequest(csrBytes) if err != nil { - return acme.MalformedErr(errors.Wrap(err, "unable to parse csr")) + return acme.WrapError(acme.ErrorMalformedType, err, "unable to parse csr") } if err = f.csr.CheckSignature(); err != nil { - return acme.MalformedErr(errors.Wrap(err, "csr failed signature check")) + return acme.WrapError(acme.ErrorMalformedType, err, "csr failed signature check") } return nil } +var defaultOrderExpiry = time.Hour * 24 +var defaultOrderBackdate = time.Minute + // NewOrder ACME api for creating a new order. func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acc, err := acme.AccountFromContext(ctx) + acc, err := accountFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + prov, err := provisionerFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -71,8 +81,8 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { } var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, - "failed to unmarshal new-order request payload"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + "failed to unmarshal new-order request payload")) return } if err := nor.Validate(); err != nil { @@ -80,44 +90,146 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } - o, err := h.Auth.NewOrder(ctx, acme.OrderOptions{ - AccountID: acc.GetID(), - Identifiers: nor.Identifiers, - NotBefore: nor.NotBefore, - NotAfter: nor.NotAfter, - }) - if err != nil { - api.WriteError(w, err) + now := clock.Now() + // New order. + o := &acme.Order{ + AccountID: acc.ID, + ProvisionerID: prov.GetID(), + Status: acme.StatusPending, + Identifiers: nor.Identifiers, + ExpiresAt: now.Add(defaultOrderExpiry), + AuthorizationIDs: make([]string, len(nor.Identifiers)), + NotBefore: nor.NotBefore, + NotAfter: nor.NotAfter, + } + + for i, identifier := range o.Identifiers { + az := &acme.Authorization{ + AccountID: acc.ID, + Identifier: identifier, + ExpiresAt: o.ExpiresAt, + Status: acme.StatusPending, + } + if err := h.newAuthorization(ctx, az); err != nil { + api.WriteError(w, err) + return + } + o.AuthorizationIDs[i] = az.ID + } + + if o.NotBefore.IsZero() { + o.NotBefore = now + } + if o.NotAfter.IsZero() { + o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration()) + } + // If request NotBefore was empty then backdate the order.NotBefore (now) + // to avoid timing issues. + if nor.NotBefore.IsZero() { + o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) + } + + if err := h.db.CreateOrder(ctx, o); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) return } - w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) + h.linker.LinkOrder(ctx, o) + + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSONStatus(w, o, http.StatusCreated) } +func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { + if strings.HasPrefix(az.Identifier.Value, "*.") { + az.Wildcard = true + az.Identifier = acme.Identifier{ + Value: strings.TrimPrefix(az.Identifier.Value, "*."), + Type: az.Identifier.Type, + } + } + + var ( + err error + chTypes = []string{"dns-01"} + ) + // HTTP and TLS challenges can only be used for identifiers without wildcards. + if !az.Wildcard { + chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...) + } + + az.Token, err = randutil.Alphanumeric(32) + if err != nil { + return acme.WrapErrorISE(err, "error generating random alphanumeric ID") + } + az.Challenges = make([]*acme.Challenge, len(chTypes)) + for i, typ := range chTypes { + ch := &acme.Challenge{ + AccountID: az.AccountID, + Value: az.Identifier.Value, + Type: typ, + Token: az.Token, + Status: acme.StatusPending, + } + if err := h.db.CreateChallenge(ctx, ch); err != nil { + return acme.WrapErrorISE(err, "error creating challenge") + } + az.Challenges[i] = ch + } + if err = h.db.CreateAuthorization(ctx, az); err != nil { + return acme.WrapErrorISE(err, "error creating authorization") + } + return nil +} + // GetOrder ACME api for retrieving an order. func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acc, err := acme.AccountFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - oid := chi.URLParam(r, "ordID") - o, err := h.Auth.GetOrder(ctx, acc.GetID(), oid) + prov, err := provisionerFromContext(ctx) if err != nil { api.WriteError(w, err) return } + o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + return + } + if acc.ID != o.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own order '%s'", acc.ID, o.ID)) + return + } + if prov.GetID() != o.ProvisionerID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) + return + } + if err = o.UpdateStatus(ctx, h.db); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) + return + } - w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) + h.linker.LinkOrder(ctx, o) + + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - acc, err := acme.AccountFromContext(ctx) + acc, err := accountFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + prov, err := provisionerFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -129,7 +241,8 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { } var fr FinalizeRequest if err := json.Unmarshal(payload.value, &fr); err != nil { - api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload"))) + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + "failed to unmarshal finalize-order request payload")) return } if err := fr.Validate(); err != nil { @@ -137,13 +250,28 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - oid := chi.URLParam(r, "ordID") - o, err := h.Auth.FinalizeOrder(ctx, acc.GetID(), oid, fr.csr) + o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, err) + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + return + } + if acc.ID != o.AccountID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "account '%s' does not own order '%s'", acc.ID, o.ID)) + return + } + if prov.GetID() != o.ProvisionerID { + api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) + return + } + if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) return } - w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.ID)) + h.linker.LinkOrder(ctx, o) + + w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) api.JSON(w, o) } diff --git a/acme/api/order_test.go b/acme/api/order_test.go index a1c8fef7..300aa61b 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -20,7 +20,7 @@ import ( "go.step.sm/crypto/pemutil" ) -func TestNewOrderRequestValidate(t *testing.T) { +func TestNewOrderRequest_Validate(t *testing.T) { type test struct { nor *NewOrderRequest nbf, naf time.Time @@ -30,7 +30,7 @@ func TestNewOrderRequestValidate(t *testing.T) { "fail/no-identifiers": func(t *testing.T) test { return test{ nor: &NewOrderRequest{}, - err: acme.MalformedErr(errors.Errorf("identifiers list cannot be empty")), + err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), } }, "fail/bad-identifier": func(t *testing.T) test { @@ -41,7 +41,7 @@ func TestNewOrderRequestValidate(t *testing.T) { {Type: "foo", Value: "bar.com"}, }, }, - err: acme.MalformedErr(errors.Errorf("identifier type unsupported: foo")), + err: acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: foo"), } }, "ok": func(t *testing.T) test { @@ -105,7 +105,7 @@ func TestFinalizeRequestValidate(t *testing.T) { "fail/parse-csr-error": func(t *testing.T) test { return test{ fr: &FinalizeRequest{}, - err: acme.MalformedErr(errors.Errorf("unable to parse csr: asn1: syntax error: sequence truncated")), + err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), } }, "fail/invalid-csr-signature": func(t *testing.T) test { @@ -117,7 +117,7 @@ func TestFinalizeRequestValidate(t *testing.T) { fr: &FinalizeRequest{ CSR: base64.RawURLEncoding.EncodeToString(c.Raw), }, - err: acme.MalformedErr(errors.Errorf("csr failed signature check: x509: ECDSA verification failure")), + err: acme.NewError(acme.ErrorMalformedType, "csr failed signature check: x509: ECDSA verification failure"), } }, "ok": func(t *testing.T) test { @@ -148,15 +148,19 @@ func TestFinalizeRequestValidate(t *testing.T) { } } -func TestHandlerGetOrder(t *testing.T) { - expiry := time.Now().UTC().Add(6 * time.Hour) - nbf := time.Now().UTC() - naf := time.Now().UTC().Add(24 * time.Hour) +func TestHandler_GetOrder(t *testing.T) { + prov := newProv() + escProvName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + + now := clock.Now() + nbf := now + naf := now.Add(24 * time.Hour) + expiry := now.Add(-time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry.Format(time.RFC3339), - NotBefore: nbf.Format(time.RFC3339), - NotAfter: naf.Format(time.RFC3339), + NotBefore: nbf, + NotAfter: naf, Identifiers: []acme.Identifier{ { Type: "dns", @@ -167,79 +171,167 @@ func TestHandlerGetOrder(t *testing.T) { Value: "*.smallstep.com", }, }, - Status: "pending", - Authorizations: []string{"foo", "bar"}, + ExpiresAt: expiry, + Status: acme.StatusInvalid, + Error: acme.NewError(acme.ErrorMalformedType, "order has expired"), + AuthorizationURLs: []string{ + fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName), + }, + FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName), } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", o.ID) - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} url := fmt.Sprintf("%s/acme/%s/order/%s", - baseURL.String(), provName, o.ID) + baseURL.String(), escProvName, o.ID) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/getOrder-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + "fail/no-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx = context.WithValue(ctx, accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/db.GetOrder-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), + } + }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "foo"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), + } + }, + "fail/provisioner-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"), + } + }, + "fail/order-update-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + AccountID: "accountID", + ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), + ExpiresAt: clock.Now().Add(-time.Hour), + Status: acme.StatusReady, + }, nil + }, + MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { + return acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getOrder: func(ctx context.Context, accID, id string) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, o.ID) - return &o, nil + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + ID: "orderID", + AccountID: "accountID", + ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), + ExpiresAt: expiry, + Status: acme.StatusReady, + AuthorizationIDs: []string{"foo", "bar", "baz"}, + NotBefore: nbf, + NotAfter: naf, + Identifiers: []acme.Identifier{ + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, + }, + }, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.OrderLink) - assert.True(t, abs) - assert.Equals(t, in, []string{o.ID}) - return url + MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { + return nil }, }, ctx: ctx, @@ -250,7 +342,7 @@ func TestHandlerGetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -263,19 +355,19 @@ func TestHandlerGetOrder(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(o) assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) @@ -284,209 +376,886 @@ func TestHandlerGetOrder(t *testing.T) { } } -func TestHandlerNewOrder(t *testing.T) { - expiry := time.Now().UTC().Add(6 * time.Hour) - nbf := time.Now().UTC().Add(5 * time.Hour) - naf := nbf.Add(17 * time.Hour) - o := acme.Order{ - ID: "orderID", - Expires: expiry.Format(time.RFC3339), - NotBefore: nbf.Format(time.RFC3339), - NotAfter: naf.Format(time.RFC3339), - Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, - }, - Status: "pending", - Authorizations: []string{"foo", "bar"}, +func TestHandler_newAuthorization(t *testing.T) { + type test struct { + az *acme.Authorization + db acme.DB + err *acme.Error } + var tests = map[string]func(t *testing.T) test{ + "fail/error-db.CreateChallenge": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "zap.internal", + }, + } + return test{ + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + assert.Equals(t, ch.AccountID, az.AccountID) + assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Token, az.Token) + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, az.Identifier.Value) + return errors.New("force") + }, + }, + az: az, + err: acme.NewErrorISE("error creating challenge: force"), + } + }, + "fail/error-db.CreateAuthorization": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "zap.internal", + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + count := 0 + var ch1, ch2, ch3 **acme.Challenge + return test{ + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, az.AccountID) + assert.Equals(t, ch.Token, az.Token) + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, az.Identifier.Value) + return nil + }, + MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { + assert.Equals(t, _az.AccountID, az.AccountID) + assert.Equals(t, _az.Token, az.Token) + assert.Equals(t, _az.Status, acme.StatusPending) + assert.Equals(t, _az.Identifier, az.Identifier) + assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) + assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, _az.Wildcard, false) + return errors.New("force") + }, + }, + az: az, + err: acme.NewErrorISE("error creating authorization: force"), + } + }, + "ok/no-wildcard": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "zap.internal", + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + count := 0 + var ch1, ch2, ch3 **acme.Challenge + return test{ + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, az.AccountID) + assert.Equals(t, ch.Token, az.Token) + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, az.Identifier.Value) + return nil + }, + MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { + assert.Equals(t, _az.AccountID, az.AccountID) + assert.Equals(t, _az.Token, az.Token) + assert.Equals(t, _az.Status, acme.StatusPending) + assert.Equals(t, _az.Identifier, az.Identifier) + assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) + assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, _az.Wildcard, false) + return nil + }, + }, + az: az, + } + }, + "ok/wildcard": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "*.zap.internal", + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + var ch1 **acme.Challenge + return test{ + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.AccountID, az.AccountID) + assert.Equals(t, ch.Token, az.Token) + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + ch1 = &ch + return nil + }, + MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { + assert.Equals(t, _az.AccountID, az.AccountID) + assert.Equals(t, _az.Token, az.Token) + assert.Equals(t, _az.Status, acme.StatusPending) + assert.Equals(t, _az.Identifier, acme.Identifier{ + Type: "dns", + Value: "zap.internal", + }) + assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) + assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1}) + assert.Equals(t, _az.Wildcard, true) + return nil + }, + }, + az: az, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + h := &Handler{db: tc.db} + if err := h.newAuthorization(context.Background(), tc.az); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *acme.Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestHandler_NewOrder(t *testing.T) { + // Request with chi context prov := newProv() - provName := url.PathEscape(prov.GetName()) + escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/new-order", - baseURL.String(), provName) + url := fmt.Sprintf("%s/acme/%s/order/ordID", + baseURL.String(), escProvName) type test struct { - auth acme.Interface + db acme.DB ctx context.Context + nor *NewOrderRequest statusCode int - problem *acme.Error + vr func(t *testing.T, o *acme.Order) + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) return test{ ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), + } + }, + "fail/no-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx = context.WithValue(ctx, accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/no-payload": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, provisionerContextKey, prov) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload does not exist"), } }, "fail/nil-payload": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("paylod does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal new-order request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - nor := &NewOrderRequest{} - b, err := json.Marshal(nor) + fr := &NewOrderRequest{} + b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("identifiers list cannot be empty")), + err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), } }, - "fail/NewOrder-error": func(t *testing.T) test { + "fail/error-h.newAuthorization": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - nor := &NewOrderRequest{ + fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + {Type: "dns", Value: "zap.internal"}, }, } - b, err := json.Marshal(nor) + b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ - auth: &mockAcmeAuthority{ - newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.AccountID, acc.ID) - assert.Equals(t, ops.Identifiers, nor.Identifiers) - return nil, acme.MalformedErr(errors.New("force")) - }, - }, ctx: ctx, - statusCode: 400, - problem: acme.MalformedErr(errors.New("force")), + statusCode: 500, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + assert.Equals(t, ch.AccountID, "accID") + assert.Equals(t, ch.Type, "dns-01") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return errors.New("force") + }, + }, + err: acme.NewErrorISE("error creating challenge: force"), } }, - "ok": func(t *testing.T) test { + "fail/error-db.CreateOrder": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + fr := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + } + b, err := json.Marshal(fr) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) + return test{ + ctx: ctx, + statusCode: 500, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, fr.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, fr.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return errors.New("force") + }, + }, + err: acme.NewErrorISE("error creating order: force"), + } + }, + "ok/multiple-authz": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + {Type: "dns", Value: "zap.internal"}, + {Type: "dns", Value: "*.zar.internal"}, }, - NotBefore: nbf, - NotAfter: naf, } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3, ch4 **acme.Challenge + az1ID, az2ID *string + chCount, azCount = 0, 0 + ) return test{ - auth: &mockAcmeAuthority{ - newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.AccountID, acc.ID) - assert.Equals(t, ops.Identifiers, nor.Identifiers) - assert.Equals(t, ops.NotBefore, nbf) - assert.Equals(t, ops.NotAfter, naf) - return &o, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.OrderLink) - assert.True(t, abs) - assert.Equals(t, in, []string{o.ID}) - return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID) - }, - }, ctx: ctx, statusCode: 201, + nor: nor, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch chCount { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Value, "zap.internal") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + assert.Equals(t, ch.Value, "zap.internal") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + assert.Equals(t, ch.Value, "zap.internal") + ch3 = &ch + case 3: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + assert.Equals(t, ch.Value, "zar.internal") + ch4 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + chCount++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + switch azCount { + case 0: + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Wildcard, false) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + case 1: + az.ID = "az2ID" + az2ID = &az.ID + assert.Equals(t, az.Identifier, acme.Identifier{ + Type: "dns", + Value: "zar.internal", + }) + assert.Equals(t, az.Wildcard, true) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch4}) + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + azCount++ + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID}) + return nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + now := clock.Now() + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + expNbf := now.Add(-defaultOrderBackdate) + expNaf := now.Add(prov.DefaultTLSCertDuration()) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{ + fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/az2ID", baseURL.String(), escProvName), + }) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, } }, "ok/default-naf-nbf": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) return test{ - auth: &mockAcmeAuthority{ - newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.AccountID, acc.ID) - assert.Equals(t, ops.Identifiers, nor.Identifiers) - - assert.True(t, ops.NotBefore.IsZero()) - assert.True(t, ops.NotAfter.IsZero()) - return &o, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.OrderLink) - assert.True(t, abs) - assert.Equals(t, in, []string{o.ID}) - return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID) - }, - }, ctx: ctx, statusCode: 201, + nor: nor, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + now := clock.Now() + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + expNbf := now.Add(-defaultOrderBackdate) + expNaf := now.Add(prov.DefaultTLSCertDuration()) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, + } + }, + "ok/nbf-no-naf": func(t *testing.T) test { + now := clock.Now() + expNbf := now.Add(10 * time.Minute) + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + NotBefore: expNbf, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) + return test{ + ctx: ctx, + statusCode: 201, + nor: nor, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + now := clock.Now() + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + expNaf := expNbf.Add(prov.DefaultTLSCertDuration()) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, + } + }, + "ok/naf-no-nbf": func(t *testing.T) test { + now := clock.Now() + expNaf := now.Add(15 * time.Minute) + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + NotAfter: expNaf, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) + return test{ + ctx: ctx, + statusCode: 201, + nor: nor, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + expNbf := now.Add(-defaultOrderBackdate) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, + } + }, + "ok/naf-nbf": func(t *testing.T) test { + now := clock.Now() + expNbf := now.Add(5 * time.Minute) + expNaf := now.Add(15 * time.Minute) + acc := &acme.Account{ID: "accID"} + nor := &NewOrderRequest{ + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "zap.internal"}, + }, + NotBefore: expNbf, + NotAfter: expNaf, + } + b, err := json.Marshal(nor) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + var ( + ch1, ch2, ch3 **acme.Challenge + az1ID *string + count = 0 + ) + return test{ + ctx: ctx, + statusCode: 201, + nor: nor, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "dns" + assert.Equals(t, ch.Type, "dns-01") + ch1 = &ch + case 1: + ch.ID = "http" + assert.Equals(t, ch.Type, "http-01") + ch2 = &ch + case 2: + ch.ID = "tls" + assert.Equals(t, ch.Type, "tls-alpn-01") + ch3 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, "accID") + assert.NotEquals(t, ch.Token, "") + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, "zap.internal") + return nil + }, + MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.AccountID, "accID") + assert.NotEquals(t, az.Token, "") + assert.Equals(t, az.Status, acme.StatusPending) + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) + assert.Equals(t, az.Wildcard, false) + return nil + }, + MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "ordID" + assert.Equals(t, o.AccountID, "accID") + assert.Equals(t, o.ProvisionerID, prov.GetID()) + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + return nil + }, + }, + vr: func(t *testing.T, o *acme.Order) { + testBufferDur := 5 * time.Second + orderExpiry := now.Add(defaultOrderExpiry) + + assert.Equals(t, o.ID, "ordID") + assert.Equals(t, o.Status, acme.StatusPending) + assert.Equals(t, o.Identifiers, nor.Identifiers) + assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) + assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) + assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) + assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) + assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) + assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) + assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) + }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -499,115 +1268,151 @@ func TestHandlerNewOrder(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - expB, err := json.Marshal(o) - assert.FatalError(t, err) - assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Location"], - []string{fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), - provName, o.ID)}) + ro := new(acme.Order) + assert.FatalError(t, json.Unmarshal(body, ro)) + if tc.vr != nil { + tc.vr(t, ro) + } + + assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } -func TestHandlerFinalizeOrder(t *testing.T) { - expiry := time.Now().UTC().Add(6 * time.Hour) - nbf := time.Now().UTC().Add(5 * time.Hour) - naf := nbf.Add(17 * time.Hour) +func TestHandler_FinalizeOrder(t *testing.T) { + prov := newProv() + escProvName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + + now := clock.Now() + nbf := now + naf := now.Add(24 * time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry.Format(time.RFC3339), - NotBefore: nbf.Format(time.RFC3339), - NotAfter: naf.Format(time.RFC3339), + NotBefore: nbf, + NotAfter: naf, Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, }, - Status: "valid", - Authorizations: []string{"foo", "bar"}, - Certificate: "https://ca.smallstep.com/acme/certificate/certID", + ExpiresAt: naf, + Status: acme.StatusValid, + AuthorizationURLs: []string{ + fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName), + }, + FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName), + CertificateURL: fmt.Sprintf("%s/acme/%s/certificate/certID", baseURL.String(), escProvName), } + + // Request with chi context + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("ordID", o.ID) + url := fmt.Sprintf("%s/acme/%s/order/%s", + baseURL.String(), escProvName, o.ID) + _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") assert.FatalError(t, err) csr, ok := _csr.(*x509.CertificateRequest) assert.Fatal(t, ok) - // Request with chi context - chiCtx := chi.NewRouteContext() - chiCtx.URLParams.Add("ordID", o.ID) - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/order/%s/finalize", - baseURL.String(), provName, o.ID) + nor := &FinalizeRequest{ + CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + } + payloadBytes, err := json.Marshal(nor) + assert.FatalError(t, err) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), + } + }, + "fail/no-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx = context.WithValue(ctx, accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/no-payload": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, provisionerContextKey, prov) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload does not exist"), } }, "fail/nil-payload": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("paylod does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal finalize-order request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { @@ -615,72 +1420,121 @@ func TestHandlerFinalizeOrder(t *testing.T) { fr := &FinalizeRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("unable to parse csr: asn1: syntax error: sequence truncated")), + err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), } }, - "fail/FinalizeOrder-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - nor := &FinalizeRequest{ - CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), - } - b, err := json.Marshal(nor) - assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + "fail/db.GetOrder-error": func(t *testing.T) test { + + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, o.ID) - assert.Equals(t, incsr.Raw, csr.Raw) - return nil, acme.MalformedErr(errors.New("force")) + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "foo"}, nil }, }, ctx: ctx, - statusCode: 400, - problem: acme.MalformedErr(errors.New("force")), + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), + } + }, + "fail/provisioner-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"), + } + }, + "fail/order-finalize-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + AccountID: "accountID", + ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), + ExpiresAt: clock.Now().Add(-time.Hour), + Status: acme.StatusReady, + }, nil + }, + MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { + return acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - nor := &FinalizeRequest{ - CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), - } - b, err := json.Marshal(nor) - assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, o.ID) - assert.Equals(t, incsr.Raw, csr.Raw) - return &o, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.OrderLink) - assert.True(t, abs) - assert.Equals(t, in, []string{o.ID}) - return fmt.Sprintf("%s/acme/%s/order/%s", - baseURL.String(), provName, o.ID) + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + ID: "orderID", + AccountID: "accountID", + ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), + ExpiresAt: naf, + Status: acme.StatusValid, + AuthorizationIDs: []string{"foo", "bar", "baz"}, + NotBefore: nbf, + NotAfter: naf, + Identifiers: []acme.Identifier{ + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, + }, + CertificateID: "certID", + }, nil }, }, ctx: ctx, @@ -691,7 +1545,7 @@ func TestHandlerFinalizeOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -704,23 +1558,24 @@ func TestHandlerFinalizeOrder(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(o) assert.FatalError(t, err) + + ro := new(acme.Order) + assert.FatalError(t, json.Unmarshal(body, ro)) + assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Location"], - []string{fmt.Sprintf("%s/acme/%s/order/%s", - baseURL, provName, o.ID)}) + assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/acme/authority.go b/acme/authority.go deleted file mode 100644 index 0f5f2c9f..00000000 --- a/acme/authority.go +++ /dev/null @@ -1,342 +0,0 @@ -package acme - -import ( - "context" - "crypto" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "net" - "net/http" - "net/url" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/certificates/authority/provisioner" - database "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "go.step.sm/crypto/jose" -) - -// Interface is the acme authority interface. -type Interface interface { - GetDirectory(ctx context.Context) (*Directory, error) - NewNonce() (string, error) - UseNonce(string) error - - DeactivateAccount(ctx context.Context, accID string) (*Account, error) - GetAccount(ctx context.Context, accID string) (*Account, error) - GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error) - NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) - UpdateAccount(context.Context, string, []string) (*Account, error) - - GetAuthz(ctx context.Context, accID string, authzID string) (*Authz, error) - ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error) - - FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error) - GetOrder(ctx context.Context, accID string, orderID string) (*Order, error) - GetOrdersByAccount(ctx context.Context, accID string) ([]string, error) - NewOrder(ctx context.Context, oo OrderOptions) (*Order, error) - - GetCertificate(string, string) ([]byte, error) - - LoadProvisionerByID(string) (provisioner.Interface, error) - GetLink(ctx context.Context, linkType Link, absoluteLink bool, inputs ...string) string - GetLinkExplicit(linkType Link, provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string -} - -// Authority is the layer that handles all ACME interactions. -type Authority struct { - backdate provisioner.Duration - db nosql.DB - dir *directory - signAuth SignAuthority -} - -// AuthorityOptions required to create a new ACME Authority. -type AuthorityOptions struct { - Backdate provisioner.Duration - // DB is the database used by nosql. - DB nosql.DB - // DNS the host used to generate accurate ACME links. By default the authority - // will use the Host from the request, so this value will only be used if - // request.Host is empty. - DNS string - // Prefix is a URL path prefix under which the ACME api is served. This - // prefix is required to generate accurate ACME links. - // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- - // "acme" is the prefix from which the ACME api is accessed. - Prefix string -} - -var ( - accountTable = []byte("acme_accounts") - accountByKeyIDTable = []byte("acme_keyID_accountID_index") - authzTable = []byte("acme_authzs") - challengeTable = []byte("acme_challenges") - nonceTable = []byte("nonces") - orderTable = []byte("acme_orders") - ordersByAccountIDTable = []byte("acme_account_orders_index") - certTable = []byte("acme_certs") -) - -// NewAuthority returns a new Authority that implements the ACME interface. -// -// Deprecated: NewAuthority exists for hitorical compatibility and should not -// be used. Use acme.New() instead. -func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) { - return New(signAuth, AuthorityOptions{ - DB: db, - DNS: dns, - Prefix: prefix, - }) -} - -// New returns a new Autohrity that implements the ACME interface. -func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) { - if _, ok := ops.DB.(*database.SimpleDB); !ok { - // If it's not a SimpleDB then go ahead and bootstrap the DB with the - // necessary ACME tables. SimpleDB should ONLY be used for testing. - tables := [][]byte{accountTable, accountByKeyIDTable, authzTable, - challengeTable, nonceTable, orderTable, ordersByAccountIDTable, - certTable} - for _, b := range tables { - if err := ops.DB.CreateTable(b); err != nil { - return nil, errors.Wrapf(err, "error creating table %s", - string(b)) - } - } - } - return &Authority{ - backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth, - }, nil -} - -// GetLink returns the requested link from the directory. -func (a *Authority) GetLink(ctx context.Context, typ Link, abs bool, inputs ...string) string { - return a.dir.getLink(ctx, typ, abs, inputs...) -} - -// GetLinkExplicit returns the requested link from the directory. -func (a *Authority) GetLinkExplicit(typ Link, provName string, abs bool, baseURL *url.URL, inputs ...string) string { - return a.dir.getLinkExplicit(typ, provName, abs, baseURL, inputs...) -} - -// GetDirectory returns the ACME directory object. -func (a *Authority) GetDirectory(ctx context.Context) (*Directory, error) { - return &Directory{ - NewNonce: a.dir.getLink(ctx, NewNonceLink, true), - NewAccount: a.dir.getLink(ctx, NewAccountLink, true), - NewOrder: a.dir.getLink(ctx, NewOrderLink, true), - RevokeCert: a.dir.getLink(ctx, RevokeCertLink, true), - KeyChange: a.dir.getLink(ctx, KeyChangeLink, true), - }, nil -} - -// LoadProvisionerByID calls out to the SignAuthority interface to load a -// provisioner by ID. -func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { - return a.signAuth.LoadProvisionerByID(id) -} - -// NewNonce generates, stores, and returns a new ACME nonce. -func (a *Authority) NewNonce() (string, error) { - n, err := newNonce(a.db) - if err != nil { - return "", err - } - return n.ID, nil -} - -// UseNonce consumes the given nonce if it is valid, returns error otherwise. -func (a *Authority) UseNonce(nonce string) error { - return useNonce(a.db, nonce) -} - -// NewAccount creates, stores, and returns a new ACME account. -func (a *Authority) NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) { - acc, err := newAccount(a.db, ao) - if err != nil { - return nil, err - } - return acc.toACME(ctx, a.db, a.dir) -} - -// UpdateAccount updates an ACME account. -func (a *Authority) UpdateAccount(ctx context.Context, id string, contact []string) (*Account, error) { - acc, err := getAccountByID(a.db, id) - if err != nil { - return nil, ServerInternalErr(err) - } - if acc, err = acc.update(a.db, contact); err != nil { - return nil, err - } - return acc.toACME(ctx, a.db, a.dir) -} - -// GetAccount returns an ACME account. -func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) { - acc, err := getAccountByID(a.db, id) - if err != nil { - return nil, err - } - return acc.toACME(ctx, a.db, a.dir) -} - -// DeactivateAccount deactivates an ACME account. -func (a *Authority) DeactivateAccount(ctx context.Context, id string) (*Account, error) { - acc, err := getAccountByID(a.db, id) - if err != nil { - return nil, err - } - if acc, err = acc.deactivate(a.db); err != nil { - return nil, err - } - return acc.toACME(ctx, a.db, a.dir) -} - -func keyToID(jwk *jose.JSONWebKey) (string, error) { - kid, err := jwk.Thumbprint(crypto.SHA256) - if err != nil { - return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint")) - } - return base64.RawURLEncoding.EncodeToString(kid), nil -} - -// GetAccountByKey returns the ACME associated with the jwk id. -func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) { - kid, err := keyToID(jwk) - if err != nil { - return nil, err - } - acc, err := getAccountByKeyID(a.db, kid) - if err != nil { - return nil, err - } - return acc.toACME(ctx, a.db, a.dir) -} - -// GetOrder returns an ACME order. -func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) { - o, err := getOrder(a.db, orderID) - if err != nil { - return nil, err - } - if accID != o.AccountID { - return nil, UnauthorizedErr(errors.New("account does not own order")) - } - if o, err = o.updateStatus(a.db); err != nil { - return nil, err - } - return o.toACME(ctx, a.db, a.dir) -} - -// GetOrdersByAccount returns the list of order urls owned by the account. -func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { - ordersByAccountMux.Lock() - defer ordersByAccountMux.Unlock() - - var oiba = orderIDsByAccount{} - oids, err := oiba.unsafeGetOrderIDsByAccount(a.db, id) - if err != nil { - return nil, err - } - - var ret = []string{} - for _, oid := range oids { - ret = append(ret, a.dir.getLink(ctx, OrderLink, true, oid)) - } - return ret, nil -} - -// NewOrder generates, stores, and returns a new ACME order. -func (a *Authority) NewOrder(ctx context.Context, ops OrderOptions) (*Order, error) { - prov, err := ProvisionerFromContext(ctx) - if err != nil { - return nil, err - } - ops.backdate = a.backdate.Duration - ops.defaultDuration = prov.DefaultTLSCertDuration() - order, err := newOrder(a.db, ops) - if err != nil { - return nil, Wrap(err, "error creating order") - } - return order.toACME(ctx, a.db, a.dir) -} - -// FinalizeOrder attempts to finalize an order and generate a new certificate. -func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) { - prov, err := ProvisionerFromContext(ctx) - if err != nil { - return nil, err - } - o, err := getOrder(a.db, orderID) - if err != nil { - return nil, err - } - if accID != o.AccountID { - return nil, UnauthorizedErr(errors.New("account does not own order")) - } - o, err = o.finalize(a.db, csr, a.signAuth, prov) - if err != nil { - return nil, Wrap(err, "error finalizing order") - } - return o.toACME(ctx, a.db, a.dir) -} - -// GetAuthz retrieves and attempts to update the status on an ACME authz -// before returning. -func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authz, error) { - az, err := getAuthz(a.db, authzID) - if err != nil { - return nil, err - } - if accID != az.getAccountID() { - return nil, UnauthorizedErr(errors.New("account does not own authz")) - } - az, err = az.updateStatus(a.db) - if err != nil { - return nil, Wrap(err, "error updating authz status") - } - return az.toACME(ctx, a.db, a.dir) -} - -// ValidateChallenge attempts to validate the challenge. -func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) { - ch, err := getChallenge(a.db, chID) - if err != nil { - return nil, err - } - if accID != ch.getAccountID() { - return nil, UnauthorizedErr(errors.New("account does not own challenge")) - } - client := http.Client{ - Timeout: time.Duration(30 * time.Second), - } - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - } - ch, err = ch.validate(a.db, jwk, validateOptions{ - httpGet: client.Get, - lookupTxt: net.LookupTXT, - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, - }) - if err != nil { - return nil, Wrap(err, "error attempting challenge validation") - } - return ch.toACME(ctx, a.db, a.dir) -} - -// GetCertificate retrieves the Certificate by ID. -func (a *Authority) GetCertificate(accID, certID string) ([]byte, error) { - cert, err := getCert(a.db, certID) - if err != nil { - return nil, err - } - if accID != cert.AccountID { - return nil, UnauthorizedErr(errors.New("account does not own certificate")) - } - return cert.toACME(a.db, a.dir) -} diff --git a/acme/authority_test.go b/acme/authority_test.go deleted file mode 100644 index 8861c15e..00000000 --- a/acme/authority_test.go +++ /dev/null @@ -1,1739 +0,0 @@ -package acme - -import ( - "context" - "crypto" - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql/database" - "go.step.sm/crypto/jose" -) - -func TestAuthorityGetLink(t *testing.T) { - auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - type test struct { - auth *Authority - typ Link - abs bool - inputs []string - res string - } - tests := map[string]func(t *testing.T) test{ - "ok/new-account/abs": func(t *testing.T) test { - return test{ - auth: auth, - typ: NewAccountLink, - abs: true, - res: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), - } - }, - "ok/new-account/no-abs": func(t *testing.T) test { - return test{ - auth: auth, - typ: NewAccountLink, - abs: false, - res: fmt.Sprintf("/%s/new-account", provName), - } - }, - "ok/order/abs": func(t *testing.T) test { - return test{ - auth: auth, - typ: OrderLink, - abs: true, - inputs: []string{"foo"}, - res: fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName), - } - }, - "ok/order/no-abs": func(t *testing.T) test { - return test{ - auth: auth, - typ: OrderLink, - abs: false, - inputs: []string{"foo"}, - res: fmt.Sprintf("/%s/order/foo", provName), - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - link := tc.auth.GetLink(ctx, tc.typ, tc.abs, tc.inputs...) - assert.Equals(t, tc.res, link) - }) - } -} - -func TestAuthorityGetDirectory(t *testing.T) { - auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - - prov := newProv() - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - type test struct { - ctx context.Context - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/empty-provisioner": func(t *testing.T) test { - return test{ - ctx: context.Background(), - } - }, - "ok/no-baseURL": func(t *testing.T) test { - return test{ - ctx: context.WithValue(context.Background(), ProvisionerContextKey, prov), - } - }, - "ok/baseURL": func(t *testing.T) test { - return test{ - ctx: ctx, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if dir, err := auth.GetDirectory(tc.ctx); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - bu := BaseURLFromContext(tc.ctx) - if bu == nil { - bu = &url.URL{Scheme: "https", Host: "ca.smallstep.com"} - } - - var provName string - prov, err := ProvisionerFromContext(tc.ctx) - if err != nil { - provName = "" - } else { - provName = url.PathEscape(prov.GetName()) - } - - assert.Equals(t, dir.NewNonce, fmt.Sprintf("%s/acme/%s/new-nonce", bu.String(), provName)) - assert.Equals(t, dir.NewAccount, fmt.Sprintf("%s/acme/%s/new-account", bu.String(), provName)) - assert.Equals(t, dir.NewOrder, fmt.Sprintf("%s/acme/%s/new-order", bu.String(), provName)) - assert.Equals(t, dir.RevokeCert, fmt.Sprintf("%s/acme/%s/revoke-cert", bu.String(), provName)) - assert.Equals(t, dir.KeyChange, fmt.Sprintf("%s/acme/%s/key-change", bu.String(), provName)) - } - } - }) - } -} - -func TestAuthorityNewNonce(t *testing.T) { - type test struct { - auth *Authority - res *string - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/newNonce-error": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - res: nil, - err: ServerInternalErr(errors.New("error storing nonce: force")), - } - }, - "ok": func(t *testing.T) test { - var _res string - res := &_res - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - *res = string(key) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - res: res, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if nonce, err := tc.auth.NewNonce(); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, nonce, *tc.res) - } - } - }) - } -} - -func TestAuthorityUseNonce(t *testing.T) { - type test struct { - auth *Authority - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/newNonce-error": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - return errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - err: ServerInternalErr(errors.New("error deleting nonce foo: force")), - } - }, - "ok": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - return nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.auth.UseNonce("foo"); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAuthorityNewAccount(t *testing.T) { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - ops := AccountOptions{ - Key: jwk, Contact: []string{"foo", "bar"}, - } - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - ops AccountOptions - err *Error - acc **Account - } - tests := map[string]func(t *testing.T) test{ - "fail/newAccount-error": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: ops, - err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), - } - }, - "ok": func(t *testing.T) test { - var ( - _acmeacc = &Account{} - acmeacc = &_acmeacc - count = 0 - dir = newDirectory("ca.smallstep.com", "acme") - ) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - var acc *account - assert.FatalError(t, json.Unmarshal(newval, &acc)) - *acmeacc, err = acc.toACME(ctx, nil, dir) - return nil, true, nil - } - count++ - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: ops, - acc: acmeacc, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.NewAccount(ctx, tc.ops); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - expb, err := json.Marshal(*tc.acc) - assert.FatalError(t, err) - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetAccount(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id string - err *Error - acc *account - } - tests := map[string]func(t *testing.T) test{ - "fail/getAccount-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - acc: acc, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.GetAccount(ctx, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - - acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetAccountByKey(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - jwk *jose.JSONWebKey - err *Error - acc *account - } - tests := map[string]func(t *testing.T) test{ - "fail/generate-thumbprint-error": func(t *testing.T) test { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - jwk.Key = "foo" - auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - jwk: jwk, - err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")), - } - }, - "fail/getAccount-error": func(t *testing.T) test { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - kid, err := keyToID(jwk) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - jwk: jwk, - err: ServerInternalErr(errors.New("error loading key-account index: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch { - case count == 0: - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - ret = []byte(acc.ID) - case count == 1: - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - ret = b - } - count++ - return ret, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - jwk: acc.Key, - acc: acc, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.GetAccountByKey(ctx, tc.jwk); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - - acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetOrder(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id, accID string - err *Error - o *order - } - tests := map[string]func(t *testing.T) test{ - "fail/getOrder-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.New("error loading order foo: force")), - } - }, - "fail/order-not-owned-by-account": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: "foo", - err: UnauthorizedErr(errors.New("account does not own order")), - } - }, - "fail/updateStatus-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - i := 0 - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch { - case i == 0: - i++ - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - default: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(o.Authorizations[0])) - return nil, ServerInternalErr(errors.New("force")) - } - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: o.AccountID, - err: ServerInternalErr(errors.Errorf("error loading authz %s: force", o.Authorizations[0])), - } - }, - "ok": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = "valid" - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: o.AccountID, - o: o, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeO, err := tc.auth.GetOrder(ctx, tc.accID, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeO) - assert.FatalError(t, err) - - acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetCertificate(t *testing.T) { - type test struct { - auth *Authority - id, accID string - err *Error - cert *certificate - } - tests := map[string]func(t *testing.T) test{ - "fail/getCertificate-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.New("error loading certificate: force")), - } - }, - "fail/certificate-not-owned-by-account": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - b, err := json.Marshal(cert) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: cert.ID, - accID: "foo", - err: UnauthorizedErr(errors.New("account does not own certificate")), - } - }, - "ok": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - b, err := json.Marshal(cert) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: cert.ID, - accID: cert.AccountID, - cert: cert, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeCert, err := tc.auth.GetCertificate(tc.accID, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeCert) - assert.FatalError(t, err) - - acmeExp, err := tc.cert.toACME(nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetAuthz(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id, accID string - err *Error - acmeAz *Authz - } - tests := map[string]func(t *testing.T) test{ - "fail/getAuthz-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.Errorf("error loading authz %s: force", id)), - } - }, - "fail/authz-not-owned-by-account": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: az.getID(), - accID: "foo", - err: UnauthorizedErr(errors.New("account does not own authz")), - } - }, - "fail/update-status-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - count := 0 - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - ret = b - case 1: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(az.getChallenges()[0])) - return nil, errors.New("force") - } - count++ - return ret, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: az.getID(), - accID: az.getAccountID(), - err: ServerInternalErr(errors.New("error updating authz status: error loading challenge")), - } - }, - "ok": func(t *testing.T) test { - var ch1B, ch2B, ch3B = &[]byte{}, &[]byte{}, &[]byte{} - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - switch count { - case 0: - *ch1B = newval - case 1: - *ch2B = newval - case 2: - *ch3B = newval - } - count++ - return nil, true, nil - }, - } - az, err := newAuthz(mockdb, "1234", Identifier{ - Type: "dns", Value: "acme.example.com", - }) - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusValid - b, err := json.Marshal(az) - assert.FatalError(t, err) - - ch1, err := unmarshalChallenge(*ch1B) - assert.FatalError(t, err) - ch2, err := unmarshalChallenge(*ch2B) - assert.FatalError(t, err) - ch3, err := unmarshalChallenge(*ch3B) - assert.FatalError(t, err) - count = 0 - mockdb = &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch1.getID())) - ret = *ch1B - case 1: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch2.getID())) - ret = *ch2B - case 2: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch3.getID())) - ret = *ch3B - } - count++ - return ret, nil - }, - } - acmeAz, err := az.toACME(ctx, mockdb, newDirectory("ca.smallstep.com", "acme")) - assert.FatalError(t, err) - - count = 0 - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - ret = b - case 1: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch1.getID())) - ret = *ch1B - case 2: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch2.getID())) - ret = *ch2B - case 3: - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch3.getID())) - ret = *ch3B - } - count++ - return ret, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: az.getID(), - accID: az.getAccountID(), - acmeAz: acmeAz, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAz, err := tc.auth.GetAuthz(ctx, tc.accID, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAz) - assert.FatalError(t, err) - - expb, err := json.Marshal(tc.acmeAz) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityNewOrder(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - ops OrderOptions - ctx context.Context - err *Error - o **Order - } - tests := map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{}, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: defaultOrderOps(), - ctx: context.Background(), - err: ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/newOrder-error": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: defaultOrderOps(), - ctx: ctx, - err: ServerInternalErr(errors.New("error creating order: error creating http challenge: error saving acme challenge: force")), - } - }, - "ok": func(t *testing.T) test { - var ( - _acmeO = &Order{} - acmeO = &_acmeO - count = 0 - dir = newDirectory("ca.smallstep.com", "acme") - err error - _accID string - accID = &_accID - ) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - switch count { - case 0: - assert.Equals(t, bucket, challengeTable) - case 1: - assert.Equals(t, bucket, challengeTable) - case 2: - assert.Equals(t, bucket, challengeTable) - case 3: - assert.Equals(t, bucket, authzTable) - case 4: - assert.Equals(t, bucket, challengeTable) - case 5: - assert.Equals(t, bucket, challengeTable) - case 6: - assert.Equals(t, bucket, challengeTable) - case 7: - assert.Equals(t, bucket, authzTable) - case 8: - assert.Equals(t, bucket, orderTable) - var o order - assert.FatalError(t, json.Unmarshal(newval, &o)) - *acmeO, err = o.toACME(ctx, nil, dir) - assert.FatalError(t, err) - *accID = o.AccountID - case 9: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, string(key), *accID) - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - ops: defaultOrderOps(), - ctx: ctx, - o: acmeO, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeO, err := tc.auth.NewOrder(tc.ctx, tc.ops); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeO) - assert.FatalError(t, err) - expb, err := json.Marshal(*tc.o) - assert.FatalError(t, err) - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityGetOrdersByAccount(t *testing.T) { - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - type test struct { - auth *Authority - id string - err *Error - res []string - } - tests := map[string]func(t *testing.T) test{ - "fail/getOrderIDsByAccount-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), - } - }, - "fail/getOrder-error": func(t *testing.T) test { - var ( - id = "zap" - oids = []string{"foo", "bar"} - count = 0 - err error - ) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(id)) - ret, err = json.Marshal(oids) - assert.FatalError(t, err) - case 1: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(oids[0])) - return nil, errors.New("force") - } - count++ - return ret, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.New("error loading order foo for account zap: error loading order foo: force")), - } - }, - "ok": func(t *testing.T) test { - accID := "zap" - - foo, err := newO() - assert.FatalError(t, err) - bfoo, err := json.Marshal(foo) - assert.FatalError(t, err) - - bar, err := newO() - assert.FatalError(t, err) - bar.Status = StatusInvalid - bbar, err := json.Marshal(bar) - assert.FatalError(t, err) - - zap, err := newO() - assert.FatalError(t, err) - bzap, err := json.Marshal(zap) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(orderTable): - dbGetOrder++ - switch dbGetOrder { - case 1: - return bfoo, nil - case 2: - return bbar, nil - case 3: - return bzap, nil - } - case string(ordersByAccountIDTable): - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(accID)) - ret, err := json.Marshal([]string{foo.ID, bar.ID, zap.ID}) - assert.FatalError(t, err) - return ret, nil - case string(challengeTable): - return bch, nil - case string(authzTable): - return baz, nil - } - return nil, errors.Errorf("should not be query db table %s", bucket) - }, - MCmpAndSwap: func(bucket, key, old, newVal []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, string(key), accID) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: accID, - res: []string{ - fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, foo.ID), - fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, zap.ID), - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if orderLinks, err := tc.auth.GetOrdersByAccount(ctx, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res, orderLinks) - } - } - }) - } -} - -func TestAuthorityFinalizeOrder(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id, accID string - ctx context.Context - err *Error - o *order - } - tests := map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - auth, err := NewAuthority(&db.MockNoSQLDB{}, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: "foo", - ctx: context.Background(), - err: ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/getOrder-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - ctx: ctx, - err: ServerInternalErr(errors.New("error loading order foo: force")), - } - }, - "fail/order-not-owned-by-account": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: "foo", - ctx: ctx, - err: UnauthorizedErr(errors.New("account does not own order")), - } - }, - "fail/finalize-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Expires = time.Now().Add(-time.Minute) - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: o.AccountID, - ctx: ctx, - err: ServerInternalErr(errors.New("error finalizing order: error storing order: force")), - } - }, - "ok": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusValid - o.Certificate = "certID" - b, err := json.Marshal(o) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: o.ID, - accID: o.AccountID, - ctx: ctx, - o: o, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeO, err := tc.auth.FinalizeOrder(tc.ctx, tc.accID, tc.id, nil); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeO) - assert.FatalError(t, err) - - acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityValidateChallenge(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - - type test struct { - auth *Authority - id, accID string - err *Error - ch challenge - jwk *jose.JSONWebKey - server *httptest.Server - } - tests := map[string]func(t *testing.T) test{ - "fail/getChallenge-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", id)), - } - }, - "fail/challenge-not-owned-by-account": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - b, err := json.Marshal(ch) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: ch.getID(), - accID: "foo", - err: UnauthorizedErr(errors.New("account does not own challenge")), - } - }, - "fail/validate-error": func(t *testing.T) test { - keyauth := "temp" - keyauthp := &keyauth - // Create test server that returns challenge auth - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "%s\r\n", *keyauthp) - })) - t.Cleanup(func() { ts.Close() }) - - ch, err := newHTTPChWithServer(strings.TrimPrefix(ts.URL, "http://")) - assert.FatalError(t, err) - - jwk, _, err := jose.GenerateDefaultKeyPair([]byte("pass")) - assert.FatalError(t, err) - - thumbprint, err := jwk.Thumbprint(crypto.SHA256) - assert.FatalError(t, err) - encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) - *keyauthp = fmt.Sprintf("%s.%s", ch.getToken(), encPrint) - - b, err := json.Marshal(ch) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: ch.getID(), - accID: ch.getAccountID(), - jwk: jwk, - server: ts, - err: ServerInternalErr(errors.New("error attempting challenge validation: error saving acme challenge: force")), - } - }, - "ok/already-valid": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - _ch.baseChallenge.Validated = clock.Now() - b, err := json.Marshal(ch) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return b, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: ch.getID(), - accID: ch.getAccountID(), - ch: ch, - } - }, - "ok": func(t *testing.T) test { - keyauth := "temp" - keyauthp := &keyauth - // Create test server that returns challenge auth - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "%s\r\n", *keyauthp) - })) - t.Cleanup(func() { ts.Close() }) - - ch, err := newHTTPChWithServer(strings.TrimPrefix(ts.URL, "http://")) - assert.FatalError(t, err) - - jwk, _, err := jose.GenerateDefaultKeyPair([]byte("pass")) - assert.FatalError(t, err) - - thumbprint, err := jwk.Thumbprint(crypto.SHA256) - assert.FatalError(t, err) - encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) - *keyauthp = fmt.Sprintf("%s.%s", ch.getToken(), encPrint) - - b, err := json.Marshal(ch) - assert.FatalError(t, err) - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: ch.getID(), - accID: ch.getAccountID(), - jwk: jwk, - server: ts, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeCh, err := tc.auth.ValidateChallenge(ctx, tc.accID, tc.id, tc.jwk); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeCh) - assert.FatalError(t, err) - - if tc.ch != nil { - acmeExp, err := tc.ch.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - } - }) - } -} - -func TestAuthorityUpdateAccount(t *testing.T) { - contact := []string{"baz", "zap"} - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id string - contact []string - acc *account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/getAccount-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - contact: contact, - err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)), - } - }, - "fail/update-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - contact: contact, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - - _acc := *acc - clone := &_acc - clone.Contact = contact - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - contact: contact, - acc: clone, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.UpdateAccount(ctx, tc.id, tc.contact); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - - acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} - -func TestAuthorityDeactivateAccount(t *testing.T) { - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - type test struct { - auth *Authority - id string - acc *account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/getAccount-error": func(t *testing.T) test { - id := "foo" - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(id)) - return nil, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: id, - err: ServerInternalErr(errors.Errorf("error loading account %s: force", id)), - } - }, - "fail/deactivate-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - - _acc := *acc - clone := &_acc - clone.Status = StatusDeactivated - clone.Deactivated = clock.Now() - auth, err := NewAuthority(&db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return b, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return nil, true, nil - }, - }, "ca.smallstep.com", "acme", nil) - assert.FatalError(t, err) - return test{ - auth: auth, - id: acc.ID, - acc: clone, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acmeAcc, err := tc.auth.DeactivateAccount(ctx, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - gotb, err := json.Marshal(acmeAcc) - assert.FatalError(t, err) - - acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) - assert.FatalError(t, err) - expb, err := json.Marshal(acmeExp) - assert.FatalError(t, err) - - assert.Equals(t, expb, gotb) - } - } - }) - } -} diff --git a/acme/authorization.go b/acme/authorization.go new file mode 100644 index 00000000..d2df5ea5 --- /dev/null +++ b/acme/authorization.go @@ -0,0 +1,69 @@ +package acme + +import ( + "context" + "encoding/json" + "time" +) + +// Authorization representst an ACME Authorization. +type Authorization struct { + ID string `json:"-"` + AccountID string `json:"-"` + Token string `json:"-"` + Identifier Identifier `json:"identifier"` + Status Status `json:"status"` + Challenges []*Challenge `json:"challenges"` + Wildcard bool `json:"wildcard"` + ExpiresAt time.Time `json:"expires"` + Error *Error `json:"error,omitempty"` +} + +// ToLog enables response logging. +func (az *Authorization) ToLog() (interface{}, error) { + b, err := json.Marshal(az) + if err != nil { + return nil, WrapErrorISE(err, "error marshaling authz for logging") + } + return string(b), nil +} + +// UpdateStatus updates the ACME Authorization Status if necessary. +// Changes to the Authorization are saved using the database interface. +func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { + now := clock.Now() + + switch az.Status { + case StatusInvalid: + return nil + case StatusValid: + return nil + case StatusPending: + // check expiry + if now.After(az.ExpiresAt) { + az.Status = StatusInvalid + break + } + + var isValid = false + for _, ch := range az.Challenges { + if ch.Status == StatusValid { + isValid = true + break + } + } + + if !isValid { + return nil + } + az.Status = StatusValid + az.Error = nil + default: + return NewErrorISE("unrecognized authorization status: %s", az.Status) + } + + if err := db.UpdateAuthorization(ctx, az); err != nil { + return WrapErrorISE(err, "error updating authorization") + } + return nil +} diff --git a/acme/authorization_test.go b/acme/authorization_test.go new file mode 100644 index 00000000..00b35b99 --- /dev/null +++ b/acme/authorization_test.go @@ -0,0 +1,150 @@ +package acme + +import ( + "context" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" +) + +func TestAuthorization_UpdateStatus(t *testing.T) { + type test struct { + az *Authorization + err *Error + db DB + } + tests := map[string]func(t *testing.T) test{ + "ok/already-invalid": func(t *testing.T) test { + az := &Authorization{ + Status: StatusInvalid, + } + return test{ + az: az, + } + }, + "ok/already-valid": func(t *testing.T) test { + az := &Authorization{ + Status: StatusInvalid, + } + return test{ + az: az, + } + }, + "fail/error-unexpected-status": func(t *testing.T) test { + az := &Authorization{ + Status: "foo", + } + return test{ + az: az, + err: NewErrorISE("unrecognized authorization status: %s", az.Status), + } + }, + "ok/expired": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + az: az, + db: &MockDB{ + MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { + assert.Equals(t, updaz.ID, az.ID) + assert.Equals(t, updaz.AccountID, az.AccountID) + assert.Equals(t, updaz.Status, StatusInvalid) + assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) + return nil + }, + }, + } + }, + "fail/db.UpdateAuthorization-error": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + az: az, + db: &MockDB{ + MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { + assert.Equals(t, updaz.ID, az.ID) + assert.Equals(t, updaz.AccountID, az.AccountID) + assert.Equals(t, updaz.Status, StatusInvalid) + assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) + return errors.New("force") + }, + }, + err: NewErrorISE("error updating authorization: force"), + } + }, + "ok/no-valid-challenges": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*Challenge{ + {Status: StatusPending}, {Status: StatusPending}, {Status: StatusPending}, + }, + } + return test{ + az: az, + } + }, + "ok/valid": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*Challenge{ + {Status: StatusPending}, {Status: StatusPending}, {Status: StatusValid}, + }, + } + return test{ + az: az, + db: &MockDB{ + MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { + assert.Equals(t, updaz.ID, az.ID) + assert.Equals(t, updaz.AccountID, az.AccountID) + assert.Equals(t, updaz.Status, StatusValid) + assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) + assert.Equals(t, updaz.Error, nil) + return nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + + } +} diff --git a/acme/authz.go b/acme/authz.go deleted file mode 100644 index 8c45bce0..00000000 --- a/acme/authz.go +++ /dev/null @@ -1,347 +0,0 @@ -package acme - -import ( - "context" - "encoding/json" - "strings" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/nosql" -) - -var defaultExpiryDuration = time.Hour * 24 - -// Authz is a subset of the Authz type containing only those attributes -// required for responses in the ACME protocol. -type Authz struct { - Identifier Identifier `json:"identifier"` - Status string `json:"status"` - Expires string `json:"expires"` - Challenges []*Challenge `json:"challenges"` - Wildcard bool `json:"wildcard"` - ID string `json:"-"` -} - -// ToLog enables response logging. -func (a *Authz) ToLog() (interface{}, error) { - b, err := json.Marshal(a) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling authz for logging")) - } - return string(b), nil -} - -// GetID returns the Authz ID. -func (a *Authz) GetID() string { - return a.ID -} - -// authz is the interface that the various authz types must implement. -type authz interface { - save(nosql.DB, authz) error - clone() *baseAuthz - getID() string - getAccountID() string - getType() string - getIdentifier() Identifier - getStatus() string - getExpiry() time.Time - getWildcard() bool - getChallenges() []string - getCreated() time.Time - updateStatus(db nosql.DB) (authz, error) - toACME(context.Context, nosql.DB, *directory) (*Authz, error) -} - -// baseAuthz is the base authz type that others build from. -type baseAuthz struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Identifier Identifier `json:"identifier"` - Status string `json:"status"` - Expires time.Time `json:"expires"` - Challenges []string `json:"challenges"` - Wildcard bool `json:"wildcard"` - Created time.Time `json:"created"` - Error *Error `json:"error"` -} - -func newBaseAuthz(accID string, identifier Identifier) (*baseAuthz, error) { - id, err := randID() - if err != nil { - return nil, err - } - - now := clock.Now() - ba := &baseAuthz{ - ID: id, - AccountID: accID, - Status: StatusPending, - Created: now, - Expires: now.Add(defaultExpiryDuration), - Identifier: identifier, - } - - if strings.HasPrefix(identifier.Value, "*.") { - ba.Wildcard = true - ba.Identifier = Identifier{ - Value: strings.TrimPrefix(identifier.Value, "*."), - Type: identifier.Type, - } - } - - return ba, nil -} - -// getID returns the ID of the authz. -func (ba *baseAuthz) getID() string { - return ba.ID -} - -// getAccountID returns the Account ID that created the authz. -func (ba *baseAuthz) getAccountID() string { - return ba.AccountID -} - -// getType returns the type of the authz. -func (ba *baseAuthz) getType() string { - return ba.Identifier.Type -} - -// getIdentifier returns the identifier for the authz. -func (ba *baseAuthz) getIdentifier() Identifier { - return ba.Identifier -} - -// getStatus returns the status of the authz. -func (ba *baseAuthz) getStatus() string { - return ba.Status -} - -// getWildcard returns true if the authz identifier has a '*', false otherwise. -func (ba *baseAuthz) getWildcard() bool { - return ba.Wildcard -} - -// getChallenges returns the authz challenge IDs. -func (ba *baseAuthz) getChallenges() []string { - return ba.Challenges -} - -// getExpiry returns the expiration time of the authz. -func (ba *baseAuthz) getExpiry() time.Time { - return ba.Expires -} - -// getCreated returns the created time of the authz. -func (ba *baseAuthz) getCreated() time.Time { - return ba.Created -} - -// toACME converts the internal Authz type into the public acmeAuthz type for -// presentation in the ACME protocol. -func (ba *baseAuthz) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Authz, error) { - var chs = make([]*Challenge, len(ba.Challenges)) - for i, chID := range ba.Challenges { - ch, err := getChallenge(db, chID) - if err != nil { - return nil, err - } - chs[i], err = ch.toACME(ctx, db, dir) - if err != nil { - return nil, err - } - } - return &Authz{ - Identifier: ba.Identifier, - Status: ba.getStatus(), - Challenges: chs, - Wildcard: ba.getWildcard(), - Expires: ba.Expires.Format(time.RFC3339), - ID: ba.ID, - }, nil -} - -func (ba *baseAuthz) save(db nosql.DB, old authz) error { - var ( - err error - oldB, newB []byte - ) - if old == nil { - oldB = nil - } else { - if oldB, err = json.Marshal(old); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old authz")) - } - } - if newB, err = json.Marshal(ba); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new authz")) - } - _, swapped, err := db.CmpAndSwap(authzTable, []byte(ba.ID), oldB, newB) - switch { - case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error storing authz")) - case !swapped: - return ServerInternalErr(errors.Errorf("error storing authz; " + - "value has changed since last read")) - default: - return nil - } -} - -func (ba *baseAuthz) clone() *baseAuthz { - u := *ba - return &u -} - -func (ba *baseAuthz) parent() authz { - return &dnsAuthz{ba} -} - -// updateStatus attempts to update the status on a baseAuthz and stores the -// updating object if necessary. -func (ba *baseAuthz) updateStatus(db nosql.DB) (authz, error) { - newAuthz := ba.clone() - - now := time.Now().UTC() - switch ba.Status { - case StatusInvalid: - return ba.parent(), nil - case StatusValid: - return ba.parent(), nil - case StatusPending: - // check expiry - if now.After(ba.Expires) { - newAuthz.Status = StatusInvalid - newAuthz.Error = MalformedErr(errors.New("authz has expired")) - break - } - - var isValid = false - for _, chID := range ba.Challenges { - ch, err := getChallenge(db, chID) - if err != nil { - return ba, err - } - if ch.getStatus() == StatusValid { - isValid = true - break - } - } - - if !isValid { - return ba.parent(), nil - } - newAuthz.Status = StatusValid - newAuthz.Error = nil - default: - return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status)) - } - - if err := newAuthz.save(db, ba); err != nil { - return ba, err - } - return newAuthz.parent(), nil -} - -// unmarshalAuthz unmarshals an authz type into the correct sub-type. -func unmarshalAuthz(data []byte) (authz, error) { - var getType struct { - Identifier Identifier `json:"identifier"` - } - if err := json.Unmarshal(data, &getType); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type")) - } - - switch getType.Identifier.Type { - case "dns": - var ba baseAuthz - if err := json.Unmarshal(data, &ba); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dnsAuthz")) - } - return &dnsAuthz{&ba}, nil - default: - return nil, ServerInternalErr(errors.Errorf("unexpected authz type %s", - getType.Identifier.Type)) - } -} - -// dnsAuthz represents a dns acme authorization. -type dnsAuthz struct { - *baseAuthz -} - -// newAuthz returns a new acme authorization object based on the identifier -// type. -func newAuthz(db nosql.DB, accID string, identifier Identifier) (a authz, err error) { - switch identifier.Type { - case "dns": - a, err = newDNSAuthz(db, accID, identifier) - default: - err = MalformedErr(errors.Errorf("unexpected authz type %s", - identifier.Type)) - } - return -} - -// newDNSAuthz returns a new dns acme authorization object. -func newDNSAuthz(db nosql.DB, accID string, identifier Identifier) (authz, error) { - ba, err := newBaseAuthz(accID, identifier) - if err != nil { - return nil, err - } - - ba.Challenges = []string{} - if !ba.Wildcard { - // http and alpn challenges are only permitted if the DNS is not a wildcard dns. - ch1, err := newHTTP01Challenge(db, ChallengeOptions{ - AccountID: accID, - AuthzID: ba.ID, - Identifier: ba.Identifier}) - if err != nil { - return nil, Wrap(err, "error creating http challenge") - } - ba.Challenges = append(ba.Challenges, ch1.getID()) - - ch2, err := newTLSALPN01Challenge(db, ChallengeOptions{ - AccountID: accID, - AuthzID: ba.ID, - Identifier: ba.Identifier, - }) - if err != nil { - return nil, Wrap(err, "error creating alpn challenge") - } - ba.Challenges = append(ba.Challenges, ch2.getID()) - } - ch3, err := newDNS01Challenge(db, ChallengeOptions{ - AccountID: accID, - AuthzID: ba.ID, - Identifier: identifier}) - if err != nil { - return nil, Wrap(err, "error creating dns challenge") - } - ba.Challenges = append(ba.Challenges, ch3.getID()) - - da := &dnsAuthz{ba} - if err := da.save(db, nil); err != nil { - return nil, err - } - - return da, nil -} - -// getAuthz retrieves and unmarshals an ACME authz type from the database. -func getAuthz(db nosql.DB, id string) (authz, error) { - b, err := db.Get(authzTable, []byte(id)) - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id)) - } else if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id)) - } - az, err := unmarshalAuthz(b) - if err != nil { - return nil, err - } - return az, nil -} diff --git a/acme/authz_test.go b/acme/authz_test.go deleted file mode 100644 index 31e6bb58..00000000 --- a/acme/authz_test.go +++ /dev/null @@ -1,836 +0,0 @@ -package acme - -import ( - "context" - "encoding/json" - "strings" - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" -) - -func newAz() (authz, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newAuthz(mockdb, "1234", Identifier{ - Type: "dns", Value: "acme.example.com", - }) -} - -func TestGetAuthz(t *testing.T) { - type test struct { - id string - db nosql.DB - az authz - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())), - } - }, - "fail/db-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Identifier.Type = "foo" - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - return b, nil - }, - }, - err: ServerInternalErr(errors.New("unexpected authz type foo")), - } - }, - "ok": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if az, err := getAuthz(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.az.getID(), az.getID()) - assert.Equals(t, tc.az.getAccountID(), az.getAccountID()) - assert.Equals(t, tc.az.getStatus(), az.getStatus()) - assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier()) - assert.Equals(t, tc.az.getCreated(), az.getCreated()) - assert.Equals(t, tc.az.getExpiry(), az.getExpiry()) - assert.Equals(t, tc.az.getChallenges(), az.getChallenges()) - } - } - }) - } -} - -func TestAuthzClone(t *testing.T) { - az, err := newAz() - assert.FatalError(t, err) - - clone := az.clone() - - assert.Equals(t, clone.getID(), az.getID()) - assert.Equals(t, clone.getAccountID(), az.getAccountID()) - assert.Equals(t, clone.getStatus(), az.getStatus()) - assert.Equals(t, clone.getIdentifier(), az.getIdentifier()) - assert.Equals(t, clone.getExpiry(), az.getExpiry()) - assert.Equals(t, clone.getCreated(), az.getCreated()) - assert.Equals(t, clone.getChallenges(), az.getChallenges()) - - clone.Status = StatusValid - - assert.NotEquals(t, clone.getStatus(), az.getStatus()) -} - -func TestNewAuthz(t *testing.T) { - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - accID := "1234" - type test struct { - iden Identifier - db nosql.DB - err *Error - resChs *([]string) - } - tests := map[string]func(t *testing.T) test{ - "fail/unexpected-type": func(t *testing.T) test { - return test{ - iden: Identifier{Type: "foo", Value: "acme.example.com"}, - err: MalformedErr(errors.New("unexpected authz type foo")), - } - }, - "fail/new-http-chall-error": func(t *testing.T) test { - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")), - } - }, - "fail/new-tls-alpn-chall-error": func(t *testing.T) test { - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error creating alpn challenge: error saving acme challenge: force")), - } - }, - "fail/new-dns-chall-error": func(t *testing.T) test { - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 2 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")), - } - }, - "fail/save-authz-error": func(t *testing.T) test { - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 3 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error storing authz: force")), - } - }, - "ok": func(t *testing.T) test { - chs := &([]string{}) - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 3 { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, old, nil) - - az, err := unmarshalAuthz(newval) - assert.FatalError(t, err) - - assert.Equals(t, az.getID(), string(key)) - assert.Equals(t, az.getAccountID(), accID) - assert.Equals(t, az.getStatus(), StatusPending) - assert.Equals(t, az.getIdentifier(), iden) - assert.Equals(t, az.getWildcard(), false) - - *chs = az.getChallenges() - - assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := az.getCreated().Add(defaultExpiryDuration) - assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) - assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) - } - count++ - return nil, true, nil - }, - }, - resChs: chs, - } - }, - "ok/wildcard": func(t *testing.T) test { - chs := &([]string{}) - count := 0 - _iden := Identifier{Type: "dns", Value: "*.acme.example.com"} - return test{ - iden: _iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, old, nil) - - az, err := unmarshalAuthz(newval) - assert.FatalError(t, err) - - assert.Equals(t, az.getID(), string(key)) - assert.Equals(t, az.getAccountID(), accID) - assert.Equals(t, az.getStatus(), StatusPending) - assert.Equals(t, az.getIdentifier(), iden) - assert.Equals(t, az.getWildcard(), true) - - *chs = az.getChallenges() - // Verify that we only have 1 challenge instead of 2. - assert.True(t, len(*chs) == 1) - - assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := az.getCreated().Add(defaultExpiryDuration) - assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) - assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) - } - count++ - return nil, true, nil - }, - }, - resChs: chs, - } - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - az, err := newAuthz(tc.db, accID, tc.iden) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, az.getAccountID(), accID) - assert.Equals(t, az.getType(), "dns") - assert.Equals(t, az.getStatus(), StatusPending) - - assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := az.getCreated().Add(defaultExpiryDuration) - assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) - assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) - - assert.Equals(t, az.getChallenges(), *(tc.resChs)) - - if strings.HasPrefix(tc.iden.Value, "*.") { - assert.True(t, az.getWildcard()) - assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*.")) - } else { - assert.False(t, az.getWildcard()) - assert.Equals(t, az.getIdentifier().Value, tc.iden.Value) - } - - assert.True(t, az.getID() != "") - } - } - }) - } -} - -func TestAuthzToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - - var ( - ch1, ch2 challenge - ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{}) - err error - ) - - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - *ch1Bytes = newval - ch1, err = unmarshalChallenge(newval) - assert.FatalError(t, err) - } else if count == 1 { - *ch2Bytes = newval - ch2, err = unmarshalChallenge(newval) - assert.FatalError(t, err) - } - count++ - return []byte("foo"), true, nil - }, - } - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - az, err := newAuthz(mockdb, "1234", iden) - assert.FatalError(t, err) - - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - - type test struct { - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/getChallenge1-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading challenge")), - } - }, - "fail/getChallenge2-error": func(t *testing.T) test { - count := 0 - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 1 { - return nil, errors.New("force") - } - count++ - return *ch1Bytes, nil - }, - }, - err: ServerInternalErr(errors.New("error loading challenge")), - } - }, - "ok": func(t *testing.T) test { - count := 0 - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - count++ - return *ch1Bytes, nil - } - return *ch2Bytes, nil - }, - }, - } - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - acmeAz, err := az.toACME(ctx, tc.db, dir) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acmeAz.ID, az.getID()) - assert.Equals(t, acmeAz.Identifier, iden) - assert.Equals(t, acmeAz.Status, StatusPending) - - acmeCh1, err := ch1.toACME(ctx, nil, dir) - assert.FatalError(t, err) - acmeCh2, err := ch2.toACME(ctx, nil, dir) - assert.FatalError(t, err) - - assert.Equals(t, acmeAz.Challenges[0], acmeCh1) - assert.Equals(t, acmeAz.Challenges[1], acmeCh2) - - expiry, err := time.Parse(time.RFC3339, acmeAz.Expires) - assert.FatalError(t, err) - assert.Equals(t, expiry.String(), az.getExpiry().String()) - } - } - }) - } -} - -func TestAuthzSave(t *testing.T) { - type test struct { - az, old authz - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing authz: force")), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, authzTable) - assert.Equals(t, []byte(az.getID()), key) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldAz, err := newAz() - assert.FatalError(t, err) - az, err := newAz() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldAz) - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - old: oldAz, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, authzTable) - assert.Equals(t, []byte(az.getID()), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.az.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAuthzUnmarshal(t *testing.T) { - type test struct { - az authz - azb []byte - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/nil": func(t *testing.T) test { - return test{ - azb: nil, - err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")), - } - }, - "fail/unexpected-type": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Identifier.Type = "foo" - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - azb: b, - err: ServerInternalErr(errors.New("unexpected authz type foo")), - } - }, - "ok/dns": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - azb: b, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if az, err := unmarshalAuthz(tc.azb); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.az.getID(), az.getID()) - assert.Equals(t, tc.az.getAccountID(), az.getAccountID()) - assert.Equals(t, tc.az.getStatus(), az.getStatus()) - assert.Equals(t, tc.az.getCreated(), az.getCreated()) - assert.Equals(t, tc.az.getExpiry(), az.getExpiry()) - assert.Equals(t, tc.az.getWildcard(), az.getWildcard()) - assert.Equals(t, tc.az.getChallenges(), az.getChallenges()) - } - } - }) - } -} - -func TestAuthzUpdateStatus(t *testing.T) { - type test struct { - az, res authz - err *Error - db nosql.DB - } - tests := map[string]func(t *testing.T) test{ - "fail/already-invalid": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusInvalid - return test{ - az: az, - res: az, - } - }, - "fail/already-valid": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusValid - return test{ - az: az, - res: az, - } - }, - "fail/unexpected-status": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusReady - return test{ - az: az, - res: az, - err: ServerInternalErr(errors.New("unrecognized authz status: ready")), - } - }, - "fail/save-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute) - return test{ - az: az, - res: az, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing authz: force")), - } - }, - "ok/expired": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute) - - clone := az.clone() - clone.Error = MalformedErr(errors.New("authz has expired")) - clone.Status = StatusInvalid - return test{ - az: az, - res: clone.parent(), - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - "fail/get-challenge-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - - return test{ - az: az, - res: az, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading challenge")), - } - }, - "ok/valid": func(t *testing.T) test { - var ( - ch3 challenge - ch2Bytes = &([]byte{}) - ch1Bytes = &([]byte{}) - err error - ) - - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - *ch1Bytes = newval - } else if count == 1 { - *ch2Bytes = newval - } else if count == 2 { - ch3, err = unmarshalChallenge(newval) - assert.FatalError(t, err) - } - count++ - return nil, true, nil - }, - } - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - az, err := newAuthz(mockdb, "1234", iden) - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Error = MalformedErr(nil) - - _ch, ok := ch3.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - chb, err := json.Marshal(ch3) - - clone := az.clone() - clone.Status = StatusValid - clone.Error = nil - - count = 0 - return test{ - az: az, - res: clone.parent(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - count++ - return *ch1Bytes, nil - } - if count == 1 { - count++ - return *ch2Bytes, nil - } - count++ - return chb, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - "ok/still-pending": func(t *testing.T) test { - var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{}) - - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - *ch1Bytes = newval - } else if count == 1 { - *ch2Bytes = newval - } - count++ - return nil, true, nil - }, - } - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - az, err := newAuthz(mockdb, "1234", iden) - assert.FatalError(t, err) - - count = 0 - return test{ - az: az, - res: az, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - count++ - return *ch1Bytes, nil - } - count++ - return *ch2Bytes, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - az, err := tc.az.updateStatus(tc.db) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - expB, err := json.Marshal(tc.res) - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - assert.Equals(t, expB, b) - } - } - }) - } -} diff --git a/acme/certificate.go b/acme/certificate.go index 6a31c880..d46d1a08 100644 --- a/acme/certificate.go +++ b/acme/certificate.go @@ -2,88 +2,13 @@ package acme import ( "crypto/x509" - "encoding/json" - "encoding/pem" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/nosql" ) -type certificate struct { - ID string `json:"id"` - Created time.Time `json:"created"` - AccountID string `json:"accountID"` - OrderID string `json:"orderID"` - Leaf []byte `json:"leaf"` - Intermediates []byte `json:"intermediates"` -} - -// CertOptions options with which to create and store a cert object. -type CertOptions struct { +// Certificate options with which to create and store a cert object. +type Certificate struct { + ID string AccountID string OrderID string Leaf *x509.Certificate Intermediates []*x509.Certificate } - -func newCert(db nosql.DB, ops CertOptions) (*certificate, error) { - id, err := randID() - if err != nil { - return nil, err - } - - leaf := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: ops.Leaf.Raw, - }) - var intermediates []byte - for _, cert := range ops.Intermediates { - intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - })...) - } - - cert := &certificate{ - ID: id, - AccountID: ops.AccountID, - OrderID: ops.OrderID, - Leaf: leaf, - Intermediates: intermediates, - Created: time.Now().UTC(), - } - certB, err := json.Marshal(cert) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling certificate")) - } - - _, swapped, err := db.CmpAndSwap(certTable, []byte(id), nil, certB) - switch { - case err != nil: - return nil, ServerInternalErr(errors.Wrap(err, "error storing certificate")) - case !swapped: - return nil, ServerInternalErr(errors.New("error storing certificate; " + - "value has changed since last read")) - default: - return cert, nil - } -} - -func (c *certificate) toACME(db nosql.DB, dir *directory) ([]byte, error) { - return append(c.Leaf, c.Intermediates...), nil -} - -func getCert(db nosql.DB, id string) (*certificate, error) { - b, err := db.Get(certTable, []byte(id)) - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id)) - } else if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate")) - } - var cert certificate - if err := json.Unmarshal(b, &cert); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate")) - } - return &cert, nil -} diff --git a/acme/certificate_test.go b/acme/certificate_test.go deleted file mode 100644 index a4b8f91a..00000000 --- a/acme/certificate_test.go +++ /dev/null @@ -1,253 +0,0 @@ -package acme - -import ( - "crypto/x509" - "encoding/json" - "encoding/pem" - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" - "go.step.sm/crypto/pemutil" -) - -func defaultCertOps() (*CertOptions, error) { - crt, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt") - if err != nil { - return nil, err - } - inter, err := pemutil.ReadCertificate("../authority/testdata/certs/intermediate_ca.crt") - if err != nil { - return nil, err - } - root, err := pemutil.ReadCertificate("../authority/testdata/certs/root_ca.crt") - if err != nil { - return nil, err - } - return &CertOptions{ - AccountID: "accID", - OrderID: "ordID", - Leaf: crt, - Intermediates: []*x509.Certificate{inter, root}, - }, nil -} - -func newcert() (*certificate, error) { - ops, err := defaultCertOps() - if err != nil { - return nil, err - } - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - } - return newCert(mockdb, *ops) -} - -func TestNewCert(t *testing.T) { - type test struct { - db nosql.DB - ops CertOptions - err *Error - id *string - } - tests := map[string]func(t *testing.T) test{ - "fail/cmpAndSwap-error": func(t *testing.T) test { - ops, err := defaultCertOps() - assert.FatalError(t, err) - return test{ - ops: *ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, old, nil) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error storing certificate: force")), - } - }, - "fail/cmpAndSwap-false": func(t *testing.T) test { - ops, err := defaultCertOps() - assert.FatalError(t, err) - return test{ - ops: *ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, old, nil) - return nil, false, nil - }, - }, - err: ServerInternalErr(errors.Errorf("error storing certificate; value has changed since last read")), - } - }, - "ok": func(t *testing.T) test { - ops, err := defaultCertOps() - assert.FatalError(t, err) - var _id string - id := &_id - return test{ - ops: *ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, old, nil) - *id = string(key) - return nil, true, nil - }, - }, - id: id, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if cert, err := newCert(tc.db, tc.ops); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, cert.ID, *tc.id) - assert.Equals(t, cert.AccountID, tc.ops.AccountID) - assert.Equals(t, cert.OrderID, tc.ops.OrderID) - - leaf := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: tc.ops.Leaf.Raw, - }) - var intermediates []byte - for _, cert := range tc.ops.Intermediates { - intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - })...) - } - assert.Equals(t, cert.Leaf, leaf) - assert.Equals(t, cert.Intermediates, intermediates) - - assert.True(t, cert.Created.Before(time.Now().Add(time.Minute))) - assert.True(t, cert.Created.After(time.Now().Add(-time.Minute))) - } - } - }) - } -} - -func TestGetCert(t *testing.T) { - type test struct { - id string - db nosql.DB - cert *certificate - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - return test{ - cert: cert, - id: cert.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("certificate %s not found: not found", cert.ID)), - } - }, - "fail/db-error": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - return test{ - cert: cert, - id: cert.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading certificate: force")), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - return test{ - cert: cert, - id: cert.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return nil, nil - }, - }, - err: ServerInternalErr(errors.New("error unmarshaling certificate: unexpected end of JSON input")), - } - }, - "ok": func(t *testing.T) test { - cert, err := newcert() - assert.FatalError(t, err) - b, err := json.Marshal(cert) - assert.FatalError(t, err) - return test{ - cert: cert, - id: cert.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if cert, err := getCert(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.cert.ID, cert.ID) - assert.Equals(t, tc.cert.AccountID, cert.AccountID) - assert.Equals(t, tc.cert.OrderID, cert.OrderID) - assert.Equals(t, tc.cert.Created, cert.Created) - assert.Equals(t, tc.cert.Leaf, cert.Leaf) - assert.Equals(t, tc.cert.Intermediates, cert.Intermediates) - } - } - }) - } -} - -func TestCertificateToACME(t *testing.T) { - cert, err := newcert() - assert.FatalError(t, err) - acmeCert, err := cert.toACME(nil, nil) - assert.FatalError(t, err) - assert.Equals(t, append(cert.Leaf, cert.Intermediates...), acmeCert) -} diff --git a/acme/challenge.go b/acme/challenge.go index 6d2d13d1..1059e437 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -14,394 +14,115 @@ import ( "io/ioutil" "net" "net/http" + "net/url" "strings" "time" - "github.com/pkg/errors" - "github.com/smallstep/nosql" "go.step.sm/crypto/jose" ) -// Challenge is a subset of the challenge type containing only those attributes -// required for responses in the ACME protocol. +// Challenge represents an ACME response Challenge type. type Challenge struct { - Type string `json:"type"` - Status string `json:"status"` - Token string `json:"token"` - Validated string `json:"validated,omitempty"` - URL string `json:"url"` - Error *AError `json:"error,omitempty"` - ID string `json:"-"` - AuthzID string `json:"-"` + ID string `json:"-"` + AccountID string `json:"-"` + AuthorizationID string `json:"-"` + Value string `json:"-"` + Type string `json:"type"` + Status Status `json:"status"` + Token string `json:"token"` + ValidatedAt string `json:"validated,omitempty"` + URL string `json:"url"` + Error *Error `json:"error,omitempty"` } // ToLog enables response logging. -func (c *Challenge) ToLog() (interface{}, error) { - b, err := json.Marshal(c) +func (ch *Challenge) ToLog() (interface{}, error) { + b, err := json.Marshal(ch) if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling challenge for logging")) + return nil, WrapErrorISE(err, "error marshaling challenge for logging") } return string(b), nil } -// GetID returns the Challenge ID. -func (c *Challenge) GetID() string { - return c.ID -} - -// GetAuthzID returns the parent Authz ID that owns the Challenge. -func (c *Challenge) GetAuthzID() string { - return c.AuthzID -} - -type httpGetter func(string) (*http.Response, error) -type lookupTxt func(string) ([]string, error) -type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) - -type validateOptions struct { - httpGet httpGetter - lookupTxt lookupTxt - tlsDial tlsDialer -} - -// challenge is the interface ACME challenege types must implement. -type challenge interface { - save(db nosql.DB, swap challenge) error - validate(nosql.DB, *jose.JSONWebKey, validateOptions) (challenge, error) - getType() string - getError() *AError - getValue() string - getStatus() string - getID() string - getAuthzID() string - getToken() string - clone() *baseChallenge - getAccountID() string - getValidated() time.Time - getCreated() time.Time - toACME(context.Context, nosql.DB, *directory) (*Challenge, error) -} - -// ChallengeOptions is the type used to created a new Challenge. -type ChallengeOptions struct { - AccountID string - AuthzID string - Identifier Identifier -} - -// baseChallenge is the base Challenge type that others build from. -type baseChallenge struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - AuthzID string `json:"authzID"` - Type string `json:"type"` - Status string `json:"status"` - Token string `json:"token"` - Value string `json:"value"` - Validated time.Time `json:"validated"` - Created time.Time `json:"created"` - Error *AError `json:"error"` -} - -func newBaseChallenge(accountID, authzID string) (*baseChallenge, error) { - id, err := randID() - if err != nil { - return nil, Wrap(err, "error generating random id for ACME challenge") - } - token, err := randID() - if err != nil { - return nil, Wrap(err, "error generating token for ACME challenge") - } - - return &baseChallenge{ - ID: id, - AccountID: accountID, - AuthzID: authzID, - Status: StatusPending, - Token: token, - Created: clock.Now(), - }, nil -} - -// getID returns the id of the baseChallenge. -func (bc *baseChallenge) getID() string { - return bc.ID -} - -// getAuthzID returns the authz ID of the baseChallenge. -func (bc *baseChallenge) getAuthzID() string { - return bc.AuthzID -} - -// getAccountID returns the account id of the baseChallenge. -func (bc *baseChallenge) getAccountID() string { - return bc.AccountID -} - -// getType returns the type of the baseChallenge. -func (bc *baseChallenge) getType() string { - return bc.Type -} - -// getValue returns the type of the baseChallenge. -func (bc *baseChallenge) getValue() string { - return bc.Value -} - -// getStatus returns the status of the baseChallenge. -func (bc *baseChallenge) getStatus() string { - return bc.Status -} - -// getToken returns the token of the baseChallenge. -func (bc *baseChallenge) getToken() string { - return bc.Token -} - -// getValidated returns the validated time of the baseChallenge. -func (bc *baseChallenge) getValidated() time.Time { - return bc.Validated -} - -// getCreated returns the created time of the baseChallenge. -func (bc *baseChallenge) getCreated() time.Time { - return bc.Created -} - -// getCreated returns the created time of the baseChallenge. -func (bc *baseChallenge) getError() *AError { - return bc.Error -} - -// toACME converts the internal Challenge type into the public acmeChallenge -// type for presentation in the ACME protocol. -func (bc *baseChallenge) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Challenge, error) { - ac := &Challenge{ - Type: bc.getType(), - Status: bc.getStatus(), - Token: bc.getToken(), - URL: dir.getLink(ctx, ChallengeLink, true, bc.getID()), - ID: bc.getID(), - AuthzID: bc.getAuthzID(), - } - if !bc.Validated.IsZero() { - ac.Validated = bc.Validated.Format(time.RFC3339) - } - if bc.Error != nil { - ac.Error = bc.Error - } - return ac, nil -} - -// save writes the challenge to disk. For new challenges 'old' should be nil, -// otherwise 'old' should be a pointer to the acme challenge as it was at the -// start of the request. This method will fail if the value currently found -// in the bucket/row does not match the value of 'old'. -func (bc *baseChallenge) save(db nosql.DB, old challenge) error { - newB, err := json.Marshal(bc) - if err != nil { - return ServerInternalErr(errors.Wrap(err, - "error marshaling new acme challenge")) - } - var oldB []byte - if old == nil { - oldB = nil - } else { - oldB, err = json.Marshal(old) - if err != nil { - return ServerInternalErr(errors.Wrap(err, - "error marshaling old acme challenge")) - } - } - - _, swapped, err := db.CmpAndSwap(challengeTable, []byte(bc.ID), oldB, newB) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error saving acme challenge")) - case !swapped: - return ServerInternalErr(errors.New("error saving acme challenge; " + - "acme challenge has changed since last read")) - default: +// Validate attempts to validate the challenge. Stores changes to the Challenge +// type using the DB interface. +// satisfactorily validated, the 'status' and 'validated' attributes are +// updated. +func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { + // If already valid or invalid then return without performing validation. + if ch.Status != StatusPending { return nil } + switch ch.Type { + case "http-01": + return http01Validate(ctx, ch, db, jwk, vo) + case "dns-01": + return dns01Validate(ctx, ch, db, jwk, vo) + case "tls-alpn-01": + return tlsalpn01Validate(ctx, ch, db, jwk, vo) + default: + return NewErrorISE("unexpected challenge type '%s'", ch.Type) + } } -func (bc *baseChallenge) clone() *baseChallenge { - u := *bc - return &u -} +func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { + url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} -func (bc *baseChallenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { - return nil, ServerInternalErr(errors.New("unimplemented")) -} + resp, err := vo.HTTPGet(url.String()) + if err != nil { + return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, + "error doing http GET for url %s", url)) + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return storeError(ctx, db, ch, false, NewError(ErrorConnectionType, + "error doing http GET for url %s with status code %d", url, resp.StatusCode)) + } -func (bc *baseChallenge) storeError(db nosql.DB, err *Error) error { - clone := bc.clone() - clone.Error = err.ToACME() - if err := clone.save(db, bc); err != nil { - return ServerInternalErr(errors.Wrap(err, "failure saving error to acme challenge")) + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return WrapErrorISE(err, "error reading "+ + "response body for url %s", url) + } + keyAuth := strings.TrimSpace(string(body)) + + expected, err := KeyAuthorization(ch.Token, jwk) + if err != nil { + return err + } + if keyAuth != expected { + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got %s", expected, keyAuth)) + } + + // Update and store the challenge. + ch.Status = StatusValid + ch.Error = nil + ch.ValidatedAt = clock.Now().Format(time.RFC3339) + + if err = db.UpdateChallenge(ctx, ch); err != nil { + return WrapErrorISE(err, "error updating challenge") } return nil } -// unmarshalChallenge unmarshals a challenge type into the correct sub-type. -func unmarshalChallenge(data []byte) (challenge, error) { - var getType struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &getType); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling challenge type")) - } - - switch getType.Type { - case "dns-01": - var bc baseChallenge - if err := json.Unmarshal(data, &bc); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ - "challenge type into dns01Challenge")) - } - return &dns01Challenge{&bc}, nil - case "http-01": - var bc baseChallenge - if err := json.Unmarshal(data, &bc); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ - "challenge type into http01Challenge")) - } - return &http01Challenge{&bc}, nil - case "tls-alpn-01": - var bc baseChallenge - if err := json.Unmarshal(data, &bc); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ - "challenge type into tlsALPN01Challenge")) - } - return &tlsALPN01Challenge{&bc}, nil - default: - return nil, ServerInternalErr(errors.Errorf("unexpected challenge type %s", getType.Type)) - } -} - -// http01Challenge represents an http-01 acme challenge. -type http01Challenge struct { - *baseChallenge -} - -// newHTTP01Challenge returns a new acme http-01 challenge. -func newHTTP01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { - bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) - if err != nil { - return nil, err - } - bc.Type = "http-01" - bc.Value = ops.Identifier.Value - - hc := &http01Challenge{bc} - if err := hc.save(db, nil); err != nil { - return nil, err - } - return hc, nil -} - -// Validate attempts to validate the challenge. If the challenge has been -// satisfactorily validated, the 'status' and 'validated' attributes are -// updated. -func (hc *http01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { - // If already valid or invalid then return without performing validation. - if hc.getStatus() == StatusValid || hc.getStatus() == StatusInvalid { - return hc, nil - } - url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", hc.Value, hc.Token) - - resp, err := vo.httpGet(url) - if err != nil { - if err = hc.storeError(db, ConnectionErr(errors.Wrapf(err, - "error doing http GET for url %s", url))); err != nil { - return nil, err - } - return hc, nil - } - if resp.StatusCode >= 400 { - if err = hc.storeError(db, - ConnectionErr(errors.Errorf("error doing http GET for url %s with status code %d", - url, resp.StatusCode))); err != nil { - return nil, err - } - return hc, nil - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error reading "+ - "response body for url %s", url)) - } - keyAuth := strings.Trim(string(body), "\r\n") - - expected, err := KeyAuthorization(hc.Token, jwk) - if err != nil { - return nil, err - } - if keyAuth != expected { - if err = hc.storeError(db, - RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ - "expected %s, but got %s", expected, keyAuth))); err != nil { - return nil, err - } - return hc, nil - } - - // Update and store the challenge. - upd := &http01Challenge{hc.baseChallenge.clone()} - upd.Status = StatusValid - upd.Error = nil - upd.Validated = clock.Now() - - if err := upd.save(db, hc); err != nil { - return nil, err - } - return upd, nil -} - -type tlsALPN01Challenge struct { - *baseChallenge -} - -// newTLSALPN01Challenge returns a new acme tls-alpn-01 challenge. -func newTLSALPN01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { - bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) - if err != nil { - return nil, err - } - bc.Type = "tls-alpn-01" - bc.Value = ops.Identifier.Value - - hc := &tlsALPN01Challenge{bc} - if err := hc.save(db, nil); err != nil { - return nil, err - } - return hc, nil -} - -func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { - // If already valid or invalid then return without performing validation. - if tc.getStatus() == StatusValid || tc.getStatus() == StatusInvalid { - return tc, nil - } - +func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { config := &tls.Config{ - NextProtos: []string{"acme-tls/1"}, - ServerName: tc.Value, + NextProtos: []string{"acme-tls/1"}, + // https://tools.ietf.org/html/rfc8737#section-4 + // ACME servers that implement "acme-tls/1" MUST only negotiate TLS 1.2 + // [RFC5246] or higher when connecting to clients for validation. + MinVersion: tls.VersionTLS12, + ServerName: ch.Value, InsecureSkipVerify: true, // we expect a self-signed challenge certificate } - hostPort := net.JoinHostPort(tc.Value, "443") + hostPort := net.JoinHostPort(ch.Value, "443") - conn, err := vo.tlsDial("tcp", hostPort, config) + conn, err := vo.TLSDial("tcp", hostPort, config) if err != nil { - if err = tc.storeError(db, - ConnectionErr(errors.Wrapf(err, "error doing TLS dial for %s", hostPort))); err != nil { - return nil, err - } - return tc, nil + return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, + "error doing TLS dial for %s", hostPort)) } defer conn.Close() @@ -409,86 +130,62 @@ func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo val certs := cs.PeerCertificates if len(certs) == 0 { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("%s challenge for %s resulted in no certificates", - tc.Type, tc.Value))); err != nil { - return nil, err - } - return tc, nil + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "%s challenge for %s resulted in no certificates", ch.Type, ch.Value)) } - if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("cannot negotiate ALPN acme-tls/1 protocol for "+ - "tls-alpn-01 challenge"))); err != nil { - return nil, err - } - return tc, nil + if cs.NegotiatedProtocol != "acme-tls/1" { + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) } leafCert := certs[0] - if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], tc.Value) { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "leaf certificate must contain a single DNS name, %v", tc.Value))); err != nil { - return nil, err - } - return tc, nil + if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) { + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value)) } idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} foundIDPeAcmeIdentifierV1Obsolete := false - keyAuth, err := KeyAuthorization(tc.Token, jwk) + keyAuth, err := KeyAuthorization(ch.Token, jwk) if err != nil { - return nil, err + return err } hashedKeyAuth := sha256.Sum256([]byte(keyAuth)) for _, ext := range leafCert.Extensions { if idPeAcmeIdentifier.Equal(ext.Id) { if !ext.Critical { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "acmeValidationV1 extension not critical"))); err != nil { - return nil, err - } - return tc, nil + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")) } var extValue []byte rest, err := asn1.Unmarshal(ext.Value, &extValue) if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "malformed acmeValidationV1 extension value"))); err != nil { - return nil, err - } - return tc, nil + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")) } if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: "+ "expected acmeValidationV1 extension value %s for this challenge but got %s", - hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue)))); err != nil { - return nil, err - } - return tc, nil + hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue))) } - upd := &tlsALPN01Challenge{tc.baseChallenge.clone()} - upd.Status = StatusValid - upd.Error = nil - upd.Validated = clock.Now() + ch.Status = StatusValid + ch.Error = nil + ch.ValidatedAt = clock.Now().Format(time.RFC3339) - if err := upd.save(db, tc); err != nil { - return nil, err + if err = db.UpdateChallenge(ctx, ch); err != nil { + return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge") } - return upd, nil + return nil } if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) { @@ -497,82 +194,30 @@ func (tc *tlsALPN01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo val } if foundIDPeAcmeIdentifierV1Obsolete { - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension"))); err != nil { - return nil, err - } - return tc, nil + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")) } - if err = tc.storeError(db, - RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "missing acmeValidationV1 extension"))); err != nil { - return nil, err - } - return tc, nil + return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, + "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } -// dns01Challenge represents an dns-01 acme challenge. -type dns01Challenge struct { - *baseChallenge -} - -// newDNS01Challenge returns a new acme dns-01 challenge. -func newDNS01Challenge(db nosql.DB, ops ChallengeOptions) (challenge, error) { - bc, err := newBaseChallenge(ops.AccountID, ops.AuthzID) - if err != nil { - return nil, err - } - bc.Type = "dns-01" - bc.Value = ops.Identifier.Value - - dc := &dns01Challenge{bc} - if err := dc.save(db, nil); err != nil { - return nil, err - } - return dc, nil -} - -// KeyAuthorization creates the ACME key authorization value from a token -// and a jwk. -func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { - thumbprint, err := jwk.Thumbprint(crypto.SHA256) - if err != nil { - return "", ServerInternalErr(errors.Wrap(err, "error generating JWK thumbprint")) - } - encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) - return fmt.Sprintf("%s.%s", token, encPrint), nil -} - -// validate attempts to validate the challenge. If the challenge has been -// satisfactorily validated, the 'status' and 'validated' attributes are -// updated. -func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) (challenge, error) { - // If already valid or invalid then return without performing validation. - if dc.getStatus() == StatusValid || dc.getStatus() == StatusInvalid { - return dc, nil - } - +func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com // Instead perform txt lookup for _acme-challenge.example.com - domain := strings.TrimPrefix(dc.Value, "*.") + domain := strings.TrimPrefix(ch.Value, "*.") - txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) + txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) if err != nil { - if err = dc.storeError(db, - DNSErr(errors.Wrapf(err, "error looking up TXT "+ - "records for domain %s", domain))); err != nil { - return nil, err - } - return dc, nil + return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, + "error looking up TXT records for domain %s", domain)) } - expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk) + expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk) if err != nil { - return nil, err + return err } h := sha256.Sum256([]byte(expectedKeyAuth)) expected := base64.RawURLEncoding.EncodeToString(h[:]) @@ -584,37 +229,51 @@ func (dc *dns01Challenge) validate(db nosql.DB, jwk *jose.JSONWebKey, vo validat } } if !found { - if err = dc.storeError(db, - RejectedIdentifierErr(errors.Errorf("keyAuthorization "+ - "does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))); err != nil { - return nil, err - } - return dc, nil + return storeError(ctx, db, ch, false, NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got %s", expectedKeyAuth, txtRecords)) } // Update and store the challenge. - upd := &dns01Challenge{dc.baseChallenge.clone()} - upd.Status = StatusValid - upd.Error = nil - upd.Validated = time.Now().UTC() + ch.Status = StatusValid + ch.Error = nil + ch.ValidatedAt = clock.Now().Format(time.RFC3339) - if err := upd.save(db, dc); err != nil { - return nil, err + if err = db.UpdateChallenge(ctx, ch); err != nil { + return WrapErrorISE(err, "error updating challenge") } - return upd, nil + return nil } -// getChallenge retrieves and unmarshals an ACME challenge type from the database. -func getChallenge(db nosql.DB, id string) (challenge, error) { - b, err := db.Get(challengeTable, []byte(id)) - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id)) - } else if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id)) - } - ch, err := unmarshalChallenge(b) +// KeyAuthorization creates the ACME key authorization value from a token +// and a jwk. +func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { + thumbprint, err := jwk.Thumbprint(crypto.SHA256) if err != nil { - return nil, err + return "", WrapErrorISE(err, "error generating JWK thumbprint") } - return ch, nil + encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) + return fmt.Sprintf("%s.%s", token, encPrint), nil +} + +// storeError the given error to an ACME error and saves using the DB interface. +func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err *Error) error { + ch.Error = err + if markInvalid { + ch.Status = StatusInvalid + } + if err := db.UpdateChallenge(ctx, ch); err != nil { + return WrapErrorISE(err, "failure saving error to acme challenge") + } + return nil +} + +type httpGetter func(string) (*http.Response, error) +type lookupTxt func(string) ([]string, error) +type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) + +// ValidateChallengeOptions are ACME challenge validator functions. +type ValidateChallengeOptions struct { + HTTPGet httpGetter + LookupTxt lookupTxt + TLSDial tlsDialer } diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 87ec0c4c..14287945 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -13,7 +13,6 @@ import ( "encoding/asn1" "encoding/base64" "encoding/hex" - "encoding/json" "fmt" "io" "io/ioutil" @@ -21,644 +20,150 @@ import ( "net" "net/http" "net/http/httptest" - "net/url" + "strings" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" "go.step.sm/crypto/jose" ) -var testOps = ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "", // will get set correctly depending on the "new.." method. - Value: "zap.internal", - }, -} - -func newDNSCh() (challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newDNS01Challenge(mockdb, testOps) -} - -func newTLSALPNCh() (challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newTLSALPN01Challenge(mockdb, testOps) -} - -func newHTTPCh() (challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newHTTP01Challenge(mockdb, testOps) -} - -func newHTTPChWithServer(host string) (challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newHTTP01Challenge(mockdb, ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "", // will get set correctly depending on the "new.." method. - Value: host, - }, - }) -} - -func TestNewHTTP01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "http", - Value: "zap.internal", - }, - } +func Test_storeError(t *testing.T) { type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newHTTP01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "http-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } - } - }) - } -} - -func TestNewTLSALPN01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "http", - Value: "zap.internal", - }, - } - type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newTLSALPN01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "tls-alpn-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } - } - }) - } -} - -func TestNewDNS01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "dns", - Value: "zap.internal", - }, - } - type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newDNS01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "dns-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } - } - }) - } -} - -func TestChallengeToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Validated = clock.Now() - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - tests := map[string]challenge{ - "dns": dnsCh, - "http": httpCh, - "tls-alpn": tlsALPNCh, - } - for name, ch := range tests { - t.Run(name, func(t *testing.T) { - ach, err := ch.toACME(ctx, nil, dir) - assert.FatalError(t, err) - - assert.Equals(t, ach.Type, ch.getType()) - assert.Equals(t, ach.Status, ch.getStatus()) - assert.Equals(t, ach.Token, ch.getToken()) - assert.Equals(t, ach.URL, - fmt.Sprintf("%s/acme/%s/challenge/%s", - baseURL.String(), provName, ch.getID())) - assert.Equals(t, ach.ID, ch.getID()) - assert.Equals(t, ach.AuthzID, ch.getAuthzID()) - - if ach.Type == "http-01" { - v, err := time.Parse(time.RFC3339, ach.Validated) - assert.FatalError(t, err) - assert.Equals(t, v.String(), _httpCh.baseChallenge.Validated.String()) - } else { - assert.Equals(t, ach.Validated, "") - } - }) - } -} - -func TestChallengeSave(t *testing.T) { - type test struct { - ch challenge - old challenge - db nosql.DB - err *Error + ch *Challenge + db DB + markInvalid bool + err *Error } + err := NewError(ErrorMalformedType, "foo") tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) + "fail/db.UpdateChallenge-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusValid, + } return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + ch: ch, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "fail/old-nil/swap-false": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) + "fail/db.UpdateChallenge-acme-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusValid, + } return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil + ch: ch, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return NewError(ErrorMalformedType, "bar") }, }, - err: ServerInternalErr(errors.New("error saving acme challenge; acme challenge has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, []byte(httpCh.getID()), key) - return []byte("foo"), true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldHTTPCh, err := newHTTPCh() - assert.FatalError(t, err) - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldHTTPCh) - assert.FatalError(t, err) - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - old: oldHTTPCh, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, []byte(httpCh.getID()), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.ch.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestChallengeClone(t *testing.T) { - ch, err := newHTTPCh() - assert.FatalError(t, err) - - clone := ch.clone() - - assert.Equals(t, clone.getID(), ch.getID()) - assert.Equals(t, clone.getAccountID(), ch.getAccountID()) - assert.Equals(t, clone.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, clone.getStatus(), ch.getStatus()) - assert.Equals(t, clone.getToken(), ch.getToken()) - assert.Equals(t, clone.getCreated(), ch.getCreated()) - assert.Equals(t, clone.getValidated(), ch.getValidated()) - - clone.Status = StatusValid - - assert.NotEquals(t, clone.getStatus(), ch.getStatus()) -} - -func TestChallengeUnmarshal(t *testing.T) { - type test struct { - ch challenge - chb []byte - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/nil": func(t *testing.T) test { - return test{ - chb: nil, - err: ServerInternalErr(errors.New("error unmarshaling challenge type: unexpected end of JSON input")), - } - }, - "fail/unexpected-type-http": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Type = "foo" - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "fail/unexpected-type-alpn": func(t *testing.T) test { - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - _tlsALPNCh, ok := tlsALPNCh.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _tlsALPNCh.baseChallenge.Type = "foo" - b, err := json.Marshal(tlsALPNCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "fail/unexpected-type-dns": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - _dnsCh, ok := dnsCh.(*dns01Challenge) - assert.Fatal(t, ok) - _dnsCh.baseChallenge.Type = "foo" - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "ok/dns": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - ch: dnsCh, - chb: b, - } - }, - "ok/http": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - chb: b, - } - }, - "ok/alpn": func(t *testing.T) test { - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - b, err := json.Marshal(tlsALPNCh) - assert.FatalError(t, err) - return test{ - ch: tlsALPNCh, - chb: b, - } - }, - "ok/err": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Error = ServerInternalErr(errors.New("force")).ToACME() - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - chb: b, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if ch, err := unmarshalChallenge(tc.chb); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.ch.getID(), ch.getID()) - assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) - assert.Equals(t, tc.ch.getToken(), ch.getToken()) - assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) - assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) - } - } - }) - } -} -func TestGetChallenge(t *testing.T) { - type test struct { - id string - db nosql.DB - ch challenge - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("challenge %s not found: not found", dnsCh.getID())), - } - }, - "fail/db-error": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", dnsCh.getID())), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - _dnsCh, ok := dnsCh.(*dns01Challenge) - assert.Fatal(t, ok) - _dnsCh.baseChallenge.Type = "foo" - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(dnsCh.getID())) - return b, nil - }, - }, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), + err: NewError(ErrorMalformedType, "failure saving error to acme challenge: bar"), } }, "ok": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusValid, + } return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(dnsCh.getID())) - return b, nil + ch: ch, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, } }, + "ok/mark-invalid": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusValid, + } + return test{ + ch: ch, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusInvalid) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + markInvalid: true, + } + }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if ch, err := getChallenge(tc.db, tc.id); err != nil { + if err := storeError(context.Background(), tc.db, tc.ch, tc.markInvalid, err); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.ch.getID(), ch.getID()) - assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) - assert.Equals(t, tc.ch.getToken(), ch.getToken()) - assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) - assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) - } + assert.Nil(t, tc.err) } }) } @@ -679,7 +184,7 @@ func TestKeyAuthorization(t *testing.T) { return test{ token: "1234", jwk: jwk, - err: ServerInternalErr(errors.Errorf("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), } }, "ok": func(t *testing.T) test { @@ -701,11 +206,16 @@ func TestKeyAuthorization(t *testing.T) { tc := run(t) if ka, err := KeyAuthorization(tc.token, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { if assert.Nil(t, tc.err) { @@ -716,6 +226,293 @@ func TestKeyAuthorization(t *testing.T) { } } +func TestChallenge_Validate(t *testing.T) { + type test struct { + ch *Challenge + vo *ValidateChallengeOptions + jwk *jose.JSONWebKey + db DB + srv *httptest.Server + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/already-valid": func(t *testing.T) test { + ch := &Challenge{ + Status: StatusValid, + } + return test{ + ch: ch, + } + }, + "fail/already-invalid": func(t *testing.T) test { + ch := &Challenge{ + Status: StatusInvalid, + } + return test{ + ch: ch, + } + }, + "fail/unexpected-type": func(t *testing.T) test { + ch := &Challenge{ + Status: StatusPending, + Type: "foo", + } + return test{ + ch: ch, + err: NewErrorISE("unexpected challenge type 'foo'"), + } + }, + "fail/http-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Status: StatusPending, + Type: "http-01", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/http-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Status: StatusPending, + Type: "http-01", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "fail/dns-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Type: "dns-01", + Status: StatusPending, + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/dns-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Type: "dns-01", + Status: StatusPending, + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "fail/tls-alpn-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Type: "tls-alpn-01", + Status: StatusPending, + Value: "zap.internal", + } + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: force", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/tls-alpn-01": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Type: "tls-alpn-01", + Status: StatusPending, + Value: "zap.internal", + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Error, nil) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + + if tc.srv != nil { + defer tc.srv.Close() + } + + if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + type errReader int func (errReader) Read(p []byte) (n int, err error) { @@ -727,258 +524,361 @@ func (errReader) Close() error { func TestHTTP01Validate(t *testing.T) { type test struct { - vo validateOptions - ch challenge - res challenge + vo *ValidateChallengeOptions + ch *Challenge jwk *jose.JSONWebKey - db nosql.DB + db DB err *Error } tests := map[string]func(t *testing.T) test{ - "ok/status-already-valid": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - return test{ - ch: ch, - res: ch, + "fail/http-get-error-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, } - }, - "ok/status-already-invalid": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusInvalid - return test{ - ch: ch, - res: ch, - } - }, - "ok/http-get-error": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+ - "http://zap.internal/.well-known/acme-challenge/%s: force", ch.getToken())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/http-get->=400": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "ok/http-get-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, + } - expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+ - "http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.getToken())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "fail/http-get->=400-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, + Body: errReader(0), }, nil }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "fail/read-body": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - jwk.Key = "foo" + "ok/http-get->=400": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, + } return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + Body: errReader(0), + }, nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "fail/read-body": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ Body: errReader(0), }, nil }, }, - jwk: jwk, - err: ServerInternalErr(errors.Errorf("error reading response "+ - "body for url http://zap.internal/.well-known/acme-challenge/%s: force", - ch.getToken())), + err: NewErrorISE("error reading response body for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token), } }, - "fail/key-authorization-gen-error": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) + "fail/key-auth-gen-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, + } + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) jwk.Key = "foo" return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ Body: ioutil.NopCloser(bytes.NewBufferString("foo")), }, nil }, }, jwk: jwk, - err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), } }, "ok/key-auth-mismatch": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, + } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ - "expected %s, but got foo", expKeyAuth)) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ Body: ioutil.NopCloser(bytes.NewBufferString("foo")), }, nil }, }, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusInvalid) + + err := NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, } }, - "fail/save-error": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) + "fail/key-auth-mismatch-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, + } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ - Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), + Body: ioutil.NopCloser(bytes.NewBufferString("foo")), }, nil }, }, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusInvalid) + + err := NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Error = MalformedErr(nil).ToACME() - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "fail/update-challenge-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, + } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &http01Challenge{baseClone} - return test{ - ch: ch, - res: newCh, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil }, }, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - - httpCh, err := unmarshalChallenge(newval) + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) assert.FatalError(t, err) - assert.Equals(t, httpCh.getStatus(), StatusValid) - assert.True(t, httpCh.getValidated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, httpCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) - baseClone.Validated = httpCh.getValidated() + return errors.New("force") + }, + }, + err: NewErrorISE("error updating challenge: force"), + } + }, + "ok": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: "zap.internal", + Status: StatusPending, + } - return nil, true, nil + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), + }, nil + }, + }, + jwk: jwk, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) + assert.FatalError(t, err) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) + return nil }, }, } @@ -987,648 +887,320 @@ func TestHTTP01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res.getID(), ch.getID()) - assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.res.getStatus(), ch.getStatus()) - assert.Equals(t, tc.res.getToken(), ch.getToken()) - assert.Equals(t, tc.res.getCreated(), ch.getCreated()) - assert.Equals(t, tc.res.getValidated(), ch.getValidated()) - assert.Equals(t, tc.res.getError(), ch.getError()) - } + assert.Nil(t, tc.err) } }) } } -func TestTLSALPN01Validate(t *testing.T) { +func TestDNS01Validate(t *testing.T) { + fulldomain := "*.zap.internal" + domain := strings.TrimPrefix(fulldomain, "*.") type test struct { - srv *httptest.Server - vo validateOptions - ch challenge - res challenge + vo *ValidateChallengeOptions + ch *Challenge jwk *jose.JSONWebKey - db nosql.DB + db DB err *Error } tests := map[string]func(t *testing.T) test{ - "ok/status-already-valid": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - _ch, ok := ch.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - - return test{ - ch: ch, - res: ch, + "fail/lookupTXT-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, } - }, - "ok/status-already-invalid": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - _ch, ok := ch.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusInvalid - - return test{ - ch: ch, - res: ch, - } - }, - "ok/tls-dial-error": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: force", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) return test{ ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) + + err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/timeout": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := ConnectionErr(errors.Errorf("error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(nil) - // srv.Start() - do not start server to cause timeout + "ok/lookupTXT-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, + } return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) + + err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, } }, - "ok/no-certificates": func(t *testing.T) test { - ch, err := newTLSALPNCh() + "fail/key-auth-gen-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - oldb, err := json.Marshal(ch) + jwk.Key = "foo" + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo"}, nil + }, + }, + jwk: jwk, + err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), + } + }, + "fail/key-auth-mismatch-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, + } + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expErr := RejectedIdentifierErr(errors.Errorf("tls-alpn-01 challenge for %v resulted in no certificates", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) return test{ ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.Client(&noopConn{}, config), nil + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo", "bar"}, nil }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) + + err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/no-names": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + "ok/key-auth-mismatch-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, + } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo", "bar"}, nil }, }, - res: ch, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusPending) + + err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + jwk: jwk, } }, - "ok/too-many-names": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + "fail/update-challenge-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, + } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue(), "other.internal") - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() + h := sha256.Sum256([]byte(expKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo", expected}, nil }, }, - res: ch, - } - }, - "ok/wrong-name": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) + assert.FatalError(t, err) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, "other.internal") - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() - - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/no-extension": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - cert, err := newTLSALPNValidationCert(nil, false, true, ch.getValue()) - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() - - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/extension-not-critical": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.getValue()) - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() - - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/extension-malformed": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.getValue()) - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() - - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/no-protocol": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.New("cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - srv := httptest.NewTLSServer(nil) - - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) + return errors.New("force") }, }, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/mismatched-token": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - incorrectTokenHash := sha256.Sum256([]byte("mismatched")) - - expErr := RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ - "expected acmeValidationV1 extension value %s for this challenge but got %s", - hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:]))) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.getValue()) - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() - - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, - } - }, - "ok/obsolete-oid": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.New("incorrect certificate for tls-alpn-01 challenge: " + - "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &tlsALPN01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.getValue()) - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() - - return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: tlsDial, - }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, string(newval), string(newb)) - return nil, true, nil - }, - }, - res: ch, + err: NewErrorISE("error updating challenge: force"), } }, "ok": func(t *testing.T) test { - ch, err := newTLSALPNCh() - assert.FatalError(t, err) - _ch, ok := ch.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Error = MalformedErr(nil).ToACME() - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &tlsALPN01Challenge{baseClone} + ch := &Challenge{ + ID: "chID", + Token: "token", + Value: fulldomain, + Status: StatusPending, + } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - - cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.getValue()) - assert.FatalError(t, err) - - srv, tlsDial := newTestTLSALPNServer(cert) - srv.Start() + h := sha256.Sum256([]byte(expKeyAuth)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) return test{ - srv: srv, - ch: ch, - vo: validateOptions{ - tlsDial: func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) { - assert.Equals(t, network, "tcp") - assert.Equals(t, addr, net.JoinHostPort(newCh.getValue(), "443")) - assert.Equals(t, config.NextProtos, []string{"acme-tls/1"}) - assert.Equals(t, config.ServerName, newCh.getValue()) - assert.True(t, config.InsecureSkipVerify) + ch: ch, + vo: &ValidateChallengeOptions{ + LookupTxt: func(url string) ([]string, error) { + return []string{"foo", expected}, nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) - return tlsDial(network, addr, config) + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) + assert.FatalError(t, err) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) + + return nil }, }, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - - alpnCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, alpnCh.getStatus(), StatusValid) - assert.True(t, alpnCh.getValidated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, alpnCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) - - baseClone.Validated = alpnCh.getValidated() - - return nil, true, nil - }, - }, - res: newCh, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - - if tc.srv != nil { - defer tc.srv.Close() - } - - if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res.getID(), ch.getID()) - assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.res.getStatus(), ch.getStatus()) - assert.Equals(t, tc.res.getToken(), ch.getToken()) - assert.Equals(t, tc.res.getCreated(), ch.getCreated()) - assert.Equals(t, tc.res.getValidated(), ch.getValidated()) - assert.Equals(t, tc.res.getError(), ch.getError()) - } + assert.Nil(t, tc.err) } }) } @@ -1726,268 +1298,939 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na }, nil } -func TestDNS01Validate(t *testing.T) { +func TestTLSALPN01Validate(t *testing.T) { + makeTLSCh := func() *Challenge { + return &Challenge{ + ID: "chID", + Token: "token", + Type: "tls-alpn-01", + Status: StatusPending, + Value: "zap.internal", + } + } type test struct { - vo validateOptions - ch challenge - res challenge + vo *ValidateChallengeOptions + ch *Challenge jwk *jose.JSONWebKey - db nosql.DB + db DB + srv *httptest.Server err *Error } tests := map[string]func(t *testing.T) test{ - "ok/status-already-valid": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - return test{ - ch: ch, - res: ch, - } - }, - "ok/status-already-invalid": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusInvalid - return test{ - ch: ch, - res: ch, - } - }, - "ok/lookup-txt-error": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - - expErr := DNSErr(errors.Errorf("error looking up TXT records for "+ - "domain %s: force", ch.getValue())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &dns01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + "fail/tlsDial-store-error": func(t *testing.T) test { + ch := makeTLSCh() return test{ ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: force", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/lookup-txt-wildcard": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Value = "*.zap.internal" + "ok/tlsDial-error": func(t *testing.T) test { + ch := makeTLSCh() + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: force", ch.Value) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) - assert.FatalError(t, err) - h := sha256.Sum256([]byte(expKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "ok/tlsDial-timeout": func(t *testing.T) test { + ch := makeTLSCh() - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &dns01Challenge{baseClone} + srv, tlsDial := newTestTLSALPNServer(nil) + // srv.Start() - do not start server to cause timeout return test{ - ch: ch, - res: newCh, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - assert.Equals(t, url, "_acme-challenge.zap.internal") - return []string{"foo", expected}, nil + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, ch.Status) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - dnsCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, dnsCh.getStatus(), StatusValid) - baseClone.Validated = dnsCh.getValidated() - return nil, true, nil + srv: srv, + } + }, + "ok/no-certificates-error": func(t *testing.T) test { + ch := makeTLSCh() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.Client(&noopConn{}, config), nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "tls-alpn-01 challenge for %v resulted in no certificates", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, } }, - "fail/key-authorization-gen-error": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) + "fail/no-certificates-store-error": func(t *testing.T) test { + ch := makeTLSCh() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.Client(&noopConn{}, config), nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "tls-alpn-01 challenge for %v resulted in no certificates", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/error-no-protocol": func(t *testing.T) test { + ch := makeTLSCh() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) + + srv := httptest.NewTLSServer(nil) + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, + "fail/no-protocol-store-error": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + srv := httptest.NewTLSServer(nil) + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/no-names-error": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, + "fail/no-names-store-error": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/too-many-names-error": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value, "other.internal") + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, + "ok/wrong-name": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, "other.internal") + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single DNS name, %v", ch.Value) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, + "fail/key-auth-gen-error": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) jwk.Key = "foo" + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + return test{ ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", "bar"}, nil - }, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, }, + srv: srv, jwk: jwk, - err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), } }, - "ok/key-auth-mismatch": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "ok/error-no-extension": func(t *testing.T) test { + ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + cert, err := newTLSALPNValidationCert(nil, false, true, ch.Value) assert.FatalError(t, err) - expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ - "expected %s, but got %s", expKeyAuth, []string{"foo", "bar"})) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() return test{ ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", "bar"}, nil + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, + srv: srv, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil - }, - }, - res: ch, } }, - "fail/save-error": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) + "fail/no-extension-store-error": func(t *testing.T) test { + ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + cert, err := newTLSALPNValidationCert(nil, false, true, ch.Value) assert.FatalError(t, err) - h := sha256.Sum256([]byte(expKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + return test{ ch: ch, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", expected}, nil + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, + srv: srv, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/error-extension-not-critical": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), + srv: srv, + jwk: jwk, + } + }, + "fail/extension-not-critical-store-error": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/error-malformed-extension": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, + "fail/malformed-extension-store-error": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/error-keyauth-mismatch": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + incorrectTokenHash := sha256.Sum256([]byte("mismatched")) + + cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ + "expected acmeValidationV1 extension value %s for this challenge but got %s", + hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:])) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, + "fail/keyauth-mismatch-store-error": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + incorrectTokenHash := sha256.Sum256([]byte("mismatched")) + + cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ + "expected acmeValidationV1 extension value %s for this challenge but got %s", + hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:])) + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), + } + }, + "ok/error-obsolete-oid": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ + "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + srv: srv, + jwk: jwk, + } + }, + "fail/obsolete-oid-store-error": func(t *testing.T) test { + ch := makeTLSCh() + + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) + + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusInvalid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ + "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension") + + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") + }, + }, + srv: srv, + jwk: jwk, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok": func(t *testing.T) test { - ch, err := newDNSCh() - assert.FatalError(t, err) - _ch, ok := ch.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Error = MalformedErr(nil).ToACME() - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - h := sha256.Sum256([]byte(expKeyAuth)) - expected := base64.RawURLEncoding.EncodeToString(h[:]) + expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &dns01Challenge{baseClone} + cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) + assert.FatalError(t, err) + + srv, tlsDial := newTestTLSALPNServer(cert) + srv.Start() return test{ - ch: ch, - res: newCh, - vo: validateOptions{ - lookupTxt: func(url string) ([]string, error) { - return []string{"foo", expected}, nil + ch: ch, + vo: &ValidateChallengeOptions{ + TLSDial: tlsDial, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Type, ch.Type) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Error, nil) + return nil }, }, + srv: srv, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - - dnsCh, err := unmarshalChallenge(newval) - assert.FatalError(t, err) - assert.Equals(t, dnsCh.getStatus(), StatusValid) - assert.True(t, dnsCh.getValidated().Before(time.Now().UTC())) - assert.True(t, dnsCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) - - baseClone.Validated = dnsCh.getValidated() - - return nil, true, nil - }, - }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + + if tc.srv != nil { + defer tc.srv.Close() + } + + if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res.getID(), ch.getID()) - assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.res.getStatus(), ch.getStatus()) - assert.Equals(t, tc.res.getToken(), ch.getToken()) - assert.Equals(t, tc.res.getCreated(), ch.getCreated()) - assert.Equals(t, tc.res.getValidated(), ch.getValidated()) - assert.Equals(t, tc.res.getError(), ch.getError()) - } + assert.Nil(t, tc.err) } }) } diff --git a/acme/common.go b/acme/common.go index fec47b94..26552c61 100644 --- a/acme/common.go +++ b/acme/common.go @@ -3,19 +3,32 @@ package acme import ( "context" "crypto/x509" - "net/url" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/randutil" ) +// CertificateAuthority is the interface implemented by a CA authority. +type CertificateAuthority interface { + Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + LoadProvisionerByID(string) (provisioner.Interface, error) +} + +// Clock that returns time in UTC rounded to seconds. +type Clock struct{} + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Truncate(time.Second) +} + +var clock Clock + // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) + GetID() string GetName() string DefaultTLSCertDuration() time.Duration GetOptions() *provisioner.Options @@ -25,6 +38,7 @@ type Provisioner interface { type MockProvisioner struct { Mret1 interface{} Merr error + MgetID func() string MgetName func() string MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) MdefaultTLSCertDuration func() time.Duration @@ -55,6 +69,7 @@ func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration { return m.Mret1.(time.Duration) } +// GetOptions mock func (m *MockProvisioner) GetOptions() *provisioner.Options { if m.MgetOptions != nil { return m.MgetOptions() @@ -62,120 +77,10 @@ func (m *MockProvisioner) GetOptions() *provisioner.Options { return m.Mret1.(*provisioner.Options) } -// ContextKey is the key type for storing and searching for ACME request -// essentials in the context of a request. -type ContextKey string - -const ( - // AccContextKey account key - AccContextKey = ContextKey("acc") - // BaseURLContextKey baseURL key - BaseURLContextKey = ContextKey("baseURL") - // JwsContextKey jws key - JwsContextKey = ContextKey("jws") - // JwkContextKey jwk key - JwkContextKey = ContextKey("jwk") - // PayloadContextKey payload key - PayloadContextKey = ContextKey("payload") - // ProvisionerContextKey provisioner key - ProvisionerContextKey = ContextKey("provisioner") -) - -// AccountFromContext searches the context for an ACME account. Returns the -// account or an error. -func AccountFromContext(ctx context.Context) (*Account, error) { - val, ok := ctx.Value(AccContextKey).(*Account) - if !ok || val == nil { - return nil, AccountDoesNotExistErr(nil) +// GetID mock +func (m *MockProvisioner) GetID() string { + if m.MgetID != nil { + return m.MgetID() } - return val, nil + return m.Mret1.(string) } - -// BaseURLFromContext returns the baseURL if one is stored in the context. -func BaseURLFromContext(ctx context.Context) *url.URL { - val, ok := ctx.Value(BaseURLContextKey).(*url.URL) - if !ok || val == nil { - return nil - } - return val -} - -// JwkFromContext searches the context for a JWK. Returns the JWK or an error. -func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { - val, ok := ctx.Value(JwkContextKey).(*jose.JSONWebKey) - if !ok || val == nil { - return nil, ServerInternalErr(errors.Errorf("jwk expected in request context")) - } - return val, nil -} - -// JwsFromContext searches the context for a JWS. Returns the JWS or an error. -func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { - val, ok := ctx.Value(JwsContextKey).(*jose.JSONWebSignature) - if !ok || val == nil { - return nil, ServerInternalErr(errors.Errorf("jws expected in request context")) - } - return val, nil -} - -// ProvisionerFromContext searches the context for a provisioner. Returns the -// provisioner or an error. -func ProvisionerFromContext(ctx context.Context) (Provisioner, error) { - val := ctx.Value(ProvisionerContextKey) - if val == nil { - return nil, ServerInternalErr(errors.Errorf("provisioner expected in request context")) - } - pval, ok := val.(Provisioner) - if !ok || pval == nil { - return nil, ServerInternalErr(errors.Errorf("provisioner in context is not an ACME provisioner")) - } - return pval, nil -} - -// SignAuthority is the interface implemented by a CA authority. -type SignAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - LoadProvisionerByID(string) (provisioner.Interface, error) -} - -// Identifier encodes the type that an order pertains to. -type Identifier struct { - Type string `json:"type"` - Value string `json:"value"` -} - -var ( - // StatusValid -- valid - StatusValid = "valid" - // StatusInvalid -- invalid - StatusInvalid = "invalid" - // StatusPending -- pending; e.g. an Order that is not ready to be finalized. - StatusPending = "pending" - // StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid. - StatusDeactivated = "deactivated" - // StatusReady -- ready; e.g. for an Order that is ready to be finalized. - StatusReady = "ready" - //statusExpired = "expired" - //statusActive = "active" - //statusProcessing = "processing" -) - -var idLen = 32 - -func randID() (val string, err error) { - val, err = randutil.Alphanumeric(idLen) - if err != nil { - return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID")) - } - return val, nil -} - -// Clock that returns time in UTC rounded to seconds. -type Clock int - -// Now returns the UTC time rounded to seconds. -func (c *Clock) Now() time.Time { - return time.Now().UTC().Round(time.Second) -} - -var clock = new(Clock) diff --git a/acme/db.go b/acme/db.go new file mode 100644 index 00000000..d678fef4 --- /dev/null +++ b/acme/db.go @@ -0,0 +1,251 @@ +package acme + +import ( + "context" + + "github.com/pkg/errors" +) + +// ErrNotFound is an error that should be used by the acme.DB interface to +// indicate that an entity does not exist. For example, in the new-account +// endpoint, if GetAccountByKeyID returns ErrNotFound we will create the new +// account. +var ErrNotFound = errors.New("not found") + +// DB is the DB interface expected by the step-ca ACME API. +type DB interface { + CreateAccount(ctx context.Context, acc *Account) error + GetAccount(ctx context.Context, id string) (*Account, error) + GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) + UpdateAccount(ctx context.Context, acc *Account) error + + CreateNonce(ctx context.Context) (Nonce, error) + DeleteNonce(ctx context.Context, nonce Nonce) error + + CreateAuthorization(ctx context.Context, az *Authorization) error + GetAuthorization(ctx context.Context, id string) (*Authorization, error) + UpdateAuthorization(ctx context.Context, az *Authorization) error + + CreateCertificate(ctx context.Context, cert *Certificate) error + GetCertificate(ctx context.Context, id string) (*Certificate, error) + + CreateChallenge(ctx context.Context, ch *Challenge) error + GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error) + UpdateChallenge(ctx context.Context, ch *Challenge) error + + CreateOrder(ctx context.Context, o *Order) error + GetOrder(ctx context.Context, id string) (*Order, error) + GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) + UpdateOrder(ctx context.Context, o *Order) error +} + +// MockDB is an implementation of the DB interface that should only be used as +// a mock in tests. +type MockDB struct { + MockCreateAccount func(ctx context.Context, acc *Account) error + MockGetAccount func(ctx context.Context, id string) (*Account, error) + MockGetAccountByKeyID func(ctx context.Context, kid string) (*Account, error) + MockUpdateAccount func(ctx context.Context, acc *Account) error + + MockCreateNonce func(ctx context.Context) (Nonce, error) + MockDeleteNonce func(ctx context.Context, nonce Nonce) error + + MockCreateAuthorization func(ctx context.Context, az *Authorization) error + MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) + MockUpdateAuthorization func(ctx context.Context, az *Authorization) error + + MockCreateCertificate func(ctx context.Context, cert *Certificate) error + MockGetCertificate func(ctx context.Context, id string) (*Certificate, error) + + MockCreateChallenge func(ctx context.Context, ch *Challenge) error + MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error) + MockUpdateChallenge func(ctx context.Context, ch *Challenge) error + + MockCreateOrder func(ctx context.Context, o *Order) error + MockGetOrder func(ctx context.Context, id string) (*Order, error) + MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error) + MockUpdateOrder func(ctx context.Context, o *Order) error + + MockRet1 interface{} + MockError error +} + +// CreateAccount mock. +func (m *MockDB) CreateAccount(ctx context.Context, acc *Account) error { + if m.MockCreateAccount != nil { + return m.MockCreateAccount(ctx, acc) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetAccount mock. +func (m *MockDB) GetAccount(ctx context.Context, id string) (*Account, error) { + if m.MockGetAccount != nil { + return m.MockGetAccount(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Account), m.MockError +} + +// GetAccountByKeyID mock +func (m *MockDB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) { + if m.MockGetAccountByKeyID != nil { + return m.MockGetAccountByKeyID(ctx, kid) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Account), m.MockError +} + +// UpdateAccount mock +func (m *MockDB) UpdateAccount(ctx context.Context, acc *Account) error { + if m.MockUpdateAccount != nil { + return m.MockUpdateAccount(ctx, acc) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateNonce mock +func (m *MockDB) CreateNonce(ctx context.Context) (Nonce, error) { + if m.MockCreateNonce != nil { + return m.MockCreateNonce(ctx) + } else if m.MockError != nil { + return Nonce(""), m.MockError + } + return m.MockRet1.(Nonce), m.MockError +} + +// DeleteNonce mock +func (m *MockDB) DeleteNonce(ctx context.Context, nonce Nonce) error { + if m.MockDeleteNonce != nil { + return m.MockDeleteNonce(ctx, nonce) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateAuthorization mock +func (m *MockDB) CreateAuthorization(ctx context.Context, az *Authorization) error { + if m.MockCreateAuthorization != nil { + return m.MockCreateAuthorization(ctx, az) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetAuthorization mock +func (m *MockDB) GetAuthorization(ctx context.Context, id string) (*Authorization, error) { + if m.MockGetAuthorization != nil { + return m.MockGetAuthorization(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Authorization), m.MockError +} + +// UpdateAuthorization mock +func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) error { + if m.MockUpdateAuthorization != nil { + return m.MockUpdateAuthorization(ctx, az) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateCertificate mock +func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error { + if m.MockCreateCertificate != nil { + return m.MockCreateCertificate(ctx, cert) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetCertificate mock +func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, error) { + if m.MockGetCertificate != nil { + return m.MockGetCertificate(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Certificate), m.MockError +} + +// CreateChallenge mock +func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error { + if m.MockCreateChallenge != nil { + return m.MockCreateChallenge(ctx, ch) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetChallenge mock +func (m *MockDB) GetChallenge(ctx context.Context, chID, azID string) (*Challenge, error) { + if m.MockGetChallenge != nil { + return m.MockGetChallenge(ctx, chID, azID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Challenge), m.MockError +} + +// UpdateChallenge mock +func (m *MockDB) UpdateChallenge(ctx context.Context, ch *Challenge) error { + if m.MockUpdateChallenge != nil { + return m.MockUpdateChallenge(ctx, ch) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateOrder mock +func (m *MockDB) CreateOrder(ctx context.Context, o *Order) error { + if m.MockCreateOrder != nil { + return m.MockCreateOrder(ctx, o) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetOrder mock +func (m *MockDB) GetOrder(ctx context.Context, id string) (*Order, error) { + if m.MockGetOrder != nil { + return m.MockGetOrder(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Order), m.MockError +} + +// UpdateOrder mock +func (m *MockDB) UpdateOrder(ctx context.Context, o *Order) error { + if m.MockUpdateOrder != nil { + return m.MockUpdateOrder(ctx, o) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetOrdersByAccountID mock +func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { + if m.MockGetOrdersByAccountID != nil { + return m.MockGetOrdersByAccountID(ctx, accID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.([]string), m.MockError +} diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go new file mode 100644 index 00000000..1c3bec5d --- /dev/null +++ b/acme/db/nosql/account.go @@ -0,0 +1,136 @@ +package nosql + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + nosqlDB "github.com/smallstep/nosql" + "go.step.sm/crypto/jose" +) + +// dbAccount represents an ACME account. +type dbAccount struct { + ID string `json:"id"` + Key *jose.JSONWebKey `json:"key"` + Contact []string `json:"contact,omitempty"` + Status acme.Status `json:"status"` + CreatedAt time.Time `json:"createdAt"` + DeactivatedAt time.Time `json:"deactivatedAt"` +} + +func (dba *dbAccount) clone() *dbAccount { + nu := *dba + return &nu +} + +func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) { + id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) + if err != nil { + if nosqlDB.IsErrNotFound(err) { + return "", acme.ErrNotFound + } + return "", errors.Wrapf(err, "error loading key-account index for key %s", kid) + } + return string(id), nil +} + +// getDBAccount retrieves and unmarshals dbAccount. +func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) { + data, err := db.db.Get(accountTable, []byte(id)) + if err != nil { + if nosqlDB.IsErrNotFound(err) { + return nil, acme.ErrNotFound + } + return nil, errors.Wrapf(err, "error loading account %s", id) + } + + dbacc := new(dbAccount) + if err = json.Unmarshal(data, dbacc); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id) + } + return dbacc, nil +} + +// GetAccount retrieves an ACME account by ID. +func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) { + dbacc, err := db.getDBAccount(ctx, id) + if err != nil { + return nil, err + } + + return &acme.Account{ + Status: dbacc.Status, + Contact: dbacc.Contact, + Key: dbacc.Key, + ID: dbacc.ID, + }, nil +} + +// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). +func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) { + id, err := db.getAccountIDByKeyID(ctx, kid) + if err != nil { + return nil, err + } + return db.GetAccount(ctx, id) +} + +// CreateAccount imlements the AcmeDB.CreateAccount interface. +func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error { + var err error + acc.ID, err = randID() + if err != nil { + return err + } + + dba := &dbAccount{ + ID: acc.ID, + Key: acc.Key, + Contact: acc.Contact, + Status: acc.Status, + CreatedAt: clock.Now(), + } + + kid, err := acme.KeyToID(dba.Key) + if err != nil { + return err + } + kidB := []byte(kid) + + // Set the jwkID -> acme account ID index + _, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(acc.ID)) + switch { + case err != nil: + return errors.Wrap(err, "error storing keyID to accountID index") + case !swapped: + return errors.Errorf("key-id to account-id index already exists") + default: + if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil { + db.db.Del(accountByKeyIDTable, kidB) + return err + } + return nil + } +} + +// UpdateAccount imlements the AcmeDB.UpdateAccount interface. +func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error { + old, err := db.getDBAccount(ctx, acc.ID) + if err != nil { + return err + } + + nu := old.clone() + nu.Contact = acc.Contact + nu.Status = acc.Status + + // If the status has changed to 'deactivated', then set deactivatedAt timestamp. + if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated { + nu.DeactivatedAt = clock.Now() + } + + return db.save(ctx, old.ID, nu, old, "account", accountTable) +} diff --git a/acme/db/nosql/account_test.go b/acme/db/nosql/account_test.go new file mode 100644 index 00000000..5ba99a73 --- /dev/null +++ b/acme/db/nosql/account_test.go @@ -0,0 +1,706 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" + "go.step.sm/crypto/jose" +) + +func TestDB_getDBAccount(t *testing.T) { + accID := "accID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbacc *dbAccount + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, nosqldb.ErrNotFound + }, + }, + err: acme.ErrNotFound, + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling account accID into dbAccount"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return b, nil + }, + }, + dbacc: dbacc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, dbacc.ID, tc.dbacc.ID) + assert.Equals(t, dbacc.Status, tc.dbacc.Status) + assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt) + assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt) + assert.Equals(t, dbacc.Contact, tc.dbacc.Contact) + assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID) + } + } + }) + } +} + +func TestDB_getAccountIDByKeyID(t *testing.T) { + accID := "accID" + kid := "kid" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), kid) + + return nil, nosqldb.ErrNotFound + }, + }, + err: acme.ErrNotFound, + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), kid) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading key-account index for key kid: force"), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), kid) + + return []byte(accID), nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, retAccID, accID) + } + } + }) + } +} + +func TestDB_GetAccount(t *testing.T) { + accID := "accID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbacc *dbAccount + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + return b, nil + }, + }, + dbacc: dbacc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if acc, err := db.GetAccount(context.Background(), accID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, tc.dbacc.ID) + assert.Equals(t, acc.Status, tc.dbacc.Status) + assert.Equals(t, acc.Contact, tc.dbacc.Contact) + assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) + } + } + }) + } +} + +func TestDB_GetAccountByKeyID(t *testing.T) { + accID := "accID" + kid := "kid" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbacc *dbAccount + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.getAccountIDByKeyID-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(accountByKeyIDTable)) + assert.Equals(t, string(key), kid) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading key-account index for key kid: force"), + } + }, + "fail/db.GetAccount-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), kid) + return []byte(accID), nil + case string(accountTable): + assert.Equals(t, string(key), accID) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), kid) + return []byte(accID), nil + case string(accountTable): + assert.Equals(t, string(key), accID) + return b, nil + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + dbacc: dbacc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, tc.dbacc.ID) + assert.Equals(t, acc.Status, tc.dbacc.Status) + assert.Equals(t, acc.Contact, tc.dbacc.Contact) + assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) + } + } + }) + } +} + +func TestDB_CreateAccount(t *testing.T) { + type test struct { + db nosql.DB + acc *acme.Account + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/keyID-cmpAndSwap-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + + assert.Equals(t, nu, []byte(acc.ID)) + return nil, false, errors.New("force") + }, + }, + acc: acc, + err: errors.New("error storing keyID to accountID index: force"), + } + }, + "fail/keyID-cmpAndSwap-false": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + + assert.Equals(t, nu, []byte(acc.ID)) + return nil, false, nil + }, + }, + acc: acc, + err: errors.New("key-id to account-id index already exists"), + } + }, + "fail/account-save-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + return nu, true, nil + case string(accountTable): + assert.Equals(t, string(key), acc.ID) + assert.Equals(t, old, nil) + + dbacc := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbacc)) + assert.Equals(t, dbacc.ID, string(key)) + assert.Equals(t, dbacc.Contact, acc.Contact) + assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) + assert.True(t, dbacc.DeactivatedAt.IsZero()) + return nil, false, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + acc: acc, + err: errors.New("error saving acme account: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + ) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + id = string(key) + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + return nu, true, nil + case string(accountTable): + assert.Equals(t, string(key), acc.ID) + assert.Equals(t, old, nil) + + dbacc := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbacc)) + assert.Equals(t, dbacc.ID, string(key)) + assert.Equals(t, dbacc.Contact, acc.Contact) + assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) + assert.True(t, dbacc.DeactivatedAt.IsZero()) + return nu, true, nil + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + acc: acc, + _id: idPtr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateAccount(context.Background(), tc.acc); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.acc.ID, *tc._id) + } + } + }) + } +} + +func TestDB_UpdateAccount(t *testing.T) { + accID := "accID" + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + type test struct { + db nosql.DB + acc *acme.Account + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + acc: &acme.Account{ + ID: accID, + }, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/already-deactivated": func(t *testing.T) test { + clone := dbacc.clone() + clone.Status = acme.StatusDeactivated + clone.DeactivatedAt = now + dbaccb, err := json.Marshal(clone) + assert.FatalError(t, err) + acc := &acme.Account{ + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"foo", "bar"}, + } + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return dbaccb, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, old, b) + + dbNew := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, clone.ID) + assert.Equals(t, dbNew.Status, clone.Status) + assert.Equals(t, dbNew.Contact, clone.Contact) + assert.Equals(t, dbNew.Key.KeyID, clone.Key.KeyID) + assert.Equals(t, dbNew.CreatedAt, clone.CreatedAt) + assert.Equals(t, dbNew.DeactivatedAt, clone.DeactivatedAt) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme account: force"), + } + }, + "fail/db.CmpAndSwap-error": func(t *testing.T) test { + acc := &acme.Account{ + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"foo", "bar"}, + } + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, old, b) + + dbNew := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbacc.ID) + assert.Equals(t, dbNew.Status, acc.Status) + assert.Equals(t, dbNew.Contact, dbacc.Contact) + assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) + assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) + assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now)) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme account: force"), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, old, b) + + dbNew := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbacc.ID) + assert.Equals(t, dbNew.Status, acc.Status) + assert.Equals(t, dbNew.Contact, dbacc.Contact) + assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) + assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) + assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now)) + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.UpdateAccount(context.Background(), tc.acc); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.acc.ID, dbacc.ID) + assert.Equals(t, tc.acc.Status, dbacc.Status) + assert.Equals(t, tc.acc.Contact, dbacc.Contact) + assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID) + } + } + }) + } +} diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go new file mode 100644 index 00000000..6decbe4f --- /dev/null +++ b/acme/db/nosql/authz.go @@ -0,0 +1,118 @@ +package nosql + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/nosql" +) + +// dbAuthz is the base authz type that others build from. +type dbAuthz struct { + ID string `json:"id"` + AccountID string `json:"accountID"` + Identifier acme.Identifier `json:"identifier"` + Status acme.Status `json:"status"` + Token string `json:"token"` + ChallengeIDs []string `json:"challengeIDs"` + Wildcard bool `json:"wildcard"` + CreatedAt time.Time `json:"createdAt"` + ExpiresAt time.Time `json:"expiresAt"` + Error *acme.Error `json:"error"` +} + +func (ba *dbAuthz) clone() *dbAuthz { + u := *ba + return &u +} + +// getDBAuthz retrieves and unmarshals a database representation of the +// ACME Authorization type. +func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) { + data, err := db.db.Get(authzTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id) + } else if err != nil { + return nil, errors.Wrapf(err, "error loading authz %s", id) + } + + var dbaz dbAuthz + if err = json.Unmarshal(data, &dbaz); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling authz %s into dbAuthz", id) + } + return &dbaz, nil +} + +// GetAuthorization retrieves and unmarshals an ACME authz type from the database. +// Implements acme.DB GetAuthorization interface. +func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorization, error) { + dbaz, err := db.getDBAuthz(ctx, id) + if err != nil { + return nil, err + } + var chs = make([]*acme.Challenge, len(dbaz.ChallengeIDs)) + for i, chID := range dbaz.ChallengeIDs { + chs[i], err = db.GetChallenge(ctx, chID, id) + if err != nil { + return nil, err + } + } + return &acme.Authorization{ + ID: dbaz.ID, + AccountID: dbaz.AccountID, + Identifier: dbaz.Identifier, + Status: dbaz.Status, + Challenges: chs, + Wildcard: dbaz.Wildcard, + ExpiresAt: dbaz.ExpiresAt, + Token: dbaz.Token, + Error: dbaz.Error, + }, nil +} + +// CreateAuthorization creates an entry in the database for the Authorization. +// Implements the acme.DB.CreateAuthorization interface. +func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) error { + var err error + az.ID, err = randID() + if err != nil { + return err + } + + chIDs := make([]string, len(az.Challenges)) + for i, ch := range az.Challenges { + chIDs[i] = ch.ID + } + + now := clock.Now() + dbaz := &dbAuthz{ + ID: az.ID, + AccountID: az.AccountID, + Status: az.Status, + CreatedAt: now, + ExpiresAt: az.ExpiresAt, + Identifier: az.Identifier, + ChallengeIDs: chIDs, + Token: az.Token, + Wildcard: az.Wildcard, + } + + return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable) +} + +// UpdateAuthorization saves an updated ACME Authorization to the database. +func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) error { + old, err := db.getDBAuthz(ctx, az.ID) + if err != nil { + return err + } + + nu := old.clone() + + nu.Status = az.Status + nu.Error = az.Error + return db.save(ctx, old.ID, nu, old, "authz", authzTable) +} diff --git a/acme/db/nosql/authz_test.go b/acme/db/nosql/authz_test.go new file mode 100644 index 00000000..0c2cec50 --- /dev/null +++ b/acme/db/nosql/authz_test.go @@ -0,0 +1,620 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" +) + +func TestDB_getDBAuthz(t *testing.T) { + azID := "azID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbaz *dbAuthz + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authz azID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling authz azID into dbAuthz"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return b, nil + }, + }, + dbaz: dbaz, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, dbaz.ID, tc.dbaz.ID) + assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID) + assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier) + assert.Equals(t, dbaz.Status, tc.dbaz.Status) + assert.Equals(t, dbaz.Token, tc.dbaz.Token) + assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt) + assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt) + assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error()) + assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard) + } + } + }) + } +} + +func TestDB_GetAuthorization(t *testing.T) { + azID := "azID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbaz *dbAuthz + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authz azID: force"), + } + }, + "fail/forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"), + } + }, + "fail/db.GetChallenge-error": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(authzTable): + assert.Equals(t, string(key), azID) + return b, nil + case string(challengeTable): + assert.Equals(t, string(key), "foo") + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket))) + return nil, errors.New("force") + } + }, + }, + err: errors.New("error loading acme challenge foo: force"), + } + }, + "fail/db.GetChallenge-not-found": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(authzTable): + assert.Equals(t, string(key), azID) + return b, nil + case string(challengeTable): + assert.Equals(t, string(key), "foo") + return nil, nosqldb.ErrNotFound + default: + assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket))) + return nil, errors.New("force") + } + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge foo not found"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + chCount := 0 + fooChb, err := json.Marshal(&dbChallenge{ID: "foo"}) + assert.FatalError(t, err) + barChb, err := json.Marshal(&dbChallenge{ID: "bar"}) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(authzTable): + assert.Equals(t, string(key), azID) + return b, nil + case string(challengeTable): + if chCount == 0 { + chCount++ + assert.Equals(t, string(key), "foo") + return fooChb, nil + } + assert.Equals(t, string(key), "bar") + return barChb, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket))) + return nil, errors.New("force") + } + }, + }, + dbaz: dbaz, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if az, err := db.GetAuthorization(context.Background(), azID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, az.ID, tc.dbaz.ID) + assert.Equals(t, az.AccountID, tc.dbaz.AccountID) + assert.Equals(t, az.Identifier, tc.dbaz.Identifier) + assert.Equals(t, az.Status, tc.dbaz.Status) + assert.Equals(t, az.Token, tc.dbaz.Token) + assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard) + assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt) + assert.Equals(t, az.Challenges, []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }) + assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error()) + } + } + }) + } +} + +func TestDB_CreateAuthorization(t *testing.T) { + azID := "azID" + type test struct { + db nosql.DB + az *acme.Authorization + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + now := clock.Now() + az := &acme.Authorization{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }, + Wildcard: true, + Error: acme.NewErrorISE("force"), + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), az.ID) + assert.Equals(t, old, nil) + + dbaz := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(nu, dbaz)) + assert.Equals(t, dbaz.ID, string(key)) + assert.Equals(t, dbaz.AccountID, az.AccountID) + assert.Equals(t, dbaz.Identifier, acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }) + assert.Equals(t, dbaz.Status, az.Status) + assert.Equals(t, dbaz.Token, az.Token) + assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"}) + assert.Equals(t, dbaz.Wildcard, az.Wildcard) + assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) + assert.Nil(t, dbaz.Error) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt)) + return nil, false, errors.New("force") + }, + }, + az: az, + err: errors.New("error saving acme authz: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + now = clock.Now() + az = &acme.Authorization{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }, + Wildcard: true, + Error: acme.NewErrorISE("force"), + } + ) + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + *idPtr = string(key) + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), az.ID) + assert.Equals(t, old, nil) + + dbaz := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(nu, dbaz)) + assert.Equals(t, dbaz.ID, string(key)) + assert.Equals(t, dbaz.AccountID, az.AccountID) + assert.Equals(t, dbaz.Identifier, acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }) + assert.Equals(t, dbaz.Status, az.Status) + assert.Equals(t, dbaz.Token, az.Token) + assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"}) + assert.Equals(t, dbaz.Wildcard, az.Wildcard) + assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) + assert.Nil(t, dbaz.Error) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt)) + return nu, true, nil + }, + }, + az: az, + _id: idPtr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateAuthorization(context.Background(), tc.az); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.az.ID, *tc._id) + } + } + }) + } +} + +func TestDB_UpdateAuthorization(t *testing.T) { + azID := "azID" + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "accountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + type test struct { + db nosql.DB + az *acme.Authorization + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + az: &acme.Authorization{ + ID: azID, + }, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authz azID: force"), + } + }, + "fail/db.CmpAndSwap-error": func(t *testing.T) test { + updAz := &acme.Authorization{ + ID: azID, + Status: acme.StatusValid, + Error: acme.NewError(acme.ErrorMalformedType, "malformed"), + } + return test{ + az: updAz, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, old, b) + + dbOld := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(old, dbOld)) + assert.Equals(t, dbaz, dbOld) + + dbNew := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbaz.ID) + assert.Equals(t, dbNew.AccountID, dbaz.AccountID) + assert.Equals(t, dbNew.Identifier, dbaz.Identifier) + assert.Equals(t, dbNew.Status, acme.StatusValid) + assert.Equals(t, dbNew.Token, dbaz.Token) + assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs) + assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) + assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) + assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) + assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme authz: force"), + } + }, + "ok": func(t *testing.T) test { + updAz := &acme.Authorization{ + ID: azID, + AccountID: dbaz.AccountID, + Status: acme.StatusValid, + Identifier: dbaz.Identifier, + Challenges: []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }, + Token: dbaz.Token, + Wildcard: dbaz.Wildcard, + ExpiresAt: dbaz.ExpiresAt, + Error: acme.NewError(acme.ErrorMalformedType, "malformed"), + } + return test{ + az: updAz, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, string(key), azID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authzTable) + assert.Equals(t, old, b) + + dbOld := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(old, dbOld)) + assert.Equals(t, dbaz, dbOld) + + dbNew := new(dbAuthz) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbaz.ID) + assert.Equals(t, dbNew.AccountID, dbaz.AccountID) + assert.Equals(t, dbNew.Identifier, dbaz.Identifier) + assert.Equals(t, dbNew.Status, acme.StatusValid) + assert.Equals(t, dbNew.Token, dbaz.Token) + assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs) + assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) + assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) + assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) + assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.az.ID, dbaz.ID) + assert.Equals(t, tc.az.AccountID, dbaz.AccountID) + assert.Equals(t, tc.az.Identifier, dbaz.Identifier) + assert.Equals(t, tc.az.Status, acme.StatusValid) + assert.Equals(t, tc.az.Wildcard, dbaz.Wildcard) + assert.Equals(t, tc.az.Token, dbaz.Token) + assert.Equals(t, tc.az.ExpiresAt, dbaz.ExpiresAt) + assert.Equals(t, tc.az.Challenges, []*acme.Challenge{ + {ID: "foo"}, + {ID: "bar"}, + }) + assert.Equals(t, tc.az.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + } + } + }) + } +} diff --git a/acme/db/nosql/certificate.go b/acme/db/nosql/certificate.go new file mode 100644 index 00000000..d3e15833 --- /dev/null +++ b/acme/db/nosql/certificate.go @@ -0,0 +1,109 @@ +package nosql + +import ( + "context" + "crypto/x509" + "encoding/json" + "encoding/pem" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/nosql" +) + +type dbCert struct { + ID string `json:"id"` + CreatedAt time.Time `json:"createdAt"` + AccountID string `json:"accountID"` + OrderID string `json:"orderID"` + Leaf []byte `json:"leaf"` + Intermediates []byte `json:"intermediates"` +} + +// CreateCertificate creates and stores an ACME certificate type. +func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error { + var err error + cert.ID, err = randID() + if err != nil { + return err + } + + leaf := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Leaf.Raw, + }) + var intermediates []byte + for _, cert := range cert.Intermediates { + intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })...) + } + + dbch := &dbCert{ + ID: cert.ID, + AccountID: cert.AccountID, + OrderID: cert.OrderID, + Leaf: leaf, + Intermediates: intermediates, + CreatedAt: time.Now().UTC(), + } + return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable) +} + +// GetCertificate retrieves and unmarshals an ACME certificate type from the +// datastore. +func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) { + b, err := db.db.Get(certTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id) + } else if err != nil { + return nil, errors.Wrapf(err, "error loading certificate %s", id) + } + dbC := new(dbCert) + if err := json.Unmarshal(b, dbC); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling certificate %s", id) + } + + certs, err := parseBundle(append(dbC.Leaf, dbC.Intermediates...)) + if err != nil { + return nil, errors.Wrapf(err, "error parsing certificate chain for ACME certificate with ID %s", id) + } + + return &acme.Certificate{ + ID: dbC.ID, + AccountID: dbC.AccountID, + OrderID: dbC.OrderID, + Leaf: certs[0], + Intermediates: certs[1:], + }, nil +} + +func parseBundle(b []byte) ([]*x509.Certificate, error) { + var ( + err error + block *pem.Block + bundle []*x509.Certificate + ) + for len(b) > 0 { + block, b = pem.Decode(b) + if block == nil { + break + } + if block.Type != "CERTIFICATE" { + return nil, errors.New("error decoding PEM: data contains block that is not a certificate") + } + var crt *x509.Certificate + crt, err = x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, errors.Wrapf(err, "error parsing x509 certificate") + } + bundle = append(bundle, crt) + } + if len(b) > 0 { + return nil, errors.New("error decoding PEM: unexpected data") + } + return bundle, nil + +} diff --git a/acme/db/nosql/certificate_test.go b/acme/db/nosql/certificate_test.go new file mode 100644 index 00000000..4ec4589e --- /dev/null +++ b/acme/db/nosql/certificate_test.go @@ -0,0 +1,321 @@ +package nosql + +import ( + "context" + "crypto/x509" + "encoding/json" + "encoding/pem" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" + + "go.step.sm/crypto/pemutil" +) + +func TestDB_CreateCertificate(t *testing.T) { + leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") + assert.FatalError(t, err) + type test struct { + db nosql.DB + cert *acme.Certificate + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + cert := &acme.Certificate{ + AccountID: "accountID", + OrderID: "orderID", + Leaf: leaf, + Intermediates: []*x509.Certificate{inter, root}, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + assert.Equals(t, old, nil) + + dbc := new(dbCert) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.ID, cert.ID) + assert.Equals(t, dbc.AccountID, cert.AccountID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) + return nil, false, errors.New("force") + }, + }, + cert: cert, + err: errors.New("error saving acme certificate: force"), + } + }, + "ok": func(t *testing.T) test { + cert := &acme.Certificate{ + AccountID: "accountID", + OrderID: "orderID", + Leaf: leaf, + Intermediates: []*x509.Certificate{inter, root}, + } + var ( + id string + idPtr = &id + ) + + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + *idPtr = string(key) + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + assert.Equals(t, old, nil) + + dbc := new(dbCert) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.ID, cert.ID) + assert.Equals(t, dbc.AccountID, cert.AccountID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) + return nil, true, nil + }, + }, + _id: idPtr, + cert: cert, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateCertificate(context.Background(), tc.cert); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.cert.ID, *tc._id) + } + } + }) + } +} + +func TestDB_GetCertificate(t *testing.T) { + leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + certID := "certID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + return nil, errors.Errorf("force") + }, + }, + err: errors.New("error loading certificate certID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + return []byte("foobar"), nil + }, + }, + err: errors.New("error unmarshaling certificate certID"), + } + }, + "fail/parseBundle-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + cert := dbCert{ + ID: certID, + AccountID: "accountID", + OrderID: "orderID", + Leaf: pem.EncodeToMemory(&pem.Block{ + Type: "Public Key", + Bytes: leaf.Raw, + }), + CreatedAt: clock.Now(), + } + b, err := json.Marshal(cert) + assert.FatalError(t, err) + + return b, nil + }, + }, + err: errors.Errorf("error parsing certificate chain for ACME certificate with ID certID"), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, certTable) + assert.Equals(t, string(key), certID) + + cert := dbCert{ + ID: certID, + AccountID: "accountID", + OrderID: "orderID", + Leaf: pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: leaf.Raw, + }), + Intermediates: append(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: inter.Raw, + }), pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: root.Raw, + })...), + CreatedAt: clock.Now(), + } + b, err := json.Marshal(cert) + assert.FatalError(t, err) + + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + cert, err := db.GetCertificate(context.Background(), certID) + if err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, cert.ID, certID) + assert.Equals(t, cert.AccountID, "accountID") + assert.Equals(t, cert.OrderID, "orderID") + assert.Equals(t, cert.Leaf, leaf) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) + } + } + }) + } +} + +func Test_parseBundle(t *testing.T) { + leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + var certs []byte + for _, cert := range []*x509.Certificate{leaf, inter, root} { + certs = append(certs, pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })...) + } + + type test struct { + b []byte + err error + } + var tests = map[string]test{ + "fail/bad-type-error": { + b: pem.EncodeToMemory(&pem.Block{ + Type: "Public Key", + Bytes: leaf.Raw, + }), + err: errors.Errorf("error decoding PEM: data contains block that is not a certificate"), + }, + "fail/bad-pem-error": { + b: pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: []byte("foo"), + }), + err: errors.Errorf("error parsing x509 certificate"), + }, + "fail/unexpected-data": { + b: append(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: leaf.Raw, + }), []byte("foo")...), + err: errors.Errorf("error decoding PEM: unexpected data"), + }, + "ok": { + b: certs, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ret, err := parseBundle(tc.b) + if err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ret, []*x509.Certificate{leaf, inter, root}) + } + } + }) + } +} diff --git a/acme/db/nosql/challenge.go b/acme/db/nosql/challenge.go new file mode 100644 index 00000000..f3a3cfca --- /dev/null +++ b/acme/db/nosql/challenge.go @@ -0,0 +1,103 @@ +package nosql + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/nosql" +) + +type dbChallenge struct { + ID string `json:"id"` + AccountID string `json:"accountID"` + Type string `json:"type"` + Status acme.Status `json:"status"` + Token string `json:"token"` + Value string `json:"value"` + ValidatedAt string `json:"validatedAt"` + CreatedAt time.Time `json:"createdAt"` + Error *acme.Error `json:"error"` +} + +func (dbc *dbChallenge) clone() *dbChallenge { + u := *dbc + return &u +} + +func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) { + data, err := db.db.Get(challengeTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id) + } else if err != nil { + return nil, errors.Wrapf(err, "error loading acme challenge %s", id) + } + + dbch := new(dbChallenge) + if err := json.Unmarshal(data, dbch); err != nil { + return nil, errors.Wrap(err, "error unmarshaling dbChallenge") + } + return dbch, nil +} + +// CreateChallenge creates a new ACME challenge data structure in the database. +// Implements acme.DB.CreateChallenge interface. +func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error { + var err error + ch.ID, err = randID() + if err != nil { + return errors.Wrap(err, "error generating random id for ACME challenge") + } + + dbch := &dbChallenge{ + ID: ch.ID, + AccountID: ch.AccountID, + Value: ch.Value, + Status: acme.StatusPending, + Token: ch.Token, + CreatedAt: clock.Now(), + Type: ch.Type, + } + + return db.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable) +} + +// GetChallenge retrieves and unmarshals an ACME challenge type from the database. +// Implements the acme.DB GetChallenge interface. +func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) { + dbch, err := db.getDBChallenge(ctx, id) + if err != nil { + return nil, err + } + + ch := &acme.Challenge{ + ID: dbch.ID, + AccountID: dbch.AccountID, + Type: dbch.Type, + Value: dbch.Value, + Status: dbch.Status, + Token: dbch.Token, + Error: dbch.Error, + ValidatedAt: dbch.ValidatedAt, + } + return ch, nil +} + +// UpdateChallenge updates an ACME challenge type in the database. +func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error { + old, err := db.getDBChallenge(ctx, ch.ID) + if err != nil { + return err + } + + nu := old.clone() + + // These should be the only values changing in an Update request. + nu.Status = ch.Status + nu.Error = ch.Error + nu.ValidatedAt = ch.ValidatedAt + + return db.save(ctx, old.ID, nu, old, "challenge", challengeTable) +} diff --git a/acme/db/nosql/challenge_test.go b/acme/db/nosql/challenge_test.go new file mode 100644 index 00000000..b39395e8 --- /dev/null +++ b/acme/db/nosql/challenge_test.go @@ -0,0 +1,464 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" +) + +func TestDB_getDBChallenge(t *testing.T) { + chID := "chID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbc *dbChallenge + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading acme challenge chID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling dbChallenge"), + } + }, + "ok": func(t *testing.T) test { + dbc := &dbChallenge{ + ID: chID, + AccountID: "accountID", + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + CreatedAt: clock.Now(), + ValidatedAt: "foobar", + Error: acme.NewErrorISE("force"), + } + b, err := json.Marshal(dbc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return b, nil + }, + }, + dbc: dbc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if ch, err := db.getDBChallenge(context.Background(), chID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.ID, tc.dbc.ID) + assert.Equals(t, ch.AccountID, tc.dbc.AccountID) + assert.Equals(t, ch.Type, tc.dbc.Type) + assert.Equals(t, ch.Status, tc.dbc.Status) + assert.Equals(t, ch.Token, tc.dbc.Token) + assert.Equals(t, ch.Value, tc.dbc.Value) + assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) + assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) + } + } + }) + } +} + +func TestDB_CreateChallenge(t *testing.T) { + type test struct { + db nosql.DB + ch *acme.Challenge + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + ch := &acme.Challenge{ + AccountID: "accountID", + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), ch.ID) + assert.Equals(t, old, nil) + + dbc := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.AccountID, ch.AccountID) + assert.Equals(t, dbc.Type, ch.Type) + assert.Equals(t, dbc.Status, ch.Status) + assert.Equals(t, dbc.Token, ch.Token) + assert.Equals(t, dbc.Value, ch.Value) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) + return nil, false, errors.New("force") + }, + }, + ch: ch, + err: errors.New("error saving acme challenge: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + ch = &acme.Challenge{ + AccountID: "accountID", + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + } + ) + + return test{ + ch: ch, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + *idPtr = string(key) + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), ch.ID) + assert.Equals(t, old, nil) + + dbc := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.AccountID, ch.AccountID) + assert.Equals(t, dbc.Type, ch.Type) + assert.Equals(t, dbc.Status, ch.Status) + assert.Equals(t, dbc.Token, ch.Token) + assert.Equals(t, dbc.Value, ch.Value) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) + return nil, true, nil + }, + }, + _id: idPtr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateChallenge(context.Background(), tc.ch); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.ch.ID, *tc._id) + } + } + }) + } +} + +func TestDB_GetChallenge(t *testing.T) { + chID := "chID" + azID := "azID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbc *dbChallenge + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading acme challenge chID: force"), + } + }, + "fail/forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"), + } + }, + "ok": func(t *testing.T) test { + dbc := &dbChallenge{ + ID: chID, + AccountID: "accountID", + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + CreatedAt: clock.Now(), + ValidatedAt: "foobar", + Error: acme.NewErrorISE("force"), + } + b, err := json.Marshal(dbc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return b, nil + }, + }, + dbc: dbc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.ID, tc.dbc.ID) + assert.Equals(t, ch.AccountID, tc.dbc.AccountID) + assert.Equals(t, ch.Type, tc.dbc.Type) + assert.Equals(t, ch.Status, tc.dbc.Status) + assert.Equals(t, ch.Token, tc.dbc.Token) + assert.Equals(t, ch.Value, tc.dbc.Value) + assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) + assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) + } + } + }) + } +} + +func TestDB_UpdateChallenge(t *testing.T) { + chID := "chID" + dbc := &dbChallenge{ + ID: chID, + AccountID: "accountID", + Type: "dns-01", + Status: acme.StatusPending, + Token: "token", + Value: "test.ca.smallstep.com", + CreatedAt: clock.Now(), + } + b, err := json.Marshal(dbc) + assert.FatalError(t, err) + type test struct { + db nosql.DB + ch *acme.Challenge + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ch: &acme.Challenge{ + ID: chID, + }, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading acme challenge chID: force"), + } + }, + "fail/db.CmpAndSwap-error": func(t *testing.T) test { + updCh := &acme.Challenge{ + ID: chID, + Status: acme.StatusValid, + ValidatedAt: "foobar", + Error: acme.NewError(acme.ErrorMalformedType, "malformed"), + } + return test{ + ch: updCh, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, old, b) + + dbOld := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(old, dbOld)) + assert.Equals(t, dbc, dbOld) + + dbNew := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbc.ID) + assert.Equals(t, dbNew.AccountID, dbc.AccountID) + assert.Equals(t, dbNew.Type, dbc.Type) + assert.Equals(t, dbNew.Status, updCh.Status) + assert.Equals(t, dbNew.Token, dbc.Token) + assert.Equals(t, dbNew.Value, dbc.Value) + assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error()) + assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt) + assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme challenge: force"), + } + }, + "ok": func(t *testing.T) test { + updCh := &acme.Challenge{ + ID: dbc.ID, + AccountID: dbc.AccountID, + Type: dbc.Type, + Token: dbc.Token, + Value: dbc.Value, + Status: acme.StatusValid, + ValidatedAt: "foobar", + Error: acme.NewError(acme.ErrorMalformedType, "malformed"), + } + return test{ + ch: updCh, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), chID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, old, b) + + dbOld := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(old, dbOld)) + assert.Equals(t, dbc, dbOld) + + dbNew := new(dbChallenge) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbc.ID) + assert.Equals(t, dbNew.AccountID, dbc.AccountID) + assert.Equals(t, dbNew.Type, dbc.Type) + assert.Equals(t, dbNew.Token, dbc.Token) + assert.Equals(t, dbNew.Value, dbc.Value) + assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt) + assert.Equals(t, dbNew.Status, acme.StatusValid) + assert.Equals(t, dbNew.ValidatedAt, "foobar") + assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.ch.ID, dbc.ID) + assert.Equals(t, tc.ch.AccountID, dbc.AccountID) + assert.Equals(t, tc.ch.Type, dbc.Type) + assert.Equals(t, tc.ch.Token, dbc.Token) + assert.Equals(t, tc.ch.Value, dbc.Value) + assert.Equals(t, tc.ch.ValidatedAt, "foobar") + assert.Equals(t, tc.ch.Status, acme.StatusValid) + assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) + } + } + }) + } +} diff --git a/acme/db/nosql/nonce.go b/acme/db/nosql/nonce.go new file mode 100644 index 00000000..9badae87 --- /dev/null +++ b/acme/db/nosql/nonce.go @@ -0,0 +1,66 @@ +package nosql + +import ( + "context" + "encoding/base64" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +// dbNonce contains nonce metadata used in the ACME protocol. +type dbNonce struct { + ID string + CreatedAt time.Time + DeletedAt time.Time +} + +// CreateNonce creates, stores, and returns an ACME replay-nonce. +// Implements the acme.DB interface. +func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) { + _id, err := randID() + if err != nil { + return "", err + } + + id := base64.RawURLEncoding.EncodeToString([]byte(_id)) + n := &dbNonce{ + ID: id, + CreatedAt: clock.Now(), + } + if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil { + return "", err + } + return acme.Nonce(id), nil +} + +// DeleteNonce verifies that the nonce is valid (by checking if it exists), +// and if so, consumes the nonce resource by deleting it from the database. +func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error { + err := db.db.Update(&database.Tx{ + Operations: []*database.TxEntry{ + { + Bucket: nonceTable, + Key: []byte(nonce), + Cmd: database.Get, + }, + { + Bucket: nonceTable, + Key: []byte(nonce), + Cmd: database.Delete, + }, + }, + }) + + switch { + case nosql.IsErrNotFound(err): + return acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", string(nonce)) + case err != nil: + return errors.Wrapf(err, "error deleting nonce %s", string(nonce)) + default: + return nil + } +} diff --git a/acme/db/nosql/nonce_test.go b/acme/db/nosql/nonce_test.go new file mode 100644 index 00000000..05d73d52 --- /dev/null +++ b/acme/db/nosql/nonce_test.go @@ -0,0 +1,168 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" +) + +func TestDB_CreateNonce(t *testing.T) { + type test struct { + db nosql.DB + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/cmpAndSwap-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, old, nil) + + dbn := new(dbNonce) + assert.FatalError(t, json.Unmarshal(nu, dbn)) + assert.Equals(t, dbn.ID, string(key)) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt)) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme nonce: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + ) + + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + *idPtr = string(key) + assert.Equals(t, bucket, nonceTable) + assert.Equals(t, old, nil) + + dbn := new(dbNonce) + assert.FatalError(t, json.Unmarshal(nu, dbn)) + assert.Equals(t, dbn.ID, string(key)) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt)) + return nil, true, nil + }, + }, + _id: idPtr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if n, err := db.CreateNonce(context.Background()); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, string(n), *tc._id) + } + } + }) + } +} + +func TestDB_DeleteNonce(t *testing.T) { + + nonceID := "nonceID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + assert.Equals(t, tx.Operations[0].Bucket, nonceTable) + assert.Equals(t, tx.Operations[0].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[0].Cmd, database.Get) + + assert.Equals(t, tx.Operations[1].Bucket, nonceTable) + assert.Equals(t, tx.Operations[1].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[1].Cmd, database.Delete) + return database.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", nonceID), + } + }, + "fail/db.Update-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + assert.Equals(t, tx.Operations[0].Bucket, nonceTable) + assert.Equals(t, tx.Operations[0].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[0].Cmd, database.Get) + + assert.Equals(t, tx.Operations[1].Bucket, nonceTable) + assert.Equals(t, tx.Operations[1].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[1].Cmd, database.Delete) + return errors.New("force") + }, + }, + err: errors.New("error deleting nonce nonceID: force"), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MUpdate: func(tx *database.Tx) error { + assert.Equals(t, tx.Operations[0].Bucket, nonceTable) + assert.Equals(t, tx.Operations[0].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[0].Cmd, database.Get) + + assert.Equals(t, tx.Operations[1].Bucket, nonceTable) + assert.Equals(t, tx.Operations[1].Key, []byte(nonceID)) + assert.Equals(t, tx.Operations[1].Cmd, database.Delete) + return nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go new file mode 100644 index 00000000..052f5729 --- /dev/null +++ b/acme/db/nosql/nosql.go @@ -0,0 +1,96 @@ +package nosql + +import ( + "context" + "encoding/json" + "time" + + "github.com/pkg/errors" + nosqlDB "github.com/smallstep/nosql" + "go.step.sm/crypto/randutil" +) + +var ( + accountTable = []byte("acme_accounts") + accountByKeyIDTable = []byte("acme_keyID_accountID_index") + authzTable = []byte("acme_authzs") + challengeTable = []byte("acme_challenges") + nonceTable = []byte("nonces") + orderTable = []byte("acme_orders") + ordersByAccountIDTable = []byte("acme_account_orders_index") + certTable = []byte("acme_certs") +) + +// DB is a struct that implements the AcmeDB interface. +type DB struct { + db nosqlDB.DB +} + +// New configures and returns a new ACME DB backend implemented using a nosql DB. +func New(db nosqlDB.DB) (*DB, error) { + tables := [][]byte{accountTable, accountByKeyIDTable, authzTable, + challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable} + for _, b := range tables { + if err := db.CreateTable(b); err != nil { + return nil, errors.Wrapf(err, "error creating table %s", + string(b)) + } + } + return &DB{db}, nil +} + +// save writes the new data to the database, overwriting the old data if it +// existed. +func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { + var ( + err error + newB []byte + ) + if nu == nil { + newB = nil + } else { + newB, err = json.Marshal(nu) + if err != nil { + return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu) + } + } + var oldB []byte + if old == nil { + oldB = nil + } else { + oldB, err = json.Marshal(old) + if err != nil { + return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old) + } + } + + _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB) + switch { + case err != nil: + return errors.Wrapf(err, "error saving acme %s", typ) + case !swapped: + return errors.Errorf("error saving acme %s; changed since last read", typ) + default: + return nil + } +} + +var idLen = 32 + +func randID() (val string, err error) { + val, err = randutil.Alphanumeric(idLen) + if err != nil { + return "", errors.Wrap(err, "error generating random alphanumeric ID") + } + return val, nil +} + +// Clock that returns time in UTC rounded to seconds. +type Clock struct{} + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Truncate(time.Second) +} + +var clock = new(Clock) diff --git a/acme/db/nosql/nosql_test.go b/acme/db/nosql/nosql_test.go new file mode 100644 index 00000000..4396acc8 --- /dev/null +++ b/acme/db/nosql/nosql_test.go @@ -0,0 +1,139 @@ +package nosql + +import ( + "context" + "testing" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" +) + +func TestNew(t *testing.T) { + type test struct { + db nosql.DB + err error + } + var tests = map[string]test{ + "fail/db.CreateTable-error": { + db: &db.MockNoSQLDB{ + MCreateTable: func(bucket []byte) error { + assert.Equals(t, string(bucket), string(accountTable)) + return errors.New("force") + }, + }, + err: errors.Errorf("error creating table %s: force", string(accountTable)), + }, + "ok": { + db: &db.MockNoSQLDB{ + MCreateTable: func(bucket []byte) error { + return nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + if _, err := New(tc.db); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +type errorThrower string + +func (et errorThrower) MarshalJSON() ([]byte, error) { + return nil, errors.New("force") +} + +func TestDB_save(t *testing.T) { + type test struct { + db nosql.DB + nu interface{} + old interface{} + err error + } + var tests = map[string]test{ + "fail/error-marshaling-new": { + nu: errorThrower("foo"), + err: errors.New("error marshaling acme type: challenge"), + }, + "fail/error-marshaling-old": { + nu: "new", + old: errorThrower("foo"), + err: errors.New("error marshaling acme type: challenge"), + }, + "fail/db.CmpAndSwap-error": { + nu: "new", + old: "old", + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), "id") + assert.Equals(t, string(old), "\"old\"") + assert.Equals(t, string(nu), "\"new\"") + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme challenge: force"), + }, + "fail/db.CmpAndSwap-false-marshaling-old": { + nu: "new", + old: "old", + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), "id") + assert.Equals(t, string(old), "\"old\"") + assert.Equals(t, string(nu), "\"new\"") + return nil, false, nil + }, + }, + err: errors.New("error saving acme challenge; changed since last read"), + }, + "ok": { + nu: "new", + old: "old", + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), "id") + assert.Equals(t, string(old), "\"old\"") + assert.Equals(t, string(nu), "\"new\"") + return nu, true, nil + }, + }, + }, + "ok/nils": { + nu: nil, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), "id") + assert.Equals(t, old, nil) + assert.Equals(t, nu, nil) + return nu, true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + db := &DB{db: tc.db} + if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go new file mode 100644 index 00000000..ba3934af --- /dev/null +++ b/acme/db/nosql/order.go @@ -0,0 +1,189 @@ +package nosql + +import ( + "context" + "encoding/json" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/nosql" +) + +// Mutex for locking ordersByAccount index operations. +var ordersByAccountMux sync.Mutex + +type dbOrder struct { + ID string `json:"id"` + AccountID string `json:"accountID"` + ProvisionerID string `json:"provisionerID"` + Identifiers []acme.Identifier `json:"identifiers"` + AuthorizationIDs []string `json:"authorizationIDs"` + Status acme.Status `json:"status"` + NotBefore time.Time `json:"notBefore,omitempty"` + NotAfter time.Time `json:"notAfter,omitempty"` + CreatedAt time.Time `json:"createdAt"` + ExpiresAt time.Time `json:"expiresAt,omitempty"` + CertificateID string `json:"certificate,omitempty"` + Error *acme.Error `json:"error,omitempty"` +} + +func (a *dbOrder) clone() *dbOrder { + b := *a + return &b +} + +// getDBOrder retrieves and unmarshals an ACME Order type from the database. +func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) { + b, err := db.db.Get(orderTable, []byte(id)) + if nosql.IsErrNotFound(err) { + return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id) + } else if err != nil { + return nil, errors.Wrapf(err, "error loading order %s", id) + } + o := new(dbOrder) + if err := json.Unmarshal(b, &o); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling order %s into dbOrder", id) + } + return o, nil +} + +// GetOrder retrieves an ACME Order from the database. +func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { + dbo, err := db.getDBOrder(ctx, id) + if err != nil { + return nil, err + } + + o := &acme.Order{ + ID: dbo.ID, + AccountID: dbo.AccountID, + ProvisionerID: dbo.ProvisionerID, + CertificateID: dbo.CertificateID, + Status: dbo.Status, + ExpiresAt: dbo.ExpiresAt, + Identifiers: dbo.Identifiers, + NotBefore: dbo.NotBefore, + NotAfter: dbo.NotAfter, + AuthorizationIDs: dbo.AuthorizationIDs, + Error: dbo.Error, + } + + return o, nil +} + +// CreateOrder creates ACME Order resources and saves them to the DB. +func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { + var err error + o.ID, err = randID() + if err != nil { + return err + } + + now := clock.Now() + dbo := &dbOrder{ + ID: o.ID, + AccountID: o.AccountID, + ProvisionerID: o.ProvisionerID, + Status: o.Status, + CreatedAt: now, + ExpiresAt: o.ExpiresAt, + Identifiers: o.Identifiers, + NotBefore: o.NotBefore, + NotAfter: o.NotAfter, + AuthorizationIDs: o.AuthorizationIDs, + } + if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil { + return err + } + + _, err = db.updateAddOrderIDs(ctx, o.AccountID, o.ID) + if err != nil { + return err + } + return nil +} + +// UpdateOrder saves an updated ACME Order to the database. +func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error { + old, err := db.getDBOrder(ctx, o.ID) + if err != nil { + return err + } + + nu := old.clone() + + nu.Status = o.Status + nu.Error = o.Error + nu.CertificateID = o.CertificateID + return db.save(ctx, old.ID, nu, old, "order", orderTable) +} + +func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) { + ordersByAccountMux.Lock() + defer ordersByAccountMux.Unlock() + + b, err := db.db.Get(ordersByAccountIDTable, []byte(accID)) + var ( + oldOids []string + ) + if err != nil { + if !nosql.IsErrNotFound(err) { + return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) + } + } else { + if err := json.Unmarshal(b, &oldOids); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID) + } + } + + // Remove any order that is not in PENDING state and update the stored list + // before returning. + // + // According to RFC 8555: + // The server SHOULD include pending orders and SHOULD NOT include orders + // that are invalid in the array of URLs. + pendOids := []string{} + for _, oid := range oldOids { + o, err := db.GetOrder(ctx, oid) + if err != nil { + return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID) + } + if err = o.UpdateStatus(ctx, db); err != nil { + return nil, acme.WrapErrorISE(err, "error updating order %s for account %s", oid, accID) + } + if o.Status == acme.StatusPending { + pendOids = append(pendOids, oid) + } + } + pendOids = append(pendOids, addOids...) + var ( + _old interface{} = oldOids + _new interface{} = pendOids + ) + switch { + case len(oldOids) == 0 && len(pendOids) == 0: + // If list has not changed from empty, then no need to write the DB. + return []string{}, nil + case len(oldOids) == 0: + _old = nil + case len(pendOids) == 0: + _new = nil + } + if err = db.save(ctx, accID, _new, _old, "orderIDsByAccountID", ordersByAccountIDTable); err != nil { + // Delete all orders that may have been previously stored if orderIDsByAccountID update fails. + for _, oid := range addOids { + // Ignore error from delete -- we tried our best. + // TODO when we have logging w/ request ID tracking, logging this error. + db.db.Del(orderTable, []byte(oid)) + } + return nil, errors.Wrapf(err, "error saving orderIDs index for account %s", accID) + } + return pendOids, nil +} + +// GetOrdersByAccountID returns a list of order IDs owned by the account. +func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { + return db.updateAddOrderIDs(ctx, accID) +} diff --git a/acme/db/nosql/order_test.go b/acme/db/nosql/order_test.go new file mode 100644 index 00000000..7248700f --- /dev/null +++ b/acme/db/nosql/order_test.go @@ -0,0 +1,1003 @@ +package nosql + +import ( + "context" + "encoding/json" + "reflect" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" +) + +func TestDB_getDBOrder(t *testing.T) { + orderID := "orderID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbo *dbOrder + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading order orderID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling order orderID into dbOrder"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbo := &dbOrder{ + ID: orderID, + AccountID: "accID", + ProvisionerID: "provID", + CertificateID: "certID", + Status: acme.StatusValid, + ExpiresAt: now, + CreatedAt: now, + NotBefore: now, + NotAfter: now, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + Error: acme.NewError(acme.ErrorMalformedType, "force"), + } + b, err := json.Marshal(dbo) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return b, nil + }, + }, + dbo: dbo, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if dbo, err := db.getDBOrder(context.Background(), orderID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, dbo.ID, tc.dbo.ID) + assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID) + assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID) + assert.Equals(t, dbo.Status, tc.dbo.Status) + assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt) + assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt) + assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore) + assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter) + assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers) + assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs) + assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error()) + } + } + }) + } +} + +func TestDB_GetOrder(t *testing.T) { + orderID := "orderID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbo *dbOrder + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading order orderID: force"), + } + }, + "fail/forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbo := &dbOrder{ + ID: orderID, + AccountID: "accID", + ProvisionerID: "provID", + CertificateID: "certID", + Status: acme.StatusValid, + ExpiresAt: now, + CreatedAt: now, + NotBefore: now, + NotAfter: now, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + Error: acme.NewError(acme.ErrorMalformedType, "force"), + } + b, err := json.Marshal(dbo) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + return b, nil + }, + }, + dbo: dbo, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if o, err := db.GetOrder(context.Background(), orderID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, o.ID, tc.dbo.ID) + assert.Equals(t, o.AccountID, tc.dbo.AccountID) + assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID) + assert.Equals(t, o.CertificateID, tc.dbo.CertificateID) + assert.Equals(t, o.Status, tc.dbo.Status) + assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt) + assert.Equals(t, o.NotBefore, tc.dbo.NotBefore) + assert.Equals(t, o.NotAfter, tc.dbo.NotAfter) + assert.Equals(t, o.Identifiers, tc.dbo.Identifiers) + assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs) + assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error()) + } + } + }) + } +} + +func TestDB_UpdateOrder(t *testing.T) { + orderID := "orderID" + now := clock.Now() + dbo := &dbOrder{ + ID: orderID, + AccountID: "accID", + ProvisionerID: "provID", + Status: acme.StatusPending, + ExpiresAt: now, + CreatedAt: now, + NotBefore: now, + NotAfter: now, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + } + b, err := json.Marshal(dbo) + assert.FatalError(t, err) + type test struct { + db nosql.DB + o *acme.Order + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + o: &acme.Order{ + ID: orderID, + }, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading order orderID: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + o := &acme.Order{ + ID: orderID, + Status: acme.StatusValid, + CertificateID: "certID", + Error: acme.NewError(acme.ErrorMalformedType, "force"), + } + return test{ + o: o, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, old, b) + + dbNew := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbo.ID) + assert.Equals(t, dbNew.AccountID, dbo.AccountID) + assert.Equals(t, dbNew.ProvisionerID, dbo.ProvisionerID) + assert.Equals(t, dbNew.CertificateID, o.CertificateID) + assert.Equals(t, dbNew.Status, o.Status) + assert.Equals(t, dbNew.CreatedAt, dbo.CreatedAt) + assert.Equals(t, dbNew.ExpiresAt, dbo.ExpiresAt) + assert.Equals(t, dbNew.NotBefore, dbo.NotBefore) + assert.Equals(t, dbNew.NotAfter, dbo.NotAfter) + assert.Equals(t, dbNew.AuthorizationIDs, dbo.AuthorizationIDs) + assert.Equals(t, dbNew.Identifiers, dbo.Identifiers) + assert.Equals(t, dbNew.Error.Error(), o.Error.Error()) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme order: force"), + } + }, + "ok": func(t *testing.T) test { + o := &acme.Order{ + ID: orderID, + Status: acme.StatusValid, + CertificateID: "certID", + Error: acme.NewError(acme.ErrorMalformedType, "force"), + } + return test{ + o: o, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, string(key), orderID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, old, b) + + dbNew := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbo.ID) + assert.Equals(t, dbNew.AccountID, dbo.AccountID) + assert.Equals(t, dbNew.ProvisionerID, dbo.ProvisionerID) + assert.Equals(t, dbNew.CertificateID, o.CertificateID) + assert.Equals(t, dbNew.Status, o.Status) + assert.Equals(t, dbNew.CreatedAt, dbo.CreatedAt) + assert.Equals(t, dbNew.ExpiresAt, dbo.ExpiresAt) + assert.Equals(t, dbNew.NotBefore, dbo.NotBefore) + assert.Equals(t, dbNew.NotAfter, dbo.NotAfter) + assert.Equals(t, dbNew.AuthorizationIDs, dbo.AuthorizationIDs) + assert.Equals(t, dbNew.Identifiers, dbo.Identifiers) + assert.Equals(t, dbNew.Error.Error(), o.Error.Error()) + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.UpdateOrder(context.Background(), tc.o); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.o.ID, dbo.ID) + assert.Equals(t, tc.o.CertificateID, "certID") + assert.Equals(t, tc.o.Status, acme.StatusValid) + assert.Equals(t, tc.o.Error.Error(), acme.NewError(acme.ErrorMalformedType, "force").Error()) + } + } + }) + } +} + +func TestDB_CreateOrder(t *testing.T) { + now := clock.Now() + nbf := now.Add(5 * time.Minute) + naf := now.Add(15 * time.Minute) + type test struct { + db nosql.DB + o *acme.Order + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/order-save-error": func(t *testing.T) test { + o := &acme.Order{ + AccountID: "accID", + ProvisionerID: "provID", + CertificateID: "certID", + Status: acme.StatusValid, + ExpiresAt: now, + NotBefore: nbf, + NotAfter: naf, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, string(bucket), string(orderTable)) + assert.Equals(t, string(key), o.ID) + assert.Equals(t, old, nil) + + dbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbo)) + assert.Equals(t, dbo.ID, o.ID) + assert.Equals(t, dbo.AccountID, o.AccountID) + assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) + assert.Equals(t, dbo.CertificateID, "") + assert.Equals(t, dbo.Status, o.Status) + assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) + assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, dbo.NotBefore, o.NotBefore) + assert.Equals(t, dbo.NotAfter, o.NotAfter) + assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, dbo.Identifiers, o.Identifiers) + assert.Equals(t, dbo.Error, nil) + return nil, false, errors.New("force") + }, + }, + o: o, + err: errors.New("error saving acme order: force"), + } + }, + "fail/orderIDsByOrderUpdate-error": func(t *testing.T) test { + o := &acme.Order{ + AccountID: "accID", + ProvisionerID: "provID", + CertificateID: "certID", + Status: acme.StatusValid, + ExpiresAt: now, + NotBefore: nbf, + NotAfter: naf, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + } + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(ordersByAccountIDTable)) + assert.Equals(t, string(key), o.AccountID) + return nil, errors.New("force") + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, string(bucket), string(orderTable)) + assert.Equals(t, string(key), o.ID) + assert.Equals(t, old, nil) + + dbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbo)) + assert.Equals(t, dbo.ID, o.ID) + assert.Equals(t, dbo.AccountID, o.AccountID) + assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) + assert.Equals(t, dbo.CertificateID, "") + assert.Equals(t, dbo.Status, o.Status) + assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) + assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, dbo.NotBefore, o.NotBefore) + assert.Equals(t, dbo.NotAfter, o.NotAfter) + assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, dbo.Identifiers, o.Identifiers) + assert.Equals(t, dbo.Error, nil) + return nu, true, nil + }, + }, + o: o, + err: errors.New("error loading orderIDs for account accID: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idptr = &id + ) + + o := &acme.Order{ + AccountID: "accID", + ProvisionerID: "provID", + Status: acme.StatusValid, + ExpiresAt: now, + NotBefore: nbf, + NotAfter: naf, + Identifiers: []acme.Identifier{ + {Type: "dns", Value: "test.ca.smallstep.com"}, + {Type: "dns", Value: "example.foo.com"}, + }, + AuthorizationIDs: []string{"foo", "bar"}, + } + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(ordersByAccountIDTable)) + assert.Equals(t, string(key), o.AccountID) + return nil, nosqldb.ErrNotFound + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + b, err := json.Marshal([]string{o.ID}) + assert.FatalError(t, err) + assert.Equals(t, string(key), "accID") + assert.Equals(t, old, nil) + assert.Equals(t, nu, b) + return nu, true, nil + case string(orderTable): + *idptr = string(key) + assert.Equals(t, string(key), o.ID) + assert.Equals(t, old, nil) + + dbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbo)) + assert.Equals(t, dbo.ID, o.ID) + assert.Equals(t, dbo.AccountID, o.AccountID) + assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) + assert.Equals(t, dbo.CertificateID, "") + assert.Equals(t, dbo.Status, o.Status) + assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) + assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, dbo.NotBefore, o.NotBefore) + assert.Equals(t, dbo.NotAfter, o.NotAfter) + assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, dbo.Identifiers, o.Identifiers) + assert.Equals(t, dbo.Error, nil) + return nu, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + o: o, + _id: idptr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateOrder(context.Background(), tc.o); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.o.ID, *tc._id) + } + } + }) + } +} + +func TestDB_updateAddOrderIDs(t *testing.T) { + accID := "accID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + addOids []string + res []string + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return nil, errors.New("force") + }, + }, + err: errors.Errorf("error loading orderIDs for account %s", accID), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return []byte("foo"), nil + }, + }, + err: errors.Errorf("error unmarshaling orderIDs for account %s", accID), + } + }, + "fail/db.Get-order-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + b, err := json.Marshal([]string{"foo", "bar"}) + assert.FatalError(t, err) + return b, nil + case string(orderTable): + assert.Equals(t, key, []byte("foo")) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + acmeErr: acme.NewErrorISE("error loading order foo for account accID: error loading order foo: force"), + } + }, + "fail/update-order-status-error": func(t *testing.T) test { + expiry := clock.Now().Add(-5 * time.Minute) + ofoo := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bfoo, err := json.Marshal(ofoo) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + b, err := json.Marshal([]string{"foo", "bar"}) + assert.FatalError(t, err) + return b, nil + case string(orderTable): + assert.Equals(t, key, []byte("foo")) + return bfoo, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("foo")) + assert.Equals(t, old, bfoo) + + newdbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, newdbo)) + assert.Equals(t, newdbo.ID, "foo") + assert.Equals(t, newdbo.Status, acme.StatusInvalid) + assert.Equals(t, newdbo.ExpiresAt, expiry) + assert.Equals(t, newdbo.Error.Error(), acme.NewError(acme.ErrorMalformedType, "order has expired").Error()) + return nil, false, errors.New("force") + }, + }, + acmeErr: acme.NewErrorISE("error updating order foo for account accID: error updating order: error saving acme order: force"), + } + }, + "fail/db.save-order-error": func(t *testing.T) test { + addOids := []string{"foo", "bar"} + b, err := json.Marshal(addOids) + assert.FatalError(t, err) + delCount := 0 + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return nil, nosqldb.ErrNotFound + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, nil) + assert.Equals(t, nu, b) + return nil, false, errors.New("force") + }, + MDel: func(bucket, key []byte) error { + delCount++ + switch delCount { + case 1: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("foo")) + return nil + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("bar")) + return nil + default: + assert.FatalError(t, errors.New("delete should only be called twice")) + return errors.New("force") + } + }, + }, + addOids: addOids, + err: errors.Errorf("error saving orderIDs index for account %s", accID), + } + }, + "ok/all-old-not-pending": func(t *testing.T) test { + oldOids := []string{"foo", "bar"} + bOldOids, err := json.Marshal(oldOids) + assert.FatalError(t, err) + expiry := clock.Now().Add(-5 * time.Minute) + ofoo := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bfoo, err := json.Marshal(ofoo) + assert.FatalError(t, err) + obar := &dbOrder{ + ID: "bar", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bbar, err := json.Marshal(obar) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + return bOldOids, nil + case string(orderTable): + switch string(key) { + case "foo": + assert.Equals(t, key, []byte("foo")) + return bfoo, nil + case "bar": + assert.Equals(t, key, []byte("bar")) + return bbar, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(orderTable): + return nil, true, nil + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, bOldOids) + assert.Equals(t, nu, nil) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + res: []string{}, + } + }, + "ok/old-and-new": func(t *testing.T) test { + oldOids := []string{"foo", "bar"} + bOldOids, err := json.Marshal(oldOids) + assert.FatalError(t, err) + addOids := []string{"zap", "zar"} + bAddOids, err := json.Marshal(addOids) + assert.FatalError(t, err) + expiry := clock.Now().Add(-5 * time.Minute) + ofoo := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bfoo, err := json.Marshal(ofoo) + assert.FatalError(t, err) + obar := &dbOrder{ + ID: "bar", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bbar, err := json.Marshal(obar) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + return bOldOids, nil + case string(orderTable): + switch string(key) { + case "foo": + assert.Equals(t, key, []byte("foo")) + return bfoo, nil + case "bar": + assert.Equals(t, key, []byte("bar")) + return bbar, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(orderTable): + return nil, true, nil + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, bOldOids) + assert.Equals(t, nu, bAddOids) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + addOids: addOids, + res: addOids, + } + }, + "ok/old-and-new-2": func(t *testing.T) test { + oldOids := []string{"foo", "bar", "baz"} + bOldOids, err := json.Marshal(oldOids) + assert.FatalError(t, err) + addOids := []string{"zap", "zar"} + now := clock.Now() + min5 := now.Add(5 * time.Minute) + expiry := now.Add(-5 * time.Minute) + + o1 := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: min5, + AuthorizationIDs: []string{"a"}, + } + bo1, err := json.Marshal(o1) + assert.FatalError(t, err) + o2 := &dbOrder{ + ID: "bar", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bo2, err := json.Marshal(o2) + assert.FatalError(t, err) + o3 := &dbOrder{ + ID: "baz", + Status: acme.StatusPending, + ExpiresAt: min5, + AuthorizationIDs: []string{"b"}, + } + bo3, err := json.Marshal(o3) + assert.FatalError(t, err) + + az1 := &dbAuthz{ + ID: "a", + Status: acme.StatusPending, + ExpiresAt: min5, + ChallengeIDs: []string{"aa"}, + } + baz1, err := json.Marshal(az1) + assert.FatalError(t, err) + az2 := &dbAuthz{ + ID: "b", + Status: acme.StatusPending, + ExpiresAt: min5, + ChallengeIDs: []string{"bb"}, + } + baz2, err := json.Marshal(az2) + assert.FatalError(t, err) + + ch1 := &dbChallenge{ + ID: "aa", + Status: acme.StatusPending, + } + bch1, err := json.Marshal(ch1) + assert.FatalError(t, err) + ch2 := &dbChallenge{ + ID: "bb", + Status: acme.StatusPending, + } + bch2, err := json.Marshal(ch2) + assert.FatalError(t, err) + + newOids := append([]string{"foo", "baz"}, addOids...) + bNewOids, err := json.Marshal(newOids) + assert.FatalError(t, err) + + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(authzTable): + switch string(key) { + case "a": + return baz1, nil + case "b": + return baz2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", string(key))) + return nil, errors.New("force") + } + case string(challengeTable): + switch string(key) { + case "aa": + return bch1, nil + case "bb": + return bch2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected challenge key %s", string(key))) + return nil, errors.New("force") + } + case string(ordersByAccountIDTable): + return bOldOids, nil + case string(orderTable): + switch string(key) { + case "foo": + return bo1, nil + case "bar": + return bo2, nil + case "baz": + return bo3, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(orderTable): + return nil, true, nil + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, bOldOids) + assert.Equals(t, nu, bNewOids) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + addOids: addOids, + res: newOids, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + var ( + res []string + err error + ) + if tc.addOids == nil { + res, err = db.updateAddOrderIDs(context.Background(), accID) + } else { + res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...) + } + + if err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.True(t, reflect.DeepEqual(res, tc.res)) + } + } + }) + } +} diff --git a/acme/directory.go b/acme/directory.go deleted file mode 100644 index d5681b73..00000000 --- a/acme/directory.go +++ /dev/null @@ -1,150 +0,0 @@ -package acme - -import ( - "context" - "encoding/json" - "fmt" - "net/url" - - "github.com/pkg/errors" -) - -// Directory represents an ACME directory for configuring clients. -type Directory struct { - NewNonce string `json:"newNonce,omitempty"` - NewAccount string `json:"newAccount,omitempty"` - NewOrder string `json:"newOrder,omitempty"` - NewAuthz string `json:"newAuthz,omitempty"` - RevokeCert string `json:"revokeCert,omitempty"` - KeyChange string `json:"keyChange,omitempty"` -} - -// ToLog enables response logging for the Directory type. -func (d *Directory) ToLog() (interface{}, error) { - b, err := json.Marshal(d) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling directory for logging")) - } - return string(b), nil -} - -type directory struct { - prefix, dns string -} - -// newDirectory returns a new Directory type. -func newDirectory(dns, prefix string) *directory { - return &directory{prefix: prefix, dns: dns} -} - -// Link captures the link type. -type Link int - -const ( - // NewNonceLink new-nonce - NewNonceLink Link = iota - // NewAccountLink new-account - NewAccountLink - // AccountLink account - AccountLink - // OrderLink order - OrderLink - // NewOrderLink new-order - NewOrderLink - // OrdersByAccountLink list of orders owned by account - OrdersByAccountLink - // FinalizeLink finalize order - FinalizeLink - // NewAuthzLink authz - NewAuthzLink - // AuthzLink new-authz - AuthzLink - // ChallengeLink challenge - ChallengeLink - // CertificateLink certificate - CertificateLink - // DirectoryLink directory - DirectoryLink - // RevokeCertLink revoke certificate - RevokeCertLink - // KeyChangeLink key rollover - KeyChangeLink -) - -func (l Link) String() string { - switch l { - case NewNonceLink: - return "new-nonce" - case NewAccountLink: - return "new-account" - case AccountLink: - return "account" - case NewOrderLink: - return "new-order" - case OrderLink: - return "order" - case NewAuthzLink: - return "new-authz" - case AuthzLink: - return "authz" - case ChallengeLink: - return "challenge" - case CertificateLink: - return "certificate" - case DirectoryLink: - return "directory" - case RevokeCertLink: - return "revoke-cert" - case KeyChangeLink: - return "key-change" - default: - return "unexpected" - } -} - -func (d *directory) getLink(ctx context.Context, typ Link, abs bool, inputs ...string) string { - var provName string - if p, err := ProvisionerFromContext(ctx); err == nil && p != nil { - provName = p.GetName() - } - return d.getLinkExplicit(typ, provName, abs, BaseURLFromContext(ctx), inputs...) -} - -// getLinkExplicit returns an absolute or partial path to the given resource and a base -// URL dynamically obtained from the request for which the link is being -// calculated. -func (d *directory) getLinkExplicit(typ Link, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { - var link string - switch typ { - case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink: - link = fmt.Sprintf("/%s/%s", provisionerName, typ.String()) - case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink: - link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0]) - case OrdersByAccountLink: - link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0]) - case FinalizeLink: - link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0]) - } - - if abs { - // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 - u := url.URL{} - if baseURL != nil { - u = *baseURL - } - - // If no Scheme is set, then default to https. - if u.Scheme == "" { - u.Scheme = "https" - } - - // If no Host is set, then use the default (first DNS attr in the ca.json). - if u.Host == "" { - u.Host = d.dns - } - - u.Path = d.prefix + link - return u.String() - } - return link -} diff --git a/acme/directory_test.go b/acme/directory_test.go deleted file mode 100644 index dd4c534c..00000000 --- a/acme/directory_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package acme - -import ( - "context" - "fmt" - "net/url" - "testing" - - "github.com/smallstep/assert" -) - -func TestDirectoryGetLink(t *testing.T) { - dns := "ca.smallstep.com" - prefix := "acme" - dir := newDirectory(dns, prefix) - id := "1234" - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - assert.Equals(t, dir.getLink(ctx, NewNonceLink, true), - fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName)) - assert.Equals(t, dir.getLink(ctx, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName)) - - // No provisioner - ctxNoProv := context.WithValue(context.Background(), BaseURLContextKey, baseURL) - assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, true), - fmt.Sprintf("%s/acme//new-nonce", baseURL.String())) - assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, false), "//new-nonce") - - // No baseURL - ctxNoBaseURL := context.WithValue(context.Background(), ProvisionerContextKey, prov) - assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, true), - fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName)) - assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName)) - - assert.Equals(t, dir.getLink(ctx, OrderLink, true, id), - fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName)) - assert.Equals(t, dir.getLink(ctx, OrderLink, false, id), fmt.Sprintf("/%s/order/1234", provName)) -} - -func TestDirectoryGetLinkExplicit(t *testing.T) { - dns := "ca.smallstep.com" - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prefix := "acme" - dir := newDirectory(dns, prefix) - id := "1234" - - prov := newProv() - provID := url.PathEscape(prov.GetName()) - - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID)) - - assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID)) - - assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID)) - - assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID)) - - assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) - - assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID)) - - assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID)) - - assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID)) - - assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) - - assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) -} diff --git a/acme/errors.go b/acme/errors.go index a4dd8159..6ecf0912 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -1,407 +1,339 @@ package acme import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/logging" ) -// AccountDoesNotExistErr returns a new acme error. -func AccountDoesNotExistErr(err error) *Error { - return &Error{ - Type: accountDoesNotExistErr, - Detail: "Account does not exist", - Status: 400, - Err: err, - } -} - -// AlreadyRevokedErr returns a new acme error. -func AlreadyRevokedErr(err error) *Error { - return &Error{ - Type: alreadyRevokedErr, - Detail: "Certificate already revoked", - Status: 400, - Err: err, - } -} - -// BadCSRErr returns a new acme error. -func BadCSRErr(err error) *Error { - return &Error{ - Type: badCSRErr, - Detail: "The CSR is unacceptable", - Status: 400, - Err: err, - } -} - -// BadNonceErr returns a new acme error. -func BadNonceErr(err error) *Error { - return &Error{ - Type: badNonceErr, - Detail: "Unacceptable anti-replay nonce", - Status: 400, - Err: err, - } -} - -// BadPublicKeyErr returns a new acme error. -func BadPublicKeyErr(err error) *Error { - return &Error{ - Type: badPublicKeyErr, - Detail: "The jws was signed by a public key the server does not support", - Status: 400, - Err: err, - } -} - -// BadRevocationReasonErr returns a new acme error. -func BadRevocationReasonErr(err error) *Error { - return &Error{ - Type: badRevocationReasonErr, - Detail: "The revocation reason provided is not allowed by the server", - Status: 400, - Err: err, - } -} - -// BadSignatureAlgorithmErr returns a new acme error. -func BadSignatureAlgorithmErr(err error) *Error { - return &Error{ - Type: badSignatureAlgorithmErr, - Detail: "The JWS was signed with an algorithm the server does not support", - Status: 400, - Err: err, - } -} - -// CaaErr returns a new acme error. -func CaaErr(err error) *Error { - return &Error{ - Type: caaErr, - Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate", - Status: 400, - Err: err, - } -} - -// CompoundErr returns a new acme error. -func CompoundErr(err error) *Error { - return &Error{ - Type: compoundErr, - Detail: "Specific error conditions are indicated in the “subproblems” array", - Status: 400, - Err: err, - } -} - -// ConnectionErr returns a new acme error. -func ConnectionErr(err error) *Error { - return &Error{ - Type: connectionErr, - Detail: "The server could not connect to validation target", - Status: 400, - Err: err, - } -} - -// DNSErr returns a new acme error. -func DNSErr(err error) *Error { - return &Error{ - Type: dnsErr, - Detail: "There was a problem with a DNS query during identifier validation", - Status: 400, - Err: err, - } -} - -// ExternalAccountRequiredErr returns a new acme error. -func ExternalAccountRequiredErr(err error) *Error { - return &Error{ - Type: externalAccountRequiredErr, - Detail: "The request must include a value for the \"externalAccountBinding\" field", - Status: 400, - Err: err, - } -} - -// IncorrectResponseErr returns a new acme error. -func IncorrectResponseErr(err error) *Error { - return &Error{ - Type: incorrectResponseErr, - Detail: "Response received didn't match the challenge's requirements", - Status: 400, - Err: err, - } -} - -// InvalidContactErr returns a new acme error. -func InvalidContactErr(err error) *Error { - return &Error{ - Type: invalidContactErr, - Detail: "A contact URL for an account was invalid", - Status: 400, - Err: err, - } -} - -// MalformedErr returns a new acme error. -func MalformedErr(err error) *Error { - return &Error{ - Type: malformedErr, - Detail: "The request message was malformed", - Status: 400, - Err: err, - } -} - -// OrderNotReadyErr returns a new acme error. -func OrderNotReadyErr(err error) *Error { - return &Error{ - Type: orderNotReadyErr, - Detail: "The request attempted to finalize an order that is not ready to be finalized", - Status: 400, - Err: err, - } -} - -// RateLimitedErr returns a new acme error. -func RateLimitedErr(err error) *Error { - return &Error{ - Type: rateLimitedErr, - Detail: "The request exceeds a rate limit", - Status: 400, - Err: err, - } -} - -// RejectedIdentifierErr returns a new acme error. -func RejectedIdentifierErr(err error) *Error { - return &Error{ - Type: rejectedIdentifierErr, - Detail: "The server will not issue certificates for the identifier", - Status: 400, - Err: err, - } -} - -// ServerInternalErr returns a new acme error. -func ServerInternalErr(err error) *Error { - return &Error{ - Type: serverInternalErr, - Detail: "The server experienced an internal error", - Status: 500, - Err: err, - } -} - -// NotImplemented returns a new acme error. -func NotImplemented(err error) *Error { - return &Error{ - Type: notImplemented, - Detail: "The requested operation is not implemented", - Status: 501, - Err: err, - } -} - -// TLSErr returns a new acme error. -func TLSErr(err error) *Error { - return &Error{ - Type: tlsErr, - Detail: "The server received a TLS error during validation", - Status: 400, - Err: err, - } -} - -// UnauthorizedErr returns a new acme error. -func UnauthorizedErr(err error) *Error { - return &Error{ - Type: unauthorizedErr, - Detail: "The client lacks sufficient authorization", - Status: 401, - Err: err, - } -} - -// UnsupportedContactErr returns a new acme error. -func UnsupportedContactErr(err error) *Error { - return &Error{ - Type: unsupportedContactErr, - Detail: "A contact URL for an account used an unsupported protocol scheme", - Status: 400, - Err: err, - } -} - -// UnsupportedIdentifierErr returns a new acme error. -func UnsupportedIdentifierErr(err error) *Error { - return &Error{ - Type: unsupportedIdentifierErr, - Detail: "An identifier is of an unsupported type", - Status: 400, - Err: err, - } -} - -// UserActionRequiredErr returns a new acme error. -func UserActionRequiredErr(err error) *Error { - return &Error{ - Type: userActionRequiredErr, - Detail: "Visit the “instance” URL and take actions specified there", - Status: 400, - Err: err, - } -} - -// ProbType is the type of the ACME problem. -type ProbType int +// ProblemType is the type of the ACME problem. +type ProblemType int const ( - // The request specified an account that does not exist - accountDoesNotExistErr ProbType = iota - // The request specified a certificate to be revoked that has already been revoked - alreadyRevokedErr - // The CSR is unacceptable (e.g., due to a short key) - badCSRErr - // The client sent an unacceptable anti-replay nonce - badNonceErr - // The JWS was signed by a public key the server does not support - badPublicKeyErr - // The revocation reason provided is not allowed by the server - badRevocationReasonErr - // The JWS was signed with an algorithm the server does not support - badSignatureAlgorithmErr - // Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate - caaErr - // Specific error conditions are indicated in the “subproblems” array. - compoundErr - // The server could not connect to validation target - connectionErr - // There was a problem with a DNS query during identifier validation - dnsErr - // The request must include a value for the “externalAccountBinding” field - externalAccountRequiredErr - // Response received didn’t match the challenge’s requirements - incorrectResponseErr - // A contact URL for an account was invalid - invalidContactErr - // The request message was malformed - malformedErr - // The request attempted to finalize an order that is not ready to be finalized - orderNotReadyErr - // The request exceeds a rate limit - rateLimitedErr - // The server will not issue certificates for the identifier - rejectedIdentifierErr - // The server experienced an internal error - serverInternalErr - // The server received a TLS error during validation - tlsErr - // The client lacks sufficient authorization - unauthorizedErr - // A contact URL for an account used an unsupported protocol scheme - unsupportedContactErr - // An identifier is of an unsupported type - unsupportedIdentifierErr - // Visit the “instance” URL and take actions specified there - userActionRequiredErr - // The operation is not implemented - notImplemented + // ErrorAccountDoesNotExistType request specified an account that does not exist + ErrorAccountDoesNotExistType ProblemType = iota + // ErrorAlreadyRevokedType request specified a certificate to be revoked that has already been revoked + ErrorAlreadyRevokedType + // ErrorBadCSRType CSR is unacceptable (e.g., due to a short key) + ErrorBadCSRType + // ErrorBadNonceType client sent an unacceptable anti-replay nonce + ErrorBadNonceType + // ErrorBadPublicKeyType JWS was signed by a public key the server does not support + ErrorBadPublicKeyType + // ErrorBadRevocationReasonType revocation reason provided is not allowed by the server + ErrorBadRevocationReasonType + // ErrorBadSignatureAlgorithmType JWS was signed with an algorithm the server does not support + ErrorBadSignatureAlgorithmType + // ErrorCaaType Authority Authorization (CAA) records forbid the CA from issuing a certificate + ErrorCaaType + // ErrorCompoundType error conditions are indicated in the “subproblems” array. + ErrorCompoundType + // ErrorConnectionType server could not connect to validation target + ErrorConnectionType + // ErrorDNSType was a problem with a DNS query during identifier validation + ErrorDNSType + // ErrorExternalAccountRequiredType request must include a value for the “externalAccountBinding” field + ErrorExternalAccountRequiredType + // ErrorIncorrectResponseType received didn’t match the challenge’s requirements + ErrorIncorrectResponseType + // ErrorInvalidContactType URL for an account was invalid + ErrorInvalidContactType + // ErrorMalformedType request message was malformed + ErrorMalformedType + // ErrorOrderNotReadyType request attempted to finalize an order that is not ready to be finalized + ErrorOrderNotReadyType + // ErrorRateLimitedType request exceeds a rate limit + ErrorRateLimitedType + // ErrorRejectedIdentifierType server will not issue certificates for the identifier + ErrorRejectedIdentifierType + // ErrorServerInternalType server experienced an internal error + ErrorServerInternalType + // ErrorTLSType server received a TLS error during validation + ErrorTLSType + // ErrorUnauthorizedType client lacks sufficient authorization + ErrorUnauthorizedType + // ErrorUnsupportedContactType URL for an account used an unsupported protocol scheme + ErrorUnsupportedContactType + // ErrorUnsupportedIdentifierType identifier is of an unsupported type + ErrorUnsupportedIdentifierType + // ErrorUserActionRequiredType the “instance” URL and take actions specified there + ErrorUserActionRequiredType + // ErrorNotImplementedType operation is not implemented + ErrorNotImplementedType ) // String returns the string representation of the acme problem type, // fulfilling the Stringer interface. -func (ap ProbType) String() string { +func (ap ProblemType) String() string { switch ap { - case accountDoesNotExistErr: + case ErrorAccountDoesNotExistType: return "accountDoesNotExist" - case alreadyRevokedErr: + case ErrorAlreadyRevokedType: return "alreadyRevoked" - case badCSRErr: + case ErrorBadCSRType: return "badCSR" - case badNonceErr: + case ErrorBadNonceType: return "badNonce" - case badPublicKeyErr: + case ErrorBadPublicKeyType: return "badPublicKey" - case badRevocationReasonErr: + case ErrorBadRevocationReasonType: return "badRevocationReason" - case badSignatureAlgorithmErr: + case ErrorBadSignatureAlgorithmType: return "badSignatureAlgorithm" - case caaErr: + case ErrorCaaType: return "caa" - case compoundErr: + case ErrorCompoundType: return "compound" - case connectionErr: + case ErrorConnectionType: return "connection" - case dnsErr: + case ErrorDNSType: return "dns" - case externalAccountRequiredErr: + case ErrorExternalAccountRequiredType: return "externalAccountRequired" - case incorrectResponseErr: + case ErrorInvalidContactType: return "incorrectResponse" - case invalidContactErr: - return "invalidContact" - case malformedErr: + case ErrorMalformedType: return "malformed" - case orderNotReadyErr: + case ErrorOrderNotReadyType: return "orderNotReady" - case rateLimitedErr: + case ErrorRateLimitedType: return "rateLimited" - case rejectedIdentifierErr: + case ErrorRejectedIdentifierType: return "rejectedIdentifier" - case serverInternalErr: + case ErrorServerInternalType: return "serverInternal" - case tlsErr: + case ErrorTLSType: return "tls" - case unauthorizedErr: + case ErrorUnauthorizedType: return "unauthorized" - case unsupportedContactErr: + case ErrorUnsupportedContactType: return "unsupportedContact" - case unsupportedIdentifierErr: + case ErrorUnsupportedIdentifierType: return "unsupportedIdentifier" - case userActionRequiredErr: + case ErrorUserActionRequiredType: return "userActionRequired" - case notImplemented: + case ErrorNotImplementedType: return "notImplemented" default: - return "unsupported type" + return fmt.Sprintf("unsupported type ACME error type '%d'", int(ap)) } } -// Error is an ACME error type complete with problem document. -type Error struct { - Type ProbType - Detail string - Err error - Status int - Sub []*Error - Identifier *Identifier +type errorMetadata struct { + details string + status int + typ string + String string } -// Wrap attempts to wrap the internal error. -func Wrap(err error, wrap string) *Error { +var ( + officialACMEPrefix = "urn:ietf:params:acme:error:" + errorServerInternalMetadata = errorMetadata{ + typ: officialACMEPrefix + ErrorServerInternalType.String(), + details: "The server experienced an internal error", + status: 500, + } + errorMap = map[ProblemType]errorMetadata{ + ErrorAccountDoesNotExistType: { + typ: officialACMEPrefix + ErrorAccountDoesNotExistType.String(), + details: "Account does not exist", + status: 400, + }, + ErrorAlreadyRevokedType: { + typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(), + details: "Certificate already Revoked", + status: 400, + }, + ErrorBadCSRType: { + typ: officialACMEPrefix + ErrorBadCSRType.String(), + details: "The CSR is unacceptable", + status: 400, + }, + ErrorBadNonceType: { + typ: officialACMEPrefix + ErrorBadNonceType.String(), + details: "Unacceptable anti-replay nonce", + status: 400, + }, + ErrorBadPublicKeyType: { + typ: officialACMEPrefix + ErrorBadPublicKeyType.String(), + details: "The jws was signed by a public key the server does not support", + status: 400, + }, + ErrorBadRevocationReasonType: { + typ: officialACMEPrefix + ErrorBadRevocationReasonType.String(), + details: "The revocation reason provided is not allowed by the server", + status: 400, + }, + ErrorBadSignatureAlgorithmType: { + typ: officialACMEPrefix + ErrorBadSignatureAlgorithmType.String(), + details: "The JWS was signed with an algorithm the server does not support", + status: 400, + }, + ErrorCaaType: { + typ: officialACMEPrefix + ErrorCaaType.String(), + details: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate", + status: 400, + }, + ErrorCompoundType: { + typ: officialACMEPrefix + ErrorCompoundType.String(), + details: "Specific error conditions are indicated in the “subproblems” array", + status: 400, + }, + ErrorConnectionType: { + typ: officialACMEPrefix + ErrorConnectionType.String(), + details: "The server could not connect to validation target", + status: 400, + }, + ErrorDNSType: { + typ: officialACMEPrefix + ErrorDNSType.String(), + details: "There was a problem with a DNS query during identifier validation", + status: 400, + }, + ErrorExternalAccountRequiredType: { + typ: officialACMEPrefix + ErrorExternalAccountRequiredType.String(), + details: "The request must include a value for the \"externalAccountBinding\" field", + status: 400, + }, + ErrorIncorrectResponseType: { + typ: officialACMEPrefix + ErrorIncorrectResponseType.String(), + details: "Response received didn't match the challenge's requirements", + status: 400, + }, + ErrorInvalidContactType: { + typ: officialACMEPrefix + ErrorInvalidContactType.String(), + details: "A contact URL for an account was invalid", + status: 400, + }, + ErrorMalformedType: { + typ: officialACMEPrefix + ErrorMalformedType.String(), + details: "The request message was malformed", + status: 400, + }, + ErrorOrderNotReadyType: { + typ: officialACMEPrefix + ErrorOrderNotReadyType.String(), + details: "The request attempted to finalize an order that is not ready to be finalized", + status: 400, + }, + ErrorRateLimitedType: { + typ: officialACMEPrefix + ErrorRateLimitedType.String(), + details: "The request exceeds a rate limit", + status: 400, + }, + ErrorRejectedIdentifierType: { + typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(), + details: "The server will not issue certificates for the identifier", + status: 400, + }, + ErrorNotImplementedType: { + typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(), + details: "The requested operation is not implemented", + status: 501, + }, + ErrorTLSType: { + typ: officialACMEPrefix + ErrorTLSType.String(), + details: "The server received a TLS error during validation", + status: 400, + }, + ErrorUnauthorizedType: { + typ: officialACMEPrefix + ErrorUnauthorizedType.String(), + details: "The client lacks sufficient authorization", + status: 401, + }, + ErrorUnsupportedContactType: { + typ: officialACMEPrefix + ErrorUnsupportedContactType.String(), + details: "A contact URL for an account used an unsupported protocol scheme", + status: 400, + }, + ErrorUnsupportedIdentifierType: { + typ: officialACMEPrefix + ErrorUnsupportedIdentifierType.String(), + details: "An identifier is of an unsupported type", + status: 400, + }, + ErrorUserActionRequiredType: { + typ: officialACMEPrefix + ErrorUserActionRequiredType.String(), + details: "Visit the “instance” URL and take actions specified there", + status: 400, + }, + ErrorServerInternalType: errorServerInternalMetadata, + } +) + +// Error represents an ACME +type Error struct { + Type string `json:"type"` + Detail string `json:"detail"` + Subproblems []interface{} `json:"subproblems,omitempty"` + Identifier interface{} `json:"identifier,omitempty"` + Err error `json:"-"` + Status int `json:"-"` +} + +// NewError creates a new Error type. +func NewError(pt ProblemType, msg string, args ...interface{}) *Error { + return newError(pt, errors.Errorf(msg, args...)) +} + +func newError(pt ProblemType, err error) *Error { + meta, ok := errorMap[pt] + if !ok { + meta = errorServerInternalMetadata + return &Error{ + Type: meta.typ, + Detail: meta.details, + Status: meta.status, + Err: err, + } + } + + return &Error{ + Type: meta.typ, + Detail: meta.details, + Status: meta.status, + Err: err, + } +} + +// NewErrorISE creates a new ErrorServerInternalType Error. +func NewErrorISE(msg string, args ...interface{}) *Error { + return NewError(ErrorServerInternalType, msg, args...) +} + +// WrapError attempts to wrap the internal error. +func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error { switch e := err.(type) { case nil: return nil case *Error: if e.Err == nil { - e.Err = errors.New(wrap + "; " + e.Detail) + e.Err = errors.Errorf(msg+"; "+e.Detail, args...) } else { - e.Err = errors.Wrap(e.Err, wrap) + e.Err = errors.Wrapf(e.Err, msg, args...) } return e default: - return ServerInternalErr(errors.Wrap(err, wrap)) + return newError(typ, errors.Wrapf(err, msg, args...)) } } -// Error implements the error interface. +// WrapErrorISE shortcut to wrap an internal server error type. +func WrapErrorISE(err error, msg string, args ...interface{}) *Error { + return WrapError(ErrorServerInternalType, err, msg, args...) +} + +// StatusCode returns the status code and implements the StatusCoder interface. +func (e *Error) StatusCode() int { + return e.Status +} + +// Error allows AError to implement the error interface. func (e *Error) Error() string { - if e.Err == nil { - return e.Detail - } - return e.Err.Error() + return e.Detail } // Cause returns the internal error and implements the Causer interface. @@ -412,70 +344,35 @@ func (e *Error) Cause() error { return e.Err } -// Official returns true if this error's type is listed in §6.7 of RFC 8555. -// Error types in §6.7 are registered under IETF urn namespace: -// -// "urn:ietf:params:acme:error:" -// -// and should include the namespace as a prefix when appearing as a problem -// document. -// -// RFC 8555 also says: -// -// This list is not exhaustive. The server MAY return errors whose -// "type" field is set to a URI other than those defined above. Servers -// MUST NOT use the ACME URN namespace for errors not listed in the -// appropriate IANA registry (see Section 9.6). Clients SHOULD display -// the "detail" field of all errors. -// -// In this case Official returns `false` so that a different namespace can -// be used. -func (e *Error) Official() bool { - return e.Type != notImplemented -} - -// ToACME returns an acme representation of the problem type. -// For official errors, the IETF ACME namespace is prepended to the error type. -// For our own errors, we use an (yet) unregistered smallstep acme namespace. -func (e *Error) ToACME() *AError { - prefix := "urn:step:acme:error" - if e.Official() { - prefix = "urn:ietf:params:acme:error:" +// ToLog implements the EnableLogger interface. +func (e *Error) ToLog() (interface{}, error) { + b, err := json.Marshal(e) + if err != nil { + return nil, WrapErrorISE(err, "error marshaling acme.Error for logging") } - ae := &AError{ - Type: prefix + e.Type.String(), - Detail: e.Error(), - Status: e.Status, + return string(b), nil +} + +// WriteError writes to w a JSON representation of the given error. +func WriteError(w http.ResponseWriter, err *Error) { + w.Header().Set("Content-Type", "application/problem+json") + w.WriteHeader(err.StatusCode()) + + // Write errors in the response writer + if rl, ok := w.(logging.ResponseLogger); ok { + rl.WithFields(map[string]interface{}{ + "error": err.Err, + }) + if os.Getenv("STEPDEBUG") == "1" { + if e, ok := err.Err.(errs.StackTracer); ok { + rl.WithFields(map[string]interface{}{ + "stack-trace": fmt.Sprintf("%+v", e), + }) + } + } } - if e.Identifier != nil { - ae.Identifier = *e.Identifier + + if err := json.NewEncoder(w).Encode(err); err != nil { + log.Println(err) } - for _, p := range e.Sub { - ae.Subproblems = append(ae.Subproblems, p.ToACME()) - } - return ae -} - -// StatusCode returns the status code and implements the StatusCode interface. -func (e *Error) StatusCode() int { - return e.Status -} - -// AError is the error type as seen in acme request/responses. -type AError struct { - Type string `json:"type"` - Detail string `json:"detail"` - Identifier interface{} `json:"identifier,omitempty"` - Subproblems []interface{} `json:"subproblems,omitempty"` - Status int `json:"-"` -} - -// Error allows AError to implement the error interface. -func (ae *AError) Error() string { - return ae.Detail -} - -// StatusCode returns the status code and implements the StatusCode interface. -func (ae *AError) StatusCode() int { - return ae.Status } diff --git a/acme/nonce.go b/acme/nonce.go index db680f08..25c86360 100644 --- a/acme/nonce.go +++ b/acme/nonce.go @@ -1,73 +1,9 @@ package acme -import ( - "encoding/base64" - "encoding/json" - "time" +// Nonce represents an ACME nonce type. +type Nonce string - "github.com/pkg/errors" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" -) - -// nonce contains nonce metadata used in the ACME protocol. -type nonce struct { - ID string - Created time.Time -} - -// newNonce creates, stores, and returns an ACME replay-nonce. -func newNonce(db nosql.DB) (*nonce, error) { - _id, err := randID() - if err != nil { - return nil, err - } - - id := base64.RawURLEncoding.EncodeToString([]byte(_id)) - n := &nonce{ - ID: id, - Created: clock.Now(), - } - b, err := json.Marshal(n) - if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce")) - } - _, swapped, err := db.CmpAndSwap(nonceTable, []byte(id), nil, b) - switch { - case err != nil: - return nil, ServerInternalErr(errors.Wrap(err, "error storing nonce")) - case !swapped: - return nil, ServerInternalErr(errors.New("error storing nonce; " + - "value has changed since last read")) - default: - return n, nil - } -} - -// useNonce verifies that the nonce is valid (by checking if it exists), -// and if so, consumes the nonce resource by deleting it from the database. -func useNonce(db nosql.DB, nonce string) error { - err := db.Update(&database.Tx{ - Operations: []*database.TxEntry{ - { - Bucket: nonceTable, - Key: []byte(nonce), - Cmd: database.Get, - }, - { - Bucket: nonceTable, - Key: []byte(nonce), - Cmd: database.Delete, - }, - }, - }) - - switch { - case nosql.IsErrNotFound(err): - return BadNonceErr(nil) - case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce)) - default: - return nil - } +// String implements the ToString interface. +func (n Nonce) String() string { + return string(n) } diff --git a/acme/nonce_test.go b/acme/nonce_test.go deleted file mode 100644 index 6aa467a0..00000000 --- a/acme/nonce_test.go +++ /dev/null @@ -1,163 +0,0 @@ -package acme - -import ( - "testing" - "time" - - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" -) - -func TestNewNonce(t *testing.T) { - type test struct { - db nosql.DB - err *Error - id *string - } - tests := map[string]func(t *testing.T) test{ - "fail/cmpAndSwap-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, old, nil) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error storing nonce: force")), - } - }, - "fail/cmpAndSwap-false": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, old, nil) - return nil, false, nil - }, - }, - err: ServerInternalErr(errors.Errorf("error storing nonce; value has changed since last read")), - } - }, - "ok": func(t *testing.T) test { - var _id string - id := &_id - return test{ - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, nonceTable) - assert.Equals(t, old, nil) - *id = string(key) - return nil, true, nil - }, - }, - id: id, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if n, err := newNonce(tc.db); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, n.ID, *tc.id) - - assert.True(t, n.Created.Before(time.Now().Add(time.Minute))) - assert.True(t, n.Created.After(time.Now().Add(-time.Minute))) - } - } - }) - } -} - -func TestUseNonce(t *testing.T) { - type test struct { - id string - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/update-not-found": func(t *testing.T) test { - id := "foo" - return test{ - db: &db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - assert.Equals(t, tx.Operations[0].Bucket, nonceTable) - assert.Equals(t, tx.Operations[0].Key, []byte(id)) - assert.Equals(t, tx.Operations[0].Cmd, database.Get) - - assert.Equals(t, tx.Operations[1].Bucket, nonceTable) - assert.Equals(t, tx.Operations[1].Key, []byte(id)) - assert.Equals(t, tx.Operations[1].Cmd, database.Delete) - return database.ErrNotFound - }, - }, - id: id, - err: BadNonceErr(nil), - } - }, - "fail/update-error": func(t *testing.T) test { - id := "foo" - return test{ - db: &db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - assert.Equals(t, tx.Operations[0].Bucket, nonceTable) - assert.Equals(t, tx.Operations[0].Key, []byte(id)) - assert.Equals(t, tx.Operations[0].Cmd, database.Get) - - assert.Equals(t, tx.Operations[1].Bucket, nonceTable) - assert.Equals(t, tx.Operations[1].Key, []byte(id)) - assert.Equals(t, tx.Operations[1].Cmd, database.Delete) - return errors.New("force") - }, - }, - id: id, - err: ServerInternalErr(errors.Errorf("error deleting nonce %s: force", id)), - } - }, - "ok": func(t *testing.T) test { - id := "foo" - return test{ - db: &db.MockNoSQLDB{ - MUpdate: func(tx *database.Tx) error { - assert.Equals(t, tx.Operations[0].Bucket, nonceTable) - assert.Equals(t, tx.Operations[0].Key, []byte(id)) - assert.Equals(t, tx.Operations[0].Cmd, database.Get) - - assert.Equals(t, tx.Operations[1].Bucket, nonceTable) - assert.Equals(t, tx.Operations[1].Key, []byte(id)) - assert.Equals(t, tx.Operations[1].Cmd, database.Delete) - - return nil - }, - }, - id: id, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := useNonce(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } - }) - } -} diff --git a/acme/order.go b/acme/order.go index 574477ca..a003fe9a 100644 --- a/acme/order.go +++ b/acme/order.go @@ -6,351 +6,129 @@ import ( "encoding/json" "sort" "strings" - "sync" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/nosql" "go.step.sm/crypto/x509util" ) -var defaultOrderExpiry = time.Hour * 24 - -// Mutex for locking ordersByAccount index operations. -var ordersByAccountMux sync.Mutex +// Identifier encodes the type that an order pertains to. +type Identifier struct { + Type string `json:"type"` + Value string `json:"value"` +} // Order contains order metadata for the ACME protocol order type. type Order struct { - Status string `json:"status"` - Expires string `json:"expires,omitempty"` - Identifiers []Identifier `json:"identifiers"` - NotBefore string `json:"notBefore,omitempty"` - NotAfter string `json:"notAfter,omitempty"` - Error interface{} `json:"error,omitempty"` - Authorizations []string `json:"authorizations"` - Finalize string `json:"finalize"` - Certificate string `json:"certificate,omitempty"` - ID string `json:"-"` + ID string `json:"id"` + AccountID string `json:"-"` + ProvisionerID string `json:"-"` + Status Status `json:"status"` + ExpiresAt time.Time `json:"expires"` + Identifiers []Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore"` + NotAfter time.Time `json:"notAfter"` + Error *Error `json:"error,omitempty"` + AuthorizationIDs []string `json:"-"` + AuthorizationURLs []string `json:"authorizations"` + FinalizeURL string `json:"finalize"` + CertificateID string `json:"-"` + CertificateURL string `json:"certificate,omitempty"` } // ToLog enables response logging. func (o *Order) ToLog() (interface{}, error) { b, err := json.Marshal(o) if err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error marshaling order for logging")) + return nil, WrapErrorISE(err, "error marshaling order for logging") } return string(b), nil } -// GetID returns the Order ID. -func (o *Order) GetID() string { - return o.ID -} - -// OrderOptions options with which to create a new Order. -type OrderOptions struct { - AccountID string `json:"accID"` - Identifiers []Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore"` - NotAfter time.Time `json:"notAfter"` - backdate time.Duration - defaultDuration time.Duration -} - -type order struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Created time.Time `json:"created"` - Expires time.Time `json:"expires,omitempty"` - Status string `json:"status"` - Identifiers []Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore,omitempty"` - NotAfter time.Time `json:"notAfter,omitempty"` - Error *Error `json:"error,omitempty"` - Authorizations []string `json:"authorizations"` - Certificate string `json:"certificate,omitempty"` -} - -// newOrder returns a new Order type. -func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { - id, err := randID() - if err != nil { - return nil, err - } - - authzs := make([]string, len(ops.Identifiers)) - for i, identifier := range ops.Identifiers { - az, err := newAuthz(db, ops.AccountID, identifier) - if err != nil { - return nil, err - } - authzs[i] = az.getID() - } - +// UpdateStatus updates the ACME Order Status if necessary. +// Changes to the order are saved using the database interface. +func (o *Order) UpdateStatus(ctx context.Context, db DB) error { now := clock.Now() - var backdate time.Duration - nbf := ops.NotBefore - if nbf.IsZero() { - nbf = now - backdate = -1 * ops.backdate - } - naf := ops.NotAfter - if naf.IsZero() { - naf = nbf.Add(ops.defaultDuration) - } - o := &order{ - ID: id, - AccountID: ops.AccountID, - Created: now, - Status: StatusPending, - Expires: now.Add(defaultOrderExpiry), - Identifiers: ops.Identifiers, - NotBefore: nbf.Add(backdate), - NotAfter: naf, - Authorizations: authzs, - } - if err := o.save(db, nil); err != nil { - return nil, err - } - - var oidHelper = orderIDsByAccount{} - _, err = oidHelper.addOrderID(db, ops.AccountID, o.ID) - if err != nil { - return nil, err - } - return o, nil -} - -type orderIDsByAccount struct{} - -// addOrderID adds an order ID to a users index of in progress order IDs. -// This method will also cull any orders that are no longer in the `pending` -// state from the index before returning it. -func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) { - ordersByAccountMux.Lock() - defer ordersByAccountMux.Unlock() - - // Update the "order IDs by account ID" index - oids, err := oiba.unsafeGetOrderIDsByAccount(db, accID) - if err != nil { - return nil, err - } - newOids := append(oids, oid) - if err = orderIDs(newOids).save(db, oids, accID); err != nil { - // Delete the entire order if storing the index fails. - db.Del(orderTable, []byte(oid)) - return nil, err - } - return newOids, nil -} - -// unsafeGetOrderIDsByAccount retrieves a list of Order IDs that were created by the -// account. -func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) { - b, err := db.Get(ordersByAccountIDTable, []byte(accID)) - if err != nil { - if nosql.IsErrNotFound(err) { - return []string{}, nil - } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) - } - var oids []string - if err := json.Unmarshal(b, &oids); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) - } - - // Remove any order that is not in PENDING state and update the stored list - // before returning. - // - // According to RFC 8555: - // The server SHOULD include pending orders and SHOULD NOT include orders - // that are invalid in the array of URLs. - pendOids := []string{} - for _, oid := range oids { - o, err := getOrder(db, oid) - if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) - } - if o, err = o.updateStatus(db); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) - } - if o.Status == StatusPending { - pendOids = append(pendOids, oid) - } - } - // If the number of pending orders is less than the number of orders in the - // list, then update the pending order list. - if len(pendOids) != len(oids) { - if err = orderIDs(pendOids).save(db, oids, accID); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ - "len(orderIDs) = %d", len(pendOids))) - } - } - - return pendOids, nil -} - -type orderIDs []string - -// save is used to update the list of orderIDs keyed by ACME account ID -// stored in the database. -// -// This method always converts empty lists to 'nil' when storing to the DB. We -// do this to avoid any confusion between an empty list and a nil value in the -// db. -func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error { - var ( - err error - oldb []byte - newb []byte - ) - if len(old) == 0 { - oldb = nil - } else { - oldb, err = json.Marshal(old) - if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice")) - } - } - if len(oids) == 0 { - newb = nil - } else { - newb, err = json.Marshal(oids) - if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice")) - } - } - _, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb) - switch { - case err != nil: - return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID)) - case !swapped: - return ServerInternalErr(errors.Errorf("error storing order IDs "+ - "for account %s; order IDs changed since last read", accID)) - default: - return nil - } -} - -func (o *order) save(db nosql.DB, old *order) error { - var ( - err error - oldB []byte - ) - if old == nil { - oldB = nil - } else { - if oldB, err = json.Marshal(old); err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling old acme order")) - } - } - - newB, err := json.Marshal(o) - if err != nil { - return ServerInternalErr(errors.Wrap(err, "error marshaling new acme order")) - } - - _, swapped, err := db.CmpAndSwap(orderTable, []byte(o.ID), oldB, newB) - switch { - case err != nil: - return ServerInternalErr(errors.Wrap(err, "error storing order")) - case !swapped: - return ServerInternalErr(errors.New("error storing order; " + - "value has changed since last read")) - default: - return nil - } -} - -// updateStatus updates order status if necessary. -func (o *order) updateStatus(db nosql.DB) (*order, error) { - _newOrder := *o - newOrder := &_newOrder - - now := time.Now().UTC() switch o.Status { case StatusInvalid: - return o, nil + return nil case StatusValid: - return o, nil + return nil case StatusReady: - // check expiry - if now.After(o.Expires) { - newOrder.Status = StatusInvalid - newOrder.Error = MalformedErr(errors.New("order has expired")) + // Check expiry + if now.After(o.ExpiresAt) { + o.Status = StatusInvalid + o.Error = NewError(ErrorMalformedType, "order has expired") break } - return o, nil + return nil case StatusPending: - // check expiry - if now.After(o.Expires) { - newOrder.Status = StatusInvalid - newOrder.Error = MalformedErr(errors.New("order has expired")) + // Check expiry + if now.After(o.ExpiresAt) { + o.Status = StatusInvalid + o.Error = NewError(ErrorMalformedType, "order has expired") break } - var count = map[string]int{ + var count = map[Status]int{ StatusValid: 0, StatusInvalid: 0, StatusPending: 0, } - for _, azID := range o.Authorizations { - az, err := getAuthz(db, azID) + for _, azID := range o.AuthorizationIDs { + az, err := db.GetAuthorization(ctx, azID) if err != nil { - return nil, err + return WrapErrorISE(err, "error getting authorization ID %s", azID) } - if az, err = az.updateStatus(db); err != nil { - return nil, err + if err = az.UpdateStatus(ctx, db); err != nil { + return WrapErrorISE(err, "error updating authorization ID %s", azID) } - st := az.getStatus() + st := az.Status count[st]++ } switch { case count[StatusInvalid] > 0: - newOrder.Status = StatusInvalid + o.Status = StatusInvalid // No change in the order status, so just return the order as is - // without writing any changes. case count[StatusPending] > 0: - return newOrder, nil + return nil - case count[StatusValid] == len(o.Authorizations): - newOrder.Status = StatusReady + case count[StatusValid] == len(o.AuthorizationIDs): + o.Status = StatusReady default: - return nil, ServerInternalErr(errors.New("unexpected authz status")) + return NewErrorISE("unexpected authz status") } default: - return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status)) + return NewErrorISE("unrecognized order status: %s", o.Status) } - - if err := newOrder.save(db, o); err != nil { - return nil, err + if err := db.UpdateOrder(ctx, o); err != nil { + return WrapErrorISE(err, "error updating order") } - return newOrder, nil + return nil } -// finalize signs a certificate if the necessary conditions for Order completion +// Finalize signs a certificate if the necessary conditions for Order completion // have been met. -func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) (*order, error) { - var err error - if o, err = o.updateStatus(db); err != nil { - return nil, err +func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth CertificateAuthority, p Provisioner) error { + if err := o.UpdateStatus(ctx, db); err != nil { + return err } switch o.Status { case StatusInvalid: - return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)) + return NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID) case StatusValid: - return o, nil + return nil case StatusPending: - return nil, OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)) + return NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID) case StatusReady: break default: - return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID)) + return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID) } // RFC8555: The CSR MUST indicate the exact same set of requested @@ -361,12 +139,12 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut if csr.Subject.CommonName != "" { csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) } - csr.DNSNames = uniqueLowerNames(csr.DNSNames) + csr.DNSNames = uniqueSortedLowerNames(csr.DNSNames) orderNames := make([]string, len(o.Identifiers)) for i, n := range o.Identifiers { orderNames[i] = n.Value } - orderNames = uniqueLowerNames(orderNames) + orderNames = uniqueSortedLowerNames(orderNames) // Validate identifier names against CSR alternative names. // @@ -374,13 +152,15 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut // absence of other SANs as they will only be set if the templates allows // them. if len(csr.DNSNames) != len(orderNames) { - return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) + return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) } sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames)) for i := range csr.DNSNames { if csr.DNSNames[i] != orderNames[i] { - return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) + return NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) } sans[i] = x509util.SubjectAlternativeName{ Type: x509util.DNSType, @@ -389,10 +169,10 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut } // Get authorizations from the ACME provisioner. - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOps, err := p.AuthorizeSign(ctx, "") if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner")) + return WrapErrorISE(err, "error retrieving authorization options from ACME provisioner") } // Template data @@ -402,82 +182,41 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner")) + return WrapErrorISE(err, "error creating template options from ACME provisioner") } signOps = append(signOps, templateOptions) - // Create and store a new certificate. + // Sign a new certificate. certChain, err := auth.Sign(csr, provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(o.NotBefore), NotAfter: provisioner.NewTimeDuration(o.NotAfter), }, signOps...) if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID)) + return WrapErrorISE(err, "error signing certificate for order %s", o.ID) } - cert, err := newCert(db, CertOptions{ + cert := &Certificate{ AccountID: o.AccountID, OrderID: o.ID, Leaf: certChain[0], Intermediates: certChain[1:], - }) - if err != nil { - return nil, err + } + if err := db.CreateCertificate(ctx, cert); err != nil { + return WrapErrorISE(err, "error creating certificate for order %s", o.ID) } - _newOrder := *o - newOrder := &_newOrder - newOrder.Certificate = cert.ID - newOrder.Status = StatusValid - if err := newOrder.save(db, o); err != nil { - return nil, err + o.CertificateID = cert.ID + o.Status = StatusValid + if err = db.UpdateOrder(ctx, o); err != nil { + return WrapErrorISE(err, "error updating order %s", o.ID) } - return newOrder, nil + return nil } -// getOrder retrieves and unmarshals an ACME Order type from the database. -func getOrder(db nosql.DB, id string) (*order, error) { - b, err := db.Get(orderTable, []byte(id)) - if nosql.IsErrNotFound(err) { - return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id)) - } else if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id)) - } - var o order - if err := json.Unmarshal(b, &o); err != nil { - return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order")) - } - return &o, nil -} - -// toACME converts the internal Order type into the public acmeOrder type for -// presentation in the ACME protocol. -func (o *order) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Order, error) { - azs := make([]string, len(o.Authorizations)) - for i, aid := range o.Authorizations { - azs[i] = dir.getLink(ctx, AuthzLink, true, aid) - } - ao := &Order{ - Status: o.Status, - Expires: o.Expires.Format(time.RFC3339), - Identifiers: o.Identifiers, - NotBefore: o.NotBefore.Format(time.RFC3339), - NotAfter: o.NotAfter.Format(time.RFC3339), - Authorizations: azs, - Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID), - ID: o.ID, - } - - if o.Certificate != "" { - ao.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate) - } - return ao, nil -} - -// uniqueLowerNames returns the set of all unique names in the input after all +// uniqueSortedLowerNames returns the set of all unique names in the input after all // of them are lowercased. The returned names will be in their lowercased form // and sorted alphabetically. -func uniqueLowerNames(names []string) (unique []string) { +func uniqueSortedLowerNames(names []string) (unique []string) { nameMap := make(map[string]int, len(names)) for _, name := range names { nameMap[strings.ToLower(name)] = 1 diff --git a/acme/order_test.go b/acme/order_test.go index e6a8f057..993a92f2 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -5,865 +5,232 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/json" - "fmt" - "net" - "net/url" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" ) -var certDuration = 6 * time.Hour - -func defaultOrderOps() OrderOptions { - return OrderOptions{ - AccountID: "accID", - Identifiers: []Identifier{ - {Type: "dns", Value: "acme.example.com"}, - {Type: "dns", Value: "step.example.com"}, - }, - NotBefore: clock.Now(), - NotAfter: clock.Now().Add(certDuration), - } -} - -func newO() (*order, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - } - return newOrder(mockdb, defaultOrderOps()) -} - -func Test_getOrder(t *testing.T) { +func TestOrder_UpdateStatus(t *testing.T) { type test struct { - id string - db nosql.DB - o *order + o *Order err *Error + db DB } tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) + "ok/already-invalid": func(t *testing.T) test { + o := &Order{ + Status: StatusInvalid, + } return test{ - o: o, - id: o.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("order %s not found: not found", o.ID)), + o: o, } }, - "fail/db-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) + "ok/already-valid": func(t *testing.T) test { + o := &Order{ + Status: StatusInvalid, + } return test{ - o: o, - id: o.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading order %s: force", o.ID)), + o: o, } }, - "fail/unmarshal-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{ - o: o, - id: o.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return nil, nil - }, - }, - err: ServerInternalErr(errors.New("error unmarshaling order: unexpected end of JSON input")), + "fail/error-unexpected-status": func(t *testing.T) test { + o := &Order{ + Status: "foo", } - }, - "ok": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - return test{ - o: o, - id: o.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(o.ID)) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if o, err := getOrder(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.o.ID, o.ID) - assert.Equals(t, tc.o.AccountID, o.AccountID) - assert.Equals(t, tc.o.Status, o.Status) - assert.Equals(t, tc.o.Identifiers, o.Identifiers) - assert.Equals(t, tc.o.Created, o.Created) - assert.Equals(t, tc.o.Expires, o.Expires) - assert.Equals(t, tc.o.Authorizations, o.Authorizations) - assert.Equals(t, tc.o.NotBefore, o.NotBefore) - assert.Equals(t, tc.o.NotAfter, o.NotAfter) - assert.Equals(t, tc.o.Certificate, o.Certificate) - assert.Equals(t, tc.o.Error, o.Error) - } - } - }) - } -} - -func TestOrderToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - type test struct { - o *order - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/no-cert": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{o: o} - }, - "ok/cert": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusValid - o.Certificate = "cert-id" - return test{o: o} - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - acmeOrder, err := tc.o.toACME(ctx, nil, dir) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acmeOrder.ID, tc.o.ID) - assert.Equals(t, acmeOrder.Status, tc.o.Status) - assert.Equals(t, acmeOrder.Identifiers, tc.o.Identifiers) - assert.Equals(t, acmeOrder.Finalize, - fmt.Sprintf("%s/acme/%s/order/%s/finalize", baseURL.String(), provName, tc.o.ID)) - if tc.o.Certificate != "" { - assert.Equals(t, acmeOrder.Certificate, fmt.Sprintf("%s/acme/%s/certificate/%s", baseURL.String(), provName, tc.o.Certificate)) - } - - expiry, err := time.Parse(time.RFC3339, acmeOrder.Expires) - assert.FatalError(t, err) - assert.Equals(t, expiry.String(), tc.o.Expires.String()) - nbf, err := time.Parse(time.RFC3339, acmeOrder.NotBefore) - assert.FatalError(t, err) - assert.Equals(t, nbf.String(), tc.o.NotBefore.String()) - naf, err := time.Parse(time.RFC3339, acmeOrder.NotAfter) - assert.FatalError(t, err) - assert.Equals(t, naf.String(), tc.o.NotAfter.String()) - } - } - }) - } -} - -func TestOrderSave(t *testing.T) { - type test struct { - o, old *order - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) return test{ o: o, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing order: force")), + err: NewErrorISE("unrecognized order status: %s", o.Status), } }, - "fail/old-nil/swap-false": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{ - o: o, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error storing order; value has changed since last read")), + "ok/ready-expired": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(-5 * time.Minute), } - }, - "ok/old-nil": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) return test{ - o: o, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, orderTable) - assert.Equals(t, []byte(o.ID), key) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldo, err := newO() - assert.FatalError(t, err) - o, err := newO() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldo) - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - return test{ - o: o, - old: oldo, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, orderTable) - assert.Equals(t, []byte(o.ID), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.o.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func Test_newOrder(t *testing.T) { - type test struct { - ops OrderOptions - db nosql.DB - err *Error - authzs *([]string) - } - tests := map[string]func(t *testing.T) test{ - "fail/unexpected-identifier-type": func(t *testing.T) test { - ops := defaultOrderOps() - ops.Identifiers[0].Type = "foo" - return test{ - ops: ops, - err: MalformedErr(errors.New("unexpected authz type foo")), - } - }, - "fail/save-order-error": func(t *testing.T) test { - count := 0 - return test{ - ops: defaultOrderOps(), - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 8 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error storing order: force")), - } - }, - "fail/get-orderIDs-error": func(t *testing.T) test { - count := 0 - ops := defaultOrderOps() - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 9 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading orderIDs for account %s: force", ops.AccountID)), - } - }, - "fail/save-orderIDs-error": func(t *testing.T) test { - count := 0 - var ( - _oid = "" - oid = &_oid - ) - ops := defaultOrderOps() - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 9 { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(ops.AccountID)) - return nil, false, errors.New("force") - } else if count == 8 { - *oid = string(key) - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - MDel: func(bucket, key []byte) error { - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(*oid)) + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) return nil }, }, - err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", ops.AccountID)), } }, - "ok": func(t *testing.T) test { - count := 0 - authzs := &([]string{}) - var ( - _oid = "" - oid = &_oid - ) - ops := defaultOrderOps() + "fail/ready-expired-db.UpdateOrder-error": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(-5 * time.Minute), + } return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 9 { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(ops.AccountID)) - assert.Equals(t, old, nil) - newB, err := json.Marshal([]string{*oid}) - assert.FatalError(t, err) - assert.Equals(t, newval, newB) - } else if count == 8 { - *oid = string(key) - } else if count == 7 { - *authzs = append(*authzs, string(key)) - } else if count == 3 { - *authzs = []string{string(key)} - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return errors.New("force") }, }, - authzs: authzs, + err: NewErrorISE("error updating order: force"), } }, - "ok/validity-bounds-not-set": func(t *testing.T) test { - count := 0 - authzs := &([]string{}) - var ( - _oid = "" - oid = &_oid - ) - ops := defaultOrderOps() - ops.backdate = time.Minute - ops.defaultDuration = 12 * time.Hour - ops.NotBefore = time.Time{} - ops.NotAfter = time.Time{} + "ok/pending-expired": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(-5 * time.Minute), + } return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count >= 9 { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(ops.AccountID)) - assert.Equals(t, old, nil) - newB, err := json.Marshal([]string{*oid}) - assert.FatalError(t, err) - assert.Equals(t, newval, newB) - } else if count == 8 { - *oid = string(key) - } else if count == 7 { - *authzs = append(*authzs, string(key)) - } else if count == 3 { - *authzs = []string{string(key)} - } - count++ - return nil, true, nil - }, - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - authzs: authzs, - } - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - o, err := newOrder(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, o.AccountID, tc.ops.AccountID) - assert.Equals(t, o.Status, StatusPending) - assert.Equals(t, o.Identifiers, tc.ops.Identifiers) - assert.Equals(t, o.Error, nil) - assert.Equals(t, o.Certificate, "") - assert.Equals(t, o.Authorizations, *tc.authzs) + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) - assert.True(t, o.Created.Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, o.Created.After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := o.Created.Add(defaultExpiryDuration) - assert.True(t, o.Expires.Before(expiry.Add(time.Minute))) - assert.True(t, o.Expires.After(expiry.Add(-1*time.Minute))) - - nbf := tc.ops.NotBefore - now := time.Now().UTC() - if !tc.ops.NotBefore.IsZero() { - assert.Equals(t, o.NotBefore, tc.ops.NotBefore) - } else { - nbf = o.NotBefore.Add(tc.ops.backdate) - assert.True(t, o.NotBefore.Before(now.Add(-tc.ops.backdate+time.Second))) - assert.True(t, o.NotBefore.Add(tc.ops.backdate+2*time.Second).After(now)) - } - if !tc.ops.NotAfter.IsZero() { - assert.Equals(t, o.NotAfter, tc.ops.NotAfter) - } else { - naf := nbf.Add(tc.ops.defaultDuration) - assert.Equals(t, o.NotAfter, naf) - } - } - } - }) - } -} - -func TestOrderIDs_save(t *testing.T) { - accID := "acc-id" - newOids := func() orderIDs { - return []string{"1", "2"} - } - type test struct { - oids, old orderIDs - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - return test{ - oids: newOids(), - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", accID)), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - return test{ - oids: newOids(), - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s; order IDs changed since last read", accID)), - } - }, - "ok/old-nil": func(t *testing.T) test { - oids := newOids() - b, err := json.Marshal(oids) - assert.FatalError(t, err) - return test{ - oids: oids, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(accID)) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldOids := newOids() - oids := append(oldOids, "3") - - oldb, err := json.Marshal(oldOids) - assert.FatalError(t, err) - b, err := json.Marshal(oids) - assert.FatalError(t, err) - return test{ - oids: oids, - old: oldOids, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(accID)) - return nil, true, nil - }, - }, - } - }, - "ok/new-empty-saved-as-nil": func(t *testing.T) test { - oldOids := newOids() - oids := []string{} - - oldb, err := json.Marshal(oldOids) - assert.FatalError(t, err) - return test{ - oids: oids, - old: oldOids, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, newval, nil) - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(accID)) - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.oids.save(tc.db, tc.old, accID); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestOrderUpdateStatus(t *testing.T) { - type test struct { - o, res *order - err *Error - db nosql.DB - } - tests := map[string]func(t *testing.T) test{ - "fail/already-invalid": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusInvalid - return test{ - o: o, - res: o, - } - }, - "fail/already-valid": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusValid - return test{ - o: o, - res: o, - } - }, - "fail/unexpected-status": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusDeactivated - return test{ - o: o, - res: o, - err: ServerInternalErr(errors.New("unrecognized order status: deactivated")), - } - }, - "fail/save-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Expires = time.Now().UTC().Add(-time.Minute) - return test{ - o: o, - res: o, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing order: force")), - } - }, - "ok/expired": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Expires = time.Now().UTC().Add(-time.Minute) - - _o := *o - clone := &_o - clone.Error = MalformedErr(errors.New("order has expired")) - clone.Status = StatusInvalid - return test{ - o: o, - res: clone, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - "fail/get-authz-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - return test{ - o: o, - res: o, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading authz")), - } - }, - "ok/still-pending": func(t *testing.T) test { - az1, err := newAz() - assert.FatalError(t, err) - az2, err := newAz() - assert.FatalError(t, err) - az3, err := newAz() - assert.FatalError(t, err) - - ch1, err := newHTTPCh() - assert.FatalError(t, err) - ch2, err := newTLSALPNCh() - assert.FatalError(t, err) - ch3, err := newDNSCh() - assert.FatalError(t, err) - - ch1b, err := json.Marshal(ch1) - assert.FatalError(t, err) - ch2b, err := json.Marshal(ch2) - assert.FatalError(t, err) - ch3b, err := json.Marshal(ch3) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()} - - _az3, ok := az3.(*dnsAuthz) - assert.Fatal(t, ok) - _az3.baseAuthz.Status = StatusValid - - b1, err := json.Marshal(az1) - assert.FatalError(t, err) - b2, err := json.Marshal(az2) - assert.FatalError(t, err) - b3, err := json.Marshal(az3) - assert.FatalError(t, err) - - count := 0 - return test{ - o: o, - res: o, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - ret = b1 - case 1: - ret = ch1b - case 2: - ret = ch2b - case 3: - ret = ch3b - case 4: - ret = b2 - case 5: - ret = ch1b - case 6: - ret = ch2b - case 7: - ret = ch3b - case 8: - ret = b3 - default: - return nil, errors.New("unexpected count") - } - count++ - return ret, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil + err := NewError(ErrorMalformedType, "order has expired") + assert.HasPrefix(t, updo.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updo.Error.Type, err.Type) + assert.Equals(t, updo.Error.Detail, err.Detail) + assert.Equals(t, updo.Error.Status, err.Status) + assert.Equals(t, updo.Error.Detail, err.Detail) + return nil }, }, } }, "ok/invalid": func(t *testing.T) test { - az1, err := newAz() - assert.FatalError(t, err) - az2, err := newAz() - assert.FatalError(t, err) - az3, err := newAz() - assert.FatalError(t, err) + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusInvalid, + } - ch1, err := newHTTPCh() - assert.FatalError(t, err) - ch2, err := newTLSALPNCh() - assert.FatalError(t, err) - ch3, err := newDNSCh() - assert.FatalError(t, err) - - ch1b, err := json.Marshal(ch1) - assert.FatalError(t, err) - ch2b, err := json.Marshal(ch2) - assert.FatalError(t, err) - ch3b, err := json.Marshal(ch3) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()} - - _az3, ok := az3.(*dnsAuthz) - assert.Fatal(t, ok) - _az3.baseAuthz.Status = StatusInvalid - - b1, err := json.Marshal(az1) - assert.FatalError(t, err) - b2, err := json.Marshal(az2) - assert.FatalError(t, err) - b3, err := json.Marshal(az3) - assert.FatalError(t, err) - - _o := *o - clone := &_o - clone.Status = StatusInvalid - - count := 0 return test{ - o: o, - res: clone, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - ret = b1 - case 1: - ret = ch1b - case 2: - ret = ch2b - case 3: - ret = ch3b - case 4: - ret = b2 - case 5: - ret = ch1b - case 6: - ret = ch2b - case 7: - ret = ch3b - case 8: - ret = b3 - default: - return nil, errors.New("unexpected count") - } - count++ - return ret, nil + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return nil }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } + }, + }, + } + }, + "ok/still-pending": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusPending, + } + + return test{ + o: o, + db: &MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } + }, + }, + } + }, + "ok/valid": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusValid, + } + + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusReady) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return nil + }, + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } }, }, } @@ -872,25 +239,24 @@ func TestOrderUpdateStatus(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - o, err := tc.o.updateStatus(tc.db) - if err != nil { + if err := tc.o.UpdateStatus(context.Background(), tc.db); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { - if assert.Nil(t, tc.err) { - expB, err := json.Marshal(tc.res) - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - assert.Equals(t, expB, b) - } + assert.Nil(t, tc.err) } }) + } } @@ -917,820 +283,456 @@ func (m *mockSignAuth) LoadProvisionerByID(id string) (provisioner.Interface, er return m.ret1.(provisioner.Interface), m.err } -func TestOrderFinalize(t *testing.T) { - prov := newProv() +func TestOrder_Finalize(t *testing.T) { type test struct { - o, res *order - err *Error - db nosql.DB - csr *x509.CertificateRequest - sa SignAuthority - prov Provisioner + o *Order + err *Error + db DB + ca CertificateAuthority + csr *x509.CertificateRequest + prov Provisioner } tests := map[string]func(t *testing.T) test{ - "fail/already-invalid": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusInvalid + "fail/invalid": func(t *testing.T) test { + o := &Order{ + ID: "oid", + Status: StatusInvalid, + } return test{ o: o, - err: OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)), + err: NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID), + } + }, + "fail/pending": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + } + + return test{ + o: o, + db: &MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } + }, + }, + err: NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID), } }, "ok/already-valid": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusValid - o.Certificate = "cert-id" + o := &Order{ + ID: "oid", + Status: StatusValid, + } return test{ - o: o, - res: o, + o: o, } }, - "fail/still-pending": func(t *testing.T) test { - az1, err := newAz() - assert.FatalError(t, err) - az2, err := newAz() - assert.FatalError(t, err) - az3, err := newAz() - assert.FatalError(t, err) + "fail/error-unexpected-status": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: "foo", + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } - ch1, err := newHTTPCh() - assert.FatalError(t, err) - ch2, err := newTLSALPNCh() - assert.FatalError(t, err) - ch3, err := newDNSCh() - assert.FatalError(t, err) - - ch1b, err := json.Marshal(ch1) - assert.FatalError(t, err) - ch2b, err := json.Marshal(ch2) - assert.FatalError(t, err) - ch3b, err := json.Marshal(ch3) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()} - - _az3, ok := az3.(*dnsAuthz) - assert.Fatal(t, ok) - _az3.baseAuthz.Status = StatusValid - - b1, err := json.Marshal(az1) - assert.FatalError(t, err) - b2, err := json.Marshal(az2) - assert.FatalError(t, err) - b3, err := json.Marshal(az3) - assert.FatalError(t, err) - - count := 0 return test{ o: o, - res: o, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - ret = b1 - case 1: - ret = ch1b - case 2: - ret = ch2b - case 3: - ret = ch3b - case 4: - ret = b2 - case 5: - ret = ch1b - case 6: - ret = ch2b - case 7: - ret = ch3b - case 8: - ret = b3 - default: - return nil, errors.New("unexpected count") - } - count++ - return ret, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, + err: NewErrorISE("unrecognized order status: %s", o.Status), + } + }, + "fail/error-names-length-mismatch": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, }, - err: OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)), } - }, - "fail/ready/csr-names-match-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - + orderNames := []string{"bar.internal", "foo.internal"} csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "acme.example.com", + CommonName: "foo.internal", }, - DNSNames: []string{"acme.example.com", "fail.smallstep.com"}, } + return test{ o: o, csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), + err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", []string{"foo.internal"}, orderNames), } }, - "fail/ready/csr-names-match-error-2": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - + "fail/error-names-mismatch": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + } + orderNames := []string{"bar.internal", "foo.internal"} csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "", + CommonName: "foo.internal", }, - DNSNames: []string{"acme.example.com"}, + DNSNames: []string{"zap.internal"}, } + return test{ o: o, csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), + err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ + "CSR names = %v, Order names = %v", []string{"foo.internal", "zap.internal"}, orderNames), } }, - "fail/ready/no-ipAddresses": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - + "fail/error-provisioner-auth": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + } csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "", + CommonName: "foo.internal", }, - // DNSNames: []string{"acme.example.com", "step.example.com"}, - IPAddresses: []net.IP{net.ParseIP("1.1.1.1")}, + DNSNames: []string{"bar.internal"}, } + return test{ o: o, csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), - } - }, - "fail/ready/no-emailAddresses": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "", - }, - // DNSNames: []string{"acme.example.com", "step.example.com"}, - EmailAddresses: []string{"max@smallstep.com", "mariano@smallstep.com"}, - } - return test{ - o: o, - csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), - } - }, - "fail/ready/no-URIs": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - u, err := url.Parse("https://google.com") - assert.FatalError(t, err) - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "", - }, - // DNSNames: []string{"acme.example.com", "step.example.com"}, - URIs: []*url.URL{u}, - } - return test{ - o: o, - csr: csr, - err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), - } - }, - "fail/ready/provisioner-auth-sign-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"step.example.com", "acme.example.com"}, - } - return test{ - o: o, - csr: csr, - err: ServerInternalErr(errors.New("error retrieving authorization options from ACME provisioner: force")), prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") return nil, errors.New("force") }, }, + err: NewErrorISE("error retrieving authorization options from ACME provisioner: force"), } }, - "fail/ready/sign-cert-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - + "fail/error-template-options": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + } csr := &x509.CertificateRequest{ Subject: pkix.Name{ - CommonName: "acme.example.com", + CommonName: "foo.internal", }, - DNSNames: []string{"step.example.com", "acme.example.com"}, + DNSNames: []string{"bar.internal"}, } + return test{ o: o, csr: csr, - err: ServerInternalErr(errors.Errorf("error generating certificate for order %s: force", o.ID)), - sa: &mockSignAuth{ - err: errors.New("force"), - }, - } - }, - "fail/ready/store-cert-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"step.example.com", "acme.example.com"}, - } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - } - inter := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "intermediate", - }, - } - return test{ - o: o, - csr: csr, - err: ServerInternalErr(errors.Errorf("error storing certificate: force")), - sa: &mockSignAuth{ - ret1: crt, ret2: inter, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - } - }, - "fail/ready/store-order-error": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"acme.example.com", "step.example.com"}, - } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - } - inter := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "intermediate", - }, - } - count := 0 - return test{ - o: o, - csr: csr, - err: ServerInternalErr(errors.Errorf("error storing order: force")), - sa: &mockSignAuth{ - ret1: crt, ret2: inter, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - } - }, - "ok/ready/sign": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"acme.example.com", "step.example.com"}, - } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - } - inter := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "intermediate", - }, - } - - _o := *o - clone := &_o - clone.Status = StatusValid - - count := 0 - return test{ - o: o, - res: clone, - csr: csr, - sa: &mockSignAuth{ - sign: func(csr *x509.CertificateRequest, pops provisioner.SignOptions, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) { - assert.Equals(t, len(signOps), 6) - return []*x509.Certificate{crt, inter}, nil - }, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - clone.Certificate = string(key) - } - count++ - return nil, true, nil - }, - }, - } - }, - "ok/ready/no-sans": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - o.Identifiers = []Identifier{ - {Type: "dns", Value: "step.example.com"}, - } - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "step.example.com", - }, - } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "step.example.com", - }, - DNSNames: []string{"step.example.com"}, - } - inter := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "intermediate", - }, - } - - clone := *o - clone.Status = StatusValid - count := 0 - return test{ - o: o, - res: &clone, - csr: csr, - sa: &mockSignAuth{ - sign: func(csr *x509.CertificateRequest, pops provisioner.SignOptions, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) { - assert.Equals(t, len(signOps), 6) - return []*x509.Certificate{crt, inter}, nil - }, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - clone.Certificate = string(key) - } - count++ - return nil, true, nil - }, - }, - } - }, - "ok/ready/sans-and-name": func(t *testing.T) test { - o, err := newO() - assert.FatalError(t, err) - o.Status = StatusReady - - csr := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"step.example.com"}, - } - crt := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "acme.example.com", - }, - DNSNames: []string{"acme.example.com", "step.example.com"}, - } - inter := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "intermediate", - }, - } - - clone := *o - clone.Status = StatusValid - count := 0 - return test{ - o: o, - res: &clone, - csr: csr, - sa: &mockSignAuth{ - sign: func(csr *x509.CertificateRequest, pops provisioner.SignOptions, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) { - assert.Equals(t, len(signOps), 6) - return []*x509.Certificate{crt, inter}, nil - }, - }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - clone.Certificate = string(key) - } - count++ - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - p := tc.prov - if p == nil { - p = prov - } - o, err := tc.o.finalize(tc.db, tc.csr, tc.sa, p) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - expB, err := json.Marshal(tc.res) - assert.FatalError(t, err) - b, err := json.Marshal(o) - assert.FatalError(t, err) - assert.Equals(t, expB, b) - } - } - }) - } -} - -func Test_getOrderIDsByAccount(t *testing.T) { - type test struct { - id string - db nosql.DB - res []string - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/not-found": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - res: []string{}, - } - }, - "fail/db-error": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") return nil, nil }, - }, - err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")), - } - }, - "fail/error-loading-order-from-order-IDs": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - dbHit := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - dbHit++ - switch dbHit { - case 1: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return boids, nil - case 2: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte("o1")) - return nil, errors.New("force") - default: - assert.FatalError(t, errors.New("should not be here")) - return nil, nil + MgetOptions: func() *provisioner.Options { + return &provisioner.Options{ + X509: &provisioner.X509Options{ + TemplateData: json.RawMessage([]byte("fo{o")), + }, } }, }, - err: ServerInternalErr(errors.New("error loading order o1 for account foo: error loading order o1: force")), + err: NewErrorISE("error creating template options from ACME provisioner: error unmarshaling template data: invalid character 'o' in literal false (expecting 'a')"), } }, - "fail/error-updating-order-from-order-IDs": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) + "fail/error-ca-sign": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + } + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", + }, + DNSNames: []string{"bar.internal"}, + } - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - dbHit := 0 return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - dbHit++ - switch dbHit { - case 1: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return boids, nil - case 2: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte("o1")) - return bo, nil - case 3: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(o.Authorizations[0])) - return nil, errors.New("force") - default: - assert.FatalError(t, errors.New("should not be here")) - return nil, nil - } + o: o, + csr: csr, + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil + }, + MgetOptions: func() *provisioner.Options { + return nil }, }, - err: ServerInternalErr(errors.Errorf("error updating order o1 for account foo: error loading authz %s: force", o.Authorizations[0])), + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return nil, errors.New("force") + }, + }, + err: NewErrorISE("error signing certificate for order oID: force"), } }, - "ok/no-change-to-pending-orders": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) + "fail/error-db.CreateCertificate": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + } + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", + }, + DNSNames: []string{"bar.internal"}, + } - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) + foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} + bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} + baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } + o: o, + csr: csr, + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("should not be attempting to store anything") + MgetOptions: func() *provisioner.Options { + return nil }, }, - res: oids, + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return []*x509.Certificate{foo, bar, baz}, nil + }, + }, + db: &MockDB{ + MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { + assert.Equals(t, cert.AccountID, o.AccountID) + assert.Equals(t, cert.OrderID, o.ID) + assert.Equals(t, cert.Leaf, foo) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) + return errors.New("force") + }, + }, + err: NewErrorISE("error creating certificate for order oID: force"), } }, - "fail/error-storing-new-oids": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) + "fail/error-db.UpdateOrder": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, + }, + } + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", + }, + DNSNames: []string{"bar.internal"}, + } - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) + foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} + bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} + baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - dbGetOrder++ - if dbGetOrder == 1 { - return binvalidOrder, nil - } - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } + o: o, + csr: csr, + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, false, errors.New("force") + MgetOptions: func() *provisioner.Options { + return nil }, }, - err: ServerInternalErr(errors.New("error storing orderIDs as part of getOrderIDsByAccount logic: len(orderIDs) = 2: error storing order IDs for account foo: force")), + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return []*x509.Certificate{foo, bar, baz}, nil + }, + }, + db: &MockDB{ + MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { + cert.ID = "certID" + assert.Equals(t, cert.AccountID, o.AccountID) + assert.Equals(t, cert.OrderID, o.ID) + assert.Equals(t, cert.Leaf, foo) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) + return nil + }, + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.CertificateID, "certID") + assert.Equals(t, updo.Status, StatusValid) + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, updo.Identifiers, o.Identifiers) + return errors.New("force") + }, + }, + err: NewErrorISE("error updating order oID: force"), } }, - "ok": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3", "o4"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - dbGetOrder++ - if dbGetOrder == 1 || dbGetOrder == 3 { - return binvalidOrder, nil - } - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, true, nil - }, + "ok/new-cert": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + Identifiers: []Identifier{ + {Type: "dns", Value: "foo.internal"}, + {Type: "dns", Value: "bar.internal"}, }, - res: []string{"o2", "o4"}, } - }, - "ok/no-pending-orders": func(t *testing.T) test { - oids := []string{"o1"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) + csr := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "foo.internal", + }, + DNSNames: []string{"bar.internal"}, + } - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) + foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} + bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} + baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - return binvalidOrder, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } + o: o, + csr: csr, + prov: &MockProvisioner{ + MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { + assert.Equals(t, token, "") + return nil, nil }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - assert.Equals(t, old, boids) - assert.Nil(t, newval) - return nil, true, nil + MgetOptions: func() *provisioner.Options { + return nil + }, + }, + ca: &mockSignAuth{ + sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + assert.Equals(t, _csr, csr) + return []*x509.Certificate{foo, bar, baz}, nil + }, + }, + db: &MockDB{ + MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { + cert.ID = "certID" + assert.Equals(t, cert.AccountID, o.AccountID) + assert.Equals(t, cert.OrderID, o.ID) + assert.Equals(t, cert.Leaf, foo) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) + return nil + }, + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.CertificateID, "certID") + assert.Equals(t, updo.Status, StatusValid) + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, updo.Identifiers, o.Identifiers) + return nil }, }, - res: []string{}, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - var oiba = orderIDsByAccount{} - if oids, err := oiba.unsafeGetOrderIDsByAccount(tc.db, tc.id); err != nil { + if err := tc.o.Finalize(context.Background(), tc.db, tc.csr, tc.ca, tc.prov); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res, oids) - } + assert.Nil(t, tc.err) } }) } diff --git a/acme/status.go b/acme/status.go new file mode 100644 index 00000000..d9aae82d --- /dev/null +++ b/acme/status.go @@ -0,0 +1,20 @@ +package acme + +// Status represents an ACME status. +type Status string + +var ( + // StatusValid -- valid + StatusValid = Status("valid") + // StatusInvalid -- invalid + StatusInvalid = Status("invalid") + // StatusPending -- pending; e.g. an Order that is not ready to be finalized. + StatusPending = Status("pending") + // StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid. + StatusDeactivated = Status("deactivated") + // StatusReady -- ready; e.g. for an Order that is ready to be finalized. + StatusReady = Status("ready") + //statusExpired = "expired" + //statusActive = "active" + //statusProcessing = "processing" +) diff --git a/api/errors.go b/api/errors.go index 93057ed2..fa2d6a06 100644 --- a/api/errors.go +++ b/api/errors.go @@ -16,11 +16,12 @@ import ( func WriteError(w http.ResponseWriter, err error) { switch k := err.(type) { case *acme.Error: - w.Header().Set("Content-Type", "application/problem+json") - err = k.ToACME() + acme.WriteError(w, k) + return default: w.Header().Set("Content-Type", "application/json") } + cause := errors.Cause(err) if sc, ok := err.(errs.StatusCoder); ok { w.WriteHeader(sc.StatusCode()) diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go index 775ed96f..f5cd5221 100644 --- a/authority/provisioner/method.go +++ b/authority/provisioner/method.go @@ -56,8 +56,7 @@ func NewContextWithMethod(ctx context.Context, method Method) context.Context { return context.WithValue(ctx, methodKey{}, method) } -// MethodFromContext returns the Method saved in ctx. Returns Sign if the given -// context has no Method associated with it. +// MethodFromContext returns the Method saved in ctx. func MethodFromContext(ctx context.Context) Method { m, _ := ctx.Value(methodKey{}).(Method) return m diff --git a/ca/acmeClient.go b/ca/acmeClient.go index deb8a3a2..5633dac5 100644 --- a/ca/acmeClient.go +++ b/ca/acmeClient.go @@ -21,7 +21,7 @@ import ( type ACMEClient struct { client *http.Client dirLoc string - dir *acme.Directory + dir *acmeAPI.Directory acc *acme.Account Key *jose.JSONWebKey kid string @@ -53,7 +53,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } - var dir acme.Directory + var dir acmeAPI.Directory if err := readJSON(resp.Body, &dir); err != nil { return nil, errors.Wrapf(err, "error reading %s", endpoint) } @@ -93,7 +93,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC // GetDirectory makes a directory request to the ACME api and returns an // ACME directory object. -func (c *ACMEClient) GetDirectory() (*acme.Directory, error) { +func (c *ACMEClient) GetDirectory() (*acmeAPI.Directory, error) { return c.dir, nil } @@ -231,7 +231,7 @@ func (c *ACMEClient) ValidateChallenge(url string) error { } // GetAuthz returns the Authz at the given path. -func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) { +func (c *ACMEClient) GetAuthz(url string) (*acme.Authorization, error) { resp, err := c.post(nil, url, withKid(c)) if err != nil { return nil, err @@ -240,7 +240,7 @@ func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) { return nil, readACMEError(resp.Body) } - var az acme.Authz + var az acme.Authorization if err := readJSON(resp.Body, &az); err != nil { return nil, errors.Wrapf(err, "error reading %s", url) } @@ -320,7 +320,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) { if c.acc == nil { return nil, errors.New("acme client not configured with account") } - resp, err := c.post(nil, c.acc.Orders, withKid(c)) + resp, err := c.post(nil, c.acc.OrdersURL, withKid(c)) if err != nil { return nil, err } @@ -330,7 +330,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) { var orders []string if err := readJSON(resp.Body, &orders); err != nil { - return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders) + return nil, errors.Wrapf(err, "error reading %s", c.acc.OrdersURL) } return orders, nil @@ -342,7 +342,7 @@ func readACMEError(r io.ReadCloser) error { if err != nil { return errors.Wrap(err, "error reading from body") } - ae := new(acme.AError) + ae := new(acme.Error) err = json.Unmarshal(b, &ae) // If we successfully marshaled to an ACMEError then return the ACMEError. if err != nil || len(ae.Error()) == 0 { diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index 25d74b9d..f5963de4 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -31,18 +31,17 @@ func TestNewACMEClient(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", NewAccount: srv.URL + "/bar", NewOrder: srv.URL + "/baz", - NewAuthz: srv.URL + "/zap", RevokeCert: srv.URL + "/zip", KeyChange: srv.URL + "/blorp", } acc := acme.Account{ - Contact: []string{"max", "mariano"}, - Status: "valid", - Orders: "orders-url", + Contact: []string{"max", "mariano"}, + Status: "valid", + OrdersURL: "orders-url", } tests := map[string]func(t *testing.T) test{ "fail/client-option-error": func(t *testing.T) test { @@ -58,7 +57,7 @@ func TestNewACMEClient(t *testing.T) { "fail/get-directory": func(t *testing.T) test { return test{ ops: []ClientOption{WithTransport(http.DefaultTransport)}, - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -76,7 +75,7 @@ func TestNewACMEClient(t *testing.T) { ops: []ClientOption{WithTransport(http.DefaultTransport)}, r1: dir, rc1: 200, - r2: acme.AccountDoesNotExistErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), rc2: 400, err: errors.New("Account does not exist"), } @@ -142,11 +141,10 @@ func TestNewACMEClient(t *testing.T) { func TestACMEClient_GetDirectory(t *testing.T) { c := &ACMEClient{ - dir: &acme.Directory{ + dir: &acmeAPI.Directory{ NewNonce: "/foo", NewAccount: "/bar", NewOrder: "/baz", - NewAuthz: "/zap", RevokeCert: "/zip", KeyChange: "/blorp", }, @@ -166,7 +164,7 @@ func TestACMEClient_GetNonce(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -185,7 +183,7 @@ func TestACMEClient_GetNonce(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/GET-nonce": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -237,7 +235,7 @@ func TestACMEClient_post(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -248,9 +246,9 @@ func TestACMEClient_post(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := acme.Account{ - Contact: []string{"max", "mariano"}, - Status: "valid", - Orders: "orders-url", + Contact: []string{"max", "mariano"}, + Status: "valid", + OrdersURL: "orders-url", } ac := &ACMEClient{ client: &http.Client{ @@ -266,7 +264,7 @@ func TestACMEClient_post(t *testing.T) { "fail/account-not-configured": func(t *testing.T) test { return test{ client: &ACMEClient{}, - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("acme client not configured with account"), } @@ -274,7 +272,7 @@ func TestACMEClient_post(t *testing.T) { "fail/GET-nonce": func(t *testing.T) test { return test{ client: ac, - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -365,7 +363,7 @@ func TestACMEClient_NewOrder(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", NewOrder: srv.URL + "/bar", } @@ -376,20 +374,21 @@ func TestACMEClient_NewOrder(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) + now := time.Now().UTC().Round(time.Second) nor := acmeAPI.NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com"}, {Type: "dns", Value: "acme.example.com"}, }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Minute), + NotBefore: now, + NotAfter: now.Add(time.Minute), } norb, err := json.Marshal(nor) assert.FatalError(t, err) ord := acme.Order{ - Status: "valid", - Expires: "soon", - Finalize: "finalize-url", + Status: "valid", + ExpiresAt: now, // "soon" + FinalizeURL: "finalize-url", } ac := &ACMEClient{ client: &http.Client{ @@ -404,7 +403,7 @@ func TestACMEClient_NewOrder(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -413,7 +412,7 @@ func TestACMEClient_NewOrder(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, ops: []withHeaderOption{withKid(ac)}, err: errors.New("The request message was malformed"), @@ -498,7 +497,7 @@ func TestACMEClient_GetOrder(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -509,9 +508,9 @@ func TestACMEClient_GetOrder(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ord := acme.Order{ - Status: "valid", - Expires: "soon", - Finalize: "finalize-url", + Status: "valid", + ExpiresAt: time.Now().UTC().Round(time.Second), // "soon" + FinalizeURL: "finalize-url", } ac := &ACMEClient{ client: &http.Client{ @@ -526,7 +525,7 @@ func TestACMEClient_GetOrder(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -535,7 +534,7 @@ func TestACMEClient_GetOrder(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -618,7 +617,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -628,9 +627,9 @@ func TestACMEClient_GetAuthz(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - az := acme.Authz{ + az := acme.Authorization{ Status: "valid", - Expires: "soon", + ExpiresAt: time.Now().UTC().Round(time.Second), Identifier: acme.Identifier{Type: "dns", Value: "example.com"}, } ac := &ACMEClient{ @@ -646,7 +645,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -655,7 +654,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -738,7 +737,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -766,7 +765,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -775,7 +774,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -859,7 +858,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -887,7 +886,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -896,7 +895,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -976,7 +975,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -987,10 +986,10 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ord := acme.Order{ - Status: "valid", - Expires: "soon", - Finalize: "finalize-url", - Certificate: "cert-url", + Status: "valid", + ExpiresAt: time.Now(), // "soon" + FinalizeURL: "finalize-url", + CertificateURL: "cert-url", } _csr, err := pemutil.Read("../authority/testdata/certs/foo.csr") assert.FatalError(t, err) @@ -1012,7 +1011,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -1021,7 +1020,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -1101,7 +1100,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -1121,9 +1120,9 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { Key: jwk, kid: "foobar", acc: &acme.Account{ - Contact: []string{"max", "mariano"}, - Status: "valid", - Orders: srv.URL + "/orders-url", + Contact: []string{"max", "mariano"}, + Status: "valid", + OrdersURL: srv.URL + "/orders-url", }, } @@ -1137,7 +1136,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { "fail/client-post": func(t *testing.T) test { return test{ client: ac, - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -1147,7 +1146,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { client: ac, r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } @@ -1198,7 +1197,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) - assert.Equals(t, jwsURL, ac.acc.Orders) + assert.Equals(t, jwsURL, ac.acc.OrdersURL) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) @@ -1232,7 +1231,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { srv := httptest.NewServer(nil) defer srv.Close() - dir := acme.Directory{ + dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. @@ -1259,16 +1258,16 @@ func TestACMEClient_GetCertificate(t *testing.T) { Key: jwk, kid: "foobar", acc: &acme.Account{ - Contact: []string{"max", "mariano"}, - Status: "valid", - Orders: srv.URL + "/orders-url", + Contact: []string{"max", "mariano"}, + Status: "valid", + OrdersURL: srv.URL + "/orders-url", }, } tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ - r1: acme.MalformedErr(nil).ToACME(), + r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } @@ -1277,7 +1276,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { return test{ r1: []byte{}, rc1: 200, - r2: acme.MalformedErr(nil).ToACME(), + r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } diff --git a/ca/ca.go b/ca/ca.go index c4e79268..7c723c73 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" + acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/db" @@ -141,23 +142,29 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { } prefix := "acme" - acmeAuth, err := acme.New(auth, acme.AuthorityOptions{ + var acmeDB acme.DB + if config.DB == nil { + acmeDB = nil + } else { + acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) + if err != nil { + return nil, errors.Wrap(err, "error configuring ACME DB interface") + } + } + acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ Backdate: *config.AuthorityConfig.Backdate, - DB: auth.GetDatabase().(nosql.DB), + DB: acmeDB, DNS: dns, Prefix: prefix, + CA: auth, }) - if err != nil { - return nil, errors.Wrap(err, "error creating ACME authority") - } - acmeRouterHandler := acmeAPI.New(acmeAuth) mux.Route("/"+prefix, func(r chi.Router) { - acmeRouterHandler.Route(r) + acmeHandler.Route(r) }) // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 // of the ACME spec. mux.Route("/2.0/"+prefix, func(r chi.Router) { - acmeRouterHandler.Route(r) + acmeHandler.Route(r) }) /*