diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 77db5f08a..0022a9eac 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -71,7 +71,7 @@ type localPeer struct { server *Server version *payload.Version lastBlockIndex uint32 - handshaked bool + handshaked int32 // TODO: use atomic.Bool after #2626. isFullNode bool t *testing.T messageHandler func(t *testing.T, msg *Message) @@ -105,26 +105,7 @@ func (p *localPeer) Disconnect(err error) { p.server.unregister <- peerDrop{p, err} } -func (p *localPeer) EnqueueMessage(msg *Message) error { - b, err := msg.Bytes() - if err != nil { - return err - } - return p.EnqueueHPPacket(b) -} func (p *localPeer) BroadcastPacket(_ context.Context, m []byte) error { - return p.EnqueueHPPacket(m) -} -func (p *localPeer) EnqueueP2PMessage(msg *Message) error { - return p.EnqueueMessage(msg) -} -func (p *localPeer) EnqueueP2PPacket(m []byte) error { - return p.EnqueueHPPacket(m) -} -func (p *localPeer) BroadcastHPPacket(_ context.Context, m []byte) error { - return p.EnqueueHPPacket(m) -} -func (p *localPeer) EnqueueHPPacket(m []byte) error { msg := &Message{} r := io.NewBinReaderFromBuf(m) err := msg.Decode(r) @@ -133,6 +114,16 @@ func (p *localPeer) EnqueueHPPacket(m []byte) error { } return nil } +func (p *localPeer) EnqueueP2PMessage(msg *Message) error { + return p.EnqueueHPMessage(msg) +} +func (p *localPeer) BroadcastHPPacket(ctx context.Context, m []byte) error { + return p.BroadcastPacket(ctx, m) +} +func (p *localPeer) EnqueueHPMessage(msg *Message) error { + p.messageHandler(p.t, msg) + return nil +} func (p *localPeer) Version() *payload.Version { return p.version } @@ -148,21 +139,19 @@ func (p *localPeer) SendVersion() error { if err != nil { return err } - _ = p.EnqueueMessage(m) + _ = p.EnqueueHPMessage(m) return nil } func (p *localPeer) SendVersionAck(m *Message) error { - _ = p.EnqueueMessage(m) + _ = p.EnqueueHPMessage(m) return nil } func (p *localPeer) HandleVersionAck() error { - p.handshaked = true + atomic.StoreInt32(&p.handshaked, 1) return nil } -func (p *localPeer) SendPing(m *Message) error { +func (p *localPeer) SetPingTimer() { p.pingSent++ - _ = p.EnqueueMessage(m) - return nil } func (p *localPeer) HandlePing(ping *payload.Ping) error { p.lastBlockIndex = ping.LastBlockIndex @@ -176,7 +165,7 @@ func (p *localPeer) HandlePong(pong *payload.Ping) error { } func (p *localPeer) Handshaked() bool { - return p.handshaked + return atomic.LoadInt32(&p.handshaked) != 0 } func (p *localPeer) IsFullNode() bool { diff --git a/pkg/network/peer.go b/pkg/network/peer.go index ca1263cfd..9854165d5 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -20,10 +20,6 @@ type Peer interface { PeerAddr() net.Addr Disconnect(error) - // EnqueueMessage is a blocking packet enqueuer similar to EnqueueP2PMessage, - // but using the lowest priority queue. - EnqueueMessage(*Message) error - // BroadcastPacket is a context-bound packet enqueuer, it either puts the // given packet into the queue or exits with errors if the context expires // or peer disconnects. It accepts a slice of bytes that @@ -36,33 +32,25 @@ type Peer interface { // queue. BroadcastHPPacket(context.Context, []byte) error - // EnqueueP2PMessage is a temporary wrapper that sends a message via - // EnqueueP2PPacket if there is no error in serializing it. + // EnqueueP2PMessage is a blocking packet enqueuer, it doesn't return until + // it puts the given message into the queue. It returns an error if the peer + // has not yet completed handshaking. This queue is intended to be used for + // unicast peer to peer communication that is more important than broadcasts + // (handled by BroadcastPacket) but less important than high-priority + // messages (handled by EnqueueHPMessage). EnqueueP2PMessage(*Message) error - // EnqueueP2PPacket is a blocking packet enqueuer, it doesn't return until - // it puts the given packet into the queue. It accepts a slice of bytes that - // can be shared with other queues (so that message marshalling can be - // done once for all peers). It returns an error if the peer has not yet - // completed handshaking. This queue is intended to be used for unicast - // peer to peer communication that is more important than broadcasts - // (handled by BroadcastPacket) but less important than high-priority - // messages (handled by EnqueueHPPacket and BroadcastHPPacket). - EnqueueP2PPacket([]byte) error - - // EnqueueHPPacket is a blocking high priority packet enqueuer, it - // doesn't return until it puts the given packet into the high-priority + // EnqueueHPMessage is similar to EnqueueP2PMessage, but uses a high-priority // queue. - EnqueueHPPacket([]byte) error + EnqueueHPMessage(*Message) error Version() *payload.Version LastBlockIndex() uint32 Handshaked() bool IsFullNode() bool - // SendPing enqueues a ping message to be sent to the peer and does - // appropriate protocol handling like timeouts and outstanding pings - // management. - SendPing(*Message) error + // SetPingTimer adds an outgoing ping to the counter and sets a PingTimeout + // timer that will shut the connection down in case of no response. + SetPingTimer() // SendVersion checks handshake status and sends a version message to // the peer. SendVersion() error diff --git a/pkg/network/server.go b/pkg/network/server.go index 99a7291cd..8ef0f6eea 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -465,10 +465,7 @@ func (s *Server) runProto() { return case <-pingTimer.C: if s.chain.BlockHeight() == prevHeight { - // Get a copy of s.peers to avoid holding a lock while sending. - for _, peer := range s.getPeers(nil) { - _ = peer.SendPing(NewMessage(CMDPing, payload.NewPing(s.chain.BlockHeight(), s.id))) - } + s.broadcastMessage(NewMessage(CMDPing, payload.NewPing(s.chain.BlockHeight(), s.id))) } pingTimer.Reset(s.PingInterval) } @@ -751,14 +748,10 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { } if len(reqHashes) > 0 { msg := NewMessage(CMDGetData, payload.NewInventory(inv.Type, reqHashes)) - pkt, err := msg.Bytes() - if err != nil { - return err - } if inv.Type == payload.ExtensibleType { - return p.EnqueueHPPacket(pkt) + return p.EnqueueHPMessage(msg) } - return p.EnqueueP2PPacket(pkt) + return p.EnqueueP2PMessage(msg) } return nil } @@ -815,13 +808,11 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { } } if msg != nil { - pkt, err := msg.Bytes() - if err == nil { - if inv.Type == payload.ExtensibleType { - err = p.EnqueueHPPacket(pkt) - } else { - err = p.EnqueueP2PPacket(pkt) - } + var err error + if inv.Type == payload.ExtensibleType { + err = p.EnqueueHPMessage(msg) + } else { + err = p.EnqueueP2PMessage(msg) } if err != nil { return err @@ -1371,6 +1362,9 @@ func (s *Server) iteratePeersWithSendMsg(msg *Message, send func(Peer, context.C if msg.Command == CMDGetAddr { p.AddGetAddrSent() } + if msg.Command == CMDPing { + p.SetPingTimer() + } replies <- send(p, ctx, pkt) }(peer, ctx, pkt) } @@ -1394,12 +1388,12 @@ func (s *Server) iteratePeersWithSendMsg(msg *Message, send func(Peer, context.C // broadcastMessage sends the message to all available peers. func (s *Server) broadcastMessage(msg *Message) { - s.iteratePeersWithSendMsg(msg, Peer.BroadcastPacket, nil) + s.iteratePeersWithSendMsg(msg, Peer.BroadcastPacket, Peer.Handshaked) } // broadcastHPMessage sends the high-priority message to all available peers. func (s *Server) broadcastHPMessage(msg *Message) { - s.iteratePeersWithSendMsg(msg, Peer.BroadcastHPPacket, nil) + s.iteratePeersWithSendMsg(msg, Peer.BroadcastHPPacket, Peer.Handshaked) } // relayBlocksLoop subscribes to new blocks in the ledger and broadcasts them diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 578e82f7e..2e4b7cf8e 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -372,7 +372,7 @@ func TestServerNotSendsVerack(t *testing.T) { func (s *Server) testHandleMessage(t *testing.T, p Peer, cmd CommandType, pl payload.Payload) *Server { if p == nil { p = newLocalPeer(t, s) - p.(*localPeer).handshaked = true + p.(*localPeer).handshaked = 1 } msg := NewMessage(cmd, pl) require.NoError(t, s.handleMessage(p, msg)) @@ -419,7 +419,7 @@ func TestConsensus(t *testing.T) { atomic2.StoreUint32(&s.chain.(*fakechain.FakeChain).Blockheight, 4) p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 s.register <- p require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10) @@ -491,7 +491,7 @@ func (s *Server) testHandleGetData(t *testing.T, invType payload.InventoryType, var recvNotFound atomic.Bool p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.messageHandler = func(t *testing.T, msg *Message) { switch msg.Command { case CMDTX, CMDBlock, CMDExtensible, CMDP2PNotaryRequest: @@ -587,7 +587,7 @@ func TestGetBlocks(t *testing.T) { } var actual []util.Uint256 p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.messageHandler = func(t *testing.T, msg *Message) { if msg.Command == CMDInv { actual = msg.Payload.(*payload.Inventory).Hashes @@ -614,7 +614,7 @@ func TestGetBlockByIndex(t *testing.T) { var expected []*block.Block var actual []*block.Block p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.messageHandler = func(t *testing.T, msg *Message) { if msg.Command == CMDBlock { actual = append(actual, msg.Payload.(*block.Block)) @@ -652,7 +652,7 @@ func TestGetHeaders(t *testing.T) { var actual *payload.Headers p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.messageHandler = func(t *testing.T, msg *Message) { if msg.Command == CMDHeaders { actual = msg.Payload.(*payload.Headers) @@ -690,7 +690,7 @@ func TestInv(t *testing.T) { var actual []util.Uint256 p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.messageHandler = func(t *testing.T, msg *Message) { if msg.Command == CMDGetData { actual = msg.Payload.(*payload.Inventory).Hashes @@ -752,7 +752,7 @@ func TestHandleGetMPTData(t *testing.T) { t.Run("P2PStateExchange extensions off", func(t *testing.T) { s := startTestServer(t) p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{ Hashes: []util.Uint256{{1, 2, 3}}, }) @@ -776,7 +776,7 @@ func TestHandleGetMPTData(t *testing.T) { Nodes: [][]byte{node}, // no duplicates expected } p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.messageHandler = func(t *testing.T, msg *Message) { switch msg.Command { case CMDMPTData: @@ -809,7 +809,7 @@ func TestHandleMPTData(t *testing.T) { t.Run("P2PStateExchange extensions off", func(t *testing.T) { s := startTestServer(t) p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 msg := NewMessage(CMDMPTData, &payload.MPTData{ Nodes: [][]byte{{1, 2, 3}}, }) @@ -829,7 +829,7 @@ func TestHandleMPTData(t *testing.T) { startWithCleanup(t, s) p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 msg := NewMessage(CMDMPTData, &payload.MPTData{ Nodes: expected, }) @@ -842,7 +842,7 @@ func TestRequestMPTNodes(t *testing.T) { var actual []util.Uint256 p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.messageHandler = func(t *testing.T, msg *Message) { if msg.Command == CMDGetMPTData { actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...) @@ -887,7 +887,7 @@ func TestRequestTx(t *testing.T) { var actual []util.Uint256 p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.messageHandler = func(t *testing.T, msg *Message) { if msg.Command == CMDGetData { actual = append(actual, msg.Payload.(*payload.Inventory).Hashes...) @@ -938,7 +938,7 @@ func TestAddrs(t *testing.T) { } p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.getAddrSent = 1 pl := &payload.AddressList{ Addrs: []*payload.AddressAndTime{ @@ -990,7 +990,7 @@ func TestMemPool(t *testing.T) { var actual []util.Uint256 p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.messageHandler = func(t *testing.T, msg *Message) { if msg.Command == CMDInv { actual = append(actual, msg.Payload.(*payload.Inventory).Hashes...) @@ -1070,12 +1070,12 @@ func TestTryInitStateSync(t *testing.T) { s := startTestServer(t) for _, h := range []uint32{10, 8, 7, 4, 11, 4} { p := newLocalPeer(t, s) - p.handshaked = true + p.handshaked = 1 p.lastBlockIndex = h s.register <- p } p := newLocalPeer(t, s) - p.handshaked = false // one disconnected peer to check it won't be taken into attention + p.handshaked = 0 // one disconnected peer to check it won't be taken into attention p.lastBlockIndex = 5 s.register <- p require.Eventually(t, func() bool { return 7 == s.PeerCount() }, time.Second, time.Millisecond*10) diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 330e2268f..9244716cc 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -81,9 +81,9 @@ func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { } } -// putBroadcastPacketIntoQueue puts the given message into the given queue if +// putPacketIntoQueue puts the given message into the given queue if // the peer has done handshaking using the given context. -func (p *TCPPeer) putBroadcastPacketIntoQueue(ctx context.Context, queue chan<- []byte, msg []byte) error { +func (p *TCPPeer) putPacketIntoQueue(ctx context.Context, queue chan<- []byte, msg []byte) error { if !p.Handshaked() { return errStateMismatch } @@ -97,29 +97,15 @@ func (p *TCPPeer) putBroadcastPacketIntoQueue(ctx context.Context, queue chan<- return nil } -// putPacketIntoQueue puts the given message into the given queue if the peer has -// done handshaking. -func (p *TCPPeer) putPacketIntoQueue(queue chan<- []byte, msg []byte) error { - if !p.Handshaked() { - return errStateMismatch - } - select { - case queue <- msg: - case <-p.done: - return errGone - } - return nil -} - // BroadcastPacket implements the Peer interface. func (p *TCPPeer) BroadcastPacket(ctx context.Context, msg []byte) error { - return p.putBroadcastPacketIntoQueue(ctx, p.sendQ, msg) + return p.putPacketIntoQueue(ctx, p.sendQ, msg) } // BroadcastHPPacket implements the Peer interface. It the peer is not yet // handshaked it's a noop. func (p *TCPPeer) BroadcastHPPacket(ctx context.Context, msg []byte) error { - return p.putBroadcastPacketIntoQueue(ctx, p.hpSendQ, msg) + return p.putPacketIntoQueue(ctx, p.hpSendQ, msg) } // putMessageIntoQueue serializes the given Message and puts it into given queue if @@ -129,18 +115,7 @@ func (p *TCPPeer) putMsgIntoQueue(queue chan<- []byte, msg *Message) error { if err != nil { return err } - return p.putPacketIntoQueue(queue, b) -} - -// EnqueueMessage is a temporary wrapper that sends a message via -// EnqueuePacket if there is no error in serializing it. -func (p *TCPPeer) EnqueueMessage(msg *Message) error { - return p.putMsgIntoQueue(p.sendQ, msg) -} - -// EnqueueP2PPacket implements the Peer interface. -func (p *TCPPeer) EnqueueP2PPacket(msg []byte) error { - return p.putPacketIntoQueue(p.p2pSendQ, msg) + return p.putPacketIntoQueue(context.Background(), queue, b) } // EnqueueP2PMessage implements the Peer interface. @@ -148,10 +123,9 @@ func (p *TCPPeer) EnqueueP2PMessage(msg *Message) error { return p.putMsgIntoQueue(p.p2pSendQ, msg) } -// EnqueueHPPacket implements the Peer interface. It the peer is not yet -// handshaked it's a noop. -func (p *TCPPeer) EnqueueHPPacket(msg []byte) error { - return p.putPacketIntoQueue(p.hpSendQ, msg) +// EnqueueHPMessage implements the Peer interface. +func (p *TCPPeer) EnqueueHPMessage(msg *Message) error { + return p.putMsgIntoQueue(p.hpSendQ, msg) } func (p *TCPPeer) writeMsg(msg *Message) error { @@ -454,12 +428,9 @@ func (p *TCPPeer) LastBlockIndex() uint32 { return p.lastBlockIndex } -// SendPing sends a ping message to the peer and does an appropriate accounting of -// outstanding pings and timeouts. -func (p *TCPPeer) SendPing(msg *Message) error { - if !p.Handshaked() { - return errStateMismatch - } +// SetPingTimer adds an outgoing ping to the counter and sets a PingTimeout timer +// that will shut the connection down in case of no response. +func (p *TCPPeer) SetPingTimer() { p.lock.Lock() p.pingSent++ if p.pingTimer == nil { @@ -468,7 +439,6 @@ func (p *TCPPeer) SendPing(msg *Message) error { }) } p.lock.Unlock() - return p.EnqueueMessage(msg) } // HandlePing handles a ping message received from the peer. diff --git a/pkg/network/tcp_peer_test.go b/pkg/network/tcp_peer_test.go index b1cdfa985..a4e2e8655 100644 --- a/pkg/network/tcp_peer_test.go +++ b/pkg/network/tcp_peer_test.go @@ -30,8 +30,8 @@ func TestPeerHandshake(t *testing.T) { require.Equal(t, false, tcpC.Handshaked()) // No ordinary messages can be written. - require.Error(t, tcpS.EnqueueMessage(&Message{})) - require.Error(t, tcpC.EnqueueMessage(&Message{})) + require.Error(t, tcpS.EnqueueP2PMessage(&Message{})) + require.Error(t, tcpC.EnqueueP2PMessage(&Message{})) // Try to mess with VersionAck on both client and server, it should fail. require.Error(t, tcpS.SendVersionAck(&Message{})) @@ -80,6 +80,6 @@ func TestPeerHandshake(t *testing.T) { require.Error(t, tcpS.SendVersionAck(&Message{})) // Now regular messaging can proceed. - require.NoError(t, tcpS.EnqueueMessage(&Message{})) - require.NoError(t, tcpC.EnqueueMessage(&Message{})) + require.NoError(t, tcpS.EnqueueP2PMessage(&Message{})) + require.NoError(t, tcpC.EnqueueP2PMessage(&Message{})) }