Merge pull request #396 from nspcc-dev/network-reconnections-and-fixes

This one fixes #390 and some connected problems. After this patchset the node reconnects to some other nodes if anything goes wrong and it better senses when something goes wrong. It also fixes some block handling problems based on the testnet connection experience.
This commit is contained in:
Roman Khimov 2019-09-16 16:57:10 +03:00 committed by GitHub
commit adba9e11ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 427 additions and 119 deletions

View file

@ -357,8 +357,7 @@ func (bc *Blockchain) persistBlock(block *Block) error {
Email: t.Email, Email: t.Email,
Description: t.Description, Description: t.Description,
} }
_ = contract
fmt.Printf("%+v", contract)
case *transaction.InvocationTX: case *transaction.InvocationTX:
} }
@ -430,6 +429,15 @@ func (bc *Blockchain) persist(ctx context.Context) (err error) {
"blockHeight": bc.BlockHeight(), "blockHeight": bc.BlockHeight(),
"took": time.Since(start), "took": time.Since(start),
}).Info("blockchain persist completed") }).Info("blockchain persist completed")
} else {
// So we have some blocks in cache but can't persist them?
// Either there are some stale blocks there or the other way
// around (which was seen in practice) --- there are some fresh
// blocks that we can't persist yet. Some of the latter can be useful
// or can be bogus (higher than the header height we expect at
// the moment). So try to reap oldies and strange newbies, if
// there are any.
bc.blockCache.ReapStrangeBlocks(bc.BlockHeight(), bc.HeaderHeight())
} }
return return

View file

@ -71,3 +71,17 @@ func (c *Cache) Delete(h util.Uint256) {
defer c.lock.Unlock() defer c.lock.Unlock()
delete(c.m, h) delete(c.m, h)
} }
// ReapStrangeBlocks drops blocks from cache that don't fit into the
// blkHeight-headHeight interval. Cache should only contain blocks that we
// expect to get and store.
func (c *Cache) ReapStrangeBlocks(blkHeight, headHeight uint32) {
c.lock.Lock()
defer c.lock.Unlock()
for i, b := range c.m {
block, ok := b.(*Block)
if ok && (block.Index < blkHeight || block.Index > headHeight) {
delete(c.m, i)
}
}
}

View file

