diff --git a/acme/account.go b/acme/account.go index ea0e7fdc..1c5870d5 100644 --- a/acme/account.go +++ b/acme/account.go @@ -195,49 +195,3 @@ func getAccountByKeyID(db nosql.DB, kid string) (*account, error) { } return getAccountByID(db, string(id)) } - -// getOrderIDsByAccount retrieves a list of Order IDs that were created by the -// account. -func getOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) { - b, err := db.Get(ordersByAccountIDTable, []byte(accID)) - if err != nil { - if nosql.IsErrNotFound(err) { - return []string{}, nil - } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) - } - var oids []string - if err := json.Unmarshal(b, &oids); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) - } - - // Remove any order that is not in PENDING state and update the stored list - // before returning. - // - // According to RFC 8555: - // The server SHOULD include pending orders and SHOULD NOT include orders - // that are invalid in the array of URLs. - pendOids := []string{} - for _, oid := range oids { - o, err := getOrder(db, oid) - if err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) - } - if o, err = o.updateStatus(db); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) - } - if o.Status == StatusPending { - pendOids = append(pendOids, oid) - } - } - // If the number of pending orders is less than the number of orders in the - // list, then update the pending order list. - if len(pendOids) != len(oids) { - if err = orderIDs(pendOids).save(db, oids, accID); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ - "len(orderIDs) = %d", len(pendOids))) - } - } - - return pendOids, nil -} diff --git a/acme/account_test.go b/acme/account_test.go index 0008551a..2e072af5 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -251,337 +251,6 @@ func TestGetAccountByKeyID(t *testing.T) { } } -func Test_getOrderIDsByAccount(t *testing.T) { - type test struct { - id string - db nosql.DB - res []string - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/not-found": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - res: []string{}, - } - }, - "fail/db-error": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, nil - }, - }, - err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")), - } - }, - "fail/error-loading-order-from-order-IDs": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - dbHit := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - dbHit++ - switch dbHit { - case 1: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return boids, nil - case 2: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte("o1")) - return nil, errors.New("force") - default: - assert.FatalError(t, errors.New("should not be here")) - return nil, nil - } - }, - }, - err: ServerInternalErr(errors.New("error loading order o1 for account foo: error loading order o1: force")), - } - }, - "fail/error-updating-order-from-order-IDs": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - dbHit := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - dbHit++ - switch dbHit { - case 1: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return boids, nil - case 2: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte("o1")) - return bo, nil - case 3: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(o.Authorizations[0])) - return nil, errors.New("force") - default: - assert.FatalError(t, errors.New("should not be here")) - return nil, nil - } - }, - }, - err: ServerInternalErr(errors.Errorf("error updating order o1 for account foo: error loading authz %s: force", o.Authorizations[0])), - } - }, - "ok/no-change-to-pending-orders": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("should not be attempting to store anything") - }, - }, - res: oids, - } - }, - "fail/error-storing-new-oids": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - dbGetOrder++ - if dbGetOrder == 1 { - return binvalidOrder, nil - } - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing orderIDs as part of getOrderIDsByAccount logic: len(orderIDs) = 2: error storing order IDs for account foo: force")), - } - }, - "ok": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3", "o4"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - dbGetOrder++ - if dbGetOrder == 1 || dbGetOrder == 3 { - return binvalidOrder, nil - } - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, true, nil - }, - }, - res: []string{"o2", "o4"}, - } - }, - "ok/no-pending-orders": func(t *testing.T) test { - oids := []string{"o1"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - return binvalidOrder, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - assert.Equals(t, old, boids) - assert.Nil(t, newval) - return nil, true, nil - }, - }, - res: []string{}, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if oids, err := getOrderIDsByAccount(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res, oids) - } - } - }) - } -} - func TestAccountToACME(t *testing.T) { dir := newDirectory("ca.smallstep.com", "acme") prov := newProv() diff --git a/acme/authority.go b/acme/authority.go index 959dc9c4..0f5f2c9f 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -233,7 +233,11 @@ func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order // GetOrdersByAccount returns the list of order urls owned by the account. func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { - oids, err := getOrderIDsByAccount(a.db, id) + ordersByAccountMux.Lock() + defer ordersByAccountMux.Unlock() + + var oiba = orderIDsByAccount{} + oids, err := oiba.unsafeGetOrderIDsByAccount(a.db, id) if err != nil { return nil, err } diff --git a/acme/order.go b/acme/order.go index 57168419..574477ca 100644 --- a/acme/order.go +++ b/acme/order.go @@ -6,6 +6,7 @@ import ( "encoding/json" "sort" "strings" + "sync" "time" "github.com/pkg/errors" @@ -16,6 +17,9 @@ import ( var defaultOrderExpiry = time.Hour * 24 +// Mutex for locking ordersByAccount index operations. +var ordersByAccountMux sync.Mutex + // Order contains order metadata for the ACME protocol order type. type Order struct { Status string `json:"status"` @@ -111,17 +115,81 @@ func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { return nil, err } - // Update the "order IDs by account ID" index // - oids, err := getOrderIDsByAccount(db, ops.AccountID) + var oidHelper = orderIDsByAccount{} + _, err = oidHelper.addOrderID(db, ops.AccountID, o.ID) if err != nil { return nil, err } - newOids := append(oids, o.ID) - if err = orderIDs(newOids).save(db, oids, o.AccountID); err != nil { - db.Del(orderTable, []byte(o.ID)) + return o, nil +} + +type orderIDsByAccount struct{} + +// addOrderID adds an order ID to a users index of in progress order IDs. +// This method will also cull any orders that are no longer in the `pending` +// state from the index before returning it. +func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) { + ordersByAccountMux.Lock() + defer ordersByAccountMux.Unlock() + + // Update the "order IDs by account ID" index + oids, err := oiba.unsafeGetOrderIDsByAccount(db, accID) + if err != nil { return nil, err } - return o, nil + newOids := append(oids, oid) + if err = orderIDs(newOids).save(db, oids, accID); err != nil { + // Delete the entire order if storing the index fails. + db.Del(orderTable, []byte(oid)) + return nil, err + } + return newOids, nil +} + +// unsafeGetOrderIDsByAccount retrieves a list of Order IDs that were created by the +// account. +func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) { + b, err := db.Get(ordersByAccountIDTable, []byte(accID)) + if err != nil { + if nosql.IsErrNotFound(err) { + return []string{}, nil + } + return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) + } + var oids []string + if err := json.Unmarshal(b, &oids); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) + } + + // Remove any order that is not in PENDING state and update the stored list + // before returning. + // + // According to RFC 8555: + // The server SHOULD include pending orders and SHOULD NOT include orders + // that are invalid in the array of URLs. + pendOids := []string{} + for _, oid := range oids { + o, err := getOrder(db, oid) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) + } + if o, err = o.updateStatus(db); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) + } + if o.Status == StatusPending { + pendOids = append(pendOids, oid) + } + } + // If the number of pending orders is less than the number of orders in the + // list, then update the pending order list. + if len(pendOids) != len(oids) { + if err = orderIDs(pendOids).save(db, oids, accID); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ + "len(orderIDs) = %d", len(pendOids))) + } + } + + return pendOids, nil } type orderIDs []string @@ -271,6 +339,7 @@ func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAut if o, err = o.updateStatus(db); err != nil { return nil, err } + switch o.Status { case StatusInvalid: return nil, OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)) diff --git a/acme/order_test.go b/acme/order_test.go index 785b24c4..e6a8f057 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -1403,3 +1403,335 @@ func TestOrderFinalize(t *testing.T) { }) } } + +func Test_getOrderIDsByAccount(t *testing.T) { + type test struct { + id string + db nosql.DB + res []string + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/not-found": func(t *testing.T) test { + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, + res: []string{}, + } + }, + "fail/db-error": func(t *testing.T) test { + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return nil, nil + }, + }, + err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")), + } + }, + "fail/error-loading-order-from-order-IDs": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + dbHit := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + dbHit++ + switch dbHit { + case 1: + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return boids, nil + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("o1")) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.New("should not be here")) + return nil, nil + } + }, + }, + err: ServerInternalErr(errors.New("error loading order o1 for account foo: error loading order o1: force")), + } + }, + "fail/error-updating-order-from-order-IDs": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + dbHit := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + dbHit++ + switch dbHit { + case 1: + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return boids, nil + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("o1")) + return bo, nil + case 3: + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(o.Authorizations[0])) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.New("should not be here")) + return nil, nil + } + }, + }, + err: ServerInternalErr(errors.Errorf("error updating order o1 for account foo: error loading authz %s: force", o.Authorizations[0])), + } + }, + "ok/no-change-to-pending-orders": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + return bo, nil + case string(authzTable): + return baz, nil + case string(challengeTable): + return bch, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("should not be attempting to store anything") + }, + }, + res: oids, + } + }, + "fail/error-storing-new-oids": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + invalidOrder, err := newO() + assert.FatalError(t, err) + invalidOrder.Status = StatusInvalid + binvalidOrder, err := json.Marshal(invalidOrder) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + dbGetOrder := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + dbGetOrder++ + if dbGetOrder == 1 { + return binvalidOrder, nil + } + return bo, nil + case string(authzTable): + return baz, nil + case string(challengeTable): + return bch, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing orderIDs as part of getOrderIDsByAccount logic: len(orderIDs) = 2: error storing order IDs for account foo: force")), + } + }, + "ok": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3", "o4"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + invalidOrder, err := newO() + assert.FatalError(t, err) + invalidOrder.Status = StatusInvalid + binvalidOrder, err := json.Marshal(invalidOrder) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + dbGetOrder := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + dbGetOrder++ + if dbGetOrder == 1 || dbGetOrder == 3 { + return binvalidOrder, nil + } + return bo, nil + case string(authzTable): + return baz, nil + case string(challengeTable): + return bch, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return nil, true, nil + }, + }, + res: []string{"o2", "o4"}, + } + }, + "ok/no-pending-orders": func(t *testing.T) test { + oids := []string{"o1"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + invalidOrder, err := newO() + assert.FatalError(t, err) + invalidOrder.Status = StatusInvalid + binvalidOrder, err := json.Marshal(invalidOrder) + assert.FatalError(t, err) + + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + return binvalidOrder, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + assert.Equals(t, old, boids) + assert.Nil(t, newval) + return nil, true, nil + }, + }, + res: []string{}, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + var oiba = orderIDsByAccount{} + if oids, err := oiba.unsafeGetOrderIDsByAccount(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.res, oids) + } + } + }) + } +}