mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-25 23:42:23 +00:00
Refactor tcp transport (#11)
* refactored tcp transport * return errors on outgoing messages * TCP transport should report its error after reading from connection * handle error returned from peer transport * bump version * cleaned up error
This commit is contained in:
parent
4050dbeeb8
commit
6e3f1ec43e
5 changed files with 179 additions and 72 deletions
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
|||
0.2.0
|
||||
0.3.0
|
|
@ -10,8 +10,10 @@ type Peer interface {
|
|||
id() uint32
|
||||
addr() util.Endpoint
|
||||
disconnect()
|
||||
callVersion(*Message)
|
||||
callGetaddr(*Message)
|
||||
callVersion(*Message) error
|
||||
callGetaddr(*Message) error
|
||||
callVerack(*Message) error
|
||||
callGetdata(*Message) error
|
||||
}
|
||||
|
||||
// LocalPeer is the simplest kind of peer, mapped to a server in the
|
||||
|
@ -28,12 +30,20 @@ func NewLocalPeer(s *Server) *LocalPeer {
|
|||
return &LocalPeer{endpoint: e, s: s}
|
||||
}
|
||||
|
||||
func (p *LocalPeer) callVersion(msg *Message) {
|
||||
p.s.handleVersionCmd(msg, p)
|
||||
func (p *LocalPeer) callVersion(msg *Message) error {
|
||||
return p.s.handleVersionCmd(msg, p)
|
||||
}
|
||||
|
||||
func (p *LocalPeer) callGetaddr(msg *Message) {
|
||||
p.s.handleGetaddrCmd(msg, p)
|
||||
func (p *LocalPeer) callVerack(msg *Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *LocalPeer) callGetaddr(msg *Message) error {
|
||||
return p.s.handleGetaddrCmd(msg, p)
|
||||
}
|
||||
|
||||
func (p *LocalPeer) callGetdata(msg *Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *LocalPeer) id() uint32 { return p.nonce }
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
|
@ -159,7 +160,7 @@ func (s *Server) shutdown() {
|
|||
|
||||
// disconnect and remove all connected peers.
|
||||
for peer := range s.peers {
|
||||
s.unregister <- peer
|
||||
peer.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -173,13 +174,17 @@ func (s *Server) loop() {
|
|||
if len(s.peers) < maxPeers {
|
||||
s.logger.Printf("peer registered from address %s", peer.addr())
|
||||
s.peers[peer] = true
|
||||
s.handlePeerConnected(peer)
|
||||
|
||||
if err := s.handlePeerConnected(peer); err != nil {
|
||||
s.logger.Printf("failed handling peer connection: %s", err)
|
||||
peer.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
// Unregister should take care of all the cleanup that has to be made.
|
||||
// unregister safely deletes a peer. For disconnecting peers use the
|
||||
// disconnect() method on the peer, it will call unregister and terminates its routines.
|
||||
case peer := <-s.unregister:
|
||||
if _, ok := s.peers[peer]; ok {
|
||||
peer.disconnect()
|
||||
delete(s.peers, peer)
|
||||
s.logger.Printf("peer %s disconnected", peer.addr())
|
||||
}
|
||||
|
@ -196,71 +201,68 @@ func (s *Server) loop() {
|
|||
// When a new peer is connected we send our version.
|
||||
// No further communication should be made before both sides has received
|
||||
// the versions of eachother.
|
||||
func (s *Server) handlePeerConnected(p Peer) {
|
||||
func (s *Server) handlePeerConnected(p Peer) error {
|
||||
// TODO: get the blockheight of this server once core implemented this.
|
||||
payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay)
|
||||
msg := newMessage(s.net, cmdVersion, payload)
|
||||
p.callVersion(msg)
|
||||
return p.callVersion(msg)
|
||||
}
|
||||
|
||||
func (s *Server) handleVersionCmd(msg *Message, p Peer) *Message {
|
||||
func (s *Server) handleVersionCmd(msg *Message, p Peer) error {
|
||||
version := msg.Payload.(*payload.Version)
|
||||
if s.id == version.Nonce {
|
||||
// s.unregister <- p
|
||||
return nil
|
||||
return errors.New("identical nonce")
|
||||
}
|
||||
if p.addr().Port != version.Port {
|
||||
// s.unregister <- p
|
||||
return nil
|
||||
return fmt.Errorf("port mismatch: %d", version.Port)
|
||||
}
|
||||
return newMessage(ModeDevNet, cmdVerack, nil)
|
||||
|
||||
return p.callVerack(newMessage(s.net, cmdVerack, nil))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetaddrCmd(msg *Message, p Peer) *Message {
|
||||
func (s *Server) handleGetaddrCmd(msg *Message, p Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The node can broadcast the object information it owns by this message.
|
||||
// The message can be sent automatically or can be used to answer getbloks messages.
|
||||
func (s *Server) handleInvCmd(msg *Message, p Peer) *Message {
|
||||
func (s *Server) handleInvCmd(msg *Message, p Peer) error {
|
||||
inv := msg.Payload.(*payload.Inventory)
|
||||
if !inv.Type.Valid() {
|
||||
s.unregister <- p
|
||||
return nil
|
||||
return fmt.Errorf("invalid inventory type %s", inv.Type)
|
||||
}
|
||||
if len(inv.Hashes) == 0 {
|
||||
s.unregister <- p
|
||||
return nil
|
||||
return errors.New("inventory should have at least 1 hash got 0")
|
||||
}
|
||||
|
||||
// todo: only grab the hashes that we dont know.
|
||||
|
||||
payload := payload.NewInventory(inv.Type, inv.Hashes)
|
||||
resp := newMessage(s.net, cmdGetData, payload)
|
||||
return resp
|
||||
|
||||
return p.callGetdata(resp)
|
||||
}
|
||||
|
||||
// handleBlockCmd processes the received block.
|
||||
func (s *Server) handleBlockCmd(msg *Message, p Peer) {
|
||||
func (s *Server) handleBlockCmd(msg *Message, p Peer) error {
|
||||
block := msg.Payload.(*core.Block)
|
||||
hash, err := block.Hash()
|
||||
if err != nil {
|
||||
// not quite sure what to do here.
|
||||
// should we disconnect the client or just silently log and move on?
|
||||
s.logger.Printf("failed to generate block hash: %s", err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(hash)
|
||||
|
||||
if s.bc.HasBlock(hash) {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// After receiving the getaddr message, the node returns an addr message as response
|
||||
// and provides information about the known nodes on the network.
|
||||
func (s *Server) handleAddrCmd(msg *Message, p Peer) {
|
||||
func (s *Server) handleAddrCmd(msg *Message, p Peer) error {
|
||||
addrList := msg.Payload.(*payload.AddressList)
|
||||
for _, addr := range addrList.Addrs {
|
||||
if !s.peerAlreadyConnected(addr.Addr) {
|
||||
|
@ -268,6 +270,7 @@ func (s *Server) handleAddrCmd(msg *Message, p Peer) {
|
|||
go connectToRemoteNode(s, addr.Addr.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) relayInventory(inv *payload.Inventory) {
|
||||
|
|
|
@ -6,6 +6,32 @@ import (
|
|||
"github.com/CityOfZion/neo-go/pkg/network/payload"
|
||||
)
|
||||
|
||||
func TestHandleVersionFailWrongPort(t *testing.T) {
|
||||
s := NewServer(ModeDevNet)
|
||||
go s.loop()
|
||||
|
||||
p := NewLocalPeer(s)
|
||||
|
||||
version := payload.NewVersion(1337, 1, "/NEO:0.0.0/", 0, true)
|
||||
msg := newMessage(ModeDevNet, cmdVersion, version)
|
||||
if err := s.handleVersionCmd(msg, p); err == nil {
|
||||
t.Fatal("expected error got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleVersionFailIdenticalNonce(t *testing.T) {
|
||||
s := NewServer(ModeDevNet)
|
||||
go s.loop()
|
||||
|
||||
p := NewLocalPeer(s)
|
||||
|
||||
version := payload.NewVersion(s.id, 1, "/NEO:0.0.0/", 0, true)
|
||||
msg := newMessage(ModeDevNet, cmdVersion, version)
|
||||
if err := s.handleVersionCmd(msg, p); err == nil {
|
||||
t.Fatal("expected error got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleVersion(t *testing.T) {
|
||||
s := NewServer(ModeDevNet)
|
||||
go s.loop()
|
||||
|
@ -15,12 +41,8 @@ func TestHandleVersion(t *testing.T) {
|
|||
version := payload.NewVersion(1337, p.addr().Port, "/NEO:0.0.0/", 0, true)
|
||||
msg := newMessage(ModeDevNet, cmdVersion, version)
|
||||
|
||||
resp := s.handleVersionCmd(msg, p)
|
||||
if resp.commandType() != cmdVerack {
|
||||
t.Fatalf("expected response message to be verack got %s", resp.commandType())
|
||||
}
|
||||
if resp.Payload != nil {
|
||||
t.Fatal("verack payload should be nil")
|
||||
if err := s.handleVersionCmd(msg, p); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package network
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -52,8 +53,7 @@ func handleConnection(s *Server, conn net.Conn) {
|
|||
|
||||
// remove the peer from connected peers and cleanup the connection.
|
||||
defer func() {
|
||||
// all cleanup will happen in the server's loop when unregister is received.
|
||||
s.unregister <- peer
|
||||
peer.disconnect()
|
||||
}()
|
||||
|
||||
// Start a goroutine that will handle all outgoing messages.
|
||||
|
@ -66,17 +66,17 @@ func handleConnection(s *Server, conn net.Conn) {
|
|||
for {
|
||||
_, err := conn.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.logger.Printf("conn read error: %s", err)
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
msg := &Message{}
|
||||
if err := msg.decode(bytes.NewReader(buf)); err != nil {
|
||||
s.logger.Printf("decode error %s", err)
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
peer.receive <- msg
|
||||
|
@ -85,9 +85,11 @@ func handleConnection(s *Server, conn net.Conn) {
|
|||
|
||||
// handleMessage hands the message received from a TCP connection over to the server.
|
||||
func handleMessage(s *Server, p *TCPPeer) {
|
||||
var err error
|
||||
|
||||
// Disconnect the peer when we break out of the loop.
|
||||
defer func() {
|
||||
s.unregister <- p
|
||||
p.disconnect()
|
||||
}()
|
||||
|
||||
for {
|
||||
|
@ -98,30 +100,58 @@ func handleMessage(s *Server, p *TCPPeer) {
|
|||
|
||||
switch command {
|
||||
case cmdVersion:
|
||||
resp := s.handleVersionCmd(msg, p)
|
||||
if err = s.handleVersionCmd(msg, p); err != nil {
|
||||
return
|
||||
}
|
||||
p.nonce = msg.Payload.(*payload.Version).Nonce
|
||||
p.send <- resp
|
||||
|
||||
// When a node receives a connection request, it declares its version immediately.
|
||||
// There will be no other communication until both sides are getting versions of each other.
|
||||
// When a node receives the version message, it replies to a verack as a response immediately.
|
||||
// NOTE: The current official NEO nodes dont mimic this behaviour. There is small chance that the
|
||||
// official nodes will not respond directly with a verack after we sended our version.
|
||||
// is this a bug? - anthdm 02/02/2018
|
||||
msgVerack := <-p.receive
|
||||
if msgVerack.commandType() != cmdVerack {
|
||||
s.logger.Printf("expected verack after sended out version")
|
||||
return
|
||||
}
|
||||
|
||||
// start the protocol
|
||||
go s.sendLoop(p)
|
||||
case cmdAddr:
|
||||
s.handleAddrCmd(msg, p)
|
||||
err = s.handleAddrCmd(msg, p)
|
||||
case cmdGetAddr:
|
||||
s.handleGetaddrCmd(msg, p)
|
||||
err = s.handleGetaddrCmd(msg, p)
|
||||
case cmdInv:
|
||||
resp := s.handleInvCmd(msg, p)
|
||||
p.send <- resp
|
||||
err = s.handleInvCmd(msg, p)
|
||||
case cmdBlock:
|
||||
s.handleBlockCmd(msg, p)
|
||||
err = s.handleBlockCmd(msg, p)
|
||||
case cmdConsensus:
|
||||
case cmdTX:
|
||||
case cmdVerack:
|
||||
go s.sendLoop(p)
|
||||
// If we receive a verack here we disconnect. We already handled the verack
|
||||
// when we sended our version.
|
||||
err = errors.New("received verack twice")
|
||||
case cmdGetHeaders:
|
||||
case cmdGetBlocks:
|
||||
case cmdGetData:
|
||||
case cmdHeaders:
|
||||
}
|
||||
|
||||
// catch all errors here and disconnect.
|
||||
if err != nil {
|
||||
s.logger.Printf("processing message failed: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type sendTuple struct {
|
||||
msg *Message
|
||||
err chan error
|
||||
}
|
||||
|
||||
// TCPPeer represents a remote node, backed by TCP transport.
|
||||
type TCPPeer struct {
|
||||
s *Server
|
||||
|
@ -132,7 +162,7 @@ type TCPPeer struct {
|
|||
// host and port information about this peer.
|
||||
endpoint util.Endpoint
|
||||
// channel to coordinate messages writen back to the connection.
|
||||
send chan *Message
|
||||
send chan sendTuple
|
||||
// channel to receive from underlying connection.
|
||||
receive chan *Message
|
||||
}
|
||||
|
@ -143,15 +173,22 @@ func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer {
|
|||
|
||||
return &TCPPeer{
|
||||
conn: conn,
|
||||
send: make(chan *Message),
|
||||
send: make(chan sendTuple),
|
||||
receive: make(chan *Message),
|
||||
endpoint: e,
|
||||
s: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TCPPeer) callVersion(msg *Message) {
|
||||
p.send <- msg
|
||||
func (p *TCPPeer) callVersion(msg *Message) error {
|
||||
t := sendTuple{
|
||||
msg: msg,
|
||||
err: make(chan error),
|
||||
}
|
||||
|
||||
p.send <- t
|
||||
|
||||
return <-t.err
|
||||
}
|
||||
|
||||
// id implements the peer interface
|
||||
|
@ -165,16 +202,51 @@ func (p *TCPPeer) addr() util.Endpoint {
|
|||
}
|
||||
|
||||
// callGetaddr will send the "getaddr" command to the remote.
|
||||
func (p *TCPPeer) callGetaddr(msg *Message) {
|
||||
p.send <- msg
|
||||
func (p *TCPPeer) callGetaddr(msg *Message) error {
|
||||
t := sendTuple{
|
||||
msg: msg,
|
||||
err: make(chan error),
|
||||
}
|
||||
|
||||
p.send <- t
|
||||
|
||||
return <-t.err
|
||||
}
|
||||
|
||||
// disconnect closes the send channel and the underlying connection.
|
||||
// TODO: this needs some love. We will get send on closed channel.
|
||||
func (p *TCPPeer) callVerack(msg *Message) error {
|
||||
t := sendTuple{
|
||||
msg: msg,
|
||||
err: make(chan error),
|
||||
}
|
||||
|
||||
p.send <- t
|
||||
|
||||
return <-t.err
|
||||
}
|
||||
|
||||
func (p *TCPPeer) callGetdata(msg *Message) error {
|
||||
t := sendTuple{
|
||||
msg: msg,
|
||||
err: make(chan error),
|
||||
}
|
||||
|
||||
p.send <- t
|
||||
|
||||
return <-t.err
|
||||
}
|
||||
|
||||
// disconnect disconnects the peer, cleaning up all its resources.
|
||||
// 3 goroutines needs to be cleanup (writeLoop, handleConnection and handleMessage)
|
||||
func (p *TCPPeer) disconnect() {
|
||||
p.conn.Close()
|
||||
close(p.send)
|
||||
close(p.receive)
|
||||
select {
|
||||
case <-p.send:
|
||||
case <-p.receive:
|
||||
default:
|
||||
close(p.send)
|
||||
close(p.receive)
|
||||
p.s.unregister <- p
|
||||
p.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// writeLoop writes messages to the underlying TCP connection.
|
||||
|
@ -184,17 +256,17 @@ func (p *TCPPeer) disconnect() {
|
|||
func (p *TCPPeer) writeLoop() {
|
||||
// clean up the connection.
|
||||
defer func() {
|
||||
p.conn.Close()
|
||||
p.disconnect()
|
||||
}()
|
||||
|
||||
for {
|
||||
msg := <-p.send
|
||||
|
||||
p.s.logger.Printf("OUT :: %s :: %+v", msg.commandType(), msg.Payload)
|
||||
|
||||
// should we disconnect here?
|
||||
if err := msg.encode(p.conn); err != nil {
|
||||
p.s.logger.Printf("encode error: %s", err)
|
||||
t := <-p.send
|
||||
if t.msg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.s.logger.Printf("OUT :: %s :: %+v", t.msg.commandType(), t.msg.Payload)
|
||||
|
||||
t.err <- t.msg.encode(p.conn)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue