diff --git a/pkg/network/discovery.go b/pkg/network/discovery.go index bb9606240..43bf56fa8 100644 --- a/pkg/network/discovery.go +++ b/pkg/network/discovery.go @@ -14,6 +14,7 @@ const ( // a healthy connection pool. type Discoverer interface { BackFill(...string) + Close() PoolCount() int RequestRemote(int) RegisterBadAddr(string) @@ -34,6 +35,7 @@ type DefaultDiscovery struct { connectedAddrs map[string]bool goodAddrs map[string]bool unconnectedAddrs map[string]int + isDead bool requestCh chan int pool chan string } @@ -88,7 +90,11 @@ func (d *DefaultDiscovery) pushToPoolOrDrop(addr string) { // RequestRemote tries to establish a connection with n nodes. func (d *DefaultDiscovery) RequestRemote(n int) { - d.requestCh <- n + d.lock.RLock() + if !d.isDead { + d.requestCh <- n + } + d.lock.RUnlock() } // RegisterBadAddr registers the given address as a bad address. @@ -171,15 +177,28 @@ func (d *DefaultDiscovery) tryAddress(addr string) { } } +// Close stops discoverer pool processing making discoverer almost useless. +func (d *DefaultDiscovery) Close() { + d.lock.Lock() + d.isDead = true + d.lock.Unlock() + select { + case <-d.requestCh: // Drain the channel if there is anything there. + default: + } + close(d.requestCh) +} + // run is a goroutine that makes DefaultDiscovery process its queue to connect // to other nodes. func (d *DefaultDiscovery) run() { - var requested int + var requested, r int + var ok bool for { - for requested = <-d.requestCh; requested > 0; requested-- { + for requested, ok = <-d.requestCh; ok && requested > 0; requested-- { select { - case r := <-d.requestCh: + case r, ok = <-d.requestCh: if requested <= r { requested = r + 1 } @@ -193,5 +212,8 @@ func (d *DefaultDiscovery) run() { } } } + if !ok { + return + } } } diff --git a/pkg/network/discovery_test.go b/pkg/network/discovery_test.go index 0662370b9..9105cdc0b 100644 --- a/pkg/network/discovery_test.go +++ b/pkg/network/discovery_test.go @@ -143,4 +143,8 @@ func TestDefaultDiscoverer(t *testing.T) { assert.Equal(t, len(set1), len(d.BadPeers())) assert.Equal(t, len(set1), len(d.GoodPeers())) require.Equal(t, 0, d.PoolCount()) + + // Close should work and subsequent RequestRemote is a no-op. + d.Close() + d.RequestRemote(42) } diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index d21eb2ae6..ddab230bb 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -140,6 +140,7 @@ func (chain testChain) VerifyTx(*transaction.Transaction, *block.Block) error { type testDiscovery struct{} func (d testDiscovery) BackFill(addrs ...string) {} +func (d testDiscovery) Close() {} func (d testDiscovery) PoolCount() int { return 0 } func (d testDiscovery) RegisterBadAddr(string) {} func (d testDiscovery) RegisterGoodAddr(string) {} diff --git a/pkg/network/server.go b/pkg/network/server.go index cfe241d9f..79bf78ec2 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -183,6 +183,11 @@ func (s *Server) Start(errChan chan error) { // Shutdown disconnects all peers and stops listening. func (s *Server) Shutdown() { s.log.Info("shutting down server", zap.Int("peers", s.PeerCount())) + s.transport.Close() + s.discovery.Close() + for p := range s.peers { + p.Disconnect(errServerShutdown) + } s.bQueue.discard() close(s.quit) } @@ -224,10 +229,6 @@ func (s *Server) run() { } select { case <-s.quit: - s.transport.Close() - for p := range s.peers { - p.Disconnect(errServerShutdown) - } return case p := <-s.register: s.lock.Lock() @@ -239,7 +240,8 @@ func (s *Server) run() { s.lock.RLock() // Pick a random peer and drop connection to it. for peer := range s.peers { - peer.Disconnect(errMaxPeers) + // It will send us unregister signal. + go peer.Disconnect(errMaxPeers) break } s.lock.RUnlock()