@ -6,6 +6,7 @@ import (
const ( const (
maxPoolSize = 200 maxPoolSize = 200
connRetries = 3
) )
// Discoverer is an interface that is responsible for maintaining // Discoverer is an interface that is responsible for maintaining
@ -15,22 +16,28 @@ type Discoverer interface {
PoolCount() int PoolCount() int
RequestRemote(int) RequestRemote(int)
RegisterBadAddr(string) RegisterBadAddr(string)
RegisterGoodAddr(string)
UnregisterConnectedAddr(string)
UnconnectedPeers() []string UnconnectedPeers() []string
BadPeers() []string BadPeers() []string
GoodPeers() []string
} }
// DefaultDiscovery default implementation of the Discoverer interface. // DefaultDiscovery default implementation of the Discoverer interface.
type DefaultDiscovery struct { type DefaultDiscovery struct {
transport Transporter transport Transporter
dialTimeout time.Duration dialTimeout time.Duration
addrs map[string]bool
badAddrs map[string]bool badAddrs map[string]bool
unconnectedAddrs map[string]bool connectedAddrs map[string]bool
goodAddrs map[string]bool
unconnectedAddrs map[string]int
requestCh chan int requestCh chan int
connectedCh chan string connectedCh chan string
backFill chan string backFill chan string
badAddrCh chan string badAddrCh chan string
pool chan string pool chan string
goodCh chan string
unconnectedCh chan string
} }
// NewDefaultDiscovery returns a new DefaultDiscovery. // NewDefaultDiscovery returns a new DefaultDiscovery.
@ -38,11 +45,14 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
d := &DefaultDiscovery{ d := &DefaultDiscovery{
transport: ts, transport: ts,
dialTimeout: dt, dialTimeout: dt,
addrs: make(map[string]bool),
badAddrs: make(map[string]bool), badAddrs: make(map[string]bool),
unconnectedAddrs: make(map[string]bool), connectedAddrs: make(map[string]bool),
goodAddrs: make(map[string]bool),
unconnectedAddrs: make(map[string]int),
requestCh: make(chan int), requestCh: make(chan int),
connectedCh: make(chan string), connectedCh: make(chan string),
goodCh: make(chan string),
unconnectedCh: make(chan string),
backFill: make(chan string), backFill: make(chan string),
badAddrCh: make(chan string), badAddrCh: make(chan string),
pool: make(chan string, maxPoolSize), pool: make(chan string, maxPoolSize),
@ -54,9 +64,6 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
// BackFill implements the Discoverer interface and will backfill the // BackFill implements the Discoverer interface and will backfill the
// the pool with the given addresses. // the pool with the given addresses.
func (d *DefaultDiscovery) BackFill(addrs ...string) { func (d *DefaultDiscovery) BackFill(addrs ...string) {
if len(d.pool) == maxPoolSize {
return
}
for _, addr := range addrs { for _, addr := range addrs {
d.backFill <- addr d.backFill <- addr
} }
@ -67,6 +74,17 @@ func (d *DefaultDiscovery) PoolCount() int {
return len(d.pool) return len(d.pool)
} }
// pushToPoolOrDrop tries to push address given into the pool, but if the pool
// is already full, it just drops it
func (d *DefaultDiscovery) pushToPoolOrDrop(addr string) {
select {
case d.pool <- addr:
// ok, queued
default:
// whatever
}
}
// RequestRemote will try to establish a connection with n nodes. // RequestRemote will try to establish a connection with n nodes.
func (d *DefaultDiscovery) RequestRemote(n int) { func (d *DefaultDiscovery) RequestRemote(n int) {
d.requestCh <- n d.requestCh <- n
@ -96,57 +114,87 @@ func (d *DefaultDiscovery) BadPeers() []string {
return addrs return addrs
} }
func (d *DefaultDiscovery) work(addrCh chan string) { // GoodPeers returns all addresses of known good peers (that at least once
for { // succeded handshaking with us).
addr := <-addrCh func (d *DefaultDiscovery) GoodPeers() []string {
addrs := make([]string, 0, len(d.goodAddrs))
for addr := range d.goodAddrs {
addrs = append(addrs, addr)
}
return addrs
}
// RegisterGoodAddr registers good known connected address that passed
// handshake successfuly.
func (d *DefaultDiscovery) RegisterGoodAddr(s string) {
d.goodCh <- s
}
// UnregisterConnectedAddr tells discoverer that this address is no longer
// connected, but it still is considered as good one.
func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) {
d.unconnectedCh <- s
}
func (d *DefaultDiscovery) tryAddress(addr string) {
if err := d.transport.Dial(addr, d.dialTimeout); err != nil { if err := d.transport.Dial(addr, d.dialTimeout); err != nil {
d.badAddrCh <- addr d.badAddrCh <- addr
} else { } else {
d.connectedCh <- addr d.connectedCh <- addr
} }
} }
}
func (d *DefaultDiscovery) next() string { func (d *DefaultDiscovery) requestToWork() {
return <-d.pool var requested int
for {
for requested = <-d.requestCh; requested > 0; requested-- {
select {
case r := <-d.requestCh:
if requested < r {
requested = r
}
case addr := <-d.pool:
if !d.connectedAddrs[addr] {
go d.tryAddress(addr)
}
}
}
}
} }
func (d *DefaultDiscovery) run() { func (d *DefaultDiscovery) run() {
var ( go d.requestToWork()
maxWorkers = 5
workCh = make(chan string)
)
for i := 0; i < maxWorkers; i++ {
go d.work(workCh)
}
for { for {
select { select {
case addr := <-d.backFill: case addr := <-d.backFill:
if _, ok := d.badAddrs[addr]; ok { if d.badAddrs[addr] || d.connectedAddrs[addr] ||
d.unconnectedAddrs[addr] > 0 {
break break
} }
if _, ok := d.addrs[addr]; !ok { d.unconnectedAddrs[addr] = connRetries
d.addrs[addr] = true d.pushToPoolOrDrop(addr)
d.unconnectedAddrs[addr] = true
d.pool <- addr
}
case n := <-d.requestCh:
go func() {
for i := 0; i < n; i++ {
workCh <- d.next()
}
}()
case addr := <-d.badAddrCh: case addr := <-d.badAddrCh:
d.unconnectedAddrs[addr]--
if d.unconnectedAddrs[addr] > 0 {
d.pushToPoolOrDrop(addr)
} else {
d.badAddrs[addr] = true d.badAddrs[addr] = true
delete(d.unconnectedAddrs, addr) delete(d.unconnectedAddrs, addr)
go func() { }
workCh <- d.next() d.RequestRemote(1)
}()
case addr := <-d.connectedCh: case addr := <-d.connectedCh:
delete(d.unconnectedAddrs, addr) delete(d.unconnectedAddrs, addr)
if !d.connectedAddrs[addr] {
d.connectedAddrs[addr] = true
}
case addr := <-d.goodCh:
if !d.goodAddrs[addr] {
d.goodAddrs[addr] = true
}
case addr := <-d.unconnectedCh:
delete(d.connectedAddrs, addr)
} }
} }
} }

View file

@ -0,0 +1,27 @@
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
package network
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[nothingDone-0]
_ = x[versionSent-1]
_ = x[versionReceived-2]
_ = x[verAckSent-3]
_ = x[verAckReceived-4]
}
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
func (i handShakeStage) String() string {
if i >= handShakeStage(len(_handShakeStage_index)-1) {
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
}

View file

@ -96,9 +96,12 @@ type testDiscovery struct{}
func (d testDiscovery) BackFill(addrs ...string) {} func (d testDiscovery) BackFill(addrs ...string) {}
func (d testDiscovery) PoolCount() int { return 0 } func (d testDiscovery) PoolCount() int { return 0 }
func (d testDiscovery) RegisterBadAddr(string) {} func (d testDiscovery) RegisterBadAddr(string) {}
func (d testDiscovery) RegisterGoodAddr(string) {}
func (d testDiscovery) UnregisterConnectedAddr(string) {}
func (d testDiscovery) UnconnectedPeers() []string { return []string{} } func (d testDiscovery) UnconnectedPeers() []string { return []string{} }
func (d testDiscovery) RequestRemote(n int) {} func (d testDiscovery) RequestRemote(n int) {}
func (d testDiscovery) BadPeers() []string { return []string{} } func (d testDiscovery) BadPeers() []string { return []string{} }
func (d testDiscovery) GoodPeers() []string { return []string{} }
type localTransport struct{} type localTransport struct{}
@ -114,6 +117,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {}
type localPeer struct { type localPeer struct {
netaddr net.TCPAddr netaddr net.TCPAddr
version *payload.Version version *payload.Version
handshaked bool
t *testing.T t *testing.T
messageHandler func(t *testing.T, msg *Message) messageHandler func(t *testing.T, msg *Message)
} }
@ -142,8 +146,23 @@ func (p *localPeer) Done() chan error {
func (p *localPeer) Version() *payload.Version { func (p *localPeer) Version() *payload.Version {
return p.version return p.version
} }
func (p *localPeer) SetVersion(v *payload.Version) { func (p *localPeer) HandleVersion(v *payload.Version) error {
p.version = v p.version = v
return nil
}
func (p *localPeer) SendVersion(m *Message) error {
return p.WriteMsg(m)
}
func (p *localPeer) SendVersionAck(m *Message) error {
return p.WriteMsg(m)
}
func (p *localPeer) HandleVersionAck() error {
p.handshaked = true
return nil
}
func (p *localPeer) Handshaked() bool {
return p.handshaked
} }
func newTestServer() *Server { func newTestServer() *Server {

View file

@ -3,6 +3,7 @@ package payload
import ( import (
"io" "io"
"net" "net"
"strconv"
"time" "time"
"github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/util"
@ -47,11 +48,28 @@ func (p *AddressAndTime) EncodeBinary(w io.Writer) error {
return bw.Err return bw.Err
} }
// IPPortString makes a string from IP and port specified.
func (p *AddressAndTime) IPPortString() string {
var netip net.IP = make(net.IP, 16)
copy(netip, p.IP[:])
port := strconv.Itoa(int(p.Port))
return netip.String() + ":" + port
}
// AddressList is a list with AddrAndTime. // AddressList is a list with AddrAndTime.
type AddressList struct { type AddressList struct {
Addrs []*AddressAndTime Addrs []*AddressAndTime
} }
// NewAddressList creates a list for n AddressAndTime elements.
func NewAddressList(n int) *AddressList {
alist := AddressList{
Addrs: make([]*AddressAndTime, n),
}
return &alist
}
// DecodeBinary implements the Payload interface. // DecodeBinary implements the Payload interface.
func (p *AddressList) DecodeBinary(r io.Reader) error { func (p *AddressList) DecodeBinary(r io.Reader) error {
br := util.BinReader{R: r} br := util.BinReader{R: r}

View file

@ -35,7 +35,7 @@ func TestEncodeDecodeAddress(t *testing.T) {
func TestEncodeDecodeAddressList(t *testing.T) { func TestEncodeDecodeAddressList(t *testing.T) {
var lenList uint8 = 4 var lenList uint8 = 4
addrList := &AddressList{make([]*AddressAndTime, lenList)} addrList := NewAddressList(int(lenList))
for i := 0; i < int(lenList); i++ { for i := 0; i < int(lenList); i++ {
e, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:200%d", i)) e, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:200%d", i))
addrList.Addrs[i] = NewAddressAndTime(e, time.Now()) addrList.Addrs[i] = NewAddressAndTime(e, time.Now())

View file

@ -7,3 +7,22 @@ type Payload interface {
EncodeBinary(io.Writer) error EncodeBinary(io.Writer) error
DecodeBinary(io.Reader) error DecodeBinary(io.Reader) error
} }
// NullPayload is a dummy payload with no fields.
type NullPayload struct {
}
// NewNullPayload returns zero-sized stub payload.
func NewNullPayload() *NullPayload {
return &NullPayload{}
}
// DecodeBinary implements the Payload interface.
func (p *NullPayload) DecodeBinary(r io.Reader) error {
return nil
}
// EncodeBinary implements the Payload interface.
func (p *NullPayload) EncodeBinary(r io.Writer) error {
return nil
}

View file

@ -13,5 +13,9 @@ type Peer interface {
WriteMsg(msg *Message) error WriteMsg(msg *Message) error
Done() chan error Done() chan error
Version() *payload.Version Version() *payload.Version
SetVersion(*payload.Version) Handshaked() bool
SendVersion(*Message) error
SendVersionAck(*Message) error
HandleVersion(*payload.Version) error
HandleVersionAck() error
} }

View file

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net"
"sync" "sync"
"time" "time"
@ -15,13 +16,15 @@ import (
) )
const ( const (
// peer numbers are arbitrary at the moment
minPeers = 5 minPeers = 5
maxPeers = 20
maxBlockBatch = 200 maxBlockBatch = 200
maxAddrsToSend = 200
minPoolCount = 30 minPoolCount = 30
) )
var ( var (
errPortMismatch = errors.New("port mismatch")
errIdenticalID = errors.New("identical node id") errIdenticalID = errors.New("identical node id")
errInvalidHandshake = errors.New("invalid handshake") errInvalidHandshake = errors.New("invalid handshake")
errInvalidNetwork = errors.New("invalid network") errInvalidNetwork = errors.New("invalid network")
@ -46,6 +49,7 @@ type (
lock sync.RWMutex lock sync.RWMutex
peers map[Peer]bool peers map[Peer]bool
addrReq chan *Message
register chan Peer register chan Peer
unregister chan peerDrop unregister chan peerDrop
quit chan struct{} quit chan struct{}
@ -64,6 +68,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer) *Server {
chain: chain, chain: chain,
id: rand.Uint32(), id: rand.Uint32(),
quit: make(chan struct{}), quit: make(chan struct{}),
addrReq: make(chan *Message, minPeers),
register: make(chan Peer), register: make(chan Peer),
unregister: make(chan peerDrop), unregister: make(chan peerDrop),
peers: make(map[Peer]bool), peers: make(map[Peer]bool),
@ -90,12 +95,7 @@ func (s *Server) Start(errChan chan error) {
"headerHeight": s.chain.HeaderHeight(), "headerHeight": s.chain.HeaderHeight(),
}).Info("node started") }).Info("node started")
for _, addr := range s.Seeds { s.discovery.BackFill(s.Seeds...)
if err := s.transport.Dial(addr, s.DialTimeout); err != nil {
log.Warnf("failed to connect to remote node %s", addr)
continue
}
}
go s.transport.Accept() go s.transport.Accept()
s.run() s.run()
@ -122,6 +122,19 @@ func (s *Server) BadPeers() []string {
func (s *Server) run() { func (s *Server) run() {
for { for {
c := s.PeerCount()
if c < minPeers {
s.discovery.RequestRemote(maxPeers - c)
}
if s.discovery.PoolCount() < minPoolCount {
select {
case s.addrReq <- NewMessage(s.Net, CMDGetAddr, payload.NewNullPayload()):
// sent request
default:
// we have one in the queue already that is
// gonna be served by some worker when it's ready
}
}
select { select {
case <-s.quit: case <-s.quit:
s.transport.Close() s.transport.Close()
@ -141,12 +154,19 @@ func (s *Server) run() {
"addr": p.NetAddr(), "addr": p.NetAddr(),
}).Info("new peer connected") }).Info("new peer connected")
case drop := <-s.unregister: case drop := <-s.unregister:
if s.peers[drop.peer] {
delete(s.peers, drop.peer) delete(s.peers, drop.peer)
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"addr": drop.peer.NetAddr(), "addr": drop.peer.NetAddr(),
"reason": drop.reason, "reason": drop.reason,
"peerCount": s.PeerCount(), "peerCount": s.PeerCount(),
}).Warn("peer disconnected") }).Warn("peer disconnected")
addr := drop.peer.NetAddr().String()
s.discovery.UnregisterConnectedAddr(addr)
s.discovery.BackFill(addr)
}
// else the peer is already gone, which can happen
// because we have two goroutines sending signals here
} }
} }
} }
@ -174,22 +194,36 @@ func (s *Server) startProtocol(p Peer) {
"id": p.Version().Nonce, "id": p.Version().Nonce,
}).Info("started protocol") }).Info("started protocol")
s.requestHeaders(p) s.discovery.RegisterGoodAddr(p.NetAddr().String())
err := s.requestHeaders(p)
if err != nil {
p.Disconnect(err)
return
}
timer := time.NewTimer(s.ProtoTickInterval) timer := time.NewTimer(s.ProtoTickInterval)
for { for {
select { select {
case err := <-p.Done(): case err = <-p.Done():
s.unregister <- peerDrop{p, err} // time to stop
return case m := <-s.addrReq:
err = p.WriteMsg(m)
case <-timer.C: case <-timer.C:
// Try to sync in headers and block with the peer if his block height is higher then ours. // Try to sync in headers and block with the peer if his block height is higher then ours.
if p.Version().StartHeight > s.chain.BlockHeight() { if p.Version().StartHeight > s.chain.BlockHeight() {
s.requestBlocks(p) err = s.requestBlocks(p)
} }
if err == nil {
timer.Reset(s.ProtoTickInterval) timer.Reset(s.ProtoTickInterval)
} }
} }
if err != nil {
s.unregister <- peerDrop{p, err}
timer.Stop()
p.Disconnect(err)
return
}
}
} }
// When a peer connects to the server, we will send our version immediately. // When a peer connects to the server, we will send our version immediately.
@ -201,20 +235,23 @@ func (s *Server) sendVersion(p Peer) error {
s.chain.BlockHeight(), s.chain.BlockHeight(),
s.Relay, s.Relay,
) )
return p.WriteMsg(NewMessage(s.Net, CMDVersion, payload)) return p.SendVersion(NewMessage(s.Net, CMDVersion, payload))
} }
// When a peer sends out his version we reply with verack after validating // When a peer sends out his version we reply with verack after validating
// the version. // the version.
func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
if p.NetAddr().Port != int(version.Port) { err := p.HandleVersion(version)
return errPortMismatch if err != nil {
return err
} }
if s.id == version.Nonce { if s.id == version.Nonce {
return errIdenticalID return errIdenticalID
} }
p.SetVersion(version) if p.NetAddr().Port != int(version.Port) {
return p.WriteMsg(NewMessage(s.Net, CMDVerack, nil)) return fmt.Errorf("port mismatch: connected to %d and peer sends %d", p.NetAddr().Port, version.Port)
}
return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil))
} }
// handleHeadersCmd will process the headers it received from its peer. // handleHeadersCmd will process the headers it received from its peer.
@ -251,18 +288,42 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
} }
// handleAddrCmd will process received addresses.
func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
for _, a := range addrs.Addrs {
s.discovery.BackFill(a.IPPortString())
}
return nil
}
// handleGetAddrCmd sends to the peer some good addresses that we know of.
func (s *Server) handleGetAddrCmd(p Peer) error {
addrs := s.discovery.GoodPeers()
if len(addrs) > maxAddrsToSend {
addrs = addrs[:maxAddrsToSend]
}
alist := payload.NewAddressList(len(addrs))
ts := time.Now()
for i, addr := range addrs {
// we know it's a good address, so it can't fail
netaddr, _ := net.ResolveTCPAddr("tcp", addr)
alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts)
}
return p.WriteMsg(NewMessage(s.Net, CMDAddr, alist))
}
// requestHeaders will send a getheaders message to the peer. // requestHeaders will send a getheaders message to the peer.
// The peer will respond with headers op to a count of 2000. // The peer will respond with headers op to a count of 2000.
func (s *Server) requestHeaders(p Peer) { func (s *Server) requestHeaders(p Peer) error {
start := []util.Uint256{s.chain.CurrentHeaderHash()} start := []util.Uint256{s.chain.CurrentHeaderHash()}
payload := payload.NewGetBlocks(start, util.Uint256{}) payload := payload.NewGetBlocks(start, util.Uint256{})
p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload)) return p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload))
} }
// requestBlocks will send a getdata message to the peer // requestBlocks will send a getdata message to the peer
// to sync up in blocks. A maximum of maxBlockBatch will // to sync up in blocks. A maximum of maxBlockBatch will
// send at once. // send at once.
func (s *Server) requestBlocks(p Peer) { func (s *Server) requestBlocks(p Peer) error {
var ( var (
hashes []util.Uint256 hashes []util.Uint256
hashStart = s.chain.BlockHeight() + 1 hashStart = s.chain.BlockHeight() + 1
@ -275,10 +336,11 @@ func (s *Server) requestBlocks(p Peer) {
} }
if len(hashes) > 0 { if len(hashes) > 0 {
payload := payload.NewInventory(payload.BlockType, hashes) payload := payload.NewInventory(payload.BlockType, hashes)
p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
} else if s.chain.HeaderHeight() < p.Version().StartHeight { } else if s.chain.HeaderHeight() < p.Version().StartHeight {
s.requestHeaders(p) return s.requestHeaders(p)
} }
return nil
} }
// handleMessage will process the given message. // handleMessage will process the given message.
@ -289,10 +351,14 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return errInvalidNetwork return errInvalidNetwork
} }
if peer.Handshaked() {
switch msg.CommandType() { switch msg.CommandType() {
case CMDVersion: case CMDAddr:
version := msg.Payload.(*payload.Version) addrs := msg.Payload.(*payload.AddressList)
return s.handleVersionCmd(peer, version) return s.handleAddrCmd(peer, addrs)
case CMDGetAddr:
// it has no payload
return s.handleGetAddrCmd(peer)
case CMDHeaders: case CMDHeaders:
headers := msg.Payload.(*payload.Headers) headers := msg.Payload.(*payload.Headers)
go s.handleHeadersCmd(peer, headers) go s.handleHeadersCmd(peer, headers)
@ -302,13 +368,23 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
case CMDBlock: case CMDBlock:
block := msg.Payload.(*core.Block) block := msg.Payload.(*core.Block)
return s.handleBlockCmd(peer, block) return s.handleBlockCmd(peer, block)
case CMDVersion, CMDVerack:
return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
}
} else {
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
return s.handleVersionCmd(peer, version)
case CMDVerack: case CMDVerack:
// Make sure this peer has send his version before we start the err := peer.HandleVersionAck()
// protocol with that peer. if err != nil {
if peer.Version() == nil { return err
return errInvalidHandshake
} }
go s.startProtocol(peer) go s.startProtocol(peer)
default:
return fmt.Errorf("received '%s' during handshake", msg.CommandType())
}
} }
return nil return nil
} }

