diff --git a/pkg/network/server.go b/pkg/network/server.go index 74e9b5502..0fdfd2a16 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -68,11 +68,15 @@ type ( chain blockchainer.Blockchainer bQueue *blockQueue consensus consensus.Service + mempool *mempool.Pool notaryRequestPool *mempool.Pool extensiblePool *extpool.Pool notaryFeer NotaryFeer notaryModule *notary.Notary + txInLock sync.Mutex + txInMap map[util.Uint256]struct{} + lock sync.RWMutex peers map[Peer]bool @@ -136,8 +140,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai quit: make(chan struct{}), register: make(chan Peer), unregister: make(chan peerDrop), + txInMap: make(map[util.Uint256]struct{}), peers: make(map[Peer]bool), syncReached: atomic.NewBool(false), + mempool: chain.GetMemPool(), extensiblePool: extpool.New(chain, config.ExtensiblePoolSize), log: log, transactions: make(chan *transaction.Transaction, 64), @@ -282,7 +288,7 @@ func (s *Server) Shutdown() { s.transport.Close() s.discovery.Close() s.consensus.Shutdown() - for p := range s.Peers() { + for _, p := range s.getPeers(nil) { p.Disconnect(errServerShutdown) } s.bQueue.discard() @@ -425,7 +431,7 @@ func (s *Server) runProto() { case <-pingTimer.C: if s.chain.BlockHeight() == prevHeight { // Get a copy of s.peers to avoid holding a lock while sending. - for peer := range s.Peers() { + for _, peer := range s.getPeers(nil) { _ = peer.SendPing(NewMessage(CMDPing, payload.NewPing(s.chain.BlockHeight(), s.id))) } } @@ -483,15 +489,18 @@ func (s *Server) UnsubscribeFromNotaryRequests(ch chan<- mempoolevent.Event) { s.notaryRequestPool.UnsubscribeFromTransactions(ch) } -// Peers returns the current list of peers connected to -// the server. -func (s *Server) Peers() map[Peer]bool { +// getPeers returns current list of peers connected to the server filtered by +// isOK function if it's given. +func (s *Server) getPeers(isOK func(Peer) bool) []Peer { s.lock.RLock() defer s.lock.RUnlock() - peers := make(map[Peer]bool, len(s.peers)) - for k, v := range s.peers { - peers[k] = v + peers := make([]Peer, 0, len(s.peers)) + for k := range s.peers { + if isOK != nil && !isOK(k) { + continue + } + peers = append(peers, k) } return peers @@ -655,7 +664,7 @@ func (s *Server) handlePong(p Peer, pong *payload.Ping) error { func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { reqHashes := make([]util.Uint256, 0) var typExists = map[payload.InventoryType]func(util.Uint256) bool{ - payload.TXType: s.chain.HasTransaction, + payload.TXType: s.mempool.ContainsKey, payload.BlockType: s.chain.HasBlock, payload.ExtensibleType: func(h util.Uint256) bool { cp := s.extensiblePool.Get(h) @@ -688,7 +697,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { // handleMempoolCmd handles getmempool command. func (s *Server) handleMempoolCmd(p Peer) error { - txs := s.chain.GetMemPool().GetVerifiedTransactions() + txs := s.mempool.GetVerifiedTransactions() hs := make([]util.Uint256, 0, payload.MaxHashesCount) for i := range txs { hs = append(hs, txs[i].Hash()) @@ -874,10 +883,21 @@ func (s *Server) handleExtensibleCmd(e *payload.Extensible) error { func (s *Server) handleTxCmd(tx *transaction.Transaction) error { // It's OK for it to fail for various reasons like tx already existing // in the pool. + s.txInLock.Lock() + _, ok := s.txInMap[tx.Hash()] + if ok || s.mempool.ContainsKey(tx.Hash()) { + s.txInLock.Unlock() + return nil + } + s.txInMap[tx.Hash()] = struct{}{} + s.txInLock.Unlock() if s.verifyAndPoolTX(tx) == nil { s.consensus.OnTransaction(tx) s.broadcastTX(tx, nil) } + s.txInLock.Lock() + delete(s.txInMap, tx.Hash()) + s.txInLock.Unlock() return nil } @@ -1124,54 +1144,49 @@ func (s *Server) requestTx(hashes ...util.Uint256) { // passed, one is to send the message and the other is to filtrate peers (the // peer is considered invalid if it returns false). func (s *Server) iteratePeersWithSendMsg(msg *Message, send func(Peer, bool, []byte) error, peerOK func(Peer) bool) { + var deadN, peerN, sentN int + // Get a copy of s.peers to avoid holding a lock while sending. - peers := s.Peers() - if len(peers) == 0 { + peers := s.getPeers(peerOK) + peerN = len(peers) + if peerN == 0 { return } + mrand.Shuffle(peerN, func(i, j int) { + peers[i], peers[j] = peers[j], peers[i] + }) pkt, err := msg.Bytes() if err != nil { return } - success := make(map[Peer]bool, len(peers)) - okCount := 0 - sentCount := 0 - for peer := range peers { - if peerOK != nil && !peerOK(peer) { - success[peer] = false - continue - } - okCount++ - if err := send(peer, false, pkt); err != nil { - continue - } - if msg.Command == CMDGetAddr { - peer.AddGetAddrSent() - } - success[peer] = true - sentCount++ - } + // If true, this node isn't counted any more, either it's dead or we + // have already sent an Inv to it. + finished := make([]bool, peerN) - // Send to at least 2/3 of good peers. - if 3*sentCount >= 2*okCount { - return - } - - // Perform blocking send now. - for peer := range peers { - if _, ok := success[peer]; ok || peerOK != nil && !peerOK(peer) { - continue - } - if err := send(peer, true, pkt); err != nil { - continue - } - if msg.Command == CMDGetAddr { - peer.AddGetAddrSent() - } - sentCount++ - if 3*sentCount >= 2*okCount { - return + // Try non-blocking sends first and only block if have to. + for _, blocking := range []bool{false, true} { + for i, peer := range peers { + // Send to 2/3 of good peers. + if 3*sentN >= 2*(peerN-deadN) { + return + } + if finished[i] { + continue + } + err := send(peer, blocking, pkt) + switch err { + case nil: + if msg.Command == CMDGetAddr { + peer.AddGetAddrSent() + } + sentN++ + case errBusy: // Can be retried. + continue + default: + deadN++ + } + finished[i] = true } } } @@ -1247,8 +1262,7 @@ func (s *Server) initStaleMemPools() { threshold = cfg.ValidatorsCount * 2 } - mp := s.chain.GetMemPool() - mp.SetResendThreshold(uint32(threshold), s.broadcastTX) + s.mempool.SetResendThreshold(uint32(threshold), s.broadcastTX) if s.chain.P2PSigExtensionsEnabled() { s.notaryRequestPool.SetResendThreshold(uint32(threshold), s.broadcastP2PNotaryRequestPayload) } diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index d70618b54..0c9c9a358 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -694,7 +694,7 @@ func TestInv(t *testing.T) { }) t.Run("transaction", func(t *testing.T) { tx := newDummyTx() - s.chain.(*fakechain.FakeChain).PutTx(tx) + require.NoError(t, s.chain.GetMemPool().Add(tx, s.chain)) hs := []util.Uint256{random.Uint256(), tx.Hash(), random.Uint256()} s.testHandleMessage(t, p, CMDInv, &payload.Inventory{ Type: payload.TXType, diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 85678d297..8ff47a18c 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -26,6 +26,7 @@ const ( requestQueueSize = 32 p2pMsgQueueSize = 16 hpRequestQueueSize = 4 + incomingQueueSize = 1 // Each message can be up to 32MB in size. ) var ( @@ -57,6 +58,7 @@ type TCPPeer struct { sendQ chan []byte p2pSendQ chan []byte hpSendQ chan []byte + incoming chan *Message // track outstanding getaddr requests. getAddrSent atomic.Int32 @@ -75,6 +77,7 @@ func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { sendQ: make(chan []byte, requestQueueSize), p2pSendQ: make(chan []byte, p2pMsgQueueSize), hpSendQ: make(chan []byte, hpRequestQueueSize), + incoming: make(chan *Message, incomingQueueSize), } } @@ -158,6 +161,7 @@ func (p *TCPPeer) handleConn() { p.server.register <- p go p.handleQueues() + go p.handleIncoming() // When a new peer is connected we send out our version immediately. err = p.SendVersion() if err == nil { @@ -172,12 +176,22 @@ func (p *TCPPeer) handleConn() { } else if err != nil { break } - if err = p.server.handleMessage(p, msg); err != nil { - if p.Handshaked() { - err = fmt.Errorf("handling %s message: %w", msg.Command.String(), err) - } - break + p.incoming <- msg + } + } + close(p.incoming) + p.Disconnect(err) +} + +func (p *TCPPeer) handleIncoming() { + var err error + for msg := range p.incoming { + err = p.server.handleMessage(p, msg) + if err != nil { + if p.Handshaked() { + err = fmt.Errorf("handling %s message: %w", msg.Command.String(), err) } + break } } p.Disconnect(err)