mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-12-01 23:45:50 +00:00
network: rework peer handshaking, fix #458
This allows to start handshaking from both client and server (mainnet/testnet nodes were seen to not care about string ordering for it), but still maintains some sane checks in the process. It also makes functions thread-safe because we have two goroutines servicing read and write side of the Peer connection, so they can clash on access to the struct fields. Add a test for it also.
This commit is contained in:
parent
e859e03240
commit
ec76ed23a5
3 changed files with 117 additions and 46 deletions
|
@ -1,27 +0,0 @@
|
||||||
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
|
|
||||||
|
|
||||||
package network
|
|
||||||
|
|
||||||
import "strconv"
|
|
||||||
|
|
||||||
func _() {
|
|
||||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
|
||||||
// Re-run the stringer command to generate them again.
|
|
||||||
var x [1]struct{}
|
|
||||||
_ = x[nothingDone-0]
|
|
||||||
_ = x[versionSent-1]
|
|
||||||
_ = x[versionReceived-2]
|
|
||||||
_ = x[verAckSent-3]
|
|
||||||
_ = x[verAckReceived-4]
|
|
||||||
}
|
|
||||||
|
|
||||||
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
|
|
||||||
|
|
||||||
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
|
|
||||||
|
|
||||||
func (i handShakeStage) String() string {
|
|
||||||
if i >= handShakeStage(len(_handShakeStage_index)-1) {
|
|
||||||
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
|
|
||||||
}
|
|
||||||
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
|
|
||||||
}
|
|
|
@ -12,13 +12,11 @@ import (
|
||||||
|
|
||||||
type handShakeStage uint8
|
type handShakeStage uint8
|
||||||
|
|
||||||
//go:generate stringer -type=handShakeStage
|
|
||||||
const (
|
const (
|
||||||
nothingDone handShakeStage = 0
|
versionSent handShakeStage = 1 << iota
|
||||||
versionSent handShakeStage = 1
|
versionReceived
|
||||||
versionReceived handShakeStage = 2
|
verAckSent
|
||||||
verAckSent handShakeStage = 3
|
verAckReceived
|
||||||
verAckReceived handShakeStage = 4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -34,6 +32,7 @@ type TCPPeer struct {
|
||||||
// The version of the peer.
|
// The version of the peer.
|
||||||
version *payload.Version
|
version *payload.Version
|
||||||
|
|
||||||
|
lock sync.RWMutex
|
||||||
handShake handShakeStage
|
handShake handShakeStage
|
||||||
|
|
||||||
done chan error
|
done chan error
|
||||||
|
@ -71,39 +70,50 @@ func (p *TCPPeer) writeMsg(msg *Message) error {
|
||||||
|
|
||||||
// Handshaked returns status of the handshake, whether it's completed or not.
|
// Handshaked returns status of the handshake, whether it's completed or not.
|
||||||
func (p *TCPPeer) Handshaked() bool {
|
func (p *TCPPeer) Handshaked() bool {
|
||||||
return p.handShake == verAckReceived
|
p.lock.RLock()
|
||||||
|
defer p.lock.RUnlock()
|
||||||
|
return p.handShake == (verAckReceived | verAckSent | versionReceived | versionSent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendVersion checks for the handshake state and sends a message to the peer.
|
// SendVersion checks for the handshake state and sends a message to the peer.
|
||||||
func (p *TCPPeer) SendVersion(msg *Message) error {
|
func (p *TCPPeer) SendVersion(msg *Message) error {
|
||||||
if p.handShake != nothingDone {
|
p.lock.Lock()
|
||||||
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
|
defer p.lock.Unlock()
|
||||||
|
if p.handShake&versionSent != 0 {
|
||||||
|
return errors.New("invalid handshake: already sent Version")
|
||||||
}
|
}
|
||||||
err := p.writeMsg(msg)
|
err := p.writeMsg(msg)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p.handShake = versionSent
|
p.handShake |= versionSent
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleVersion checks for the handshake state and version message contents.
|
// HandleVersion checks for the handshake state and version message contents.
|
||||||
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
|
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
|
||||||
if p.handShake != versionSent {
|
p.lock.Lock()
|
||||||
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
|
defer p.lock.Unlock()
|
||||||
|
if p.handShake&versionReceived != 0 {
|
||||||
|
return errors.New("invalid handshake: already received Version")
|
||||||
}
|
}
|
||||||
p.version = version
|
p.version = version
|
||||||
p.handShake = versionReceived
|
p.handShake |= versionReceived
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendVersionAck checks for the handshake state and sends a message to the peer.
|
// SendVersionAck checks for the handshake state and sends a message to the peer.
|
||||||
func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
||||||
if p.handShake != versionReceived {
|
p.lock.Lock()
|
||||||
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
|
defer p.lock.Unlock()
|
||||||
|
if p.handShake&versionReceived == 0 {
|
||||||
|
return errors.New("invalid handshake: tried to send VersionAck, but no version received yet")
|
||||||
|
}
|
||||||
|
if p.handShake&verAckSent != 0 {
|
||||||
|
return errors.New("invalid handshake: already sent VersionAck")
|
||||||
}
|
}
|
||||||
err := p.writeMsg(msg)
|
err := p.writeMsg(msg)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p.handShake = verAckSent
|
p.handShake |= verAckSent
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -111,10 +121,15 @@ func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
||||||
// HandleVersionAck checks handshake sequence correctness when VerAck message
|
// HandleVersionAck checks handshake sequence correctness when VerAck message
|
||||||
// is received.
|
// is received.
|
||||||
func (p *TCPPeer) HandleVersionAck() error {
|
func (p *TCPPeer) HandleVersionAck() error {
|
||||||
if p.handShake != verAckSent {
|
p.lock.Lock()
|
||||||
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
|
defer p.lock.Unlock()
|
||||||
|
if p.handShake&versionSent == 0 {
|
||||||
|
return errors.New("invalid handshake: received VersionAck, but no version sent yet")
|
||||||
}
|
}
|
||||||
p.handShake = verAckReceived
|
if p.handShake&verAckReceived != 0 {
|
||||||
|
return errors.New("invalid handshake: already received VersionAck")
|
||||||
|
}
|
||||||
|
p.handShake |= verAckReceived
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
83
pkg/network/tcp_peer_test.go
Normal file
83
pkg/network/tcp_peer_test.go
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func connReadStub(conn net.Conn) {
|
||||||
|
b := make([]byte, 1024)
|
||||||
|
var err error
|
||||||
|
for ; err == nil; _, err = conn.Read(b) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerHandshake(t *testing.T) {
|
||||||
|
server, client := net.Pipe()
|
||||||
|
|
||||||
|
tcpS := NewTCPPeer(server)
|
||||||
|
tcpC := NewTCPPeer(client)
|
||||||
|
|
||||||
|
// Something should read things written into the pipe.
|
||||||
|
go connReadStub(tcpS.conn)
|
||||||
|
go connReadStub(tcpC.conn)
|
||||||
|
|
||||||
|
// No handshake yet.
|
||||||
|
require.Equal(t, false, tcpS.Handshaked())
|
||||||
|
require.Equal(t, false, tcpC.Handshaked())
|
||||||
|
|
||||||
|
// No ordinary messages can be written.
|
||||||
|
require.Error(t, tcpS.WriteMsg(&Message{}))
|
||||||
|
require.Error(t, tcpC.WriteMsg(&Message{}))
|
||||||
|
|
||||||
|
// Try to mess with VersionAck on both client and server, it should fail.
|
||||||
|
require.Error(t, tcpS.SendVersionAck(&Message{}))
|
||||||
|
require.Error(t, tcpS.HandleVersionAck())
|
||||||
|
require.Error(t, tcpC.SendVersionAck(&Message{}))
|
||||||
|
require.Error(t, tcpC.HandleVersionAck())
|
||||||
|
|
||||||
|
// No handshake yet.
|
||||||
|
require.Equal(t, false, tcpS.Handshaked())
|
||||||
|
require.Equal(t, false, tcpC.Handshaked())
|
||||||
|
|
||||||
|
// Now send and handle versions, but in a different order on client and
|
||||||
|
// server.
|
||||||
|
require.NoError(t, tcpC.SendVersion(&Message{}))
|
||||||
|
require.NoError(t, tcpS.HandleVersion(&payload.Version{}))
|
||||||
|
require.NoError(t, tcpC.HandleVersion(&payload.Version{}))
|
||||||
|
require.NoError(t, tcpS.SendVersion(&Message{}))
|
||||||
|
|
||||||
|
// No handshake yet.
|
||||||
|
require.Equal(t, false, tcpS.Handshaked())
|
||||||
|
require.Equal(t, false, tcpC.Handshaked())
|
||||||
|
|
||||||
|
// These are sent/received and should fail now.
|
||||||
|
require.Error(t, tcpC.SendVersion(&Message{}))
|
||||||
|
require.Error(t, tcpS.HandleVersion(&payload.Version{}))
|
||||||
|
require.Error(t, tcpC.HandleVersion(&payload.Version{}))
|
||||||
|
require.Error(t, tcpS.SendVersion(&Message{}))
|
||||||
|
|
||||||
|
// Now send and handle ACK, again in a different order on client and
|
||||||
|
// server.
|
||||||
|
require.NoError(t, tcpC.SendVersionAck(&Message{}))
|
||||||
|
require.NoError(t, tcpS.HandleVersionAck())
|
||||||
|
require.NoError(t, tcpC.HandleVersionAck())
|
||||||
|
require.NoError(t, tcpS.SendVersionAck(&Message{}))
|
||||||
|
|
||||||
|
// Handshaked now.
|
||||||
|
require.Equal(t, true, tcpS.Handshaked())
|
||||||
|
require.Equal(t, true, tcpC.Handshaked())
|
||||||
|
|
||||||
|
// Subsequent ACKing should fail.
|
||||||
|
require.Error(t, tcpC.SendVersionAck(&Message{}))
|
||||||
|
require.Error(t, tcpS.HandleVersionAck())
|
||||||
|
require.Error(t, tcpC.HandleVersionAck())
|
||||||
|
require.Error(t, tcpS.SendVersionAck(&Message{}))
|
||||||
|
|
||||||
|
// Now regular messaging can proceed.
|
||||||
|
require.NoError(t, tcpS.WriteMsg(&Message{}))
|
||||||
|
require.NoError(t, tcpC.WriteMsg(&Message{}))
|
||||||
|
}
|
Loading…
Reference in a new issue