loadOrStore -> cmpAndSwap

This commit is contained in:
max furman 2019-06-10 13:21:06 -07:00
parent 578beec25d
commit 599fc1058c
3 changed files with 22 additions and 22 deletions

4
Gopkg.lock generated
View file

@ -363,7 +363,7 @@
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:5e778214d472b6d2ad4d544d293d1478d9b222db8ffc6079623fbe3e58e1841e" digest = "1:9c1b7052fa8f2c918efd60ed5ae3c70ccbba08967c58ec71067535449a3ba220"
name = "github.com/smallstep/nosql" name = "github.com/smallstep/nosql"
packages = [ packages = [
".", ".",
@ -373,7 +373,7 @@
"mysql", "mysql",
] ]
pruneopts = "UT" pruneopts = "UT"
revision = "b66b34823456721912ba037126e92414690c07d6" revision = "a0934e12468769d8cbede3ed316c47a4b88de4ca"
[[projects]] [[projects]]
branch = "master" branch = "master"

View file

@ -131,14 +131,12 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
// UseToken returns true if we were able to successfully store the token for // UseToken returns true if we were able to successfully store the token for
// for the first time, false otherwise. // for the first time, false otherwise.
func (db *DB) UseToken(id, tok string) (bool, error) { func (db *DB) UseToken(id, tok string) (bool, error) {
// If the error is `Not Found` then the certificate has not been revoked. _, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok))
// Any other error should be propagated to the caller.
_, found, err := db.LoadOrStore(usedOTTTable, []byte(id), []byte(tok))
switch { switch {
case err != nil: case err != nil:
return false, errors.Wrapf(err, "error LoadOrStore-ing token %s/%s", return false, errors.Wrapf(err, "error storing used token %s/%s",
string(usedOTTTable), id) string(usedOTTTable), id)
case found: case !swapped:
return false, nil return false, nil
default: default:
return true, nil return true, nil

View file

@ -20,12 +20,12 @@ type MockNoSQLDB struct {
del func(bucket, key []byte) error del func(bucket, key []byte) error
list func(bucket []byte) ([]*database.Entry, error) list func(bucket []byte) ([]*database.Entry, error)
update func(tx *database.Tx) error update func(tx *database.Tx) error
loadOrStore func(bucket, key, value []byte) ([]byte, bool, error) cmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error)
} }
func (m *MockNoSQLDB) LoadOrStore(bucket, key, value []byte) ([]byte, bool, error) { func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) {
if m.get != nil { if m.cmpAndSwap != nil {
return m.loadOrStore(bucket, key, value) return m.cmpAndSwap(bucket, key, old, newval)
} }
if m.ret1 == nil { if m.ret1 == nil {
return nil, false, m.err return nil, false, m.err
@ -210,37 +210,37 @@ func TestUseToken(t *testing.T) {
db *DB db *DB
want result want result
}{ }{
"fail/force-LoadOrStore-error": { "fail/force-CmpAndSwap-error": {
id: "id", id: "id",
tok: "token", tok: "token",
db: &DB{&MockNoSQLDB{ db: &DB{&MockNoSQLDB{
loadOrStore: func(bucket, key, value []byte) ([]byte, bool, error) { cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force") return nil, false, errors.New("force")
}, },
}, true}, }, true},
want: result{ want: result{
ok: false, ok: false,
err: errors.New("error LoadOrStore-ing token id/token"), err: errors.New("error storing used token used_ott/id"),
}, },
}, },
"fail/LoadOrStore-found": { "fail/CmpAndSwap-already-exists": {
id: "id", id: "id",
tok: "token", tok: "token",
db: &DB{&MockNoSQLDB{ db: &DB{&MockNoSQLDB{
loadOrStore: func(bucket, key, value []byte) ([]byte, bool, error) { cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), true, nil return []byte("foo"), false, nil
}, },
}, true}, }, true},
want: result{ want: result{
ok: false, ok: false,
}, },
}, },
"ok/LoadOrStore-not-found": { "ok/cmpAndSwap-success": {
id: "id", id: "id",
tok: "token", tok: "token",
db: &DB{&MockNoSQLDB{ db: &DB{&MockNoSQLDB{
loadOrStore: func(bucket, key, value []byte) ([]byte, bool, error) { cmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, nil return []byte("bar"), true, nil
}, },
}, true}, }, true},
want: result{ want: result{
@ -253,11 +253,13 @@ func TestUseToken(t *testing.T) {
ok, err := tc.db.UseToken(tc.id, tc.tok) ok, err := tc.db.UseToken(tc.id, tc.tok)
if err != nil { if err != nil {
if assert.NotNil(t, tc.want.err) { if assert.NotNil(t, tc.want.err) {
assert.HasPrefix(t, tc.want.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tc.want.err.Error())
} }
assert.False(t, ok) assert.False(t, ok)
} else if ok {
assert.True(t, tc.want.ok)
} else { } else {
assert.True(t, ok) assert.False(t, tc.want.ok)
} }
}) })
} }