network: rework discovery with rwmutex, add test
Keeping run() as the owner of all maps would mean adding at least three more channels to keep address getters with thread-safety. But then there also is a race between requestToWork() and run() which is way harder to solve with channels because there are lots of possibilities for deadlocks. So rework all of this with good old mutexes. While at it, fix `requestCh` handling in the inner select of run, it will waste one loop to handle it, so we should add one to the `requested`. Fixes #445.
This commit is contained in:
parent
77a50d6dc6
commit
006337b1f8
2 changed files with 197 additions and 53 deletions
|
@ -1,6 +1,7 @@
|
||||||
package network
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,18 +27,14 @@ type Discoverer interface {
|
||||||
// DefaultDiscovery default implementation of the Discoverer interface.
|
// DefaultDiscovery default implementation of the Discoverer interface.
|
||||||
type DefaultDiscovery struct {
|
type DefaultDiscovery struct {
|
||||||
transport Transporter
|
transport Transporter
|
||||||
|
lock sync.RWMutex
|
||||||
dialTimeout time.Duration
|
dialTimeout time.Duration
|
||||||
badAddrs map[string]bool
|
badAddrs map[string]bool
|
||||||
connectedAddrs map[string]bool
|
connectedAddrs map[string]bool
|
||||||
goodAddrs map[string]bool
|
goodAddrs map[string]bool
|
||||||
unconnectedAddrs map[string]int
|
unconnectedAddrs map[string]int
|
||||||
requestCh chan int
|
requestCh chan int
|
||||||
connectedCh chan string
|
|
||||||
backFill chan string
|
|
||||||
badAddrCh chan string
|
|
||||||
pool chan string
|
pool chan string
|
||||||
goodCh chan string
|
|
||||||
unconnectedCh chan string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultDiscovery returns a new DefaultDiscovery.
|
// NewDefaultDiscovery returns a new DefaultDiscovery.
|
||||||
|
@ -50,11 +47,6 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
|
||||||
goodAddrs: make(map[string]bool),
|
goodAddrs: make(map[string]bool),
|
||||||
unconnectedAddrs: make(map[string]int),
|
unconnectedAddrs: make(map[string]int),
|
||||||
requestCh: make(chan int),
|
requestCh: make(chan int),
|
||||||
connectedCh: make(chan string),
|
|
||||||
goodCh: make(chan string),
|
|
||||||
unconnectedCh: make(chan string),
|
|
||||||
backFill: make(chan string),
|
|
||||||
badAddrCh: make(chan string),
|
|
||||||
pool: make(chan string, maxPoolSize),
|
pool: make(chan string, maxPoolSize),
|
||||||
}
|
}
|
||||||
go d.run()
|
go d.run()
|
||||||
|
@ -64,9 +56,16 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
|
||||||
// BackFill implements the Discoverer interface and will backfill the
|
// BackFill implements the Discoverer interface and will backfill the
|
||||||
// the pool with the given addresses.
|
// the pool with the given addresses.
|
||||||
func (d *DefaultDiscovery) BackFill(addrs ...string) {
|
func (d *DefaultDiscovery) BackFill(addrs ...string) {
|
||||||
|
d.lock.Lock()
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
d.backFill <- addr
|
if d.badAddrs[addr] || d.connectedAddrs[addr] ||
|
||||||
|
d.unconnectedAddrs[addr] > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
d.unconnectedAddrs[addr] = connRetries
|
||||||
|
d.pushToPoolOrDrop(addr)
|
||||||
}
|
}
|
||||||
|
d.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// PoolCount returns the number of available node addresses.
|
// PoolCount returns the number of available node addresses.
|
||||||
|
@ -92,105 +91,104 @@ func (d *DefaultDiscovery) RequestRemote(n int) {
|
||||||
|
|
||||||
// RegisterBadAddr registers the given address as a bad address.
|
// RegisterBadAddr registers the given address as a bad address.
|
||||||
func (d *DefaultDiscovery) RegisterBadAddr(addr string) {
|
func (d *DefaultDiscovery) RegisterBadAddr(addr string) {
|
||||||
d.badAddrCh <- addr
|
d.lock.Lock()
|
||||||
d.RequestRemote(1)
|
d.unconnectedAddrs[addr]--
|
||||||
|
if d.unconnectedAddrs[addr] > 0 {
|
||||||
|
d.pushToPoolOrDrop(addr)
|
||||||
|
} else {
|
||||||
|
d.badAddrs[addr] = true
|
||||||
|
delete(d.unconnectedAddrs, addr)
|
||||||
|
}
|
||||||
|
d.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnconnectedPeers returns all addresses of unconnected addrs.
|
// UnconnectedPeers returns all addresses of unconnected addrs.
|
||||||
func (d *DefaultDiscovery) UnconnectedPeers() []string {
|
func (d *DefaultDiscovery) UnconnectedPeers() []string {
|
||||||
|
d.lock.RLock()
|
||||||
addrs := make([]string, 0, len(d.unconnectedAddrs))
|
addrs := make([]string, 0, len(d.unconnectedAddrs))
|
||||||
for addr := range d.unconnectedAddrs {
|
for addr := range d.unconnectedAddrs {
|
||||||
addrs = append(addrs, addr)
|
addrs = append(addrs, addr)
|
||||||
}
|
}
|
||||||
|
d.lock.RUnlock()
|
||||||
return addrs
|
return addrs
|
||||||
}
|
}
|
||||||
|
|
||||||
// BadPeers returns all addresses of bad addrs.
|
// BadPeers returns all addresses of bad addrs.
|
||||||
func (d *DefaultDiscovery) BadPeers() []string {
|
func (d *DefaultDiscovery) BadPeers() []string {
|
||||||
|
d.lock.RLock()
|
||||||
addrs := make([]string, 0, len(d.badAddrs))
|
addrs := make([]string, 0, len(d.badAddrs))
|
||||||
for addr := range d.badAddrs {
|
for addr := range d.badAddrs {
|
||||||
addrs = append(addrs, addr)
|
addrs = append(addrs, addr)
|
||||||
}
|
}
|
||||||
|
d.lock.RUnlock()
|
||||||
return addrs
|
return addrs
|
||||||
}
|
}
|
||||||
|
|
||||||
// GoodPeers returns all addresses of known good peers (that at least once
|
// GoodPeers returns all addresses of known good peers (that at least once
|
||||||
// succeeded handshaking with us).
|
// succeeded handshaking with us).
|
||||||
func (d *DefaultDiscovery) GoodPeers() []string {
|
func (d *DefaultDiscovery) GoodPeers() []string {
|
||||||
|
d.lock.RLock()
|
||||||
addrs := make([]string, 0, len(d.goodAddrs))
|
addrs := make([]string, 0, len(d.goodAddrs))
|
||||||
for addr := range d.goodAddrs {
|
for addr := range d.goodAddrs {
|
||||||
addrs = append(addrs, addr)
|
addrs = append(addrs, addr)
|
||||||
}
|
}
|
||||||
|
d.lock.RUnlock()
|
||||||
return addrs
|
return addrs
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterGoodAddr registers good known connected address that passed
|
// RegisterGoodAddr registers good known connected address that passed
|
||||||
// handshake successfully.
|
// handshake successfully.
|
||||||
func (d *DefaultDiscovery) RegisterGoodAddr(s string) {
|
func (d *DefaultDiscovery) RegisterGoodAddr(s string) {
|
||||||
d.goodCh <- s
|
d.lock.Lock()
|
||||||
|
d.goodAddrs[s] = true
|
||||||
|
d.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnregisterConnectedAddr tells discoverer that this address is no longer
|
// UnregisterConnectedAddr tells discoverer that this address is no longer
|
||||||
// connected, but it still is considered as good one.
|
// connected, but it still is considered as good one.
|
||||||
func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) {
|
func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) {
|
||||||
d.unconnectedCh <- s
|
d.lock.Lock()
|
||||||
|
delete(d.connectedAddrs, s)
|
||||||
|
d.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerConnectedAddr tells discoverer that given address is now connected.
|
||||||
|
func (d *DefaultDiscovery) registerConnectedAddr(addr string) {
|
||||||
|
d.lock.Lock()
|
||||||
|
delete(d.unconnectedAddrs, addr)
|
||||||
|
d.connectedAddrs[addr] = true
|
||||||
|
d.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultDiscovery) tryAddress(addr string) {
|
func (d *DefaultDiscovery) tryAddress(addr string) {
|
||||||
if err := d.transport.Dial(addr, d.dialTimeout); err != nil {
|
if err := d.transport.Dial(addr, d.dialTimeout); err != nil {
|
||||||
d.badAddrCh <- addr
|
d.RegisterBadAddr(addr)
|
||||||
|
d.RequestRemote(1)
|
||||||
} else {
|
} else {
|
||||||
d.connectedCh <- addr
|
d.registerConnectedAddr(addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultDiscovery) requestToWork() {
|
// run is a goroutine that makes DefaultDiscovery process its queue to connect
|
||||||
|
// to other nodes.
|
||||||
|
func (d *DefaultDiscovery) run() {
|
||||||
var requested int
|
var requested int
|
||||||
|
|
||||||
for {
|
for {
|
||||||
for requested = <-d.requestCh; requested > 0; requested-- {
|
for requested = <-d.requestCh; requested > 0; requested-- {
|
||||||
select {
|
select {
|
||||||
case r := <-d.requestCh:
|
case r := <-d.requestCh:
|
||||||
if requested < r {
|
if requested <= r {
|
||||||
requested = r
|
requested = r + 1
|
||||||
}
|
}
|
||||||
case addr := <-d.pool:
|
case addr := <-d.pool:
|
||||||
if !d.connectedAddrs[addr] {
|
d.lock.RLock()
|
||||||
|
addrIsConnected := d.connectedAddrs[addr]
|
||||||
|
d.lock.RUnlock()
|
||||||
|
if !addrIsConnected {
|
||||||
go d.tryAddress(addr)
|
go d.tryAddress(addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultDiscovery) run() {
|
|
||||||
go d.requestToWork()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case addr := <-d.backFill:
|
|
||||||
if d.badAddrs[addr] || d.connectedAddrs[addr] ||
|
|
||||||
d.unconnectedAddrs[addr] > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
d.unconnectedAddrs[addr] = connRetries
|
|
||||||
d.pushToPoolOrDrop(addr)
|
|
||||||
case addr := <-d.badAddrCh:
|
|
||||||
d.unconnectedAddrs[addr]--
|
|
||||||
if d.unconnectedAddrs[addr] > 0 {
|
|
||||||
d.pushToPoolOrDrop(addr)
|
|
||||||
} else {
|
|
||||||
d.badAddrs[addr] = true
|
|
||||||
delete(d.unconnectedAddrs, addr)
|
|
||||||
}
|
|
||||||
d.RequestRemote(1)
|
|
||||||
|
|
||||||
case addr := <-d.connectedCh:
|
|
||||||
delete(d.unconnectedAddrs, addr)
|
|
||||||
d.connectedAddrs[addr] = true
|
|
||||||
case addr := <-d.goodCh:
|
|
||||||
d.goodAddrs[addr] = true
|
|
||||||
case addr := <-d.unconnectedCh:
|
|
||||||
delete(d.connectedAddrs, addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
146
pkg/network/discovery_test.go
Normal file
146
pkg/network/discovery_test.go
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sort"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeTransp struct {
|
||||||
|
retFalse int32
|
||||||
|
dialCh chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error {
|
||||||
|
ft.dialCh <- addr
|
||||||
|
|
||||||
|
if atomic.LoadInt32(&ft.retFalse) > 0 {
|
||||||
|
return errors.New("smth bad happened")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (ft *fakeTransp) Accept() {
|
||||||
|
}
|
||||||
|
func (ft *fakeTransp) Proto() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
func (ft *fakeTransp) Close() {
|
||||||
|
}
|
||||||
|
func TestDefaultDiscoverer(t *testing.T) {
|
||||||
|
ts := &fakeTransp{}
|
||||||
|
ts.dialCh = make(chan string)
|
||||||
|
d := NewDefaultDiscovery(time.Second, ts)
|
||||||
|
|
||||||
|
var set1 = []string{"1.1.1.1:10333", "2.2.2.2:10333"}
|
||||||
|
sort.Strings(set1)
|
||||||
|
|
||||||
|
// Added addresses should end up in the pool and in the unconnected set.
|
||||||
|
// Done twice to check re-adding unconnected addresses, which should be
|
||||||
|
// a no-op.
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
d.BackFill(set1...)
|
||||||
|
assert.Equal(t, len(set1), d.PoolCount())
|
||||||
|
set1D := d.UnconnectedPeers()
|
||||||
|
sort.Strings(set1D)
|
||||||
|
assert.Equal(t, 0, len(d.GoodPeers()))
|
||||||
|
assert.Equal(t, 0, len(d.BadPeers()))
|
||||||
|
require.Equal(t, set1, set1D)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request should make goroutines dial our addresses draining the pool.
|
||||||
|
d.RequestRemote(len(set1))
|
||||||
|
dialled := make([]string, 0)
|
||||||
|
for i := 0; i < len(set1); i++ {
|
||||||
|
select {
|
||||||
|
case a := <-ts.dialCh:
|
||||||
|
dialled = append(dialled, a)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("timeout expecting for transport dial")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Updated asynchronously.
|
||||||
|
if len(d.UnconnectedPeers()) != 0 {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
sort.Strings(dialled)
|
||||||
|
assert.Equal(t, 0, d.PoolCount())
|
||||||
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||||
|
assert.Equal(t, 0, len(d.BadPeers()))
|
||||||
|
assert.Equal(t, 0, len(d.GoodPeers()))
|
||||||
|
require.Equal(t, set1, dialled)
|
||||||
|
|
||||||
|
// Registered good addresses should end up in appropriate set.
|
||||||
|
for _, addr := range set1 {
|
||||||
|
d.RegisterGoodAddr(addr)
|
||||||
|
}
|
||||||
|
gAddrs := d.GoodPeers()
|
||||||
|
sort.Strings(gAddrs)
|
||||||
|
assert.Equal(t, 0, d.PoolCount())
|
||||||
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||||
|
assert.Equal(t, 0, len(d.BadPeers()))
|
||||||
|
require.Equal(t, set1, gAddrs)
|
||||||
|
|
||||||
|
// Re-adding connected addresses should be no-op.
|
||||||
|
d.BackFill(set1...)
|
||||||
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||||
|
assert.Equal(t, 0, len(d.BadPeers()))
|
||||||
|
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
||||||
|
require.Equal(t, 0, d.PoolCount())
|
||||||
|
|
||||||
|
// Unregistering connected should work.
|
||||||
|
for _, addr := range set1 {
|
||||||
|
d.UnregisterConnectedAddr(addr)
|
||||||
|
}
|
||||||
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||||
|
assert.Equal(t, 0, len(d.BadPeers()))
|
||||||
|
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
||||||
|
require.Equal(t, 0, d.PoolCount())
|
||||||
|
|
||||||
|
// Now make Dial() fail and wait to see addresses in the bad list.
|
||||||
|
atomic.StoreInt32(&ts.retFalse, 1)
|
||||||
|
d.BackFill(set1...)
|
||||||
|
assert.Equal(t, len(set1), d.PoolCount())
|
||||||
|
set1D := d.UnconnectedPeers()
|
||||||
|
sort.Strings(set1D)
|
||||||
|
assert.Equal(t, 0, len(d.BadPeers()))
|
||||||
|
require.Equal(t, set1, set1D)
|
||||||
|
|
||||||
|
dialledBad := make([]string, 0)
|
||||||
|
d.RequestRemote(len(set1))
|
||||||
|
for i := 0; i < connRetries; i++ {
|
||||||
|
for j := 0; j < len(set1); j++ {
|
||||||
|
select {
|
||||||
|
case a := <-ts.dialCh:
|
||||||
|
dialledBad = append(dialledBad, a)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("timeout expecting for transport dial; i: %d, j: %d", i, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, 0, d.PoolCount())
|
||||||
|
sort.Strings(dialledBad)
|
||||||
|
for i := 0; i < len(set1); i++ {
|
||||||
|
for j := 0; j < connRetries; j++ {
|
||||||
|
assert.Equal(t, set1[i], dialledBad[i*connRetries+j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Updated asynchronously.
|
||||||
|
if len(d.BadPeers()) != len(set1) {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
assert.Equal(t, len(set1), len(d.BadPeers()))
|
||||||
|
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
||||||
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||||
|
|
||||||
|
// Re-adding bad addresses is a no-op.
|
||||||
|
d.BackFill(set1...)
|
||||||
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||||
|
assert.Equal(t, len(set1), len(d.BadPeers()))
|
||||||
|
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
||||||
|
require.Equal(t, 0, d.PoolCount())
|
||||||
|
}
|
Loading…
Reference in a new issue