Merge pull request #2811 from nspcc-dev/network-connections-fix

Network connections fix
This commit is contained in:
Roman Khimov 2022-11-18 14:26:25 +07:00 committed by GitHub
commit bb9f17108e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 213 additions and 112 deletions

View file

@ -2,6 +2,7 @@ package network
import ( import (
"math" "math"
"math/rand"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -14,6 +15,11 @@ const (
connRetries = 3 connRetries = 3
) )
var (
// Maximum waiting time before connection attempt.
tryMaxWait = time.Second / 2
)
// Discoverer is an interface that is responsible for maintaining // Discoverer is an interface that is responsible for maintaining
// a healthy connection pool. // a healthy connection pool.
type Discoverer interface { type Discoverer interface {
@ -22,10 +28,10 @@ type Discoverer interface {
NetworkSize() int NetworkSize() int
PoolCount() int PoolCount() int
RequestRemote(int) RequestRemote(int)
RegisterBadAddr(string) RegisterSelf(AddressablePeer)
RegisterGoodAddr(string, capability.Capabilities) RegisterGood(AddressablePeer)
RegisterConnectedAddr(string) RegisterConnected(AddressablePeer)
UnregisterConnectedAddr(string) UnregisterConnected(AddressablePeer, bool)
UnconnectedPeers() []string UnconnectedPeers() []string
BadPeers() []string BadPeers() []string
GoodPeers() []AddressWithCapabilities GoodPeers() []AddressWithCapabilities
@ -39,15 +45,17 @@ type AddressWithCapabilities struct {
// DefaultDiscovery default implementation of the Discoverer interface. // DefaultDiscovery default implementation of the Discoverer interface.
type DefaultDiscovery struct { type DefaultDiscovery struct {
seeds []string seeds map[string]string
transport Transporter transport Transporter
lock sync.RWMutex lock sync.RWMutex
dialTimeout time.Duration dialTimeout time.Duration
badAddrs map[string]bool badAddrs map[string]bool
connectedAddrs map[string]bool connectedAddrs map[string]bool
handshakedAddrs map[string]bool
goodAddrs map[string]capability.Capabilities goodAddrs map[string]capability.Capabilities
unconnectedAddrs map[string]int unconnectedAddrs map[string]int
attempted map[string]bool attempted map[string]bool
outstanding int32
optimalFanOut int32 optimalFanOut int32
networkSize int32 networkSize int32
requestCh chan int requestCh chan int
@ -55,12 +63,17 @@ type DefaultDiscovery struct {
// NewDefaultDiscovery returns a new DefaultDiscovery. // NewDefaultDiscovery returns a new DefaultDiscovery.
func NewDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) *DefaultDiscovery { func NewDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) *DefaultDiscovery {
var seeds = make(map[string]string)
for i := range addrs {
seeds[addrs[i]] = ""
}
d := &DefaultDiscovery{ d := &DefaultDiscovery{
seeds: addrs, seeds: seeds,
transport: ts, transport: ts,
dialTimeout: dt, dialTimeout: dt,
badAddrs: make(map[string]bool), badAddrs: make(map[string]bool),
connectedAddrs: make(map[string]bool), connectedAddrs: make(map[string]bool),
handshakedAddrs: make(map[string]bool),
goodAddrs: make(map[string]capability.Capabilities), goodAddrs: make(map[string]capability.Capabilities),
unconnectedAddrs: make(map[string]int), unconnectedAddrs: make(map[string]int),
attempted: make(map[string]bool), attempted: make(map[string]bool),
@ -83,7 +96,7 @@ func (d *DefaultDiscovery) BackFill(addrs ...string) {
func (d *DefaultDiscovery) backfill(addrs ...string) { func (d *DefaultDiscovery) backfill(addrs ...string) {
for _, addr := range addrs { for _, addr := range addrs {
if d.badAddrs[addr] || d.connectedAddrs[addr] || if d.badAddrs[addr] || d.connectedAddrs[addr] || d.handshakedAddrs[addr] ||
d.unconnectedAddrs[addr] > 0 { d.unconnectedAddrs[addr] > 0 {
continue continue
} }
@ -113,11 +126,13 @@ func (d *DefaultDiscovery) pushToPoolOrDrop(addr string) {
// RequestRemote tries to establish a connection with n nodes. // RequestRemote tries to establish a connection with n nodes.
func (d *DefaultDiscovery) RequestRemote(requested int) { func (d *DefaultDiscovery) RequestRemote(requested int) {
outstanding := int(atomic.LoadInt32(&d.outstanding))
requested -= outstanding
for ; requested > 0; requested-- { for ; requested > 0; requested-- {
var nextAddr string var nextAddr string
d.lock.Lock() d.lock.Lock()
for addr := range d.unconnectedAddrs { for addr := range d.unconnectedAddrs {
if !d.connectedAddrs[addr] && !d.attempted[addr] { if !d.connectedAddrs[addr] && !d.handshakedAddrs[addr] && !d.attempted[addr] {
nextAddr = addr nextAddr = addr
break break
} }
@ -125,8 +140,8 @@ func (d *DefaultDiscovery) RequestRemote(requested int) {
if nextAddr == "" { if nextAddr == "" {
// Empty pool, try seeds. // Empty pool, try seeds.
for _, addr := range d.seeds { for addr, ip := range d.seeds {
if !d.connectedAddrs[addr] && !d.attempted[addr] { if ip == "" && !d.attempted[addr] {
nextAddr = addr nextAddr = addr
break break
} }
@ -140,30 +155,38 @@ func (d *DefaultDiscovery) RequestRemote(requested int) {
} }
d.attempted[nextAddr] = true d.attempted[nextAddr] = true
d.lock.Unlock() d.lock.Unlock()
atomic.AddInt32(&d.outstanding, 1)
go d.tryAddress(nextAddr) go d.tryAddress(nextAddr)
} }
} }
// RegisterBadAddr registers the given address as a bad address. // RegisterSelf registers the given Peer as a bad one, because it's our own node.
func (d *DefaultDiscovery) RegisterBadAddr(addr string) { func (d *DefaultDiscovery) RegisterSelf(p AddressablePeer) {
var isSeed bool var connaddr = p.ConnectionAddr()
d.lock.Lock() d.lock.Lock()
for _, seed := range d.seeds { delete(d.connectedAddrs, connaddr)
if addr == seed { d.registerBad(connaddr, true)
isSeed = true d.registerBad(p.PeerAddr().String(), true)
break d.lock.Unlock()
}
func (d *DefaultDiscovery) registerBad(addr string, force bool) {
_, isSeed := d.seeds[addr]
if isSeed {
if !force {
d.seeds[addr] = ""
} else {
d.seeds[addr] = "forever" // That's our own address, so never try connecting to it.
} }
} } else {
if !isSeed {
d.unconnectedAddrs[addr]-- d.unconnectedAddrs[addr]--
if d.unconnectedAddrs[addr] <= 0 { if d.unconnectedAddrs[addr] <= 0 || force {
d.badAddrs[addr] = true d.badAddrs[addr] = true
delete(d.unconnectedAddrs, addr) delete(d.unconnectedAddrs, addr)
delete(d.goodAddrs, addr) delete(d.goodAddrs, addr)
} }
} }
d.updateNetSize() d.updateNetSize()
d.lock.Unlock()
} }
// UnconnectedPeers returns all addresses of unconnected addrs. // UnconnectedPeers returns all addresses of unconnected addrs.
@ -203,31 +226,53 @@ func (d *DefaultDiscovery) GoodPeers() []AddressWithCapabilities {
return addrs return addrs
} }
// RegisterGoodAddr registers a known good connected address that has passed // RegisterGood registers a known good connected peer that has passed
// handshake successfully. // handshake successfully.
func (d *DefaultDiscovery) RegisterGoodAddr(s string, c capability.Capabilities) { func (d *DefaultDiscovery) RegisterGood(p AddressablePeer) {
var s = p.PeerAddr().String()
d.lock.Lock() d.lock.Lock()
d.goodAddrs[s] = c d.handshakedAddrs[s] = true
d.goodAddrs[s] = p.Version().Capabilities
delete(d.badAddrs, s) delete(d.badAddrs, s)
d.lock.Unlock() d.lock.Unlock()
} }
// UnregisterConnectedAddr tells the discoverer that this address is no longer // UnregisterConnected tells the discoverer that this peer is no longer
// connected, but it is still considered a good one. // connected, but it is still considered a good one.
func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) { func (d *DefaultDiscovery) UnregisterConnected(p AddressablePeer, duplicate bool) {
var (
peeraddr = p.PeerAddr().String()
connaddr = p.ConnectionAddr()
)
d.lock.Lock() d.lock.Lock()
delete(d.connectedAddrs, s) delete(d.connectedAddrs, connaddr)
d.backfill(s) if !duplicate {
for addr, ip := range d.seeds {
if ip == peeraddr {
d.seeds[addr] = ""
break
}
}
delete(d.handshakedAddrs, peeraddr)
if _, ok := d.goodAddrs[peeraddr]; ok {
d.backfill(peeraddr)
}
}
d.lock.Unlock() d.lock.Unlock()
} }
// RegisterConnectedAddr tells discoverer that the given address is now connected. // RegisterConnected tells discoverer that the given peer is now connected.
func (d *DefaultDiscovery) RegisterConnectedAddr(addr string) { func (d *DefaultDiscovery) RegisterConnected(p AddressablePeer) {
var addr = p.ConnectionAddr()
d.lock.Lock() d.lock.Lock()
d.registerConnected(addr)
d.lock.Unlock()
}
func (d *DefaultDiscovery) registerConnected(addr string) {
delete(d.unconnectedAddrs, addr) delete(d.unconnectedAddrs, addr)
d.connectedAddrs[addr] = true d.connectedAddrs[addr] = true
d.updateNetSize() d.updateNetSize()
d.lock.Unlock()
} }
// GetFanOut returns the optimal number of nodes to broadcast packets to. // GetFanOut returns the optimal number of nodes to broadcast packets to.
@ -242,9 +287,9 @@ func (d *DefaultDiscovery) NetworkSize() int {
// updateNetSize updates network size estimation metric. Must be called under read lock. // updateNetSize updates network size estimation metric. Must be called under read lock.
func (d *DefaultDiscovery) updateNetSize() { func (d *DefaultDiscovery) updateNetSize() {
var netsize = len(d.connectedAddrs) + len(d.unconnectedAddrs) + 1 // 1 for the node itself. var netsize = len(d.handshakedAddrs) + len(d.unconnectedAddrs) + 1 // 1 for the node itself.
var fanOut = 2.5 * math.Log(float64(netsize-1)) // -1 for the number of potential peers. var fanOut = 2.5 * math.Log(float64(netsize-1)) // -1 for the number of potential peers.
if netsize == 2 { // log(1) == 0. if netsize == 2 { // log(1) == 0.
fanOut = 1 // But we still want to push messages to the peer. fanOut = 1 // But we still want to push messages to the peer.
} }
@ -255,12 +300,22 @@ func (d *DefaultDiscovery) updateNetSize() {
} }
func (d *DefaultDiscovery) tryAddress(addr string) { func (d *DefaultDiscovery) tryAddress(addr string) {
err := d.transport.Dial(addr, d.dialTimeout) var tout = rand.Int63n(int64(tryMaxWait))
time.Sleep(time.Duration(tout)) // Have a sleep before working hard.
p, err := d.transport.Dial(addr, d.dialTimeout)
atomic.AddInt32(&d.outstanding, -1)
d.lock.Lock() d.lock.Lock()
delete(d.attempted, addr) delete(d.attempted, addr)
if err == nil {
if _, ok := d.seeds[addr]; ok {
d.seeds[addr] = p.PeerAddr().String()
}
d.registerConnected(addr)
} else {
d.registerBad(addr, false)
}
d.lock.Unlock() d.lock.Unlock()
if err != nil { if err != nil {
d.RegisterBadAddr(addr)
time.Sleep(d.dialTimeout) time.Sleep(d.dialTimeout)
d.RequestRemote(1) d.RequestRemote(1)
} }

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
atomic2 "go.uber.org/atomic" atomic2 "go.uber.org/atomic"
@ -22,18 +23,40 @@ type fakeTransp struct {
addr string addr string
} }
type fakeAPeer struct {
addr string
peer string
version *payload.Version
}
func (f *fakeAPeer) ConnectionAddr() string {
return f.addr
}
func (f *fakeAPeer) PeerAddr() net.Addr {
tcpAddr, err := net.ResolveTCPAddr("tcp", f.peer)
if err != nil {
panic(err)
}
return tcpAddr
}
func (f *fakeAPeer) Version() *payload.Version {
return f.version
}
func newFakeTransp(s *Server) Transporter { func newFakeTransp(s *Server) Transporter {
return &fakeTransp{} return &fakeTransp{}
} }
func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error { func (ft *fakeTransp) Dial(addr string, timeout time.Duration) (AddressablePeer, error) {
var ret error var ret error
if atomic.LoadInt32(&ft.retFalse) > 0 { if atomic.LoadInt32(&ft.retFalse) > 0 {
ret = errors.New("smth bad happened") ret = errors.New("smth bad happened")
} }
ft.dialCh <- addr ft.dialCh <- addr
return ret return &fakeAPeer{addr: addr, peer: addr}, ret
} }
func (ft *fakeTransp) Accept() { func (ft *fakeTransp) Accept() {
if ft.started.Load() { if ft.started.Load() {
@ -59,6 +82,7 @@ func TestDefaultDiscoverer(t *testing.T) {
ts.dialCh = make(chan string) ts.dialCh = make(chan string)
d := NewDefaultDiscovery(nil, time.Second/16, ts) d := NewDefaultDiscovery(nil, time.Second/16, ts)
tryMaxWait = 1 // Don't waste time.
var set1 = []string{"1.1.1.1:10333", "2.2.2.2:10333"} var set1 = []string{"1.1.1.1:10333", "2.2.2.2:10333"}
sort.Strings(set1) sort.Strings(set1)
@ -83,7 +107,7 @@ func TestDefaultDiscoverer(t *testing.T) {
select { select {
case a := <-ts.dialCh: case a := <-ts.dialCh:
dialled = append(dialled, a) dialled = append(dialled, a)
d.RegisterConnectedAddr(a) d.RegisterConnected(&fakeAPeer{addr: a, peer: a})
case <-time.After(time.Second): case <-time.After(time.Second):
t.Fatalf("timeout expecting for transport dial") t.Fatalf("timeout expecting for transport dial")
} }
@ -97,10 +121,14 @@ func TestDefaultDiscoverer(t *testing.T) {
// Registered good addresses should end up in appropriate set. // Registered good addresses should end up in appropriate set.
for _, addr := range set1 { for _, addr := range set1 {
d.RegisterGoodAddr(addr, capability.Capabilities{ d.RegisterGood(&fakeAPeer{
{ addr: addr,
Type: capability.FullNode, peer: addr,
Data: &capability.Node{StartHeight: 123}, version: &payload.Version{
Capabilities: capability.Capabilities{{
Type: capability.FullNode,
Data: &capability.Node{StartHeight: 123},
}},
}, },
}) })
} }
@ -130,7 +158,7 @@ func TestDefaultDiscoverer(t *testing.T) {
// Unregistering connected should work. // Unregistering connected should work.
for _, addr := range set1 { for _, addr := range set1 {
d.UnregisterConnectedAddr(addr) d.UnregisterConnected(&fakeAPeer{addr: addr, peer: addr}, false)
} }
assert.Equal(t, 2, len(d.UnconnectedPeers())) // They're re-added automatically. assert.Equal(t, 2, len(d.UnconnectedPeers())) // They're re-added automatically.
assert.Equal(t, 0, len(d.BadPeers())) assert.Equal(t, 0, len(d.BadPeers()))
@ -184,6 +212,7 @@ func TestSeedDiscovery(t *testing.T) {
sort.Strings(seeds) sort.Strings(seeds)
d := NewDefaultDiscovery(seeds, time.Second/10, ts) d := NewDefaultDiscovery(seeds, time.Second/10, ts)
tryMaxWait = 1 // Don't waste time.
d.RequestRemote(len(seeds)) d.RequestRemote(len(seeds))
for i := 0; i < connRetries*2; i++ { for i := 0; i < connRetries*2; i++ {

View file

@ -13,7 +13,6 @@ import (
"github.com/nspcc-dev/neo-go/internal/fakechain" "github.com/nspcc-dev/neo-go/internal/fakechain"
"github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
@ -35,10 +34,10 @@ func (d *testDiscovery) BackFill(addrs ...string) {
d.backfill = append(d.backfill, addrs...) d.backfill = append(d.backfill, addrs...)
} }
func (d *testDiscovery) PoolCount() int { return 0 } func (d *testDiscovery) PoolCount() int { return 0 }
func (d *testDiscovery) RegisterBadAddr(addr string) { func (d *testDiscovery) RegisterSelf(p AddressablePeer) {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
d.bad = append(d.bad, addr) d.bad = append(d.bad, p.ConnectionAddr())
} }
func (d *testDiscovery) GetFanOut() int { func (d *testDiscovery) GetFanOut() int {
d.Lock() d.Lock()
@ -50,16 +49,16 @@ func (d *testDiscovery) NetworkSize() int {
defer d.Unlock() defer d.Unlock()
return len(d.connected) + len(d.backfill) return len(d.connected) + len(d.backfill)
} }
func (d *testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {} func (d *testDiscovery) RegisterGood(AddressablePeer) {}
func (d *testDiscovery) RegisterConnectedAddr(addr string) { func (d *testDiscovery) RegisterConnected(p AddressablePeer) {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
d.connected = append(d.connected, addr) d.connected = append(d.connected, p.ConnectionAddr())
} }
func (d *testDiscovery) UnregisterConnectedAddr(addr string) { func (d *testDiscovery) UnregisterConnected(p AddressablePeer, force bool) {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
d.unregistered = append(d.unregistered, addr) d.unregistered = append(d.unregistered, p.ConnectionAddr())
} }
func (d *testDiscovery) UnconnectedPeers() []string { func (d *testDiscovery) UnconnectedPeers() []string {
d.Lock() d.Lock()
@ -100,6 +99,9 @@ func newLocalPeer(t *testing.T, s *Server) *localPeer {
} }
} }
func (p *localPeer) ConnectionAddr() string {
return p.netaddr.String()
}
func (p *localPeer) RemoteAddr() net.Addr { func (p *localPeer) RemoteAddr() net.Addr {
return &p.netaddr return &p.netaddr
} }

View file

@ -7,10 +7,12 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
) )
// Peer represents a network node neo-go is connected to. type AddressablePeer interface {
type Peer interface { // ConnectionAddr returns an address-like identifier of this connection
// RemoteAddr returns the remote address that we're connected to now. // before we have a proper one (after the handshake). It's either the
RemoteAddr() net.Addr // address from discoverer (if initiated from node) or one from socket
// (if connected to node from outside).
ConnectionAddr() string
// PeerAddr returns the remote address that should be used to establish // PeerAddr returns the remote address that should be used to establish
// a new connection to the node. It can differ from the RemoteAddr // a new connection to the node. It can differ from the RemoteAddr
// address in case the remote node is a client and its current // address in case the remote node is a client and its current
@ -18,6 +20,16 @@ type Peer interface {
// to connect to it. It's only valid after the handshake is completed. // to connect to it. It's only valid after the handshake is completed.
// Before that, it returns the same address as RemoteAddr. // Before that, it returns the same address as RemoteAddr.
PeerAddr() net.Addr PeerAddr() net.Addr
// Version returns peer's version message if the peer has handshaked
// already.
Version() *payload.Version
}
// Peer represents a network node neo-go is connected to.
type Peer interface {
AddressablePeer
// RemoteAddr returns the remote address that we're connected to now.
RemoteAddr() net.Addr
Disconnect(error) Disconnect(error)
// BroadcastPacket is a context-bound packet enqueuer, it either puts the // BroadcastPacket is a context-bound packet enqueuer, it either puts the
@ -49,7 +61,6 @@ type Peer interface {
// EnqueueHPPacket is similar to EnqueueHPMessage, but accepts a slice of // EnqueueHPPacket is similar to EnqueueHPMessage, but accepts a slice of
// message(s) bytes. // message(s) bytes.
EnqueueHPPacket([]byte) error EnqueueHPPacket([]byte) error
Version() *payload.Version
LastBlockIndex() uint32 LastBlockIndex() uint32
Handshaked() bool Handshaked() bool
IsFullNode() bool IsFullNode() bool

View file

@ -127,6 +127,7 @@ type (
register chan Peer register chan Peer
unregister chan peerDrop unregister chan peerDrop
handshake chan Peer
quit chan struct{} quit chan struct{}
relayFin chan struct{} relayFin chan struct{}
@ -181,6 +182,7 @@ func newServerFromConstructors(config ServerConfig, chain Ledger, stSync StateSy
relayFin: make(chan struct{}), relayFin: make(chan struct{}),
register: make(chan Peer), register: make(chan Peer),
unregister: make(chan peerDrop), unregister: make(chan peerDrop),
handshake: make(chan Peer),
txInMap: make(map[util.Uint256]struct{}), txInMap: make(map[util.Uint256]struct{}),
peers: make(map[Peer]bool), peers: make(map[Peer]bool),
syncReached: atomic.NewBool(false), syncReached: atomic.NewBool(false),
@ -398,10 +400,12 @@ func (s *Server) ConnectedPeers() []string {
func (s *Server) run() { func (s *Server) run() {
var ( var (
peerCheckTime = s.TimePerBlock * peerTimeFactor peerCheckTime = s.TimePerBlock * peerTimeFactor
peerCheckTimeout bool addrCheckTimeout bool
timer = time.NewTimer(peerCheckTime) addrTimer = time.NewTimer(peerCheckTime)
peerTimer = time.NewTimer(s.ProtoTickInterval)
) )
defer timer.Stop() defer addrTimer.Stop()
defer peerTimer.Stop()
go s.runProto() go s.runProto()
for loopCnt := 0; ; loopCnt++ { for loopCnt := 0; ; loopCnt++ {
var ( var (
@ -409,12 +413,16 @@ func (s *Server) run() {
// "Optimal" number of peers. // "Optimal" number of peers.
optimalN = s.discovery.GetFanOut() * 2 optimalN = s.discovery.GetFanOut() * 2
// Real number of peers. // Real number of peers.
peerN = s.PeerCount() peerN = s.HandshakedPeersCount()
// Timeout value for the next peerTimer, long one by default.
peerT = peerCheckTime
) )
if peerN < s.MinPeers { if peerN < s.MinPeers {
// Starting up or going below the minimum -> quickly get many new peers. // Starting up or going below the minimum -> quickly get many new peers.
s.discovery.RequestRemote(s.AttemptConnPeers) s.discovery.RequestRemote(s.AttemptConnPeers)
// Check/retry new connections soon.
peerT = s.ProtoTickInterval
} else if s.MinPeers > 0 && loopCnt%s.MinPeers == 0 && optimalN > peerN && optimalN < s.MaxPeers && optimalN < netSize { } else if s.MinPeers > 0 && loopCnt%s.MinPeers == 0 && optimalN > peerN && optimalN < s.MaxPeers && optimalN < netSize {
// Having some number of peers, but probably can get some more, the network is big. // Having some number of peers, but probably can get some more, the network is big.
// It also allows to start picking up new peers proactively, before we suddenly have <s.MinPeers of them. // It also allows to start picking up new peers proactively, before we suddenly have <s.MinPeers of them.
@ -425,16 +433,18 @@ func (s *Server) run() {
s.discovery.RequestRemote(connN) s.discovery.RequestRemote(connN)
} }
if peerCheckTimeout || s.discovery.PoolCount() < s.AttemptConnPeers { if addrCheckTimeout || s.discovery.PoolCount() < s.AttemptConnPeers {
s.broadcastHPMessage(NewMessage(CMDGetAddr, payload.NewNullPayload())) s.broadcastHPMessage(NewMessage(CMDGetAddr, payload.NewNullPayload()))
peerCheckTimeout = false addrCheckTimeout = false
} }
select { select {
case <-s.quit: case <-s.quit:
return return
case <-timer.C: case <-addrTimer.C:
peerCheckTimeout = true addrCheckTimeout = true
timer.Reset(peerCheckTime) addrTimer.Reset(peerCheckTime)
case <-peerTimer.C:
peerTimer.Reset(peerT)
case p := <-s.register: case p := <-s.register:
s.lock.Lock() s.lock.Lock()
s.peers[p] = true s.peers[p] = true
@ -462,31 +472,10 @@ func (s *Server) run() {
zap.Stringer("addr", drop.peer.RemoteAddr()), zap.Stringer("addr", drop.peer.RemoteAddr()),
zap.Error(drop.reason), zap.Error(drop.reason),
zap.Int("peerCount", s.PeerCount())) zap.Int("peerCount", s.PeerCount()))
addr := drop.peer.PeerAddr().String()
if errors.Is(drop.reason, errIdenticalID) { if errors.Is(drop.reason, errIdenticalID) {
s.discovery.RegisterBadAddr(addr) s.discovery.RegisterSelf(drop.peer)
} else if errors.Is(drop.reason, errAlreadyConnected) {
// There is a race condition when peer can be disconnected twice for the this reason
// which can lead to no connections to peer at all. Here we check for such a possibility.
stillConnected := false
s.lock.RLock()
verDrop := drop.peer.Version()
addr := drop.peer.PeerAddr().String()
if verDrop != nil {
for peer := range s.peers {
ver := peer.Version()
// Already connected, drop this connection.
if ver != nil && ver.Nonce == verDrop.Nonce && peer.PeerAddr().String() == addr {
stillConnected = true
}
}
}
s.lock.RUnlock()
if !stillConnected {
s.discovery.UnregisterConnectedAddr(addr)
}
} else { } else {
s.discovery.UnregisterConnectedAddr(addr) s.discovery.UnregisterConnected(drop.peer, errors.Is(drop.reason, errAlreadyConnected))
} }
updatePeersConnectedMetric(s.PeerCount()) updatePeersConnectedMetric(s.PeerCount())
} else { } else {
@ -494,6 +483,19 @@ func (s *Server) run() {
// because we have two goroutines sending signals here // because we have two goroutines sending signals here
s.lock.Unlock() s.lock.Unlock()
} }
case p := <-s.handshake:
ver := p.Version()
s.log.Info("started protocol",
zap.Stringer("addr", p.RemoteAddr()),
zap.ByteString("userAgent", ver.UserAgent),
zap.Uint32("startHeight", p.LastBlockIndex()),
zap.Uint32("id", ver.Nonce))
s.discovery.RegisterGood(p)
s.tryInitStateSync()
s.tryStartServices()
} }
} }
} }
@ -700,7 +702,6 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
} }
} }
s.lock.RUnlock() s.lock.RUnlock()
s.discovery.RegisterConnectedAddr(peerAddr)
return p.SendVersionAck(NewMessage(CMDVerack, payload.NewNullPayload())) return p.SendVersionAck(NewMessage(CMDVerack, payload.NewNullPayload()))
} }
@ -1195,11 +1196,9 @@ func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
if !p.CanProcessAddr() { if !p.CanProcessAddr() {
return errors.New("unexpected addr received") return errors.New("unexpected addr received")
} }
dups := make(map[string]bool)
for _, a := range addrs.Addrs { for _, a := range addrs.Addrs {
addr, err := a.GetTCPAddress() addr, err := a.GetTCPAddress()
if err == nil && !dups[addr] { if err == nil {
dups[addr] = true
s.discovery.BackFill(addr) s.discovery.BackFill(addr)
} }
} }
@ -1356,9 +1355,6 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return err return err
} }
go peer.StartProtocol() go peer.StartProtocol()
s.tryInitStateSync()
s.tryStartServices()
default: default:
return fmt.Errorf("received '%s' during handshake", msg.Command.String()) return fmt.Errorf("received '%s' during handshake", msg.Command.String())
} }

View file

@ -141,14 +141,17 @@ func TestServerRegisterPeer(t *testing.T) {
for i := range ps { for i := range ps {
ps[i] = newLocalPeer(t, s) ps[i] = newLocalPeer(t, s)
ps[i].netaddr.Port = i + 1 ps[i].netaddr.Port = i + 1
ps[i].version = &payload.Version{Nonce: uint32(i), UserAgent: []byte("fake")}
} }
startWithCleanup(t, s) startWithCleanup(t, s)
s.register <- ps[0] s.register <- ps[0]
require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10) require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10)
s.handshake <- ps[0]
s.register <- ps[1] s.register <- ps[1]
s.handshake <- ps[1]
require.Eventually(t, func() bool { return 2 == s.PeerCount() }, time.Second, time.Millisecond*10) require.Eventually(t, func() bool { return 2 == s.PeerCount() }, time.Second, time.Millisecond*10)
require.Equal(t, 0, len(s.discovery.UnconnectedPeers())) require.Equal(t, 0, len(s.discovery.UnconnectedPeers()))

View file

@ -13,7 +13,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap"
) )
type handShakeStage uint8 type handShakeStage uint8
@ -48,6 +47,8 @@ type TCPPeer struct {
version *payload.Version version *payload.Version
// Index of the last block. // Index of the last block.
lastBlockIndex uint32 lastBlockIndex uint32
// pre-handshake non-canonical connection address.
addr string
lock sync.RWMutex lock sync.RWMutex
finale sync.Once finale sync.Once
@ -69,10 +70,11 @@ type TCPPeer struct {
} }
// NewTCPPeer returns a TCPPeer structure based on the given connection. // NewTCPPeer returns a TCPPeer structure based on the given connection.
func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { func NewTCPPeer(conn net.Conn, addr string, s *Server) *TCPPeer {
return &TCPPeer{ return &TCPPeer{
conn: conn, conn: conn,
server: s, server: s,
addr: addr,
done: make(chan struct{}), done: make(chan struct{}),
sendQ: make(chan []byte, requestQueueSize), sendQ: make(chan []byte, requestQueueSize),
p2pSendQ: make(chan []byte, p2pMsgQueueSize), p2pSendQ: make(chan []byte, p2pMsgQueueSize),
@ -256,13 +258,8 @@ func (p *TCPPeer) handleQueues() {
func (p *TCPPeer) StartProtocol() { func (p *TCPPeer) StartProtocol() {
var err error var err error
p.server.log.Info("started protocol", p.server.handshake <- p
zap.Stringer("addr", p.RemoteAddr()),
zap.ByteString("userAgent", p.Version().UserAgent),
zap.Uint32("startHeight", p.lastBlockIndex),
zap.Uint32("id", p.Version().Nonce))
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
err = p.server.requestBlocksOrHeaders(p) err = p.server.requestBlocksOrHeaders(p)
if err != nil { if err != nil {
p.Disconnect(err) p.Disconnect(err)
@ -384,6 +381,14 @@ func (p *TCPPeer) HandleVersionAck() error {
return nil return nil
} }
// ConnectionAddr implements the Peer interface.
func (p *TCPPeer) ConnectionAddr() string {
if p.addr != "" {
return p.addr
}
return p.conn.RemoteAddr().String()
}
// RemoteAddr implements the Peer interface. // RemoteAddr implements the Peer interface.
func (p *TCPPeer) RemoteAddr() net.Addr { func (p *TCPPeer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr() return p.conn.RemoteAddr()

View file

@ -18,8 +18,8 @@ func connReadStub(conn net.Conn) {
func TestPeerHandshake(t *testing.T) { func TestPeerHandshake(t *testing.T) {
server, client := net.Pipe() server, client := net.Pipe()
tcpS := NewTCPPeer(server, newTestServer(t, ServerConfig{})) tcpS := NewTCPPeer(server, "", newTestServer(t, ServerConfig{}))
tcpC := NewTCPPeer(client, newTestServer(t, ServerConfig{})) tcpC := NewTCPPeer(client, "", newTestServer(t, ServerConfig{}))
// Something should read things written into the pipe. // Something should read things written into the pipe.
go connReadStub(tcpS.conn) go connReadStub(tcpS.conn)

View file

@ -30,14 +30,14 @@ func NewTCPTransport(s *Server, bindAddr string, log *zap.Logger) *TCPTransport
} }
// Dial implements the Transporter interface. // Dial implements the Transporter interface.
func (t *TCPTransport) Dial(addr string, timeout time.Duration) error { func (t *TCPTransport) Dial(addr string, timeout time.Duration) (AddressablePeer, error) {
conn, err := net.DialTimeout("tcp", addr, timeout) conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil { if err != nil {
return err return nil, err
} }
p := NewTCPPeer(conn, t.server) p := NewTCPPeer(conn, addr, t.server)
go p.handleConn() go p.handleConn()
return nil return p, nil
} }
// Accept implements the Transporter interface. // Accept implements the Transporter interface.
@ -69,7 +69,7 @@ func (t *TCPTransport) Accept() {
t.log.Warn("TCP accept error", zap.Error(err)) t.log.Warn("TCP accept error", zap.Error(err))
continue continue
} }
p := NewTCPPeer(conn, t.server) p := NewTCPPeer(conn, "", t.server)
go p.handleConn() go p.handleConn()
} }
} }

View file

@ -5,7 +5,7 @@ import "time"
// Transporter is an interface that allows us to abstract // Transporter is an interface that allows us to abstract
// any form of communication between the server and its peers. // any form of communication between the server and its peers.
type Transporter interface { type Transporter interface {
Dial(addr string, timeout time.Duration) error Dial(addr string, timeout time.Duration) (AddressablePeer, error)
Accept() Accept()
Proto() string Proto() string
Address() string Address() string