forked from TrueCloudLab/neoneo-go
network: make node strictly follow handshake procedure
Don't accept other messages before handshake is completed, check handshake message sequence.
This commit is contained in:
parent
c6487423ae
commit
76c7cff67f
5 changed files with 164 additions and 35 deletions
27
pkg/network/handshakestage_string.go
Normal file
27
pkg/network/handshakestage_string.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package network
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
func _() {
|
||||||
|
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||||
|
// Re-run the stringer command to generate them again.
|
||||||
|
var x [1]struct{}
|
||||||
|
_ = x[nothingDone-0]
|
||||||
|
_ = x[versionSent-1]
|
||||||
|
_ = x[versionReceived-2]
|
||||||
|
_ = x[verAckSent-3]
|
||||||
|
_ = x[verAckReceived-4]
|
||||||
|
}
|
||||||
|
|
||||||
|
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
|
||||||
|
|
||||||
|
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
|
||||||
|
|
||||||
|
func (i handShakeStage) String() string {
|
||||||
|
if i >= handShakeStage(len(_handShakeStage_index)-1) {
|
||||||
|
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||||
|
}
|
||||||
|
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
|
||||||
|
}
|
|
@ -114,6 +114,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {}
|
||||||
type localPeer struct {
|
type localPeer struct {
|
||||||
netaddr net.TCPAddr
|
netaddr net.TCPAddr
|
||||||
version *payload.Version
|
version *payload.Version
|
||||||
|
handshaked bool
|
||||||
t *testing.T
|
t *testing.T
|
||||||
messageHandler func(t *testing.T, msg *Message)
|
messageHandler func(t *testing.T, msg *Message)
|
||||||
}
|
}
|
||||||
|
@ -142,8 +143,23 @@ func (p *localPeer) Done() chan error {
|
||||||
func (p *localPeer) Version() *payload.Version {
|
func (p *localPeer) Version() *payload.Version {
|
||||||
return p.version
|
return p.version
|
||||||
}
|
}
|
||||||
func (p *localPeer) SetVersion(v *payload.Version) {
|
func (p *localPeer) HandleVersion(v *payload.Version) error {
|
||||||
p.version = v
|
p.version = v
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (p *localPeer) SendVersion(m *Message) error {
|
||||||
|
return p.WriteMsg(m)
|
||||||
|
}
|
||||||
|
func (p *localPeer) SendVersionAck(m *Message) error {
|
||||||
|
return p.WriteMsg(m)
|
||||||
|
}
|
||||||
|
func (p *localPeer) HandleVersionAck() error {
|
||||||
|
p.handshaked = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *localPeer) Handshaked() bool {
|
||||||
|
return p.handshaked
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestServer() *Server {
|
func newTestServer() *Server {
|
||||||
|
|
|
@ -13,5 +13,9 @@ type Peer interface {
|
||||||
WriteMsg(msg *Message) error
|
WriteMsg(msg *Message) error
|
||||||
Done() chan error
|
Done() chan error
|
||||||
Version() *payload.Version
|
Version() *payload.Version
|
||||||
SetVersion(*payload.Version)
|
Handshaked() bool
|
||||||
|
SendVersion(*Message) error
|
||||||
|
SendVersionAck(*Message) error
|
||||||
|
HandleVersion(*payload.Version) error
|
||||||
|
HandleVersionAck() error
|
||||||
}
|
}
|
||||||
|
|
|
@ -216,20 +216,23 @@ func (s *Server) sendVersion(p Peer) error {
|
||||||
s.chain.BlockHeight(),
|
s.chain.BlockHeight(),
|
||||||
s.Relay,
|
s.Relay,
|
||||||
)
|
)
|
||||||
return p.WriteMsg(NewMessage(s.Net, CMDVersion, payload))
|
return p.SendVersion(NewMessage(s.Net, CMDVersion, payload))
|
||||||
}
|
}
|
||||||
|
|
||||||
// When a peer sends out his version we reply with verack after validating
|
// When a peer sends out his version we reply with verack after validating
|
||||||
// the version.
|
// the version.
|
||||||
func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
|
func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
|
||||||
if p.NetAddr().Port != int(version.Port) {
|
err := p.HandleVersion(version)
|
||||||
return errPortMismatch
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if s.id == version.Nonce {
|
if s.id == version.Nonce {
|
||||||
return errIdenticalID
|
return errIdenticalID
|
||||||
}
|
}
|
||||||
p.SetVersion(version)
|
if p.NetAddr().Port != int(version.Port) {
|
||||||
return p.WriteMsg(NewMessage(s.Net, CMDVerack, nil))
|
return errPortMismatch
|
||||||
|
}
|
||||||
|
return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleHeadersCmd will process the headers it received from its peer.
|
// handleHeadersCmd will process the headers it received from its peer.
|
||||||
|
@ -312,29 +315,37 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
||||||
return errInvalidNetwork
|
return errInvalidNetwork
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg.CommandType() {
|
if peer.Handshaked() {
|
||||||
case CMDAddr:
|
switch msg.CommandType() {
|
||||||
addrs := msg.Payload.(*payload.AddressList)
|
case CMDAddr:
|
||||||
return s.handleAddrCmd(peer, addrs)
|
addrs := msg.Payload.(*payload.AddressList)
|
||||||
case CMDVersion:
|
return s.handleAddrCmd(peer, addrs)
|
||||||
version := msg.Payload.(*payload.Version)
|
case CMDHeaders:
|
||||||
return s.handleVersionCmd(peer, version)
|
headers := msg.Payload.(*payload.Headers)
|
||||||
case CMDHeaders:
|
go s.handleHeadersCmd(peer, headers)
|
||||||
headers := msg.Payload.(*payload.Headers)
|
case CMDInv:
|
||||||
go s.handleHeadersCmd(peer, headers)
|
inventory := msg.Payload.(*payload.Inventory)
|
||||||
case CMDInv:
|
return s.handleInvCmd(peer, inventory)
|
||||||
inventory := msg.Payload.(*payload.Inventory)
|
case CMDBlock:
|
||||||
return s.handleInvCmd(peer, inventory)
|
block := msg.Payload.(*core.Block)
|
||||||
case CMDBlock:
|
return s.handleBlockCmd(peer, block)
|
||||||
block := msg.Payload.(*core.Block)
|
case CMDVersion, CMDVerack:
|
||||||
return s.handleBlockCmd(peer, block)
|
return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
|
||||||
case CMDVerack:
|
}
|
||||||
// Make sure this peer has send his version before we start the
|
} else {
|
||||||
// protocol with that peer.
|
switch msg.CommandType() {
|
||||||
if peer.Version() == nil {
|
case CMDVersion:
|
||||||
return errInvalidHandshake
|
version := msg.Payload.(*payload.Version)
|
||||||
|
return s.handleVersionCmd(peer, version)
|
||||||
|
case CMDVerack:
|
||||||
|
err := peer.HandleVersionAck()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go s.startProtocol(peer)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("received '%s' during handshake", msg.CommandType())
|
||||||
}
|
}
|
||||||
go s.startProtocol(peer)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,29 @@
|
||||||
package network
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type handShakeStage uint8
|
||||||
|
|
||||||
|
//go:generate stringer -type=handShakeStage
|
||||||
|
const (
|
||||||
|
nothingDone handShakeStage = 0
|
||||||
|
versionSent handShakeStage = 1
|
||||||
|
versionReceived handShakeStage = 2
|
||||||
|
verAckSent handShakeStage = 3
|
||||||
|
verAckReceived handShakeStage = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errStateMismatch = errors.New("tried to send protocol message before handshake completed")
|
||||||
|
)
|
||||||
|
|
||||||
// TCPPeer represents a connected remote node in the
|
// TCPPeer represents a connected remote node in the
|
||||||
// network over TCP.
|
// network over TCP.
|
||||||
type TCPPeer struct {
|
type TCPPeer struct {
|
||||||
|
@ -17,6 +34,8 @@ type TCPPeer struct {
|
||||||
// The version of the peer.
|
// The version of the peer.
|
||||||
version *payload.Version
|
version *payload.Version
|
||||||
|
|
||||||
|
handShake handShakeStage
|
||||||
|
|
||||||
done chan error
|
done chan error
|
||||||
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
@ -35,8 +54,16 @@ func NewTCPPeer(conn net.Conn) *TCPPeer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteMsg implements the Peer interface. This will write/encode the message
|
// WriteMsg implements the Peer interface. This will write/encode the message
|
||||||
// to the underlying connection.
|
// to the underlying connection, this only works for messages other than Version
|
||||||
|
// or VerAck.
|
||||||
func (p *TCPPeer) WriteMsg(msg *Message) error {
|
func (p *TCPPeer) WriteMsg(msg *Message) error {
|
||||||
|
if !p.Handshaked() {
|
||||||
|
return errStateMismatch
|
||||||
|
}
|
||||||
|
return p.writeMsg(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TCPPeer) writeMsg(msg *Message) error {
|
||||||
select {
|
select {
|
||||||
case err := <-p.done:
|
case err := <-p.done:
|
||||||
return err
|
return err
|
||||||
|
@ -45,6 +72,55 @@ func (p *TCPPeer) WriteMsg(msg *Message) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handshaked returns status of the handshake, whether it's completed or not.
|
||||||
|
func (p *TCPPeer) Handshaked() bool {
|
||||||
|
return p.handShake == verAckReceived
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendVersion checks for the handshake state and sends a message to the peer.
|
||||||
|
func (p *TCPPeer) SendVersion(msg *Message) error {
|
||||||
|
if p.handShake != nothingDone {
|
||||||
|
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
|
||||||
|
}
|
||||||
|
err := p.writeMsg(msg)
|
||||||
|
if err == nil {
|
||||||
|
p.handShake = versionSent
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleVersion checks for the handshake state and version message contents.
|
||||||
|
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
|
||||||
|
if p.handShake != versionSent {
|
||||||
|
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
|
||||||
|
}
|
||||||
|
p.version = version
|
||||||
|
p.handShake = versionReceived
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendVersionAck checks for the handshake state and sends a message to the peer.
|
||||||
|
func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
||||||
|
if p.handShake != versionReceived {
|
||||||
|
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
|
||||||
|
}
|
||||||
|
err := p.writeMsg(msg)
|
||||||
|
if err == nil {
|
||||||
|
p.handShake = verAckSent
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleVersionAck checks handshake sequence correctness when VerAck message
|
||||||
|
// is received.
|
||||||
|
func (p *TCPPeer) HandleVersionAck() error {
|
||||||
|
if p.handShake != verAckSent {
|
||||||
|
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
|
||||||
|
}
|
||||||
|
p.handShake = verAckReceived
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// NetAddr implements the Peer interface.
|
// NetAddr implements the Peer interface.
|
||||||
func (p *TCPPeer) NetAddr() *net.TCPAddr {
|
func (p *TCPPeer) NetAddr() *net.TCPAddr {
|
||||||
return &p.addr
|
return &p.addr
|
||||||
|
@ -67,8 +143,3 @@ func (p *TCPPeer) Disconnect(err error) {
|
||||||
func (p *TCPPeer) Version() *payload.Version {
|
func (p *TCPPeer) Version() *payload.Version {
|
||||||
return p.version
|
return p.version
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetVersion implements the Peer interface.
|
|
||||||
func (p *TCPPeer) SetVersion(v *payload.Version) {
|
|
||||||
p.version = v
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue