diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index b35290935..f41685f78 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -15,6 +15,7 @@ import ( "github.com/CityOfZion/neo-go/pkg/core/storage" "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/crypto/keys" + "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/vm" @@ -179,8 +180,24 @@ func (p *localPeer) PeerAddr() net.Addr { } func (p *localPeer) StartProtocol() {} func (p *localPeer) Disconnect(err error) {} -func (p *localPeer) WriteMsg(msg *Message) error { - p.messageHandler(p.t, msg) + +func (p *localPeer) EnqueueMessage(msg *Message) error { + b, err := msg.Bytes() + if err != nil { + return err + } + return p.EnqueuePacket(b) +} +func (p *localPeer) EnqueuePacket(m []byte) error { + return p.EnqueueHPPacket(m) +} +func (p *localPeer) EnqueueHPPacket(m []byte) error { + msg := &Message{} + r := io.NewBinReaderFromBuf(m) + err := msg.Decode(r) + if err == nil { + p.messageHandler(p.t, msg) + } return nil } func (p *localPeer) Version() *payload.Version { @@ -197,10 +214,12 @@ func (p *localPeer) HandleVersion(v *payload.Version) error { return nil } func (p *localPeer) SendVersion(m *Message) error { - return p.WriteMsg(m) + _ = p.EnqueueMessage(m) + return nil } func (p *localPeer) SendVersionAck(m *Message) error { - return p.WriteMsg(m) + _ = p.EnqueueMessage(m) + return nil } func (p *localPeer) HandleVersionAck() error { p.handshaked = true diff --git a/pkg/network/message.go b/pkg/network/message.go index fbe025fc5..181b52cd5 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -226,6 +226,18 @@ func (m *Message) Encode(br *io.BinWriter) error { return nil } +// Bytes serializes a Message into the new allocated buffer and returns it. +func (m *Message) Bytes() ([]byte, error) { + w := io.NewBufBinWriter() + if err := m.Encode(w.BinWriter); err != nil { + return nil, err + } + if w.Err != nil { + return nil, w.Err + } + return w.Bytes(), nil +} + // convert a command (string) to a byte slice filled with 0 bytes till // size 12. func cmdToByteArray(cmd CommandType) [cmdSize]byte { diff --git a/pkg/network/peer.go b/pkg/network/peer.go index a11f09de5..3fe9cb23d 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -18,7 +18,22 @@ type Peer interface { // before that it returns the same address as RemoteAddr. PeerAddr() net.Addr Disconnect(error) - WriteMsg(msg *Message) error + + // EnqueueMessage is a temporary wrapper that sends a message via + // EnqueuePacket if there is no error in serializing it. + EnqueueMessage(*Message) error + + // EnqueuePacket is a blocking packet enqueuer, it doesn't return until + // it puts given packet into the queue. It accepts a slice of bytes that + // can be shared with other queues (so that message marshalling can be + // done once for all peers). Does nothing is the peer is not yet + // completed handshaking. + EnqueuePacket([]byte) error + + // EnqueueHPPacket is a blocking high priority packet enqueuer, it + // doesn't return until it puts given packet into the high-priority + // queue. + EnqueueHPPacket([]byte) error Version() *payload.Version LastBlockIndex() uint32 UpdateLastBlockIndex(lbIndex uint32) diff --git a/pkg/network/server.go b/pkg/network/server.go index 7d9d6ed71..cf11d1df6 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -364,7 +364,7 @@ func (s *Server) handleBlockCmd(p Peer, block *block.Block) error { // handlePing processes ping request. func (s *Server) handlePing(p Peer, ping *payload.Ping) error { - return p.WriteMsg(NewMessage(s.Net, CMDPong, payload.NewPing(s.id, s.chain.BlockHeight()))) + return p.EnqueueMessage(NewMessage(s.Net, CMDPong, payload.NewPing(s.id, s.chain.BlockHeight()))) } // handlePing processes pong request. @@ -400,43 +400,49 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { } } if len(reqHashes) > 0 { - payload := payload.NewInventory(inv.Type, reqHashes) - return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) + msg := NewMessage(s.Net, CMDGetData, payload.NewInventory(inv.Type, reqHashes)) + pkt, err := msg.Bytes() + if err != nil { + return err + } + if inv.Type == payload.ConsensusType { + return p.EnqueueHPPacket(pkt) + } + return p.EnqueuePacket(pkt) } return nil } // handleInvCmd processes the received inventory. func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { - switch inv.Type { - case payload.TXType: - for _, hash := range inv.Hashes { + for _, hash := range inv.Hashes { + var msg *Message + + switch inv.Type { + case payload.TXType: tx, _, err := s.chain.GetTransaction(hash) if err == nil { - err = p.WriteMsg(NewMessage(s.Net, CMDTX, tx)) - if err != nil { - return err - } - + msg = NewMessage(s.Net, CMDTX, tx) } - } - case payload.BlockType: - for _, hash := range inv.Hashes { + case payload.BlockType: b, err := s.chain.GetBlock(hash) if err == nil { - err = p.WriteMsg(NewMessage(s.Net, CMDBlock, b)) - if err != nil { - return err - } + msg = NewMessage(s.Net, CMDBlock, b) + } + case payload.ConsensusType: + if cp := s.consensus.GetPayload(hash); cp != nil { + msg = NewMessage(s.Net, CMDConsensus, cp) } } - case payload.ConsensusType: - for _, hash := range inv.Hashes { - if cp := s.consensus.GetPayload(hash); cp != nil { - if err := p.WriteMsg(NewMessage(s.Net, CMDConsensus, cp)); err != nil { - return err - } + if msg != nil { + pkt, err := msg.Bytes() + if err != nil { + return err } + if inv.Type == payload.ConsensusType { + return p.EnqueueHPPacket(pkt) + } + return p.EnqueuePacket(pkt) } } return nil @@ -468,7 +474,8 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { return nil } payload := payload.NewInventory(payload.BlockType, blockHashes) - return p.WriteMsg(NewMessage(s.Net, CMDInv, payload)) + msg := NewMessage(s.Net, CMDInv, payload) + return p.EnqueueMessage(msg) } // handleGetHeadersCmd processes the getheaders request. @@ -497,7 +504,8 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { if len(resp.Hdrs) == 0 { return nil } - return p.WriteMsg(NewMessage(s.Net, CMDHeaders, &resp)) + msg := NewMessage(s.Net, CMDHeaders, &resp) + return p.EnqueueMessage(msg) } // handleConsensusCmd processes received consensus payload. @@ -538,7 +546,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error { netaddr, _ := net.ResolveTCPAddr("tcp", addr) alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts) } - return p.WriteMsg(NewMessage(s.Net, CMDAddr, alist)) + return p.EnqueueMessage(NewMessage(s.Net, CMDAddr, alist)) } // requestHeaders sends a getheaders message to the peer. @@ -546,7 +554,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error { func (s *Server) requestHeaders(p Peer) error { start := []util.Uint256{s.chain.CurrentHeaderHash()} payload := payload.NewGetBlocks(start, util.Uint256{}) - return p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload)) + return p.EnqueueMessage(NewMessage(s.Net, CMDGetHeaders, payload)) } // requestBlocks sends a getdata message to the peer @@ -565,7 +573,7 @@ func (s *Server) requestBlocks(p Peer) error { } if len(hashes) > 0 { payload := payload.NewInventory(payload.BlockType, hashes) - return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) + return p.EnqueueMessage(NewMessage(s.Net, CMDGetData, payload)) } else if s.chain.HeaderHeight() < p.Version().StartHeight { return s.requestHeaders(p) } @@ -667,7 +675,7 @@ func (s *Server) relayInventoryCmd(cmd CommandType, t payload.InventoryType, has continue } // Who cares about these messages anyway? - _ = peer.WriteMsg(msg) + _ = peer.EnqueueMessage(msg) } } diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 5be9c8622..62ceb417d 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -19,6 +19,9 @@ const ( versionReceived verAckSent verAckReceived + + requestQueueSize = 32 + hpRequestQueueSize = 4 ) var ( @@ -41,7 +44,9 @@ type TCPPeer struct { finale sync.Once handShake handShakeStage - done chan struct{} + done chan struct{} + sendQ chan []byte + hpSendQ chan []byte wg sync.WaitGroup @@ -52,29 +57,50 @@ type TCPPeer struct { // NewTCPPeer returns a TCPPeer structure based on the given connection. func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { return &TCPPeer{ - conn: conn, - server: s, - done: make(chan struct{}), + conn: conn, + server: s, + done: make(chan struct{}), + sendQ: make(chan []byte, requestQueueSize), + hpSendQ: make(chan []byte, hpRequestQueueSize), } } -// WriteMsg implements the Peer interface. This will write/encode the message -// to the underlying connection, this only works for messages other than Version -// or VerAck. -func (p *TCPPeer) WriteMsg(msg *Message) error { +// EnqueuePacket implements the Peer interface. +func (p *TCPPeer) EnqueuePacket(msg []byte) error { if !p.Handshaked() { return errStateMismatch } - return p.writeMsg(msg) + p.sendQ <- msg + return nil +} + +// EnqueueMessage is a temporary wrapper that sends a message via +// EnqueuePacket if there is no error in serializing it. +func (p *TCPPeer) EnqueueMessage(msg *Message) error { + b, err := msg.Bytes() + if err != nil { + return err + } + return p.EnqueuePacket(b) +} + +// EnqueueHPPacket implements the Peer interface. It the peer is not yet +// handshaked it's a noop. +func (p *TCPPeer) EnqueueHPPacket(msg []byte) error { + if !p.Handshaked() { + return errStateMismatch + } + p.hpSendQ <- msg + return nil } func (p *TCPPeer) writeMsg(msg *Message) error { - w := io.NewBufBinWriter() - if err := msg.Encode(w.BinWriter); err != nil { + b, err := msg.Bytes() + if err != nil { return err } - _, err := p.conn.Write(w.Bytes()) + _, err = p.conn.Write(b) return err } @@ -86,6 +112,7 @@ func (p *TCPPeer) handleConn() { p.server.register <- p + go p.handleQueues() // When a new peer is connected we send out our version immediately. err = p.server.sendVersion(p) if err == nil { @@ -108,6 +135,40 @@ func (p *TCPPeer) handleConn() { p.Disconnect(err) } +// handleQueues is a goroutine that is started automatically to handle +// send queues. +func (p *TCPPeer) handleQueues() { + var err error + + for { + var msg []byte + + // This one is to give priority to the hp queue + select { + case <-p.done: + return + case msg = <-p.hpSendQ: + default: + } + + // If there is no message in the hp queue, block until one + // appears in any of the queues. + if msg == nil { + select { + case <-p.done: + return + case msg = <-p.hpSendQ: + case msg = <-p.sendQ: + } + } + _, err = p.conn.Write(msg) + if err != nil { + break + } + } + p.Disconnect(err) +} + // StartProtocol starts a long running background loop that interacts // every ProtoTickInterval with the peer. It's only good to run after the // handshake. @@ -136,7 +197,12 @@ func (p *TCPPeer) StartProtocol() { case <-p.done: return case m := <-p.server.addrReq: - err = p.WriteMsg(m) + var pkt []byte + + pkt, err = m.Bytes() + if err == nil { + err = p.EnqueueHPPacket(pkt) + } case <-timer.C: // Try to sync in headers and block with the peer if his block height is higher then ours. if p.LastBlockIndex() > p.server.chain.BlockHeight() { @@ -153,7 +219,7 @@ func (p *TCPPeer) StartProtocol() { diff := uint32(time.Now().UTC().Unix()) - block.Timestamp if diff > uint32(p.server.PingInterval/time.Second) { p.UpdatePingSent(p.GetPingSent() + 1) - err = p.WriteMsg(NewMessage(p.server.Net, CMDPing, payload.NewPing(p.server.id, p.server.chain.HeaderHeight()))) + err = p.EnqueueMessage(NewMessage(p.server.Net, CMDPing, payload.NewPing(p.server.id, p.server.chain.HeaderHeight()))) } } } diff --git a/pkg/network/tcp_peer_test.go b/pkg/network/tcp_peer_test.go index c6c1c1cec..691e22b6d 100644 --- a/pkg/network/tcp_peer_test.go +++ b/pkg/network/tcp_peer_test.go @@ -30,8 +30,8 @@ func TestPeerHandshake(t *testing.T) { require.Equal(t, false, tcpC.Handshaked()) // No ordinary messages can be written. - require.Error(t, tcpS.WriteMsg(&Message{})) - require.Error(t, tcpC.WriteMsg(&Message{})) + require.Error(t, tcpS.EnqueueMessage(&Message{})) + require.Error(t, tcpC.EnqueueMessage(&Message{})) // Try to mess with VersionAck on both client and server, it should fail. require.Error(t, tcpS.SendVersionAck(&Message{})) @@ -80,6 +80,6 @@ func TestPeerHandshake(t *testing.T) { require.Error(t, tcpS.SendVersionAck(&Message{})) // Now regular messaging can proceed. - require.NoError(t, tcpS.WriteMsg(&Message{})) - require.NoError(t, tcpC.WriteMsg(&Message{})) + require.NoError(t, tcpS.EnqueueMessage(&Message{})) + require.NoError(t, tcpC.EnqueueMessage(&Message{})) }