From 88e6f0034742ac1ba98bbd5bb48487fdcff35e81 Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 22 Mar 2021 14:46:05 -0700 Subject: [PATCH] nosql account db unit tests --- acme/db/nosql/account.go | 108 +++-- acme/db/nosql/account_test.go | 752 ++++++++++++++++++++++++++++++++++ 2 files changed, 804 insertions(+), 56 deletions(-) create mode 100644 acme/db/nosql/account_test.go diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index 0e0a7c4b..3115e8ab 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -26,6 +26,58 @@ func (dba *dbAccount) clone() *dbAccount { return &nu } +func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) { + id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) + if err != nil { + if nosqlDB.IsErrNotFound(err) { + return "", acme.NewError(acme.ErrorMalformedType, "account with key-id %s not found", kid) + } + return "", errors.Wrapf(err, "error loading key-account index for key %s", kid) + } + return string(id), nil +} + +// getDBAccount retrieves and unmarshals dbAccount. +func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) { + data, err := db.db.Get(accountTable, []byte(id)) + if err != nil { + if nosqlDB.IsErrNotFound(err) { + return nil, acme.NewError(acme.ErrorMalformedType, "account %s not found", id) + } + return nil, errors.Wrapf(err, "error loading account %s", id) + } + + dbacc := new(dbAccount) + if err = json.Unmarshal(data, dbacc); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id) + } + return dbacc, nil +} + +// GetAccount retrieves an ACME account by ID. +func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) { + dbacc, err := db.getDBAccount(ctx, id) + if err != nil { + return nil, err + } + + return &acme.Account{ + Status: dbacc.Status, + Contact: dbacc.Contact, + Key: dbacc.Key, + ID: dbacc.ID, + }, nil +} + +// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). +func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) { + id, err := db.getAccountIDByKeyID(ctx, kid) + if err != nil { + return nil, err + } + return db.GetAccount(ctx, id) +} + // CreateAccount imlements the AcmeDB.CreateAccount interface. func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error { var err error @@ -64,36 +116,8 @@ func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error { } } -// GetAccount retrieves an ACME account by ID. -func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) { - dbacc, err := db.getDBAccount(ctx, id) - if err != nil { - return nil, err - } - - return &acme.Account{ - Status: dbacc.Status, - Contact: dbacc.Contact, - Key: dbacc.Key, - ID: dbacc.ID, - }, nil -} - -// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). -func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) { - id, err := db.getAccountIDByKeyID(ctx, kid) - if err != nil { - return nil, err - } - return db.GetAccount(ctx, id) -} - // UpdateAccount imlements the AcmeDB.UpdateAccount interface. func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error { - if len(acc.ID) == 0 { - return errors.New("id cannot be empty") - } - old, err := db.getDBAccount(ctx, acc.ID) if err != nil { return err @@ -110,31 +134,3 @@ func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error { return db.save(ctx, old.ID, nu, old, "account", accountTable) } - -func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) { - id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) - if err != nil { - if nosqlDB.IsErrNotFound(err) { - return "", errors.Wrapf(err, "account with key id %s not found", kid) - } - return "", errors.Wrapf(err, "error loading key-account index") - } - return string(id), nil -} - -// getDBAccount retrieves and unmarshals dbAccount. -func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) { - data, err := db.db.Get(accountTable, []byte(id)) - if err != nil { - if nosqlDB.IsErrNotFound(err) { - return nil, errors.Wrapf(err, "account %s not found", id) - } - return nil, errors.Wrapf(err, "error loading account %s", id) - } - - dbacc := new(dbAccount) - if err = json.Unmarshal(data, dbacc); err != nil { - return nil, errors.Wrap(err, "error unmarshaling account") - } - return dbacc, nil -} diff --git a/acme/db/nosql/account_test.go b/acme/db/nosql/account_test.go new file mode 100644 index 00000000..9f889e64 --- /dev/null +++ b/acme/db/nosql/account_test.go @@ -0,0 +1,752 @@ +package nosql + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" + "go.step.sm/crypto/jose" +) + +func TestDB_getDBAccount(t *testing.T) { + accID := "accID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbacc *dbAccount + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling account accID into dbAccount"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return b, nil + }, + }, + dbacc: dbacc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, dbacc.ID, tc.dbacc.ID) + assert.Equals(t, dbacc.Status, tc.dbacc.Status) + assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt) + assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt) + assert.Equals(t, dbacc.Contact, tc.dbacc.Contact) + assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID) + } + } + }) + } +} + +func TestDB_getAccountIDByKeyID(t *testing.T) { + accID := "accID" + kid := "kid" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), kid) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), kid) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading key-account index for key kid: force"), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), kid) + + return []byte(accID), nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, retAccID, accID) + } + } + }) + } +} + +func TestDB_GetAccount(t *testing.T) { + accID := "accID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbacc *dbAccount + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + return b, nil + }, + }, + dbacc: dbacc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if acc, err := db.GetAccount(context.Background(), accID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, tc.dbacc.ID) + assert.Equals(t, acc.Status, tc.dbacc.Status) + assert.Equals(t, acc.Contact, tc.dbacc.Contact) + assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) + } + } + }) + } +} + +func TestDB_GetAccountByKeyID(t *testing.T) { + accID := "accID" + kid := "kid" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbacc *dbAccount + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.getAccountIDByKeyID-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(accountByKeyIDTable)) + assert.Equals(t, string(key), kid) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading key-account index for key kid: force"), + } + }, + "fail/db.getAccountIDByKeyID-forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(accountByKeyIDTable)) + assert.Equals(t, string(key), kid) + + return nil, nosqldb.ErrNotFound + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"), + } + }, + "fail/db.GetAccount-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), kid) + return []byte(accID), nil + case string(accountTable): + assert.Equals(t, string(key), accID) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/db.GetAccount-forward-acme-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), kid) + return []byte(accID), nil + case string(accountTable): + assert.Equals(t, string(key), accID) + return nil, nosqldb.ErrNotFound + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), kid) + return []byte(accID), nil + case string(accountTable): + assert.Equals(t, string(key), accID) + return b, nil + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + dbacc: dbacc, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, acc.ID, tc.dbacc.ID) + assert.Equals(t, acc.Status, tc.dbacc.Status) + assert.Equals(t, acc.Contact, tc.dbacc.Contact) + assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) + } + } + }) + } +} + +func TestDB_CreateAccount(t *testing.T) { + type test struct { + db nosql.DB + acc *acme.Account + err error + _id *string + } + var tests = map[string]func(t *testing.T) test{ + "fail/keyID-cmpAndSwap-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + + assert.Equals(t, nu, []byte(acc.ID)) + return nil, false, errors.New("force") + }, + }, + acc: acc, + err: errors.New("error storing keyID to accountID index: force"), + } + }, + "fail/keyID-cmpAndSwap-false": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountByKeyIDTable) + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + + assert.Equals(t, nu, []byte(acc.ID)) + return nil, false, nil + }, + }, + acc: acc, + err: errors.New("key-id to account-id index already exists"), + } + }, + "fail/account-save-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + return nu, true, nil + case string(accountTable): + assert.Equals(t, string(key), acc.ID) + assert.Equals(t, old, nil) + + dbacc := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbacc)) + assert.Equals(t, dbacc.ID, string(key)) + assert.Equals(t, dbacc.Contact, acc.Contact) + assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) + assert.True(t, dbacc.DeactivatedAt.IsZero()) + return nil, false, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + acc: acc, + err: errors.New("error saving acme account: force"), + } + }, + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + ) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + id = string(key) + switch string(bucket) { + case string(accountByKeyIDTable): + assert.Equals(t, string(key), jwk.KeyID) + assert.Equals(t, old, nil) + return nu, true, nil + case string(accountTable): + assert.Equals(t, string(key), acc.ID) + assert.Equals(t, old, nil) + + dbacc := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbacc)) + assert.Equals(t, dbacc.ID, string(key)) + assert.Equals(t, dbacc.Contact, acc.Contact) + assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) + assert.True(t, dbacc.DeactivatedAt.IsZero()) + return nu, true, nil + default: + assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + acc: acc, + _id: idPtr, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.CreateAccount(context.Background(), tc.acc); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.acc.ID, *tc._id) + } + } + }) + } +} + +func TestDB_UpdateAccount(t *testing.T) { + accID := "accID" + now := clock.Now() + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + dbacc := &dbAccount{ + ID: accID, + Status: acme.StatusDeactivated, + CreatedAt: now, + DeactivatedAt: now, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + b, err := json.Marshal(dbacc) + assert.FatalError(t, err) + type test struct { + db nosql.DB + acc *acme.Account + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + acc: &acme.Account{ + ID: accID, + }, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading account accID: force"), + } + }, + "fail/already-deactivated": func(t *testing.T) test { + clone := dbacc.clone() + clone.Status = acme.StatusDeactivated + clone.DeactivatedAt = now + dbaccb, err := json.Marshal(clone) + assert.FatalError(t, err) + acc := &acme.Account{ + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"foo", "bar"}, + } + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return dbaccb, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, old, b) + + dbNew := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, clone.ID) + assert.Equals(t, dbNew.Status, clone.Status) + assert.Equals(t, dbNew.Contact, clone.Contact) + assert.Equals(t, dbNew.Key.KeyID, clone.Key.KeyID) + assert.Equals(t, dbNew.CreatedAt, clone.CreatedAt) + assert.Equals(t, dbNew.DeactivatedAt, clone.DeactivatedAt) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme account: force"), + } + }, + "fail/db.CmpAndSwap-error": func(t *testing.T) test { + acc := &acme.Account{ + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"foo", "bar"}, + } + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, old, b) + + dbNew := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbacc.ID) + assert.Equals(t, dbNew.Status, acc.Status) + assert.Equals(t, dbNew.Contact, dbacc.Contact) + assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) + assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) + assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now)) + return nil, false, errors.New("force") + }, + }, + err: errors.New("error saving acme account: force"), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ + ID: accID, + Status: acme.StatusDeactivated, + Contact: []string{"foo", "bar"}, + Key: jwk, + } + return test{ + acc: acc, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, string(key), accID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, accountTable) + assert.Equals(t, old, b) + + dbNew := new(dbAccount) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbacc.ID) + assert.Equals(t, dbNew.Status, acc.Status) + assert.Equals(t, dbNew.Contact, dbacc.Contact) + assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) + assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) + assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now)) + return nu, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + if err := db.UpdateAccount(context.Background(), tc.acc); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.acc.ID, dbacc.ID) + assert.Equals(t, tc.acc.Status, dbacc.Status) + assert.Equals(t, tc.acc.Contact, dbacc.Contact) + assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID) + } + } + }) + } +}