network: make node strictly follow handshake procedure

Don't accept other messages before handshake is completed, check handshake
message sequence.
This commit is contained in:
Roman Khimov 2019-09-13 15:43:22 +03:00
parent c6487423ae
commit 76c7cff67f
5 changed files with 164 additions and 35 deletions

View file

@ -0,0 +1,27 @@
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
package network
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[nothingDone-0]
_ = x[versionSent-1]
_ = x[versionReceived-2]
_ = x[verAckSent-3]
_ = x[verAckReceived-4]
}
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
func (i handShakeStage) String() string {
if i >= handShakeStage(len(_handShakeStage_index)-1) {
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
}

View file

@ -114,6 +114,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {}
type localPeer struct { type localPeer struct {
netaddr net.TCPAddr netaddr net.TCPAddr
version *payload.Version version *payload.Version
handshaked bool
t *testing.T t *testing.T
messageHandler func(t *testing.T, msg *Message) messageHandler func(t *testing.T, msg *Message)
} }
@ -142,8 +143,23 @@ func (p *localPeer) Done() chan error {
func (p *localPeer) Version() *payload.Version { func (p *localPeer) Version() *payload.Version {
return p.version return p.version
} }
func (p *localPeer) SetVersion(v *payload.Version) { func (p *localPeer) HandleVersion(v *payload.Version) error {
p.version = v p.version = v
return nil
}
func (p *localPeer) SendVersion(m *Message) error {
return p.WriteMsg(m)
}
func (p *localPeer) SendVersionAck(m *Message) error {
return p.WriteMsg(m)
}
func (p *localPeer) HandleVersionAck() error {
p.handshaked = true
return nil
}
func (p *localPeer) Handshaked() bool {
return p.handshaked
} }
func newTestServer() *Server { func newTestServer() *Server {

View file

@ -13,5 +13,9 @@ type Peer interface {
WriteMsg(msg *Message) error WriteMsg(msg *Message) error
Done() chan error Done() chan error
Version() *payload.Version Version() *payload.Version
SetVersion(*payload.Version) Handshaked() bool
SendVersion(*Message) error
SendVersionAck(*Message) error
HandleVersion(*payload.Version) error
HandleVersionAck() error
} }

View file

@ -216,20 +216,23 @@ func (s *Server) sendVersion(p Peer) error {
s.chain.BlockHeight(), s.chain.BlockHeight(),
s.Relay, s.Relay,
) )
return p.WriteMsg(NewMessage(s.Net, CMDVersion, payload)) return p.SendVersion(NewMessage(s.Net, CMDVersion, payload))
} }
// When a peer sends out his version we reply with verack after validating // When a peer sends out his version we reply with verack after validating
// the version. // the version.
func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
if p.NetAddr().Port != int(version.Port) { err := p.HandleVersion(version)
return errPortMismatch if err != nil {
return err
} }
if s.id == version.Nonce { if s.id == version.Nonce {
return errIdenticalID return errIdenticalID
} }
p.SetVersion(version) if p.NetAddr().Port != int(version.Port) {
return p.WriteMsg(NewMessage(s.Net, CMDVerack, nil)) return errPortMismatch
}
return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil))
} }
// handleHeadersCmd will process the headers it received from its peer. // handleHeadersCmd will process the headers it received from its peer.
@ -312,29 +315,37 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return errInvalidNetwork return errInvalidNetwork
} }
switch msg.CommandType() { if peer.Handshaked() {
case CMDAddr: switch msg.CommandType() {
addrs := msg.Payload.(*payload.AddressList) case CMDAddr:
return s.handleAddrCmd(peer, addrs) addrs := msg.Payload.(*payload.AddressList)
case CMDVersion: return s.handleAddrCmd(peer, addrs)
version := msg.Payload.(*payload.Version) case CMDHeaders:
return s.handleVersionCmd(peer, version) headers := msg.Payload.(*payload.Headers)
case CMDHeaders: go s.handleHeadersCmd(peer, headers)
headers := msg.Payload.(*payload.Headers) case CMDInv:
go s.handleHeadersCmd(peer, headers) inventory := msg.Payload.(*payload.Inventory)
case CMDInv: return s.handleInvCmd(peer, inventory)
inventory := msg.Payload.(*payload.Inventory) case CMDBlock:
return s.handleInvCmd(peer, inventory) block := msg.Payload.(*core.Block)
case CMDBlock: return s.handleBlockCmd(peer, block)
block := msg.Payload.(*core.Block) case CMDVersion, CMDVerack:
return s.handleBlockCmd(peer, block) return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
case CMDVerack: }
// Make sure this peer has send his version before we start the } else {
// protocol with that peer. switch msg.CommandType() {
if peer.Version() == nil { case CMDVersion:
return errInvalidHandshake version := msg.Payload.(*payload.Version)
return s.handleVersionCmd(peer, version)
case CMDVerack:
err := peer.HandleVersionAck()
if err != nil {
return err
}
go s.startProtocol(peer)
default:
return fmt.Errorf("received '%s' during handshake", msg.CommandType())
} }
go s.startProtocol(peer)
} }
return nil return nil
} }

