diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 2ad71205d..41530a452 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -158,7 +158,10 @@ func newLocalPeer(t *testing.T) *localPeer { } } -func (p *localPeer) NetAddr() *net.TCPAddr { +func (p *localPeer) RemoteAddr() net.Addr { + return &p.netaddr +} +func (p *localPeer) PeerAddr() net.Addr { return &p.netaddr } func (p *localPeer) Disconnect(err error) {} diff --git a/pkg/network/peer.go b/pkg/network/peer.go index 620aa3d91..f56beab01 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -8,7 +8,15 @@ import ( // Peer represents a network node neo-go is connected to. type Peer interface { - NetAddr() *net.TCPAddr + // RemoteAddr returns the remote address that we're connected to now. + RemoteAddr() net.Addr + // PeerAddr returns the remote address that should be used to establish + // a new connection to the node. It can differ from the RemoteAddr + // address in case where the remote node is a client and its current + // connection port is different from the one the other node should use + // to connect to it. It's only valid after the handshake is completed, + // before that it returns the same address as RemoteAddr. + PeerAddr() net.Addr Disconnect(error) WriteMsg(msg *Message) error Done() chan error diff --git a/pkg/network/server.go b/pkg/network/server.go index e24f479ab..8496a478b 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -160,12 +160,12 @@ func (s *Server) run() { // When a new peer is connected we send out our version immediately. if err := s.sendVersion(p); err != nil { log.WithFields(log.Fields{ - "addr": p.NetAddr(), + "addr": p.RemoteAddr(), }).Error(err) } s.peers[p] = true log.WithFields(log.Fields{ - "addr": p.NetAddr(), + "addr": p.RemoteAddr(), }).Info("new peer connected") updatePeersConnectedMetric(s.PeerCount()) @@ -173,11 +173,11 @@ func (s *Server) run() { if s.peers[drop.peer] { delete(s.peers, drop.peer) log.WithFields(log.Fields{ - "addr": drop.peer.NetAddr(), + "addr": drop.peer.RemoteAddr(), "reason": drop.reason, "peerCount": s.PeerCount(), }).Warn("peer disconnected") - addr := drop.peer.NetAddr().String() + addr := drop.peer.PeerAddr().String() s.discovery.UnregisterConnectedAddr(addr) s.discovery.BackFill(addr) updatePeersConnectedMetric(s.PeerCount()) @@ -205,13 +205,13 @@ func (s *Server) PeerCount() int { // every ProtoTickInterval with the peer. func (s *Server) startProtocol(p Peer) { log.WithFields(log.Fields{ - "addr": p.NetAddr(), + "addr": p.RemoteAddr(), "userAgent": string(p.Version().UserAgent), "startHeight": p.Version().StartHeight, "id": p.Version().Nonce, }).Info("started protocol") - s.discovery.RegisterGoodAddr(p.NetAddr().String()) + s.discovery.RegisterGoodAddr(p.PeerAddr().String()) err := s.requestHeaders(p) if err != nil { p.Disconnect(err) diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index a9a5b5a27..b1d613f93 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -2,8 +2,8 @@ package network import ( "errors" - "fmt" "net" + "strconv" "sync" "github.com/CityOfZion/neo-go/pkg/io" @@ -30,7 +30,6 @@ var ( type TCPPeer struct { // underlying TCP connection. conn net.Conn - addr net.TCPAddr // The version of the peer. version *payload.Version @@ -44,13 +43,9 @@ type TCPPeer struct { // NewTCPPeer returns a TCPPeer structure based on the given connection. func NewTCPPeer(conn net.Conn) *TCPPeer { - raddr := conn.RemoteAddr() - // can't fail because raddr is a real connection - tcpaddr, _ := net.ResolveTCPAddr(raddr.Network(), raddr.String()) return &TCPPeer{ conn: conn, done: make(chan error, 1), - addr: *tcpaddr, } } @@ -123,9 +118,28 @@ func (p *TCPPeer) HandleVersionAck() error { return nil } -// NetAddr implements the Peer interface. -func (p *TCPPeer) NetAddr() *net.TCPAddr { - return &p.addr +// RemoteAddr implements the Peer interface. +func (p *TCPPeer) RemoteAddr() net.Addr { + return p.conn.RemoteAddr() +} + +// PeerAddr implements the Peer interface. +func (p *TCPPeer) PeerAddr() net.Addr { + remote := p.conn.RemoteAddr() + // The network can be non-tcp in unit tests. + if !p.Handshaked() || remote.Network() != "tcp" { + return p.RemoteAddr() + } + host, _, err := net.SplitHostPort(remote.String()) + if err != nil { + return p.RemoteAddr() + } + addrString := net.JoinHostPort(host, strconv.Itoa(int(p.version.Port))) + tcpAddr, err := net.ResolveTCPAddr("tcp", addrString) + if err != nil { + return p.RemoteAddr() + } + return tcpAddr } // Done implements the Peer interface and notifies diff --git a/pkg/rpc/server.go b/pkg/rpc/server.go index 2fe141fe4..eb7304264 100644 --- a/pkg/rpc/server.go +++ b/pkg/rpc/server.go @@ -196,7 +196,7 @@ Methods: } for addr := range s.coreServer.Peers() { - peers.AddPeer("connected", addr.NetAddr().String()) + peers.AddPeer("connected", addr.PeerAddr().String()) } results = peers