diff --git a/pkg/network/server.go b/pkg/network/server.go index 02e26bb4a..10b9f6e43 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -26,6 +26,7 @@ const ( ) var ( + errAlreadyConnected = errors.New("already connected") errIdenticalID = errors.New("identical node id") errInvalidHandshake = errors.New("invalid handshake") errInvalidNetwork = errors.New("invalid network") @@ -272,6 +273,16 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { if s.id == version.Nonce { return errIdenticalID } + peerAddr := p.PeerAddr().String() + s.lock.RLock() + for peer := range s.peers { + // Already connected, drop this connection. + if peer.Handshaked() && peer.PeerAddr().String() == peerAddr && peer.Version().Nonce == version.Nonce { + s.lock.RUnlock() + return errAlreadyConnected + } + } + s.lock.RUnlock() return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil)) } diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index b24205460..70101d89d 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -6,6 +6,7 @@ import ( "github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSendVersion(t *testing.T) { @@ -57,14 +58,16 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) { // invalid version and disconnects the peer. func TestServerNotSendsVerack(t *testing.T) { var ( - s = newTestServer() - p = newLocalPeer(t) + s = newTestServer() + p = newLocalPeer(t) + p2 = newLocalPeer(t) ) s.id = 1 go s.run() na, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:3000") p.netaddr = *na + p2.netaddr = *na s.register <- p // identical id's @@ -72,6 +75,18 @@ func TestServerNotSendsVerack(t *testing.T) { err := s.handleVersionCmd(p, version) assert.NotNil(t, err) assert.Equal(t, errIdenticalID, err) + + // Different IDs, make handshake pass. + version.Nonce = 2 + require.NoError(t, s.handleVersionCmd(p, version)) + require.NoError(t, p.HandleVersionAck()) + require.Equal(t, true, p.Handshaked()) + + // Second handshake from the same peer should fail. + s.register <- p2 + err = s.handleVersionCmd(p2, version) + assert.NotNil(t, err) + require.Equal(t, errAlreadyConnected, err) } func TestRequestHeaders(t *testing.T) {