Refactor tcp transport (#11)

* refactored tcp transport

* return errors on outgoing messages

* TCP transport should report its error after reading from connection

* handle error returned from peer transport

* bump version

* cleaned up error
This commit is contained in:
Anthony De Meulemeester 2018-02-02 11:02:25 +01:00 committed by GitHub
parent 4050dbeeb8
commit 6e3f1ec43e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 179 additions and 72 deletions

View file

@ -1 +1 @@
0.2.0 0.3.0

View file

@ -10,8 +10,10 @@ type Peer interface {
id() uint32 id() uint32
addr() util.Endpoint addr() util.Endpoint
disconnect() disconnect()
callVersion(*Message) callVersion(*Message) error
callGetaddr(*Message) callGetaddr(*Message) error
callVerack(*Message) error
callGetdata(*Message) error
} }
// LocalPeer is the simplest kind of peer, mapped to a server in the // LocalPeer is the simplest kind of peer, mapped to a server in the
@ -28,12 +30,20 @@ func NewLocalPeer(s *Server) *LocalPeer {
return &LocalPeer{endpoint: e, s: s} return &LocalPeer{endpoint: e, s: s}
} }
func (p *LocalPeer) callVersion(msg *Message) { func (p *LocalPeer) callVersion(msg *Message) error {
p.s.handleVersionCmd(msg, p) return p.s.handleVersionCmd(msg, p)
} }
func (p *LocalPeer) callGetaddr(msg *Message) { func (p *LocalPeer) callVerack(msg *Message) error {
p.s.handleGetaddrCmd(msg, p) return nil
}
func (p *LocalPeer) callGetaddr(msg *Message) error {
return p.s.handleGetaddrCmd(msg, p)
}
func (p *LocalPeer) callGetdata(msg *Message) error {
return nil
} }
func (p *LocalPeer) id() uint32 { return p.nonce } func (p *LocalPeer) id() uint32 { return p.nonce }

View file

