lib/kv: add unit tests, fix races #5587
After testing concurrent calling of `kv.Start` and `db.Stop` I had to restrict more parts of these under mutex to make results deterministic without Sleep's in the test body. It's more safe but has potential to lock Start for up to 2 seconds due to `db.open`.
This commit is contained in:
parent
50df8cec9c
commit
57c7fde864
2 changed files with 112 additions and 38 deletions
|
@ -44,7 +44,8 @@ type DB struct {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
dbMap = map[string]*DB{}
|
dbMap = map[string]*DB{}
|
||||||
dbMut = sync.Mutex{}
|
dbMut sync.Mutex
|
||||||
|
atExit bool
|
||||||
)
|
)
|
||||||
|
|
||||||
// Supported returns true on supported OSes
|
// Supported returns true on supported OSes
|
||||||
|
@ -66,7 +67,9 @@ func makeName(facility string, f fs.Fs) string {
|
||||||
|
|
||||||
// Start a new key-value database
|
// Start a new key-value database
|
||||||
func Start(ctx context.Context, facility string, f fs.Fs) (*DB, error) {
|
func Start(ctx context.Context, facility string, f fs.Fs) (*DB, error) {
|
||||||
if db := Get(facility, f); db != nil {
|
dbMut.Lock()
|
||||||
|
defer dbMut.Unlock()
|
||||||
|
if db := lockedGet(facility, f); db != nil {
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,44 +104,29 @@ func Start(ctx context.Context, facility string, f fs.Fs) (*DB, error) {
|
||||||
return nil, errors.Wrapf(err, "cannot open db: %s", db.path)
|
return nil, errors.Wrapf(err, "cannot open db: %s", db.path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialization above was performed without locks..
|
dbMap[name] = db
|
||||||
dbMut.Lock()
|
go db.loop()
|
||||||
defer dbMut.Unlock()
|
|
||||||
if dbOther := dbMap[name]; dbOther != nil {
|
|
||||||
// Races between concurrent Start's are rare but possible, the 1st one wins.
|
|
||||||
_ = db.close()
|
|
||||||
return dbOther, nil
|
|
||||||
}
|
|
||||||
go db.loop() // Start queue handling
|
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get returns database record for given filesystem and facility
|
// Get returns database record for given filesystem and facility
|
||||||
func Get(facility string, f fs.Fs) *DB {
|
func Get(facility string, f fs.Fs) *DB {
|
||||||
name := makeName(facility, f)
|
|
||||||
dbMut.Lock()
|
dbMut.Lock()
|
||||||
|
defer dbMut.Unlock()
|
||||||
|
return lockedGet(facility, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func lockedGet(facility string, f fs.Fs) *DB {
|
||||||
|
name := makeName(facility, f)
|
||||||
db := dbMap[name]
|
db := dbMap[name]
|
||||||
if db != nil {
|
if db != nil {
|
||||||
db.mu.Lock()
|
db.mu.Lock()
|
||||||
db.refs++
|
db.refs++
|
||||||
db.mu.Unlock()
|
db.mu.Unlock()
|
||||||
}
|
}
|
||||||
dbMut.Unlock()
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// free database record
|
|
||||||
func (db *DB) free() {
|
|
||||||
dbMut.Lock()
|
|
||||||
db.mu.Lock()
|
|
||||||
db.refs--
|
|
||||||
if db.refs <= 0 {
|
|
||||||
delete(dbMap, db.name)
|
|
||||||
}
|
|
||||||
db.mu.Unlock()
|
|
||||||
dbMut.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Path returns database path
|
// Path returns database path
|
||||||
func (db *DB) Path() string { return db.path }
|
func (db *DB) Path() string { return db.path }
|
||||||
|
|
||||||
|
@ -201,18 +189,28 @@ func (db *DB) close() (err error) {
|
||||||
// loop over database operations sequentially
|
// loop over database operations sequentially
|
||||||
func (db *DB) loop() {
|
func (db *DB) loop() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
for db.queue != nil {
|
var req *request
|
||||||
|
quit := false
|
||||||
|
for !quit {
|
||||||
select {
|
select {
|
||||||
case req := <-db.queue:
|
case req = <-db.queue:
|
||||||
req.handle(ctx, db)
|
if quit = req.handle(ctx, db); !quit {
|
||||||
|
req.wg.Done()
|
||||||
_ = db.idleTimer.Reset(db.idleTime)
|
_ = db.idleTimer.Reset(db.idleTime)
|
||||||
|
}
|
||||||
case <-db.idleTimer.C:
|
case <-db.idleTimer.C:
|
||||||
_ = db.close()
|
_ = db.close()
|
||||||
case <-db.lockTimer.C:
|
case <-db.lockTimer.C:
|
||||||
_ = db.close()
|
_ = db.close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
db.free()
|
db.queue = nil
|
||||||
|
if !atExit {
|
||||||
|
dbMut.Lock()
|
||||||
|
delete(dbMap, db.name)
|
||||||
|
dbMut.Unlock()
|
||||||
|
}
|
||||||
|
req.wg.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do a key-value operation and return error when done
|
// Do a key-value operation and return error when done
|
||||||
|
@ -239,8 +237,10 @@ type request struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle a key-value request with given DB
|
// handle a key-value request with given DB
|
||||||
func (r *request) handle(ctx context.Context, db *DB) {
|
// returns true as a signal to quit the loop
|
||||||
|
func (r *request) handle(ctx context.Context, db *DB) bool {
|
||||||
db.mu.Lock()
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
if op, stop := r.op.(*opStop); stop {
|
if op, stop := r.op.(*opStop); stop {
|
||||||
r.err = db.close()
|
r.err = db.close()
|
||||||
if op.remove {
|
if op.remove {
|
||||||
|
@ -248,12 +248,11 @@ func (r *request) handle(ctx context.Context, db *DB) {
|
||||||
r.err = err
|
r.err = err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
db.queue = nil
|
db.refs--
|
||||||
} else {
|
return db.refs <= 0
|
||||||
r.err = db.execute(ctx, r.op, r.wr)
|
|
||||||
}
|
}
|
||||||
db.mu.Unlock()
|
r.err = db.execute(ctx, r.op, r.wr)
|
||||||
r.wg.Done()
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute a key-value DB operation
|
// execute a key-value DB operation
|
||||||
|
@ -302,11 +301,15 @@ func (*opStop) Do(context.Context, Bucket) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exit stops all databases
|
// Exit immediately stops all databases
|
||||||
func Exit() {
|
func Exit() {
|
||||||
dbMut.Lock()
|
dbMut.Lock()
|
||||||
|
atExit = true
|
||||||
for _, s := range dbMap {
|
for _, s := range dbMap {
|
||||||
|
s.refs = 0
|
||||||
_ = s.Stop(false)
|
_ = s.Stop(false)
|
||||||
}
|
}
|
||||||
|
dbMap = map[string]*DB{}
|
||||||
|
atExit = false
|
||||||
dbMut.Unlock()
|
dbMut.Unlock()
|
||||||
}
|
}
|
||||||
|
|
71
lib/kv/internal_test.go
Normal file
71
lib/kv/internal_test.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
//go:build !plan9 && !js
|
||||||
|
// +build !plan9,!js
|
||||||
|
|
||||||
|
package kv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKvConcurrency(t *testing.T) {
|
||||||
|
require.Equal(t, 0, len(dbMap), "no databases can be started initially")
|
||||||
|
|
||||||
|
const threadNum = 5
|
||||||
|
const facility = "test"
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
ctx := context.Background()
|
||||||
|
results := make([]*DB, threadNum)
|
||||||
|
wg.Add(threadNum)
|
||||||
|
for i := 0; i < threadNum; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
db, err := Start(ctx, "test", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, db)
|
||||||
|
results[i] = db
|
||||||
|
wg.Done()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// must have a single multi-referenced db
|
||||||
|
db := results[0]
|
||||||
|
assert.Equal(t, 1, len(dbMap))
|
||||||
|
assert.Equal(t, threadNum, db.refs)
|
||||||
|
for i := 0; i < threadNum; i++ {
|
||||||
|
assert.Equal(t, db, results[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < threadNum; i++ {
|
||||||
|
assert.Equal(t, 1, len(dbMap))
|
||||||
|
err := db.Stop(false)
|
||||||
|
assert.NoError(t, err, "unexpected error %v at retry %d", err, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 0, len(dbMap), "must be closed in the end")
|
||||||
|
err := db.Stop(false)
|
||||||
|
assert.ErrorIs(t, err, ErrInactive, "missing expected stop indication")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKvExit(t *testing.T) {
|
||||||
|
require.Equal(t, 0, len(dbMap), "no databases can be started initially")
|
||||||
|
const dbNum = 5
|
||||||
|
const openNum = 2
|
||||||
|
ctx := context.Background()
|
||||||
|
for i := 0; i < dbNum; i++ {
|
||||||
|
facility := fmt.Sprintf("test-%d", i)
|
||||||
|
for j := 0; j <= i; j++ {
|
||||||
|
db, err := Start(ctx, facility, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, db)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, dbNum, len(dbMap))
|
||||||
|
Exit()
|
||||||
|
assert.Equal(t, 0, len(dbMap))
|
||||||
|
}
|
Loading…
Reference in a new issue