certificates/acme/db/nosql/account.go

137 lines
3.5 KiB
Go
Raw Permalink Normal View History

2021-02-25 18:24:24 +00:00
package nosql
import (
"context"
"encoding/json"
"time"
"github.com/pkg/errors"
2021-03-01 06:49:20 +00:00
"github.com/smallstep/certificates/acme"
2021-02-25 18:24:24 +00:00
nosqlDB "github.com/smallstep/nosql"
"go.step.sm/crypto/jose"
)
// dbAccount represents an ACME account.
type dbAccount struct {
ID string `json:"id"`
Key *jose.JSONWebKey `json:"key"`
Contact []string `json:"contact,omitempty"`
Status acme.Status `json:"status"`
2021-03-29 19:04:14 +00:00
CreatedAt time.Time `json:"createdAt"`
DeactivatedAt time.Time `json:"deactivatedAt"`
2021-02-25 18:24:24 +00:00
}
func (dba *dbAccount) clone() *dbAccount {
nu := *dba
return &nu
}
2023-05-10 06:47:28 +00:00
func (db *DB) getAccountIDByKeyID(_ context.Context, kid string) (string, error) {
2021-03-22 21:46:05 +00:00
id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
if err != nil {
if nosqlDB.IsErrNotFound(err) {
return "", acme.ErrNotFound
2021-03-22 21:46:05 +00:00
}
return "", errors.Wrapf(err, "error loading key-account index for key %s", kid)
}
return string(id), nil
}
// getDBAccount retrieves and unmarshals dbAccount.
2023-05-10 06:47:28 +00:00
func (db *DB) getDBAccount(_ context.Context, id string) (*dbAccount, error) {
2021-03-22 21:46:05 +00:00
data, err := db.db.Get(accountTable, []byte(id))
if err != nil {
if nosqlDB.IsErrNotFound(err) {
return nil, acme.ErrNotFound
2021-03-22 21:46:05 +00:00
}
return nil, errors.Wrapf(err, "error loading account %s", id)
}
dbacc := new(dbAccount)
if err = json.Unmarshal(data, dbacc); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id)
}
return dbacc, nil
}
// GetAccount retrieves an ACME account by ID.
func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
dbacc, err := db.getDBAccount(ctx, id)
if err != nil {
return nil, err
}
return &acme.Account{
Status: dbacc.Status,
Contact: dbacc.Contact,
Key: dbacc.Key,
ID: dbacc.ID,
}, nil
}
// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK).
func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) {
id, err := db.getAccountIDByKeyID(ctx, kid)
if err != nil {
return nil, err
}
return db.GetAccount(ctx, id)
}
2021-02-25 18:24:24 +00:00
// CreateAccount imlements the AcmeDB.CreateAccount interface.
2021-03-01 06:49:20 +00:00
func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
var err error
2021-02-28 01:05:37 +00:00
acc.ID, err = randID()
2021-02-25 18:24:24 +00:00
if err != nil {
2021-03-01 06:49:20 +00:00
return err
2021-02-25 18:24:24 +00:00
}
dba := &dbAccount{
ID: acc.ID,
Key: acc.Key,
Contact: acc.Contact,
Status: acc.Status,
CreatedAt: clock.Now(),
2021-02-25 18:24:24 +00:00
}
2021-03-01 06:49:20 +00:00
kid, err := acme.KeyToID(dba.Key)
2021-02-25 18:24:24 +00:00
if err != nil {
return err
}
kidB := []byte(kid)
// Set the jwkID -> acme account ID index
2021-03-01 06:49:20 +00:00
_, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(acc.ID))
2021-02-25 18:24:24 +00:00
switch {
case err != nil:
2021-03-01 06:49:20 +00:00
return errors.Wrap(err, "error storing keyID to accountID index")
2021-02-25 18:24:24 +00:00
case !swapped:
2021-03-01 06:49:20 +00:00
return errors.Errorf("key-id to account-id index already exists")
2021-02-25 18:24:24 +00:00
default:
2021-02-28 01:05:37 +00:00
if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil {
2021-02-25 18:24:24 +00:00
db.db.Del(accountByKeyIDTable, kidB)
return err
}
return nil
}
}
// UpdateAccount imlements the AcmeDB.UpdateAccount interface.
2021-03-01 06:49:20 +00:00
func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error {
2021-02-28 01:05:37 +00:00
old, err := db.getDBAccount(ctx, acc.ID)
2021-02-25 18:24:24 +00:00
if err != nil {
return err
}
nu := old.clone()
2021-03-01 06:49:20 +00:00
nu.Contact = acc.Contact
nu.Status = acc.Status
2021-02-25 18:24:24 +00:00
// If the status has changed to 'deactivated', then set deactivatedAt timestamp.
2021-03-01 06:49:20 +00:00
if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated {
nu.DeactivatedAt = clock.Now()
2021-02-25 18:24:24 +00:00
}
2021-03-01 06:49:20 +00:00
return db.save(ctx, old.ID, nu, old, "account", accountTable)
2021-02-25 18:24:24 +00:00
}