View file

@ -1,12 +1,29 @@
package network package network
import ( import (
"errors"
"fmt"
"net" "net"
"sync" "sync"
"github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/network/payload"
) )
type handShakeStage uint8
//go:generate stringer -type=handShakeStage
const (
nothingDone handShakeStage = 0
versionSent handShakeStage = 1
versionReceived handShakeStage = 2
verAckSent handShakeStage = 3
verAckReceived handShakeStage = 4
)
var (
errStateMismatch = errors.New("tried to send protocol message before handshake completed")
)
// TCPPeer represents a connected remote node in the // TCPPeer represents a connected remote node in the
// network over TCP. // network over TCP.
type TCPPeer struct { type TCPPeer struct {
@ -17,6 +34,8 @@ type TCPPeer struct {
// The version of the peer. // The version of the peer.
version *payload.Version version *payload.Version
handShake handShakeStage
done chan error done chan error
wg sync.WaitGroup wg sync.WaitGroup
@ -35,8 +54,16 @@ func NewTCPPeer(conn net.Conn) *TCPPeer {
} }
// WriteMsg implements the Peer interface. This will write/encode the message // WriteMsg implements the Peer interface. This will write/encode the message
// to the underlying connection. // to the underlying connection, this only works for messages other than Version
// or VerAck.
func (p *TCPPeer) WriteMsg(msg *Message) error { func (p *TCPPeer) WriteMsg(msg *Message) error {
if !p.Handshaked() {
return errStateMismatch
}
return p.writeMsg(msg)
}
func (p *TCPPeer) writeMsg(msg *Message) error {
select { select {
case err := <-p.done: case err := <-p.done:
return err return err
@ -45,6 +72,55 @@ func (p *TCPPeer) WriteMsg(msg *Message) error {
} }
} }
// Handshaked returns status of the handshake, whether it's completed or not.
func (p *TCPPeer) Handshaked() bool {
return p.handShake == verAckReceived
}
// SendVersion checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersion(msg *Message) error {
if p.handShake != nothingDone {
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = versionSent
}
return err
}
// HandleVersion checks for the handshake state and version message contents.
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
if p.handShake != versionSent {
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
}
p.version = version
p.handShake = versionReceived
return nil
}
// SendVersionAck checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersionAck(msg *Message) error {
if p.handShake != versionReceived {
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = verAckSent
}
return err
}
// HandleVersionAck checks handshake sequence correctness when VerAck message
// is received.
func (p *TCPPeer) HandleVersionAck() error {
if p.handShake != verAckSent {
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
}
p.handShake = verAckReceived
return nil
}
// NetAddr implements the Peer interface. // NetAddr implements the Peer interface.
func (p *TCPPeer) NetAddr() *net.TCPAddr { func (p *TCPPeer) NetAddr() *net.TCPAddr {
return &p.addr return &p.addr
@ -67,8 +143,3 @@ func (p *TCPPeer) Disconnect(err error) {
func (p *TCPPeer) Version() *payload.Version { func (p *TCPPeer) Version() *payload.Version {
return p.version return p.version
} }
// SetVersion implements the Peer interface.
func (p *TCPPeer) SetVersion(v *payload.Version) {
p.version = v
}