network: rework peer handshaking, fix #458

This allows to start handshaking from both client and server (mainnet/testnet
nodes were seen to not care about string ordering for it), but still maintains
some sane checks in the process. It also makes functions thread-safe because
we have two goroutines servicing read and write side of the Peer connection,
so they can clash on access to the struct fields.

Add a test for it also.
This commit is contained in:
Roman Khimov 2019-11-06 11:06:00 +03:00
parent e859e03240
commit ec76ed23a5
3 changed files with 117 additions and 46 deletions

View file

@ -1,27 +0,0 @@
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
package network
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[nothingDone-0]
_ = x[versionSent-1]
_ = x[versionReceived-2]
_ = x[verAckSent-3]
_ = x[verAckReceived-4]
}
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
func (i handShakeStage) String() string {
if i >= handShakeStage(len(_handShakeStage_index)-1) {
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
}

View file

@ -12,13 +12,11 @@ import (
type handShakeStage uint8 type handShakeStage uint8
//go:generate stringer -type=handShakeStage
const ( const (
nothingDone handShakeStage = 0 versionSent handShakeStage = 1 << iota
versionSent handShakeStage = 1 versionReceived
versionReceived handShakeStage = 2 verAckSent
verAckSent handShakeStage = 3 verAckReceived
verAckReceived handShakeStage = 4
) )
var ( var (
@ -34,6 +32,7 @@ type TCPPeer struct {
// The version of the peer. // The version of the peer.
version *payload.Version version *payload.Version
lock sync.RWMutex
handShake handShakeStage handShake handShakeStage
done chan error done chan error
@ -71,39 +70,50 @@ func (p *TCPPeer) writeMsg(msg *Message) error {
// Handshaked returns status of the handshake, whether it's completed or not. // Handshaked returns status of the handshake, whether it's completed or not.
func (p *TCPPeer) Handshaked() bool { func (p *TCPPeer) Handshaked() bool {
return p.handShake == verAckReceived p.lock.RLock()
defer p.lock.RUnlock()
return p.handShake == (verAckReceived | verAckSent | versionReceived | versionSent)
} }
// SendVersion checks for the handshake state and sends a message to the peer. // SendVersion checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersion(msg *Message) error { func (p *TCPPeer) SendVersion(msg *Message) error {
if p.handShake != nothingDone { p.lock.Lock()
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String()) defer p.lock.Unlock()
if p.handShake&versionSent != 0 {
return errors.New("invalid handshake: already sent Version")
} }
err := p.writeMsg(msg) err := p.writeMsg(msg)
if err == nil { if err == nil {
p.handShake = versionSent p.handShake |= versionSent
} }
return err return err
} }
// HandleVersion checks for the handshake state and version message contents. // HandleVersion checks for the handshake state and version message contents.
func (p *TCPPeer) HandleVersion(version *payload.Version) error { func (p *TCPPeer) HandleVersion(version *payload.Version) error {
if p.handShake != versionSent { p.lock.Lock()
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String()) defer p.lock.Unlock()
if p.handShake&versionReceived != 0 {
return errors.New("invalid handshake: already received Version")
} }
p.version = version p.version = version
p.handShake = versionReceived p.handShake |= versionReceived
return nil return nil
} }
// SendVersionAck checks for the handshake state and sends a message to the peer. // SendVersionAck checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersionAck(msg *Message) error { func (p *TCPPeer) SendVersionAck(msg *Message) error {
if p.handShake != versionReceived { p.lock.Lock()
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String()) defer p.lock.Unlock()
if p.handShake&versionReceived == 0 {
return errors.New("invalid handshake: tried to send VersionAck, but no version received yet")
}
if p.handShake&verAckSent != 0 {
return errors.New("invalid handshake: already sent VersionAck")
} }
err := p.writeMsg(msg) err := p.writeMsg(msg)
if err == nil { if err == nil {
p.handShake = verAckSent p.handShake |= verAckSent
} }
return err return err
} }
@ -111,10 +121,15 @@ func (p *TCPPeer) SendVersionAck(msg *Message) error {
// HandleVersionAck checks handshake sequence correctness when VerAck message // HandleVersionAck checks handshake sequence correctness when VerAck message
// is received. // is received.
func (p *TCPPeer) HandleVersionAck() error { func (p *TCPPeer) HandleVersionAck() error {
if p.handShake != verAckSent { p.lock.Lock()
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String()) defer p.lock.Unlock()
if p.handShake&versionSent == 0 {
return errors.New("invalid handshake: received VersionAck, but no version sent yet")
} }
p.handShake = verAckReceived if p.handShake&verAckReceived != 0 {
return errors.New("invalid handshake: already received VersionAck")
}
p.handShake |= verAckReceived
return nil return nil
} }

View file

@ -0,0 +1,83 @@
package network
import (
"net"
"testing"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/stretchr/testify/require"
)
func connReadStub(conn net.Conn) {
b := make([]byte, 1024)
var err error
for ; err == nil; _, err = conn.Read(b) {
}
}
func TestPeerHandshake(t *testing.T) {
server, client := net.Pipe()
tcpS := NewTCPPeer(server)
tcpC := NewTCPPeer(client)
// Something should read things written into the pipe.
go connReadStub(tcpS.conn)
go connReadStub(tcpC.conn)
// No handshake yet.
require.Equal(t, false, tcpS.Handshaked())
require.Equal(t, false, tcpC.Handshaked())
// No ordinary messages can be written.
require.Error(t, tcpS.WriteMsg(&Message{}))
require.Error(t, tcpC.WriteMsg(&Message{}))
// Try to mess with VersionAck on both client and server, it should fail.
require.Error(t, tcpS.SendVersionAck(&Message{}))
require.Error(t, tcpS.HandleVersionAck())
require.Error(t, tcpC.SendVersionAck(&Message{}))
require.Error(t, tcpC.HandleVersionAck())
// No handshake yet.
require.Equal(t, false, tcpS.Handshaked())
require.Equal(t, false, tcpC.Handshaked())
// Now send and handle versions, but in a different order on client and
// server.
require.NoError(t, tcpC.SendVersion(&Message{}))
require.NoError(t, tcpS.HandleVersion(&payload.Version{}))
require.NoError(t, tcpC.HandleVersion(&payload.Version{}))
require.NoError(t, tcpS.SendVersion(&Message{}))
// No handshake yet.
require.Equal(t, false, tcpS.Handshaked())
require.Equal(t, false, tcpC.Handshaked())
// These are sent/received and should fail now.
require.Error(t, tcpC.SendVersion(&Message{}))
require.Error(t, tcpS.HandleVersion(&payload.Version{}))
require.Error(t, tcpC.HandleVersion(&payload.Version{}))
require.Error(t, tcpS.SendVersion(&Message{}))
// Now send and handle ACK, again in a different order on client and
// server.
require.NoError(t, tcpC.SendVersionAck(&Message{}))
require.NoError(t, tcpS.HandleVersionAck())
require.NoError(t, tcpC.HandleVersionAck())
require.NoError(t, tcpS.SendVersionAck(&Message{}))
// Handshaked now.
require.Equal(t, true, tcpS.Handshaked())
require.Equal(t, true, tcpC.Handshaked())
// Subsequent ACKing should fail.
require.Error(t, tcpC.SendVersionAck(&Message{}))
require.Error(t, tcpS.HandleVersionAck())
require.Error(t, tcpC.HandleVersionAck())
require.Error(t, tcpS.SendVersionAck(&Message{}))
// Now regular messaging can proceed.
require.NoError(t, tcpS.WriteMsg(&Message{}))
require.NoError(t, tcpC.WriteMsg(&Message{}))
}