@ -1,6 +1,7 @@
package network package network
import ( import (
"errors"
"fmt" "fmt"
"log" "log"
"net" "net"
@ -159,7 +160,7 @@ func (s *Server) shutdown() {
// disconnect and remove all connected peers. // disconnect and remove all connected peers.
for peer := range s.peers { for peer := range s.peers {
s.unregister <- peer peer.disconnect()
} }
} }
@ -173,13 +174,17 @@ func (s *Server) loop() {
if len(s.peers) < maxPeers { if len(s.peers) < maxPeers {
s.logger.Printf("peer registered from address %s", peer.addr()) s.logger.Printf("peer registered from address %s", peer.addr())
s.peers[peer] = true s.peers[peer] = true
s.handlePeerConnected(peer)
if err := s.handlePeerConnected(peer); err != nil {
s.logger.Printf("failed handling peer connection: %s", err)
peer.disconnect()
}
} }
// Unregister should take care of all the cleanup that has to be made. // unregister safely deletes a peer. For disconnecting peers use the
// disconnect() method on the peer, it will call unregister and terminates its routines.
case peer := <-s.unregister: case peer := <-s.unregister:
if _, ok := s.peers[peer]; ok { if _, ok := s.peers[peer]; ok {
peer.disconnect()
delete(s.peers, peer) delete(s.peers, peer)
s.logger.Printf("peer %s disconnected", peer.addr()) s.logger.Printf("peer %s disconnected", peer.addr())
} }
@ -196,71 +201,68 @@ func (s *Server) loop() {
// When a new peer is connected we send our version. // When a new peer is connected we send our version.
// No further communication should be made before both sides has received // No further communication should be made before both sides has received
// the versions of eachother. // the versions of eachother.
func (s *Server) handlePeerConnected(p Peer) { func (s *Server) handlePeerConnected(p Peer) error {
// TODO: get the blockheight of this server once core implemented this. // TODO: get the blockheight of this server once core implemented this.
payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay) payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay)
msg := newMessage(s.net, cmdVersion, payload) msg := newMessage(s.net, cmdVersion, payload)
p.callVersion(msg) return p.callVersion(msg)
} }
func (s *Server) handleVersionCmd(msg *Message, p Peer) *Message { func (s *Server) handleVersionCmd(msg *Message, p Peer) error {
version := msg.Payload.(*payload.Version) version := msg.Payload.(*payload.Version)
if s.id == version.Nonce { if s.id == version.Nonce {
// s.unregister <- p return errors.New("identical nonce")
return nil
} }
if p.addr().Port != version.Port { if p.addr().Port != version.Port {
// s.unregister <- p return fmt.Errorf("port mismatch: %d", version.Port)
return nil
} }
return newMessage(ModeDevNet, cmdVerack, nil)
return p.callVerack(newMessage(s.net, cmdVerack, nil))
} }
func (s *Server) handleGetaddrCmd(msg *Message, p Peer) *Message { func (s *Server) handleGetaddrCmd(msg *Message, p Peer) error {
return nil return nil
} }
// The node can broadcast the object information it owns by this message. // The node can broadcast the object information it owns by this message.
// The message can be sent automatically or can be used to answer getbloks messages. // The message can be sent automatically or can be used to answer getbloks messages.
func (s *Server) handleInvCmd(msg *Message, p Peer) *Message { func (s *Server) handleInvCmd(msg *Message, p Peer) error {
inv := msg.Payload.(*payload.Inventory) inv := msg.Payload.(*payload.Inventory)
if !inv.Type.Valid() { if !inv.Type.Valid() {
s.unregister <- p return fmt.Errorf("invalid inventory type %s", inv.Type)
return nil
} }
if len(inv.Hashes) == 0 { if len(inv.Hashes) == 0 {
s.unregister <- p return errors.New("inventory should have at least 1 hash got 0")
return nil
} }
// todo: only grab the hashes that we dont know. // todo: only grab the hashes that we dont know.
payload := payload.NewInventory(inv.Type, inv.Hashes) payload := payload.NewInventory(inv.Type, inv.Hashes)
resp := newMessage(s.net, cmdGetData, payload) resp := newMessage(s.net, cmdGetData, payload)
return resp
return p.callGetdata(resp)
} }
// handleBlockCmd processes the received block. // handleBlockCmd processes the received block.
func (s *Server) handleBlockCmd(msg *Message, p Peer) { func (s *Server) handleBlockCmd(msg *Message, p Peer) error {
block := msg.Payload.(*core.Block) block := msg.Payload.(*core.Block)
hash, err := block.Hash() hash, err := block.Hash()
if err != nil { if err != nil {
// not quite sure what to do here. return err
// should we disconnect the client or just silently log and move on?
s.logger.Printf("failed to generate block hash: %s", err)
return
} }
fmt.Println(hash) fmt.Println(hash)
if s.bc.HasBlock(hash) { if s.bc.HasBlock(hash) {
return return nil
} }
return nil
} }
// After receiving the getaddr message, the node returns an addr message as response // After receiving the getaddr message, the node returns an addr message as response
// and provides information about the known nodes on the network. // and provides information about the known nodes on the network.
func (s *Server) handleAddrCmd(msg *Message, p Peer) { func (s *Server) handleAddrCmd(msg *Message, p Peer) error {
addrList := msg.Payload.(*payload.AddressList) addrList := msg.Payload.(*payload.AddressList)
for _, addr := range addrList.Addrs { for _, addr := range addrList.Addrs {
if !s.peerAlreadyConnected(addr.Addr) { if !s.peerAlreadyConnected(addr.Addr) {
@ -268,6 +270,7 @@ func (s *Server) handleAddrCmd(msg *Message, p Peer) {
go connectToRemoteNode(s, addr.Addr.String()) go connectToRemoteNode(s, addr.Addr.String())
} }
} }
return nil
} }
func (s *Server) relayInventory(inv *payload.Inventory) { func (s *Server) relayInventory(inv *payload.Inventory) {

View file

@ -6,6 +6,32 @@ import (
"github.com/CityOfZion/neo-go/pkg/network/payload" "github.com/CityOfZion/neo-go/pkg/network/payload"
) )
func TestHandleVersionFailWrongPort(t *testing.T) {
s := NewServer(ModeDevNet)
go s.loop()
p := NewLocalPeer(s)
version := payload.NewVersion(1337, 1, "/NEO:0.0.0/", 0, true)
msg := newMessage(ModeDevNet, cmdVersion, version)
if err := s.handleVersionCmd(msg, p); err == nil {
t.Fatal("expected error got nil")
}
}
func TestHandleVersionFailIdenticalNonce(t *testing.T) {
s := NewServer(ModeDevNet)
go s.loop()
p := NewLocalPeer(s)
version := payload.NewVersion(s.id, 1, "/NEO:0.0.0/", 0, true)
msg := newMessage(ModeDevNet, cmdVersion, version)
if err := s.handleVersionCmd(msg, p); err == nil {
t.Fatal("expected error got nil")
}
}
func TestHandleVersion(t *testing.T) { func TestHandleVersion(t *testing.T) {
s := NewServer(ModeDevNet) s := NewServer(ModeDevNet)
go s.loop() go s.loop()
@ -15,12 +41,8 @@ func TestHandleVersion(t *testing.T) {
version := payload.NewVersion(1337, p.addr().Port, "/NEO:0.0.0/", 0, true) version := payload.NewVersion(1337, p.addr().Port, "/NEO:0.0.0/", 0, true)
msg := newMessage(ModeDevNet, cmdVersion, version) msg := newMessage(ModeDevNet, cmdVersion, version)
resp := s.handleVersionCmd(msg, p) if err := s.handleVersionCmd(msg, p); err != nil {
if resp.commandType() != cmdVerack { t.Fatal(err)
t.Fatalf("expected response message to be verack got %s", resp.commandType())
}
if resp.Payload != nil {
t.Fatal("verack payload should be nil")
} }
} }

View file

@ -2,6 +2,7 @@ package network
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -52,8 +53,7 @@ func handleConnection(s *Server, conn net.Conn) {
// remove the peer from connected peers and cleanup the connection. // remove the peer from connected peers and cleanup the connection.
defer func() { defer func() {
// all cleanup will happen in the server's loop when unregister is received. peer.disconnect()
s.unregister <- peer
}() }()
// Start a goroutine that will handle all outgoing messages. // Start a goroutine that will handle all outgoing messages.
@ -66,17 +66,17 @@ func handleConnection(s *Server, conn net.Conn) {
for { for {
_, err := conn.Read(buf) _, err := conn.Read(buf)
if err == io.EOF { if err == io.EOF {
break return
} }
if err != nil { if err != nil {
s.logger.Printf("conn read error: %s", err) s.logger.Printf("conn read error: %s", err)
break return
} }
msg := &Message{} msg := &Message{}
if err := msg.decode(bytes.NewReader(buf)); err != nil { if err := msg.decode(bytes.NewReader(buf)); err != nil {
s.logger.Printf("decode error %s", err) s.logger.Printf("decode error %s", err)
break return
} }
peer.receive <- msg peer.receive <- msg
@ -85,9 +85,11 @@ func handleConnection(s *Server, conn net.Conn) {
// handleMessage hands the message received from a TCP connection over to the server. // handleMessage hands the message received from a TCP connection over to the server.
func handleMessage(s *Server, p *TCPPeer) { func handleMessage(s *Server, p *TCPPeer) {
var err error
// Disconnect the peer when we break out of the loop. // Disconnect the peer when we break out of the loop.
defer func() { defer func() {
s.unregister <- p p.disconnect()
}() }()
for { for {
@ -98,30 +100,58 @@ func handleMessage(s *Server, p *TCPPeer) {
switch command { switch command {
case cmdVersion: case cmdVersion:
resp := s.handleVersionCmd(msg, p) if err = s.handleVersionCmd(msg, p); err != nil {
return
}
p.nonce = msg.Payload.(*payload.Version).Nonce p.nonce = msg.Payload.(*payload.Version).Nonce
p.send <- resp
// When a node receives a connection request, it declares its version immediately.
// There will be no other communication until both sides are getting versions of each other.
// When a node receives the version message, it replies to a verack as a response immediately.
// NOTE: The current official NEO nodes dont mimic this behaviour. There is small chance that the
// official nodes will not respond directly with a verack after we sended our version.
// is this a bug? - anthdm 02/02/2018
msgVerack := <-p.receive
if msgVerack.commandType() != cmdVerack {
s.logger.Printf("expected verack after sended out version")
return
}
// start the protocol
go s.sendLoop(p)
case cmdAddr: case cmdAddr:
s.handleAddrCmd(msg, p) err = s.handleAddrCmd(msg, p)
case cmdGetAddr: case cmdGetAddr:
s.handleGetaddrCmd(msg, p) err = s.handleGetaddrCmd(msg, p)
case cmdInv: case cmdInv:
resp := s.handleInvCmd(msg, p) err = s.handleInvCmd(msg, p)
p.send <- resp
case cmdBlock: case cmdBlock:
s.handleBlockCmd(msg, p) err = s.handleBlockCmd(msg, p)
case cmdConsensus: case cmdConsensus:
case cmdTX: case cmdTX:
case cmdVerack: case cmdVerack:
go s.sendLoop(p) // If we receive a verack here we disconnect. We already handled the verack
// when we sended our version.
err = errors.New("received verack twice")
case cmdGetHeaders: case cmdGetHeaders:
case cmdGetBlocks: case cmdGetBlocks:
case cmdGetData: case cmdGetData:
case cmdHeaders: case cmdHeaders:
} }
// catch all errors here and disconnect.
if err != nil {
s.logger.Printf("processing message failed: %s", err)
return
}
} }
} }
type sendTuple struct {
msg *Message
err chan error
}
// TCPPeer represents a remote node, backed by TCP transport. // TCPPeer represents a remote node, backed by TCP transport.
type TCPPeer struct { type TCPPeer struct {
s *Server s *Server
@ -132,7 +162,7 @@ type TCPPeer struct {
// host and port information about this peer. // host and port information about this peer.
endpoint util.Endpoint endpoint util.Endpoint
// channel to coordinate messages writen back to the connection. // channel to coordinate messages writen back to the connection.
send chan *Message send chan sendTuple
// channel to receive from underlying connection. // channel to receive from underlying connection.
receive chan *Message receive chan *Message
} }
@ -143,15 +173,22 @@ func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer {
return &TCPPeer{ return &TCPPeer{
conn: conn, conn: conn,
send: make(chan *Message), send: make(chan sendTuple),
receive: make(chan *Message), receive: make(chan *Message),
endpoint: e, endpoint: e,
s: s, s: s,
} }
} }
func (p *TCPPeer) callVersion(msg *Message) { func (p *TCPPeer) callVersion(msg *Message) error {
p.send <- msg t := sendTuple{
msg: msg,
err: make(chan error),
}
p.send <- t
return <-t.err
} }
// id implements the peer interface // id implements the peer interface
@ -165,16 +202,51 @@ func (p *TCPPeer) addr() util.Endpoint {
} }
// callGetaddr will send the "getaddr" command to the remote. // callGetaddr will send the "getaddr" command to the remote.
func (p *TCPPeer) callGetaddr(msg *Message) { func (p *TCPPeer) callGetaddr(msg *Message) error {
p.send <- msg t := sendTuple{
msg: msg,
err: make(chan error),
}
p.send <- t
return <-t.err
} }
// disconnect closes the send channel and the underlying connection. func (p *TCPPeer) callVerack(msg *Message) error {
// TODO: this needs some love. We will get send on closed channel. t := sendTuple{
msg: msg,
err: make(chan error),
}
p.send <- t
return <-t.err
}
func (p *TCPPeer) callGetdata(msg *Message) error {
t := sendTuple{
msg: msg,
err: make(chan error),
}
p.send <- t
return <-t.err
}
// disconnect disconnects the peer, cleaning up all its resources.
// 3 goroutines needs to be cleanup (writeLoop, handleConnection and handleMessage)
func (p *TCPPeer) disconnect() { func (p *TCPPeer) disconnect() {
p.conn.Close() select {
close(p.send) case <-p.send:
close(p.receive) case <-p.receive:
default:
close(p.send)
close(p.receive)
p.s.unregister <- p
p.conn.Close()
}
} }
// writeLoop writes messages to the underlying TCP connection. // writeLoop writes messages to the underlying TCP connection.
@ -184,17 +256,17 @@ func (p *TCPPeer) disconnect() {
func (p *TCPPeer) writeLoop() { func (p *TCPPeer) writeLoop() {
// clean up the connection. // clean up the connection.
defer func() { defer func() {
p.conn.Close() p.disconnect()
}() }()
for { for {
msg := <-p.send t := <-p.send
if t.msg == nil {
p.s.logger.Printf("OUT :: %s :: %+v", msg.commandType(), msg.Payload) return
// should we disconnect here?
if err := msg.encode(p.conn); err != nil {
p.s.logger.Printf("encode error: %s", err)
} }
p.s.logger.Printf("OUT :: %s :: %+v", t.msg.commandType(), t.msg.Payload)
t.err <- t.msg.encode(p.conn)
} }
} }