Merge pull request #2811 from nspcc-dev/network-connections-fix

Network connections fix
This commit is contained in:
Roman Khimov 2022-11-18 14:26:25 +07:00 committed by GitHub
commit bb9f17108e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 213 additions and 112 deletions

View file

@ -2,6 +2,7 @@ package network
import (
"math"
"math/rand"
"sync"
"sync/atomic"
"time"
@ -14,6 +15,11 @@ const (
connRetries = 3
)
var (
// Maximum waiting time before connection attempt.
tryMaxWait = time.Second / 2
)
// Discoverer is an interface that is responsible for maintaining
// a healthy connection pool.
type Discoverer interface {
@ -22,10 +28,10 @@ type Discoverer interface {
NetworkSize() int
PoolCount() int
RequestRemote(int)
RegisterBadAddr(string)
RegisterGoodAddr(string, capability.Capabilities)
RegisterConnectedAddr(string)
UnregisterConnectedAddr(string)
RegisterSelf(AddressablePeer)
RegisterGood(AddressablePeer)
RegisterConnected(AddressablePeer)
UnregisterConnected(AddressablePeer, bool)
UnconnectedPeers() []string
BadPeers() []string
GoodPeers() []AddressWithCapabilities
@ -39,15 +45,17 @@ type AddressWithCapabilities struct {
// DefaultDiscovery default implementation of the Discoverer interface.
type DefaultDiscovery struct {
seeds []string
seeds map[string]string
transport Transporter
lock sync.RWMutex
dialTimeout time.Duration
badAddrs map[string]bool
connectedAddrs map[string]bool
handshakedAddrs map[string]bool
goodAddrs map[string]capability.Capabilities
unconnectedAddrs map[string]int
attempted map[string]bool
outstanding int32
optimalFanOut int32
networkSize int32
requestCh chan int
@ -55,12 +63,17 @@ type DefaultDiscovery struct {
// NewDefaultDiscovery returns a new DefaultDiscovery.
func NewDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) *DefaultDiscovery {
var seeds = make(map[string]string)
for i := range addrs {
seeds[addrs[i]] = ""
}
d := &DefaultDiscovery{
seeds: addrs,
seeds: seeds,
transport: ts,
dialTimeout: dt,
badAddrs: make(map[string]bool),
connectedAddrs: make(map[string]bool),
handshakedAddrs: make(map[string]bool),
goodAddrs: make(map[string]capability.Capabilities),
unconnectedAddrs: make(map[string]int),
attempted: make(map[string]bool),
@ -83,7 +96,7 @@ func (d *DefaultDiscovery) BackFill(addrs ...string) {
func (d *DefaultDiscovery) backfill(addrs ...string) {
for _, addr := range addrs {
if d.badAddrs[addr] || d.connectedAddrs[addr] ||
if d.badAddrs[addr] || d.connectedAddrs[addr] || d.handshakedAddrs[addr] ||
d.unconnectedAddrs[addr] > 0 {
continue
}
@ -113,11 +126,13 @@ func (d *DefaultDiscovery) pushToPoolOrDrop(addr string) {
// RequestRemote tries to establish a connection with n nodes.
func (d *DefaultDiscovery) RequestRemote(requested int) {
outstanding := int(atomic.LoadInt32(&d.outstanding))
requested -= outstanding
for ; requested > 0; requested-- {
var nextAddr string
d.lock.Lock()
for addr := range d.unconnectedAddrs {
if !d.connectedAddrs[addr] && !d.attempted[addr] {
if !d.connectedAddrs[addr] && !d.handshakedAddrs[addr] && !d.attempted[addr] {
nextAddr = addr
break
}
@ -125,8 +140,8 @@ func (d *DefaultDiscovery) RequestRemote(requested int) {
if nextAddr == "" {
// Empty pool, try seeds.
for _, addr := range d.seeds {
if !d.connectedAddrs[addr] && !d.attempted[addr] {
for addr, ip := range d.seeds {
if ip == "" && !d.attempted[addr] {
nextAddr = addr
break
}
@ -140,30 +155,38 @@ func (d *DefaultDiscovery) RequestRemote(requested int) {
}
d.attempted[nextAddr] = true
d.lock.Unlock()
atomic.AddInt32(&d.outstanding, 1)
go d.tryAddress(nextAddr)
}
}
// RegisterBadAddr registers the given address as a bad address.
func (d *DefaultDiscovery) RegisterBadAddr(addr string) {
var isSeed bool
// RegisterSelf registers the given Peer as a bad one, because it's our own node.
func (d *DefaultDiscovery) RegisterSelf(p AddressablePeer) {
var connaddr = p.ConnectionAddr()
d.lock.Lock()
for _, seed := range d.seeds {
if addr == seed {
isSeed = true
break
delete(d.connectedAddrs, connaddr)
d.registerBad(connaddr, true)
d.registerBad(p.PeerAddr().String(), true)
d.lock.Unlock()
}
func (d *DefaultDiscovery) registerBad(addr string, force bool) {
_, isSeed := d.seeds[addr]
if isSeed {
if !force {
d.seeds[addr] = ""
} else {
d.seeds[addr] = "forever" // That's our own address, so never try connecting to it.
}
}
if !isSeed {
} else {
d.unconnectedAddrs[addr]--
if d.unconnectedAddrs[addr] <= 0 {
if d.unconnectedAddrs[addr] <= 0 || force {
d.badAddrs[addr] = true
delete(d.unconnectedAddrs, addr)
delete(d.goodAddrs, addr)
}
}
d.updateNetSize()
d.lock.Unlock()
}
// UnconnectedPeers returns all addresses of unconnected addrs.
@ -203,31 +226,53 @@ func (d *DefaultDiscovery) GoodPeers() []AddressWithCapabilities {
return addrs
}
// RegisterGoodAddr registers a known good connected address that has passed
// RegisterGood registers a known good connected peer that has passed
// handshake successfully.
func (d *DefaultDiscovery) RegisterGoodAddr(s string, c capability.Capabilities) {
func (d *DefaultDiscovery) RegisterGood(p AddressablePeer) {
var s = p.PeerAddr().String()
d.lock.Lock()
d.goodAddrs[s] = c
d.handshakedAddrs[s] = true
d.goodAddrs[s] = p.Version().Capabilities
delete(d.badAddrs, s)
d.lock.Unlock()
}
// UnregisterConnectedAddr tells the discoverer that this address is no longer
// UnregisterConnected tells the discoverer that this peer is no longer
// connected, but it is still considered a good one.
func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) {
func (d *DefaultDiscovery) UnregisterConnected(p AddressablePeer, duplicate bool) {
var (
peeraddr = p.PeerAddr().String()
connaddr = p.ConnectionAddr()
)
d.lock.Lock()
delete(d.connectedAddrs, s)
d.backfill(s)
delete(d.connectedAddrs, connaddr)
if !duplicate {
for addr, ip := range d.seeds {
if ip == peeraddr {
d.seeds[addr] = ""
break
}
}
delete(d.handshakedAddrs, peeraddr)
if _, ok := d.goodAddrs[peeraddr]; ok {
d.backfill(peeraddr)
}
}
d.lock.Unlock()
}
// RegisterConnectedAddr tells discoverer that the given address is now connected.
func (d *DefaultDiscovery) RegisterConnectedAddr(addr string) {
// RegisterConnected tells discoverer that the given peer is now connected.
func (d *DefaultDiscovery) RegisterConnected(p AddressablePeer) {
var addr = p.ConnectionAddr()
d.lock.Lock()
d.registerConnected(addr)
d.lock.Unlock()
}
func (d *DefaultDiscovery) registerConnected(addr string) {
delete(d.unconnectedAddrs, addr)
d.connectedAddrs[addr] = true
d.updateNetSize()
d.lock.Unlock()
}
// GetFanOut returns the optimal number of nodes to broadcast packets to.
@ -242,9 +287,9 @@ func (d *DefaultDiscovery) NetworkSize() int {
// updateNetSize updates network size estimation metric. Must be called under read lock.
func (d *DefaultDiscovery) updateNetSize() {
var netsize = len(d.connectedAddrs) + len(d.unconnectedAddrs) + 1 // 1 for the node itself.
var fanOut = 2.5 * math.Log(float64(netsize-1)) // -1 for the number of potential peers.
if netsize == 2 { // log(1) == 0.
var netsize = len(d.handshakedAddrs) + len(d.unconnectedAddrs) + 1 // 1 for the node itself.
var fanOut = 2.5 * math.Log(float64(netsize-1)) // -1 for the number of potential peers.
if netsize == 2 { // log(1) == 0.
fanOut = 1 // But we still want to push messages to the peer.
}
@ -255,12 +300,22 @@ func (d *DefaultDiscovery) updateNetSize() {
}
func (d *DefaultDiscovery) tryAddress(addr string) {
err := d.transport.Dial(addr, d.dialTimeout)
var tout = rand.Int63n(int64(tryMaxWait))
time.Sleep(time.Duration(tout)) // Have a sleep before working hard.
p, err := d.transport.Dial(addr, d.dialTimeout)
atomic.AddInt32(&d.outstanding, -1)
d.lock.Lock()
delete(d.attempted, addr)
if err == nil {
if _, ok := d.seeds[addr]; ok {
d.seeds[addr] = p.PeerAddr().String()
}
d.registerConnected(addr)
} else {
d.registerBad(addr, false)
}
d.lock.Unlock()
if err != nil {
d.RegisterBadAddr(addr)
time.Sleep(d.dialTimeout)
d.RequestRemote(1)
}

View file

@ -9,6 +9,7 @@ import (
"time"
"github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
atomic2 "go.uber.org/atomic"
@ -22,18 +23,40 @@ type fakeTransp struct {
addr string
}
type fakeAPeer struct {
addr string
peer string
version *payload.Version
}
func (f *fakeAPeer) ConnectionAddr() string {
return f.addr
}
func (f *fakeAPeer) PeerAddr() net.Addr {
tcpAddr, err := net.ResolveTCPAddr("tcp", f.peer)
if err != nil {
panic(err)
}
return tcpAddr
}
func (f *fakeAPeer) Version() *payload.Version {
return f.version
}
func newFakeTransp(s *Server) Transporter {
return &fakeTransp{}
}
func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error {
func (ft *fakeTransp) Dial(addr string, timeout time.Duration) (AddressablePeer, error) {
var ret error
if atomic.LoadInt32(&ft.retFalse) > 0 {
ret = errors.New("smth bad happened")
}
ft.dialCh <- addr
return ret
return &fakeAPeer{addr: addr, peer: addr}, ret
}
func (ft *fakeTransp) Accept() {
if ft.started.Load() {
@ -59,6 +82,7 @@ func TestDefaultDiscoverer(t *testing.T) {
ts.dialCh = make(chan string)
d := NewDefaultDiscovery(nil, time.Second/16, ts)
tryMaxWait = 1 // Don't waste time.
var set1 = []string{"1.1.1.1:10333", "2.2.2.2:10333"}
sort.Strings(set1)
@ -83,7 +107,7 @@ func TestDefaultDiscoverer(t *testing.T) {
select {
case a := <-ts.dialCh:
dialled = append(dialled, a)
d.RegisterConnectedAddr(a)
d.RegisterConnected(&fakeAPeer{addr: a, peer: a})
case <-time.After(time.Second):
t.Fatalf("timeout expecting for transport dial")
}
@ -97,10 +121,14 @@ func TestDefaultDiscoverer(t *testing.T) {
// Registered good addresses should end up in appropriate set.
for _, addr := range set1 {
d.RegisterGoodAddr(addr, capability.Capabilities{
{
Type: capability.FullNode,
Data: &capability.Node{StartHeight: 123},
d.RegisterGood(&fakeAPeer{
addr: addr,
peer: addr,
version: &payload.Version{
Capabilities: capability.Capabilities{{
Type: capability.FullNode,
Data: &capability.Node{StartHeight: 123},
}},
},
})
}
@ -130,7 +158,7 @@ func TestDefaultDiscoverer(t *testing.T) {
// Unregistering connected should work.
for _, addr := range set1 {
d.UnregisterConnectedAddr(addr)
d.UnregisterConnected(&fakeAPeer{addr: addr, peer: addr}, false)
}
assert.Equal(t, 2, len(d.UnconnectedPeers())) // They're re-added automatically.
assert.Equal(t, 0, len(d.BadPeers()))
@ -184,6 +212,7 @@ func TestSeedDiscovery(t *testing.T) {
sort.Strings(seeds)
d := NewDefaultDiscovery(seeds, time.Second/10, ts)
tryMaxWait = 1 // Don't waste time.
d.RequestRemote(len(seeds))
for i := 0; i < connRetries*2; i++ {

View file

@ -13,7 +13,6 @@ import (
"github.com/nspcc-dev/neo-go/internal/fakechain"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
@ -35,10 +34,10 @@ func (d *testDiscovery) BackFill(addrs ...string) {
d.backfill = append(d.backfill, addrs...)
}
func (d *testDiscovery) PoolCount() int { return 0 }
func (d *testDiscovery) RegisterBadAddr(addr string) {
func (d *testDiscovery) RegisterSelf(p AddressablePeer) {
d.Lock()
defer d.Unlock()
d.bad = append(d.bad, addr)
d.bad = append(d.bad, p.ConnectionAddr())
}
func (d *testDiscovery) GetFanOut() int {
d.Lock()
@ -50,16 +49,16 @@ func (d *testDiscovery) NetworkSize() int {
defer d.Unlock()
return len(d.connected) + len(d.backfill)
}
func (d *testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {}
func (d *testDiscovery) RegisterConnectedAddr(addr string) {
func (d *testDiscovery) RegisterGood(AddressablePeer) {}
func (d *testDiscovery) RegisterConnected(p AddressablePeer) {
d.Lock()
defer d.Unlock()
d.connected = append(d.connected, addr)
d.connected = append(d.connected, p.ConnectionAddr())
}
func (d *testDiscovery) UnregisterConnectedAddr(addr string) {
func (d *testDiscovery) UnregisterConnected(p AddressablePeer, force bool) {
d.Lock()
defer d.Unlock()
d.unregistered = append(d.unregistered, addr)
d.unregistered = append(d.unregistered, p.ConnectionAddr())
}
func (d *testDiscovery) UnconnectedPeers() []string {
d.Lock()
@ -100,6 +99,9 @@ func newLocalPeer(t *testing.T, s *Server) *localPeer {
}
}
func (p *localPeer) ConnectionAddr() string {
return p.netaddr.String()
}
func (p *localPeer) RemoteAddr() net.Addr {
return &p.netaddr
}

View file

@ -7,10 +7,12 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/payload"
)
// Peer represents a network node neo-go is connected to.
type Peer interface {
// RemoteAddr returns the remote address that we're connected to now.
RemoteAddr() net.Addr
type AddressablePeer interface {
// ConnectionAddr returns an address-like identifier of this connection
// before we have a proper one (after the handshake). It's either the
// address from discoverer (if initiated from node) or one from socket
// (if connected to node from outside).
ConnectionAddr() string
// PeerAddr returns the remote address that should be used to establish
// a new connection to the node. It can differ from the RemoteAddr
// address in case the remote node is a client and its current
@ -18,6 +20,16 @@ type Peer interface {
// to connect to it. It's only valid after the handshake is completed.
// Before that, it returns the same address as RemoteAddr.
PeerAddr() net.Addr
// Version returns peer's version message if the peer has handshaked
// already.
Version() *payload.Version
}
// Peer represents a network node neo-go is connected to.
type Peer interface {
AddressablePeer
// RemoteAddr returns the remote address that we're connected to now.
RemoteAddr() net.Addr
Disconnect(error)
// BroadcastPacket is a context-bound packet enqueuer, it either puts the
@ -49,7 +61,6 @@ type Peer interface {
// EnqueueHPPacket is similar to EnqueueHPMessage, but accepts a slice of
// message(s) bytes.
EnqueueHPPacket([]byte) error
Version() *payload.Version
LastBlockIndex() uint32
Handshaked() bool
IsFullNode() bool

View file

@ -127,6 +127,7 @@ type (
register chan Peer
unregister chan peerDrop
handshake chan Peer
quit chan struct{}
relayFin chan struct{}
@ -181,6 +182,7 @@ func newServerFromConstructors(config ServerConfig, chain Ledger, stSync StateSy
relayFin: make(chan struct{}),
register: make(chan Peer),
unregister: make(chan peerDrop),
handshake: make(chan Peer),
txInMap: make(map[util.Uint256]struct{}),
peers: make(map[Peer]bool),
syncReached: atomic.NewBool(false),
@ -398,10 +400,12 @@ func (s *Server) ConnectedPeers() []string {
func (s *Server) run() {
var (
peerCheckTime = s.TimePerBlock * peerTimeFactor
peerCheckTimeout bool
timer = time.NewTimer(peerCheckTime)
addrCheckTimeout bool
addrTimer = time.NewTimer(peerCheckTime)
peerTimer = time.NewTimer(s.ProtoTickInterval)
)
defer timer.Stop()
defer addrTimer.Stop()
defer peerTimer.Stop()
go s.runProto()
for loopCnt := 0; ; loopCnt++ {
var (
@ -409,12 +413,16 @@ func (s *Server) run() {
// "Optimal" number of peers.
optimalN = s.discovery.GetFanOut() * 2
// Real number of peers.
peerN = s.PeerCount()
peerN = s.HandshakedPeersCount()
// Timeout value for the next peerTimer, long one by default.
peerT = peerCheckTime
)
if peerN < s.MinPeers {
// Starting up or going below the minimum -> quickly get many new peers.
s.discovery.RequestRemote(s.AttemptConnPeers)
// Check/retry new connections soon.
peerT = s.ProtoTickInterval
} else if s.MinPeers > 0 && loopCnt%s.MinPeers == 0 && optimalN > peerN && optimalN < s.MaxPeers && optimalN < netSize {
// Having some number of peers, but probably can get some more, the network is big.
// It also allows to start picking up new peers proactively, before we suddenly have <s.MinPeers of them.
@ -425,16 +433,18 @@ func (s *Server) run() {
s.discovery.RequestRemote(connN)
}
if peerCheckTimeout || s.discovery.PoolCount() < s.AttemptConnPeers {
if addrCheckTimeout || s.discovery.PoolCount() < s.AttemptConnPeers {
s.broadcastHPMessage(NewMessage(CMDGetAddr, payload.NewNullPayload()))
peerCheckTimeout = false
addrCheckTimeout = false
}
select {
case <-s.quit:
return
case <-timer.C:
peerCheckTimeout = true
timer.Reset(peerCheckTime)
case <-addrTimer.C:
addrCheckTimeout = true
addrTimer.Reset(peerCheckTime)
case <-peerTimer.C:
peerTimer.Reset(peerT)
case p := <-s.register:
s.lock.Lock()
s.peers[p] = true
@ -462,31 +472,10 @@ func (s *Server) run() {
zap.Stringer("addr", drop.peer.RemoteAddr()),
zap.Error(drop.reason),
zap.Int("peerCount", s.PeerCount()))
addr := drop.peer.PeerAddr().String()
if errors.Is(drop.reason, errIdenticalID) {
s.discovery.RegisterBadAddr(addr)
} else if errors.Is(drop.reason, errAlreadyConnected) {
// There is a race condition when peer can be disconnected twice for the this reason
// which can lead to no connections to peer at all. Here we check for such a possibility.
stillConnected := false
s.lock.RLock()
verDrop := drop.peer.Version()
addr := drop.peer.PeerAddr().String()
if verDrop != nil {
for peer := range s.peers {
ver := peer.Version()
// Already connected, drop this connection.
if ver != nil && ver.Nonce == verDrop.Nonce && peer.PeerAddr().String() == addr {
stillConnected = true
}
}
}
s.lock.RUnlock()
if !stillConnected {
s.discovery.UnregisterConnectedAddr(addr)
}
s.discovery.RegisterSelf(drop.peer)
} else {
s.discovery.UnregisterConnectedAddr(addr)
s.discovery.UnregisterConnected(drop.peer, errors.Is(drop.reason, errAlreadyConnected))
}
updatePeersConnectedMetric(s.PeerCount())
} else {
@ -494,6 +483,19 @@ func (s *Server) run() {
// because we have two goroutines sending signals here
s.lock.Unlock()
}
case p := <-s.handshake:
ver := p.Version()
s.log.Info("started protocol",
zap.Stringer("addr", p.RemoteAddr()),
zap.ByteString("userAgent", ver.UserAgent),
zap.Uint32("startHeight", p.LastBlockIndex()),
zap.Uint32("id", ver.Nonce))
s.discovery.RegisterGood(p)
s.tryInitStateSync()
s.tryStartServices()
}
}
}
@ -700,7 +702,6 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
}
}
s.lock.RUnlock()
s.discovery.RegisterConnectedAddr(peerAddr)
return p.SendVersionAck(NewMessage(CMDVerack, payload.NewNullPayload()))
}
@ -1195,11 +1196,9 @@ func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
if !p.CanProcessAddr() {
return errors.New("unexpected addr received")
}
dups := make(map[string]bool)
for _, a := range addrs.Addrs {
addr, err := a.GetTCPAddress()
if err == nil && !dups[addr] {
dups[addr] = true
if err == nil {
s.discovery.BackFill(addr)
}
}
@ -1356,9 +1355,6 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return err
}
go peer.StartProtocol()
s.tryInitStateSync()
s.tryStartServices()
default:
return fmt.Errorf("received '%s' during handshake", msg.Command.String())
}

View file

@ -141,14 +141,17 @@ func TestServerRegisterPeer(t *testing.T) {
for i := range ps {
ps[i] = newLocalPeer(t, s)
ps[i].netaddr.Port = i + 1
ps[i].version = &payload.Version{Nonce: uint32(i), UserAgent: []byte("fake")}
}
startWithCleanup(t, s)
s.register <- ps[0]
require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10)
s.handshake <- ps[0]
s.register <- ps[1]
s.handshake <- ps[1]
require.Eventually(t, func() bool { return 2 == s.PeerCount() }, time.Second, time.Millisecond*10)
require.Equal(t, 0, len(s.discovery.UnconnectedPeers()))

View file

@ -13,7 +13,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"go.uber.org/atomic"
"go.uber.org/zap"
)
type handShakeStage uint8
@ -48,6 +47,8 @@ type TCPPeer struct {
version *payload.Version
// Index of the last block.
lastBlockIndex uint32
// pre-handshake non-canonical connection address.
addr string
lock sync.RWMutex
finale sync.Once
@ -69,10 +70,11 @@ type TCPPeer struct {
}
// NewTCPPeer returns a TCPPeer structure based on the given connection.
func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer {
func NewTCPPeer(conn net.Conn, addr string, s *Server) *TCPPeer {
return &TCPPeer{
conn: conn,
server: s,
addr: addr,
done: make(chan struct{}),
sendQ: make(chan []byte, requestQueueSize),
p2pSendQ: make(chan []byte, p2pMsgQueueSize),
@ -256,13 +258,8 @@ func (p *TCPPeer) handleQueues() {
func (p *TCPPeer) StartProtocol() {
var err error
p.server.log.Info("started protocol",
zap.Stringer("addr", p.RemoteAddr()),
zap.ByteString("userAgent", p.Version().UserAgent),
zap.Uint32("startHeight", p.lastBlockIndex),
zap.Uint32("id", p.Version().Nonce))
p.server.handshake <- p
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
err = p.server.requestBlocksOrHeaders(p)
if err != nil {
p.Disconnect(err)
@ -384,6 +381,14 @@ func (p *TCPPeer) HandleVersionAck() error {
return nil
}
// ConnectionAddr implements the Peer interface.
func (p *TCPPeer) ConnectionAddr() string {
if p.addr != "" {
return p.addr
}
return p.conn.RemoteAddr().String()
}
// RemoteAddr implements the Peer interface.
func (p *TCPPeer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr()

View file

@ -18,8 +18,8 @@ func connReadStub(conn net.Conn) {
func TestPeerHandshake(t *testing.T) {
server, client := net.Pipe()
tcpS := NewTCPPeer(server, newTestServer(t, ServerConfig{}))
tcpC := NewTCPPeer(client, newTestServer(t, ServerConfig{}))
tcpS := NewTCPPeer(server, "", newTestServer(t, ServerConfig{}))
tcpC := NewTCPPeer(client, "", newTestServer(t, ServerConfig{}))
// Something should read things written into the pipe.
go connReadStub(tcpS.conn)

View file

@ -30,14 +30,14 @@ func NewTCPTransport(s *Server, bindAddr string, log *zap.Logger) *TCPTransport
}
// Dial implements the Transporter interface.
func (t *TCPTransport) Dial(addr string, timeout time.Duration) error {
func (t *TCPTransport) Dial(addr string, timeout time.Duration) (AddressablePeer, error) {
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return err
return nil, err
}
p := NewTCPPeer(conn, t.server)
p := NewTCPPeer(conn, addr, t.server)
go p.handleConn()
return nil
return p, nil
}
// Accept implements the Transporter interface.
@ -69,7 +69,7 @@ func (t *TCPTransport) Accept() {
t.log.Warn("TCP accept error", zap.Error(err))
continue
}
p := NewTCPPeer(conn, t.server)
p := NewTCPPeer(conn, "", t.server)
go p.handleConn()
}
}

View file

@ -5,7 +5,7 @@ import "time"
// Transporter is an interface that allows us to abstract
// any form of communication between the server and its peers.
type Transporter interface {
Dial(addr string, timeout time.Duration) error
Dial(addr string, timeout time.Duration) (AddressablePeer, error)
Accept()
Proto() string
Address() string