diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index bf5361f35..f41685f78 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -15,6 +15,7 @@ import ( "github.com/CityOfZion/neo-go/pkg/core/storage" "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/crypto/keys" + "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/vm" @@ -177,14 +178,27 @@ func (p *localPeer) RemoteAddr() net.Addr { func (p *localPeer) PeerAddr() net.Addr { return &p.netaddr } +func (p *localPeer) StartProtocol() {} func (p *localPeer) Disconnect(err error) {} -func (p *localPeer) WriteMsg(msg *Message) error { - p.messageHandler(p.t, msg) - return nil + +func (p *localPeer) EnqueueMessage(msg *Message) error { + b, err := msg.Bytes() + if err != nil { + return err + } + return p.EnqueuePacket(b) } -func (p *localPeer) Done() chan error { - done := make(chan error) - return done +func (p *localPeer) EnqueuePacket(m []byte) error { + return p.EnqueueHPPacket(m) +} +func (p *localPeer) EnqueueHPPacket(m []byte) error { + msg := &Message{} + r := io.NewBinReaderFromBuf(m) + err := msg.Decode(r) + if err == nil { + p.messageHandler(p.t, msg) + } + return nil } func (p *localPeer) Version() *payload.Version { return p.version @@ -200,10 +214,12 @@ func (p *localPeer) HandleVersion(v *payload.Version) error { return nil } func (p *localPeer) SendVersion(m *Message) error { - return p.WriteMsg(m) + _ = p.EnqueueMessage(m) + return nil } func (p *localPeer) SendVersionAck(m *Message) error { - return p.WriteMsg(m) + _ = p.EnqueueMessage(m) + return nil } func (p *localPeer) HandleVersionAck() error { p.handshaked = true diff --git a/pkg/network/message.go b/pkg/network/message.go index fbe025fc5..181b52cd5 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -226,6 +226,18 @@ func (m *Message) Encode(br *io.BinWriter) error { return nil } +// Bytes serializes a Message into the new allocated buffer and returns it. +func (m *Message) Bytes() ([]byte, error) { + w := io.NewBufBinWriter() + if err := m.Encode(w.BinWriter); err != nil { + return nil, err + } + if w.Err != nil { + return nil, w.Err + } + return w.Bytes(), nil +} + // convert a command (string) to a byte slice filled with 0 bytes till // size 12. func cmdToByteArray(cmd CommandType) [cmdSize]byte { diff --git a/pkg/network/peer.go b/pkg/network/peer.go index 2562153fd..3fe9cb23d 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -18,14 +18,31 @@ type Peer interface { // before that it returns the same address as RemoteAddr. PeerAddr() net.Addr Disconnect(error) - WriteMsg(msg *Message) error - Done() chan error + + // EnqueueMessage is a temporary wrapper that sends a message via + // EnqueuePacket if there is no error in serializing it. + EnqueueMessage(*Message) error + + // EnqueuePacket is a blocking packet enqueuer, it doesn't return until + // it puts given packet into the queue. It accepts a slice of bytes that + // can be shared with other queues (so that message marshalling can be + // done once for all peers). Does nothing is the peer is not yet + // completed handshaking. + EnqueuePacket([]byte) error + + // EnqueueHPPacket is a blocking high priority packet enqueuer, it + // doesn't return until it puts given packet into the high-priority + // queue. + EnqueueHPPacket([]byte) error Version() *payload.Version LastBlockIndex() uint32 UpdateLastBlockIndex(lbIndex uint32) Handshaked() bool SendVersion(*Message) error SendVersionAck(*Message) error + // StartProtocol is a goroutine to be run after the handshake. It + // implements basic peer-related protocol handling. + StartProtocol() HandleVersion(*payload.Version) error HandleVersionAck() error GetPingSent() int diff --git a/pkg/network/server.go b/pkg/network/server.go index e242295f7..cf11d1df6 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -305,71 +305,6 @@ func (s *Server) HandshakedPeersCount() int { return count } -// startProtocol starts a long running background loop that interacts -// every ProtoTickInterval with the peer. -func (s *Server) startProtocol(p Peer) { - var err error - - s.log.Info("started protocol", - zap.Stringer("addr", p.RemoteAddr()), - zap.ByteString("userAgent", p.Version().UserAgent), - zap.Uint32("startHeight", p.Version().StartHeight), - zap.Uint32("id", p.Version().Nonce)) - - s.discovery.RegisterGoodAddr(p.PeerAddr().String()) - if s.chain.HeaderHeight() < p.LastBlockIndex() { - err = s.requestHeaders(p) - if err != nil { - p.Disconnect(err) - return - } - } - - timer := time.NewTimer(s.ProtoTickInterval) - pingTimer := time.NewTimer(s.PingTimeout) - for { - select { - case err = <-p.Done(): - // time to stop - case m := <-s.addrReq: - err = p.WriteMsg(m) - case <-timer.C: - // Try to sync in headers and block with the peer if his block height is higher then ours. - if p.LastBlockIndex() > s.chain.BlockHeight() { - err = s.requestBlocks(p) - } - if err == nil { - timer.Reset(s.ProtoTickInterval) - } - if s.chain.HeaderHeight() >= p.LastBlockIndex() { - block, errGetBlock := s.chain.GetBlock(s.chain.CurrentBlockHash()) - if errGetBlock != nil { - err = errGetBlock - } else { - diff := uint32(time.Now().UTC().Unix()) - block.Timestamp - if diff > uint32(s.PingInterval/time.Second) { - p.UpdatePingSent(p.GetPingSent() + 1) - err = p.WriteMsg(NewMessage(s.Net, CMDPing, payload.NewPing(s.id, s.chain.HeaderHeight()))) - } - } - } - case <-pingTimer.C: - if p.GetPingSent() > defaultPingLimit { - err = errors.New("ping/pong timeout") - } else { - pingTimer.Reset(s.PingTimeout) - p.UpdatePingSent(0) - } - } - if err != nil { - s.unregister <- peerDrop{p, err} - timer.Stop() - p.Disconnect(err) - return - } - } -} - // When a peer connects to the server, we will send our version immediately. func (s *Server) sendVersion(p Peer) error { payload := payload.NewVersion( @@ -429,7 +364,7 @@ func (s *Server) handleBlockCmd(p Peer, block *block.Block) error { // handlePing processes ping request. func (s *Server) handlePing(p Peer, ping *payload.Ping) error { - return p.WriteMsg(NewMessage(s.Net, CMDPong, payload.NewPing(s.id, s.chain.BlockHeight()))) + return p.EnqueueMessage(NewMessage(s.Net, CMDPong, payload.NewPing(s.id, s.chain.BlockHeight()))) } // handlePing processes pong request. @@ -465,43 +400,49 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { } } if len(reqHashes) > 0 { - payload := payload.NewInventory(inv.Type, reqHashes) - return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) + msg := NewMessage(s.Net, CMDGetData, payload.NewInventory(inv.Type, reqHashes)) + pkt, err := msg.Bytes() + if err != nil { + return err + } + if inv.Type == payload.ConsensusType { + return p.EnqueueHPPacket(pkt) + } + return p.EnqueuePacket(pkt) } return nil } // handleInvCmd processes the received inventory. func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { - switch inv.Type { - case payload.TXType: - for _, hash := range inv.Hashes { + for _, hash := range inv.Hashes { + var msg *Message + + switch inv.Type { + case payload.TXType: tx, _, err := s.chain.GetTransaction(hash) if err == nil { - err = p.WriteMsg(NewMessage(s.Net, CMDTX, tx)) - if err != nil { - return err - } - + msg = NewMessage(s.Net, CMDTX, tx) } - } - case payload.BlockType: - for _, hash := range inv.Hashes { + case payload.BlockType: b, err := s.chain.GetBlock(hash) if err == nil { - err = p.WriteMsg(NewMessage(s.Net, CMDBlock, b)) - if err != nil { - return err - } + msg = NewMessage(s.Net, CMDBlock, b) + } + case payload.ConsensusType: + if cp := s.consensus.GetPayload(hash); cp != nil { + msg = NewMessage(s.Net, CMDConsensus, cp) } } - case payload.ConsensusType: - for _, hash := range inv.Hashes { - if cp := s.consensus.GetPayload(hash); cp != nil { - if err := p.WriteMsg(NewMessage(s.Net, CMDConsensus, cp)); err != nil { - return err - } + if msg != nil { + pkt, err := msg.Bytes() + if err != nil { + return err } + if inv.Type == payload.ConsensusType { + return p.EnqueueHPPacket(pkt) + } + return p.EnqueuePacket(pkt) } } return nil @@ -533,7 +474,8 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { return nil } payload := payload.NewInventory(payload.BlockType, blockHashes) - return p.WriteMsg(NewMessage(s.Net, CMDInv, payload)) + msg := NewMessage(s.Net, CMDInv, payload) + return p.EnqueueMessage(msg) } // handleGetHeadersCmd processes the getheaders request. @@ -562,7 +504,8 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { if len(resp.Hdrs) == 0 { return nil } - return p.WriteMsg(NewMessage(s.Net, CMDHeaders, &resp)) + msg := NewMessage(s.Net, CMDHeaders, &resp) + return p.EnqueueMessage(msg) } // handleConsensusCmd processes received consensus payload. @@ -603,7 +546,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error { netaddr, _ := net.ResolveTCPAddr("tcp", addr) alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts) } - return p.WriteMsg(NewMessage(s.Net, CMDAddr, alist)) + return p.EnqueueMessage(NewMessage(s.Net, CMDAddr, alist)) } // requestHeaders sends a getheaders message to the peer. @@ -611,7 +554,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error { func (s *Server) requestHeaders(p Peer) error { start := []util.Uint256{s.chain.CurrentHeaderHash()} payload := payload.NewGetBlocks(start, util.Uint256{}) - return p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload)) + return p.EnqueueMessage(NewMessage(s.Net, CMDGetHeaders, payload)) } // requestBlocks sends a getdata message to the peer @@ -630,7 +573,7 @@ func (s *Server) requestBlocks(p Peer) error { } if len(hashes) > 0 { payload := payload.NewInventory(payload.BlockType, hashes) - return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) + return p.EnqueueMessage(NewMessage(s.Net, CMDGetData, payload)) } else if s.chain.HeaderHeight() < p.Version().StartHeight { return s.requestHeaders(p) } @@ -701,7 +644,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { if err != nil { return err } - go s.startProtocol(peer) + go peer.StartProtocol() s.tryStartConsensus() default: @@ -732,7 +675,7 @@ func (s *Server) relayInventoryCmd(cmd CommandType, t payload.InventoryType, has continue } // Who cares about these messages anyway? - _ = peer.WriteMsg(msg) + _ = peer.EnqueueMessage(msg) } } diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index b1dd94a82..62ceb417d 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -5,9 +5,11 @@ import ( "net" "strconv" "sync" + "time" "github.com/CityOfZion/neo-go/pkg/io" "github.com/CityOfZion/neo-go/pkg/network/payload" + "go.uber.org/zap" ) type handShakeStage uint8 @@ -17,6 +19,9 @@ const ( versionReceived verAckSent verAckReceived + + requestQueueSize = 32 + hpRequestQueueSize = 4 ) var ( @@ -28,16 +33,20 @@ var ( type TCPPeer struct { // underlying TCP connection. conn net.Conn - + // The server this peer belongs to. + server *Server // The version of the peer. version *payload.Version // Index of the last block. lastBlockIndex uint32 lock sync.RWMutex + finale sync.Once handShake handShakeStage - done chan error + done chan struct{} + sendQ chan []byte + hpSendQ chan []byte wg sync.WaitGroup @@ -46,36 +55,187 @@ type TCPPeer struct { } // NewTCPPeer returns a TCPPeer structure based on the given connection. -func NewTCPPeer(conn net.Conn) *TCPPeer { +func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { return &TCPPeer{ - conn: conn, - done: make(chan error, 1), + conn: conn, + server: s, + done: make(chan struct{}), + sendQ: make(chan []byte, requestQueueSize), + hpSendQ: make(chan []byte, hpRequestQueueSize), } } -// WriteMsg implements the Peer interface. This will write/encode the message -// to the underlying connection, this only works for messages other than Version -// or VerAck. -func (p *TCPPeer) WriteMsg(msg *Message) error { +// EnqueuePacket implements the Peer interface. +func (p *TCPPeer) EnqueuePacket(msg []byte) error { if !p.Handshaked() { return errStateMismatch } - return p.writeMsg(msg) + p.sendQ <- msg + return nil +} + +// EnqueueMessage is a temporary wrapper that sends a message via +// EnqueuePacket if there is no error in serializing it. +func (p *TCPPeer) EnqueueMessage(msg *Message) error { + b, err := msg.Bytes() + if err != nil { + return err + } + return p.EnqueuePacket(b) +} + +// EnqueueHPPacket implements the Peer interface. It the peer is not yet +// handshaked it's a noop. +func (p *TCPPeer) EnqueueHPPacket(msg []byte) error { + if !p.Handshaked() { + return errStateMismatch + } + p.hpSendQ <- msg + return nil } func (p *TCPPeer) writeMsg(msg *Message) error { - select { - case err := <-p.done: + b, err := msg.Bytes() + if err != nil { return err - default: - w := io.NewBufBinWriter() - if err := msg.Encode(w.BinWriter); err != nil { - return err + } + + _, err = p.conn.Write(b) + + return err +} + +// handleConn handles the read side of the connection, it should be started as +// a goroutine right after the new peer setup. +func (p *TCPPeer) handleConn() { + var err error + + p.server.register <- p + + go p.handleQueues() + // When a new peer is connected we send out our version immediately. + err = p.server.sendVersion(p) + if err == nil { + r := io.NewBinReaderFromIO(p.conn) + for { + msg := &Message{} + err = msg.Decode(r) + + if err == payload.ErrTooManyHeaders { + p.server.log.Warn("not all headers were processed") + r.Err = nil + } else if err != nil { + break + } + if err = p.server.handleMessage(p, msg); err != nil { + break + } + } + } + p.Disconnect(err) +} + +// handleQueues is a goroutine that is started automatically to handle +// send queues. +func (p *TCPPeer) handleQueues() { + var err error + + for { + var msg []byte + + // This one is to give priority to the hp queue + select { + case <-p.done: + return + case msg = <-p.hpSendQ: + default: } - _, err := p.conn.Write(w.Bytes()) + // If there is no message in the hp queue, block until one + // appears in any of the queues. + if msg == nil { + select { + case <-p.done: + return + case msg = <-p.hpSendQ: + case msg = <-p.sendQ: + } + } + _, err = p.conn.Write(msg) + if err != nil { + break + } + } + p.Disconnect(err) +} - return err +// StartProtocol starts a long running background loop that interacts +// every ProtoTickInterval with the peer. It's only good to run after the +// handshake. +func (p *TCPPeer) StartProtocol() { + var err error + + p.server.log.Info("started protocol", + zap.Stringer("addr", p.RemoteAddr()), + zap.ByteString("userAgent", p.Version().UserAgent), + zap.Uint32("startHeight", p.Version().StartHeight), + zap.Uint32("id", p.Version().Nonce)) + + p.server.discovery.RegisterGoodAddr(p.PeerAddr().String()) + if p.server.chain.HeaderHeight() < p.LastBlockIndex() { + err = p.server.requestHeaders(p) + if err != nil { + p.Disconnect(err) + return + } + } + + timer := time.NewTimer(p.server.ProtoTickInterval) + pingTimer := time.NewTimer(p.server.PingTimeout) + for { + select { + case <-p.done: + return + case m := <-p.server.addrReq: + var pkt []byte + + pkt, err = m.Bytes() + if err == nil { + err = p.EnqueueHPPacket(pkt) + } + case <-timer.C: + // Try to sync in headers and block with the peer if his block height is higher then ours. + if p.LastBlockIndex() > p.server.chain.BlockHeight() { + err = p.server.requestBlocks(p) + } + if err == nil { + timer.Reset(p.server.ProtoTickInterval) + } + if p.server.chain.HeaderHeight() >= p.LastBlockIndex() { + block, errGetBlock := p.server.chain.GetBlock(p.server.chain.CurrentBlockHash()) + if errGetBlock != nil { + err = errGetBlock + } else { + diff := uint32(time.Now().UTC().Unix()) - block.Timestamp + if diff > uint32(p.server.PingInterval/time.Second) { + p.UpdatePingSent(p.GetPingSent() + 1) + err = p.EnqueueMessage(NewMessage(p.server.Net, CMDPing, payload.NewPing(p.server.id, p.server.chain.HeaderHeight()))) + } + } + } + case <-pingTimer.C: + if p.GetPingSent() > defaultPingLimit { + err = errors.New("ping/pong timeout") + } else { + pingTimer.Reset(p.server.PingTimeout) + p.UpdatePingSent(0) + } + } + if err != nil { + timer.Stop() + p.Disconnect(err) + return + } } } @@ -175,22 +335,13 @@ func (p *TCPPeer) PeerAddr() net.Addr { return tcpAddr } -// Done implements the Peer interface and notifies -// all other resources operating on it that this peer -// is no longer running. -func (p *TCPPeer) Done() chan error { - return p.done -} - // Disconnect will fill the peer's done channel with the given error. func (p *TCPPeer) Disconnect(err error) { - p.conn.Close() - select { - case p.done <- err: - // one message to the queue - default: - // the other side may already be gone, it's OK - } + p.finale.Do(func() { + p.server.unregister <- peerDrop{p, err} + p.conn.Close() + close(p.done) + }) } // Version implements the Peer interface. diff --git a/pkg/network/tcp_peer_test.go b/pkg/network/tcp_peer_test.go index 5223bf64b..691e22b6d 100644 --- a/pkg/network/tcp_peer_test.go +++ b/pkg/network/tcp_peer_test.go @@ -18,8 +18,8 @@ func connReadStub(conn net.Conn) { func TestPeerHandshake(t *testing.T) { server, client := net.Pipe() - tcpS := NewTCPPeer(server) - tcpC := NewTCPPeer(client) + tcpS := NewTCPPeer(server, nil) + tcpC := NewTCPPeer(client, nil) // Something should read things written into the pipe. go connReadStub(tcpS.conn) @@ -30,8 +30,8 @@ func TestPeerHandshake(t *testing.T) { require.Equal(t, false, tcpC.Handshaked()) // No ordinary messages can be written. - require.Error(t, tcpS.WriteMsg(&Message{})) - require.Error(t, tcpC.WriteMsg(&Message{})) + require.Error(t, tcpS.EnqueueMessage(&Message{})) + require.Error(t, tcpC.EnqueueMessage(&Message{})) // Try to mess with VersionAck on both client and server, it should fail. require.Error(t, tcpS.SendVersionAck(&Message{})) @@ -80,6 +80,6 @@ func TestPeerHandshake(t *testing.T) { require.Error(t, tcpS.SendVersionAck(&Message{})) // Now regular messaging can proceed. - require.NoError(t, tcpS.WriteMsg(&Message{})) - require.NoError(t, tcpC.WriteMsg(&Message{})) + require.NoError(t, tcpS.EnqueueMessage(&Message{})) + require.NoError(t, tcpC.EnqueueMessage(&Message{})) } diff --git a/pkg/network/tcp_transport.go b/pkg/network/tcp_transport.go index 58896156e..a9b6feb3b 100644 --- a/pkg/network/tcp_transport.go +++ b/pkg/network/tcp_transport.go @@ -5,8 +5,6 @@ import ( "regexp" "time" - "github.com/CityOfZion/neo-go/pkg/io" - "github.com/CityOfZion/neo-go/pkg/network/payload" "go.uber.org/zap" ) @@ -36,7 +34,8 @@ func (t *TCPTransport) Dial(addr string, timeout time.Duration) error { if err != nil { return err } - go t.handleConn(conn) + p := NewTCPPeer(conn, t.server) + go p.handleConn() return nil } @@ -59,7 +58,8 @@ func (t *TCPTransport) Accept() { } continue } - go t.handleConn(conn) + p := NewTCPPeer(conn, t.server) + go p.handleConn() } } @@ -73,37 +73,6 @@ func (t *TCPTransport) isCloseError(err error) bool { return false } -func (t *TCPTransport) handleConn(conn net.Conn) { - var ( - p = NewTCPPeer(conn) - err error - ) - - t.server.register <- p - - // When a new peer is connected we send out our version immediately. - err = t.server.sendVersion(p) - if err == nil { - r := io.NewBinReaderFromIO(p.conn) - for { - msg := &Message{} - err = msg.Decode(r) - - if err == payload.ErrTooManyHeaders { - t.log.Warn("not all headers were processed") - r.Err = nil - } else if err != nil { - break - } - if err = t.server.handleMessage(p, msg); err != nil { - break - } - } - } - t.server.unregister <- peerDrop{p, err} - p.Disconnect(err) -} - // Close implements the Transporter interface. func (t *TCPTransport) Close() { t.listener.Close()