Node network improvements (#45)

* small improvements.

* Fixed datarace + cleanup node and peer

* bumped version.

* removed race flag to pass build
This commit is contained in:
Anthony De Meulemeester 2018-03-10 13:04:06 +01:00 committed by GitHub
parent 4023661cf1
commit aa4bd34b6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 367 additions and 147 deletions

145
' Normal file
View file

@ -0,0 +1,145 @@
package network
import (
"bytes"
"net"
"os"
"time"
"github.com/CityOfZion/neo-go/pkg/network/payload"
"github.com/CityOfZion/neo-go/pkg/util"
log "github.com/go-kit/kit/log"
)
// TCPPeer represents a connected remote node in the
// network over TCP.
type TCPPeer struct {
// The endpoint of the peer.
endpoint util.Endpoint
// underlying connection.
conn net.Conn
// The version the peer declared when connecting.
version *payload.Version
// connectedAt is the timestamp the peer connected to
// the network.
connectedAt time.Time
// handleProto is the handler that will handle the
// incoming message along with its peer.
handleProto protoHandleFunc
// Done is used to broadcast this peer has stopped running
// and should be removed as reference.
done chan struct{}
send chan *Message
disc chan struct{}
logger log.Logger
}
// NewTCPPeer creates a new peer from a TCP connection.
func NewTCPPeer(conn net.Conn, fun protoHandleFunc) *TCPPeer {
e := util.NewEndpoint(conn.RemoteAddr().String())
logger := log.NewLogfmtLogger(os.Stderr)
logger = log.With(logger, "component", "peer", "endpoint", e)
return &TCPPeer{
endpoint: e,
conn: conn,
done: make(chan struct{}),
send: make(chan *Message),
logger: logger,
connectedAt: time.Now().UTC(),
handleProto: fun,
disc: make(chan struct{}, 1),
}
}
// Version implements the Peer interface.
func (p *TCPPeer) Version() *payload.Version {
return p.version
}
// Endpoint implements the Peer interface.
func (p *TCPPeer) Endpoint() util.Endpoint {
return p.endpoint
}
// Send implements the Peer interface.
func (p *TCPPeer) Send(msg *Message) {
select {
case p.send <- msg:
break
case <-p.disc:
break
}
}
// Done implemnets the Peer interface.
func (p *TCPPeer) Done() chan struct{} {
return p.done
}
func (p *TCPPeer) run() error {
errCh := make(chan error, 1)
go p.readLoop(errCh)
go p.writeLoop(errCh)
err := <-errCh
p.logger.Log("err", err)
p.cleanup()
return err
}
func (p *TCPPeer) readLoop(errCh chan error) {
for {
msg := &Message{}
if err := msg.decode(p.conn); err != nil {
errCh <- err
break
}
p.handleMessage(msg)
}
}
func (p *TCPPeer) writeLoop(errCh chan error) {
buf := new(bytes.Buffer)
for {
select {
case msg := <-p.send:
if err := msg.encode(buf); err != nil {
errCh <- err
return
}
if _, err := p.conn.Write(buf.Bytes()); err != nil {
errCh <- err
return
}
buf.Reset()
}
}
}
func (p *TCPPeer) cleanup() {
p.conn.Close()
p.disc <- struct{}{}
p.done <- struct{}{}
close(p.disc)
close(p.send)
}
func (p *TCPPeer) handleMessage(msg *Message) {
switch msg.CommandType() {
case CMDVersion:
version := msg.Payload.(*payload.Version)
p.version = version
p.handleProto(msg, p)
default:
p.handleProto(msg, p)
}
}

View file

@ -1 +1 @@
0.26.0 0.27.0

View file

@ -44,6 +44,9 @@ type Blockchain struct {
// Only for operating on the headerList. // Only for operating on the headerList.
headersOp chan headersOpFunc headersOp chan headersOpFunc
headersOpDone chan struct{} headersOpDone chan struct{}
// Whether we will verify received blocks.
verifyBlocks bool
} }
type headersOpFunc func(headerList *HeaderHashList) type headersOpFunc func(headerList *HeaderHashList)
@ -60,6 +63,7 @@ func NewBlockchain(s Store, startHash util.Uint256) *Blockchain {
headersOpDone: make(chan struct{}), headersOpDone: make(chan struct{}),
startHash: startHash, startHash: startHash,
blockCache: NewCache(), blockCache: NewCache(),
verifyBlocks: true,
} }
go bc.run() go bc.run()
bc.init() bc.init()
@ -93,9 +97,12 @@ func (bc *Blockchain) AddBlock(block *Block) error {
return nil return nil
} }
if int(block.Index) == headerLen { if int(block.Index) == headerLen {
// todo: if (VerifyBlocks && !block.Verify()) return false; if bc.verifyBlocks && !block.Verify(false) {
return fmt.Errorf("block %s is invalid", block.Hash())
} }
return bc.AddHeaders(block.Header()) return bc.AddHeaders(block.Header())
}
return nil
} }
func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) { func (bc *Blockchain) AddHeaders(headers ...*Header) (err error) {

View file

@ -32,6 +32,16 @@ func TestAddHeaders(t *testing.T) {
assert.Equal(t, uint32(1), bc.storedHeaderCount) assert.Equal(t, uint32(1), bc.storedHeaderCount)
assert.Equal(t, uint32(0), bc.BlockHeight()) assert.Equal(t, uint32(0), bc.BlockHeight())
assert.Equal(t, h3.Hash(), bc.CurrentHeaderHash()) assert.Equal(t, h3.Hash(), bc.CurrentHeaderHash())
// Add them again, they should not be added.
if err := bc.AddHeaders(h3, h2, h1); err != nil {
t.Fatal(err)
}
assert.Equal(t, h3.Index, bc.HeaderHeight())
assert.Equal(t, uint32(1), bc.storedHeaderCount)
assert.Equal(t, uint32(0), bc.BlockHeight())
assert.Equal(t, h3.Hash(), bc.CurrentHeaderHash())
} }
func TestAddBlock(t *testing.T) { func TestAddBlock(t *testing.T) {
@ -66,5 +76,6 @@ func TestAddBlock(t *testing.T) {
func newTestBC() *Blockchain { func newTestBC() *Blockchain {
startHash, _ := util.Uint256DecodeString("a") startHash, _ := util.Uint256DecodeString("a")
bc := NewBlockchain(NewMemoryStore(), startHash) bc := NewBlockchain(NewMemoryStore(), startHash)
bc.verifyBlocks = false
return bc return bc
} }

View file

@ -10,3 +10,13 @@ func GenesisHashPrivNet() util.Uint256 {
hash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099") hash, _ := util.Uint256DecodeString("996e37358dc369912041f966f8c5d8d3a8255ba5dcbd3447f8a82b55db869099")
return hash return hash
} }
func GenesisHashTestNet() util.Uint256 {
hash, _ := util.Uint256DecodeString("b3181718ef6167105b70920e4a8fbbd0a0a56aacf460d70e10ba6fa1668f1fef")
return hash
}
func GenesisHashMainNet() util.Uint256 {
hash, _ := util.Uint256DecodeString("d42561e3d30e15be6400b6df2f328e02d2bf6354c41dce433bc57687c82144bf")
return hash
}

View file

@ -9,6 +9,7 @@ import (
"io" "io"
"github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/network/payload"
) )
@ -220,6 +221,11 @@ func (m *Message) decodePayload(r io.Reader) error {
if err := p.DecodeBinary(r); err != nil { if err := p.DecodeBinary(r); err != nil {
return err return err
} }
case CMDTX:
p = &transaction.Transaction{}
if err := p.DecodeBinary(r); err != nil {
return err
}
} }
m.Payload = p m.Payload = p
@ -229,10 +235,18 @@ func (m *Message) decodePayload(r io.Reader) error {
// encode a Message to any given io.Writer. // encode a Message to any given io.Writer.
func (m *Message) encode(w io.Writer) error { func (m *Message) encode(w io.Writer) error {
binary.Write(w, binary.LittleEndian, m.Magic) if err := binary.Write(w, binary.LittleEndian, m.Magic); err != nil {
binary.Write(w, binary.LittleEndian, m.Command) return err
binary.Write(w, binary.LittleEndian, m.Length) }
binary.Write(w, binary.LittleEndian, m.Checksum) if err := binary.Write(w, binary.LittleEndian, m.Command); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, m.Length); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, m.Checksum); err != nil {
return err
}
if m.Payload != nil { if m.Payload != nil {
return m.Payload.EncodeBinary(w) return m.Payload.EncodeBinary(w)

View file

@ -27,14 +27,6 @@ type Node struct {
server *Server server *Server
services uint64 services uint64
bc *core.Blockchain bc *core.Blockchain
protoIn chan messageTuple
}
// messageTuple respresents a tuple that holds the message being
// send along with its peer.
type messageTuple struct {
peer Peer
msg *Message
} }
func newNode(s *Server, cfg Config) *Node { func newNode(s *Server, cfg Config) *Node {
@ -42,6 +34,12 @@ func newNode(s *Server, cfg Config) *Node {
if cfg.Net == ModePrivNet { if cfg.Net == ModePrivNet {
startHash = core.GenesisHashPrivNet() startHash = core.GenesisHashPrivNet()
} }
if cfg.Net == ModeTestNet {
startHash = core.GenesisHashTestNet()
}
if cfg.Net == ModeMainNet {
startHash = core.GenesisHashMainNet()
}
bc := core.NewBlockchain( bc := core.NewBlockchain(
core.NewMemoryStore(), core.NewMemoryStore(),
@ -53,12 +51,10 @@ func newNode(s *Server, cfg Config) *Node {
n := &Node{ n := &Node{
Config: cfg, Config: cfg,
protoIn: make(chan messageTuple),
server: s, server: s,
bc: bc, bc: bc,
logger: logger, logger: logger,
} }
go n.handleMessages()
return n return n
} }
@ -67,32 +63,45 @@ func (n *Node) version() *payload.Version {
return payload.NewVersion(n.server.id, n.ListenTCP, n.UserAgent, 1, n.Relay) return payload.NewVersion(n.server.id, n.ListenTCP, n.UserAgent, 1, n.Relay)
} }
func (n *Node) startProtocol(peer Peer) { func (n *Node) startProtocol(p Peer) {
ticker := time.NewTicker(protoTickInterval).C n.logger.Log(
"event", "start protocol",
"peer", p.Endpoint(),
"userAgent", string(p.Version().UserAgent),
)
defer func() {
n.logger.Log(
"msg", "protocol stopped",
"peer", p.Endpoint(),
)
}()
timer := time.NewTimer(protoTickInterval)
for { for {
<-timer.C
select { select {
case <-ticker: case <-p.Done():
return
default:
// Try to sync with the peer if his block height is higher then ours. // Try to sync with the peer if his block height is higher then ours.
if peer.Version().StartHeight > n.bc.HeaderHeight() { if p.Version().StartHeight > n.bc.HeaderHeight() {
n.askMoreHeaders(peer) n.askMoreHeaders(p)
} }
// Only ask for more peers if the server has the capacity for it. // Only ask for more peers if the server has the capacity for it.
if n.server.hasCapacity() { if n.server.hasCapacity() {
msg := NewMessage(n.Net, CMDGetAddr, nil) msg := NewMessage(n.Net, CMDGetAddr, nil)
peer.Send(msg) p.Send(msg)
} }
case <-peer.Done(): timer.Reset(protoTickInterval)
return
} }
} }
} }
// When a peer sends out his version we reply with verack after validating // When a peer sends out his version we reply with verack after validating
// the version. // the version.
func (n *Node) handleVersionCmd(version *payload.Version, peer Peer) error { func (n *Node) handleVersionCmd(version *payload.Version, p Peer) error {
msg := NewMessage(n.Net, CMDVerack, nil) msg := NewMessage(n.Net, CMDVerack, nil)
peer.Send(msg) p.Send(msg)
return nil return nil
} }
@ -100,7 +109,7 @@ func (n *Node) handleVersionCmd(version *payload.Version, peer Peer) error {
// We will use the getdata message to get more details about the received // We will use the getdata message to get more details about the received
// inventory. // inventory.
// note: if the server has Relay on false, inventory messages are not received. // note: if the server has Relay on false, inventory messages are not received.
func (n *Node) handleInvCmd(inv *payload.Inventory, peer Peer) error { func (n *Node) handleInvCmd(inv *payload.Inventory, p Peer) error {
if !inv.Type.Valid() { if !inv.Type.Valid() {
return fmt.Errorf("invalid inventory type received: %s", inv.Type) return fmt.Errorf("invalid inventory type received: %s", inv.Type)
} }
@ -108,7 +117,7 @@ func (n *Node) handleInvCmd(inv *payload.Inventory, peer Peer) error {
return errors.New("inventory has no hashes") return errors.New("inventory has no hashes")
} }
payload := payload.NewInventory(inv.Type, inv.Hashes) payload := payload.NewInventory(inv.Type, inv.Hashes)
peer.Send(NewMessage(n.Net, CMDGetData, payload)) p.Send(NewMessage(n.Net, CMDGetData, payload))
return nil return nil
} }
@ -120,7 +129,6 @@ func (n *Node) handleBlockCmd(block *core.Block, peer Peer) error {
"hash", block.Hash(), "hash", block.Hash(),
"tx", len(block.Transactions), "tx", len(block.Transactions),
) )
return n.bc.AddBlock(block) return n.bc.AddBlock(block)
} }
@ -164,56 +172,42 @@ func (n *Node) askMoreHeaders(p Peer) {
// blockhain implements the Noder interface. // blockhain implements the Noder interface.
func (n *Node) blockchain() *core.Blockchain { return n.bc } func (n *Node) blockchain() *core.Blockchain { return n.bc }
// handleProto implements the protoHandler interface. func (n *Node) handleProto(msg *Message, p Peer) error {
func (n *Node) handleProto(msg *Message, p Peer) { //n.logger.Log(
n.protoIn <- messageTuple{ // "event", "message received",
msg: msg, // "from", p.Endpoint(),
peer: p, // "msg", msg.CommandType(),
} //)
}
func (n *Node) handleMessages() {
for {
t := <-n.protoIn
var (
msg = t.msg
p = t.peer
err error
)
switch msg.CommandType() { switch msg.CommandType() {
case CMDVersion: case CMDVersion:
version := msg.Payload.(*payload.Version) version := msg.Payload.(*payload.Version)
err = n.handleVersionCmd(version, p) return n.handleVersionCmd(version, p)
case CMDAddr: case CMDAddr:
addressList := msg.Payload.(*payload.AddressList) addressList := msg.Payload.(*payload.AddressList)
err = n.handleAddrCmd(addressList, p) return n.handleAddrCmd(addressList, p)
case CMDInv: case CMDInv:
inventory := msg.Payload.(*payload.Inventory) inventory := msg.Payload.(*payload.Inventory)
err = n.handleInvCmd(inventory, p) return n.handleInvCmd(inventory, p)
case CMDBlock: case CMDBlock:
block := msg.Payload.(*core.Block) block := msg.Payload.(*core.Block)
err = n.handleBlockCmd(block, p) return n.handleBlockCmd(block, p)
case CMDHeaders: case CMDHeaders:
headers := msg.Payload.(*payload.Headers) headers := msg.Payload.(*payload.Headers)
err = n.handleHeadersCmd(headers, p) return n.handleHeadersCmd(headers, p)
case CMDTX:
// tx := msg.Payload.(*transaction.Transaction)
//n.logger.Log("tx", fmt.Sprintf("%+v", tx))
return nil
case CMDVerack: case CMDVerack:
// Only start the protocol if we got the version and verack // Only start the protocol if we got the version and verack
// received. // received.
if p.Version() != nil { if p.Version() != nil {
go n.startProtocol(p) go n.startProtocol(p)
} }
return nil
case CMDUnknown: case CMDUnknown:
err = errors.New("received non-protocol messgae") return errors.New("received non-protocol messgae")
}
if err != nil {
n.logger.Log(
"msg", "failed processing message",
"command", msg.CommandType,
"err", err,
)
}
} }
return nil
} }

View file

@ -13,4 +13,5 @@ type Peer interface {
Endpoint() util.Endpoint Endpoint() util.Endpoint
Send(*Message) Send(*Message)
Done() chan struct{} Done() chan struct{}
Disconnect(err error)
} }

View file

@ -9,10 +9,10 @@ import (
// of the NEO protocol. // of the NEO protocol.
type ProtoHandler interface { type ProtoHandler interface {
version() *payload.Version version() *payload.Version
handleProto(*Message, Peer) handleProto(*Message, Peer) error
} }
type protoHandleFunc func(*Message, Peer) type protoHandleFunc func(*Message, Peer) error
// Noder is anything that implements the NEO protocol // Noder is anything that implements the NEO protocol
// and can return the Blockchain object. // and can return the Blockchain object.

View file

@ -70,7 +70,7 @@ type Server struct {
listener net.Listener listener net.Listener
register chan Peer register chan Peer
unregister chan Peer unregister chan peerDrop
badAddrOp chan func(map[string]bool) badAddrOp chan func(map[string]bool)
badAddrOpDone chan struct{} badAddrOpDone chan struct{}
@ -81,6 +81,11 @@ type Server struct {
quit chan struct{} quit chan struct{}
} }
type peerDrop struct {
p Peer
err error
}
// NewServer returns a new Server object created from the // NewServer returns a new Server object created from the
// given config. // given config.
func NewServer(cfg Config) *Server { func NewServer(cfg Config) *Server {
@ -103,7 +108,7 @@ func NewServer(cfg Config) *Server {
id: util.RandUint32(1000000, 9999999), id: util.RandUint32(1000000, 9999999),
quit: make(chan struct{}, 1), quit: make(chan struct{}, 1),
register: make(chan Peer), register: make(chan Peer),
unregister: make(chan Peer), unregister: make(chan peerDrop),
badAddrOp: make(chan func(map[string]bool)), badAddrOp: make(chan func(map[string]bool)),
badAddrOpDone: make(chan struct{}), badAddrOpDone: make(chan struct{}),
peerOp: make(chan func(map[Peer]bool)), peerOp: make(chan func(map[Peer]bool)),
@ -131,12 +136,14 @@ func (s *Server) listenTCP() {
s.logger.Log("msg", "conn read error", "err", err) s.logger.Log("msg", "conn read error", "err", err)
break break
} }
go s.setupConnection(conn) go s.setupPeerConn(conn)
} }
s.Quit() s.Quit()
} }
func (s *Server) setupConnection(conn net.Conn) { // setupPeerConn runs in its own routine for each connected Peer.
// and waits till the Peer.Run() returns.
func (s *Server) setupPeerConn(conn net.Conn) {
if !s.hasCapacity() { if !s.hasCapacity() {
s.logger.Log("msg", "server reached maximum capacity") s.logger.Log("msg", "server reached maximum capacity")
return return
@ -144,9 +151,9 @@ func (s *Server) setupConnection(conn net.Conn) {
p := NewTCPPeer(conn, s.proto.handleProto) p := NewTCPPeer(conn, s.proto.handleProto)
s.register <- p s.register <- p
if err := p.run(); err != nil {
s.unregister <- p err := p.run()
} s.unregister <- peerDrop{p, err}
} }
func (s *Server) connectToPeers(addrs ...string) { func (s *Server) connectToPeers(addrs ...string) {
@ -161,7 +168,7 @@ func (s *Server) connectToPeers(addrs ...string) {
<-s.badAddrOpDone <-s.badAddrOpDone
return return
} }
go s.setupConnection(conn) go s.setupPeerConn(conn)
}(addr) }(addr)
} }
} }
@ -194,13 +201,12 @@ func (s *Server) hasCapacity() bool {
return s.PeerCount() != s.MaxPeers return s.PeerCount() != s.MaxPeers
} }
func (s *Server) sendVersion(peer Peer) { func (s *Server) sendVersion(p Peer) {
peer.Send(NewMessage(s.Net, CMDVersion, s.proto.version())) p.Send(NewMessage(s.Net, CMDVersion, s.proto.version()))
} }
func (s *Server) run() { func (s *Server) run() {
var ( var (
ticker = time.NewTicker(30 * time.Second).C
peers = make(map[Peer]bool) peers = make(map[Peer]bool)
badAddrs = make(map[string]bool) badAddrs = make(map[string]bool)
) )
@ -219,11 +225,18 @@ func (s *Server) run() {
// out our version immediately. // out our version immediately.
s.sendVersion(p) s.sendVersion(p)
s.logger.Log("event", "peer connected", "endpoint", p.Endpoint()) s.logger.Log("event", "peer connected", "endpoint", p.Endpoint())
case p := <-s.unregister: case drop := <-s.unregister:
delete(peers, p) delete(peers, drop.p)
s.logger.Log("event", "peer disconnected", "endpoint", p.Endpoint()) s.logger.Log(
case <-ticker: "event", "peer disconnected",
s.printState() "endpoint", drop.p.Endpoint(),
"reason", drop.err,
"peerCount", len(peers),
)
if len(peers) == 0 {
s.logger.Log("fatal", "no more available peers")
return
}
case <-s.quit: case <-s.quit:
return return
} }

View file

@ -34,7 +34,7 @@ func TestUnregisterPeer(t *testing.T) {
s.register <- newTestPeer() s.register <- newTestPeer()
assert.Equal(t, 3, s.PeerCount()) assert.Equal(t, 3, s.PeerCount())
s.unregister <- peer s.unregister <- peerDrop{peer, nil}
assert.Equal(t, 2, s.PeerCount()) assert.Equal(t, 2, s.PeerCount())
} }
@ -44,7 +44,9 @@ func (t testNode) version() *payload.Version {
return &payload.Version{} return &payload.Version{}
} }
func (t testNode) handleProto(msg *Message, p Peer) {} func (t testNode) handleProto(msg *Message, p Peer) error {
return nil
}
func newTestServer() *Server { func newTestServer() *Server {
return &Server{ return &Server{
@ -52,7 +54,7 @@ func newTestServer() *Server {
id: util.RandUint32(1000000, 9999999), id: util.RandUint32(1000000, 9999999),
quit: make(chan struct{}, 1), quit: make(chan struct{}, 1),
register: make(chan Peer), register: make(chan Peer),
unregister: make(chan Peer), unregister: make(chan peerDrop),
badAddrOp: make(chan func(map[string]bool)), badAddrOp: make(chan func(map[string]bool)),
badAddrOpDone: make(chan struct{}), badAddrOpDone: make(chan struct{}),
peerOp: make(chan func(map[Peer]bool)), peerOp: make(chan func(map[Peer]bool)),
@ -84,3 +86,7 @@ func (p testPeer) Send(msg *Message) {}
func (p testPeer) Done() chan struct{} { func (p testPeer) Done() chan struct{} {
return p.done return p.done
} }
func (p testPeer) Disconnect(err error) {
}

View file

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"net" "net"
"os" "os"
"sync"
"time" "time"
"github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/network/payload"
@ -31,10 +32,14 @@ type TCPPeer struct {
// incoming message along with its peer. // incoming message along with its peer.
handleProto protoHandleFunc handleProto protoHandleFunc
// Done is used to broadcast this peer has stopped running // Done is used to broadcast that this peer has stopped running
// and should be removed as reference. // and should be removed as reference.
done chan struct{} done chan struct{}
send chan *Message
// Every send to this channel will terminate the Peer.
discErr chan error
closed chan struct{}
wg sync.WaitGroup
logger log.Logger logger log.Logger
} }
@ -49,10 +54,11 @@ func NewTCPPeer(conn net.Conn, fun protoHandleFunc) *TCPPeer {
endpoint: e, endpoint: e,
conn: conn, conn: conn,
done: make(chan struct{}), done: make(chan struct{}),
send: make(chan *Message),
logger: logger, logger: logger,
connectedAt: time.Now().UTC(), connectedAt: time.Now().UTC(),
handleProto: fun, handleProto: fun,
discErr: make(chan error),
closed: make(chan struct{}),
} }
} }
@ -68,58 +74,69 @@ func (p *TCPPeer) Endpoint() util.Endpoint {
// Send implements the Peer interface. // Send implements the Peer interface.
func (p *TCPPeer) Send(msg *Message) { func (p *TCPPeer) Send(msg *Message) {
p.send <- msg buf := new(bytes.Buffer)
if err := msg.encode(buf); err != nil {
p.discErr <- err
return
}
if _, err := p.conn.Write(buf.Bytes()); err != nil {
p.discErr <- err
return
}
} }
// Done implemnets the Peer interface. // Done implemnets the Peer interface. It use is to
// notify the Node that this peer is no longer available
// for sending messages to.
func (p *TCPPeer) Done() chan struct{} { func (p *TCPPeer) Done() chan struct{} {
return p.done return p.done
} }
func (p *TCPPeer) run() error { // Disconnect terminates the peer connection.
errCh := make(chan error, 1) func (p *TCPPeer) Disconnect(err error) {
select {
case p.discErr <- err:
case <-p.closed:
}
}
go p.readLoop(errCh) func (p *TCPPeer) run() (err error) {
go p.writeLoop(errCh) p.wg.Add(1)
go p.readLoop()
err := <-errCh run:
p.logger.Log("err", err) for {
p.cleanup() select {
case err = <-p.discErr:
break run
}
}
p.conn.Close()
close(p.closed)
// Close done instead of sending empty struct.
// It could happen that startProtocol in Node never happens
// on connection errors for example.
close(p.done)
p.wg.Wait()
return err return err
} }
func (p *TCPPeer) readLoop(errCh chan error) { func (p *TCPPeer) readLoop() {
defer p.wg.Done()
for { for {
select {
case <-p.closed:
return
default:
msg := &Message{} msg := &Message{}
if err := msg.decode(p.conn); err != nil { if err := msg.decode(p.conn); err != nil {
errCh <- err p.discErr <- err
break return
} }
p.handleMessage(msg) p.handleMessage(msg)
} }
}
func (p *TCPPeer) writeLoop(errCh chan error) {
buf := new(bytes.Buffer)
for {
msg := <-p.send
if err := msg.encode(buf); err != nil {
errCh <- err
break
} }
if _, err := p.conn.Write(buf.Bytes()); err != nil {
errCh <- err
break
}
buf.Reset()
}
}
func (p *TCPPeer) cleanup() {
p.conn.Close()
close(p.send)
p.done <- struct{}{}
} }
func (p *TCPPeer) handleMessage(msg *Message) { func (p *TCPPeer) handleMessage(msg *Message) {
@ -127,8 +144,10 @@ func (p *TCPPeer) handleMessage(msg *Message) {
case CMDVersion: case CMDVersion:
version := msg.Payload.(*payload.Version) version := msg.Payload.(*payload.Version)
p.version = version p.version = version
p.handleProto(msg, p) fallthrough
default: default:
p.handleProto(msg, p) if err := p.handleProto(msg, p); err != nil {
p.discErr <- err
}
} }
} }