mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-25 23:42:23 +00:00
network: rework peer handshaking, fix #458
This allows to start handshaking from both client and server (mainnet/testnet nodes were seen to not care about string ordering for it), but still maintains some sane checks in the process. It also makes functions thread-safe because we have two goroutines servicing read and write side of the Peer connection, so they can clash on access to the struct fields. Add a test for it also.
This commit is contained in:
parent
e859e03240
commit
ec76ed23a5
3 changed files with 117 additions and 46 deletions
|
@ -1,27 +0,0 @@
|
|||
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
|
||||
|
||||
package network
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[nothingDone-0]
|
||||
_ = x[versionSent-1]
|
||||
_ = x[versionReceived-2]
|
||||
_ = x[verAckSent-3]
|
||||
_ = x[verAckReceived-4]
|
||||
}
|
||||
|
||||
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
|
||||
|
||||
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
|
||||
|
||||
func (i handShakeStage) String() string {
|
||||
if i >= handShakeStage(len(_handShakeStage_index)-1) {
|
||||
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
|
||||
}
|
|
@ -12,13 +12,11 @@ import (
|
|||
|
||||
type handShakeStage uint8
|
||||
|
||||
//go:generate stringer -type=handShakeStage
|
||||
const (
|
||||
nothingDone handShakeStage = 0
|
||||
versionSent handShakeStage = 1
|
||||
versionReceived handShakeStage = 2
|
||||
verAckSent handShakeStage = 3
|
||||
verAckReceived handShakeStage = 4
|
||||
versionSent handShakeStage = 1 << iota
|
||||
versionReceived
|
||||
verAckSent
|
||||
verAckReceived
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -34,6 +32,7 @@ type TCPPeer struct {
|
|||
// The version of the peer.
|
||||
version *payload.Version
|
||||
|
||||
lock sync.RWMutex
|
||||
handShake handShakeStage
|
||||
|
||||
done chan error
|
||||
|
@ -71,39 +70,50 @@ func (p *TCPPeer) writeMsg(msg *Message) error {
|
|||
|
||||
// Handshaked returns status of the handshake, whether it's completed or not.
|
||||
func (p *TCPPeer) Handshaked() bool {
|
||||
return p.handShake == verAckReceived
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
return p.handShake == (verAckReceived | verAckSent | versionReceived | versionSent)
|
||||
}
|
||||
|
||||
// SendVersion checks for the handshake state and sends a message to the peer.
|
||||
func (p *TCPPeer) SendVersion(msg *Message) error {
|
||||
if p.handShake != nothingDone {
|
||||
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
if p.handShake&versionSent != 0 {
|
||||
return errors.New("invalid handshake: already sent Version")
|
||||
}
|
||||
err := p.writeMsg(msg)
|
||||
if err == nil {
|
||||
p.handShake = versionSent
|
||||
p.handShake |= versionSent
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleVersion checks for the handshake state and version message contents.
|
||||
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
|
||||
if p.handShake != versionSent {
|
||||
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
if p.handShake&versionReceived != 0 {
|
||||
return errors.New("invalid handshake: already received Version")
|
||||
}
|
||||
p.version = version
|
||||
p.handShake = versionReceived
|
||||
p.handShake |= versionReceived
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendVersionAck checks for the handshake state and sends a message to the peer.
|
||||
func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
||||
if p.handShake != versionReceived {
|
||||
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
if p.handShake&versionReceived == 0 {
|
||||
return errors.New("invalid handshake: tried to send VersionAck, but no version received yet")
|
||||
}
|
||||
if p.handShake&verAckSent != 0 {
|
||||
return errors.New("invalid handshake: already sent VersionAck")
|
||||
}
|
||||
err := p.writeMsg(msg)
|
||||
if err == nil {
|
||||
p.handShake = verAckSent
|
||||
p.handShake |= verAckSent
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -111,10 +121,15 @@ func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
|||
// HandleVersionAck checks handshake sequence correctness when VerAck message
|
||||
// is received.
|
||||
func (p *TCPPeer) HandleVersionAck() error {
|
||||
if p.handShake != verAckSent {
|
||||
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
if p.handShake&versionSent == 0 {
|
||||
return errors.New("invalid handshake: received VersionAck, but no version sent yet")
|
||||
}
|
||||
p.handShake = verAckReceived
|
||||
if p.handShake&verAckReceived != 0 {
|
||||
return errors.New("invalid handshake: already received VersionAck")
|
||||
}
|
||||
p.handShake |= verAckReceived
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
83
pkg/network/tcp_peer_test.go
Normal file
83
pkg/network/tcp_peer_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func connReadStub(conn net.Conn) {
|
||||
b := make([]byte, 1024)
|
||||
var err error
|
||||
for ; err == nil; _, err = conn.Read(b) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerHandshake(t *testing.T) {
|
||||
server, client := net.Pipe()
|
||||
|
||||
tcpS := NewTCPPeer(server)
|
||||
tcpC := NewTCPPeer(client)
|
||||
|
||||
// Something should read things written into the pipe.
|
||||
go connReadStub(tcpS.conn)
|
||||
go connReadStub(tcpC.conn)
|
||||
|
||||
// No handshake yet.
|
||||
require.Equal(t, false, tcpS.Handshaked())
|
||||
require.Equal(t, false, tcpC.Handshaked())
|
||||
|
||||
// No ordinary messages can be written.
|
||||
require.Error(t, tcpS.WriteMsg(&Message{}))
|
||||
require.Error(t, tcpC.WriteMsg(&Message{}))
|
||||
|
||||
// Try to mess with VersionAck on both client and server, it should fail.
|
||||
require.Error(t, tcpS.SendVersionAck(&Message{}))
|
||||
require.Error(t, tcpS.HandleVersionAck())
|
||||
require.Error(t, tcpC.SendVersionAck(&Message{}))
|
||||
require.Error(t, tcpC.HandleVersionAck())
|
||||
|
||||
// No handshake yet.
|
||||
require.Equal(t, false, tcpS.Handshaked())
|
||||
require.Equal(t, false, tcpC.Handshaked())
|
||||
|
||||
// Now send and handle versions, but in a different order on client and
|
||||
// server.
|
||||
require.NoError(t, tcpC.SendVersion(&Message{}))
|
||||
require.NoError(t, tcpS.HandleVersion(&payload.Version{}))
|
||||
require.NoError(t, tcpC.HandleVersion(&payload.Version{}))
|
||||
require.NoError(t, tcpS.SendVersion(&Message{}))
|
||||
|
||||
// No handshake yet.
|
||||
require.Equal(t, false, tcpS.Handshaked())
|
||||
require.Equal(t, false, tcpC.Handshaked())
|
||||
|
||||
// These are sent/received and should fail now.
|
||||
require.Error(t, tcpC.SendVersion(&Message{}))
|
||||
require.Error(t, tcpS.HandleVersion(&payload.Version{}))
|
||||
require.Error(t, tcpC.HandleVersion(&payload.Version{}))
|
||||
require.Error(t, tcpS.SendVersion(&Message{}))
|
||||
|
||||
// Now send and handle ACK, again in a different order on client and
|
||||
// server.
|
||||
require.NoError(t, tcpC.SendVersionAck(&Message{}))
|
||||
require.NoError(t, tcpS.HandleVersionAck())
|
||||
require.NoError(t, tcpC.HandleVersionAck())
|
||||
require.NoError(t, tcpS.SendVersionAck(&Message{}))
|
||||
|
||||
// Handshaked now.
|
||||
require.Equal(t, true, tcpS.Handshaked())
|
||||
require.Equal(t, true, tcpC.Handshaked())
|
||||
|
||||
// Subsequent ACKing should fail.
|
||||
require.Error(t, tcpC.SendVersionAck(&Message{}))
|
||||
require.Error(t, tcpS.HandleVersionAck())
|
||||
require.Error(t, tcpC.HandleVersionAck())
|
||||
require.Error(t, tcpS.SendVersionAck(&Message{}))
|
||||
|
||||
// Now regular messaging can proceed.
|
||||
require.NoError(t, tcpS.WriteMsg(&Message{}))
|
||||
require.NoError(t, tcpC.WriteMsg(&Message{}))
|
||||
}
|
Loading…
Reference in a new issue