diff --git a/pkg/network/handshakestage_string.go b/pkg/network/handshakestage_string.go new file mode 100644 index 000000000..40cac11bb --- /dev/null +++ b/pkg/network/handshakestage_string.go @@ -0,0 +1,27 @@ +// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT. + +package network + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[nothingDone-0] + _ = x[versionSent-1] + _ = x[versionReceived-2] + _ = x[verAckSent-3] + _ = x[verAckReceived-4] +} + +const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived" + +var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61} + +func (i handShakeStage) String() string { + if i >= handShakeStage(len(_handShakeStage_index)-1) { + return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]] +} diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index c231f96dd..adff3331a 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -114,6 +114,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {} type localPeer struct { netaddr net.TCPAddr version *payload.Version + handshaked bool t *testing.T messageHandler func(t *testing.T, msg *Message) } @@ -142,8 +143,23 @@ func (p *localPeer) Done() chan error { func (p *localPeer) Version() *payload.Version { return p.version } -func (p *localPeer) SetVersion(v *payload.Version) { +func (p *localPeer) HandleVersion(v *payload.Version) error { p.version = v + return nil +} +func (p *localPeer) SendVersion(m *Message) error { + return p.WriteMsg(m) +} +func (p *localPeer) SendVersionAck(m *Message) error { + return p.WriteMsg(m) +} +func (p *localPeer) HandleVersionAck() error { + p.handshaked = true + return nil +} + +func (p *localPeer) Handshaked() bool { + return p.handshaked } func newTestServer() *Server { diff --git a/pkg/network/peer.go b/pkg/network/peer.go index 1298d615a..620aa3d91 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -13,5 +13,9 @@ type Peer interface { WriteMsg(msg *Message) error Done() chan error Version() *payload.Version - SetVersion(*payload.Version) + Handshaked() bool + SendVersion(*Message) error + SendVersionAck(*Message) error + HandleVersion(*payload.Version) error + HandleVersionAck() error } diff --git a/pkg/network/server.go b/pkg/network/server.go index 007f9ff3e..6afd7c861 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -216,20 +216,23 @@ func (s *Server) sendVersion(p Peer) error { s.chain.BlockHeight(), s.Relay, ) - return p.WriteMsg(NewMessage(s.Net, CMDVersion, payload)) + return p.SendVersion(NewMessage(s.Net, CMDVersion, payload)) } // When a peer sends out his version we reply with verack after validating // the version. func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { - if p.NetAddr().Port != int(version.Port) { - return errPortMismatch + err := p.HandleVersion(version) + if err != nil { + return err } if s.id == version.Nonce { return errIdenticalID } - p.SetVersion(version) - return p.WriteMsg(NewMessage(s.Net, CMDVerack, nil)) + if p.NetAddr().Port != int(version.Port) { + return errPortMismatch + } + return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil)) } // handleHeadersCmd will process the headers it received from its peer. @@ -312,29 +315,37 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return errInvalidNetwork } - switch msg.CommandType() { - case CMDAddr: - addrs := msg.Payload.(*payload.AddressList) - return s.handleAddrCmd(peer, addrs) - case CMDVersion: - version := msg.Payload.(*payload.Version) - return s.handleVersionCmd(peer, version) - case CMDHeaders: - headers := msg.Payload.(*payload.Headers) - go s.handleHeadersCmd(peer, headers) - case CMDInv: - inventory := msg.Payload.(*payload.Inventory) - return s.handleInvCmd(peer, inventory) - case CMDBlock: - block := msg.Payload.(*core.Block) - return s.handleBlockCmd(peer, block) - case CMDVerack: - // Make sure this peer has send his version before we start the - // protocol with that peer. - if peer.Version() == nil { - return errInvalidHandshake + if peer.Handshaked() { + switch msg.CommandType() { + case CMDAddr: + addrs := msg.Payload.(*payload.AddressList) + return s.handleAddrCmd(peer, addrs) + case CMDHeaders: + headers := msg.Payload.(*payload.Headers) + go s.handleHeadersCmd(peer, headers) + case CMDInv: + inventory := msg.Payload.(*payload.Inventory) + return s.handleInvCmd(peer, inventory) + case CMDBlock: + block := msg.Payload.(*core.Block) + return s.handleBlockCmd(peer, block) + case CMDVersion, CMDVerack: + return fmt.Errorf("received '%s' after the handshake", msg.CommandType()) + } + } else { + switch msg.CommandType() { + case CMDVersion: + version := msg.Payload.(*payload.Version) + return s.handleVersionCmd(peer, version) + case CMDVerack: + err := peer.HandleVersionAck() + if err != nil { + return err + } + go s.startProtocol(peer) + default: + return fmt.Errorf("received '%s' during handshake", msg.CommandType()) } - go s.startProtocol(peer) } return nil } diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 7703e7bd2..2ebc12078 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -1,12 +1,29 @@ package network import ( + "errors" + "fmt" "net" "sync" "github.com/CityOfZion/neo-go/pkg/network/payload" ) +type handShakeStage uint8 + +//go:generate stringer -type=handShakeStage +const ( + nothingDone handShakeStage = 0 + versionSent handShakeStage = 1 + versionReceived handShakeStage = 2 + verAckSent handShakeStage = 3 + verAckReceived handShakeStage = 4 +) + +var ( + errStateMismatch = errors.New("tried to send protocol message before handshake completed") +) + // TCPPeer represents a connected remote node in the // network over TCP. type TCPPeer struct { @@ -17,6 +34,8 @@ type TCPPeer struct { // The version of the peer. version *payload.Version + handShake handShakeStage + done chan error wg sync.WaitGroup @@ -35,8 +54,16 @@ func NewTCPPeer(conn net.Conn) *TCPPeer { } // WriteMsg implements the Peer interface. This will write/encode the message -// to the underlying connection. +// to the underlying connection, this only works for messages other than Version +// or VerAck. func (p *TCPPeer) WriteMsg(msg *Message) error { + if !p.Handshaked() { + return errStateMismatch + } + return p.writeMsg(msg) +} + +func (p *TCPPeer) writeMsg(msg *Message) error { select { case err := <-p.done: return err @@ -45,6 +72,55 @@ func (p *TCPPeer) WriteMsg(msg *Message) error { } } +// Handshaked returns status of the handshake, whether it's completed or not. +func (p *TCPPeer) Handshaked() bool { + return p.handShake == verAckReceived +} + +// SendVersion checks for the handshake state and sends a message to the peer. +func (p *TCPPeer) SendVersion(msg *Message) error { + if p.handShake != nothingDone { + return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String()) + } + err := p.writeMsg(msg) + if err == nil { + p.handShake = versionSent + } + return err +} + +// HandleVersion checks for the handshake state and version message contents. +func (p *TCPPeer) HandleVersion(version *payload.Version) error { + if p.handShake != versionSent { + return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String()) + } + p.version = version + p.handShake = versionReceived + return nil +} + +// SendVersionAck checks for the handshake state and sends a message to the peer. +func (p *TCPPeer) SendVersionAck(msg *Message) error { + if p.handShake != versionReceived { + return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String()) + } + err := p.writeMsg(msg) + if err == nil { + p.handShake = verAckSent + } + return err +} + +// HandleVersionAck checks handshake sequence correctness when VerAck message +// is received. +func (p *TCPPeer) HandleVersionAck() error { + if p.handShake != verAckSent { + return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String()) + } + p.handShake = verAckReceived + return nil +} + // NetAddr implements the Peer interface. func (p *TCPPeer) NetAddr() *net.TCPAddr { return &p.addr @@ -67,8 +143,3 @@ func (p *TCPPeer) Disconnect(err error) { func (p *TCPPeer) Version() *payload.Version { return p.version } - -// SetVersion implements the Peer interface. -func (p *TCPPeer) SetVersion(v *payload.Version) { - p.version = v -}