network: re-add addresses to the pool on UnregisterConnectedAddr

That's what we do anyway, but this way we can be a bit more efficient.
This commit is contained in:
Roman Khimov 2022-10-13 15:47:55 +03:00
parent 631f166709
commit c1ef326183
3 changed files with 8 additions and 7 deletions

View file

@ -83,6 +83,11 @@ func newDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) Disco
// the pool with the given addresses. // the pool with the given addresses.
func (d *DefaultDiscovery) BackFill(addrs ...string) { func (d *DefaultDiscovery) BackFill(addrs ...string) {
d.lock.Lock() d.lock.Lock()
d.backfill(addrs...)
d.lock.Unlock()
}
func (d *DefaultDiscovery) backfill(addrs ...string) {
for _, addr := range addrs { for _, addr := range addrs {
if d.badAddrs[addr] || d.connectedAddrs[addr] || if d.badAddrs[addr] || d.connectedAddrs[addr] ||
d.unconnectedAddrs[addr] > 0 { d.unconnectedAddrs[addr] > 0 {
@ -92,7 +97,6 @@ func (d *DefaultDiscovery) BackFill(addrs ...string) {
d.pushToPoolOrDrop(addr) d.pushToPoolOrDrop(addr)
} }
d.updateNetSize() d.updateNetSize()
d.lock.Unlock()
} }
// PoolCount returns the number of the available node addresses. // PoolCount returns the number of the available node addresses.
@ -187,7 +191,7 @@ func (d *DefaultDiscovery) RegisterGoodAddr(s string, c capability.Capabilities)
func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) { func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) {
d.lock.Lock() d.lock.Lock()
delete(d.connectedAddrs, s) delete(d.connectedAddrs, s)
d.updateNetSize() d.backfill(s)
d.lock.Unlock() d.lock.Unlock()
} }

View file

@ -132,14 +132,13 @@ func TestDefaultDiscoverer(t *testing.T) {
for _, addr := range set1 { for _, addr := range set1 {
d.UnregisterConnectedAddr(addr) d.UnregisterConnectedAddr(addr)
} }
assert.Equal(t, 0, len(d.UnconnectedPeers())) assert.Equal(t, 2, len(d.UnconnectedPeers())) // They're re-added automatically.
assert.Equal(t, 0, len(d.BadPeers())) assert.Equal(t, 0, len(d.BadPeers()))
assert.Equal(t, len(set1), len(d.GoodPeers())) assert.Equal(t, len(set1), len(d.GoodPeers()))
require.Equal(t, 0, d.PoolCount()) require.Equal(t, 2, d.PoolCount())
// Now make Dial() fail and wait to see addresses in the bad list. // Now make Dial() fail and wait to see addresses in the bad list.
atomic.StoreInt32(&ts.retFalse, 1) atomic.StoreInt32(&ts.retFalse, 1)
d.BackFill(set1...)
assert.Equal(t, len(set1), d.PoolCount()) assert.Equal(t, len(set1), d.PoolCount())
set1D := d.UnconnectedPeers() set1D := d.UnconnectedPeers()
sort.Strings(set1D) sort.Strings(set1D)

View file

@ -439,11 +439,9 @@ func (s *Server) run() {
s.lock.RUnlock() s.lock.RUnlock()
if !stillConnected { if !stillConnected {
s.discovery.UnregisterConnectedAddr(addr) s.discovery.UnregisterConnectedAddr(addr)
s.discovery.BackFill(addr)
} }
} else { } else {
s.discovery.UnregisterConnectedAddr(addr) s.discovery.UnregisterConnectedAddr(addr)
s.discovery.BackFill(addr)
} }
updatePeersConnectedMetric(s.PeerCount()) updatePeersConnectedMetric(s.PeerCount())
} else { } else {