View file

@ -71,7 +71,7 @@ func TestServerNotSendsVerack(t *testing.T) {
version := payload.NewVersion(1337, 2000, "/NEO-GO/", 0, true) version := payload.NewVersion(1337, 2000, "/NEO-GO/", 0, true)
err := s.handleVersionCmd(p, version) err := s.handleVersionCmd(p, version)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, errPortMismatch, err) assert.Contains(t, err.Error(), "port mismatch")
// identical id's // identical id's
version = payload.NewVersion(1, 3000, "/NEO-GO/", 0, true) version = payload.NewVersion(1, 3000, "/NEO-GO/", 0, true)

View file

@ -1,12 +1,29 @@
package network package network
import ( import (
"errors"
"fmt"
"net" "net"
"sync" "sync"
"github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/network/payload"
) )
type handShakeStage uint8
//go:generate stringer -type=handShakeStage
const (
nothingDone handShakeStage = 0
versionSent handShakeStage = 1
versionReceived handShakeStage = 2
verAckSent handShakeStage = 3
verAckReceived handShakeStage = 4
)
var (
errStateMismatch = errors.New("tried to send protocol message before handshake completed")
)
// TCPPeer represents a connected remote node in the // TCPPeer represents a connected remote node in the
// network over TCP. // network over TCP.
type TCPPeer struct { type TCPPeer struct {
@ -17,6 +34,8 @@ type TCPPeer struct {
// The version of the peer. // The version of the peer.
version *payload.Version version *payload.Version
handShake handShakeStage
done chan error done chan error
wg sync.WaitGroup wg sync.WaitGroup
@ -35,8 +54,16 @@ func NewTCPPeer(conn net.Conn) *TCPPeer {
} }
// WriteMsg implements the Peer interface. This will write/encode the message // WriteMsg implements the Peer interface. This will write/encode the message
// to the underlying connection. // to the underlying connection, this only works for messages other than Version
// or VerAck.
func (p *TCPPeer) WriteMsg(msg *Message) error { func (p *TCPPeer) WriteMsg(msg *Message) error {
if !p.Handshaked() {
return errStateMismatch
}
return p.writeMsg(msg)
}
func (p *TCPPeer) writeMsg(msg *Message) error {
select { select {
case err := <-p.done: case err := <-p.done:
return err return err
@ -45,6 +72,55 @@ func (p *TCPPeer) WriteMsg(msg *Message) error {
} }
} }
// Handshaked returns status of the handshake, whether it's completed or not.
func (p *TCPPeer) Handshaked() bool {
return p.handShake == verAckReceived
}
// SendVersion checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersion(msg *Message) error {
if p.handShake != nothingDone {
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = versionSent
}
return err
}
// HandleVersion checks for the handshake state and version message contents.
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
if p.handShake != versionSent {
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
}
p.version = version
p.handShake = versionReceived
return nil
}
// SendVersionAck checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersionAck(msg *Message) error {
if p.handShake != versionReceived {
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = verAckSent
}
return err
}
// HandleVersionAck checks handshake sequence correctness when VerAck message
// is received.
func (p *TCPPeer) HandleVersionAck() error {
if p.handShake != verAckSent {
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
}
p.handShake = verAckReceived
return nil
}
// NetAddr implements the Peer interface. // NetAddr implements the Peer interface.
func (p *TCPPeer) NetAddr() *net.TCPAddr { func (p *TCPPeer) NetAddr() *net.TCPAddr {
return &p.addr return &p.addr
@ -59,15 +135,16 @@ func (p *TCPPeer) Done() chan error {
// Disconnect will fill the peer's done channel with the given error. // Disconnect will fill the peer's done channel with the given error.
func (p *TCPPeer) Disconnect(err error) { func (p *TCPPeer) Disconnect(err error) {
p.done <- err p.conn.Close()
select {
case p.done <- err:
// one message to the queue
default:
// the other side may already be gone, it's OK
}
} }
// Version implements the Peer interface. // Version implements the Peer interface.
func (p *TCPPeer) Version() *payload.Version { func (p *TCPPeer) Version() *payload.Version {
return p.version return p.version
} }
// SetVersion implements the Peer interface.
func (p *TCPPeer) SetVersion(v *payload.Version) {
p.version = v
}

View file

@ -75,21 +75,19 @@ func (t *TCPTransport) handleConn(conn net.Conn) {
err error err error
) )
defer func() {
p.Disconnect(err)
}()
t.server.register <- p t.server.register <- p
for { for {
msg := &Message{} msg := &Message{}
if err = msg.Decode(p.conn); err != nil { if err = msg.Decode(p.conn); err != nil {
return break
} }
if err = t.server.handleMessage(p, msg); err != nil { if err = t.server.handleMessage(p, msg); err != nil {
return break
} }
} }
t.server.unregister <- peerDrop{p, err}
p.Disconnect(err)
} }
// Close implements the Transporter interface. // Close implements the Transporter interface.