nosql account db unit tests
This commit is contained in:
parent
ce13d09dcb
commit
88e6f00347
2 changed files with 804 additions and 56 deletions
|
@ -26,6 +26,58 @@ func (dba *dbAccount) clone() *dbAccount {
|
||||||
return &nu
|
return &nu
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) {
|
||||||
|
id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
|
||||||
|
if err != nil {
|
||||||
|
if nosqlDB.IsErrNotFound(err) {
|
||||||
|
return "", acme.NewError(acme.ErrorMalformedType, "account with key-id %s not found", kid)
|
||||||
|
}
|
||||||
|
return "", errors.Wrapf(err, "error loading key-account index for key %s", kid)
|
||||||
|
}
|
||||||
|
return string(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDBAccount retrieves and unmarshals dbAccount.
|
||||||
|
func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) {
|
||||||
|
data, err := db.db.Get(accountTable, []byte(id))
|
||||||
|
if err != nil {
|
||||||
|
if nosqlDB.IsErrNotFound(err) {
|
||||||
|
return nil, acme.NewError(acme.ErrorMalformedType, "account %s not found", id)
|
||||||
|
}
|
||||||
|
return nil, errors.Wrapf(err, "error loading account %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbacc := new(dbAccount)
|
||||||
|
if err = json.Unmarshal(data, dbacc); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id)
|
||||||
|
}
|
||||||
|
return dbacc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccount retrieves an ACME account by ID.
|
||||||
|
func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
|
||||||
|
dbacc, err := db.getDBAccount(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &acme.Account{
|
||||||
|
Status: dbacc.Status,
|
||||||
|
Contact: dbacc.Contact,
|
||||||
|
Key: dbacc.Key,
|
||||||
|
ID: dbacc.ID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK).
|
||||||
|
func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) {
|
||||||
|
id, err := db.getAccountIDByKeyID(ctx, kid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return db.GetAccount(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
// CreateAccount imlements the AcmeDB.CreateAccount interface.
|
// CreateAccount imlements the AcmeDB.CreateAccount interface.
|
||||||
func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
|
func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
|
||||||
var err error
|
var err error
|
||||||
|
@ -64,36 +116,8 @@ func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccount retrieves an ACME account by ID.
|
|
||||||
func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
|
|
||||||
dbacc, err := db.getDBAccount(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &acme.Account{
|
|
||||||
Status: dbacc.Status,
|
|
||||||
Contact: dbacc.Contact,
|
|
||||||
Key: dbacc.Key,
|
|
||||||
ID: dbacc.ID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK).
|
|
||||||
func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) {
|
|
||||||
id, err := db.getAccountIDByKeyID(ctx, kid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return db.GetAccount(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateAccount imlements the AcmeDB.UpdateAccount interface.
|
// UpdateAccount imlements the AcmeDB.UpdateAccount interface.
|
||||||
func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error {
|
func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error {
|
||||||
if len(acc.ID) == 0 {
|
|
||||||
return errors.New("id cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
old, err := db.getDBAccount(ctx, acc.ID)
|
old, err := db.getDBAccount(ctx, acc.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -110,31 +134,3 @@ func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error {
|
||||||
|
|
||||||
return db.save(ctx, old.ID, nu, old, "account", accountTable)
|
return db.save(ctx, old.ID, nu, old, "account", accountTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) {
|
|
||||||
id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
|
|
||||||
if err != nil {
|
|
||||||
if nosqlDB.IsErrNotFound(err) {
|
|
||||||
return "", errors.Wrapf(err, "account with key id %s not found", kid)
|
|
||||||
}
|
|
||||||
return "", errors.Wrapf(err, "error loading key-account index")
|
|
||||||
}
|
|
||||||
return string(id), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getDBAccount retrieves and unmarshals dbAccount.
|
|
||||||
func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) {
|
|
||||||
data, err := db.db.Get(accountTable, []byte(id))
|
|
||||||
if err != nil {
|
|
||||||
if nosqlDB.IsErrNotFound(err) {
|
|
||||||
return nil, errors.Wrapf(err, "account %s not found", id)
|
|
||||||
}
|
|
||||||
return nil, errors.Wrapf(err, "error loading account %s", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
dbacc := new(dbAccount)
|
|
||||||
if err = json.Unmarshal(data, dbacc); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error unmarshaling account")
|
|
||||||
}
|
|
||||||
return dbacc, nil
|
|
||||||
}
|
|
||||||
|
|
752
acme/db/nosql/account_test.go
Normal file
752
acme/db/nosql/account_test.go
Normal file
|
@ -0,0 +1,752 @@
|
||||||
|
package nosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/nosql"
|
||||||
|
nosqldb "github.com/smallstep/nosql/database"
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDB_getDBAccount(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbacc *dbAccount
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading account accID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return []byte("foo"), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error unmarshaling account accID into dbAccount"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
dbacc := &dbAccount{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
CreatedAt: now,
|
||||||
|
DeactivatedAt: now,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbacc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbacc: dbacc,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
||||||
|
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
||||||
|
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
||||||
|
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
|
||||||
|
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
||||||
|
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
kid := "kid"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/not-found": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading key-account index for key kid: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
|
||||||
|
return []byte(accID), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, retAccID, accID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_GetAccount(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbacc *dbAccount
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading account accID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/forward-acme-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
dbacc := &dbAccount{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
CreatedAt: now,
|
||||||
|
DeactivatedAt: now,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbacc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbacc: dbacc,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if acc, err := db.GetAccount(context.Background(), accID); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||||
|
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||||
|
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||||
|
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
kid := "kid"
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
err error
|
||||||
|
acmeErr *acme.Error
|
||||||
|
dbacc *dbAccount
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.getAccountIDByKeyID-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, string(bucket), string(accountByKeyIDTable))
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading key-account index for key kid: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.getAccountIDByKeyID-forward-acme-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, string(bucket), string(accountByKeyIDTable))
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.GetAccount-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(accountByKeyIDTable):
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
return []byte(accID), nil
|
||||||
|
case string(accountTable):
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
return nil, errors.New("force")
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading account accID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.GetAccount-forward-acme-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(accountByKeyIDTable):
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
return []byte(accID), nil
|
||||||
|
case string(accountTable):
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
return nil, nosqldb.ErrNotFound
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
now := clock.Now()
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
dbacc := &dbAccount{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
CreatedAt: now,
|
||||||
|
DeactivatedAt: now,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbacc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(accountByKeyIDTable):
|
||||||
|
assert.Equals(t, string(key), kid)
|
||||||
|
return []byte(accID), nil
|
||||||
|
case string(accountTable):
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
return b, nil
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
dbacc: dbacc,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil {
|
||||||
|
switch k := err.(type) {
|
||||||
|
case *acme.Error:
|
||||||
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||||
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||||
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||||
|
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||||
|
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||||
|
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_CreateAccount(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
acc *acme.Account
|
||||||
|
err error
|
||||||
|
_id *string
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/keyID-cmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), jwk.KeyID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
assert.Equals(t, nu, []byte(acc.ID))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acc: acc,
|
||||||
|
err: errors.New("error storing keyID to accountID index: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/keyID-cmpAndSwap-false": func(t *testing.T) test {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountByKeyIDTable)
|
||||||
|
assert.Equals(t, string(key), jwk.KeyID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
assert.Equals(t, nu, []byte(acc.ID))
|
||||||
|
return nil, false, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acc: acc,
|
||||||
|
err: errors.New("key-id to account-id index already exists"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/account-save-error": func(t *testing.T) test {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(accountByKeyIDTable):
|
||||||
|
assert.Equals(t, string(key), jwk.KeyID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
return nu, true, nil
|
||||||
|
case string(accountTable):
|
||||||
|
assert.Equals(t, string(key), acc.ID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbacc := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||||
|
assert.Equals(t, dbacc.ID, string(key))
|
||||||
|
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||||
|
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||||
|
assert.True(t, dbacc.DeactivatedAt.IsZero())
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acc: acc,
|
||||||
|
err: errors.New("error saving acme account: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
var (
|
||||||
|
id string
|
||||||
|
idPtr = &id
|
||||||
|
)
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
Status: acme.StatusValid,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
id = string(key)
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(accountByKeyIDTable):
|
||||||
|
assert.Equals(t, string(key), jwk.KeyID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
return nu, true, nil
|
||||||
|
case string(accountTable):
|
||||||
|
assert.Equals(t, string(key), acc.ID)
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
|
||||||
|
dbacc := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||||
|
assert.Equals(t, dbacc.ID, string(key))
|
||||||
|
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||||
|
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||||
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||||
|
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||||
|
assert.True(t, dbacc.DeactivatedAt.IsZero())
|
||||||
|
return nu, true, nil
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
acc: acc,
|
||||||
|
_id: idPtr,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.CreateAccount(context.Background(), tc.acc); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.acc.ID, *tc._id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB_UpdateAccount(t *testing.T) {
|
||||||
|
accID := "accID"
|
||||||
|
now := clock.Now()
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
dbacc := &dbAccount{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
CreatedAt: now,
|
||||||
|
DeactivatedAt: now,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dbacc)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
type test struct {
|
||||||
|
db nosql.DB
|
||||||
|
acc *acme.Account
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
acc: &acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
},
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error loading account accID: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/already-deactivated": func(t *testing.T) test {
|
||||||
|
clone := dbacc.clone()
|
||||||
|
clone.Status = acme.StatusDeactivated
|
||||||
|
clone.DeactivatedAt = now
|
||||||
|
dbaccb, err := json.Marshal(clone)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return dbaccb, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbNew := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, clone.ID)
|
||||||
|
assert.Equals(t, dbNew.Status, clone.Status)
|
||||||
|
assert.Equals(t, dbNew.Contact, clone.Contact)
|
||||||
|
assert.Equals(t, dbNew.Key.KeyID, clone.Key.KeyID)
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, clone.CreatedAt)
|
||||||
|
assert.Equals(t, dbNew.DeactivatedAt, clone.DeactivatedAt)
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error saving acme account: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbNew := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||||
|
assert.Equals(t, dbNew.Status, acc.Status)
|
||||||
|
assert.Equals(t, dbNew.Contact, dbacc.Contact)
|
||||||
|
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||||
|
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||||
|
assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errors.New("error saving acme account: force"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{
|
||||||
|
ID: accID,
|
||||||
|
Status: acme.StatusDeactivated,
|
||||||
|
Contact: []string{"foo", "bar"},
|
||||||
|
Key: jwk,
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
acc: acc,
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, string(key), accID)
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
assert.Equals(t, bucket, accountTable)
|
||||||
|
assert.Equals(t, old, b)
|
||||||
|
|
||||||
|
dbNew := new(dbAccount)
|
||||||
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
|
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||||
|
assert.Equals(t, dbNew.Status, acc.Status)
|
||||||
|
assert.Equals(t, dbNew.Contact, dbacc.Contact)
|
||||||
|
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||||
|
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||||
|
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||||
|
assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now))
|
||||||
|
return nu, true, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
db := DB{db: tc.db}
|
||||||
|
if err := db.UpdateAccount(context.Background(), tc.acc); err != nil {
|
||||||
|
if assert.NotNil(t, tc.err) {
|
||||||
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if assert.Nil(t, tc.err) {
|
||||||
|
assert.Equals(t, tc.acc.ID, dbacc.ID)
|
||||||
|
assert.Equals(t, tc.acc.Status, dbacc.Status)
|
||||||
|
assert.Equals(t, tc.acc.Contact, dbacc.Contact)
|
||||||
|
assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue