[acme db interface] compiles!

This commit is contained in:
max furman 2021-03-06 13:06:43 -08:00
parent 116869ebc5
commit fc395f4d69
8 changed files with 86 additions and 114 deletions

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/logging"
@ -181,24 +182,26 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
/*
acc, err := acme.AccountFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
return
}
accID := chi.URLParam(r, "accID")
if acc.ID != accID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"))
return
}
orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID())
if err != nil {
api.WriteError(w, err)
return
}
api.JSON(w, orders)
logOrdersByAccount(w, orders)
*/
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
accID := chi.URLParam(r, "accID")
if acc.ID != accID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return
}
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil {
api.WriteError(w, err)
return
}
h.linker.LinkOrdersByAccountID(ctx, orders)
api.JSON(w, orders)
logOrdersByAccount(w, orders)
return
}

View file

@ -151,14 +151,21 @@ func (l *Linker) LinkAccount(ctx context.Context, acc *acme.Account) {
acc.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID)
}
// LinkChallenge sets the ACME links required by an ACME account.
// LinkChallenge sets the ACME links required by an ACME challenge.
func (l *Linker) LinkChallenge(ctx context.Context, ch *acme.Challenge) {
ch.URL = l.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID)
}
// LinkAuthorization sets the ACME links required by an ACME account.
// LinkAuthorization sets the ACME links required by an ACME authorization.
func (l *Linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
for _, ch := range az.Challenges {
l.LinkChallenge(ctx, ch)
}
}
// LinkOrdersByAccountID converts each order ID to an ACME link.
func (l *Linker) LinkOrdersByAccountID(ctx context.Context, orders []string) {
for i, id := range orders {
orders[i] = l.GetLink(ctx, OrderLinkType, true, id)
}
}

View file

@ -24,8 +24,7 @@ type DB interface {
UpdateChallenge(ctx context.Context, ch *Challenge) error
CreateOrder(ctx context.Context, o *Order) error
DeleteOrder(ctx context.Context, id string) error
GetOrder(ctx context.Context, id string) (*Order, error)
GetOrdersByAccountID(ctx context.Context, accountID string) error
GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error)
UpdateOrder(ctx context.Context, o *Order) error
}

View file

