diff --git a/pkg/network/discovery.go b/pkg/network/discovery.go index adc8d0c9a..7696ab77b 100644 --- a/pkg/network/discovery.go +++ b/pkg/network/discovery.go @@ -1,6 +1,7 @@ package network import ( + "sync" "time" ) @@ -26,18 +27,14 @@ type Discoverer interface { // DefaultDiscovery default implementation of the Discoverer interface. type DefaultDiscovery struct { transport Transporter + lock sync.RWMutex dialTimeout time.Duration badAddrs map[string]bool connectedAddrs map[string]bool goodAddrs map[string]bool unconnectedAddrs map[string]int requestCh chan int - connectedCh chan string - backFill chan string - badAddrCh chan string pool chan string - goodCh chan string - unconnectedCh chan string } // NewDefaultDiscovery returns a new DefaultDiscovery. @@ -50,11 +47,6 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery { goodAddrs: make(map[string]bool), unconnectedAddrs: make(map[string]int), requestCh: make(chan int), - connectedCh: make(chan string), - goodCh: make(chan string), - unconnectedCh: make(chan string), - backFill: make(chan string), - badAddrCh: make(chan string), pool: make(chan string, maxPoolSize), } go d.run() @@ -64,9 +56,16 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery { // BackFill implements the Discoverer interface and will backfill the // the pool with the given addresses. func (d *DefaultDiscovery) BackFill(addrs ...string) { + d.lock.Lock() for _, addr := range addrs { - d.backFill <- addr + if d.badAddrs[addr] || d.connectedAddrs[addr] || + d.unconnectedAddrs[addr] > 0 { + continue + } + d.unconnectedAddrs[addr] = connRetries + d.pushToPoolOrDrop(addr) } + d.lock.Unlock() } // PoolCount returns the number of available node addresses. @@ -92,105 +91,104 @@ func (d *DefaultDiscovery) RequestRemote(n int) { // RegisterBadAddr registers the given address as a bad address. func (d *DefaultDiscovery) RegisterBadAddr(addr string) { - d.badAddrCh <- addr - d.RequestRemote(1) + d.lock.Lock() + d.unconnectedAddrs[addr]-- + if d.unconnectedAddrs[addr] > 0 { + d.pushToPoolOrDrop(addr) + } else { + d.badAddrs[addr] = true + delete(d.unconnectedAddrs, addr) + } + d.lock.Unlock() } // UnconnectedPeers returns all addresses of unconnected addrs. func (d *DefaultDiscovery) UnconnectedPeers() []string { + d.lock.RLock() addrs := make([]string, 0, len(d.unconnectedAddrs)) for addr := range d.unconnectedAddrs { addrs = append(addrs, addr) } + d.lock.RUnlock() return addrs } // BadPeers returns all addresses of bad addrs. func (d *DefaultDiscovery) BadPeers() []string { + d.lock.RLock() addrs := make([]string, 0, len(d.badAddrs)) for addr := range d.badAddrs { addrs = append(addrs, addr) } + d.lock.RUnlock() return addrs } // GoodPeers returns all addresses of known good peers (that at least once // succeeded handshaking with us). func (d *DefaultDiscovery) GoodPeers() []string { + d.lock.RLock() addrs := make([]string, 0, len(d.goodAddrs)) for addr := range d.goodAddrs { addrs = append(addrs, addr) } + d.lock.RUnlock() return addrs } // RegisterGoodAddr registers good known connected address that passed // handshake successfully. func (d *DefaultDiscovery) RegisterGoodAddr(s string) { - d.goodCh <- s + d.lock.Lock() + d.goodAddrs[s] = true + d.lock.Unlock() } // UnregisterConnectedAddr tells discoverer that this address is no longer // connected, but it still is considered as good one. func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) { - d.unconnectedCh <- s + d.lock.Lock() + delete(d.connectedAddrs, s) + d.lock.Unlock() +} + +// registerConnectedAddr tells discoverer that given address is now connected. +func (d *DefaultDiscovery) registerConnectedAddr(addr string) { + d.lock.Lock() + delete(d.unconnectedAddrs, addr) + d.connectedAddrs[addr] = true + d.lock.Unlock() } func (d *DefaultDiscovery) tryAddress(addr string) { if err := d.transport.Dial(addr, d.dialTimeout); err != nil { - d.badAddrCh <- addr + d.RegisterBadAddr(addr) + d.RequestRemote(1) } else { - d.connectedCh <- addr + d.registerConnectedAddr(addr) } } -func (d *DefaultDiscovery) requestToWork() { +// run is a goroutine that makes DefaultDiscovery process its queue to connect +// to other nodes. +func (d *DefaultDiscovery) run() { var requested int for { for requested = <-d.requestCh; requested > 0; requested-- { select { case r := <-d.requestCh: - if requested < r { - requested = r + if requested <= r { + requested = r + 1 } case addr := <-d.pool: - if !d.connectedAddrs[addr] { + d.lock.RLock() + addrIsConnected := d.connectedAddrs[addr] + d.lock.RUnlock() + if !addrIsConnected { go d.tryAddress(addr) } } } } } - -func (d *DefaultDiscovery) run() { - go d.requestToWork() - for { - select { - case addr := <-d.backFill: - if d.badAddrs[addr] || d.connectedAddrs[addr] || - d.unconnectedAddrs[addr] > 0 { - break - } - d.unconnectedAddrs[addr] = connRetries - d.pushToPoolOrDrop(addr) - case addr := <-d.badAddrCh: - d.unconnectedAddrs[addr]-- - if d.unconnectedAddrs[addr] > 0 { - d.pushToPoolOrDrop(addr) - } else { - d.badAddrs[addr] = true - delete(d.unconnectedAddrs, addr) - } - d.RequestRemote(1) - - case addr := <-d.connectedCh: - delete(d.unconnectedAddrs, addr) - d.connectedAddrs[addr] = true - case addr := <-d.goodCh: - d.goodAddrs[addr] = true - case addr := <-d.unconnectedCh: - delete(d.connectedAddrs, addr) - } - } -} diff --git a/pkg/network/discovery_test.go b/pkg/network/discovery_test.go new file mode 100644 index 000000000..0662370b9 --- /dev/null +++ b/pkg/network/discovery_test.go @@ -0,0 +1,146 @@ +package network + +import ( + "errors" + "sort" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeTransp struct { + retFalse int32 + dialCh chan string +} + +func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error { + ft.dialCh <- addr + + if atomic.LoadInt32(&ft.retFalse) > 0 { + return errors.New("smth bad happened") + } + return nil +} +func (ft *fakeTransp) Accept() { +} +func (ft *fakeTransp) Proto() string { + return "" +} +func (ft *fakeTransp) Close() { +} +func TestDefaultDiscoverer(t *testing.T) { + ts := &fakeTransp{} + ts.dialCh = make(chan string) + d := NewDefaultDiscovery(time.Second, ts) + + var set1 = []string{"1.1.1.1:10333", "2.2.2.2:10333"} + sort.Strings(set1) + + // Added addresses should end up in the pool and in the unconnected set. + // Done twice to check re-adding unconnected addresses, which should be + // a no-op. + for i := 0; i < 2; i++ { + d.BackFill(set1...) + assert.Equal(t, len(set1), d.PoolCount()) + set1D := d.UnconnectedPeers() + sort.Strings(set1D) + assert.Equal(t, 0, len(d.GoodPeers())) + assert.Equal(t, 0, len(d.BadPeers())) + require.Equal(t, set1, set1D) + } + + // Request should make goroutines dial our addresses draining the pool. + d.RequestRemote(len(set1)) + dialled := make([]string, 0) + for i := 0; i < len(set1); i++ { + select { + case a := <-ts.dialCh: + dialled = append(dialled, a) + case <-time.After(time.Second): + t.Fatalf("timeout expecting for transport dial") + } + } + // Updated asynchronously. + if len(d.UnconnectedPeers()) != 0 { + time.Sleep(time.Second) + } + sort.Strings(dialled) + assert.Equal(t, 0, d.PoolCount()) + assert.Equal(t, 0, len(d.UnconnectedPeers())) + assert.Equal(t, 0, len(d.BadPeers())) + assert.Equal(t, 0, len(d.GoodPeers())) + require.Equal(t, set1, dialled) + + // Registered good addresses should end up in appropriate set. + for _, addr := range set1 { + d.RegisterGoodAddr(addr) + } + gAddrs := d.GoodPeers() + sort.Strings(gAddrs) + assert.Equal(t, 0, d.PoolCount()) + assert.Equal(t, 0, len(d.UnconnectedPeers())) + assert.Equal(t, 0, len(d.BadPeers())) + require.Equal(t, set1, gAddrs) + + // Re-adding connected addresses should be no-op. + d.BackFill(set1...) + assert.Equal(t, 0, len(d.UnconnectedPeers())) + assert.Equal(t, 0, len(d.BadPeers())) + assert.Equal(t, len(set1), len(d.GoodPeers())) + require.Equal(t, 0, d.PoolCount()) + + // Unregistering connected should work. + for _, addr := range set1 { + d.UnregisterConnectedAddr(addr) + } + assert.Equal(t, 0, len(d.UnconnectedPeers())) + assert.Equal(t, 0, len(d.BadPeers())) + assert.Equal(t, len(set1), len(d.GoodPeers())) + require.Equal(t, 0, d.PoolCount()) + + // Now make Dial() fail and wait to see addresses in the bad list. + atomic.StoreInt32(&ts.retFalse, 1) + d.BackFill(set1...) + assert.Equal(t, len(set1), d.PoolCount()) + set1D := d.UnconnectedPeers() + sort.Strings(set1D) + assert.Equal(t, 0, len(d.BadPeers())) + require.Equal(t, set1, set1D) + + dialledBad := make([]string, 0) + d.RequestRemote(len(set1)) + for i := 0; i < connRetries; i++ { + for j := 0; j < len(set1); j++ { + select { + case a := <-ts.dialCh: + dialledBad = append(dialledBad, a) + case <-time.After(time.Second): + t.Fatalf("timeout expecting for transport dial; i: %d, j: %d", i, j) + } + } + } + require.Equal(t, 0, d.PoolCount()) + sort.Strings(dialledBad) + for i := 0; i < len(set1); i++ { + for j := 0; j < connRetries; j++ { + assert.Equal(t, set1[i], dialledBad[i*connRetries+j]) + } + } + // Updated asynchronously. + if len(d.BadPeers()) != len(set1) { + time.Sleep(time.Second) + } + assert.Equal(t, len(set1), len(d.BadPeers())) + assert.Equal(t, len(set1), len(d.GoodPeers())) + assert.Equal(t, 0, len(d.UnconnectedPeers())) + + // Re-adding bad addresses is a no-op. + d.BackFill(set1...) + assert.Equal(t, 0, len(d.UnconnectedPeers())) + assert.Equal(t, len(set1), len(d.BadPeers())) + assert.Equal(t, len(set1), len(d.GoodPeers())) + require.Equal(t, 0, d.PoolCount()) +}