forked from TrueCloudLab/neoneo-go
Merge pull request #459 from nspcc-dev/network-fix-445
Fix discoverer races (#445)
This commit is contained in:
commit
0ea7568caa
2 changed files with 197 additions and 57 deletions
|
@ -1,6 +1,7 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -26,18 +27,14 @@ type Discoverer interface {
|
|||
// DefaultDiscovery default implementation of the Discoverer interface.
|
||||
type DefaultDiscovery struct {
|
||||
transport Transporter
|
||||
lock sync.RWMutex
|
||||
dialTimeout time.Duration
|
||||
badAddrs map[string]bool
|
||||
connectedAddrs map[string]bool
|
||||
goodAddrs map[string]bool
|
||||
unconnectedAddrs map[string]int
|
||||
requestCh chan int
|
||||
connectedCh chan string
|
||||
backFill chan string
|
||||
badAddrCh chan string
|
||||
pool chan string
|
||||
goodCh chan string
|
||||
unconnectedCh chan string
|
||||
}
|
||||
|
||||
// NewDefaultDiscovery returns a new DefaultDiscovery.
|
||||
|
@ -50,11 +47,6 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
|
|||
goodAddrs: make(map[string]bool),
|
||||
unconnectedAddrs: make(map[string]int),
|
||||
requestCh: make(chan int),
|
||||
connectedCh: make(chan string),
|
||||
goodCh: make(chan string),
|
||||
unconnectedCh: make(chan string),
|
||||
backFill: make(chan string),
|
||||
badAddrCh: make(chan string),
|
||||
pool: make(chan string, maxPoolSize),
|
||||
}
|
||||
go d.run()
|
||||
|
@ -64,9 +56,16 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
|
|||
// BackFill implements the Discoverer interface and will backfill the
|
||||
// the pool with the given addresses.
|
||||
func (d *DefaultDiscovery) BackFill(addrs ...string) {
|
||||
d.lock.Lock()
|
||||
for _, addr := range addrs {
|
||||
d.backFill <- addr
|
||||
if d.badAddrs[addr] || d.connectedAddrs[addr] ||
|
||||
d.unconnectedAddrs[addr] > 0 {
|
||||
continue
|
||||
}
|
||||
d.unconnectedAddrs[addr] = connRetries
|
||||
d.pushToPoolOrDrop(addr)
|
||||
}
|
||||
d.lock.Unlock()
|
||||
}
|
||||
|
||||
// PoolCount returns the number of available node addresses.
|
||||
|
@ -92,109 +91,104 @@ func (d *DefaultDiscovery) RequestRemote(n int) {
|
|||
|
||||
// RegisterBadAddr registers the given address as a bad address.
|
||||
func (d *DefaultDiscovery) RegisterBadAddr(addr string) {
|
||||
d.badAddrCh <- addr
|
||||
d.RequestRemote(1)
|
||||
d.lock.Lock()
|
||||
d.unconnectedAddrs[addr]--
|
||||
if d.unconnectedAddrs[addr] > 0 {
|
||||
d.pushToPoolOrDrop(addr)
|
||||
} else {
|
||||
d.badAddrs[addr] = true
|
||||
delete(d.unconnectedAddrs, addr)
|
||||
}
|
||||
d.lock.Unlock()
|
||||
}
|
||||
|
||||
// UnconnectedPeers returns all addresses of unconnected addrs.
|
||||
func (d *DefaultDiscovery) UnconnectedPeers() []string {
|
||||
d.lock.RLock()
|
||||
addrs := make([]string, 0, len(d.unconnectedAddrs))
|
||||
for addr := range d.unconnectedAddrs {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
d.lock.RUnlock()
|
||||
return addrs
|
||||
}
|
||||
|
||||
// BadPeers returns all addresses of bad addrs.
|
||||
func (d *DefaultDiscovery) BadPeers() []string {
|
||||
d.lock.RLock()
|
||||
addrs := make([]string, 0, len(d.badAddrs))
|
||||
for addr := range d.badAddrs {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
d.lock.RUnlock()
|
||||
return addrs
|
||||
}
|
||||
|
||||
// GoodPeers returns all addresses of known good peers (that at least once
|
||||
// succeeded handshaking with us).
|
||||
func (d *DefaultDiscovery) GoodPeers() []string {
|
||||
d.lock.RLock()
|
||||
addrs := make([]string, 0, len(d.goodAddrs))
|
||||
for addr := range d.goodAddrs {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
d.lock.RUnlock()
|
||||
return addrs
|
||||
}
|
||||
|
||||
// RegisterGoodAddr registers good known connected address that passed
|
||||
// handshake successfully.
|
||||
func (d *DefaultDiscovery) RegisterGoodAddr(s string) {
|
||||
d.goodCh <- s
|
||||
d.lock.Lock()
|
||||
d.goodAddrs[s] = true
|
||||
d.lock.Unlock()
|
||||
}
|
||||
|
||||
// UnregisterConnectedAddr tells discoverer that this address is no longer
|
||||
// connected, but it still is considered as good one.
|
||||
func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) {
|
||||
d.unconnectedCh <- s
|
||||
d.lock.Lock()
|
||||
delete(d.connectedAddrs, s)
|
||||
d.lock.Unlock()
|
||||
}
|
||||
|
||||
// registerConnectedAddr tells discoverer that given address is now connected.
|
||||
func (d *DefaultDiscovery) registerConnectedAddr(addr string) {
|
||||
d.lock.Lock()
|
||||
delete(d.unconnectedAddrs, addr)
|
||||
d.connectedAddrs[addr] = true
|
||||
d.lock.Unlock()
|
||||
}
|
||||
|
||||
func (d *DefaultDiscovery) tryAddress(addr string) {
|
||||
if err := d.transport.Dial(addr, d.dialTimeout); err != nil {
|
||||
d.badAddrCh <- addr
|
||||
d.RegisterBadAddr(addr)
|
||||
d.RequestRemote(1)
|
||||
} else {
|
||||
d.connectedCh <- addr
|
||||
d.registerConnectedAddr(addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultDiscovery) requestToWork() {
|
||||
// run is a goroutine that makes DefaultDiscovery process its queue to connect
|
||||
// to other nodes.
|
||||
func (d *DefaultDiscovery) run() {
|
||||
var requested int
|
||||
|
||||
for {
|
||||
for requested = <-d.requestCh; requested > 0; requested-- {
|
||||
select {
|
||||
case r := <-d.requestCh:
|
||||
if requested < r {
|
||||
requested = r
|
||||
if requested <= r {
|
||||
requested = r + 1
|
||||
}
|
||||
case addr := <-d.pool:
|
||||
if !d.connectedAddrs[addr] {
|
||||
d.lock.RLock()
|
||||
addrIsConnected := d.connectedAddrs[addr]
|
||||
d.lock.RUnlock()
|
||||
if !addrIsConnected {
|
||||
go d.tryAddress(addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultDiscovery) run() {
|
||||
go d.requestToWork()
|
||||
for {
|
||||
select {
|
||||
case addr := <-d.backFill:
|
||||
if d.badAddrs[addr] || d.connectedAddrs[addr] ||
|
||||
d.unconnectedAddrs[addr] > 0 {
|
||||
break
|
||||
}
|
||||
d.unconnectedAddrs[addr] = connRetries
|
||||
d.pushToPoolOrDrop(addr)
|
||||
case addr := <-d.badAddrCh:
|
||||
d.unconnectedAddrs[addr]--
|
||||
if d.unconnectedAddrs[addr] > 0 {
|
||||
d.pushToPoolOrDrop(addr)
|
||||
} else {
|
||||
d.badAddrs[addr] = true
|
||||
delete(d.unconnectedAddrs, addr)
|
||||
}
|
||||
d.RequestRemote(1)
|
||||
|
||||
case addr := <-d.connectedCh:
|
||||
delete(d.unconnectedAddrs, addr)
|
||||
if !d.connectedAddrs[addr] {
|
||||
d.connectedAddrs[addr] = true
|
||||
}
|
||||
case addr := <-d.goodCh:
|
||||
if !d.goodAddrs[addr] {
|
||||
d.goodAddrs[addr] = true
|
||||
}
|
||||
case addr := <-d.unconnectedCh:
|
||||
delete(d.connectedAddrs, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
146
pkg/network/discovery_test.go
Normal file
146
pkg/network/discovery_test.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeTransp struct {
|
||||
retFalse int32
|
||||
dialCh chan string
|
||||
}
|
||||
|
||||
func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error {
|
||||
ft.dialCh <- addr
|
||||
|
||||
if atomic.LoadInt32(&ft.retFalse) > 0 {
|
||||
return errors.New("smth bad happened")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (ft *fakeTransp) Accept() {
|
||||
}
|
||||
func (ft *fakeTransp) Proto() string {
|
||||
return ""
|
||||
}
|
||||
func (ft *fakeTransp) Close() {
|
||||
}
|
||||
func TestDefaultDiscoverer(t *testing.T) {
|
||||
ts := &fakeTransp{}
|
||||
ts.dialCh = make(chan string)
|
||||
d := NewDefaultDiscovery(time.Second, ts)
|
||||
|
||||
var set1 = []string{"1.1.1.1:10333", "2.2.2.2:10333"}
|
||||
sort.Strings(set1)
|
||||
|
||||
// Added addresses should end up in the pool and in the unconnected set.
|
||||
// Done twice to check re-adding unconnected addresses, which should be
|
||||
// a no-op.
|
||||
for i := 0; i < 2; i++ {
|
||||
d.BackFill(set1...)
|
||||
assert.Equal(t, len(set1), d.PoolCount())
|
||||
set1D := d.UnconnectedPeers()
|
||||
sort.Strings(set1D)
|
||||
assert.Equal(t, 0, len(d.GoodPeers()))
|
||||
assert.Equal(t, 0, len(d.BadPeers()))
|
||||
require.Equal(t, set1, set1D)
|
||||
}
|
||||
|
||||
// Request should make goroutines dial our addresses draining the pool.
|
||||
d.RequestRemote(len(set1))
|
||||
dialled := make([]string, 0)
|
||||
for i := 0; i < len(set1); i++ {
|
||||
select {
|
||||
case a := <-ts.dialCh:
|
||||
dialled = append(dialled, a)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout expecting for transport dial")
|
||||
}
|
||||
}
|
||||
// Updated asynchronously.
|
||||
if len(d.UnconnectedPeers()) != 0 {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
sort.Strings(dialled)
|
||||
assert.Equal(t, 0, d.PoolCount())
|
||||
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||
assert.Equal(t, 0, len(d.BadPeers()))
|
||||
assert.Equal(t, 0, len(d.GoodPeers()))
|
||||
require.Equal(t, set1, dialled)
|
||||
|
||||
// Registered good addresses should end up in appropriate set.
|
||||
for _, addr := range set1 {
|
||||
d.RegisterGoodAddr(addr)
|
||||
}
|
||||
gAddrs := d.GoodPeers()
|
||||
sort.Strings(gAddrs)
|
||||
assert.Equal(t, 0, d.PoolCount())
|
||||
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||
assert.Equal(t, 0, len(d.BadPeers()))
|
||||
require.Equal(t, set1, gAddrs)
|
||||
|
||||
// Re-adding connected addresses should be no-op.
|
||||
d.BackFill(set1...)
|
||||
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||
assert.Equal(t, 0, len(d.BadPeers()))
|
||||
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
||||
require.Equal(t, 0, d.PoolCount())
|
||||
|
||||
// Unregistering connected should work.
|
||||
for _, addr := range set1 {
|
||||
d.UnregisterConnectedAddr(addr)
|
||||
}
|
||||
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||
assert.Equal(t, 0, len(d.BadPeers()))
|
||||
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
||||
require.Equal(t, 0, d.PoolCount())
|
||||
|
||||
// Now make Dial() fail and wait to see addresses in the bad list.
|
||||
atomic.StoreInt32(&ts.retFalse, 1)
|
||||
d.BackFill(set1...)
|
||||
assert.Equal(t, len(set1), d.PoolCount())
|
||||
set1D := d.UnconnectedPeers()
|
||||
sort.Strings(set1D)
|
||||
assert.Equal(t, 0, len(d.BadPeers()))
|
||||
require.Equal(t, set1, set1D)
|
||||
|
||||
dialledBad := make([]string, 0)
|
||||
d.RequestRemote(len(set1))
|
||||
for i := 0; i < connRetries; i++ {
|
||||
for j := 0; j < len(set1); j++ {
|
||||
select {
|
||||
case a := <-ts.dialCh:
|
||||
dialledBad = append(dialledBad, a)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout expecting for transport dial; i: %d, j: %d", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 0, d.PoolCount())
|
||||
sort.Strings(dialledBad)
|
||||
for i := 0; i < len(set1); i++ {
|
||||
for j := 0; j < connRetries; j++ {
|
||||
assert.Equal(t, set1[i], dialledBad[i*connRetries+j])
|
||||
}
|
||||
}
|
||||
// Updated asynchronously.
|
||||
if len(d.BadPeers()) != len(set1) {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
assert.Equal(t, len(set1), len(d.BadPeers()))
|
||||
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
||||
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||
|
||||
// Re-adding bad addresses is a no-op.
|
||||
d.BackFill(set1...)
|
||||
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
||||
assert.Equal(t, len(set1), len(d.BadPeers()))
|
||||
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
||||
require.Equal(t, 0, d.PoolCount())
|
||||
}
|
Loading…
Reference in a new issue