mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2025-01-25 15:14:48 +00:00
4f708c037d
Fix potential memory leak with a lot of connected clients that keep requesting things from node and then disconnect.
505 lines
12 KiB
Go
505 lines
12 KiB
Go
package network
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
|
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
|
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
|
"go.uber.org/atomic"
|
|
)
|
|
|
|
type handShakeStage uint8
|
|
|
|
const (
|
|
versionSent handShakeStage = 1 << iota
|
|
versionReceived
|
|
verAckSent
|
|
verAckReceived
|
|
|
|
requestQueueSize = 32
|
|
p2pMsgQueueSize = 16
|
|
hpRequestQueueSize = 4
|
|
incomingQueueSize = 1 // Each message can be up to 32MB in size.
|
|
)
|
|
|
|
var (
|
|
errGone = errors.New("the peer is gone already")
|
|
errStateMismatch = errors.New("tried to send protocol message before handshake completed")
|
|
errPingPong = errors.New("ping/pong timeout")
|
|
errUnexpectedPong = errors.New("pong message wasn't expected")
|
|
)
|
|
|
|
// TCPPeer represents a connected remote node in the
|
|
// network over TCP.
|
|
type TCPPeer struct {
|
|
// underlying TCP connection.
|
|
conn net.Conn
|
|
// The server this peer belongs to.
|
|
server *Server
|
|
// The version of the peer.
|
|
version *payload.Version
|
|
// Index of the last block.
|
|
lastBlockIndex uint32
|
|
// pre-handshake non-canonical connection address.
|
|
addr string
|
|
|
|
lock sync.RWMutex
|
|
finale sync.Once
|
|
handShake handShakeStage
|
|
isFullNode bool
|
|
|
|
done chan struct{}
|
|
sendQ chan []byte
|
|
p2pSendQ chan []byte
|
|
hpSendQ chan []byte
|
|
incoming chan *Message
|
|
|
|
// track outstanding getaddr requests.
|
|
getAddrSent atomic.Int32
|
|
|
|
// number of sent pings.
|
|
pingSent int
|
|
pingTimer *time.Timer
|
|
}
|
|
|
|
// NewTCPPeer returns a TCPPeer structure based on the given connection.
|
|
func NewTCPPeer(conn net.Conn, addr string, s *Server) *TCPPeer {
|
|
return &TCPPeer{
|
|
conn: conn,
|
|
server: s,
|
|
addr: addr,
|
|
done: make(chan struct{}),
|
|
sendQ: make(chan []byte, requestQueueSize),
|
|
p2pSendQ: make(chan []byte, p2pMsgQueueSize),
|
|
hpSendQ: make(chan []byte, hpRequestQueueSize),
|
|
incoming: make(chan *Message, incomingQueueSize),
|
|
}
|
|
}
|
|
|
|
// putPacketIntoQueue puts the given message into the given queue if
|
|
// the peer has done handshaking using the given context.
|
|
func (p *TCPPeer) putPacketIntoQueue(ctx context.Context, queue chan<- []byte, msg []byte) error {
|
|
if !p.Handshaked() {
|
|
return errStateMismatch
|
|
}
|
|
select {
|
|
case queue <- msg:
|
|
case <-p.done:
|
|
return errGone
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// BroadcastPacket implements the Peer interface.
|
|
func (p *TCPPeer) BroadcastPacket(ctx context.Context, msg []byte) error {
|
|
return p.putPacketIntoQueue(ctx, p.sendQ, msg)
|
|
}
|
|
|
|
// BroadcastHPPacket implements the Peer interface. It the peer is not yet
|
|
// handshaked it's a noop.
|
|
func (p *TCPPeer) BroadcastHPPacket(ctx context.Context, msg []byte) error {
|
|
return p.putPacketIntoQueue(ctx, p.hpSendQ, msg)
|
|
}
|
|
|
|
// putMessageIntoQueue serializes the given Message and puts it into given queue if
|
|
// the peer has done handshaking.
|
|
func (p *TCPPeer) putMsgIntoQueue(queue chan<- []byte, msg *Message) error {
|
|
b, err := msg.Bytes()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return p.putPacketIntoQueue(context.Background(), queue, b)
|
|
}
|
|
|
|
// EnqueueP2PMessage implements the Peer interface.
|
|
func (p *TCPPeer) EnqueueP2PMessage(msg *Message) error {
|
|
return p.putMsgIntoQueue(p.p2pSendQ, msg)
|
|
}
|
|
|
|
// EnqueueHPMessage implements the Peer interface.
|
|
func (p *TCPPeer) EnqueueHPMessage(msg *Message) error {
|
|
return p.putMsgIntoQueue(p.hpSendQ, msg)
|
|
}
|
|
|
|
// EnqueueP2PPacket implements the Peer interface.
|
|
func (p *TCPPeer) EnqueueP2PPacket(b []byte) error {
|
|
return p.putPacketIntoQueue(context.Background(), p.p2pSendQ, b)
|
|
}
|
|
|
|
// EnqueueHPPacket implements the Peer interface.
|
|
func (p *TCPPeer) EnqueueHPPacket(b []byte) error {
|
|
return p.putPacketIntoQueue(context.Background(), p.hpSendQ, b)
|
|
}
|
|
|
|
func (p *TCPPeer) writeMsg(msg *Message) error {
|
|
b, err := msg.Bytes()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = p.conn.Write(b)
|
|
|
|
return err
|
|
}
|
|
|
|
// handleConn handles the read side of the connection, it should be started as
|
|
// a goroutine right after a new peer setup.
|
|
func (p *TCPPeer) handleConn() {
|
|
var err error
|
|
|
|
p.server.register <- p
|
|
|
|
go p.handleQueues()
|
|
go p.handleIncoming()
|
|
// When a new peer is connected, we send out our version immediately.
|
|
err = p.SendVersion()
|
|
if err == nil {
|
|
r := io.NewBinReaderFromIO(p.conn)
|
|
for {
|
|
msg := &Message{StateRootInHeader: p.server.config.StateRootInHeader}
|
|
err = msg.Decode(r)
|
|
|
|
if errors.Is(err, payload.ErrTooManyHeaders) {
|
|
p.server.log.Warn("not all headers were processed")
|
|
r.Err = nil
|
|
} else if err != nil {
|
|
break
|
|
}
|
|
p.incoming <- msg
|
|
}
|
|
}
|
|
p.Disconnect(err)
|
|
close(p.incoming)
|
|
}
|
|
|
|
func (p *TCPPeer) handleIncoming() {
|
|
var err error
|
|
for msg := range p.incoming {
|
|
err = p.server.handleMessage(p, msg)
|
|
if err != nil {
|
|
if p.Handshaked() {
|
|
err = fmt.Errorf("handling %s message: %w", msg.Command.String(), err)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
p.Disconnect(err)
|
|
}
|
|
|
|
// handleQueues is a goroutine that is started automatically to handle
|
|
// send queues.
|
|
func (p *TCPPeer) handleQueues() {
|
|
var err error
|
|
// p2psend queue shares its time with send queue in around
|
|
// ((p2pSkipDivisor - 1) * 2 + 1)/1 ratio, ratio because the third
|
|
// select can still choose p2psend over send.
|
|
var p2pSkipCounter uint32
|
|
const p2pSkipDivisor = 4
|
|
|
|
var writeTimeout = p.server.TimePerBlock
|
|
for {
|
|
var msg []byte
|
|
|
|
// This one is to give priority to the hp queue
|
|
select {
|
|
case <-p.done:
|
|
return
|
|
case msg = <-p.hpSendQ:
|
|
default:
|
|
}
|
|
|
|
// Skip this select every p2pSkipDivisor iteration.
|
|
if msg == nil && p2pSkipCounter%p2pSkipDivisor != 0 {
|
|
// Then look at the p2p queue.
|
|
select {
|
|
case <-p.done:
|
|
return
|
|
case msg = <-p.hpSendQ:
|
|
case msg = <-p.p2pSendQ:
|
|
default:
|
|
}
|
|
}
|
|
// If there is no message in HP or P2P queues, block until one
|
|
// appears in any of the queues.
|
|
if msg == nil {
|
|
select {
|
|
case <-p.done:
|
|
return
|
|
case msg = <-p.hpSendQ:
|
|
case msg = <-p.p2pSendQ:
|
|
case msg = <-p.sendQ:
|
|
}
|
|
}
|
|
err = p.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
|
if err != nil {
|
|
break
|
|
}
|
|
_, err = p.conn.Write(msg)
|
|
if err != nil {
|
|
break
|
|
}
|
|
p2pSkipCounter++
|
|
}
|
|
p.Disconnect(err)
|
|
drainloop:
|
|
for {
|
|
select {
|
|
case <-p.hpSendQ:
|
|
case <-p.p2pSendQ:
|
|
case <-p.sendQ:
|
|
default:
|
|
break drainloop
|
|
}
|
|
}
|
|
}
|
|
|
|
// StartProtocol starts a long running background loop that interacts
|
|
// every ProtoTickInterval with the peer. It's only good to run after the
|
|
// handshake.
|
|
func (p *TCPPeer) StartProtocol() {
|
|
var err error
|
|
|
|
p.server.handshake <- p
|
|
|
|
err = p.server.requestBlocksOrHeaders(p)
|
|
if err != nil {
|
|
p.Disconnect(err)
|
|
return
|
|
}
|
|
|
|
timer := time.NewTimer(p.server.ProtoTickInterval)
|
|
for {
|
|
select {
|
|
case <-p.done:
|
|
return
|
|
case <-timer.C:
|
|
// Try to sync in headers and block with the peer if his block height is higher than ours.
|
|
err = p.server.requestBlocksOrHeaders(p)
|
|
if err == nil {
|
|
timer.Reset(p.server.ProtoTickInterval)
|
|
}
|
|
}
|
|
if err != nil {
|
|
timer.Stop()
|
|
p.Disconnect(err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handshaked returns status of the handshake, whether it's completed or not.
|
|
func (p *TCPPeer) Handshaked() bool {
|
|
p.lock.RLock()
|
|
defer p.lock.RUnlock()
|
|
return p.handshaked()
|
|
}
|
|
|
|
// handshaked is internal unlocked version of Handshaked().
|
|
func (p *TCPPeer) handshaked() bool {
|
|
return p.handShake == (verAckReceived | verAckSent | versionReceived | versionSent)
|
|
}
|
|
|
|
// IsFullNode returns whether the node has full capability or TCP/WS only.
|
|
func (p *TCPPeer) IsFullNode() bool {
|
|
p.lock.RLock()
|
|
defer p.lock.RUnlock()
|
|
return p.handshaked() && p.isFullNode
|
|
}
|
|
|
|
// SendVersion checks for the handshake state and sends a message to the peer.
|
|
func (p *TCPPeer) SendVersion() error {
|
|
msg, err := p.server.getVersionMsg(p.conn.LocalAddr())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.lock.Lock()
|
|
defer p.lock.Unlock()
|
|
if p.handShake&versionSent != 0 {
|
|
return errors.New("invalid handshake: already sent Version")
|
|
}
|
|
err = p.writeMsg(msg)
|
|
if err == nil {
|
|
p.handShake |= versionSent
|
|
}
|
|
return err
|
|
}
|
|
|
|
// HandleVersion checks for the handshake state and version message contents.
|
|
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
|
|
p.lock.Lock()
|
|
defer p.lock.Unlock()
|
|
if p.handShake&versionReceived != 0 {
|
|
return errors.New("invalid handshake: already received Version")
|
|
}
|
|
p.version = version
|
|
for _, cap := range version.Capabilities {
|
|
if cap.Type == capability.FullNode {
|
|
p.isFullNode = true
|
|
p.lastBlockIndex = cap.Data.(*capability.Node).StartHeight
|
|
break
|
|
}
|
|
}
|
|
|
|
p.handShake |= versionReceived
|
|
return nil
|
|
}
|
|
|
|
// SendVersionAck checks for the handshake state and sends a message to the peer.
|
|
func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
|
p.lock.Lock()
|
|
defer p.lock.Unlock()
|
|
if p.handShake&versionReceived == 0 {
|
|
return errors.New("invalid handshake: tried to send VersionAck, but no version received yet")
|
|
}
|
|
if p.handShake&versionSent == 0 {
|
|
return errors.New("invalid handshake: tried to send VersionAck, but didn't send Version yet")
|
|
}
|
|
if p.handShake&verAckSent != 0 {
|
|
return errors.New("invalid handshake: already sent VersionAck")
|
|
}
|
|
err := p.writeMsg(msg)
|
|
if err == nil {
|
|
p.handShake |= verAckSent
|
|
}
|
|
return err
|
|
}
|
|
|
|
// HandleVersionAck checks handshake sequence correctness when VerAck message
|
|
// is received.
|
|
func (p *TCPPeer) HandleVersionAck() error {
|
|
p.lock.Lock()
|
|
defer p.lock.Unlock()
|
|
if p.handShake&versionSent == 0 {
|
|
return errors.New("invalid handshake: received VersionAck, but no version sent yet")
|
|
}
|
|
if p.handShake&versionReceived == 0 {
|
|
return errors.New("invalid handshake: received VersionAck, but no version received yet")
|
|
}
|
|
if p.handShake&verAckReceived != 0 {
|
|
return errors.New("invalid handshake: already received VersionAck")
|
|
}
|
|
p.handShake |= verAckReceived
|
|
return nil
|
|
}
|
|
|
|
// ConnectionAddr implements the Peer interface.
|
|
func (p *TCPPeer) ConnectionAddr() string {
|
|
if p.addr != "" {
|
|
return p.addr
|
|
}
|
|
return p.conn.RemoteAddr().String()
|
|
}
|
|
|
|
// RemoteAddr implements the Peer interface.
|
|
func (p *TCPPeer) RemoteAddr() net.Addr {
|
|
return p.conn.RemoteAddr()
|
|
}
|
|
|
|
// PeerAddr implements the Peer interface.
|
|
func (p *TCPPeer) PeerAddr() net.Addr {
|
|
remote := p.conn.RemoteAddr()
|
|
// The network can be non-tcp in unit tests.
|
|
if p.version == nil || remote.Network() != "tcp" {
|
|
return p.RemoteAddr()
|
|
}
|
|
host, _, err := net.SplitHostPort(remote.String())
|
|
if err != nil {
|
|
return p.RemoteAddr()
|
|
}
|
|
var port uint16
|
|
for _, cap := range p.version.Capabilities {
|
|
if cap.Type == capability.TCPServer {
|
|
port = cap.Data.(*capability.Server).Port
|
|
}
|
|
}
|
|
if port == 0 {
|
|
return p.RemoteAddr()
|
|
}
|
|
addrString := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
|
tcpAddr, err := net.ResolveTCPAddr("tcp", addrString)
|
|
if err != nil {
|
|
return p.RemoteAddr()
|
|
}
|
|
return tcpAddr
|
|
}
|
|
|
|
// Disconnect will fill the peer's done channel with the given error.
|
|
func (p *TCPPeer) Disconnect(err error) {
|
|
p.finale.Do(func() {
|
|
close(p.done)
|
|
p.conn.Close()
|
|
p.server.unregister <- peerDrop{p, err}
|
|
})
|
|
}
|
|
|
|
// Version implements the Peer interface.
|
|
func (p *TCPPeer) Version() *payload.Version {
|
|
return p.version
|
|
}
|
|
|
|
// LastBlockIndex returns the last block index.
|
|
func (p *TCPPeer) LastBlockIndex() uint32 {
|
|
p.lock.RLock()
|
|
defer p.lock.RUnlock()
|
|
return p.lastBlockIndex
|
|
}
|
|
|
|
// SetPingTimer adds an outgoing ping to the counter and sets a PingTimeout timer
|
|
// that will shut the connection down in case of no response.
|
|
func (p *TCPPeer) SetPingTimer() {
|
|
p.lock.Lock()
|
|
p.pingSent++
|
|
if p.pingTimer == nil {
|
|
p.pingTimer = time.AfterFunc(p.server.PingTimeout, func() {
|
|
p.Disconnect(errPingPong)
|
|
})
|
|
}
|
|
p.lock.Unlock()
|
|
}
|
|
|
|
// HandlePing handles a ping message received from the peer.
|
|
func (p *TCPPeer) HandlePing(ping *payload.Ping) error {
|
|
p.lock.Lock()
|
|
defer p.lock.Unlock()
|
|
p.lastBlockIndex = ping.LastBlockIndex
|
|
return nil
|
|
}
|
|
|
|
// HandlePong handles a pong message received from the peer and does an appropriate
|
|
// accounting of outstanding pings and timeouts.
|
|
func (p *TCPPeer) HandlePong(pong *payload.Ping) error {
|
|
p.lock.Lock()
|
|
defer p.lock.Unlock()
|
|
if p.pingTimer != nil && !p.pingTimer.Stop() {
|
|
return errPingPong
|
|
}
|
|
p.pingTimer = nil
|
|
p.pingSent--
|
|
if p.pingSent < 0 {
|
|
return errUnexpectedPong
|
|
}
|
|
p.lastBlockIndex = pong.LastBlockIndex
|
|
return nil
|
|
}
|
|
|
|
// AddGetAddrSent increments internal outstanding getaddr requests counter. Then,
|
|
// the peer can only send one addr reply per getaddr request.
|
|
func (p *TCPPeer) AddGetAddrSent() {
|
|
p.getAddrSent.Inc()
|
|
}
|
|
|
|
// CanProcessAddr decrements internal outstanding getaddr requests counter and
|
|
// answers whether the addr command from the peer can be safely processed.
|
|
func (p *TCPPeer) CanProcessAddr() bool {
|
|
v := p.getAddrSent.Dec()
|
|
return v >= 0
|
|
}
|