network: move SendVersion() to the Peer

Only leave server-specific `getVersionMsg()` in the Server, all the other
logic is peer-related.
This commit is contained in:
Roman Khimov 2020-01-21 17:26:08 +03:00
parent 9befd8de99
commit 1f672e0da7
6 changed files with 26 additions and 20 deletions

View file

@ -155,6 +155,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {}
type localPeer struct { type localPeer struct {
netaddr net.TCPAddr netaddr net.TCPAddr
server *Server
version *payload.Version version *payload.Version
lastBlockIndex uint32 lastBlockIndex uint32
handshaked bool handshaked bool
@ -163,10 +164,11 @@ type localPeer struct {
pingSent int pingSent int
} }
func newLocalPeer(t *testing.T) *localPeer { func newLocalPeer(t *testing.T, s *Server) *localPeer {
naddr, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:0") naddr, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:0")
return &localPeer{ return &localPeer{
t: t, t: t,
server: s,
netaddr: *naddr, netaddr: *naddr,
messageHandler: defaultMessageHandler, messageHandler: defaultMessageHandler,
} }
@ -210,7 +212,8 @@ func (p *localPeer) HandleVersion(v *payload.Version) error {
p.version = v p.version = v
return nil return nil
} }
func (p *localPeer) SendVersion(m *Message) error { func (p *localPeer) SendVersion() error {
m := p.server.getVersionMsg()
_ = p.EnqueueMessage(m) _ = p.EnqueueMessage(m)
return nil return nil
} }

View file

@ -42,7 +42,9 @@ type Peer interface {
// appropriate protocol handling like timeouts and outstanding pings // appropriate protocol handling like timeouts and outstanding pings
// management. // management.
SendPing() error SendPing() error
SendVersion(*Message) error // SendVersion checks handshake status and sends a version message to
// the peer.
SendVersion() error
SendVersionAck(*Message) error SendVersionAck(*Message) error
// StartProtocol is a goroutine to be run after the handshake. It // StartProtocol is a goroutine to be run after the handshake. It
// implements basic peer-related protocol handling. // implements basic peer-related protocol handling.

View file

@ -307,8 +307,8 @@ func (s *Server) HandshakedPeersCount() int {
return count return count
} }
// When a peer connects to the server, we will send our version immediately. // getVersionMsg returns current version message.
func (s *Server) sendVersion(p Peer) error { func (s *Server) getVersionMsg() *Message {
payload := payload.NewVersion( payload := payload.NewVersion(
s.id, s.id,
s.Port, s.Port,
@ -316,7 +316,7 @@ func (s *Server) sendVersion(p Peer) error {
s.chain.BlockHeight(), s.chain.BlockHeight(),
s.Relay, s.Relay,
) )
return p.SendVersion(NewMessage(s.Net, CMDVersion, payload)) return NewMessage(s.Net, CMDVersion, payload)
} }
// When a peer sends out his version we reply with verack after validating // When a peer sends out his version we reply with verack after validating

View file

@ -12,7 +12,7 @@ import (
func TestSendVersion(t *testing.T) { func TestSendVersion(t *testing.T) {
var ( var (
s = newTestServer(t) s = newTestServer(t)
p = newLocalPeer(t) p = newLocalPeer(t, s)
) )
s.Port = 3000 s.Port = 3000
s.UserAgent = "/test/" s.UserAgent = "/test/"
@ -29,7 +29,7 @@ func TestSendVersion(t *testing.T) {
assert.Equal(t, uint32(0), version.StartHeight) assert.Equal(t, uint32(0), version.StartHeight)
} }
if err := s.sendVersion(p); err != nil { if err := p.SendVersion(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
@ -38,7 +38,7 @@ func TestSendVersion(t *testing.T) {
func TestVerackAfterHandleVersionCmd(t *testing.T) { func TestVerackAfterHandleVersionCmd(t *testing.T) {
var ( var (
s = newTestServer(t) s = newTestServer(t)
p = newLocalPeer(t) p = newLocalPeer(t, s)
) )
na, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:3000") na, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:3000")
p.netaddr = *na p.netaddr = *na
@ -59,8 +59,8 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) {
func TestServerNotSendsVerack(t *testing.T) { func TestServerNotSendsVerack(t *testing.T) {
var ( var (
s = newTestServer(t) s = newTestServer(t)
p = newLocalPeer(t) p = newLocalPeer(t, s)
p2 = newLocalPeer(t) p2 = newLocalPeer(t, s)
) )
s.id = 1 s.id = 1
go s.run() go s.run()
@ -92,7 +92,7 @@ func TestServerNotSendsVerack(t *testing.T) {
func TestRequestHeaders(t *testing.T) { func TestRequestHeaders(t *testing.T) {
var ( var (
s = newTestServer(t) s = newTestServer(t)
p = newLocalPeer(t) p = newLocalPeer(t, s)
) )
p.messageHandler = func(t *testing.T, msg *Message) { p.messageHandler = func(t *testing.T, msg *Message) {
assert.IsType(t, &payload.GetBlocks{}, msg.Payload) assert.IsType(t, &payload.GetBlocks{}, msg.Payload)

View file

@ -117,7 +117,7 @@ func (p *TCPPeer) handleConn() {
go p.handleQueues() go p.handleQueues()
// When a new peer is connected we send out our version immediately. // When a new peer is connected we send out our version immediately.
err = p.server.sendVersion(p) err = p.SendVersion()
if err == nil { if err == nil {
r := io.NewBinReaderFromIO(p.conn) r := io.NewBinReaderFromIO(p.conn)
for { for {
@ -235,7 +235,8 @@ func (p *TCPPeer) Handshaked() bool {
} }
// SendVersion checks for the handshake state and sends a message to the peer. // SendVersion checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersion(msg *Message) error { func (p *TCPPeer) SendVersion() error {
msg := p.server.getVersionMsg()
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
if p.handShake&versionSent != 0 { if p.handShake&versionSent != 0 {

View file

@ -18,8 +18,8 @@ func connReadStub(conn net.Conn) {
func TestPeerHandshake(t *testing.T) { func TestPeerHandshake(t *testing.T) {
server, client := net.Pipe() server, client := net.Pipe()
tcpS := NewTCPPeer(server, nil) tcpS := NewTCPPeer(server, newTestServer(t))
tcpC := NewTCPPeer(client, nil) tcpC := NewTCPPeer(client, newTestServer(t))
// Something should read things written into the pipe. // Something should read things written into the pipe.
go connReadStub(tcpS.conn) go connReadStub(tcpS.conn)
@ -45,22 +45,22 @@ func TestPeerHandshake(t *testing.T) {
// Now send and handle versions, but in a different order on client and // Now send and handle versions, but in a different order on client and
// server. // server.
require.NoError(t, tcpC.SendVersion(&Message{})) require.NoError(t, tcpC.SendVersion())
require.Error(t, tcpC.HandleVersionAck()) // Didn't receive version yet. require.Error(t, tcpC.HandleVersionAck()) // Didn't receive version yet.
require.NoError(t, tcpS.HandleVersion(&payload.Version{})) require.NoError(t, tcpS.HandleVersion(&payload.Version{}))
require.Error(t, tcpS.SendVersionAck(&Message{})) // Didn't send version yet. require.Error(t, tcpS.SendVersionAck(&Message{})) // Didn't send version yet.
require.NoError(t, tcpC.HandleVersion(&payload.Version{})) require.NoError(t, tcpC.HandleVersion(&payload.Version{}))
require.NoError(t, tcpS.SendVersion(&Message{})) require.NoError(t, tcpS.SendVersion())
// No handshake yet. // No handshake yet.
require.Equal(t, false, tcpS.Handshaked()) require.Equal(t, false, tcpS.Handshaked())
require.Equal(t, false, tcpC.Handshaked()) require.Equal(t, false, tcpC.Handshaked())
// These are sent/received and should fail now. // These are sent/received and should fail now.
require.Error(t, tcpC.SendVersion(&Message{})) require.Error(t, tcpC.SendVersion())
require.Error(t, tcpS.HandleVersion(&payload.Version{})) require.Error(t, tcpS.HandleVersion(&payload.Version{}))
require.Error(t, tcpC.HandleVersion(&payload.Version{})) require.Error(t, tcpC.HandleVersion(&payload.Version{}))
require.Error(t, tcpS.SendVersion(&Message{})) require.Error(t, tcpS.SendVersion())
// Now send and handle ACK, again in a different order on client and // Now send and handle ACK, again in a different order on client and
// server. // server.