network: make node strictly follow handshake procedure

Don't accept other messages before handshake is completed, check handshake
message sequence.
This commit is contained in:
Roman Khimov 2019-09-13 15:43:22 +03:00
parent c6487423ae
commit 76c7cff67f
5 changed files with 164 additions and 35 deletions

View file

@ -0,0 +1,27 @@
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
package network
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[nothingDone-0]
_ = x[versionSent-1]
_ = x[versionReceived-2]
_ = x[verAckSent-3]
_ = x[verAckReceived-4]
}
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
func (i handShakeStage) String() string {
if i >= handShakeStage(len(_handShakeStage_index)-1) {
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
}

View file

@ -114,6 +114,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {}
type localPeer struct {
netaddr net.TCPAddr
version *payload.Version
handshaked bool
t *testing.T
messageHandler func(t *testing.T, msg *Message)
}
@ -142,8 +143,23 @@ func (p *localPeer) Done() chan error {
func (p *localPeer) Version() *payload.Version {
return p.version
}
func (p *localPeer) SetVersion(v *payload.Version) {
func (p *localPeer) HandleVersion(v *payload.Version) error {
p.version = v
return nil
}
func (p *localPeer) SendVersion(m *Message) error {
return p.WriteMsg(m)
}
func (p *localPeer) SendVersionAck(m *Message) error {
return p.WriteMsg(m)
}
func (p *localPeer) HandleVersionAck() error {
p.handshaked = true
return nil
}
func (p *localPeer) Handshaked() bool {
return p.handshaked
}
func newTestServer() *Server {

View file

@ -13,5 +13,9 @@ type Peer interface {
WriteMsg(msg *Message) error
Done() chan error
Version() *payload.Version
SetVersion(*payload.Version)
Handshaked() bool
SendVersion(*Message) error
SendVersionAck(*Message) error
HandleVersion(*payload.Version) error
HandleVersionAck() error
}

View file

@ -216,20 +216,23 @@ func (s *Server) sendVersion(p Peer) error {
s.chain.BlockHeight(),
s.Relay,
)
return p.WriteMsg(NewMessage(s.Net, CMDVersion, payload))
return p.SendVersion(NewMessage(s.Net, CMDVersion, payload))
}
// When a peer sends out his version we reply with verack after validating
// the version.
func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
if p.NetAddr().Port != int(version.Port) {
return errPortMismatch
err := p.HandleVersion(version)
if err != nil {
return err
}
if s.id == version.Nonce {
return errIdenticalID
}
p.SetVersion(version)
return p.WriteMsg(NewMessage(s.Net, CMDVerack, nil))
if p.NetAddr().Port != int(version.Port) {
return errPortMismatch
}
return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil))
}
// handleHeadersCmd will process the headers it received from its peer.
@ -312,29 +315,37 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return errInvalidNetwork
}
switch msg.CommandType() {
case CMDAddr:
addrs := msg.Payload.(*payload.AddressList)
return s.handleAddrCmd(peer, addrs)
case CMDVersion:
version := msg.Payload.(*payload.Version)
return s.handleVersionCmd(peer, version)
case CMDHeaders:
headers := msg.Payload.(*payload.Headers)
go s.handleHeadersCmd(peer, headers)
case CMDInv:
inventory := msg.Payload.(*payload.Inventory)
return s.handleInvCmd(peer, inventory)
case CMDBlock:
block := msg.Payload.(*core.Block)
return s.handleBlockCmd(peer, block)
case CMDVerack:
// Make sure this peer has send his version before we start the
// protocol with that peer.
if peer.Version() == nil {
return errInvalidHandshake
if peer.Handshaked() {
switch msg.CommandType() {
case CMDAddr:
addrs := msg.Payload.(*payload.AddressList)
return s.handleAddrCmd(peer, addrs)
case CMDHeaders:
headers := msg.Payload.(*payload.Headers)
go s.handleHeadersCmd(peer, headers)
case CMDInv:
inventory := msg.Payload.(*payload.Inventory)
return s.handleInvCmd(peer, inventory)
case CMDBlock:
block := msg.Payload.(*core.Block)
return s.handleBlockCmd(peer, block)
case CMDVersion, CMDVerack:
return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
}
} else {
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
return s.handleVersionCmd(peer, version)
case CMDVerack:
err := peer.HandleVersionAck()
if err != nil {
return err
}
go s.startProtocol(peer)
default:
return fmt.Errorf("received '%s' during handshake", msg.CommandType())
}
go s.startProtocol(peer)
}
return nil
}

View file

@ -1,12 +1,29 @@
package network
import (
"errors"
"fmt"
"net"
"sync"
"github.com/CityOfZion/neo-go/pkg/network/payload"
)
type handShakeStage uint8
//go:generate stringer -type=handShakeStage
const (
nothingDone handShakeStage = 0
versionSent handShakeStage = 1
versionReceived handShakeStage = 2
verAckSent handShakeStage = 3
verAckReceived handShakeStage = 4
)
var (
errStateMismatch = errors.New("tried to send protocol message before handshake completed")
)
// TCPPeer represents a connected remote node in the
// network over TCP.
type TCPPeer struct {
@ -17,6 +34,8 @@ type TCPPeer struct {
// The version of the peer.
version *payload.Version
handShake handShakeStage
done chan error
wg sync.WaitGroup
@ -35,8 +54,16 @@ func NewTCPPeer(conn net.Conn) *TCPPeer {
}
// WriteMsg implements the Peer interface. This will write/encode the message
// to the underlying connection.
// to the underlying connection, this only works for messages other than Version
// or VerAck.
func (p *TCPPeer) WriteMsg(msg *Message) error {
if !p.Handshaked() {
return errStateMismatch
}
return p.writeMsg(msg)
}
func (p *TCPPeer) writeMsg(msg *Message) error {
select {
case err := <-p.done:
return err
@ -45,6 +72,55 @@ func (p *TCPPeer) WriteMsg(msg *Message) error {
}
}
// Handshaked returns status of the handshake, whether it's completed or not.
func (p *TCPPeer) Handshaked() bool {
return p.handShake == verAckReceived
}
// SendVersion checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersion(msg *Message) error {
if p.handShake != nothingDone {
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = versionSent
}
return err
}
// HandleVersion checks for the handshake state and version message contents.
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
if p.handShake != versionSent {
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
}
p.version = version
p.handShake = versionReceived
return nil
}
// SendVersionAck checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersionAck(msg *Message) error {
if p.handShake != versionReceived {
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = verAckSent
}
return err
}
// HandleVersionAck checks handshake sequence correctness when VerAck message
// is received.
func (p *TCPPeer) HandleVersionAck() error {
if p.handShake != verAckSent {
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
}
p.handShake = verAckReceived
return nil
}
// NetAddr implements the Peer interface.
func (p *TCPPeer) NetAddr() *net.TCPAddr {
return &p.addr
@ -67,8 +143,3 @@ func (p *TCPPeer) Disconnect(err error) {
func (p *TCPPeer) Version() *payload.Version {
return p.version
}
// SetVersion implements the Peer interface.
func (p *TCPPeer) SetVersion(v *payload.Version) {
p.version = v
}