Merge pull request #478 from nspcc-dev/handshake-and-peers-fix
Fixes #458, MaxPeers handling and some other related things.
This commit is contained in:
commit
79d0c7446a
18 changed files with 256 additions and 85 deletions
|
@ -72,6 +72,7 @@ type (
|
||||||
DialTimeout time.Duration `yaml:"DialTimeout"`
|
DialTimeout time.Duration `yaml:"DialTimeout"`
|
||||||
ProtoTickInterval time.Duration `yaml:"ProtoTickInterval"`
|
ProtoTickInterval time.Duration `yaml:"ProtoTickInterval"`
|
||||||
MaxPeers int `yaml:"MaxPeers"`
|
MaxPeers int `yaml:"MaxPeers"`
|
||||||
|
AttemptConnPeers int `yaml:"AttemptConnPeers"`
|
||||||
MinPeers int `yaml:"MinPeers"`
|
MinPeers int `yaml:"MinPeers"`
|
||||||
Monitoring metrics.PrometheusConfig `yaml:"Monitoring"`
|
Monitoring metrics.PrometheusConfig `yaml:"Monitoring"`
|
||||||
RPC RPCConfig `yaml:"RPC"`
|
RPC RPCConfig `yaml:"RPC"`
|
||||||
|
|
|
@ -42,16 +42,17 @@ ApplicationConfiguration:
|
||||||
# DB: 0
|
# DB: 0
|
||||||
# BoltDBOptions:
|
# BoltDBOptions:
|
||||||
# FilePath: "./chains/mainnet.bolt"
|
# FilePath: "./chains/mainnet.bolt"
|
||||||
NodePort: 20333
|
NodePort: 10333
|
||||||
Relay: true
|
Relay: true
|
||||||
DialTimeout: 3
|
DialTimeout: 3
|
||||||
ProtoTickInterval: 2
|
ProtoTickInterval: 2
|
||||||
MaxPeers: 50
|
MaxPeers: 100
|
||||||
|
AttemptConnPeers: 20
|
||||||
MinPeers: 5
|
MinPeers: 5
|
||||||
RPC:
|
RPC:
|
||||||
Enabled: true
|
Enabled: true
|
||||||
EnableCORSWorkaround: false
|
EnableCORSWorkaround: false
|
||||||
Port: 20332
|
Port: 10332
|
||||||
Monitoring:
|
Monitoring:
|
||||||
Enabled: true
|
Enabled: true
|
||||||
Port: 2112
|
Port: 2112
|
||||||
|
|
|
@ -34,7 +34,8 @@ ApplicationConfiguration:
|
||||||
Relay: true
|
Relay: true
|
||||||
DialTimeout: 3
|
DialTimeout: 3
|
||||||
ProtoTickInterval: 2
|
ProtoTickInterval: 2
|
||||||
MaxPeers: 50
|
MaxPeers: 10
|
||||||
|
AttemptConnPeers: 5
|
||||||
MinPeers: 3
|
MinPeers: 3
|
||||||
RPC:
|
RPC:
|
||||||
Enabled: true
|
Enabled: true
|
||||||
|
|
|
@ -31,7 +31,8 @@ ApplicationConfiguration:
|
||||||
Relay: true
|
Relay: true
|
||||||
DialTimeout: 3
|
DialTimeout: 3
|
||||||
ProtoTickInterval: 2
|
ProtoTickInterval: 2
|
||||||
MaxPeers: 50
|
MaxPeers: 10
|
||||||
|
AttemptConnPeers: 5
|
||||||
MinPeers: 3
|
MinPeers: 3
|
||||||
RPC:
|
RPC:
|
||||||
Enabled: true
|
Enabled: true
|
||||||
|
|
|
@ -31,7 +31,8 @@ ApplicationConfiguration:
|
||||||
Relay: true
|
Relay: true
|
||||||
DialTimeout: 3
|
DialTimeout: 3
|
||||||
ProtoTickInterval: 2
|
ProtoTickInterval: 2
|
||||||
MaxPeers: 50
|
MaxPeers: 10
|
||||||
|
AttemptConnPeers: 5
|
||||||
MinPeers: 3
|
MinPeers: 3
|
||||||
RPC:
|
RPC:
|
||||||
Enabled: true
|
Enabled: true
|
||||||
|
|
|
@ -31,7 +31,8 @@ ApplicationConfiguration:
|
||||||
Relay: true
|
Relay: true
|
||||||
DialTimeout: 3
|
DialTimeout: 3
|
||||||
ProtoTickInterval: 2
|
ProtoTickInterval: 2
|
||||||
MaxPeers: 50
|
MaxPeers: 10
|
||||||
|
AttemptConnPeers: 5
|
||||||
MinPeers: 3
|
MinPeers: 3
|
||||||
RPC:
|
RPC:
|
||||||
Enabled: true
|
Enabled: true
|
||||||
|
|
|
@ -37,7 +37,8 @@ ApplicationConfiguration:
|
||||||
Relay: true
|
Relay: true
|
||||||
DialTimeout: 3
|
DialTimeout: 3
|
||||||
ProtoTickInterval: 2
|
ProtoTickInterval: 2
|
||||||
MaxPeers: 50
|
MaxPeers: 10
|
||||||
|
AttemptConnPeers: 5
|
||||||
MinPeers: 3
|
MinPeers: 3
|
||||||
RPC:
|
RPC:
|
||||||
Enabled: true
|
Enabled: true
|
||||||
|
|
|
@ -46,7 +46,8 @@ ApplicationConfiguration:
|
||||||
Relay: true
|
Relay: true
|
||||||
DialTimeout: 3
|
DialTimeout: 3
|
||||||
ProtoTickInterval: 2
|
ProtoTickInterval: 2
|
||||||
MaxPeers: 50
|
MaxPeers: 100
|
||||||
|
AttemptConnPeers: 20
|
||||||
MinPeers: 5
|
MinPeers: 5
|
||||||
RPC:
|
RPC:
|
||||||
Enabled: true
|
Enabled: true
|
||||||
|
|
|
@ -37,6 +37,7 @@ ApplicationConfiguration:
|
||||||
DialTimeout: 3
|
DialTimeout: 3
|
||||||
ProtoTickInterval: 2
|
ProtoTickInterval: 2
|
||||||
MaxPeers: 50
|
MaxPeers: 50
|
||||||
|
AttemptConnPeers: 5
|
||||||
MinPeers: 1
|
MinPeers: 1
|
||||||
RPC:
|
RPC:
|
||||||
Enabled: true
|
Enabled: true
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
|
|
||||||
|
|
||||||
package network
|
|
||||||
|
|
||||||
import "strconv"
|
|
||||||
|
|
||||||
func _() {
|
|
||||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
|
||||||
// Re-run the stringer command to generate them again.
|
|
||||||
var x [1]struct{}
|
|
||||||
_ = x[nothingDone-0]
|
|
||||||
_ = x[versionSent-1]
|
|
||||||
_ = x[versionReceived-2]
|
|
||||||
_ = x[verAckSent-3]
|
|
||||||
_ = x[verAckReceived-4]
|
|
||||||
}
|
|
||||||
|
|
||||||
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
|
|
||||||
|
|
||||||
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
|
|
||||||
|
|
||||||
func (i handShakeStage) String() string {
|
|
||||||
if i >= handShakeStage(len(_handShakeStage_index)-1) {
|
|
||||||
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
|
|
||||||
}
|
|
||||||
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
|
|
||||||
}
|
|
|
@ -158,7 +158,10 @@ func newLocalPeer(t *testing.T) *localPeer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *localPeer) NetAddr() *net.TCPAddr {
|
func (p *localPeer) RemoteAddr() net.Addr {
|
||||||
|
return &p.netaddr
|
||||||
|
}
|
||||||
|
func (p *localPeer) PeerAddr() net.Addr {
|
||||||
return &p.netaddr
|
return &p.netaddr
|
||||||
}
|
}
|
||||||
func (p *localPeer) Disconnect(err error) {}
|
func (p *localPeer) Disconnect(err error) {}
|
||||||
|
|
|
@ -8,7 +8,15 @@ import (
|
||||||
|
|
||||||
// Peer represents a network node neo-go is connected to.
|
// Peer represents a network node neo-go is connected to.
|
||||||
type Peer interface {
|
type Peer interface {
|
||||||
NetAddr() *net.TCPAddr
|
// RemoteAddr returns the remote address that we're connected to now.
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
// PeerAddr returns the remote address that should be used to establish
|
||||||
|
// a new connection to the node. It can differ from the RemoteAddr
|
||||||
|
// address in case where the remote node is a client and its current
|
||||||
|
// connection port is different from the one the other node should use
|
||||||
|
// to connect to it. It's only valid after the handshake is completed,
|
||||||
|
// before that it returns the same address as RemoteAddr.
|
||||||
|
PeerAddr() net.Addr
|
||||||
Disconnect(error)
|
Disconnect(error)
|
||||||
WriteMsg(msg *Message) error
|
WriteMsg(msg *Message) error
|
||||||
Done() chan error
|
Done() chan error
|
||||||
|
|
|
@ -18,17 +18,20 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// peer numbers are arbitrary at the moment.
|
// peer numbers are arbitrary at the moment.
|
||||||
defaultMinPeers = 5
|
defaultMinPeers = 5
|
||||||
maxPeers = 20
|
defaultAttemptConnPeers = 20
|
||||||
maxBlockBatch = 200
|
defaultMaxPeers = 100
|
||||||
maxAddrsToSend = 200
|
maxBlockBatch = 200
|
||||||
minPoolCount = 30
|
maxAddrsToSend = 200
|
||||||
|
minPoolCount = 30
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
errAlreadyConnected = errors.New("already connected")
|
||||||
errIdenticalID = errors.New("identical node id")
|
errIdenticalID = errors.New("identical node id")
|
||||||
errInvalidHandshake = errors.New("invalid handshake")
|
errInvalidHandshake = errors.New("invalid handshake")
|
||||||
errInvalidNetwork = errors.New("invalid network")
|
errInvalidNetwork = errors.New("invalid network")
|
||||||
|
errMaxPeers = errors.New("max peers reached")
|
||||||
errServerShutdown = errors.New("server shutdown")
|
errServerShutdown = errors.New("server shutdown")
|
||||||
errInvalidInvType = errors.New("invalid inventory type")
|
errInvalidInvType = errors.New("invalid inventory type")
|
||||||
)
|
)
|
||||||
|
@ -85,6 +88,22 @@ func NewServer(config ServerConfig, chain core.Blockchainer) *Server {
|
||||||
s.MinPeers = defaultMinPeers
|
s.MinPeers = defaultMinPeers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.MaxPeers <= 0 {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"MaxPeers configured": s.MaxPeers,
|
||||||
|
"MaxPeers actual": defaultMaxPeers,
|
||||||
|
}).Info("bad MaxPeers configured, using the default value")
|
||||||
|
s.MaxPeers = defaultMaxPeers
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.AttemptConnPeers <= 0 {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"AttemptConnPeers configured": s.AttemptConnPeers,
|
||||||
|
"AttemptConnPeers actual": defaultAttemptConnPeers,
|
||||||
|
}).Info("bad AttemptConnPeers configured, using the default value")
|
||||||
|
s.AttemptConnPeers = defaultAttemptConnPeers
|
||||||
|
}
|
||||||
|
|
||||||
s.transport = NewTCPTransport(s, fmt.Sprintf(":%d", config.ListenTCP))
|
s.transport = NewTCPTransport(s, fmt.Sprintf(":%d", config.ListenTCP))
|
||||||
s.discovery = NewDefaultDiscovery(
|
s.discovery = NewDefaultDiscovery(
|
||||||
s.DialTimeout,
|
s.DialTimeout,
|
||||||
|
@ -136,9 +155,8 @@ func (s *Server) BadPeers() []string {
|
||||||
|
|
||||||
func (s *Server) run() {
|
func (s *Server) run() {
|
||||||
for {
|
for {
|
||||||
c := s.PeerCount()
|
if s.PeerCount() < s.MinPeers {
|
||||||
if c < s.ServerConfig.MinPeers {
|
s.discovery.RequestRemote(s.AttemptConnPeers)
|
||||||
s.discovery.RequestRemote(maxPeers - c)
|
|
||||||
}
|
}
|
||||||
if s.discovery.PoolCount() < minPoolCount {
|
if s.discovery.PoolCount() < minPoolCount {
|
||||||
select {
|
select {
|
||||||
|
@ -160,30 +178,47 @@ func (s *Server) run() {
|
||||||
// When a new peer is connected we send out our version immediately.
|
// When a new peer is connected we send out our version immediately.
|
||||||
if err := s.sendVersion(p); err != nil {
|
if err := s.sendVersion(p); err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"addr": p.NetAddr(),
|
"addr": p.RemoteAddr(),
|
||||||
}).Error(err)
|
}).Error(err)
|
||||||
}
|
}
|
||||||
|
s.lock.Lock()
|
||||||
s.peers[p] = true
|
s.peers[p] = true
|
||||||
|
s.lock.Unlock()
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"addr": p.NetAddr(),
|
"addr": p.RemoteAddr(),
|
||||||
}).Info("new peer connected")
|
}).Info("new peer connected")
|
||||||
|
peerCount := s.PeerCount()
|
||||||
|
if peerCount > s.MaxPeers {
|
||||||
|
s.lock.RLock()
|
||||||
|
// Pick a random peer and drop connection to it.
|
||||||
|
for peer := range s.peers {
|
||||||
|
peer.Disconnect(errMaxPeers)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
s.lock.RUnlock()
|
||||||
|
}
|
||||||
updatePeersConnectedMetric(s.PeerCount())
|
updatePeersConnectedMetric(s.PeerCount())
|
||||||
|
|
||||||
case drop := <-s.unregister:
|
case drop := <-s.unregister:
|
||||||
|
s.lock.Lock()
|
||||||
if s.peers[drop.peer] {
|
if s.peers[drop.peer] {
|
||||||
delete(s.peers, drop.peer)
|
delete(s.peers, drop.peer)
|
||||||
|
s.lock.Unlock()
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"addr": drop.peer.NetAddr(),
|
"addr": drop.peer.RemoteAddr(),
|
||||||
"reason": drop.reason,
|
"reason": drop.reason,
|
||||||
"peerCount": s.PeerCount(),
|
"peerCount": s.PeerCount(),
|
||||||
}).Warn("peer disconnected")
|
}).Warn("peer disconnected")
|
||||||
addr := drop.peer.NetAddr().String()
|
addr := drop.peer.PeerAddr().String()
|
||||||
s.discovery.UnregisterConnectedAddr(addr)
|
s.discovery.UnregisterConnectedAddr(addr)
|
||||||
s.discovery.BackFill(addr)
|
s.discovery.BackFill(addr)
|
||||||
updatePeersConnectedMetric(s.PeerCount())
|
updatePeersConnectedMetric(s.PeerCount())
|
||||||
|
} else {
|
||||||
|
// else the peer is already gone, which can happen
|
||||||
|
// because we have two goroutines sending signals here
|
||||||
|
s.lock.Unlock()
|
||||||
}
|
}
|
||||||
// else the peer is already gone, which can happen
|
|
||||||
// because we have two goroutines sending signals here
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -205,13 +240,13 @@ func (s *Server) PeerCount() int {
|
||||||
// every ProtoTickInterval with the peer.
|
// every ProtoTickInterval with the peer.
|
||||||
func (s *Server) startProtocol(p Peer) {
|
func (s *Server) startProtocol(p Peer) {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"addr": p.NetAddr(),
|
"addr": p.RemoteAddr(),
|
||||||
"userAgent": string(p.Version().UserAgent),
|
"userAgent": string(p.Version().UserAgent),
|
||||||
"startHeight": p.Version().StartHeight,
|
"startHeight": p.Version().StartHeight,
|
||||||
"id": p.Version().Nonce,
|
"id": p.Version().Nonce,
|
||||||
}).Info("started protocol")
|
}).Info("started protocol")
|
||||||
|
|
||||||
s.discovery.RegisterGoodAddr(p.NetAddr().String())
|
s.discovery.RegisterGoodAddr(p.PeerAddr().String())
|
||||||
err := s.requestHeaders(p)
|
err := s.requestHeaders(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.Disconnect(err)
|
p.Disconnect(err)
|
||||||
|
@ -265,6 +300,16 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
|
||||||
if s.id == version.Nonce {
|
if s.id == version.Nonce {
|
||||||
return errIdenticalID
|
return errIdenticalID
|
||||||
}
|
}
|
||||||
|
peerAddr := p.PeerAddr().String()
|
||||||
|
s.lock.RLock()
|
||||||
|
for peer := range s.peers {
|
||||||
|
// Already connected, drop this connection.
|
||||||
|
if peer.Handshaked() && peer.PeerAddr().String() == peerAddr && peer.Version().Nonce == version.Nonce {
|
||||||
|
s.lock.RUnlock()
|
||||||
|
return errAlreadyConnected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.lock.RUnlock()
|
||||||
return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil))
|
return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,11 @@ type (
|
||||||
// connect with some new ones.
|
// connect with some new ones.
|
||||||
MinPeers int
|
MinPeers int
|
||||||
|
|
||||||
|
// AttemptConnPeers it the number of connection to try to
|
||||||
|
// establish when the connection count drops below the MinPeers
|
||||||
|
// value.
|
||||||
|
AttemptConnPeers int
|
||||||
|
|
||||||
// MaxPeers it the maximum numbers of peers that can
|
// MaxPeers it the maximum numbers of peers that can
|
||||||
// be connected to the server.
|
// be connected to the server.
|
||||||
MaxPeers int
|
MaxPeers int
|
||||||
|
@ -64,6 +69,7 @@ func NewServerConfig(cfg config.Config) ServerConfig {
|
||||||
DialTimeout: appConfig.DialTimeout * time.Second,
|
DialTimeout: appConfig.DialTimeout * time.Second,
|
||||||
ProtoTickInterval: appConfig.ProtoTickInterval * time.Second,
|
ProtoTickInterval: appConfig.ProtoTickInterval * time.Second,
|
||||||
MaxPeers: appConfig.MaxPeers,
|
MaxPeers: appConfig.MaxPeers,
|
||||||
|
AttemptConnPeers: appConfig.AttemptConnPeers,
|
||||||
MinPeers: appConfig.MinPeers,
|
MinPeers: appConfig.MinPeers,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSendVersion(t *testing.T) {
|
func TestSendVersion(t *testing.T) {
|
||||||
|
@ -57,14 +58,16 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) {
|
||||||
// invalid version and disconnects the peer.
|
// invalid version and disconnects the peer.
|
||||||
func TestServerNotSendsVerack(t *testing.T) {
|
func TestServerNotSendsVerack(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
s = newTestServer()
|
s = newTestServer()
|
||||||
p = newLocalPeer(t)
|
p = newLocalPeer(t)
|
||||||
|
p2 = newLocalPeer(t)
|
||||||
)
|
)
|
||||||
s.id = 1
|
s.id = 1
|
||||||
go s.run()
|
go s.run()
|
||||||
|
|
||||||
na, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:3000")
|
na, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:3000")
|
||||||
p.netaddr = *na
|
p.netaddr = *na
|
||||||
|
p2.netaddr = *na
|
||||||
s.register <- p
|
s.register <- p
|
||||||
|
|
||||||
// identical id's
|
// identical id's
|
||||||
|
@ -72,6 +75,18 @@ func TestServerNotSendsVerack(t *testing.T) {
|
||||||
err := s.handleVersionCmd(p, version)
|
err := s.handleVersionCmd(p, version)
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
assert.Equal(t, errIdenticalID, err)
|
assert.Equal(t, errIdenticalID, err)
|
||||||
|
|
||||||
|
// Different IDs, make handshake pass.
|
||||||
|
version.Nonce = 2
|
||||||
|
require.NoError(t, s.handleVersionCmd(p, version))
|
||||||
|
require.NoError(t, p.HandleVersionAck())
|
||||||
|
require.Equal(t, true, p.Handshaked())
|
||||||
|
|
||||||
|
// Second handshake from the same peer should fail.
|
||||||
|
s.register <- p2
|
||||||
|
err = s.handleVersionCmd(p2, version)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
require.Equal(t, errAlreadyConnected, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestHeaders(t *testing.T) {
|
func TestRequestHeaders(t *testing.T) {
|
||||||
|
|
|
@ -2,8 +2,8 @@ package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/CityOfZion/neo-go/pkg/io"
|
"github.com/CityOfZion/neo-go/pkg/io"
|
||||||
|
@ -12,13 +12,11 @@ import (
|
||||||
|
|
||||||
type handShakeStage uint8
|
type handShakeStage uint8
|
||||||
|
|
||||||
//go:generate stringer -type=handShakeStage
|
|
||||||
const (
|
const (
|
||||||
nothingDone handShakeStage = 0
|
versionSent handShakeStage = 1 << iota
|
||||||
versionSent handShakeStage = 1
|
versionReceived
|
||||||
versionReceived handShakeStage = 2
|
verAckSent
|
||||||
verAckSent handShakeStage = 3
|
verAckReceived
|
||||||
verAckReceived handShakeStage = 4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -30,11 +28,11 @@ var (
|
||||||
type TCPPeer struct {
|
type TCPPeer struct {
|
||||||
// underlying TCP connection.
|
// underlying TCP connection.
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
addr net.TCPAddr
|
|
||||||
|
|
||||||
// The version of the peer.
|
// The version of the peer.
|
||||||
version *payload.Version
|
version *payload.Version
|
||||||
|
|
||||||
|
lock sync.RWMutex
|
||||||
handShake handShakeStage
|
handShake handShakeStage
|
||||||
|
|
||||||
done chan error
|
done chan error
|
||||||
|
@ -44,13 +42,9 @@ type TCPPeer struct {
|
||||||
|
|
||||||
// NewTCPPeer returns a TCPPeer structure based on the given connection.
|
// NewTCPPeer returns a TCPPeer structure based on the given connection.
|
||||||
func NewTCPPeer(conn net.Conn) *TCPPeer {
|
func NewTCPPeer(conn net.Conn) *TCPPeer {
|
||||||
raddr := conn.RemoteAddr()
|
|
||||||
// can't fail because raddr is a real connection
|
|
||||||
tcpaddr, _ := net.ResolveTCPAddr(raddr.Network(), raddr.String())
|
|
||||||
return &TCPPeer{
|
return &TCPPeer{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
done: make(chan error, 1),
|
done: make(chan error, 1),
|
||||||
addr: *tcpaddr,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,39 +70,50 @@ func (p *TCPPeer) writeMsg(msg *Message) error {
|
||||||
|
|
||||||
// Handshaked returns status of the handshake, whether it's completed or not.
|
// Handshaked returns status of the handshake, whether it's completed or not.
|
||||||
func (p *TCPPeer) Handshaked() bool {
|
func (p *TCPPeer) Handshaked() bool {
|
||||||
return p.handShake == verAckReceived
|
p.lock.RLock()
|
||||||
|
defer p.lock.RUnlock()
|
||||||
|
return p.handShake == (verAckReceived | verAckSent | versionReceived | versionSent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendVersion checks for the handshake state and sends a message to the peer.
|
// SendVersion checks for the handshake state and sends a message to the peer.
|
||||||
func (p *TCPPeer) SendVersion(msg *Message) error {
|
func (p *TCPPeer) SendVersion(msg *Message) error {
|
||||||
if p.handShake != nothingDone {
|
p.lock.Lock()
|
||||||
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
|
defer p.lock.Unlock()
|
||||||
|
if p.handShake&versionSent != 0 {
|
||||||
|
return errors.New("invalid handshake: already sent Version")
|
||||||
}
|
}
|
||||||
err := p.writeMsg(msg)
|
err := p.writeMsg(msg)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p.handShake = versionSent
|
p.handShake |= versionSent
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleVersion checks for the handshake state and version message contents.
|
// HandleVersion checks for the handshake state and version message contents.
|
||||||
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
|
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
|
||||||
if p.handShake != versionSent {
|
p.lock.Lock()
|
||||||
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
|
defer p.lock.Unlock()
|
||||||
|
if p.handShake&versionReceived != 0 {
|
||||||
|
return errors.New("invalid handshake: already received Version")
|
||||||
}
|
}
|
||||||
p.version = version
|
p.version = version
|
||||||
p.handShake = versionReceived
|
p.handShake |= versionReceived
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendVersionAck checks for the handshake state and sends a message to the peer.
|
// SendVersionAck checks for the handshake state and sends a message to the peer.
|
||||||
func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
||||||
if p.handShake != versionReceived {
|
p.lock.Lock()
|
||||||
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
|
defer p.lock.Unlock()
|
||||||
|
if p.handShake&versionReceived == 0 {
|
||||||
|
return errors.New("invalid handshake: tried to send VersionAck, but no version received yet")
|
||||||
|
}
|
||||||
|
if p.handShake&verAckSent != 0 {
|
||||||
|
return errors.New("invalid handshake: already sent VersionAck")
|
||||||
}
|
}
|
||||||
err := p.writeMsg(msg)
|
err := p.writeMsg(msg)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p.handShake = verAckSent
|
p.handShake |= verAckSent
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -116,16 +121,40 @@ func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
||||||
// HandleVersionAck checks handshake sequence correctness when VerAck message
|
// HandleVersionAck checks handshake sequence correctness when VerAck message
|
||||||
// is received.
|
// is received.
|
||||||
func (p *TCPPeer) HandleVersionAck() error {
|
func (p *TCPPeer) HandleVersionAck() error {
|
||||||
if p.handShake != verAckSent {
|
p.lock.Lock()
|
||||||
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
|
defer p.lock.Unlock()
|
||||||
|
if p.handShake&versionSent == 0 {
|
||||||
|
return errors.New("invalid handshake: received VersionAck, but no version sent yet")
|
||||||
}
|
}
|
||||||
p.handShake = verAckReceived
|
if p.handShake&verAckReceived != 0 {
|
||||||
|
return errors.New("invalid handshake: already received VersionAck")
|
||||||
|
}
|
||||||
|
p.handShake |= verAckReceived
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NetAddr implements the Peer interface.
|
// RemoteAddr implements the Peer interface.
|
||||||
func (p *TCPPeer) NetAddr() *net.TCPAddr {
|
func (p *TCPPeer) RemoteAddr() net.Addr {
|
||||||
return &p.addr
|
return p.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerAddr implements the Peer interface.
|
||||||
|
func (p *TCPPeer) PeerAddr() net.Addr {
|
||||||
|
remote := p.conn.RemoteAddr()
|
||||||
|
// The network can be non-tcp in unit tests.
|
||||||
|
if !p.Handshaked() || remote.Network() != "tcp" {
|
||||||
|
return p.RemoteAddr()
|
||||||
|
}
|
||||||
|
host, _, err := net.SplitHostPort(remote.String())
|
||||||
|
if err != nil {
|
||||||
|
return p.RemoteAddr()
|
||||||
|
}
|
||||||
|
addrString := net.JoinHostPort(host, strconv.Itoa(int(p.version.Port)))
|
||||||
|
tcpAddr, err := net.ResolveTCPAddr("tcp", addrString)
|
||||||
|
if err != nil {
|
||||||
|
return p.RemoteAddr()
|
||||||
|
}
|
||||||
|
return tcpAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Done implements the Peer interface and notifies
|
// Done implements the Peer interface and notifies
|
||||||
|
|
83
pkg/network/tcp_peer_test.go
Normal file
83
pkg/network/tcp_peer_test.go
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func connReadStub(conn net.Conn) {
|
||||||
|
b := make([]byte, 1024)
|
||||||
|
var err error
|
||||||
|
for ; err == nil; _, err = conn.Read(b) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerHandshake(t *testing.T) {
|
||||||
|
server, client := net.Pipe()
|
||||||
|
|
||||||
|
tcpS := NewTCPPeer(server)
|
||||||
|
tcpC := NewTCPPeer(client)
|
||||||
|
|
||||||
|
// Something should read things written into the pipe.
|
||||||
|
go connReadStub(tcpS.conn)
|
||||||
|
go connReadStub(tcpC.conn)
|
||||||
|
|
||||||
|
// No handshake yet.
|
||||||
|
require.Equal(t, false, tcpS.Handshaked())
|
||||||
|
require.Equal(t, false, tcpC.Handshaked())
|
||||||
|
|
||||||
|
// No ordinary messages can be written.
|
||||||
|
require.Error(t, tcpS.WriteMsg(&Message{}))
|
||||||
|
require.Error(t, tcpC.WriteMsg(&Message{}))
|
||||||
|
|
||||||
|
// Try to mess with VersionAck on both client and server, it should fail.
|
||||||
|
require.Error(t, tcpS.SendVersionAck(&Message{}))
|
||||||
|
require.Error(t, tcpS.HandleVersionAck())
|
||||||
|
require.Error(t, tcpC.SendVersionAck(&Message{}))
|
||||||
|
require.Error(t, tcpC.HandleVersionAck())
|
||||||
|
|
||||||
|
// No handshake yet.
|
||||||
|
require.Equal(t, false, tcpS.Handshaked())
|
||||||
|
require.Equal(t, false, tcpC.Handshaked())
|
||||||
|
|
||||||
|
// Now send and handle versions, but in a different order on client and
|
||||||
|
// server.
|
||||||
|
require.NoError(t, tcpC.SendVersion(&Message{}))
|
||||||
|
require.NoError(t, tcpS.HandleVersion(&payload.Version{}))
|
||||||
|
require.NoError(t, tcpC.HandleVersion(&payload.Version{}))
|
||||||
|
require.NoError(t, tcpS.SendVersion(&Message{}))
|
||||||
|
|
||||||
|
// No handshake yet.
|
||||||
|
require.Equal(t, false, tcpS.Handshaked())
|
||||||
|
require.Equal(t, false, tcpC.Handshaked())
|
||||||
|
|
||||||
|
// These are sent/received and should fail now.
|
||||||
|
require.Error(t, tcpC.SendVersion(&Message{}))
|
||||||
|
require.Error(t, tcpS.HandleVersion(&payload.Version{}))
|
||||||
|
require.Error(t, tcpC.HandleVersion(&payload.Version{}))
|
||||||
|
require.Error(t, tcpS.SendVersion(&Message{}))
|
||||||
|
|
||||||
|
// Now send and handle ACK, again in a different order on client and
|
||||||
|
// server.
|
||||||
|
require.NoError(t, tcpC.SendVersionAck(&Message{}))
|
||||||
|
require.NoError(t, tcpS.HandleVersionAck())
|
||||||
|
require.NoError(t, tcpC.HandleVersionAck())
|
||||||
|
require.NoError(t, tcpS.SendVersionAck(&Message{}))
|
||||||
|
|
||||||
|
// Handshaked now.
|
||||||
|
require.Equal(t, true, tcpS.Handshaked())
|
||||||
|
require.Equal(t, true, tcpC.Handshaked())
|
||||||
|
|
||||||
|
// Subsequent ACKing should fail.
|
||||||
|
require.Error(t, tcpC.SendVersionAck(&Message{}))
|
||||||
|
require.Error(t, tcpS.HandleVersionAck())
|
||||||
|
require.Error(t, tcpC.HandleVersionAck())
|
||||||
|
require.Error(t, tcpS.SendVersionAck(&Message{}))
|
||||||
|
|
||||||
|
// Now regular messaging can proceed.
|
||||||
|
require.NoError(t, tcpS.WriteMsg(&Message{}))
|
||||||
|
require.NoError(t, tcpC.WriteMsg(&Message{}))
|
||||||
|
}
|
|
@ -196,7 +196,7 @@ Methods:
|
||||||
}
|
}
|
||||||
|
|
||||||
for addr := range s.coreServer.Peers() {
|
for addr := range s.coreServer.Peers() {
|
||||||
peers.AddPeer("connected", addr.NetAddr().String())
|
peers.AddPeer("connected", addr.PeerAddr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
results = peers
|
results = peers
|
||||||
|
|
Loading…
Reference in a new issue