Merge pull request #396 from nspcc-dev/network-reconnections-and-fixes

This one fixes #390 and some connected problems. After this patchset the node reconnects to some other nodes if anything goes wrong and it better senses when something goes wrong. It also fixes some block handling problems based on the testnet connection experience.
This commit is contained in:
Roman Khimov 2019-09-16 16:57:10 +03:00 committed by GitHub
commit adba9e11ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 427 additions and 119 deletions

View file

@ -357,8 +357,7 @@ func (bc *Blockchain) persistBlock(block *Block) error {
Email: t.Email,
Description: t.Description,
}
fmt.Printf("%+v", contract)
_ = contract
case *transaction.InvocationTX:
}
@ -430,6 +429,15 @@ func (bc *Blockchain) persist(ctx context.Context) (err error) {
"blockHeight": bc.BlockHeight(),
"took": time.Since(start),
}).Info("blockchain persist completed")
} else {
// So we have some blocks in cache but can't persist them?
// Either there are some stale blocks there or the other way
// around (which was seen in practice) --- there are some fresh
// blocks that we can't persist yet. Some of the latter can be useful
// or can be bogus (higher than the header height we expect at
// the moment). So try to reap oldies and strange newbies, if
// there are any.
bc.blockCache.ReapStrangeBlocks(bc.BlockHeight(), bc.HeaderHeight())
}
return

View file

@ -71,3 +71,17 @@ func (c *Cache) Delete(h util.Uint256) {
defer c.lock.Unlock()
delete(c.m, h)
}
// ReapStrangeBlocks drops blocks from cache that don't fit into the
// blkHeight-headHeight interval. Cache should only contain blocks that we
// expect to get and store.
func (c *Cache) ReapStrangeBlocks(blkHeight, headHeight uint32) {
c.lock.Lock()
defer c.lock.Unlock()
for i, b := range c.m {
block, ok := b.(*Block)
if ok && (block.Index < blkHeight || block.Index > headHeight) {
delete(c.m, i)
}
}
}

View file

