diff --git a/pkg/network/discovery.go b/pkg/network/discovery.go index 25204a0f9..08e59f2ee 100644 --- a/pkg/network/discovery.go +++ b/pkg/network/discovery.go @@ -83,6 +83,11 @@ func newDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) Disco // the pool with the given addresses. func (d *DefaultDiscovery) BackFill(addrs ...string) { d.lock.Lock() + d.backfill(addrs...) + d.lock.Unlock() +} + +func (d *DefaultDiscovery) backfill(addrs ...string) { for _, addr := range addrs { if d.badAddrs[addr] || d.connectedAddrs[addr] || d.unconnectedAddrs[addr] > 0 { @@ -92,7 +97,6 @@ func (d *DefaultDiscovery) BackFill(addrs ...string) { d.pushToPoolOrDrop(addr) } d.updateNetSize() - d.lock.Unlock() } // PoolCount returns the number of the available node addresses. @@ -187,7 +191,7 @@ func (d *DefaultDiscovery) RegisterGoodAddr(s string, c capability.Capabilities) func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) { d.lock.Lock() delete(d.connectedAddrs, s) - d.updateNetSize() + d.backfill(s) d.lock.Unlock() } diff --git a/pkg/network/discovery_test.go b/pkg/network/discovery_test.go index c1d871bed..047aceb9b 100644 --- a/pkg/network/discovery_test.go +++ b/pkg/network/discovery_test.go @@ -132,14 +132,13 @@ func TestDefaultDiscoverer(t *testing.T) { for _, addr := range set1 { d.UnregisterConnectedAddr(addr) } - assert.Equal(t, 0, len(d.UnconnectedPeers())) + assert.Equal(t, 2, len(d.UnconnectedPeers())) // They're re-added automatically. assert.Equal(t, 0, len(d.BadPeers())) assert.Equal(t, len(set1), len(d.GoodPeers())) - require.Equal(t, 0, d.PoolCount()) + require.Equal(t, 2, d.PoolCount()) // Now make Dial() fail and wait to see addresses in the bad list. atomic.StoreInt32(&ts.retFalse, 1) - d.BackFill(set1...) assert.Equal(t, len(set1), d.PoolCount()) set1D := d.UnconnectedPeers() sort.Strings(set1D) diff --git a/pkg/network/server.go b/pkg/network/server.go index 0e505db16..b33901543 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -439,11 +439,9 @@ func (s *Server) run() { s.lock.RUnlock() if !stillConnected { s.discovery.UnregisterConnectedAddr(addr) - s.discovery.BackFill(addr) } } else { s.discovery.UnregisterConnectedAddr(addr) - s.discovery.BackFill(addr) } updatePeersConnectedMetric(s.PeerCount()) } else {