@ -43,7 +43,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
// DeleteNonce verifies that the nonce is valid (by checking if it exists),
// and if so, consumes the nonce resource by deleting it from the database.
func (db *DB) DeleteNonce(nonce string) error {
func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error {
err := db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{

View file

@ -44,8 +44,7 @@ func New(db nosqlDB.DB) (*DB, error) {
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
newB, err := json.Marshal(nu)
if err != nil {
return errors.Wrapf(err,
"error marshaling new acme %s", typ)
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu)
}
var oldB []byte
if old == nil {
@ -53,8 +52,7 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface
} else {
oldB, err = json.Marshal(old)
if err != nil {
return errors.Wrapf(err,
"error marshaling old acme %s", typ)
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old)
}
}
@ -63,8 +61,7 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface
case err != nil:
return errors.Wrapf(err, "error saving acme %s", typ)
case !swapped:
return errors.Errorf("error saving acme %s; "+
"changed since last read", typ)
return errors.Errorf("error saving acme %s; changed since last read", typ)
default:
return nil
}

View file

@ -29,8 +29,13 @@ type dbOrder struct {
CertificateID string `json:"certificate,omitempty"`
}
func (a *dbOrder) clone() *dbOrder {
b := *a
return &b
}
// getDBOrder retrieves and unmarshals an ACME Order type from the database.
func (db *DB) getDBOrder(id string) (*dbOrder, error) {
func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) {
b, err := db.db.Get(orderTable, []byte(id))
if nosql.IsErrNotFound(err) {
return nil, acme.WrapError(acme.ErrorMalformedType, err, "order %s not found", id)
@ -46,7 +51,7 @@ func (db *DB) getDBOrder(id string) (*dbOrder, error) {
// GetOrder retrieves an ACME Order from the database.
func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) {
dbo, err := db.getDBOrder(id)
dbo, err := db.getDBOrder(ctx, id)
if err != nil {
return nil, err
}
@ -91,8 +96,7 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error {
return err
}
var oidHelper = orderIDsByAccount{}
_, err = oidHelper.addOrderID(db, o.AccountID, o.ID)
_, err = db.updateAddOrderIDs(ctx, o.AccountID, o.ID)
if err != nil {
return err
}
@ -104,28 +108,11 @@ type orderIDsByAccount struct{}
// addOrderID adds an order ID to a users index of in progress order IDs.
// This method will also cull any orders that are no longer in the `pending`
// state from the index before returning it.
func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) {
func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...string) ([]string, error) {
ordersByAccountMux.Lock()
defer ordersByAccountMux.Unlock()
// Update the "order IDs by account ID" index
oids, err := oiba.unsafeGetOrderIDsByAccount(db, accID)
if err != nil {
return nil, err
}
newOids := append(oids, oid)
if err = orderIDs(newOids).save(db, oids, accID); err != nil {
// Delete the entire order if storing the index fails.
db.Del(orderTable, []byte(oid))
return nil, err
}
return newOids, nil
}
// unsafeGetOrderIDsByAccount retrieves a list of Order IDs that were created by the
// account.
func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) {
b, err := db.Get(ordersByAccountIDTable, []byte(accID))
b, err := db.db.Get(ordersByAccountIDTable, []byte(accID))
if err != nil {
if nosql.IsErrNotFound(err) {
return []string{}, nil
@ -145,67 +132,46 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri
// that are invalid in the array of URLs.
pendOids := []string{}
for _, oid := range oids {
o, err := getOrder(db, oid)
o, err := db.GetOrder(ctx, oid)
if err != nil {
return nil, errors.Wrapf(err, "error loading order %s for account %s", oid, accID)
return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID)
}
if o, err = o.UpdateStatus(db); err != nil {
return nil, errors.Wrapf(err, "error updating order %s for account %s", oid, accID)
if err = o.UpdateStatus(ctx, db); err != nil {
return nil, acme.WrapErrorISE(err, "error updating order %s for account %s", oid, accID)
}
if o.Status == acme.StatusPending {
pendOids = append(pendOids, oid)
}
}
// If the number of pending orders is less than the number of orders in the
// list, then update the pending order list.
if len(pendOids) != len(oids) {
if err = orderIDs(pendOids).save(db, oids, accID); err != nil {
return nil, errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+
"len(orderIDs) = %d", len(pendOids))
}
pendOids = append(pendOids, addOids...)
if len(oids) == 0 {
oids = nil
}
if err = db.save(ctx, accID, pendOids, oids, "orderIDsByAccountID", ordersByAccountIDTable); err != nil {
// Delete all orders that may have been previously stored if orderIDsByAccountID update fails.
for _, oid := range addOids {
db.db.Del(orderTable, []byte(oid))
}
return nil, errors.Wrap(err, "error saving OrderIDsByAccountID index")
}
return pendOids, nil
}
type orderIDs []string
// save is used to update the list of orderIDs keyed by ACME account ID
// stored in the database.
//
// This method always converts empty lists to 'nil' when storing to the DB. We
// do this to avoid any confusion between an empty list and a nil value in the
// db.
func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
var (
err error
oldb []byte
newb []byte
)
if len(old) == 0 {
oldb = nil
} else {
oldb, err = json.Marshal(old)
if err != nil {
return errors.Wrap(err, "error marshaling old order IDs slice")
}
}
if len(oids) == 0 {
newb = nil
} else {
newb, err = json.Marshal(oids)
if err != nil {
return errors.Wrap(err, "error marshaling new order IDs slice")
}
}
_, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb)
switch {
case err != nil:
return errors.Wrapf(err, "error storing order IDs for account %s", accID)
case !swapped:
return errors.Errorf("error storing order IDs "+
"for account %s; order IDs changed since last read", accID)
default:
return nil
}
func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) {
return db.updateAddOrderIDs(ctx, accID)
}
// UpdateOrder saves an updated ACME Order to the database.
func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error {
old, err := db.getDBOrder(ctx, o.ID)
if err != nil {
return err
}
nu := old.clone()
nu.Status = o.Status
nu.Error = o.Error
nu.CertificateID = o.CertificateID
return db.save(ctx, old.ID, nu, old, "order", orderTable)
}

View file

@ -25,7 +25,7 @@ type Order struct {
Identifiers []Identifier `json:"identifiers"`
NotBefore time.Time `json:"notBefore,omitempty"`
NotAfter time.Time `json:"notAfter,omitempty"`
Error interface{} `json:"error,omitempty"`
Error *Error `json:"error,omitempty"`
AuthorizationIDs []string `json:"-"`
AuthorizationURLs []string `json:"authorizations"`
FinalizeURL string `json:"finalize"`

View file

@ -124,24 +124,24 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) {
}
prefix := "acme"
acmeAuth, err := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
acmeDB, err := acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
if err != nil {
return nil, errors.Wrap(err, "error configuring ACME DB interface")
}
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
Backdate: *config.AuthorityConfig.Backdate,
DB: acmeNoSQL.New(auth.GetDatabase().(nosql.DB)),
DB: acmeDB,
DNS: dns,
Prefix: prefix,
CA: auth,
})
if err != nil {
return nil, errors.Wrap(err, "error creating ACME authority")
}
acmeRouterHandler := acmeAPI.New(acmeAuth)
mux.Route("/"+prefix, func(r chi.Router) {
acmeRouterHandler.Route(r)
acmeHandler.Route(r)
})
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
// of the ACME spec.
mux.Route("/2.0/"+prefix, func(r chi.Router) {
acmeRouterHandler.Route(r)
acmeHandler.Route(r)
})
/*