diff --git a/pkg/network/discovery.go b/pkg/network/discovery.go index bb9606240..43bf56fa8 100644 --- a/pkg/network/discovery.go +++ b/pkg/network/discovery.go @@ -14,6 +14,7 @@ const ( // a healthy connection pool. type Discoverer interface { BackFill(...string) + Close() PoolCount() int RequestRemote(int) RegisterBadAddr(string) @@ -34,6 +35,7 @@ type DefaultDiscovery struct { connectedAddrs map[string]bool goodAddrs map[string]bool unconnectedAddrs map[string]int + isDead bool requestCh chan int pool chan string } @@ -88,7 +90,11 @@ func (d *DefaultDiscovery) pushToPoolOrDrop(addr string) { // RequestRemote tries to establish a connection with n nodes. func (d *DefaultDiscovery) RequestRemote(n int) { - d.requestCh <- n + d.lock.RLock() + if !d.isDead { + d.requestCh <- n + } + d.lock.RUnlock() } // RegisterBadAddr registers the given address as a bad address. @@ -171,15 +177,28 @@ func (d *DefaultDiscovery) tryAddress(addr string) { } } +// Close stops discoverer pool processing making discoverer almost useless. +func (d *DefaultDiscovery) Close() { + d.lock.Lock() + d.isDead = true + d.lock.Unlock() + select { + case <-d.requestCh: // Drain the channel if there is anything there. + default: + } + close(d.requestCh) +} + // run is a goroutine that makes DefaultDiscovery process its queue to connect // to other nodes. func (d *DefaultDiscovery) run() { - var requested int + var requested, r int + var ok bool for { - for requested = <-d.requestCh; requested > 0; requested-- { + for requested, ok = <-d.requestCh; ok && requested > 0; requested-- { select { - case r := <-d.requestCh: + case r, ok = <-d.requestCh: if requested <= r { requested = r + 1 } @@ -193,5 +212,8 @@ func (d *DefaultDiscovery) run() { } } } + if !ok { + return + } } } diff --git a/pkg/network/discovery_test.go b/pkg/network/discovery_test.go index 0662370b9..9105cdc0b 100644 --- a/pkg/network/discovery_test.go +++ b/pkg/network/discovery_test.go @@ -143,4 +143,8 @@ func TestDefaultDiscoverer(t *testing.T) { assert.Equal(t, len(set1), len(d.BadPeers())) assert.Equal(t, len(set1), len(d.GoodPeers())) require.Equal(t, 0, d.PoolCount()) + + // Close should work and subsequent RequestRemote is a no-op. + d.Close() + d.RequestRemote(42) } diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 2759c37df..29227b149 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -140,6 +140,7 @@ func (chain testChain) VerifyTx(*transaction.Transaction, *block.Block) error { type testDiscovery struct{} func (d testDiscovery) BackFill(addrs ...string) {} +func (d testDiscovery) Close() {} func (d testDiscovery) PoolCount() int { return 0 } func (d testDiscovery) RegisterBadAddr(string) {} func (d testDiscovery) RegisterGoodAddr(string) {} diff --git a/pkg/network/server.go b/pkg/network/server.go index 3e0eb6575..b6cbaaf7a 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -184,6 +184,7 @@ func (s *Server) Start(errChan chan error) { func (s *Server) Shutdown() { s.log.Info("shutting down server", zap.Int("peers", s.PeerCount())) s.bQueue.discard() + s.discovery.Close() close(s.quit) }