Fix possible data race in pkg/stall (#163)

fix #162
This commit is contained in:
Evgeniy Kulikov 2019-02-26 20:28:38 +03:00 committed by decentralisedkev
parent 1d1f81e168
commit 926dd20792
2 changed files with 28 additions and 18 deletions

View file

@ -17,7 +17,7 @@ type Detector struct {
responseTime time.Duration responseTime time.Duration
tickInterval time.Duration tickInterval time.Duration
lock sync.Mutex lock *sync.RWMutex
responses map[command.Type]time.Time responses map[command.Type]time.Time
// The detector is embedded into a peer and the peer watches this quit chan // The detector is embedded into a peer and the peer watches this quit chan
@ -35,7 +35,7 @@ func NewDetector(rTime time.Duration, tickerInterval time.Duration) *Detector {
d := &Detector{ d := &Detector{
responseTime: rTime, responseTime: rTime,
tickInterval: tickerInterval, tickInterval: tickerInterval,
lock: sync.Mutex{}, lock: new(sync.RWMutex),
responses: map[command.Type]time.Time{}, responses: map[command.Type]time.Time{},
Quitch: make(chan struct{}), Quitch: make(chan struct{}),
} }
@ -46,24 +46,27 @@ func NewDetector(rTime time.Duration, tickerInterval time.Duration) *Detector {
func (d *Detector) loop() { func (d *Detector) loop() {
ticker := time.NewTicker(d.tickInterval) ticker := time.NewTicker(d.tickInterval)
loop: defer func() {
d.Quit()
d.DeleteAll()
ticker.Stop()
}()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
now := time.Now() now := time.Now()
for _, deadline := range d.responses { d.lock.RLock()
resp := d.responses
d.lock.RUnlock()
for _, deadline := range resp {
if now.After(deadline) { if now.After(deadline) {
fmt.Println("Deadline passed") fmt.Println("Deadline passed")
ticker.Stop() return
break loop
} }
} }
} }
} }
d.Quit()
d.DeleteAll()
ticker.Stop()
} }
// Quit is a concurrent safe way to call the Quit channel // Quit is a concurrent safe way to call the Quit channel
@ -114,17 +117,16 @@ func (d *Detector) DeleteAll() {
// and their deadlines // and their deadlines
func (d *Detector) GetMessages() map[command.Type]time.Time { func (d *Detector) GetMessages() map[command.Type]time.Time {
var resp map[command.Type]time.Time var resp map[command.Type]time.Time
d.lock.Lock() d.lock.RLock()
resp = d.responses resp = d.responses
d.lock.Unlock() d.lock.RUnlock()
return resp return resp
} }
// when a message is added, we will add a deadline for // when a message is added, we will add a deadline for
// expected response // expected response
func (d *Detector) addMessage(cmd command.Type) []command.Type { func (d *Detector) addMessage(cmd command.Type) []command.Type {
var cmds []command.Type
cmds := []command.Type{}
switch cmd { switch cmd {
case command.GetHeaders: case command.GetHeaders:
@ -151,8 +153,7 @@ func (d *Detector) addMessage(cmd command.Type) []command.Type {
// if receive a message, we will delete it from pending // if receive a message, we will delete it from pending
func (d *Detector) removeMessage(cmd command.Type) []command.Type { func (d *Detector) removeMessage(cmd command.Type) []command.Type {
var cmds []command.Type
cmds := []command.Type{}
switch cmd { switch cmd {
case command.Block: case command.Block:

View file

@ -1,6 +1,7 @@
package stall package stall
import ( import (
"sync"
"testing" "testing"
"time" "time"
@ -29,6 +30,7 @@ func TestAddRemoveMessage(t *testing.T) {
} }
type mockPeer struct { type mockPeer struct {
lock *sync.RWMutex
online bool online bool
detector *Detector detector *Detector
} }
@ -43,7 +45,9 @@ loop:
} }
} }
// cleanup // cleanup
mp.lock.Lock()
mp.online = false mp.online = false
mp.lock.Unlock()
} }
func TestDeadlineWorks(t *testing.T) { func TestDeadlineWorks(t *testing.T) {
@ -51,16 +55,19 @@ func TestDeadlineWorks(t *testing.T) {
tickerInterval := 1 * time.Second tickerInterval := 1 * time.Second
d := NewDetector(responseTime, tickerInterval) d := NewDetector(responseTime, tickerInterval)
mp := mockPeer{online: true, detector: d} mp := mockPeer{online: true, detector: d, lock: new(sync.RWMutex)}
go mp.loop() go mp.loop()
d.AddMessage(command.GetAddr) d.AddMessage(command.GetAddr)
time.Sleep(responseTime + 1*time.Second) time.Sleep(responseTime + 1*time.Second)
k := make(map[command.Type]time.Time) k := make(map[command.Type]time.Time)
d.lock.RLock()
assert.Equal(t, k, d.responses) assert.Equal(t, k, d.responses)
d.lock.RUnlock()
mp.lock.RLock()
assert.Equal(t, false, mp.online) assert.Equal(t, false, mp.online)
mp.lock.RUnlock()
} }
func TestDeadlineShouldNotBeEmpty(t *testing.T) { func TestDeadlineShouldNotBeEmpty(t *testing.T) {
responseTime := 10 * time.Second responseTime := 10 * time.Second
@ -71,5 +78,7 @@ func TestDeadlineShouldNotBeEmpty(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
k := make(map[command.Type]time.Time) k := make(map[command.Type]time.Time)
d.lock.RLock()
assert.NotEqual(t, k, d.responses) assert.NotEqual(t, k, d.responses)
d.lock.RUnlock()
} }