network: handle errors and connection close more correctly

This makes writer side handle errors properly and fixes communication between
reader and writer goroutine to always correctly unregister the peer. This is
especially important for the case where error occurs before handshake
completes as in this case we don't even have goroutine in startProtocol()
running.
This commit is contained in:
Roman Khimov 2019-09-13 15:36:53 +03:00
parent 76c7cff67f
commit d3bb8ddf8f
3 changed files with 45 additions and 26 deletions

View file

@ -153,6 +153,7 @@ func (s *Server) run() {
"addr": p.NetAddr(), "addr": p.NetAddr(),
}).Info("new peer connected") }).Info("new peer connected")
case drop := <-s.unregister: case drop := <-s.unregister:
if s.peers[drop.peer] {
delete(s.peers, drop.peer) delete(s.peers, drop.peer)
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"addr": drop.peer.NetAddr(), "addr": drop.peer.NetAddr(),
@ -161,6 +162,9 @@ func (s *Server) run() {
}).Warn("peer disconnected") }).Warn("peer disconnected")
s.discovery.BackFill(drop.peer.NetAddr().String()) s.discovery.BackFill(drop.peer.NetAddr().String())
} }
// else the peer is already gone, which can happen
// because we have two goroutines sending signals here
}
} }
} }
@ -187,24 +191,35 @@ func (s *Server) startProtocol(p Peer) {
"id": p.Version().Nonce, "id": p.Version().Nonce,
}).Info("started protocol") }).Info("started protocol")
s.requestHeaders(p) err := s.requestHeaders(p)
if err != nil {
p.Disconnect(err)
return
}
timer := time.NewTimer(s.ProtoTickInterval) timer := time.NewTimer(s.ProtoTickInterval)
for { for {
select { select {
case err := <-p.Done(): case err = <-p.Done():
s.unregister <- peerDrop{p, err} // time to stop
return
case m := <-s.addrReq: case m := <-s.addrReq:
p.WriteMsg(m) err = p.WriteMsg(m)
case <-timer.C: case <-timer.C:
// Try to sync in headers and block with the peer if his block height is higher then ours. // Try to sync in headers and block with the peer if his block height is higher then ours.
if p.Version().StartHeight > s.chain.BlockHeight() { if p.Version().StartHeight > s.chain.BlockHeight() {
s.requestBlocks(p) err = s.requestBlocks(p)
} }
if err == nil {
timer.Reset(s.ProtoTickInterval) timer.Reset(s.ProtoTickInterval)
} }
} }
if err != nil {
s.unregister <- peerDrop{p, err}
timer.Stop()
p.Disconnect(err)
return
}
}
} }
// When a peer connects to the server, we will send our version immediately. // When a peer connects to the server, we will send our version immediately.
@ -279,16 +294,16 @@ func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
// requestHeaders will send a getheaders message to the peer. // requestHeaders will send a getheaders message to the peer.
// The peer will respond with headers op to a count of 2000. // The peer will respond with headers op to a count of 2000.
func (s *Server) requestHeaders(p Peer) { func (s *Server) requestHeaders(p Peer) error {
start := []util.Uint256{s.chain.CurrentHeaderHash()} start := []util.Uint256{s.chain.CurrentHeaderHash()}
payload := payload.NewGetBlocks(start, util.Uint256{}) payload := payload.NewGetBlocks(start, util.Uint256{})
p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload)) return p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload))
} }
// requestBlocks will send a getdata message to the peer // requestBlocks will send a getdata message to the peer
// to sync up in blocks. A maximum of maxBlockBatch will // to sync up in blocks. A maximum of maxBlockBatch will
// send at once. // send at once.
func (s *Server) requestBlocks(p Peer) { func (s *Server) requestBlocks(p Peer) error {
var ( var (
hashes []util.Uint256 hashes []util.Uint256
hashStart = s.chain.BlockHeight() + 1 hashStart = s.chain.BlockHeight() + 1
@ -301,10 +316,11 @@ func (s *Server) requestBlocks(p Peer) {
} }
if len(hashes) > 0 { if len(hashes) > 0 {
payload := payload.NewInventory(payload.BlockType, hashes) payload := payload.NewInventory(payload.BlockType, hashes)
p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
} else if s.chain.HeaderHeight() < p.Version().StartHeight { } else if s.chain.HeaderHeight() < p.Version().StartHeight {
s.requestHeaders(p) return s.requestHeaders(p)
} }
return nil
} }
// handleMessage will process the given message. // handleMessage will process the given message.

View file

@ -136,7 +136,12 @@ func (p *TCPPeer) Done() chan error {
// Disconnect will fill the peer's done channel with the given error. // Disconnect will fill the peer's done channel with the given error.
func (p *TCPPeer) Disconnect(err error) { func (p *TCPPeer) Disconnect(err error) {
p.conn.Close() p.conn.Close()
p.done <- err select {
case p.done <- err:
// one message to the queue
default:
// the other side may already be gone, it's OK
}
} }
// Version implements the Peer interface. // Version implements the Peer interface.

View file

@ -75,21 +75,19 @@ func (t *TCPTransport) handleConn(conn net.Conn) {
err error err error
) )
defer func() {
p.Disconnect(err)
}()
t.server.register <- p t.server.register <- p
for { for {
msg := &Message{} msg := &Message{}
if err = msg.Decode(p.conn); err != nil { if err = msg.Decode(p.conn); err != nil {
return break
} }
if err = t.server.handleMessage(p, msg); err != nil { if err = t.server.handleMessage(p, msg); err != nil {
return break
} }
} }
t.server.unregister <- peerDrop{p, err}
p.Disconnect(err)
} }
// Close implements the Transporter interface. // Close implements the Transporter interface.