network: make node strictly follow handshake procedure
Don't accept other messages before handshake is completed, check handshake message sequence.
This commit is contained in:
parent
c6487423ae
commit
76c7cff67f
5 changed files with 164 additions and 35 deletions
27
pkg/network/handshakestage_string.go
Normal file
27
pkg/network/handshakestage_string.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
|
||||
|
||||
package network
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[nothingDone-0]
|
||||
_ = x[versionSent-1]
|
||||
_ = x[versionReceived-2]
|
||||
_ = x[verAckSent-3]
|
||||
_ = x[verAckReceived-4]
|
||||
}
|
||||
|
||||
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
|
||||
|
||||
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
|
||||
|
||||
func (i handShakeStage) String() string {
|
||||
if i >= handShakeStage(len(_handShakeStage_index)-1) {
|
||||
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
|
||||
}
|
|
@ -114,6 +114,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {}
|
|||
type localPeer struct {
|
||||
netaddr net.TCPAddr
|
||||
version *payload.Version
|
||||
handshaked bool
|
||||
t *testing.T
|
||||
messageHandler func(t *testing.T, msg *Message)
|
||||
}
|
||||
|
@ -142,8 +143,23 @@ func (p *localPeer) Done() chan error {
|
|||
func (p *localPeer) Version() *payload.Version {
|
||||
return p.version
|
||||
}
|
||||
func (p *localPeer) SetVersion(v *payload.Version) {
|
||||
func (p *localPeer) HandleVersion(v *payload.Version) error {
|
||||
p.version = v
|
||||
return nil
|
||||
}
|
||||
func (p *localPeer) SendVersion(m *Message) error {
|
||||
return p.WriteMsg(m)
|
||||
}
|
||||
func (p *localPeer) SendVersionAck(m *Message) error {
|
||||
return p.WriteMsg(m)
|
||||
}
|
||||
func (p *localPeer) HandleVersionAck() error {
|
||||
p.handshaked = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *localPeer) Handshaked() bool {
|
||||
return p.handshaked
|
||||
}
|
||||
|
||||
func newTestServer() *Server {
|
||||
|
|
|
@ -13,5 +13,9 @@ type Peer interface {
|
|||
WriteMsg(msg *Message) error
|
||||
Done() chan error
|
||||
Version() *payload.Version
|
||||
SetVersion(*payload.Version)
|
||||
Handshaked() bool
|
||||
SendVersion(*Message) error
|
||||
SendVersionAck(*Message) error
|
||||
HandleVersion(*payload.Version) error
|
||||
HandleVersionAck() error
|
||||
}
|
||||
|
|
|
@ -216,20 +216,23 @@ func (s *Server) sendVersion(p Peer) error {
|
|||
s.chain.BlockHeight(),
|
||||
s.Relay,
|
||||
)
|
||||
return p.WriteMsg(NewMessage(s.Net, CMDVersion, payload))
|
||||
return p.SendVersion(NewMessage(s.Net, CMDVersion, payload))
|
||||
}
|
||||
|
||||
// When a peer sends out his version we reply with verack after validating
|
||||
// the version.
|
||||
func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
|
||||
if p.NetAddr().Port != int(version.Port) {
|
||||
return errPortMismatch
|
||||
err := p.HandleVersion(version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.id == version.Nonce {
|
||||
return errIdenticalID
|
||||
}
|
||||
p.SetVersion(version)
|
||||
return p.WriteMsg(NewMessage(s.Net, CMDVerack, nil))
|
||||
if p.NetAddr().Port != int(version.Port) {
|
||||
return errPortMismatch
|
||||
}
|
||||
return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil))
|
||||
}
|
||||
|
||||
// handleHeadersCmd will process the headers it received from its peer.
|
||||
|
@ -312,29 +315,37 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
return errInvalidNetwork
|
||||
}
|
||||
|
||||
switch msg.CommandType() {
|
||||
case CMDAddr:
|
||||
addrs := msg.Payload.(*payload.AddressList)
|
||||
return s.handleAddrCmd(peer, addrs)
|
||||
case CMDVersion:
|
||||
version := msg.Payload.(*payload.Version)
|
||||
return s.handleVersionCmd(peer, version)
|
||||
case CMDHeaders:
|
||||
headers := msg.Payload.(*payload.Headers)
|
||||
go s.handleHeadersCmd(peer, headers)
|
||||
case CMDInv:
|
||||
inventory := msg.Payload.(*payload.Inventory)
|
||||
return s.handleInvCmd(peer, inventory)
|
||||
case CMDBlock:
|
||||
block := msg.Payload.(*core.Block)
|
||||
return s.handleBlockCmd(peer, block)
|
||||
case CMDVerack:
|
||||
// Make sure this peer has send his version before we start the
|
||||
// protocol with that peer.
|
||||
if peer.Version() == nil {
|
||||
return errInvalidHandshake
|
||||
if peer.Handshaked() {
|
||||
switch msg.CommandType() {
|
||||
case CMDAddr:
|
||||
addrs := msg.Payload.(*payload.AddressList)
|
||||
return s.handleAddrCmd(peer, addrs)
|
||||
case CMDHeaders:
|
||||
headers := msg.Payload.(*payload.Headers)
|
||||
go s.handleHeadersCmd(peer, headers)
|
||||
case CMDInv:
|
||||
inventory := msg.Payload.(*payload.Inventory)
|
||||
return s.handleInvCmd(peer, inventory)
|
||||
case CMDBlock:
|
||||
block := msg.Payload.(*core.Block)
|
||||
return s.handleBlockCmd(peer, block)
|
||||
case CMDVersion, CMDVerack:
|
||||
return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
|
||||
}
|
||||
} else {
|
||||
switch msg.CommandType() {
|
||||
case CMDVersion:
|
||||
version := msg.Payload.(*payload.Version)
|
||||
return s.handleVersionCmd(peer, version)
|
||||
case CMDVerack:
|
||||
err := peer.HandleVersionAck()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.startProtocol(peer)
|
||||
default:
|
||||
return fmt.Errorf("received '%s' during handshake", msg.CommandType())
|
||||
}
|
||||
go s.startProtocol(peer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,12 +1,29 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
)
|
||||
|
||||
type handShakeStage uint8
|
||||
|
||||
//go:generate stringer -type=handShakeStage
|
||||
const (
|
||||
nothingDone handShakeStage = 0
|
||||
versionSent handShakeStage = 1
|
||||
versionReceived handShakeStage = 2
|
||||
verAckSent handShakeStage = 3
|
||||
verAckReceived handShakeStage = 4
|
||||
)
|
||||
|
||||
var (
|
||||
errStateMismatch = errors.New("tried to send protocol message before handshake completed")
|
||||
)
|
||||
|
||||
// TCPPeer represents a connected remote node in the
|
||||
// network over TCP.
|
||||
type TCPPeer struct {
|
||||
|
@ -17,6 +34,8 @@ type TCPPeer struct {
|
|||
// The version of the peer.
|
||||
version *payload.Version
|
||||
|
||||
handShake handShakeStage
|
||||
|
||||
done chan error
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
@ -35,8 +54,16 @@ func NewTCPPeer(conn net.Conn) *TCPPeer {
|
|||
}
|
||||
|
||||
// WriteMsg implements the Peer interface. This will write/encode the message
|
||||
// to the underlying connection.
|
||||
// to the underlying connection, this only works for messages other than Version
|
||||
// or VerAck.
|
||||
func (p *TCPPeer) WriteMsg(msg *Message) error {
|
||||
if !p.Handshaked() {
|
||||
return errStateMismatch
|
||||
}
|
||||
return p.writeMsg(msg)
|
||||
}
|
||||
|
||||
func (p *TCPPeer) writeMsg(msg *Message) error {
|
||||
select {
|
||||
case err := <-p.done:
|
||||
return err
|
||||
|
@ -45,6 +72,55 @@ func (p *TCPPeer) WriteMsg(msg *Message) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Handshaked returns status of the handshake, whether it's completed or not.
|
||||
func (p *TCPPeer) Handshaked() bool {
|
||||
return p.handShake == verAckReceived
|
||||
}
|
||||
|
||||
// SendVersion checks for the handshake state and sends a message to the peer.
|
||||
func (p *TCPPeer) SendVersion(msg *Message) error {
|
||||
if p.handShake != nothingDone {
|
||||
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
|
||||
}
|
||||
err := p.writeMsg(msg)
|
||||
if err == nil {
|
||||
p.handShake = versionSent
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleVersion checks for the handshake state and version message contents.
|
||||
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
|
||||
if p.handShake != versionSent {
|
||||
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
|
||||
}
|
||||
p.version = version
|
||||
p.handShake = versionReceived
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendVersionAck checks for the handshake state and sends a message to the peer.
|
||||
func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
||||
if p.handShake != versionReceived {
|
||||
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
|
||||
}
|
||||
err := p.writeMsg(msg)
|
||||
if err == nil {
|
||||
p.handShake = verAckSent
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleVersionAck checks handshake sequence correctness when VerAck message
|
||||
// is received.
|
||||
func (p *TCPPeer) HandleVersionAck() error {
|
||||
if p.handShake != verAckSent {
|
||||
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
|
||||
}
|
||||
p.handShake = verAckReceived
|
||||
return nil
|
||||
}
|
||||
|
||||
// NetAddr implements the Peer interface.
|
||||
func (p *TCPPeer) NetAddr() *net.TCPAddr {
|
||||
return &p.addr
|
||||
|
@ -67,8 +143,3 @@ func (p *TCPPeer) Disconnect(err error) {
|
|||
func (p *TCPPeer) Version() *payload.Version {
|
||||
return p.version
|
||||
}
|
||||
|
||||
// SetVersion implements the Peer interface.
|
||||
func (p *TCPPeer) SetVersion(v *payload.Version) {
|
||||
p.version = v
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue