lib/kv: add unit tests, fix races #5587

After testing concurrent calling of `kv.Start` and `db.Stop` I had to restrict
more parts of these under mutex to make results deterministic without Sleep's
in the test body. It's more safe but has potential to lock Start for up to
2 seconds due to `db.open`.
This commit is contained in:
Ivan Andreev 2021-10-13 23:13:27 +03:00
parent 50df8cec9c
commit 57c7fde864
2 changed files with 112 additions and 38 deletions

View file

@ -44,7 +44,8 @@ type DB struct {
var ( var (
dbMap = map[string]*DB{} dbMap = map[string]*DB{}
dbMut = sync.Mutex{} dbMut sync.Mutex
atExit bool
) )
// Supported returns true on supported OSes // Supported returns true on supported OSes
@ -66,7 +67,9 @@ func makeName(facility string, f fs.Fs) string {
// Start a new key-value database // Start a new key-value database
func Start(ctx context.Context, facility string, f fs.Fs) (*DB, error) { func Start(ctx context.Context, facility string, f fs.Fs) (*DB, error) {
if db := Get(facility, f); db != nil { dbMut.Lock()
defer dbMut.Unlock()
if db := lockedGet(facility, f); db != nil {
return db, nil return db, nil
} }
@ -101,44 +104,29 @@ func Start(ctx context.Context, facility string, f fs.Fs) (*DB, error) {
return nil, errors.Wrapf(err, "cannot open db: %s", db.path) return nil, errors.Wrapf(err, "cannot open db: %s", db.path)
} }
// Initialization above was performed without locks.. dbMap[name] = db
dbMut.Lock() go db.loop()
defer dbMut.Unlock()
if dbOther := dbMap[name]; dbOther != nil {
// Races between concurrent Start's are rare but possible, the 1st one wins.
_ = db.close()
return dbOther, nil
}
go db.loop() // Start queue handling
return db, nil return db, nil
} }
// Get returns database record for given filesystem and facility // Get returns database record for given filesystem and facility
func Get(facility string, f fs.Fs) *DB { func Get(facility string, f fs.Fs) *DB {
name := makeName(facility, f)
dbMut.Lock() dbMut.Lock()
defer dbMut.Unlock()
return lockedGet(facility, f)
}
func lockedGet(facility string, f fs.Fs) *DB {
name := makeName(facility, f)
db := dbMap[name] db := dbMap[name]
if db != nil { if db != nil {
db.mu.Lock() db.mu.Lock()
db.refs++ db.refs++
db.mu.Unlock() db.mu.Unlock()
} }
dbMut.Unlock()
return db return db
} }
// free database record
func (db *DB) free() {
dbMut.Lock()
db.mu.Lock()
db.refs--
if db.refs <= 0 {
delete(dbMap, db.name)
}
db.mu.Unlock()
dbMut.Unlock()
}
// Path returns database path // Path returns database path
func (db *DB) Path() string { return db.path } func (db *DB) Path() string { return db.path }
@ -201,18 +189,28 @@ func (db *DB) close() (err error) {
// loop over database operations sequentially // loop over database operations sequentially
func (db *DB) loop() { func (db *DB) loop() {
ctx := context.Background() ctx := context.Background()
for db.queue != nil { var req *request
quit := false
for !quit {
select { select {
case req := <-db.queue: case req = <-db.queue:
req.handle(ctx, db) if quit = req.handle(ctx, db); !quit {
req.wg.Done()
_ = db.idleTimer.Reset(db.idleTime) _ = db.idleTimer.Reset(db.idleTime)
}
case <-db.idleTimer.C: case <-db.idleTimer.C:
_ = db.close() _ = db.close()
case <-db.lockTimer.C: case <-db.lockTimer.C:
_ = db.close() _ = db.close()
} }
} }
db.free() db.queue = nil
if !atExit {
dbMut.Lock()
delete(dbMap, db.name)
dbMut.Unlock()
}
req.wg.Done()
} }
// Do a key-value operation and return error when done // Do a key-value operation and return error when done
@ -239,8 +237,10 @@ type request struct {
} }
// handle a key-value request with given DB // handle a key-value request with given DB
func (r *request) handle(ctx context.Context, db *DB) { // returns true as a signal to quit the loop
func (r *request) handle(ctx context.Context, db *DB) bool {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock()
if op, stop := r.op.(*opStop); stop { if op, stop := r.op.(*opStop); stop {
r.err = db.close() r.err = db.close()
if op.remove { if op.remove {
@ -248,12 +248,11 @@ func (r *request) handle(ctx context.Context, db *DB) {
r.err = err r.err = err
} }
} }
db.queue = nil db.refs--
} else { return db.refs <= 0
r.err = db.execute(ctx, r.op, r.wr)
} }
db.mu.Unlock() r.err = db.execute(ctx, r.op, r.wr)
r.wg.Done() return false
} }
// execute a key-value DB operation // execute a key-value DB operation
@ -302,11 +301,15 @@ func (*opStop) Do(context.Context, Bucket) error {
return nil return nil
} }
// Exit stops all databases // Exit immediately stops all databases
func Exit() { func Exit() {
dbMut.Lock() dbMut.Lock()
atExit = true
for _, s := range dbMap { for _, s := range dbMap {
s.refs = 0
_ = s.Stop(false) _ = s.Stop(false)
} }
dbMap = map[string]*DB{}
atExit = false
dbMut.Unlock() dbMut.Unlock()
} }

71
lib/kv/internal_test.go Normal file
View file

@ -0,0 +1,71 @@
//go:build !plan9 && !js
// +build !plan9,!js
package kv
import (
"context"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestKvConcurrency(t *testing.T) {
require.Equal(t, 0, len(dbMap), "no databases can be started initially")
const threadNum = 5
const facility = "test"
var wg sync.WaitGroup
ctx := context.Background()
results := make([]*DB, threadNum)
wg.Add(threadNum)
for i := 0; i < threadNum; i++ {
go func(i int) {
db, err := Start(ctx, "test", nil)
require.NoError(t, err)
require.NotNil(t, db)
results[i] = db
wg.Done()
}(i)
}
wg.Wait()
// must have a single multi-referenced db
db := results[0]
assert.Equal(t, 1, len(dbMap))
assert.Equal(t, threadNum, db.refs)
for i := 0; i < threadNum; i++ {
assert.Equal(t, db, results[i])
}
for i := 0; i < threadNum; i++ {
assert.Equal(t, 1, len(dbMap))
err := db.Stop(false)
assert.NoError(t, err, "unexpected error %v at retry %d", err, i)
}
assert.Equal(t, 0, len(dbMap), "must be closed in the end")
err := db.Stop(false)
assert.ErrorIs(t, err, ErrInactive, "missing expected stop indication")
}
func TestKvExit(t *testing.T) {
require.Equal(t, 0, len(dbMap), "no databases can be started initially")
const dbNum = 5
const openNum = 2
ctx := context.Background()
for i := 0; i < dbNum; i++ {
facility := fmt.Sprintf("test-%d", i)
for j := 0; j <= i; j++ {
db, err := Start(ctx, facility, nil)
require.NoError(t, err)
require.NotNil(t, db)
}
}
assert.Equal(t, dbNum, len(dbMap))
Exit()
assert.Equal(t, 0, len(dbMap))
}