diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index d06b12847..a44ae9f00 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -2,6 +2,7 @@ package network import ( "context" + "errors" "fmt" "net" "sync" @@ -115,17 +116,25 @@ func (p *localPeer) Disconnect(err error) { } func (p *localPeer) BroadcastPacket(_ context.Context, m []byte) error { + if len(m) == 0 { + return errors.New("empty msg") + } msg := &Message{} r := io.NewBinReaderFromBuf(m) - err := msg.Decode(r) - if err == nil { - p.messageHandler(p.t, msg) + for r.Len() > 0 { + err := msg.Decode(r) + if err == nil { + p.messageHandler(p.t, msg) + } } return nil } func (p *localPeer) EnqueueP2PMessage(msg *Message) error { return p.EnqueueHPMessage(msg) } +func (p *localPeer) EnqueueP2PPacket(m []byte) error { + return p.BroadcastPacket(context.TODO(), m) +} func (p *localPeer) BroadcastHPPacket(ctx context.Context, m []byte) error { return p.BroadcastPacket(ctx, m) } @@ -133,6 +142,9 @@ func (p *localPeer) EnqueueHPMessage(msg *Message) error { p.messageHandler(p.t, msg) return nil } +func (p *localPeer) EnqueueHPPacket(m []byte) error { + return p.BroadcastPacket(context.TODO(), m) +} func (p *localPeer) Version() *payload.Version { return p.version } diff --git a/pkg/network/peer.go b/pkg/network/peer.go index 9854165d5..6dfcf16e0 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -39,10 +39,16 @@ type Peer interface { // (handled by BroadcastPacket) but less important than high-priority // messages (handled by EnqueueHPMessage). EnqueueP2PMessage(*Message) error + // EnqueueP2PPacket is similar to EnqueueP2PMessage, but accepts a slice of + // message(s) bytes. + EnqueueP2PPacket([]byte) error // EnqueueHPMessage is similar to EnqueueP2PMessage, but uses a high-priority // queue. EnqueueHPMessage(*Message) error + // EnqueueHPPacket is similar to EnqueueHPMessage, but accepts a slice of + // message(s) bytes. + EnqueueHPPacket([]byte) error Version() *payload.Version LastBlockIndex() uint32 Handshaked() bool diff --git a/pkg/network/server.go b/pkg/network/server.go index 57aa06ffe..49bd63bc0 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -808,7 +808,15 @@ func (s *Server) handleMempoolCmd(p Peer) error { // handleInvCmd processes the received inventory. func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { - var notFound []util.Uint256 + var ( + err error + notFound []util.Uint256 + reply = io.NewBufBinWriter() + send = p.EnqueueP2PPacket + ) + if inv.Type == payload.ExtensibleType { + send = p.EnqueueHPPacket + } for _, hash := range inv.Hashes { var msg *Message @@ -839,19 +847,37 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { } } if msg != nil { - var err error - if inv.Type == payload.ExtensibleType { - err = p.EnqueueHPMessage(msg) - } else { - err = p.EnqueueP2PMessage(msg) - } + err = addMessageToPacket(reply, msg, send) if err != nil { return err } } } if len(notFound) != 0 { - return p.EnqueueP2PMessage(NewMessage(CMDNotFound, payload.NewInventory(inv.Type, notFound))) + err = addMessageToPacket(reply, NewMessage(CMDNotFound, payload.NewInventory(inv.Type, notFound)), send) + if err != nil { + return err + } + } + if reply.Len() == 0 { + return nil + } + return send(reply.Bytes()) +} + +// addMessageToPacket serializes given message into the given buffer and sends whole +// batch if it exceeds MaxSize/2 memory limit (to prevent DoS). +func addMessageToPacket(batch *io.BufBinWriter, msg *Message, send func([]byte) error) error { + err := msg.Encode(batch.BinWriter) + if err != nil { + return err + } + if batch.Len() > payload.MaxSize/2 { + err = send(batch.Bytes()) + if err != nil { + return err + } + batch.Reset() } return nil } @@ -945,6 +971,7 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { // handleGetBlockByIndexCmd processes the getblockbyindex request. func (s *Server) handleGetBlockByIndexCmd(p Peer, gbd *payload.GetBlockByIndex) error { + var reply = io.NewBufBinWriter() count := gbd.Count if gbd.Count < 0 || gbd.Count > payload.MaxHashesCount { count = payload.MaxHashesCount @@ -958,12 +985,15 @@ func (s *Server) handleGetBlockByIndexCmd(p Peer, gbd *payload.GetBlockByIndex) if err != nil { break } - msg := NewMessage(CMDBlock, b) - if err = p.EnqueueP2PMessage(msg); err != nil { + err = addMessageToPacket(reply, NewMessage(CMDBlock, b), p.EnqueueP2PPacket) + if err != nil { return err } } - return nil + if reply.Len() == 0 { + return nil + } + return p.EnqueueP2PPacket(reply.Bytes()) } // handleGetHeadersCmd processes the getheaders request. diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 9244716cc..392934615 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -128,6 +128,16 @@ func (p *TCPPeer) EnqueueHPMessage(msg *Message) error { return p.putMsgIntoQueue(p.hpSendQ, msg) } +// EnqueueP2PPacket implements the Peer interface. +func (p *TCPPeer) EnqueueP2PPacket(b []byte) error { + return p.putPacketIntoQueue(context.Background(), p.p2pSendQ, b) +} + +// EnqueueHPPacket implements the Peer interface. +func (p *TCPPeer) EnqueueHPPacket(b []byte) error { + return p.putPacketIntoQueue(context.Background(), p.hpSendQ, b) +} + func (p *TCPPeer) writeMsg(msg *Message) error { b, err := msg.Bytes() if err != nil {