@ -6,6 +6,7 @@ import (
const (
maxPoolSize = 200
connRetries = 3
)
// Discoverer is an interface that is responsible for maintaining
@ -15,22 +16,28 @@ type Discoverer interface {
PoolCount() int
RequestRemote(int)
RegisterBadAddr(string)
RegisterGoodAddr(string)
UnregisterConnectedAddr(string)
UnconnectedPeers() []string
BadPeers() []string
GoodPeers() []string
}
// DefaultDiscovery default implementation of the Discoverer interface.
type DefaultDiscovery struct {
transport Transporter
dialTimeout time.Duration
addrs map[string]bool
badAddrs map[string]bool
unconnectedAddrs map[string]bool
connectedAddrs map[string]bool
goodAddrs map[string]bool
unconnectedAddrs map[string]int
requestCh chan int
connectedCh chan string
backFill chan string
badAddrCh chan string
pool chan string
goodCh chan string
unconnectedCh chan string
}
// NewDefaultDiscovery returns a new DefaultDiscovery.
@ -38,11 +45,14 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
d := &DefaultDiscovery{
transport: ts,
dialTimeout: dt,
addrs: make(map[string]bool),
badAddrs: make(map[string]bool),
unconnectedAddrs: make(map[string]bool),
connectedAddrs: make(map[string]bool),
goodAddrs: make(map[string]bool),
unconnectedAddrs: make(map[string]int),
requestCh: make(chan int),
connectedCh: make(chan string),
goodCh: make(chan string),
unconnectedCh: make(chan string),
backFill: make(chan string),
badAddrCh: make(chan string),
pool: make(chan string, maxPoolSize),
@ -54,9 +64,6 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery {
// BackFill implements the Discoverer interface and will backfill the
// the pool with the given addresses.
func (d *DefaultDiscovery) BackFill(addrs ...string) {
if len(d.pool) == maxPoolSize {
return
}
for _, addr := range addrs {
d.backFill <- addr
}
@ -67,6 +74,17 @@ func (d *DefaultDiscovery) PoolCount() int {
return len(d.pool)
}
// pushToPoolOrDrop tries to push address given into the pool, but if the pool
// is already full, it just drops it
func (d *DefaultDiscovery) pushToPoolOrDrop(addr string) {
select {
case d.pool <- addr:
// ok, queued
default:
// whatever
}
}
// RequestRemote will try to establish a connection with n nodes.
func (d *DefaultDiscovery) RequestRemote(n int) {
d.requestCh <- n
@ -96,57 +114,87 @@ func (d *DefaultDiscovery) BadPeers() []string {
return addrs
}
func (d *DefaultDiscovery) work(addrCh chan string) {
for {
addr := <-addrCh
if err := d.transport.Dial(addr, d.dialTimeout); err != nil {
d.badAddrCh <- addr
} else {
d.connectedCh <- addr
}
// GoodPeers returns all addresses of known good peers (that at least once
// succeded handshaking with us).
func (d *DefaultDiscovery) GoodPeers() []string {
addrs := make([]string, 0, len(d.goodAddrs))
for addr := range d.goodAddrs {
addrs = append(addrs, addr)
}
return addrs
}
// RegisterGoodAddr registers good known connected address that passed
// handshake successfuly.
func (d *DefaultDiscovery) RegisterGoodAddr(s string) {
d.goodCh <- s
}
// UnregisterConnectedAddr tells discoverer that this address is no longer
// connected, but it still is considered as good one.
func (d *DefaultDiscovery) UnregisterConnectedAddr(s string) {
d.unconnectedCh <- s
}
func (d *DefaultDiscovery) tryAddress(addr string) {
if err := d.transport.Dial(addr, d.dialTimeout); err != nil {
d.badAddrCh <- addr
} else {
d.connectedCh <- addr
}
}
func (d *DefaultDiscovery) next() string {
return <-d.pool
func (d *DefaultDiscovery) requestToWork() {
var requested int
for {
for requested = <-d.requestCh; requested > 0; requested-- {
select {
case r := <-d.requestCh:
if requested < r {
requested = r
}
case addr := <-d.pool:
if !d.connectedAddrs[addr] {
go d.tryAddress(addr)
}
}
}
}
}
func (d *DefaultDiscovery) run() {
var (
maxWorkers = 5
workCh = make(chan string)
)
for i := 0; i < maxWorkers; i++ {
go d.work(workCh)
}
go d.requestToWork()
for {
select {
case addr := <-d.backFill:
if _, ok := d.badAddrs[addr]; ok {
if d.badAddrs[addr] || d.connectedAddrs[addr] ||
d.unconnectedAddrs[addr] > 0 {
break
}
if _, ok := d.addrs[addr]; !ok {
d.addrs[addr] = true
d.unconnectedAddrs[addr] = true
d.pool <- addr
}
case n := <-d.requestCh:
go func() {
for i := 0; i < n; i++ {
workCh <- d.next()
}
}()
d.unconnectedAddrs[addr] = connRetries
d.pushToPoolOrDrop(addr)
case addr := <-d.badAddrCh:
d.badAddrs[addr] = true
delete(d.unconnectedAddrs, addr)
go func() {
workCh <- d.next()
}()
d.unconnectedAddrs[addr]--
if d.unconnectedAddrs[addr] > 0 {
d.pushToPoolOrDrop(addr)
} else {
d.badAddrs[addr] = true
delete(d.unconnectedAddrs, addr)
}
d.RequestRemote(1)
case addr := <-d.connectedCh:
delete(d.unconnectedAddrs, addr)
if !d.connectedAddrs[addr] {
d.connectedAddrs[addr] = true
}
case addr := <-d.goodCh:
if !d.goodAddrs[addr] {
d.goodAddrs[addr] = true
}
case addr := <-d.unconnectedCh:
delete(d.connectedAddrs, addr)
}
}
}

View file

@ -0,0 +1,27 @@
// Code generated by "stringer -type=handShakeStage"; DO NOT EDIT.
package network
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[nothingDone-0]
_ = x[versionSent-1]
_ = x[versionReceived-2]
_ = x[verAckSent-3]
_ = x[verAckReceived-4]
}
const _handShakeStage_name = "nothingDoneversionSentversionReceivedverAckSentverAckReceived"
var _handShakeStage_index = [...]uint8{0, 11, 22, 37, 47, 61}
func (i handShakeStage) String() string {
if i >= handShakeStage(len(_handShakeStage_index)-1) {
return "handShakeStage(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _handShakeStage_name[_handShakeStage_index[i]:_handShakeStage_index[i+1]]
}

View file

@ -93,12 +93,15 @@ func (chain testChain) Verify(*transaction.Transaction) error {
type testDiscovery struct{}
func (d testDiscovery) BackFill(addrs ...string) {}
func (d testDiscovery) PoolCount() int { return 0 }
func (d testDiscovery) RegisterBadAddr(string) {}
func (d testDiscovery) UnconnectedPeers() []string { return []string{} }
func (d testDiscovery) RequestRemote(n int) {}
func (d testDiscovery) BadPeers() []string { return []string{} }
func (d testDiscovery) BackFill(addrs ...string) {}
func (d testDiscovery) PoolCount() int { return 0 }
func (d testDiscovery) RegisterBadAddr(string) {}
func (d testDiscovery) RegisterGoodAddr(string) {}
func (d testDiscovery) UnregisterConnectedAddr(string) {}
func (d testDiscovery) UnconnectedPeers() []string { return []string{} }
func (d testDiscovery) RequestRemote(n int) {}
func (d testDiscovery) BadPeers() []string { return []string{} }
func (d testDiscovery) GoodPeers() []string { return []string{} }
type localTransport struct{}
@ -114,6 +117,7 @@ var defaultMessageHandler = func(t *testing.T, msg *Message) {}
type localPeer struct {
netaddr net.TCPAddr
version *payload.Version
handshaked bool
t *testing.T
messageHandler func(t *testing.T, msg *Message)
}
@ -142,8 +146,23 @@ func (p *localPeer) Done() chan error {
func (p *localPeer) Version() *payload.Version {
return p.version
}
func (p *localPeer) SetVersion(v *payload.Version) {
func (p *localPeer) HandleVersion(v *payload.Version) error {
p.version = v
return nil
}
func (p *localPeer) SendVersion(m *Message) error {
return p.WriteMsg(m)
}
func (p *localPeer) SendVersionAck(m *Message) error {
return p.WriteMsg(m)
}
func (p *localPeer) HandleVersionAck() error {
p.handshaked = true
return nil
}
func (p *localPeer) Handshaked() bool {
return p.handshaked
}
func newTestServer() *Server {

View file

@ -3,6 +3,7 @@ package payload
import (
"io"
"net"
"strconv"
"time"
"github.com/CityOfZion/neo-go/pkg/util"
@ -47,11 +48,28 @@ func (p *AddressAndTime) EncodeBinary(w io.Writer) error {
return bw.Err
}
// IPPortString makes a string from IP and port specified.
func (p *AddressAndTime) IPPortString() string {
var netip net.IP = make(net.IP, 16)
copy(netip, p.IP[:])
port := strconv.Itoa(int(p.Port))
return netip.String() + ":" + port
}
// AddressList is a list with AddrAndTime.
type AddressList struct {
Addrs []*AddressAndTime
}
// NewAddressList creates a list for n AddressAndTime elements.
func NewAddressList(n int) *AddressList {
alist := AddressList{
Addrs: make([]*AddressAndTime, n),
}
return &alist
}
// DecodeBinary implements the Payload interface.
func (p *AddressList) DecodeBinary(r io.Reader) error {
br := util.BinReader{R: r}

View file

@ -35,7 +35,7 @@ func TestEncodeDecodeAddress(t *testing.T) {
func TestEncodeDecodeAddressList(t *testing.T) {
var lenList uint8 = 4
addrList := &AddressList{make([]*AddressAndTime, lenList)}
addrList := NewAddressList(int(lenList))
for i := 0; i < int(lenList); i++ {
e, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:200%d", i))
addrList.Addrs[i] = NewAddressAndTime(e, time.Now())

View file

@ -7,3 +7,22 @@ type Payload interface {
EncodeBinary(io.Writer) error
DecodeBinary(io.Reader) error
}
// NullPayload is a dummy payload with no fields.
type NullPayload struct {
}
// NewNullPayload returns zero-sized stub payload.
func NewNullPayload() *NullPayload {
return &NullPayload{}
}
// DecodeBinary implements the Payload interface.
func (p *NullPayload) DecodeBinary(r io.Reader) error {
return nil
}
// EncodeBinary implements the Payload interface.
func (p *NullPayload) EncodeBinary(r io.Writer) error {
return nil
}

View file

@ -13,5 +13,9 @@ type Peer interface {
WriteMsg(msg *Message) error
Done() chan error
Version() *payload.Version
SetVersion(*payload.Version)
Handshaked() bool
SendVersion(*Message) error
SendVersionAck(*Message) error
HandleVersion(*payload.Version) error
HandleVersionAck() error
}

View file

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"math/rand"
"net"
"sync"
"time"
@ -15,13 +16,15 @@ import (
)
const (
minPeers = 5
maxBlockBatch = 200
minPoolCount = 30
// peer numbers are arbitrary at the moment
minPeers = 5
maxPeers = 20
maxBlockBatch = 200
maxAddrsToSend = 200
minPoolCount = 30
)
var (
errPortMismatch = errors.New("port mismatch")
errIdenticalID = errors.New("identical node id")
errInvalidHandshake = errors.New("invalid handshake")
errInvalidNetwork = errors.New("invalid network")
@ -46,6 +49,7 @@ type (
lock sync.RWMutex
peers map[Peer]bool
addrReq chan *Message
register chan Peer
unregister chan peerDrop
quit chan struct{}
@ -64,6 +68,7 @@ func NewServer(config ServerConfig, chain core.Blockchainer) *Server {
chain: chain,
id: rand.Uint32(),
quit: make(chan struct{}),
addrReq: make(chan *Message, minPeers),
register: make(chan Peer),
unregister: make(chan peerDrop),
peers: make(map[Peer]bool),
@ -90,12 +95,7 @@ func (s *Server) Start(errChan chan error) {
"headerHeight": s.chain.HeaderHeight(),
}).Info("node started")
for _, addr := range s.Seeds {
if err := s.transport.Dial(addr, s.DialTimeout); err != nil {
log.Warnf("failed to connect to remote node %s", addr)
continue
}
}
s.discovery.BackFill(s.Seeds...)
go s.transport.Accept()
s.run()
@ -122,6 +122,19 @@ func (s *Server) BadPeers() []string {
func (s *Server) run() {
for {
c := s.PeerCount()
if c < minPeers {
s.discovery.RequestRemote(maxPeers - c)
}
if s.discovery.PoolCount() < minPoolCount {
select {
case s.addrReq <- NewMessage(s.Net, CMDGetAddr, payload.NewNullPayload()):
// sent request
default:
// we have one in the queue already that is
// gonna be served by some worker when it's ready
}
}
select {
case <-s.quit:
s.transport.Close()
@ -141,12 +154,19 @@ func (s *Server) run() {
"addr": p.NetAddr(),
}).Info("new peer connected")
case drop := <-s.unregister:
delete(s.peers, drop.peer)
log.WithFields(log.Fields{
"addr": drop.peer.NetAddr(),
"reason": drop.reason,
"peerCount": s.PeerCount(),
}).Warn("peer disconnected")
if s.peers[drop.peer] {
delete(s.peers, drop.peer)
log.WithFields(log.Fields{
"addr": drop.peer.NetAddr(),
"reason": drop.reason,
"peerCount": s.PeerCount(),
}).Warn("peer disconnected")
addr := drop.peer.NetAddr().String()
s.discovery.UnregisterConnectedAddr(addr)
s.discovery.BackFill(addr)
}
// else the peer is already gone, which can happen
// because we have two goroutines sending signals here
}
}
}
@ -174,20 +194,34 @@ func (s *Server) startProtocol(p Peer) {
"id": p.Version().Nonce,
}).Info("started protocol")
s.requestHeaders(p)
s.discovery.RegisterGoodAddr(p.NetAddr().String())
err := s.requestHeaders(p)
if err != nil {
p.Disconnect(err)
return
}
timer := time.NewTimer(s.ProtoTickInterval)
for {
select {
case err := <-p.Done():
s.unregister <- peerDrop{p, err}
return
case err = <-p.Done():
// time to stop
case m := <-s.addrReq:
err = p.WriteMsg(m)
case <-timer.C:
// Try to sync in headers and block with the peer if his block height is higher then ours.
if p.Version().StartHeight > s.chain.BlockHeight() {
s.requestBlocks(p)
err = s.requestBlocks(p)
}
timer.Reset(s.ProtoTickInterval)
if err == nil {
timer.Reset(s.ProtoTickInterval)
}
}
if err != nil {
s.unregister <- peerDrop{p, err}
timer.Stop()
p.Disconnect(err)
return
}
}
}
@ -201,20 +235,23 @@ func (s *Server) sendVersion(p Peer) error {
s.chain.BlockHeight(),
s.Relay,
)
return p.WriteMsg(NewMessage(s.Net, CMDVersion, payload))
return p.SendVersion(NewMessage(s.Net, CMDVersion, payload))
}
// When a peer sends out his version we reply with verack after validating
// the version.
func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
if p.NetAddr().Port != int(version.Port) {
return errPortMismatch
err := p.HandleVersion(version)
if err != nil {
return err
}
if s.id == version.Nonce {
return errIdenticalID
}
p.SetVersion(version)
return p.WriteMsg(NewMessage(s.Net, CMDVerack, nil))
if p.NetAddr().Port != int(version.Port) {
return fmt.Errorf("port mismatch: connected to %d and peer sends %d", p.NetAddr().Port, version.Port)
}
return p.SendVersionAck(NewMessage(s.Net, CMDVerack, nil))
}
// handleHeadersCmd will process the headers it received from its peer.
@ -251,18 +288,42 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
}
// handleAddrCmd will process received addresses.
func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
for _, a := range addrs.Addrs {
s.discovery.BackFill(a.IPPortString())
}
return nil
}
// handleGetAddrCmd sends to the peer some good addresses that we know of.
func (s *Server) handleGetAddrCmd(p Peer) error {
addrs := s.discovery.GoodPeers()
if len(addrs) > maxAddrsToSend {
addrs = addrs[:maxAddrsToSend]
}
alist := payload.NewAddressList(len(addrs))
ts := time.Now()
for i, addr := range addrs {
// we know it's a good address, so it can't fail
netaddr, _ := net.ResolveTCPAddr("tcp", addr)
alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts)
}
return p.WriteMsg(NewMessage(s.Net, CMDAddr, alist))
}
// requestHeaders will send a getheaders message to the peer.
// The peer will respond with headers op to a count of 2000.
func (s *Server) requestHeaders(p Peer) {
func (s *Server) requestHeaders(p Peer) error {
start := []util.Uint256{s.chain.CurrentHeaderHash()}
payload := payload.NewGetBlocks(start, util.Uint256{})
p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload))
return p.WriteMsg(NewMessage(s.Net, CMDGetHeaders, payload))
}
// requestBlocks will send a getdata message to the peer
// to sync up in blocks. A maximum of maxBlockBatch will
// send at once.
func (s *Server) requestBlocks(p Peer) {
func (s *Server) requestBlocks(p Peer) error {
var (
hashes []util.Uint256
hashStart = s.chain.BlockHeight() + 1
@ -275,10 +336,11 @@ func (s *Server) requestBlocks(p Peer) {
}
if len(hashes) > 0 {
payload := payload.NewInventory(payload.BlockType, hashes)
p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
} else if s.chain.HeaderHeight() < p.Version().StartHeight {
s.requestHeaders(p)
return s.requestHeaders(p)
}
return nil
}
// handleMessage will process the given message.
@ -289,26 +351,40 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return errInvalidNetwork
}
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
return s.handleVersionCmd(peer, version)
case CMDHeaders:
headers := msg.Payload.(*payload.Headers)
go s.handleHeadersCmd(peer, headers)
case CMDInv:
inventory := msg.Payload.(*payload.Inventory)
return s.handleInvCmd(peer, inventory)
case CMDBlock:
block := msg.Payload.(*core.Block)
return s.handleBlockCmd(peer, block)
case CMDVerack:
// Make sure this peer has send his version before we start the
// protocol with that peer.
if peer.Version() == nil {
return errInvalidHandshake
if peer.Handshaked() {
switch msg.CommandType() {
case CMDAddr:
addrs := msg.Payload.(*payload.AddressList)
return s.handleAddrCmd(peer, addrs)
case CMDGetAddr:
// it has no payload
return s.handleGetAddrCmd(peer)
case CMDHeaders:
headers := msg.Payload.(*payload.Headers)
go s.handleHeadersCmd(peer, headers)
case CMDInv:
inventory := msg.Payload.(*payload.Inventory)
return s.handleInvCmd(peer, inventory)
case CMDBlock:
block := msg.Payload.(*core.Block)
return s.handleBlockCmd(peer, block)
case CMDVersion, CMDVerack:
return fmt.Errorf("received '%s' after the handshake", msg.CommandType())
}
} else {
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
return s.handleVersionCmd(peer, version)
case CMDVerack:
err := peer.HandleVersionAck()
if err != nil {
return err
}
go s.startProtocol(peer)
default:
return fmt.Errorf("received '%s' during handshake", msg.CommandType())
}
go s.startProtocol(peer)
}
return nil
}

View file

@ -71,7 +71,7 @@ func TestServerNotSendsVerack(t *testing.T) {
version := payload.NewVersion(1337, 2000, "/NEO-GO/", 0, true)
err := s.handleVersionCmd(p, version)
assert.NotNil(t, err)
assert.Equal(t, errPortMismatch, err)
assert.Contains(t, err.Error(), "port mismatch")
// identical id's
version = payload.NewVersion(1, 3000, "/NEO-GO/", 0, true)

View file

@ -1,12 +1,29 @@
package network
import (
"errors"
"fmt"
"net"
"sync"
"github.com/CityOfZion/neo-go/pkg/network/payload"
)
type handShakeStage uint8
//go:generate stringer -type=handShakeStage
const (
nothingDone handShakeStage = 0
versionSent handShakeStage = 1
versionReceived handShakeStage = 2
verAckSent handShakeStage = 3
verAckReceived handShakeStage = 4
)
var (
errStateMismatch = errors.New("tried to send protocol message before handshake completed")
)
// TCPPeer represents a connected remote node in the
// network over TCP.
type TCPPeer struct {
@ -17,6 +34,8 @@ type TCPPeer struct {
// The version of the peer.
version *payload.Version
handShake handShakeStage
done chan error
wg sync.WaitGroup
@ -35,8 +54,16 @@ func NewTCPPeer(conn net.Conn) *TCPPeer {
}
// WriteMsg implements the Peer interface. This will write/encode the message
// to the underlying connection.
// to the underlying connection, this only works for messages other than Version
// or VerAck.
func (p *TCPPeer) WriteMsg(msg *Message) error {
if !p.Handshaked() {
return errStateMismatch
}
return p.writeMsg(msg)
}
func (p *TCPPeer) writeMsg(msg *Message) error {
select {
case err := <-p.done:
return err
@ -45,6 +72,55 @@ func (p *TCPPeer) WriteMsg(msg *Message) error {
}
}
// Handshaked returns status of the handshake, whether it's completed or not.
func (p *TCPPeer) Handshaked() bool {
return p.handShake == verAckReceived
}
// SendVersion checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersion(msg *Message) error {
if p.handShake != nothingDone {
return fmt.Errorf("invalid handshake: tried to send Version in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = versionSent
}
return err
}
// HandleVersion checks for the handshake state and version message contents.
func (p *TCPPeer) HandleVersion(version *payload.Version) error {
if p.handShake != versionSent {
return fmt.Errorf("invalid handshake: received Version in %s state", p.handShake.String())
}
p.version = version
p.handShake = versionReceived
return nil
}
// SendVersionAck checks for the handshake state and sends a message to the peer.
func (p *TCPPeer) SendVersionAck(msg *Message) error {
if p.handShake != versionReceived {
return fmt.Errorf("invalid handshake: tried to send VersionAck in %s state", p.handShake.String())
}
err := p.writeMsg(msg)
if err == nil {
p.handShake = verAckSent
}
return err
}
// HandleVersionAck checks handshake sequence correctness when VerAck message
// is received.
func (p *TCPPeer) HandleVersionAck() error {
if p.handShake != verAckSent {
return fmt.Errorf("invalid handshake: received VersionAck in %s state", p.handShake.String())
}
p.handShake = verAckReceived
return nil
}
// NetAddr implements the Peer interface.
func (p *TCPPeer) NetAddr() *net.TCPAddr {
return &p.addr
@ -59,15 +135,16 @@ func (p *TCPPeer) Done() chan error {
// Disconnect will fill the peer's done channel with the given error.
func (p *TCPPeer) Disconnect(err error) {
p.done <- err
p.conn.Close()
select {
case p.done <- err:
// one message to the queue
default:
// the other side may already be gone, it's OK
}
}
// Version implements the Peer interface.
func (p *TCPPeer) Version() *payload.Version {
return p.version
}
// SetVersion implements the Peer interface.
func (p *TCPPeer) SetVersion(v *payload.Version) {
p.version = v
}

View file

@ -75,21 +75,19 @@ func (t *TCPTransport) handleConn(conn net.Conn) {
err error
)
defer func() {
p.Disconnect(err)
}()
t.server.register <- p
for {
msg := &Message{}
if err = msg.Decode(p.conn); err != nil {
return
break
}
if err = t.server.handleMessage(p, msg); err != nil {
return
break
}
}
t.server.unregister <- peerDrop{p, err}
p.Disconnect(err)
}
// Close implements the Transporter interface.