diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 30ba4321a..0c6bd9281 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -155,6 +155,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {} type localPeer struct { netaddr net.TCPAddr + server *Server version *payload.Version lastBlockIndex uint32 handshaked bool @@ -163,10 +164,11 @@ type localPeer struct { pingSent int } -func newLocalPeer(t *testing.T) *localPeer { +func newLocalPeer(t *testing.T, s *Server) *localPeer { naddr, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:0") return &localPeer{ t: t, + server: s, netaddr: *naddr, messageHandler: defaultMessageHandler, } @@ -210,7 +212,8 @@ func (p *localPeer) HandleVersion(v *payload.Version) error { p.version = v return nil } -func (p *localPeer) SendVersion(m *Message) error { +func (p *localPeer) SendVersion() error { + m := p.server.getVersionMsg() _ = p.EnqueueMessage(m) return nil } diff --git a/pkg/network/peer.go b/pkg/network/peer.go index d063c5ddf..9f2443d0c 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -42,7 +42,9 @@ type Peer interface { // appropriate protocol handling like timeouts and outstanding pings // management. SendPing() error - SendVersion(*Message) error + // SendVersion checks handshake status and sends a version message to + // the peer. + SendVersion() error SendVersionAck(*Message) error // StartProtocol is a goroutine to be run after the handshake. It // implements basic peer-related protocol handling. diff --git a/pkg/network/server.go b/pkg/network/server.go index 3bf0fa249..082d3c519 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -307,8 +307,8 @@ func (s *Server) HandshakedPeersCount() int { return count } -// When a peer connects to the server, we will send our version immediately. -func (s *Server) sendVersion(p Peer) error { +// getVersionMsg returns current version message. +func (s *Server) getVersionMsg() *Message { payload := payload.NewVersion( s.id, s.Port, @@ -316,7 +316,7 @@ func (s *Server) sendVersion(p Peer) error { s.chain.BlockHeight(), s.Relay, ) - return p.SendVersion(NewMessage(s.Net, CMDVersion, payload)) + return NewMessage(s.Net, CMDVersion, payload) } // When a peer sends out his version we reply with verack after validating diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 39f2caedc..f5dede1bd 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -12,7 +12,7 @@ import ( func TestSendVersion(t *testing.T) { var ( s = newTestServer(t) - p = newLocalPeer(t) + p = newLocalPeer(t, s) ) s.Port = 3000 s.UserAgent = "/test/" @@ -29,7 +29,7 @@ func TestSendVersion(t *testing.T) { assert.Equal(t, uint32(0), version.StartHeight) } - if err := s.sendVersion(p); err != nil { + if err := p.SendVersion(); err != nil { t.Fatal(err) } } @@ -38,7 +38,7 @@ func TestSendVersion(t *testing.T) { func TestVerackAfterHandleVersionCmd(t *testing.T) { var ( s = newTestServer(t) - p = newLocalPeer(t) + p = newLocalPeer(t, s) ) na, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:3000") p.netaddr = *na @@ -59,8 +59,8 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) { func TestServerNotSendsVerack(t *testing.T) { var ( s = newTestServer(t) - p = newLocalPeer(t) - p2 = newLocalPeer(t) + p = newLocalPeer(t, s) + p2 = newLocalPeer(t, s) ) s.id = 1 go s.run() @@ -92,7 +92,7 @@ func TestServerNotSendsVerack(t *testing.T) { func TestRequestHeaders(t *testing.T) { var ( s = newTestServer(t) - p = newLocalPeer(t) + p = newLocalPeer(t, s) ) p.messageHandler = func(t *testing.T, msg *Message) { assert.IsType(t, &payload.GetBlocks{}, msg.Payload) diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 5683cfe39..b41726bbc 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -117,7 +117,7 @@ func (p *TCPPeer) handleConn() { go p.handleQueues() // When a new peer is connected we send out our version immediately. - err = p.server.sendVersion(p) + err = p.SendVersion() if err == nil { r := io.NewBinReaderFromIO(p.conn) for { @@ -235,7 +235,8 @@ func (p *TCPPeer) Handshaked() bool { } // SendVersion checks for the handshake state and sends a message to the peer. -func (p *TCPPeer) SendVersion(msg *Message) error { +func (p *TCPPeer) SendVersion() error { + msg := p.server.getVersionMsg() p.lock.Lock() defer p.lock.Unlock() if p.handShake&versionSent != 0 { diff --git a/pkg/network/tcp_peer_test.go b/pkg/network/tcp_peer_test.go index 691e22b6d..5e2e6366d 100644 --- a/pkg/network/tcp_peer_test.go +++ b/pkg/network/tcp_peer_test.go @@ -18,8 +18,8 @@ func connReadStub(conn net.Conn) { func TestPeerHandshake(t *testing.T) { server, client := net.Pipe() - tcpS := NewTCPPeer(server, nil) - tcpC := NewTCPPeer(client, nil) + tcpS := NewTCPPeer(server, newTestServer(t)) + tcpC := NewTCPPeer(client, newTestServer(t)) // Something should read things written into the pipe. go connReadStub(tcpS.conn) @@ -45,22 +45,22 @@ func TestPeerHandshake(t *testing.T) { // Now send and handle versions, but in a different order on client and // server. - require.NoError(t, tcpC.SendVersion(&Message{})) + require.NoError(t, tcpC.SendVersion()) require.Error(t, tcpC.HandleVersionAck()) // Didn't receive version yet. require.NoError(t, tcpS.HandleVersion(&payload.Version{})) require.Error(t, tcpS.SendVersionAck(&Message{})) // Didn't send version yet. require.NoError(t, tcpC.HandleVersion(&payload.Version{})) - require.NoError(t, tcpS.SendVersion(&Message{})) + require.NoError(t, tcpS.SendVersion()) // No handshake yet. require.Equal(t, false, tcpS.Handshaked()) require.Equal(t, false, tcpC.Handshaked()) // These are sent/received and should fail now. - require.Error(t, tcpC.SendVersion(&Message{})) + require.Error(t, tcpC.SendVersion()) require.Error(t, tcpS.HandleVersion(&payload.Version{})) require.Error(t, tcpC.HandleVersion(&payload.Version{})) - require.Error(t, tcpS.SendVersion(&Message{})) + require.Error(t, tcpS.SendVersion()) // Now send and handle ACK, again in a different order on client and // server.