Refactor tcp transport (#11)

* refactored tcp transport

* return errors on outgoing messages

* TCP transport should report its error after reading from connection

* handle error returned from peer transport

* bump version

* cleaned up error
This commit is contained in:
Anthony De Meulemeester 2018-02-02 11:02:25 +01:00 committed by GitHub
parent 4050dbeeb8
commit 6e3f1ec43e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 179 additions and 72 deletions

View file

@ -1 +1 @@
0.2.0
0.3.0

View file

@ -10,8 +10,10 @@ type Peer interface {
id() uint32
addr() util.Endpoint
disconnect()
callVersion(*Message)
callGetaddr(*Message)
callVersion(*Message) error
callGetaddr(*Message) error
callVerack(*Message) error
callGetdata(*Message) error
}
// LocalPeer is the simplest kind of peer, mapped to a server in the
@ -28,12 +30,20 @@ func NewLocalPeer(s *Server) *LocalPeer {
return &LocalPeer{endpoint: e, s: s}
}
func (p *LocalPeer) callVersion(msg *Message) {
p.s.handleVersionCmd(msg, p)
func (p *LocalPeer) callVersion(msg *Message) error {
return p.s.handleVersionCmd(msg, p)
}
func (p *LocalPeer) callGetaddr(msg *Message) {
p.s.handleGetaddrCmd(msg, p)
func (p *LocalPeer) callVerack(msg *Message) error {
return nil
}
func (p *LocalPeer) callGetaddr(msg *Message) error {
return p.s.handleGetaddrCmd(msg, p)
}
func (p *LocalPeer) callGetdata(msg *Message) error {
return nil
}
func (p *LocalPeer) id() uint32 { return p.nonce }

View file

@ -1,6 +1,7 @@
package network
import (
"errors"
"fmt"
"log"
"net"
@ -159,7 +160,7 @@ func (s *Server) shutdown() {
// disconnect and remove all connected peers.
for peer := range s.peers {
s.unregister <- peer
peer.disconnect()
}
}
@ -173,13 +174,17 @@ func (s *Server) loop() {
if len(s.peers) < maxPeers {
s.logger.Printf("peer registered from address %s", peer.addr())
s.peers[peer] = true
s.handlePeerConnected(peer)
if err := s.handlePeerConnected(peer); err != nil {
s.logger.Printf("failed handling peer connection: %s", err)
peer.disconnect()
}
}
// Unregister should take care of all the cleanup that has to be made.
// unregister safely deletes a peer. For disconnecting peers use the
// disconnect() method on the peer, it will call unregister and terminates its routines.
case peer := <-s.unregister:
if _, ok := s.peers[peer]; ok {
peer.disconnect()
delete(s.peers, peer)
s.logger.Printf("peer %s disconnected", peer.addr())
}
@ -196,71 +201,68 @@ func (s *Server) loop() {
// When a new peer is connected we send our version.
// No further communication should be made before both sides has received
// the versions of eachother.
func (s *Server) handlePeerConnected(p Peer) {
func (s *Server) handlePeerConnected(p Peer) error {
// TODO: get the blockheight of this server once core implemented this.
payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay)
msg := newMessage(s.net, cmdVersion, payload)
p.callVersion(msg)
return p.callVersion(msg)
}
func (s *Server) handleVersionCmd(msg *Message, p Peer) *Message {
func (s *Server) handleVersionCmd(msg *Message, p Peer) error {
version := msg.Payload.(*payload.Version)
if s.id == version.Nonce {
// s.unregister <- p
return nil
return errors.New("identical nonce")
}
if p.addr().Port != version.Port {
// s.unregister <- p
return nil
}
return newMessage(ModeDevNet, cmdVerack, nil)
return fmt.Errorf("port mismatch: %d", version.Port)
}
func (s *Server) handleGetaddrCmd(msg *Message, p Peer) *Message {
return p.callVerack(newMessage(s.net, cmdVerack, nil))
}
func (s *Server) handleGetaddrCmd(msg *Message, p Peer) error {
return nil
}
// The node can broadcast the object information it owns by this message.
// The message can be sent automatically or can be used to answer getbloks messages.
func (s *Server) handleInvCmd(msg *Message, p Peer) *Message {
func (s *Server) handleInvCmd(msg *Message, p Peer) error {
inv := msg.Payload.(*payload.Inventory)
if !inv.Type.Valid() {
s.unregister <- p
return nil
return fmt.Errorf("invalid inventory type %s", inv.Type)
}
if len(inv.Hashes) == 0 {
s.unregister <- p
return nil
return errors.New("inventory should have at least 1 hash got 0")
}
// todo: only grab the hashes that we dont know.
payload := payload.NewInventory(inv.Type, inv.Hashes)
resp := newMessage(s.net, cmdGetData, payload)
return resp
return p.callGetdata(resp)
}
// handleBlockCmd processes the received block.
func (s *Server) handleBlockCmd(msg *Message, p Peer) {
func (s *Server) handleBlockCmd(msg *Message, p Peer) error {
block := msg.Payload.(*core.Block)
hash, err := block.Hash()
if err != nil {
// not quite sure what to do here.
// should we disconnect the client or just silently log and move on?
s.logger.Printf("failed to generate block hash: %s", err)
return
return err
}
fmt.Println(hash)
if s.bc.HasBlock(hash) {
return
return nil
}
return nil
}
// After receiving the getaddr message, the node returns an addr message as response
// and provides information about the known nodes on the network.
func (s *Server) handleAddrCmd(msg *Message, p Peer) {
func (s *Server) handleAddrCmd(msg *Message, p Peer) error {
addrList := msg.Payload.(*payload.AddressList)
for _, addr := range addrList.Addrs {
if !s.peerAlreadyConnected(addr.Addr) {
@ -268,6 +270,7 @@ func (s *Server) handleAddrCmd(msg *Message, p Peer) {
go connectToRemoteNode(s, addr.Addr.String())
}
}
return nil
}
func (s *Server) relayInventory(inv *payload.Inventory) {

View file

@ -6,6 +6,32 @@ import (
"github.com/CityOfZion/neo-go/pkg/network/payload"
)
func TestHandleVersionFailWrongPort(t *testing.T) {
s := NewServer(ModeDevNet)
go s.loop()
p := NewLocalPeer(s)
version := payload.NewVersion(1337, 1, "/NEO:0.0.0/", 0, true)
msg := newMessage(ModeDevNet, cmdVersion, version)
if err := s.handleVersionCmd(msg, p); err == nil {
t.Fatal("expected error got nil")
}
}
func TestHandleVersionFailIdenticalNonce(t *testing.T) {
s := NewServer(ModeDevNet)
go s.loop()
p := NewLocalPeer(s)
version := payload.NewVersion(s.id, 1, "/NEO:0.0.0/", 0, true)
msg := newMessage(ModeDevNet, cmdVersion, version)
if err := s.handleVersionCmd(msg, p); err == nil {
t.Fatal("expected error got nil")
}
}
func TestHandleVersion(t *testing.T) {
s := NewServer(ModeDevNet)
go s.loop()
@ -15,12 +41,8 @@ func TestHandleVersion(t *testing.T) {
version := payload.NewVersion(1337, p.addr().Port, "/NEO:0.0.0/", 0, true)
msg := newMessage(ModeDevNet, cmdVersion, version)
resp := s.handleVersionCmd(msg, p)
if resp.commandType() != cmdVerack {
t.Fatalf("expected response message to be verack got %s", resp.commandType())
}
if resp.Payload != nil {
t.Fatal("verack payload should be nil")
if err := s.handleVersionCmd(msg, p); err != nil {
t.Fatal(err)
}
}

View file

@ -2,6 +2,7 @@ package network
import (
"bytes"
"errors"
"fmt"
"io"
"net"
@ -52,8 +53,7 @@ func handleConnection(s *Server, conn net.Conn) {
// remove the peer from connected peers and cleanup the connection.
defer func() {
// all cleanup will happen in the server's loop when unregister is received.
s.unregister <- peer
peer.disconnect()
}()
// Start a goroutine that will handle all outgoing messages.
@ -66,17 +66,17 @@ func handleConnection(s *Server, conn net.Conn) {
for {
_, err := conn.Read(buf)
if err == io.EOF {
break
return
}
if err != nil {
s.logger.Printf("conn read error: %s", err)
break
return
}
msg := &Message{}
if err := msg.decode(bytes.NewReader(buf)); err != nil {
s.logger.Printf("decode error %s", err)
break
return
}
peer.receive <- msg
@ -85,9 +85,11 @@ func handleConnection(s *Server, conn net.Conn) {
// handleMessage hands the message received from a TCP connection over to the server.
func handleMessage(s *Server, p *TCPPeer) {
var err error
// Disconnect the peer when we break out of the loop.
defer func() {
s.unregister <- p
p.disconnect()
}()
for {
@ -98,29 +100,57 @@ func handleMessage(s *Server, p *TCPPeer) {
switch command {
case cmdVersion:
resp := s.handleVersionCmd(msg, p)
if err = s.handleVersionCmd(msg, p); err != nil {
return
}
p.nonce = msg.Payload.(*payload.Version).Nonce
p.send <- resp
// When a node receives a connection request, it declares its version immediately.
// There will be no other communication until both sides are getting versions of each other.
// When a node receives the version message, it replies to a verack as a response immediately.
// NOTE: The current official NEO nodes dont mimic this behaviour. There is small chance that the
// official nodes will not respond directly with a verack after we sended our version.
// is this a bug? - anthdm 02/02/2018
msgVerack := <-p.receive
if msgVerack.commandType() != cmdVerack {
s.logger.Printf("expected verack after sended out version")
return
}
// start the protocol
go s.sendLoop(p)
case cmdAddr:
s.handleAddrCmd(msg, p)
err = s.handleAddrCmd(msg, p)
case cmdGetAddr:
s.handleGetaddrCmd(msg, p)
err = s.handleGetaddrCmd(msg, p)
case cmdInv:
resp := s.handleInvCmd(msg, p)
p.send <- resp
err = s.handleInvCmd(msg, p)
case cmdBlock:
s.handleBlockCmd(msg, p)
err = s.handleBlockCmd(msg, p)
case cmdConsensus:
case cmdTX:
case cmdVerack:
go s.sendLoop(p)
// If we receive a verack here we disconnect. We already handled the verack
// when we sended our version.
err = errors.New("received verack twice")
case cmdGetHeaders:
case cmdGetBlocks:
case cmdGetData:
case cmdHeaders:
}
// catch all errors here and disconnect.
if err != nil {
s.logger.Printf("processing message failed: %s", err)
return
}
}
}
type sendTuple struct {
msg *Message
err chan error
}
// TCPPeer represents a remote node, backed by TCP transport.
type TCPPeer struct {
@ -132,7 +162,7 @@ type TCPPeer struct {
// host and port information about this peer.
endpoint util.Endpoint
// channel to coordinate messages writen back to the connection.
send chan *Message
send chan sendTuple
// channel to receive from underlying connection.
receive chan *Message
}
@ -143,15 +173,22 @@ func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer {
return &TCPPeer{
conn: conn,
send: make(chan *Message),
send: make(chan sendTuple),
receive: make(chan *Message),
endpoint: e,
s: s,
}
}
func (p *TCPPeer) callVersion(msg *Message) {
p.send <- msg
func (p *TCPPeer) callVersion(msg *Message) error {
t := sendTuple{
msg: msg,
err: make(chan error),
}
p.send <- t
return <-t.err
}
// id implements the peer interface
@ -165,16 +202,51 @@ func (p *TCPPeer) addr() util.Endpoint {
}
// callGetaddr will send the "getaddr" command to the remote.
func (p *TCPPeer) callGetaddr(msg *Message) {
p.send <- msg
func (p *TCPPeer) callGetaddr(msg *Message) error {
t := sendTuple{
msg: msg,
err: make(chan error),
}
// disconnect closes the send channel and the underlying connection.
// TODO: this needs some love. We will get send on closed channel.
p.send <- t
return <-t.err
}
func (p *TCPPeer) callVerack(msg *Message) error {
t := sendTuple{
msg: msg,
err: make(chan error),
}
p.send <- t
return <-t.err
}
func (p *TCPPeer) callGetdata(msg *Message) error {
t := sendTuple{
msg: msg,
err: make(chan error),
}
p.send <- t
return <-t.err
}
// disconnect disconnects the peer, cleaning up all its resources.
// 3 goroutines needs to be cleanup (writeLoop, handleConnection and handleMessage)
func (p *TCPPeer) disconnect() {
p.conn.Close()
select {
case <-p.send:
case <-p.receive:
default:
close(p.send)
close(p.receive)
p.s.unregister <- p
p.conn.Close()
}
}
// writeLoop writes messages to the underlying TCP connection.
@ -184,17 +256,17 @@ func (p *TCPPeer) disconnect() {
func (p *TCPPeer) writeLoop() {
// clean up the connection.
defer func() {
p.conn.Close()
p.disconnect()
}()
for {
msg := <-p.send
t := <-p.send
if t.msg == nil {
return
}
p.s.logger.Printf("OUT :: %s :: %+v", msg.commandType(), msg.Payload)
p.s.logger.Printf("OUT :: %s :: %+v", t.msg.commandType(), t.msg.Payload)
// should we disconnect here?
if err := msg.encode(p.conn); err != nil {
p.s.logger.Printf("encode error: %s", err)
}
t.err <- t.msg.encode(p.conn)
}
}