neoneo-go/pkg/network/tcp_peer.go
Roman Khimov d3bb8ddf8f network: handle errors and connection close more correctly
This makes writer side handle errors properly and fixes communication between
reader and writer goroutine to always correctly unregister the peer. This is
especially important for the case where error occurs before handshake
completes as in this case we don't even have goroutine in startProtocol()
running.
2019-09-16 16:32:04 +03:00

150 lines
3.6 KiB
Go

package network
import (
"errors"
"fmt"
"net"
"sync"
"github.com/CityOfZion/neo-go/pkg/network/payload"
)
type handShakeStage uint8
//go:generate stringer -type=handShakeStage
const (
nothingDone handShakeStage = 0
versionSent handShakeStage = 1
versionReceived handShakeStage = 2
verAckSent handShakeStage = 3
verAckReceived handShakeStage = 4
)
var (
errStateMismatch = errors.New("tried to send protocol message before handshake completed")
)
// TCPPeer represents a connected remote node in the
// network over TCP.
type TCPPeer struct {
// underlying TCP connection.
conn net.Conn
addr net.TCPAddr
// The version of the peer.
version *payload.Version
handShake handShakeStage
done chan error
wg sync.WaitGroup
}
// NewTCPPeer returns a TCPPeer structure based on the given connection.
func NewTCPPeer(conn net.Conn) *TCPPeer {
raddr := conn.RemoteAddr()
// can't fail because raddr is a real connection
tcpaddr, _ := net.ResolveTCPAddr(raddr.Network(), raddr.String())
return &TCPPeer{
conn: conn,
done: make(chan error, 1),
addr: *tcpaddr,
}
}
// WriteMsg implements the Peer interface. This will write/encode the message
// to the underlying connection, this only works for messages other than Version
// or VerAck.
func (p *TCPPeer) WriteMsg(msg *Message) error {
if !p.Handshaked() {
return errStateMismatch
}
return p.writeMsg(msg)
}
func (p *TCPPeer) writeMsg(msg *Message) error {
select {
case err := <-p.done:
return err
default:
return msg.Encode(p.conn)
}
}
// Handshaked returns status of the handshake, whether it's completed or not.
func (p *TCPPeer) Handshaked() bool {
return p.handShake == verAckReceived
}
// SendVersion checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersion(msg *Message) error {
if p.handShake != nothingDone {
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = versionSent
}
return err
}
// HandleVersion checks for the handshake state and version message contents.
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
if p.handShake != versionSent {
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
}
p.version = version
p.handShake = versionReceived
return nil
}
// SendVersionAck checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersionAck(msg *Message) error {
if p.handShake != versionReceived {
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = verAckSent
}
return err
}
// HandleVersionAck checks handshake sequence correctness when VerAck message
// is received.
func (p *TCPPeer) HandleVersionAck() error {
if p.handShake != verAckSent {
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
}
p.handShake = verAckReceived
return nil
}
// NetAddr implements the Peer interface.
func (p *TCPPeer) NetAddr() *net.TCPAddr {
return &p.addr
}
// Done implements the Peer interface and notifies
// all other resources operating on it that this peer
// is no longer running.
func (p *TCPPeer) Done() chan error {
return p.done
}
// Disconnect will fill the peer's done channel with the given error.
func (p *TCPPeer) Disconnect(err error) {
p.conn.Close()
select {
case p.done <- err:
// one message to the queue
default:
// the other side may already be gone, it's OK
}
}
// Version implements the Peer interface.
func (p *TCPPeer) Version() *payload.Version {
return p.version
}