network: move per-peer goroutines into the TCPPeer

As they're directly tied to it.
This commit is contained in:
Roman Khimov 2020-01-15 17:03:42 +03:00
parent 32213b1454
commit 907a236285
6 changed files with 115 additions and 107 deletions

View file

@ -177,6 +177,7 @@ func (p *localPeer) RemoteAddr() net.Addr {
func (p *localPeer) PeerAddr() net.Addr {
return &p.netaddr
}
func (p *localPeer) StartProtocol() {}
func (p *localPeer) Disconnect(err error) {}
func (p *localPeer) WriteMsg(msg *Message) error {
p.messageHandler(p.t, msg)

View file

@ -26,6 +26,9 @@ type Peer interface {
Handshaked() bool
SendVersion(*Message) error
SendVersionAck(*Message) error
// StartProtocol is a goroutine to be run after the handshake. It
// implements basic peer-related protocol handling.
StartProtocol()
HandleVersion(*payload.Version) error
HandleVersionAck() error
GetPingSent() int

View file

@ -305,71 +305,6 @@ func (s *Server) HandshakedPeersCount() int {
return count
}
// startProtocol starts a long running background loop that interacts
// every ProtoTickInterval with the peer.
func (s *Server) startProtocol(p Peer) {
var err error
s.log.Info("started protocol",
zap.Stringer("addr", p.RemoteAddr()),
zap.ByteString("userAgent", p.Version().UserAgent),
zap.Uint32("startHeight", p.Version().StartHeight),
zap.Uint32("id", p.Version().Nonce))
s.discovery.RegisterGoodAddr(p.PeerAddr().String())
if s.chain.HeaderHeight() < p.LastBlockIndex() {
err = s.requestHeaders(p)
if err != nil {
p.Disconnect(err)
return
}
}
timer := time.NewTimer(s.ProtoTickInterval)
pingTimer := time.NewTimer(s.PingTimeout)
for {
select {
case err = <-p.Done():
// time to stop
case m := <-s.addrReq:
err = p.WriteMsg(m)
case <-timer.C:
// Try to sync in headers and block with the peer if his block height is higher then ours.
if p.LastBlockIndex() > s.chain.BlockHeight() {
err = s.requestBlocks(p)
}
if err == nil {
timer.Reset(s.ProtoTickInterval)
}
if s.chain.HeaderHeight() >= p.LastBlockIndex() {
block, errGetBlock := s.chain.GetBlock(s.chain.CurrentBlockHash())
if errGetBlock != nil {
err = errGetBlock
} else {
diff := uint32(time.Now().UTC().Unix()) - block.Timestamp
if diff > uint32(s.PingInterval/time.Second) {
p.UpdatePingSent(p.GetPingSent() + 1)
err = p.WriteMsg(NewMessage(s.Net, CMDPing, payload.NewPing(s.id, s.chain.HeaderHeight())))
}
}
}
case <-pingTimer.C:
if p.GetPingSent() > defaultPingLimit {
err = errors.New("ping/pong timeout")
} else {
pingTimer.Reset(s.PingTimeout)
p.UpdatePingSent(0)
}
}
if err != nil {
s.unregister <- peerDrop{p, err}
timer.Stop()
p.Disconnect(err)
return
}
}
}
// When a peer connects to the server, we will send our version immediately.
func (s *Server) sendVersion(p Peer) error {
payload := payload.NewVersion(
@ -701,7 +636,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
if err != nil {
return err
}
go s.startProtocol(peer)
go peer.StartProtocol()
s.tryStartConsensus()
default:

View file

@ -5,9 +5,11 @@ import (
"net"
"strconv"
"sync"
"time"
"github.com/CityOfZion/neo-go/pkg/io"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"go.uber.org/zap"
)
type handShakeStage uint8
@ -28,7 +30,8 @@ var (
type TCPPeer struct {
// underlying TCP connection.
conn net.Conn
// The server this peer belongs to.
server *Server
// The version of the peer.
version *payload.Version
// Index of the last block.
@ -46,9 +49,10 @@ type TCPPeer struct {
}
// NewTCPPeer returns a TCPPeer structure based on the given connection.
func NewTCPPeer(conn net.Conn) *TCPPeer {
func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer {
return &TCPPeer{
conn: conn,
server: s,
done: make(chan error, 1),
}
}
@ -79,6 +83,102 @@ func (p *TCPPeer) writeMsg(msg *Message) error {
}
}
// handleConn handles the read side of the connection, it should be started as
// a goroutine right after the new peer setup.
func (p *TCPPeer) handleConn() {
var err error
p.server.register <- p
// When a new peer is connected we send out our version immediately.
err = p.server.sendVersion(p)
if err == nil {
r := io.NewBinReaderFromIO(p.conn)
for {
msg := &Message{}
err = msg.Decode(r)
if err == payload.ErrTooManyHeaders {
p.server.log.Warn("not all headers were processed")
r.Err = nil
} else if err != nil {
break
}
if err = p.server.handleMessage(p, msg); err != nil {
break
}
}
}
p.server.unregister <- peerDrop{p, err}
p.Disconnect(err)
}
// StartProtocol starts a long running background loop that interacts
// every ProtoTickInterval with the peer. It's only good to run after the
// handshake.
func (p *TCPPeer) StartProtocol() {
var err error
p.server.log.Info("started protocol",
zap.Stringer("addr", p.RemoteAddr()),
zap.ByteString("userAgent", p.Version().UserAgent),
zap.Uint32("startHeight", p.Version().StartHeight),
zap.Uint32("id", p.Version().Nonce))
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String())
if p.server.chain.HeaderHeight() < p.LastBlockIndex() {
err = p.server.requestHeaders(p)
if err != nil {
p.Disconnect(err)
return
}
}
timer := time.NewTimer(p.server.ProtoTickInterval)
pingTimer := time.NewTimer(p.server.PingTimeout)
for {
select {
case err = <-p.Done():
// time to stop
case m := <-p.server.addrReq:
err = p.WriteMsg(m)
case <-timer.C:
// Try to sync in headers and block with the peer if his block height is higher then ours.
if p.LastBlockIndex() > p.server.chain.BlockHeight() {
err = p.server.requestBlocks(p)
}
if err == nil {
timer.Reset(p.server.ProtoTickInterval)
}
if p.server.chain.HeaderHeight() >= p.LastBlockIndex() {
block, errGetBlock := p.server.chain.GetBlock(p.server.chain.CurrentBlockHash())
if errGetBlock != nil {
err = errGetBlock
} else {
diff := uint32(time.Now().UTC().Unix()) - block.Timestamp
if diff > uint32(p.server.PingInterval/time.Second) {
p.UpdatePingSent(p.GetPingSent() + 1)
err = p.WriteMsg(NewMessage(p.server.Net, CMDPing, payload.NewPing(p.server.id, p.server.chain.HeaderHeight())))
}
}
}
case <-pingTimer.C:
if p.GetPingSent() > defaultPingLimit {
err = errors.New("ping/pong timeout")
} else {
pingTimer.Reset(p.server.PingTimeout)
p.UpdatePingSent(0)
}
}
if err != nil {
p.server.unregister <- peerDrop{p, err}
timer.Stop()
p.Disconnect(err)
return
}
}
}
// Handshaked returns status of the handshake, whether it's completed or not.
func (p *TCPPeer) Handshaked() bool {
p.lock.RLock()

View file

@ -18,8 +18,8 @@ func connReadStub(conn net.Conn) {
func TestPeerHandshake(t *testing.T) {
server, client := net.Pipe()
tcpS := NewTCPPeer(server)
tcpC := NewTCPPeer(client)
tcpS := NewTCPPeer(server, nil)
tcpC := NewTCPPeer(client, nil)
// Something should read things written into the pipe.
go connReadStub(tcpS.conn)

View file

@ -5,8 +5,6 @@ import (
"regexp"
"time"
"github.com/CityOfZion/neo-go/pkg/io"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"go.uber.org/zap"
)
@ -36,7 +34,8 @@ func (t *TCPTransport) Dial(addr string, timeout time.Duration) error {
if err != nil {
return err
}
go t.handleConn(conn)
p := NewTCPPeer(conn, t.server)
go p.handleConn()
return nil
}
@ -59,7 +58,8 @@ func (t *TCPTransport) Accept() {
}
continue
}
go t.handleConn(conn)
p := NewTCPPeer(conn, t.server)
go p.handleConn()
}
}
@ -73,37 +73,6 @@ func (t *TCPTransport) isCloseError(err error) bool {
return false
}
func (t *TCPTransport) handleConn(conn net.Conn) {
var (
p = NewTCPPeer(conn)
err error
)
t.server.register <- p
// When a new peer is connected we send out our version immediately.
err = t.server.sendVersion(p)
if err == nil {
r := io.NewBinReaderFromIO(p.conn)
for {
msg := &Message{}
err = msg.Decode(r)
if err == payload.ErrTooManyHeaders {
t.log.Warn("not all headers were processed")
r.Err = nil
} else if err != nil {
break
}
if err = t.server.handleMessage(p, msg); err != nil {
break
}
}
}
t.server.unregister <- peerDrop{p, err}
p.Disconnect(err)
}
// Close implements the Transporter interface.
func (t *TCPTransport) Close() {
t.listener.Close()