From ec76ed23a56473286fe78e972ca94f01b09da15e Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 6 Nov 2019 11:06:00 +0300 Subject: [PATCH] network: rework peer handshaking, fix #458 This allows to start handshaking from both client and server (mainnet/testnet nodes were seen to not care about string ordering for it), but still maintains some sane checks in the process. It also makes functions thread-safe because we have two goroutines servicing read and write side of the Peer connection, so they can clash on access to the struct fields. Add a test for it also. --- pkg/network/handshakestage_string.go | 27 --------- pkg/network/tcp_peer.go | 53 +++++++++++------- pkg/network/tcp_peer_test.go | 83 ++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 46 deletions(-) delete mode 100644 pkg/network/handshakestage_string.go create mode 100644 pkg/network/tcp_peer_test.go diff --git a/pkg/network/handshakestage_string.go b/pkg/network/handshakestage_string.go deleted file mode 100644 index 40cac11bb..000000000 --- a/pkg/network/handshakestage_string.go +++ /dev/null @@ -1,27 +0,0 @@ -// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT. - -package network - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[nothingDone-0] - _ = x[versionSent-1] - _ = x[versionReceived-2] - _ = x[verAckSent-3] - _ = x[verAckReceived-4] -} - -const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived" - -var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61} - -func (i handShakeStage) String() string { - if i >= handShakeStage(len(_handShakeStage_index)-1) { - return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")" - } - return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]] -} diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index b1d613f93..e1e0923a5 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -12,13 +12,11 @@ import ( type handShakeStage uint8 -//go:generate stringer -type=handShakeStage const ( - nothingDone handShakeStage = 0 - versionSent handShakeStage = 1 - versionReceived handShakeStage = 2 - verAckSent handShakeStage = 3 - verAckReceived handShakeStage = 4 + versionSent handShakeStage = 1 << iota + versionReceived + verAckSent + verAckReceived ) var ( @@ -34,6 +32,7 @@ type TCPPeer struct { // The version of the peer. version *payload.Version + lock sync.RWMutex handShake handShakeStage done chan error @@ -71,39 +70,50 @@ func (p *TCPPeer) writeMsg(msg *Message) error { // Handshaked returns status of the handshake, whether it's completed or not. func (p *TCPPeer) Handshaked() bool { - return p.handShake == verAckReceived + p.lock.RLock() + defer p.lock.RUnlock() + return p.handShake == (verAckReceived | verAckSent | versionReceived | versionSent) } // SendVersion checks for the handshake state and sends a message to the peer. func (p *TCPPeer) SendVersion(msg *Message) error { - if p.handShake != nothingDone { - return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String()) + p.lock.Lock() + defer p.lock.Unlock() + if p.handShake&versionSent != 0 { + return errors.New("invalid handshake: already sent Version") } err := p.writeMsg(msg) if err == nil { - p.handShake = versionSent + p.handShake |= versionSent } return err } // HandleVersion checks for the handshake state and version message contents. func (p *TCPPeer) HandleVersion(version *payload.Version) error { - if p.handShake != versionSent { - return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String()) + p.lock.Lock() + defer p.lock.Unlock() + if p.handShake&versionReceived != 0 { + return errors.New("invalid handshake: already received Version") } p.version = version - p.handShake = versionReceived + p.handShake |= versionReceived return nil } // SendVersionAck checks for the handshake state and sends a message to the peer. func (p *TCPPeer) SendVersionAck(msg *Message) error { - if p.handShake != versionReceived { - return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String()) + p.lock.Lock() + defer p.lock.Unlock() + if p.handShake&versionReceived == 0 { + return errors.New("invalid handshake: tried to send VersionAck, but no version received yet") + } + if p.handShake&verAckSent != 0 { + return errors.New("invalid handshake: already sent VersionAck") } err := p.writeMsg(msg) if err == nil { - p.handShake = verAckSent + p.handShake |= verAckSent } return err } @@ -111,10 +121,15 @@ func (p *TCPPeer) SendVersionAck(msg *Message) error { // HandleVersionAck checks handshake sequence correctness when VerAck message // is received. func (p *TCPPeer) HandleVersionAck() error { - if p.handShake != verAckSent { - return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String()) + p.lock.Lock() + defer p.lock.Unlock() + if p.handShake&versionSent == 0 { + return errors.New("invalid handshake: received VersionAck, but no version sent yet") } - p.handShake = verAckReceived + if p.handShake&verAckReceived != 0 { + return errors.New("invalid handshake: already received VersionAck") + } + p.handShake |= verAckReceived return nil } diff --git a/pkg/network/tcp_peer_test.go b/pkg/network/tcp_peer_test.go new file mode 100644 index 000000000..b4e1bb3c0 --- /dev/null +++ b/pkg/network/tcp_peer_test.go @@ -0,0 +1,83 @@ +package network + +import ( + "net" + "testing" + + "github.com/CityOfZion/neo-go/pkg/network/payload" + "github.com/stretchr/testify/require" +) + +func connReadStub(conn net.Conn) { + b := make([]byte, 1024) + var err error + for ; err == nil; _, err = conn.Read(b) { + } +} + +func TestPeerHandshake(t *testing.T) { + server, client := net.Pipe() + + tcpS := NewTCPPeer(server) + tcpC := NewTCPPeer(client) + + // Something should read things written into the pipe. + go connReadStub(tcpS.conn) + go connReadStub(tcpC.conn) + + // No handshake yet. + require.Equal(t, false, tcpS.Handshaked()) + require.Equal(t, false, tcpC.Handshaked()) + + // No ordinary messages can be written. + require.Error(t, tcpS.WriteMsg(&Message{})) + require.Error(t, tcpC.WriteMsg(&Message{})) + + // Try to mess with VersionAck on both client and server, it should fail. + require.Error(t, tcpS.SendVersionAck(&Message{})) + require.Error(t, tcpS.HandleVersionAck()) + require.Error(t, tcpC.SendVersionAck(&Message{})) + require.Error(t, tcpC.HandleVersionAck()) + + // No handshake yet. + require.Equal(t, false, tcpS.Handshaked()) + require.Equal(t, false, tcpC.Handshaked()) + + // Now send and handle versions, but in a different order on client and + // server. + require.NoError(t, tcpC.SendVersion(&Message{})) + require.NoError(t, tcpS.HandleVersion(&payload.Version{})) + require.NoError(t, tcpC.HandleVersion(&payload.Version{})) + require.NoError(t, tcpS.SendVersion(&Message{})) + + // No handshake yet. + require.Equal(t, false, tcpS.Handshaked()) + require.Equal(t, false, tcpC.Handshaked()) + + // These are sent/received and should fail now. + require.Error(t, tcpC.SendVersion(&Message{})) + require.Error(t, tcpS.HandleVersion(&payload.Version{})) + require.Error(t, tcpC.HandleVersion(&payload.Version{})) + require.Error(t, tcpS.SendVersion(&Message{})) + + // Now send and handle ACK, again in a different order on client and + // server. + require.NoError(t, tcpC.SendVersionAck(&Message{})) + require.NoError(t, tcpS.HandleVersionAck()) + require.NoError(t, tcpC.HandleVersionAck()) + require.NoError(t, tcpS.SendVersionAck(&Message{})) + + // Handshaked now. + require.Equal(t, true, tcpS.Handshaked()) + require.Equal(t, true, tcpC.Handshaked()) + + // Subsequent ACKing should fail. + require.Error(t, tcpC.SendVersionAck(&Message{})) + require.Error(t, tcpS.HandleVersionAck()) + require.Error(t, tcpC.HandleVersionAck()) + require.Error(t, tcpS.SendVersionAck(&Message{})) + + // Now regular messaging can proceed. + require.NoError(t, tcpS.WriteMsg(&Message{})) + require.NoError(t, tcpC.WriteMsg(&Message{})) +}