diff --git a/acme/account.go b/acme/account.go index e340bfa8..eeac09b9 100644 --- a/acme/account.go +++ b/acme/account.go @@ -198,17 +198,46 @@ func getAccountByKeyID(db nosql.DB, kid string) (*account, error) { // getOrderIDsByAccount retrieves a list of Order IDs that were created by the // account. -func getOrderIDsByAccount(db nosql.DB, id string) ([]string, error) { - b, err := db.Get(ordersByAccountIDTable, []byte(id)) +func getOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) { + b, err := db.Get(ordersByAccountIDTable, []byte(accID)) if err != nil { if nosql.IsErrNotFound(err) { return []string{}, nil } - return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", id)) + return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) } - var orderIDs []string - if err := json.Unmarshal(b, &orderIDs); err != nil { - return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", id)) + var oids []string + if err := json.Unmarshal(b, &oids); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) } - return orderIDs, nil + + // Remove any order that is not in PENDING state and update the stored list + // before returning. + // + // According to RFC 8555: + // The server SHOULD include pending orders and SHOULD NOT include orders + // that are invalid in the array of URLs. + pendOids := []string{} + for _, oid := range oids { + o, err := getOrder(db, oid) + if err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) + } + if o, err = o.updateStatus(db); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) + } + if o.Status == StatusPending { + pendOids = append(pendOids, oid) + } + } + // If the number of pending orders is less than the number of orders in the + // list, then update the pending order list. + if len(pendOids) != len(oids) { + if err = orderIDs(pendOids).save(db, oids, accID); err != nil { + return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ + "len(orderIDs) = %d", len(pendOids))) + } + } + + return pendOids, nil } diff --git a/acme/account_test.go b/acme/account_test.go index 91327080..f5993d70 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -251,7 +251,7 @@ func TestGetAccountByKeyID(t *testing.T) { } } -func TestGetAccountIDsByAccount(t *testing.T) { +func Test_getOrderIDsByAccount(t *testing.T) { type test struct { id string db nosql.DB @@ -294,22 +294,236 @@ func TestGetAccountIDsByAccount(t *testing.T) { err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")), } }, - "ok": func(t *testing.T) test { - oids := []string{"foo", "bar", "baz"} - b, err := json.Marshal(oids) + "fail/error-loading-order-from-order-IDs": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) assert.FatalError(t, err) + dbHit := 0 return test{ id: "foo", db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return b, nil + dbHit++ + switch dbHit { + case 1: + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return boids, nil + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("o1")) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.New("should not be here")) + return nil, nil + } + }, + }, + err: ServerInternalErr(errors.New("error loading order o1 for account foo: error loading order o1: force")), + } + }, + "fail/error-updating-order-from-order-IDs": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + dbHit := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + dbHit++ + switch dbHit { + case 1: + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return boids, nil + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("o1")) + return bo, nil + case 3: + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(o.Authorizations[0])) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.New("should not be here")) + return nil, nil + } + }, + }, + err: ServerInternalErr(errors.Errorf("error updating order o1 for account foo: error loading authz %s: force", o.Authorizations[0])), + } + }, + "ok/no-change-to-pending-orders": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + return bo, nil + case string(authzTable): + return baz, nil + case string(challengeTable): + return bch, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("should not be attempting to store anything") }, }, res: oids, } }, + "fail/error-storing-new-oids": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + invalidOrder, err := newO() + assert.FatalError(t, err) + invalidOrder.Status = StatusInvalid + binvalidOrder, err := json.Marshal(invalidOrder) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + dbGetOrder := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + dbGetOrder++ + if dbGetOrder == 1 { + return binvalidOrder, nil + } + return bo, nil + case string(authzTable): + return baz, nil + case string(challengeTable): + return bch, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing orderIDs as part of getOrderIDsByAccount logic: len(orderIDs) = 2: error storing order IDs for account foo: force")), + } + }, + "ok": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3", "o4"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + invalidOrder, err := newO() + assert.FatalError(t, err) + invalidOrder.Status = StatusInvalid + binvalidOrder, err := json.Marshal(invalidOrder) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + dbGetOrder := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + dbGetOrder++ + if dbGetOrder == 1 || dbGetOrder == 3 { + return binvalidOrder, nil + } + return bo, nil + case string(authzTable): + return baz, nil + case string(challengeTable): + return bch, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return nil, true, nil + }, + }, + res: []string{"o2", "o4"}, + } + }, } for name, run := range tests { t.Run(name, func(t *testing.T) { diff --git a/acme/authority.go b/acme/authority.go index 66bc1e00..e37835f6 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -240,14 +240,7 @@ func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string var ret = []string{} for _, oid := range oids { - o, err := getOrder(a.db, oid) - if err != nil { - return nil, ServerInternalErr(err) - } - if o.Status == StatusInvalid { - continue - } - ret = append(ret, a.dir.getLink(ctx, OrderLink, true, o.ID)) + ret = append(ret, a.dir.getLink(ctx, OrderLink, true, oid)) } return ret, nil } diff --git a/acme/authority_test.go b/acme/authority_test.go index e11b91db..19b42cb6 100644 --- a/acme/authority_test.go +++ b/acme/authority_test.go @@ -1092,56 +1092,78 @@ func TestAuthorityGetOrdersByAccount(t *testing.T) { return test{ auth: auth, id: id, - err: ServerInternalErr(errors.New("error loading order foo: force")), + err: ServerInternalErr(errors.New("error loading order foo for account zap: error loading order foo: force")), } }, "ok": func(t *testing.T) test { - var ( - id = "zap" - count = 0 - err error - ) - foo, err := newO() - bar, err := newO() - baz, err := newO() - bar.Status = StatusInvalid + accID := "zap" + foo, err := newO() + assert.FatalError(t, err) + bfoo, err := json.Marshal(foo) + assert.FatalError(t, err) + + bar, err := newO() + assert.FatalError(t, err) + bar.Status = StatusInvalid + bbar, err := json.Marshal(bar) + assert.FatalError(t, err) + + zap, err := newO() + assert.FatalError(t, err) + bzap, err := json.Marshal(zap) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + dbGetOrder := 0 auth, err := NewAuthority(&db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: + switch string(bucket) { + case string(orderTable): + dbGetOrder++ + switch dbGetOrder { + case 1: + return bfoo, nil + case 2: + return bbar, nil + case 3: + return bzap, nil + } + case string(ordersByAccountIDTable): assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte(id)) - ret, err = json.Marshal([]string{foo.ID, bar.ID, baz.ID}) - assert.FatalError(t, err) - case 1: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(foo.ID)) - ret, err = json.Marshal(foo) - assert.FatalError(t, err) - case 2: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(bar.ID)) - ret, err = json.Marshal(bar) - assert.FatalError(t, err) - case 3: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte(baz.ID)) - ret, err = json.Marshal(baz) + assert.Equals(t, key, []byte(accID)) + ret, err := json.Marshal([]string{foo.ID, bar.ID, zap.ID}) assert.FatalError(t, err) + return ret, nil + case string(challengeTable): + return bch, nil + case string(authzTable): + return baz, nil } - count++ - return ret, nil + return nil, errors.Errorf("should not be query db table %s", bucket) + }, + MCmpAndSwap: func(bucket, key, old, newVal []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, string(key), accID) + return nil, true, nil }, }, "ca.smallstep.com", "acme", nil) assert.FatalError(t, err) return test{ auth: auth, - id: id, + id: accID, res: []string{ fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, foo.ID), - fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, baz.ID), + fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, zap.ID), }, } }, diff --git a/acme/order.go b/acme/order.go index 839af337..e5b410af 100644 --- a/acme/order.go +++ b/acme/order.go @@ -230,10 +230,15 @@ func (o *order) updateStatus(db nosql.DB) (*order, error) { switch { case count[StatusInvalid] > 0: newOrder.Status = StatusInvalid + + // No change in the order status, so just return the order as is - + // without writing any changes. case count[StatusPending] > 0: - break + return newOrder, nil + case count[StatusValid] == len(o.Authorizations): newOrder.Status = StatusReady + default: return nil, ServerInternalErr(errors.New("unexpected authz status")) } diff --git a/acme/order_test.go b/acme/order_test.go index 86a4eb32..07caa50f 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -38,17 +38,13 @@ func newO() (*order, error) { return []byte("foo"), true, nil }, MGet: func(bucket, key []byte) ([]byte, error) { - b, err := json.Marshal([]string{"1", "2"}) - if err != nil { - return nil, err - } - return b, nil + return nil, database.ErrNotFound }, } return newOrder(mockdb, defaultOrderOps()) } -func TestGetOrder(t *testing.T) { +func Test_getOrder(t *testing.T) { type test struct { id string db nosql.DB @@ -363,9 +359,6 @@ func Test_newOrder(t *testing.T) { }, "fail/save-orderIDs-error": func(t *testing.T) test { count := 0 - oids := []string{"1", "2", "3"} - oidsB, err := json.Marshal(oids) - assert.FatalError(t, err) var ( _oid = "" oid = &_oid @@ -386,7 +379,7 @@ func Test_newOrder(t *testing.T) { return nil, true, nil }, MGet: func(bucket, key []byte) ([]byte, error) { - return oidsB, nil + return nil, database.ErrNotFound }, MDel: func(bucket, key []byte) error { assert.Equals(t, bucket, orderTable) @@ -399,9 +392,6 @@ func Test_newOrder(t *testing.T) { }, "ok": func(t *testing.T) test { count := 0 - oids := []string{"1", "2", "3"} - oidsB, err := json.Marshal(oids) - assert.FatalError(t, err) authzs := &([]string{}) var ( _oid = "" @@ -415,8 +405,8 @@ func Test_newOrder(t *testing.T) { if count >= 9 { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(ops.AccountID)) - assert.Equals(t, old, oidsB) - newB, err := json.Marshal(append(oids, *oid)) + assert.Equals(t, old, nil) + newB, err := json.Marshal([]string{*oid}) assert.FatalError(t, err) assert.Equals(t, newval, newB) } else if count == 8 { @@ -430,7 +420,7 @@ func Test_newOrder(t *testing.T) { return nil, true, nil }, MGet: func(bucket, key []byte) ([]byte, error) { - return oidsB, nil + return nil, database.ErrNotFound }, }, authzs: authzs, @@ -438,9 +428,6 @@ func Test_newOrder(t *testing.T) { }, "ok/validity-bounds-not-set": func(t *testing.T) test { count := 0 - oids := []string{"1", "2", "3"} - oidsB, err := json.Marshal(oids) - assert.FatalError(t, err) authzs := &([]string{}) var ( _oid = "" @@ -458,8 +445,8 @@ func Test_newOrder(t *testing.T) { if count >= 9 { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(ops.AccountID)) - assert.Equals(t, old, oidsB) - newB, err := json.Marshal(append(oids, *oid)) + assert.Equals(t, old, nil) + newB, err := json.Marshal([]string{*oid}) assert.FatalError(t, err) assert.Equals(t, newval, newB) } else if count == 8 { @@ -473,7 +460,7 @@ func Test_newOrder(t *testing.T) { return nil, true, nil }, MGet: func(bucket, key []byte) ([]byte, error) { - return oidsB, nil + return nil, database.ErrNotFound }, }, authzs: authzs,