Merge pull request #396 from nspcc-dev/network-reconnections-and-fixes
This one fixes #390 and some connected problems. After this patchset the node reconnects to some other nodes if anything goes wrong and it better senses when something goes wrong. It also fixes some block handling problems based on the testnet connection experience.
This commit is contained in:
commit
adba9e11ee
13 changed files with 427 additions and 119 deletions
|
@ -357,8 +357,7 @@ func (bc *Blockchain) persistBlock(block *Block) error {
|
|||
Email: t.Email,
|
||||
Description: t.Description,
|
||||
}
|
||||
|
||||
fmt.Printf("%+v", contract)
|
||||
_ = contract
|
||||
|
||||
case *transaction.InvocationTX:
|
||||
}
|
||||
|
@ -430,6 +429,15 @@ func (bc *Blockchain) persist(ctx context.Context) (err error) {
|
|||
"blockHeight": bc.BlockHeight(),
|
||||
"took": time.Since(start),
|
||||
}).Info("blockchain persist completed")
|
||||
} else {
|
||||
// So we have some blocks in cache but can't persist them?
|
||||
// Either there are some stale blocks there or the other way
|
||||
// around (which was seen in practice) --- there are some fresh
|
||||
// blocks that we can't persist yet. Some of the latter can be useful
|
||||
// or can be bogus (higher than the header height we expect at
|
||||
// the moment). So try to reap oldies and strange newbies, if
|
||||
// there are any.
|
||||
bc.blockCache.ReapStrangeBlocks(bc.BlockHeight(), bc.HeaderHeight())
|
||||
}
|
||||
|
||||
return
|
||||
|
|
|
@ -71,3 +71,17 @@ func (c *Cache) Delete(h util.Uint256) {
|
|||
defer c.lock.Unlock()
|
||||
delete(c.m, h)
|
||||
}
|
||||
|
||||
// ReapStrangeBlocks drops blocks from cache that don't fit into the
|
||||
// blkHeight-headHeight interval. Cache should only contain blocks that we
|
||||
// expect to get and store.
|
||||
func (c *Cache) ReapStrangeBlocks(blkHeight, headHeight uint32) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
for i, b := range c.m {
|
||||
block, ok := b.(*Block)
|
||||
if ok && (block.Index < blkHeight || block.Index > headHeight) {
|
||||
delete(c.m, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
const (
|
||||
maxPoolSize = 200
|
||||
connRetries = 3
|
||||
)
|
||||
|
||||
// Discoverer is an interface that is responsible for maintaining
|
||||
|
@ -15,22 +16,28 @@ type Discoverer interface {
|
|||
PoolCount() int
|
||||
RequestRemote(int)
|
||||
RegisterBadAddr(string)
|
||||
RegisterGoodAddr(string)
|
||||
UnregisterConnectedAddr(string)
|
||||
UnconnectedPeers() []string
|
||||
BadPeers() []string
|
||||
GoodPeers() []string
|
||||
}
|
||||
|
||||
// DefaultDiscovery default implementation of the Discoverer interface.
|
||||
type DefaultDiscovery struct {
|
||||
transport Transporter
|
||||
dialTimeout time.Duration
|
||||
addrs map[string]bool
|
||||
badAddrs map[string]bool
|
||||
unconnectedAddrs map[string]bool
|
||||
connectedAddrs map[string]bool
|
||||
goodAddrs map[string]bool
|
||||
unconnectedAddrs map[string]int
|
||||
requestCh chan int
|
||||
connectedCh chan string
|
||||
backFill chan string
|
||||
badAddrCh chan string
|
||||
pool chan string
|
||||
goodCh chan string
|
||||
unconnectedCh chan string
|
||||
}
|
||||
|
||||
// NewDefaultDiscovery returns a new DefaultDiscovery.
|
||||
|
@ -38,11 +45,14 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
|
|||
d := &DefaultDiscovery{
|
||||
transport: ts,
|
||||
dialTimeout: dt,
|
||||
addrs: make(map[string]bool),
|
||||
badAddrs: make(map[string]bool),
|
||||
unconnectedAddrs: make(map[string]bool),
|
||||
connectedAddrs: make(map[string]bool),
|
||||
goodAddrs: make(map[string]bool),
|
||||
unconnectedAddrs: make(map[string]int),
|
||||
requestCh: make(chan int),
|
||||
connectedCh: make(chan string),
|
||||
goodCh: make(chan string),
|
||||
unconnectedCh: make(chan string),
|
||||
backFill: make(chan string),
|
||||
badAddrCh: make(chan string),
|
||||
pool: make(chan string, maxPoolSize),
|
||||
|
@ -54,9 +64,6 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
|
|||
// BackFill implements the Discoverer interface and will backfill the
|
||||
// the pool with the given addresses.
|
||||
func (d *DefaultDiscovery) BackFill(addrs ...string) {
|
||||
if len(d.pool) == maxPoolSize {
|
||||
return
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
d.backFill <- addr
|
||||
}
|
||||
|
@ -67,6 +74,17 @@ func (d *DefaultDiscovery) PoolCount() int {
|
|||
return len(d.pool)
|
||||
}
|
||||
|
||||
// pushToPoolOrDrop tries to push address given into the pool, but if the pool
|
||||
// is already full, it just drops it
|
||||
func (d *DefaultDiscovery) pushToPoolOrDrop(addr string) {
|
||||
select {
|
||||
case d.pool <- addr:
|
||||
// ok, queued
|
||||
default:
|
||||
// whatever
|
||||
}
|
||||
}
|
||||
|
||||
// RequestRemote will try to establish a connection with n nodes.
|
||||
func (d *DefaultDiscovery) RequestRemote(n int) {
|
||||
d.requestCh <- n
|
||||
|
@ -96,57 +114,87 @@ func (d *DefaultDiscovery) BadPeers() []string {
|
|||
return addrs
|
||||
}
|
||||
|
||||
func (d *DefaultDiscovery) work(addrCh chan string) {
|
||||
for {
|
||||
addr := <-addrCh
|
||||
if err := d.transport.Dial(addr, d.dialTimeout); err != nil {
|
||||
d.badAddrCh <- addr
|
||||
} else {
|
||||
d.connectedCh <- addr
|
||||
}
|
||||
// GoodPeers returns all addresses of known good peers (that at least once
|
||||
// succeded handshaking with us).
|
||||
func (d *DefaultDiscovery) GoodPeers() []string {
|
||||
addrs := make([]string, 0, len(d.goodAddrs))
|
||||
for addr := range d.goodAddrs {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
||||
// RegisterGoodAddr registers good known connected address that passed
|
||||
// handshake successfuly.
|
||||
func (d *DefaultDiscovery) RegisterGoodAddr(s string) {
|
||||
d.goodCh <- s
|
||||
}
|
||||
|
||||
// UnregisterConnectedAddr tells discoverer that this address is no longer
|
||||
// connected, but it still is considered as good one.
|
||||
func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) {
|
||||
d.unconnectedCh <- s
|
||||
}
|
||||
|
||||
func (d *DefaultDiscovery) tryAddress(addr string) {
|
||||
if err := d.transport.Dial(addr, d.dialTimeout); err != nil {
|
||||
d.badAddrCh <- addr
|
||||
} else {
|
||||
d.connectedCh <- addr
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultDiscovery) next() string {
|
||||
return <-d.pool
|
||||
func (d *DefaultDiscovery) requestToWork() {
|
||||
var requested int
|
||||
|
||||
for {
|
||||
for requested = <-d.requestCh; requested > 0; requested-- {
|
||||
select {
|
||||
case r := <-d.requestCh:
|
||||
if requested < r {
|
||||
requested = r
|
||||
}
|
||||
case addr := <-d.pool:
|
||||
if !d.connectedAddrs[addr] {
|
||||
go d.tryAddress(addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultDiscovery) run() {
|
||||
var (
|
||||
maxWorkers = 5
|
||||
workCh = make(chan string)
|
||||
)
|
||||
|
||||
for i := 0; i < maxWorkers; i++ {
|
||||
go d.work(workCh)
|
||||
}
|
||||
|
||||
go d.requestToWork()
|
||||
for {
|
||||
select {
|
||||
case addr := <-d.backFill:
|
||||
if _, ok := d.badAddrs[addr]; ok {
|
||||
if d.badAddrs[addr] || d.connectedAddrs[addr] ||
|
||||
d.unconnectedAddrs[addr] > 0 {
|
||||
break
|
||||
}
|
||||
if _, ok := d.addrs[addr]; !ok {
|
||||
d.addrs[addr] = true
|
||||
d.unconnectedAddrs[addr] = true
|
||||
d.pool <- addr
|
||||
}
|
||||
case n := <-d.requestCh:
|
||||
go func() {
|
||||
for i := 0; i < n; i++ {
|
||||
workCh <- d.next()
|
||||
}
|
||||
}()
|
||||
d.unconnectedAddrs[addr] = connRetries
|
||||
d.pushToPoolOrDrop(addr)
|
||||
case addr := <-d.badAddrCh:
|
||||
d.badAddrs[addr] = true
|
||||
delete(d.unconnectedAddrs, addr)
|
||||
go func() {
|
||||
workCh <- d.next()
|
||||
}()
|
||||
d.unconnectedAddrs[addr]--
|
||||
if d.unconnectedAddrs[addr] > 0 {
|
||||
d.pushToPoolOrDrop(addr)
|
||||
} else {
|
||||
d.badAddrs[addr] = true
|
||||
delete(d.unconnectedAddrs, addr)
|
||||
}
|
||||
d.RequestRemote(1)
|
||||
|
||||
case addr := <-d.connectedCh:
|
||||
delete(d.unconnectedAddrs, addr)
|
||||
if !d.connectedAddrs[addr] {
|
||||
d.connectedAddrs[addr] = true
|
||||
}
|
||||
case addr := <-d.goodCh:
|
||||
if !d.goodAddrs[addr] {
|
||||
d.goodAddrs[addr] = true
|
||||
}
|
||||
case addr := <-d.unconnectedCh:
|
||||
delete(d.connectedAddrs, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
27
pkg/network/handshakestage_string.go
Normal file
27
pkg/network/handshakestage_string.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
|
||||
|
||||
package network
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[nothingDone-0]
|
||||
_ = x[versionSent-1]
|
||||
_ = x[versionReceived-2]
|
||||
_ = x[verAckSent-3]
|
||||
_ = x[verAckReceived-4]
|
||||
}
|
||||
|
||||
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
|
||||
|
||||
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
|
||||
|
||||
func (i handShakeStage) String() string {
|
||||
if i >= handShakeStage(len(_handShakeStage_index)-1) {
|
||||
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
|
||||
}
|
|
@ -93,12 +93,15 @@ func (chain testChain) Verify(*transaction.Transaction) error {
|
|||
|
||||
type testDiscovery struct{}
|
||||
|
||||
func (d testDiscovery) BackFill(addrs ...string) {}
|
||||
func (d testDiscovery) PoolCount() int { return 0 }
|
||||
func (d testDiscovery) RegisterBadAddr(string) {}
|
||||
func (d testDiscovery) UnconnectedPeers() []string { return []string{} }
|
||||
func (d testDiscovery) RequestRemote(n int) {}
|
||||
func (d testDiscovery) BadPeers() []string { return []string{} }
|
||||
func (d testDiscovery) BackFill(addrs ...string) {}
|
||||
func (d testDiscovery) PoolCount() int { return 0 }
|
||||
func (d testDiscovery) RegisterBadAddr(string) {}
|
||||
func (d testDiscovery) RegisterGoodAddr(string) {}
|
||||
func (d testDiscovery) UnregisterConnectedAddr(string) {}
|
||||
func (d testDiscovery) UnconnectedPeers() []string { return []string{} }
|
||||
func (d testDiscovery) RequestRemote(n int) {}
|
||||
func (d testDiscovery) BadPeers() []string { return []string{} }
|
||||
func (d testDiscovery) GoodPeers() []string { return []string{} }
|
||||
|
||||
type localTransport struct{}
|
||||
|
||||
|
@ -114,6 +117,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {}
|
|||
type localPeer struct {
|
||||
netaddr net.TCPAddr
|
||||
version *payload.Version
|
||||
handshaked bool
|
||||
t *testing.T
|
||||
messageHandler func(t *testing.T, msg *Message)
|
||||
}
|
||||
|
@ -142,8 +146,23 @@ func (p *localPeer) Done() chan error {
|
|||
func (p *localPeer) Version() *payload.Version {
|
||||
return p.version
|
||||
}
|
||||
func (p *localPeer) SetVersion(v *payload.Version) {
|
||||
func (p *localPeer) HandleVersion(v *payload.Version) error {
|
||||
p.version = v
|
||||
return nil
|
||||
}
|
||||
func (p *localPeer) SendVersion(m *Message) error {
|
||||
return p.WriteMsg(m)
|
||||
}
|
||||
func (p *localPeer) SendVersionAck(m *Message) error {
|
||||
return p.WriteMsg(m)
|
||||
}
|
||||
func (p *localPeer) HandleVersionAck() error {
|
||||
p.handshaked = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *localPeer) Handshaked() bool {
|
||||
return p.handshaked
|
||||
}
|
||||
|
||||
func newTestServer() *Server {
|
||||
|
|
|
@ -3,6 +3,7 @@ package payload
|
|||
import (
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/util"
|
||||
|
@ -47,11 +48,28 @@ func (p *AddressAndTime) EncodeBinary(w io.Writer) error {
|
|||
return bw.Err
|
||||
}
|
||||
|
||||
// IPPortString makes a string from IP and port specified.
|
||||
func (p *AddressAndTime) IPPortString() string {
|
||||
var netip net.IP = make(net.IP, 16)
|
||||
|
||||
copy(netip, p.IP[:])
|
||||
port := strconv.Itoa(int(p.Port))
|
||||
return netip.String() + ":" + port
|
||||
}
|
||||
|
||||
// AddressList is a list with AddrAndTime.
|
||||
type AddressList struct {
|
||||
Addrs []*AddressAndTime
|
||||
}
|
||||
|
||||
// NewAddressList creates a list for n AddressAndTime elements.
|
||||
func NewAddressList(n int) *AddressList {
|
||||
alist := AddressList{
|
||||
Addrs: make([]*AddressAndTime, n),
|
||||
}
|
||||
return &alist
|
||||
}
|
||||
|
||||
// DecodeBinary implements the Payload interface.
|
||||
func (p *AddressList) DecodeBinary(r io.Reader) error {
|
||||
br := util.BinReader{R: r}
|
||||
|
|
|
@ -35,7 +35,7 @@ func TestEncodeDecodeAddress(t *testing.T) {
|
|||
|
||||
func TestEncodeDecodeAddressList(t *testing.T) {
|
||||
var lenList uint8 = 4
|
||||
addrList := &AddressList{make([]*AddressAndTime, lenList)}
|
||||
addrList := NewAddressList(int(lenList))
|
||||
for i := 0; i < int(lenList); i++ {
|
||||
e, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:200%d", i))
|
||||
addrList.Addrs[i] = NewAddressAndTime(e, time.Now())
|
||||
|
|
|
@ -7,3 +7,22 @@ type Payload interface {
|
|||
EncodeBinary(io.Writer) error
|
||||
DecodeBinary(io.Reader) error
|
||||
}
|
||||
|
||||
// NullPayload is a dummy payload with no fields.
|
||||
type NullPayload struct {
|
||||
}
|
||||
|
||||
// NewNullPayload returns zero-sized stub payload.
|
||||
func NewNullPayload() *NullPayload {
|
||||
return &NullPayload{}
|
||||
}
|
||||
|
||||
// DecodeBinary implements the Payload interface.
|
||||
func (p *NullPayload) DecodeBinary(r io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeBinary implements the Payload interface.
|
||||
func (p *NullPayload) EncodeBinary(r io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -13,5 +13,9 @@ type Peer interface {
|
|||
WriteMsg(msg *Message) error
|
||||
Done() chan error
|
||||
Version() *payload.Version
|
||||
SetVersion(*payload.Version)
|
||||
Handshaked() bool
|
||||
SendVersion(*Message) error
|
||||
SendVersionAck(*Message) error
|
||||
HandleVersion(*payload.Version) error
|
||||
HandleVersionAck() error
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -15,13 +16,15 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
minPeers = 5
|
||||
maxBlockBatch = 200
|
||||
minPoolCount = 30
|
||||
// peer numbers are arbitrary at the moment
|
||||
minPeers = 5
|
||||
maxPeers = 20
|
||||
maxBlockBatch = 200
|
||||
maxAddrsToSend = 200
|
||||
minPoolCount = 30
|
||||
)
|
||||
|
||||
var (
|
||||
errPortMismatch = errors.New("port mismatch")
|
||||
errIdenticalID = errors.New("identical node id")
|
||||
errInvalidHandshake = errors.New("invalid handshake")
|
||||
errInvalidNetwork = errors.New("invalid network")
|
||||
|
@ -46,6 +49,7 @@ type (
|
|||
lock sync.RWMutex
|
||||
peers map[Peer]bool
|
||||
|
||||
addrReq chan *Message
|
||||
register chan Peer
|
||||
unregister chan peerDrop
|
||||
quit chan struct{}
|
||||
|
@ -64,6 +68,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer) *Server {
|
|||
chain: chain,
|
||||
id: rand.Uint32(),
|
||||
quit: make(chan struct{}),
|
||||
addrReq: make(chan *Message, minPeers),
|
||||
register: make(chan Peer),
|
||||
unregister: make(chan peerDrop),
|
||||
peers: make(map[Peer]bool),
|
||||
|
@ -90,12 +95,7 @@ func (s *Server) Start(errChan chan error) {
|
|||
"headerHeight": s.chain.HeaderHeight(),
|
||||
}).Info("node started")
|
||||
|
||||
for _, addr := range s.Seeds {
|
||||
if err := s.transport.Dial(addr, s.DialTimeout); err != nil {
|
||||
log.Warnf("failed to connect to remote node %s", addr)
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.discovery.BackFill(s.Seeds...)
|
||||
|
||||
go s.transport.Accept()
|
||||
s.run()
|
||||
|
@ -122,6 +122,19 @@ func (s *Server) BadPeers() []string {
|
|||
|
||||
func (s *Server) run() {
|
||||
for {
|
||||
c := s.PeerCount()
|
||||
if c < minPeers {
|
||||
s.discovery.RequestRemote(maxPeers - c)
|
||||
}
|
||||
if s.discovery.PoolCount() < minPoolCount {
|
||||
select {
|
||||
case s.addrReq <- NewMessage(s.Net, CMDGetAddr, payload.NewNullPayload()):
|
||||
// sent request
|
||||
default:
|
||||
// we have one in the queue already that is
|
||||
// gonna be served by some worker when it's ready
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-s.quit:
|
||||
s.transport.Close()
|
||||
|
@ -141,12 +154,19 @@ func (s *Server) run() {
|
|||
"addr": p.NetAddr(),
|
||||
}).Info("new peer connected")
|
||||
case drop := <-s.unregister:
|
||||
delete(s.peers, drop.peer)
|
||||
log.WithFields(log.Fields{
|
||||
"addr": drop.peer.NetAddr(),
|
||||
"reason": drop.reason,
|
||||
"peerCount": s.PeerCount(),
|
||||
}).Warn("peer disconnected")
|
||||
if s.peers[drop.peer] {
|
||||
delete(s.peers, drop.peer)
|
||||
log.WithFields(log.Fields{
|
||||
"addr": drop.peer.NetAddr(),
|
||||
"reason": drop.reason,
|
||||
"peerCount": s.PeerCount(),
|
||||
}).Warn("peer disconnected")
|
||||
addr := drop.peer.NetAddr().String()
|
||||
s.discovery.UnregisterConnectedAddr(addr)
|
||||
s.discovery.BackFill(addr)
|
||||
}
|
||||
// else the peer is already gone, which can happen
|
||||
// because we have two goroutines sending signals here
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -174,20 +194,34 @@ func (s *Server) startProtocol(p Peer) {
|
|||
"id": p.Version().Nonce,
|
||||
}).Info("started protocol")
|
||||
|
||||
s.requestHeaders(p)
|
||||
s.discovery.RegisterGoodAddr(p.NetAddr().String())
|
||||
err := s.requestHeaders(p)
|
||||
if err != nil {
|
||||
p.Disconnect(err)
|
||||
return
|
||||
}
|
||||
|
||||
timer := time.NewTimer(s.ProtoTickInterval)
|
||||
for {
|
||||
select {
|
||||
case err := <-p.Done():
|
||||
s.unregister <- peerDrop{p, err}
|
||||
return
|
||||
case err = <-p.Done():
|
||||
// time to stop
|
||||
case m := <-s.addrReq:
|
||||
err = p.WriteMsg(m)
|
||||
case <-timer.C:
|
||||
// Try to sync in headers and block with the peer if his block height is higher then ours.
|
||||
if p.Version().StartHeight > s.chain.BlockHeight() {
|
||||
s.requestBlocks(p)
|
||||
err = s.requestBlocks(p)
|
||||
}
|
||||
timer.Reset(s.ProtoTickInterval)
|
||||
if err == nil {
|
||||
timer.Reset(s.ProtoTickInterval)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
s.unregister <- peerDrop{p, err}
|
||||
timer.Stop()
|
||||
p.Disconnect(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -201,20 +235,23 @@ func (s *Server) sendVersion(p Peer) error {
|
|||
s.chain.BlockHeight(),
|
||||
s.Relay,
|
||||
)
|
||||
return p.WriteMsg(NewMessage(s.Net, CMDVersion, payload))
|
||||
return p.SendVersion(NewMessage(s.Net, CMDVersion, payload))
|
||||
}
|
||||
|
||||
// When a peer sends out his version we reply with verack after validating
|
||||
// the version.
|
||||
func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
|
||||
if p.NetAddr().Port != int(version.Port) {
|
||||
return errPortMismatch
|
||||
err := p.HandleVersion(version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.id == version.Nonce {
|
||||
return errIdenticalID
|
||||
}
|
||||
p.SetVersion(version)
|
||||
return p.WriteMsg(NewMessage(s.Net, CMDVerack, nil))
|
||||
if p.NetAddr().Port != int(version.Port) {
|
||||
return fmt.Errorf("port mismatch: connected to %d and peer sends %d", p.NetAddr().Port, version.Port)
|
||||
}
|
||||
return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil))
|
||||
}
|
||||
|
||||
// handleHeadersCmd will process the headers it received from its peer.
|
||||
|
@ -251,18 +288,42 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
|
|||
return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
|
||||
}
|
||||
|
||||
// handleAddrCmd will process received addresses.
|
||||
func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
|
||||
for _, a := range addrs.Addrs {
|
||||
s.discovery.BackFill(a.IPPortString())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleGetAddrCmd sends to the peer some good addresses that we know of.
|
||||
func (s *Server) handleGetAddrCmd(p Peer) error {
|
||||
addrs := s.discovery.GoodPeers()
|
||||
if len(addrs) > maxAddrsToSend {
|
||||
addrs = addrs[:maxAddrsToSend]
|
||||
}
|
||||
alist := payload.NewAddressList(len(addrs))
|
||||
ts := time.Now()
|
||||
for i, addr := range addrs {
|
||||
// we know it's a good address, so it can't fail
|
||||
netaddr, _ := net.ResolveTCPAddr("tcp", addr)
|
||||
alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts)
|
||||
}
|
||||
return p.WriteMsg(NewMessage(s.Net, CMDAddr, alist))
|
||||
}
|
||||
|
||||
// requestHeaders will send a getheaders message to the peer.
|
||||
// The peer will respond with headers op to a count of 2000.
|
||||
func (s *Server) requestHeaders(p Peer) {
|
||||
func (s *Server) requestHeaders(p Peer) error {
|
||||
start := []util.Uint256{s.chain.CurrentHeaderHash()}
|
||||
payload := payload.NewGetBlocks(start, util.Uint256{})
|
||||
p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload))
|
||||
return p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload))
|
||||
}
|
||||
|
||||
// requestBlocks will send a getdata message to the peer
|
||||
// to sync up in blocks. A maximum of maxBlockBatch will
|
||||
// send at once.
|
||||
func (s *Server) requestBlocks(p Peer) {
|
||||
func (s *Server) requestBlocks(p Peer) error {
|
||||
var (
|
||||
hashes []util.Uint256
|
||||
hashStart = s.chain.BlockHeight() + 1
|
||||
|
@ -275,10 +336,11 @@ func (s *Server) requestBlocks(p Peer) {
|
|||
}
|
||||
if len(hashes) > 0 {
|
||||
payload := payload.NewInventory(payload.BlockType, hashes)
|
||||
p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
|
||||
return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
|
||||
} else if s.chain.HeaderHeight() < p.Version().StartHeight {
|
||||
s.requestHeaders(p)
|
||||
return s.requestHeaders(p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMessage will process the given message.
|
||||
|
@ -289,26 +351,40 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
return errInvalidNetwork
|
||||
}
|
||||
|
||||
switch msg.CommandType() {
|
||||
case CMDVersion:
|
||||
version := msg.Payload.(*payload.Version)
|
||||
return s.handleVersionCmd(peer, version)
|
||||
case CMDHeaders:
|
||||
headers := msg.Payload.(*payload.Headers)
|
||||
go s.handleHeadersCmd(peer, headers)
|
||||
case CMDInv:
|
||||
inventory := msg.Payload.(*payload.Inventory)
|
||||
return s.handleInvCmd(peer, inventory)
|
||||
case CMDBlock:
|
||||
block := msg.Payload.(*core.Block)
|
||||
return s.handleBlockCmd(peer, block)
|
||||
case CMDVerack:
|
||||
// Make sure this peer has send his version before we start the
|
||||
// protocol with that peer.
|
||||
if peer.Version() == nil {
|
||||
return errInvalidHandshake
|
||||
if peer.Handshaked() {
|
||||
switch msg.CommandType() {
|
||||
case CMDAddr:
|
||||
addrs := msg.Payload.(*payload.AddressList)
|
||||
return s.handleAddrCmd(peer, addrs)
|
||||
case CMDGetAddr:
|
||||
// it has no payload
|
||||
return s.handleGetAddrCmd(peer)
|
||||
case CMDHeaders:
|
||||
headers := msg.Payload.(*payload.Headers)
|
||||
go s.handleHeadersCmd(peer, headers)
|
||||
case CMDInv:
|
||||
inventory := msg.Payload.(*payload.Inventory)
|
||||
return s.handleInvCmd(peer, inventory)
|
||||
case CMDBlock:
|
||||
block := msg.Payload.(*core.Block)
|
||||
return s.handleBlockCmd(peer, block)
|
||||
case CMDVersion, CMDVerack:
|
||||
return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
|
||||
}
|
||||
} else {
|
||||
switch msg.CommandType() {
|
||||
case CMDVersion:
|
||||
version := msg.Payload.(*payload.Version)
|
||||
return s.handleVersionCmd(peer, version)
|
||||
case CMDVerack:
|
||||
err := peer.HandleVersionAck()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.startProtocol(peer)
|
||||
default:
|
||||
return fmt.Errorf("received '%s' during handshake", msg.CommandType())
|
||||
}
|
||||
go s.startProtocol(peer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ func TestServerNotSendsVerack(t *testing.T) {
|
|||
version := payload.NewVersion(1337, 2000, "/NEO-GO/", 0, true)
|
||||
err := s.handleVersionCmd(p, version)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, errPortMismatch, err)
|
||||
assert.Contains(t, err.Error(), "port mismatch")
|
||||
|
||||
// identical id's
|
||||
version = payload.NewVersion(1, 3000, "/NEO-GO/", 0, true)
|
||||
|
|
|
@ -1,12 +1,29 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
)
|
||||
|
||||
type handShakeStage uint8
|
||||
|
||||
//go:generate stringer -type=handShakeStage
|
||||
const (
|
||||
nothingDone handShakeStage = 0
|
||||
versionSent handShakeStage = 1
|
||||
versionReceived handShakeStage = 2
|
||||
verAckSent handShakeStage = 3
|
||||
verAckReceived handShakeStage = 4
|
||||
)
|
||||
|
||||
var (
|
||||
errStateMismatch = errors.New("tried to send protocol message before handshake completed")
|
||||
)
|
||||
|
||||
// TCPPeer represents a connected remote node in the
|
||||
// network over TCP.
|
||||
type TCPPeer struct {
|
||||
|
@ -17,6 +34,8 @@ type TCPPeer struct {
|
|||
// The version of the peer.
|
||||
version *payload.Version
|
||||
|
||||
handShake handShakeStage
|
||||
|
||||
done chan error
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
@ -35,8 +54,16 @@ func NewTCPPeer(conn net.Conn) *TCPPeer {
|
|||
}
|
||||
|
||||
// WriteMsg implements the Peer interface. This will write/encode the message
|
||||
// to the underlying connection.
|
||||
// to the underlying connection, this only works for messages other than Version
|
||||
// or VerAck.
|
||||
func (p *TCPPeer) WriteMsg(msg *Message) error {
|
||||
if !p.Handshaked() {
|
||||
return errStateMismatch
|
||||
}
|
||||
return p.writeMsg(msg)
|
||||
}
|
||||
|
||||
func (p *TCPPeer) writeMsg(msg *Message) error {
|
||||
select {
|
||||
case err := <-p.done:
|
||||
return err
|
||||
|
@ -45,6 +72,55 @@ func (p *TCPPeer) WriteMsg(msg *Message) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Handshaked returns status of the handshake, whether it's completed or not.
|
||||
func (p *TCPPeer) Handshaked() bool {
|
||||
return p.handShake == verAckReceived
|
||||
}
|
||||
|
||||
// SendVersion checks for the handshake state and sends a message to the peer.
|
||||
func (p *TCPPeer) SendVersion(msg *Message) error {
|
||||
if p.handShake != nothingDone {
|
||||
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
|
||||
}
|
||||
err := p.writeMsg(msg)
|
||||
if err == nil {
|
||||
p.handShake = versionSent
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleVersion checks for the handshake state and version message contents.
|
||||
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
|
||||
if p.handShake != versionSent {
|
||||
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
|
||||
}
|
||||
p.version = version
|
||||
p.handShake = versionReceived
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendVersionAck checks for the handshake state and sends a message to the peer.
|
||||
func (p *TCPPeer) SendVersionAck(msg *Message) error {
|
||||
if p.handShake != versionReceived {
|
||||
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
|
||||
}
|
||||
err := p.writeMsg(msg)
|
||||
if err == nil {
|
||||
p.handShake = verAckSent
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleVersionAck checks handshake sequence correctness when VerAck message
|
||||
// is received.
|
||||
func (p *TCPPeer) HandleVersionAck() error {
|
||||
if p.handShake != verAckSent {
|
||||
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
|
||||
}
|
||||
p.handShake = verAckReceived
|
||||
return nil
|
||||
}
|
||||
|
||||
// NetAddr implements the Peer interface.
|
||||
func (p *TCPPeer) NetAddr() *net.TCPAddr {
|
||||
return &p.addr
|
||||
|
@ -59,15 +135,16 @@ func (p *TCPPeer) Done() chan error {
|
|||
|
||||
// Disconnect will fill the peer's done channel with the given error.
|
||||
func (p *TCPPeer) Disconnect(err error) {
|
||||
p.done <- err
|
||||
p.conn.Close()
|
||||
select {
|
||||
case p.done <- err:
|
||||
// one message to the queue
|
||||
default:
|
||||
// the other side may already be gone, it's OK
|
||||
}
|
||||
}
|
||||
|
||||
// Version implements the Peer interface.
|
||||
func (p *TCPPeer) Version() *payload.Version {
|
||||
return p.version
|
||||
}
|
||||
|
||||
// SetVersion implements the Peer interface.
|
||||
func (p *TCPPeer) SetVersion(v *payload.Version) {
|
||||
p.version = v
|
||||
}
|
||||
|
|
|
@ -75,21 +75,19 @@ func (t *TCPTransport) handleConn(conn net.Conn) {
|
|||
err error
|
||||
)
|
||||
|
||||
defer func() {
|
||||
p.Disconnect(err)
|
||||
}()
|
||||
|
||||
t.server.register <- p
|
||||
|
||||
for {
|
||||
msg := &Message{}
|
||||
if err = msg.Decode(p.conn); err != nil {
|
||||
return
|
||||
break
|
||||
}
|
||||
if err = t.server.handleMessage(p, msg); err != nil {
|
||||
return
|
||||
break
|
||||
}
|
||||
}
|
||||
t.server.unregister <- peerDrop{p, err}
|
||||
p.Disconnect(err)
|
||||
}
|
||||
|
||||
// Close implements the Transporter interface.
|
||||
|
|
Loading…
Reference in a new issue