diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 30ba4321a..0c6bd9281 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -155,6 +155,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {} type localPeer struct { netaddr net.TCPAddr + server *Server version *payload.Version lastBlockIndex uint32 handshaked bool @@ -163,10 +164,11 @@ type localPeer struct { pingSent int } -func newLocalPeer(t *testing.T) *localPeer { +func newLocalPeer(t *testing.T, s *Server) *localPeer { naddr, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:0") return &localPeer{ t: t, + server: s, netaddr: *naddr, messageHandler: defaultMessageHandler, } @@ -210,7 +212,8 @@ func (p *localPeer) HandleVersion(v *payload.Version) error { p.version = v return nil } -func (p *localPeer) SendVersion(m *Message) error { +func (p *localPeer) SendVersion() error { + m := p.server.getVersionMsg() _ = p.EnqueueMessage(m) return nil } diff --git a/pkg/network/peer.go b/pkg/network/peer.go index d063c5ddf..9f2443d0c 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -42,7 +42,9 @@ type Peer interface { // appropriate protocol handling like timeouts and outstanding pings // management. SendPing() error - SendVersion(*Message) error + // SendVersion checks handshake status and sends a version message to + // the peer. + SendVersion() error SendVersionAck(*Message) error // StartProtocol is a goroutine to be run after the handshake. It // implements basic peer-related protocol handling. diff --git a/pkg/network/server.go b/pkg/network/server.go index f96ba8383..f08e0e54c 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -61,7 +61,6 @@ type ( lock sync.RWMutex peers map[Peer]bool - addrReq chan *Message register chan Peer unregister chan peerDrop quit chan struct{} @@ -97,7 +96,6 @@ func NewServer(config ServerConfig, chain core.Blockchainer, log *zap.Logger) (* bQueue: newBlockQueue(maxBlockBatch, chain, log), id: randomID(), quit: make(chan struct{}), - addrReq: make(chan *Message, config.MinPeers), register: make(chan Peer), unregister: make(chan peerDrop), peers: make(map[Peer]bool), @@ -152,6 +150,12 @@ func NewServer(config ServerConfig, chain core.Blockchainer, log *zap.Logger) (* return s, nil } +// MkMsg creates a new message based on the server configured network and given +// parameters. +func (s *Server) MkMsg(cmd CommandType, p payload.Payload) *Message { + return NewMessage(s.Net, cmd, p) +} + // ID returns the servers ID. func (s *Server) ID() uint32 { return s.id @@ -197,13 +201,7 @@ func (s *Server) run() { s.discovery.RequestRemote(s.AttemptConnPeers) } if s.discovery.PoolCount() < minPoolCount { - select { - case s.addrReq <- NewMessage(s.Net, CMDGetAddr, payload.NewNullPayload()): - // sent request - default: - // we have one in the queue already that is - // gonna be served by some worker when it's ready - } + s.broadcastHPMessage(s.MkMsg(CMDGetAddr, payload.NewNullPayload())) } select { case <-s.quit: @@ -307,8 +305,8 @@ func (s *Server) HandshakedPeersCount() int { return count } -// When a peer connects to the server, we will send our version immediately. -func (s *Server) sendVersion(p Peer) error { +// getVersionMsg returns current version message. +func (s *Server) getVersionMsg() *Message { payload := payload.NewVersion( s.id, s.Port, @@ -316,7 +314,7 @@ func (s *Server) sendVersion(p Peer) error { s.chain.BlockHeight(), s.Relay, ) - return p.SendVersion(NewMessage(s.Net, CMDVersion, payload)) + return s.MkMsg(CMDVersion, payload) } // When a peer sends out his version we reply with verack after validating @@ -339,7 +337,7 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { } } s.lock.RUnlock() - return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil)) + return p.SendVersionAck(s.MkMsg(CMDVerack, nil)) } // handleHeadersCmd processes the headers received from its peer. @@ -367,7 +365,7 @@ func (s *Server) handleBlockCmd(p Peer, block *block.Block) error { // handlePing processes ping request. func (s *Server) handlePing(p Peer, ping *payload.Ping) error { - return p.EnqueueMessage(NewMessage(s.Net, CMDPong, payload.NewPing(s.id, s.chain.BlockHeight()))) + return p.EnqueueMessage(s.MkMsg(CMDPong, payload.NewPing(s.id, s.chain.BlockHeight()))) } // handlePing processes pong request. @@ -401,7 +399,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { } } if len(reqHashes) > 0 { - msg := NewMessage(s.Net, CMDGetData, payload.NewInventory(inv.Type, reqHashes)) + msg := s.MkMsg(CMDGetData, payload.NewInventory(inv.Type, reqHashes)) pkt, err := msg.Bytes() if err != nil { return err @@ -423,16 +421,16 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { case payload.TXType: tx, _, err := s.chain.GetTransaction(hash) if err == nil { - msg = NewMessage(s.Net, CMDTX, tx) + msg = s.MkMsg(CMDTX, tx) } case payload.BlockType: b, err := s.chain.GetBlock(hash) if err == nil { - msg = NewMessage(s.Net, CMDBlock, b) + msg = s.MkMsg(CMDBlock, b) } case payload.ConsensusType: if cp := s.consensus.GetPayload(hash); cp != nil { - msg = NewMessage(s.Net, CMDConsensus, cp) + msg = s.MkMsg(CMDConsensus, cp) } } if msg != nil { @@ -475,7 +473,7 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { return nil } payload := payload.NewInventory(payload.BlockType, blockHashes) - msg := NewMessage(s.Net, CMDInv, payload) + msg := s.MkMsg(CMDInv, payload) return p.EnqueueMessage(msg) } @@ -505,7 +503,7 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { if len(resp.Hdrs) == 0 { return nil } - msg := NewMessage(s.Net, CMDHeaders, &resp) + msg := s.MkMsg(CMDHeaders, &resp) return p.EnqueueMessage(msg) } @@ -547,7 +545,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error { netaddr, _ := net.ResolveTCPAddr("tcp", addr) alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts) } - return p.EnqueueMessage(NewMessage(s.Net, CMDAddr, alist)) + return p.EnqueueMessage(s.MkMsg(CMDAddr, alist)) } // requestHeaders sends a getheaders message to the peer. @@ -555,7 +553,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error { func (s *Server) requestHeaders(p Peer) error { start := []util.Uint256{s.chain.CurrentHeaderHash()} payload := payload.NewGetBlocks(start, util.Uint256{}) - return p.EnqueueMessage(NewMessage(s.Net, CMDGetHeaders, payload)) + return p.EnqueueMessage(s.MkMsg(CMDGetHeaders, payload)) } // requestBlocks sends a getdata message to the peer @@ -574,7 +572,7 @@ func (s *Server) requestBlocks(p Peer) error { } if len(hashes) > 0 { payload := payload.NewInventory(payload.BlockType, hashes) - return p.EnqueueMessage(NewMessage(s.Net, CMDGetData, payload)) + return p.EnqueueMessage(s.MkMsg(CMDGetData, payload)) } else if s.chain.HeaderHeight() < p.LastBlockIndex() { return s.requestHeaders(p) } @@ -656,7 +654,10 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { } func (s *Server) handleNewPayload(p *consensus.Payload) { - s.relayInventoryCmd(CMDInv, payload.ConsensusType, p.Hash()) + msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) + // It's high priority because it directly affects consensus process, + // even though it's just an inv. + s.broadcastHPMessage(msg) } // getLastBlockTime returns unix timestamp for the moment when the last block @@ -670,25 +671,44 @@ func (s *Server) requestTx(hashes ...util.Uint256) { return } - s.relayInventoryCmd(CMDGetData, payload.TXType, hashes...) + msg := s.MkMsg(CMDGetData, payload.NewInventory(payload.TXType, hashes)) + // It's high priority because it directly affects consensus process, + // even though it's getdata. + s.broadcastHPMessage(msg) } -func (s *Server) relayInventoryCmd(cmd CommandType, t payload.InventoryType, hashes ...util.Uint256) { - payload := payload.NewInventory(t, hashes) - msg := NewMessage(s.Net, cmd, payload) - +// iteratePeersWithSendMsg sends given message to all peers using two functions +// passed, one is to send the message and the other is to filtrate peers (the +// peer is considered invalid if it returns false). +func (s *Server) iteratePeersWithSendMsg(msg *Message, send func(Peer, []byte) error, peerOK func(Peer) bool) { + pkt, err := msg.Bytes() + if err != nil { + return + } + // Get a copy of s.peers to avoid holding a lock while sending. for peer := range s.Peers() { - if !peer.Handshaked() || !peer.Version().Relay { + if peerOK != nil && !peerOK(peer) { continue } // Who cares about these messages anyway? - _ = peer.EnqueueMessage(msg) + _ = send(peer, pkt) } } +// broadcastMessage sends the message to all available peers. +func (s *Server) broadcastMessage(msg *Message) { + s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, nil) +} + +// broadcastHPMessage sends the high-priority message to all available peers. +func (s *Server) broadcastHPMessage(msg *Message) { + s.iteratePeersWithSendMsg(msg, Peer.EnqueueHPPacket, nil) +} + // relayBlock tells all the other connected nodes about the given block. func (s *Server) relayBlock(b *block.Block) { - s.relayInventoryCmd(CMDInv, payload.BlockType, b.Hash()) + msg := s.MkMsg(CMDInv, payload.NewInventory(payload.BlockType, []util.Uint256{b.Hash()})) + s.broadcastMessage(msg) } // RelayTxn a new transaction to the local node and the connected peers. @@ -710,7 +730,13 @@ func (s *Server) RelayTxn(t *transaction.Transaction) RelayReason { return RelayOutOfMemory } - s.relayInventoryCmd(CMDInv, payload.TXType, t.Hash()) + msg := s.MkMsg(CMDInv, payload.NewInventory(payload.TXType, []util.Uint256{t.Hash()})) + + // We need to filter out non-relaying nodes, so plain broadcast + // functions don't fit here. + s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool { + return p.Handshaked() && p.Version().Relay + }) return RelaySucceed } diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 39f2caedc..f5dede1bd 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -12,7 +12,7 @@ import ( func TestSendVersion(t *testing.T) { var ( s = newTestServer(t) - p = newLocalPeer(t) + p = newLocalPeer(t, s) ) s.Port = 3000 s.UserAgent = "/test/" @@ -29,7 +29,7 @@ func TestSendVersion(t *testing.T) { assert.Equal(t, uint32(0), version.StartHeight) } - if err := s.sendVersion(p); err != nil { + if err := p.SendVersion(); err != nil { t.Fatal(err) } } @@ -38,7 +38,7 @@ func TestSendVersion(t *testing.T) { func TestVerackAfterHandleVersionCmd(t *testing.T) { var ( s = newTestServer(t) - p = newLocalPeer(t) + p = newLocalPeer(t, s) ) na, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:3000") p.netaddr = *na @@ -59,8 +59,8 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) { func TestServerNotSendsVerack(t *testing.T) { var ( s = newTestServer(t) - p = newLocalPeer(t) - p2 = newLocalPeer(t) + p = newLocalPeer(t, s) + p2 = newLocalPeer(t, s) ) s.id = 1 go s.run() @@ -92,7 +92,7 @@ func TestServerNotSendsVerack(t *testing.T) { func TestRequestHeaders(t *testing.T) { var ( s = newTestServer(t) - p = newLocalPeer(t) + p = newLocalPeer(t, s) ) p.messageHandler = func(t *testing.T, msg *Message) { assert.IsType(t, &payload.GetBlocks{}, msg.Payload) diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 5683cfe39..096966359 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -117,7 +117,7 @@ func (p *TCPPeer) handleConn() { go p.handleQueues() // When a new peer is connected we send out our version immediately. - err = p.server.sendVersion(p) + err = p.SendVersion() if err == nil { r := io.NewBinReaderFromIO(p.conn) for { @@ -198,13 +198,6 @@ func (p *TCPPeer) StartProtocol() { select { case <-p.done: return - case m := <-p.server.addrReq: - var pkt []byte - - pkt, err = m.Bytes() - if err == nil { - err = p.EnqueueHPPacket(pkt) - } case <-timer.C: // Try to sync in headers and block with the peer if his block height is higher then ours. if p.LastBlockIndex() > p.server.chain.BlockHeight() { @@ -235,7 +228,8 @@ func (p *TCPPeer) Handshaked() bool { } // SendVersion checks for the handshake state and sends a message to the peer. -func (p *TCPPeer) SendVersion(msg *Message) error { +func (p *TCPPeer) SendVersion() error { + msg := p.server.getVersionMsg() p.lock.Lock() defer p.lock.Unlock() if p.handShake&versionSent != 0 { @@ -355,7 +349,7 @@ func (p *TCPPeer) SendPing() error { }) } p.lock.Unlock() - return p.EnqueueMessage(NewMessage(p.server.Net, CMDPing, payload.NewPing(p.server.id, p.server.chain.HeaderHeight()))) + return p.EnqueueMessage(p.server.MkMsg(CMDPing, payload.NewPing(p.server.id, p.server.chain.HeaderHeight()))) } // HandlePong handles a pong message received from the peer and does appropriate diff --git a/pkg/network/tcp_peer_test.go b/pkg/network/tcp_peer_test.go index 691e22b6d..5e2e6366d 100644 --- a/pkg/network/tcp_peer_test.go +++ b/pkg/network/tcp_peer_test.go @@ -18,8 +18,8 @@ func connReadStub(conn net.Conn) { func TestPeerHandshake(t *testing.T) { server, client := net.Pipe() - tcpS := NewTCPPeer(server, nil) - tcpC := NewTCPPeer(client, nil) + tcpS := NewTCPPeer(server, newTestServer(t)) + tcpC := NewTCPPeer(client, newTestServer(t)) // Something should read things written into the pipe. go connReadStub(tcpS.conn) @@ -45,22 +45,22 @@ func TestPeerHandshake(t *testing.T) { // Now send and handle versions, but in a different order on client and // server. - require.NoError(t, tcpC.SendVersion(&Message{})) + require.NoError(t, tcpC.SendVersion()) require.Error(t, tcpC.HandleVersionAck()) // Didn't receive version yet. require.NoError(t, tcpS.HandleVersion(&payload.Version{})) require.Error(t, tcpS.SendVersionAck(&Message{})) // Didn't send version yet. require.NoError(t, tcpC.HandleVersion(&payload.Version{})) - require.NoError(t, tcpS.SendVersion(&Message{})) + require.NoError(t, tcpS.SendVersion()) // No handshake yet. require.Equal(t, false, tcpS.Handshaked()) require.Equal(t, false, tcpC.Handshaked()) // These are sent/received and should fail now. - require.Error(t, tcpC.SendVersion(&Message{})) + require.Error(t, tcpC.SendVersion()) require.Error(t, tcpS.HandleVersion(&payload.Version{})) require.Error(t, tcpC.HandleVersion(&payload.Version{})) - require.Error(t, tcpS.SendVersion(&Message{})) + require.Error(t, tcpS.SendVersion()) // Now send and handle ACK, again in a different order on client and // server.