Merge pull request #281 from smallstep/max/acmeOrders

Only retain `pending' orders in the `acme_account_orders_index`
This commit is contained in:
Max 2020-06-01 13:16:05 -07:00 committed by GitHub
commit 619f6f6ce0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 329 additions and 79 deletions

View file

@ -198,17 +198,46 @@ func getAccountByKeyID(db nosql.DB, kid string) (*account, error) {
// getOrderIDsByAccount retrieves a list of Order IDs that were created by the // getOrderIDsByAccount retrieves a list of Order IDs that were created by the
// account. // account.
func getOrderIDsByAccount(db nosql.DB, id string) ([]string, error) { func getOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) {
b, err := db.Get(ordersByAccountIDTable, []byte(id)) b, err := db.Get(ordersByAccountIDTable, []byte(accID))
if err != nil { if err != nil {
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return []string{}, nil return []string{}, nil
} }
return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", id)) return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID))
} }
var orderIDs []string var oids []string
if err := json.Unmarshal(b, &orderIDs); err != nil { if err := json.Unmarshal(b, &oids); err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", id)) return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID))
} }
return orderIDs, nil
// Remove any order that is not in PENDING state and update the stored list
// before returning.
//
// According to RFC 8555:
// The server SHOULD include pending orders and SHOULD NOT include orders
// that are invalid in the array of URLs.
pendOids := []string{}
for _, oid := range oids {
o, err := getOrder(db, oid)
if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID))
}
if o, err = o.updateStatus(db); err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID))
}
if o.Status == StatusPending {
pendOids = append(pendOids, oid)
}
}
// If the number of pending orders is less than the number of orders in the
// list, then update the pending order list.
if len(pendOids) != len(oids) {
if err = orderIDs(pendOids).save(db, oids, accID); err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+
"len(orderIDs) = %d", len(pendOids)))
}
}
return pendOids, nil
} }

View file

@ -251,7 +251,7 @@ func TestGetAccountByKeyID(t *testing.T) {
} }
} }
func TestGetAccountIDsByAccount(t *testing.T) { func Test_getOrderIDsByAccount(t *testing.T) {
type test struct { type test struct {
id string id string
db nosql.DB db nosql.DB
@ -294,22 +294,236 @@ func TestGetAccountIDsByAccount(t *testing.T) {
err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")), err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")),
} }
}, },
"ok": func(t *testing.T) test { "fail/error-loading-order-from-order-IDs": func(t *testing.T) test {
oids := []string{"foo", "bar", "baz"} oids := []string{"o1", "o2", "o3"}
b, err := json.Marshal(oids) boids, err := json.Marshal(oids)
assert.FatalError(t, err) assert.FatalError(t, err)
dbHit := 0
return test{ return test{
id: "foo", id: "foo",
db: &db.MockNoSQLDB{ db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, ordersByAccountIDTable) dbHit++
assert.Equals(t, key, []byte("foo")) switch dbHit {
return b, nil case 1:
assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte("foo"))
return boids, nil
case 2:
assert.Equals(t, bucket, orderTable)
assert.Equals(t, key, []byte("o1"))
return nil, errors.New("force")
default:
assert.FatalError(t, errors.New("should not be here"))
return nil, nil
}
},
},
err: ServerInternalErr(errors.New("error loading order o1 for account foo: error loading order o1: force")),
}
},
"fail/error-updating-order-from-order-IDs": func(t *testing.T) test {
oids := []string{"o1", "o2", "o3"}
boids, err := json.Marshal(oids)
assert.FatalError(t, err)
o, err := newO()
assert.FatalError(t, err)
bo, err := json.Marshal(o)
assert.FatalError(t, err)
dbHit := 0
return test{
id: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
dbHit++
switch dbHit {
case 1:
assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte("foo"))
return boids, nil
case 2:
assert.Equals(t, bucket, orderTable)
assert.Equals(t, key, []byte("o1"))
return bo, nil
case 3:
assert.Equals(t, bucket, authzTable)
assert.Equals(t, key, []byte(o.Authorizations[0]))
return nil, errors.New("force")
default:
assert.FatalError(t, errors.New("should not be here"))
return nil, nil
}
},
},
err: ServerInternalErr(errors.Errorf("error updating order o1 for account foo: error loading authz %s: force", o.Authorizations[0])),
}
},
"ok/no-change-to-pending-orders": func(t *testing.T) test {
oids := []string{"o1", "o2", "o3"}
boids, err := json.Marshal(oids)
assert.FatalError(t, err)
o, err := newO()
assert.FatalError(t, err)
bo, err := json.Marshal(o)
assert.FatalError(t, err)
az, err := newAz()
assert.FatalError(t, err)
baz, err := json.Marshal(az)
assert.FatalError(t, err)
ch, err := newDNSCh()
assert.FatalError(t, err)
bch, err := json.Marshal(ch)
assert.FatalError(t, err)
return test{
id: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
switch string(bucket) {
case string(ordersByAccountIDTable):
assert.Equals(t, key, []byte("foo"))
return boids, nil
case string(orderTable):
return bo, nil
case string(authzTable):
return baz, nil
case string(challengeTable):
return bch, nil
default:
assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket))
return nil, nil
}
},
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("should not be attempting to store anything")
}, },
}, },
res: oids, res: oids,
} }
}, },
"fail/error-storing-new-oids": func(t *testing.T) test {
oids := []string{"o1", "o2", "o3"}
boids, err := json.Marshal(oids)
assert.FatalError(t, err)
o, err := newO()
assert.FatalError(t, err)
bo, err := json.Marshal(o)
assert.FatalError(t, err)
invalidOrder, err := newO()
assert.FatalError(t, err)
invalidOrder.Status = StatusInvalid
binvalidOrder, err := json.Marshal(invalidOrder)
assert.FatalError(t, err)
az, err := newAz()
assert.FatalError(t, err)
baz, err := json.Marshal(az)
assert.FatalError(t, err)
ch, err := newDNSCh()
assert.FatalError(t, err)
bch, err := json.Marshal(ch)
assert.FatalError(t, err)
dbGetOrder := 0
return test{
id: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
switch string(bucket) {
case string(ordersByAccountIDTable):
assert.Equals(t, key, []byte("foo"))
return boids, nil
case string(orderTable):
dbGetOrder++
if dbGetOrder == 1 {
return binvalidOrder, nil
}
return bo, nil
case string(authzTable):
return baz, nil
case string(challengeTable):
return bch, nil
default:
assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket))
return nil, nil
}
},
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte("foo"))
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing orderIDs as part of getOrderIDsByAccount logic: len(orderIDs) = 2: error storing order IDs for account foo: force")),
}
},
"ok": func(t *testing.T) test {
oids := []string{"o1", "o2", "o3", "o4"}
boids, err := json.Marshal(oids)
assert.FatalError(t, err)
o, err := newO()
assert.FatalError(t, err)
bo, err := json.Marshal(o)
assert.FatalError(t, err)
invalidOrder, err := newO()
assert.FatalError(t, err)
invalidOrder.Status = StatusInvalid
binvalidOrder, err := json.Marshal(invalidOrder)
assert.FatalError(t, err)
az, err := newAz()
assert.FatalError(t, err)
baz, err := json.Marshal(az)
assert.FatalError(t, err)
ch, err := newDNSCh()
assert.FatalError(t, err)
bch, err := json.Marshal(ch)
assert.FatalError(t, err)
dbGetOrder := 0
return test{
id: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
switch string(bucket) {
case string(ordersByAccountIDTable):
assert.Equals(t, key, []byte("foo"))
return boids, nil
case string(orderTable):
dbGetOrder++
if dbGetOrder == 1 || dbGetOrder == 3 {
return binvalidOrder, nil
}
return bo, nil
case string(authzTable):
return baz, nil
case string(challengeTable):
return bch, nil
default:
assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket))
return nil, nil
}
},
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte("foo"))
return nil, true, nil
},
},
res: []string{"o2", "o4"},
}
},
} }
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {

View file

@ -240,14 +240,7 @@ func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string
var ret = []string{} var ret = []string{}
for _, oid := range oids { for _, oid := range oids {
o, err := getOrder(a.db, oid) ret = append(ret, a.dir.getLink(ctx, OrderLink, true, oid))
if err != nil {
return nil, ServerInternalErr(err)
}
if o.Status == StatusInvalid {
continue
}
ret = append(ret, a.dir.getLink(ctx, OrderLink, true, o.ID))
} }
return ret, nil return ret, nil
} }

View file

@ -1092,56 +1092,78 @@ func TestAuthorityGetOrdersByAccount(t *testing.T) {
return test{ return test{
auth: auth, auth: auth,
id: id, id: id,
err: ServerInternalErr(errors.New("error loading order foo: force")), err: ServerInternalErr(errors.New("error loading order foo for account zap: error loading order foo: force")),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
var ( accID := "zap"
id = "zap"
count = 0
err error
)
foo, err := newO()
bar, err := newO()
baz, err := newO()
bar.Status = StatusInvalid
foo, err := newO()
assert.FatalError(t, err)
bfoo, err := json.Marshal(foo)
assert.FatalError(t, err)
bar, err := newO()
assert.FatalError(t, err)
bar.Status = StatusInvalid
bbar, err := json.Marshal(bar)
assert.FatalError(t, err)
zap, err := newO()
assert.FatalError(t, err)
bzap, err := json.Marshal(zap)
assert.FatalError(t, err)
az, err := newAz()
assert.FatalError(t, err)
baz, err := json.Marshal(az)
assert.FatalError(t, err)
ch, err := newDNSCh()
assert.FatalError(t, err)
bch, err := json.Marshal(ch)
assert.FatalError(t, err)
dbGetOrder := 0
auth, err := NewAuthority(&db.MockNoSQLDB{ auth, err := NewAuthority(&db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
var ret []byte switch string(bucket) {
switch count { case string(orderTable):
case 0: dbGetOrder++
switch dbGetOrder {
case 1:
return bfoo, nil
case 2:
return bbar, nil
case 3:
return bzap, nil
}
case string(ordersByAccountIDTable):
assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte(id)) assert.Equals(t, key, []byte(accID))
ret, err = json.Marshal([]string{foo.ID, bar.ID, baz.ID}) ret, err := json.Marshal([]string{foo.ID, bar.ID, zap.ID})
assert.FatalError(t, err)
case 1:
assert.Equals(t, bucket, orderTable)
assert.Equals(t, key, []byte(foo.ID))
ret, err = json.Marshal(foo)
assert.FatalError(t, err)
case 2:
assert.Equals(t, bucket, orderTable)
assert.Equals(t, key, []byte(bar.ID))
ret, err = json.Marshal(bar)
assert.FatalError(t, err)
case 3:
assert.Equals(t, bucket, orderTable)
assert.Equals(t, key, []byte(baz.ID))
ret, err = json.Marshal(baz)
assert.FatalError(t, err) assert.FatalError(t, err)
return ret, nil
case string(challengeTable):
return bch, nil
case string(authzTable):
return baz, nil
} }
count++ return nil, errors.Errorf("should not be query db table %s", bucket)
return ret, nil },
MCmpAndSwap: func(bucket, key, old, newVal []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, string(key), accID)
return nil, true, nil
}, },
}, "ca.smallstep.com", "acme", nil) }, "ca.smallstep.com", "acme", nil)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
auth: auth, auth: auth,
id: id, id: accID,
res: []string{ res: []string{
fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, foo.ID), fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, foo.ID),
fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, baz.ID), fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, zap.ID),
}, },
} }
}, },

View file

@ -230,10 +230,15 @@ func (o *order) updateStatus(db nosql.DB) (*order, error) {
switch { switch {
case count[StatusInvalid] > 0: case count[StatusInvalid] > 0:
newOrder.Status = StatusInvalid newOrder.Status = StatusInvalid
// No change in the order status, so just return the order as is -
// without writing any changes.
case count[StatusPending] > 0: case count[StatusPending] > 0:
break return newOrder, nil
case count[StatusValid] == len(o.Authorizations): case count[StatusValid] == len(o.Authorizations):
newOrder.Status = StatusReady newOrder.Status = StatusReady
default: default:
return nil, ServerInternalErr(errors.New("unexpected authz status")) return nil, ServerInternalErr(errors.New("unexpected authz status"))
} }

View file

@ -38,17 +38,13 @@ func newO() (*order, error) {
return []byte("foo"), true, nil return []byte("foo"), true, nil
}, },
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
b, err := json.Marshal([]string{"1", "2"}) return nil, database.ErrNotFound
if err != nil {
return nil, err
}
return b, nil
}, },
} }
return newOrder(mockdb, defaultOrderOps()) return newOrder(mockdb, defaultOrderOps())
} }
func TestGetOrder(t *testing.T) { func Test_getOrder(t *testing.T) {
type test struct { type test struct {
id string id string
db nosql.DB db nosql.DB
@ -363,9 +359,6 @@ func Test_newOrder(t *testing.T) {
}, },
"fail/save-orderIDs-error": func(t *testing.T) test { "fail/save-orderIDs-error": func(t *testing.T) test {
count := 0 count := 0
oids := []string{"1", "2", "3"}
oidsB, err := json.Marshal(oids)
assert.FatalError(t, err)
var ( var (
_oid = "" _oid = ""
oid = &_oid oid = &_oid
@ -386,7 +379,7 @@ func Test_newOrder(t *testing.T) {
return nil, true, nil return nil, true, nil
}, },
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
return oidsB, nil return nil, database.ErrNotFound
}, },
MDel: func(bucket, key []byte) error { MDel: func(bucket, key []byte) error {
assert.Equals(t, bucket, orderTable) assert.Equals(t, bucket, orderTable)
@ -399,9 +392,6 @@ func Test_newOrder(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
count := 0 count := 0
oids := []string{"1", "2", "3"}
oidsB, err := json.Marshal(oids)
assert.FatalError(t, err)
authzs := &([]string{}) authzs := &([]string{})
var ( var (
_oid = "" _oid = ""
@ -415,8 +405,8 @@ func Test_newOrder(t *testing.T) {
if count >= 9 { if count >= 9 {
assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte(ops.AccountID)) assert.Equals(t, key, []byte(ops.AccountID))
assert.Equals(t, old, oidsB) assert.Equals(t, old, nil)
newB, err := json.Marshal(append(oids, *oid)) newB, err := json.Marshal([]string{*oid})
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, newval, newB) assert.Equals(t, newval, newB)
} else if count == 8 { } else if count == 8 {
@ -430,7 +420,7 @@ func Test_newOrder(t *testing.T) {
return nil, true, nil return nil, true, nil
}, },
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
return oidsB, nil return nil, database.ErrNotFound
}, },
}, },
authzs: authzs, authzs: authzs,
@ -438,9 +428,6 @@ func Test_newOrder(t *testing.T) {
}, },
"ok/validity-bounds-not-set": func(t *testing.T) test { "ok/validity-bounds-not-set": func(t *testing.T) test {
count := 0 count := 0
oids := []string{"1", "2", "3"}
oidsB, err := json.Marshal(oids)
assert.FatalError(t, err)
authzs := &([]string{}) authzs := &([]string{})
var ( var (
_oid = "" _oid = ""
@ -458,8 +445,8 @@ func Test_newOrder(t *testing.T) {
if count >= 9 { if count >= 9 {
assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte(ops.AccountID)) assert.Equals(t, key, []byte(ops.AccountID))
assert.Equals(t, old, oidsB) assert.Equals(t, old, nil)
newB, err := json.Marshal(append(oids, *oid)) newB, err := json.Marshal([]string{*oid})
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, newval, newB) assert.Equals(t, newval, newB)
} else if count == 8 { } else if count == 8 {
@ -473,7 +460,7 @@ func Test_newOrder(t *testing.T) {
return nil, true, nil return nil, true, nil
}, },
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
return oidsB, nil return nil, database.ErrNotFound
}, },
}, },
authzs: authzs, authzs: authzs,