diff --git a/VERSION b/VERSION index 341cf11fa..9325c3ccd 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.2.0 \ No newline at end of file +0.3.0 \ No newline at end of file diff --git a/pkg/network/peer.go b/pkg/network/peer.go index c83b7b24c..d301da5a8 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -10,8 +10,10 @@ type Peer interface { id() uint32 addr() util.Endpoint disconnect() - callVersion(*Message) - callGetaddr(*Message) + callVersion(*Message) error + callGetaddr(*Message) error + callVerack(*Message) error + callGetdata(*Message) error } // LocalPeer is the simplest kind of peer, mapped to a server in the @@ -28,12 +30,20 @@ func NewLocalPeer(s *Server) *LocalPeer { return &LocalPeer{endpoint: e, s: s} } -func (p *LocalPeer) callVersion(msg *Message) { - p.s.handleVersionCmd(msg, p) +func (p *LocalPeer) callVersion(msg *Message) error { + return p.s.handleVersionCmd(msg, p) } -func (p *LocalPeer) callGetaddr(msg *Message) { - p.s.handleGetaddrCmd(msg, p) +func (p *LocalPeer) callVerack(msg *Message) error { + return nil +} + +func (p *LocalPeer) callGetaddr(msg *Message) error { + return p.s.handleGetaddrCmd(msg, p) +} + +func (p *LocalPeer) callGetdata(msg *Message) error { + return nil } func (p *LocalPeer) id() uint32 { return p.nonce } diff --git a/pkg/network/server.go b/pkg/network/server.go index 5bc145d36..7f5975dc9 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -1,6 +1,7 @@ package network import ( + "errors" "fmt" "log" "net" @@ -159,7 +160,7 @@ func (s *Server) shutdown() { // disconnect and remove all connected peers. for peer := range s.peers { - s.unregister <- peer + peer.disconnect() } } @@ -173,13 +174,17 @@ func (s *Server) loop() { if len(s.peers) < maxPeers { s.logger.Printf("peer registered from address %s", peer.addr()) s.peers[peer] = true - s.handlePeerConnected(peer) + + if err := s.handlePeerConnected(peer); err != nil { + s.logger.Printf("failed handling peer connection: %s", err) + peer.disconnect() + } } - // Unregister should take care of all the cleanup that has to be made. + // unregister safely deletes a peer. For disconnecting peers use the + // disconnect() method on the peer, it will call unregister and terminates its routines. case peer := <-s.unregister: if _, ok := s.peers[peer]; ok { - peer.disconnect() delete(s.peers, peer) s.logger.Printf("peer %s disconnected", peer.addr()) } @@ -196,71 +201,68 @@ func (s *Server) loop() { // When a new peer is connected we send our version. // No further communication should be made before both sides has received // the versions of eachother. -func (s *Server) handlePeerConnected(p Peer) { +func (s *Server) handlePeerConnected(p Peer) error { // TODO: get the blockheight of this server once core implemented this. payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay) msg := newMessage(s.net, cmdVersion, payload) - p.callVersion(msg) + return p.callVersion(msg) } -func (s *Server) handleVersionCmd(msg *Message, p Peer) *Message { +func (s *Server) handleVersionCmd(msg *Message, p Peer) error { version := msg.Payload.(*payload.Version) if s.id == version.Nonce { - // s.unregister <- p - return nil + return errors.New("identical nonce") } if p.addr().Port != version.Port { - // s.unregister <- p - return nil + return fmt.Errorf("port mismatch: %d", version.Port) } - return newMessage(ModeDevNet, cmdVerack, nil) + + return p.callVerack(newMessage(s.net, cmdVerack, nil)) } -func (s *Server) handleGetaddrCmd(msg *Message, p Peer) *Message { +func (s *Server) handleGetaddrCmd(msg *Message, p Peer) error { return nil } // The node can broadcast the object information it owns by this message. // The message can be sent automatically or can be used to answer getbloks messages. -func (s *Server) handleInvCmd(msg *Message, p Peer) *Message { +func (s *Server) handleInvCmd(msg *Message, p Peer) error { inv := msg.Payload.(*payload.Inventory) if !inv.Type.Valid() { - s.unregister <- p - return nil + return fmt.Errorf("invalid inventory type %s", inv.Type) } if len(inv.Hashes) == 0 { - s.unregister <- p - return nil + return errors.New("inventory should have at least 1 hash got 0") } // todo: only grab the hashes that we dont know. payload := payload.NewInventory(inv.Type, inv.Hashes) resp := newMessage(s.net, cmdGetData, payload) - return resp + + return p.callGetdata(resp) } // handleBlockCmd processes the received block. -func (s *Server) handleBlockCmd(msg *Message, p Peer) { +func (s *Server) handleBlockCmd(msg *Message, p Peer) error { block := msg.Payload.(*core.Block) hash, err := block.Hash() if err != nil { - // not quite sure what to do here. - // should we disconnect the client or just silently log and move on? - s.logger.Printf("failed to generate block hash: %s", err) - return + return err } fmt.Println(hash) if s.bc.HasBlock(hash) { - return + return nil } + + return nil } // After receiving the getaddr message, the node returns an addr message as response // and provides information about the known nodes on the network. -func (s *Server) handleAddrCmd(msg *Message, p Peer) { +func (s *Server) handleAddrCmd(msg *Message, p Peer) error { addrList := msg.Payload.(*payload.AddressList) for _, addr := range addrList.Addrs { if !s.peerAlreadyConnected(addr.Addr) { @@ -268,6 +270,7 @@ func (s *Server) handleAddrCmd(msg *Message, p Peer) { go connectToRemoteNode(s, addr.Addr.String()) } } + return nil } func (s *Server) relayInventory(inv *payload.Inventory) { diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 115c2b74c..34fbe3063 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -6,6 +6,32 @@ import ( "github.com/CityOfZion/neo-go/pkg/network/payload" ) +func TestHandleVersionFailWrongPort(t *testing.T) { + s := NewServer(ModeDevNet) + go s.loop() + + p := NewLocalPeer(s) + + version := payload.NewVersion(1337, 1, "/NEO:0.0.0/", 0, true) + msg := newMessage(ModeDevNet, cmdVersion, version) + if err := s.handleVersionCmd(msg, p); err == nil { + t.Fatal("expected error got nil") + } +} + +func TestHandleVersionFailIdenticalNonce(t *testing.T) { + s := NewServer(ModeDevNet) + go s.loop() + + p := NewLocalPeer(s) + + version := payload.NewVersion(s.id, 1, "/NEO:0.0.0/", 0, true) + msg := newMessage(ModeDevNet, cmdVersion, version) + if err := s.handleVersionCmd(msg, p); err == nil { + t.Fatal("expected error got nil") + } +} + func TestHandleVersion(t *testing.T) { s := NewServer(ModeDevNet) go s.loop() @@ -15,12 +41,8 @@ func TestHandleVersion(t *testing.T) { version := payload.NewVersion(1337, p.addr().Port, "/NEO:0.0.0/", 0, true) msg := newMessage(ModeDevNet, cmdVersion, version) - resp := s.handleVersionCmd(msg, p) - if resp.commandType() != cmdVerack { - t.Fatalf("expected response message to be verack got %s", resp.commandType()) - } - if resp.Payload != nil { - t.Fatal("verack payload should be nil") + if err := s.handleVersionCmd(msg, p); err != nil { + t.Fatal(err) } } diff --git a/pkg/network/tcp.go b/pkg/network/tcp.go index dffb56b4d..88243a9e1 100644 --- a/pkg/network/tcp.go +++ b/pkg/network/tcp.go @@ -2,6 +2,7 @@ package network import ( "bytes" + "errors" "fmt" "io" "net" @@ -52,8 +53,7 @@ func handleConnection(s *Server, conn net.Conn) { // remove the peer from connected peers and cleanup the connection. defer func() { - // all cleanup will happen in the server's loop when unregister is received. - s.unregister <- peer + peer.disconnect() }() // Start a goroutine that will handle all outgoing messages. @@ -66,17 +66,17 @@ func handleConnection(s *Server, conn net.Conn) { for { _, err := conn.Read(buf) if err == io.EOF { - break + return } if err != nil { s.logger.Printf("conn read error: %s", err) - break + return } msg := &Message{} if err := msg.decode(bytes.NewReader(buf)); err != nil { s.logger.Printf("decode error %s", err) - break + return } peer.receive <- msg @@ -85,9 +85,11 @@ func handleConnection(s *Server, conn net.Conn) { // handleMessage hands the message received from a TCP connection over to the server. func handleMessage(s *Server, p *TCPPeer) { + var err error + // Disconnect the peer when we break out of the loop. defer func() { - s.unregister <- p + p.disconnect() }() for { @@ -98,30 +100,58 @@ func handleMessage(s *Server, p *TCPPeer) { switch command { case cmdVersion: - resp := s.handleVersionCmd(msg, p) + if err = s.handleVersionCmd(msg, p); err != nil { + return + } p.nonce = msg.Payload.(*payload.Version).Nonce - p.send <- resp + + // When a node receives a connection request, it declares its version immediately. + // There will be no other communication until both sides are getting versions of each other. + // When a node receives the version message, it replies to a verack as a response immediately. + // NOTE: The current official NEO nodes dont mimic this behaviour. There is small chance that the + // official nodes will not respond directly with a verack after we sended our version. + // is this a bug? - anthdm 02/02/2018 + msgVerack := <-p.receive + if msgVerack.commandType() != cmdVerack { + s.logger.Printf("expected verack after sended out version") + return + } + + // start the protocol + go s.sendLoop(p) case cmdAddr: - s.handleAddrCmd(msg, p) + err = s.handleAddrCmd(msg, p) case cmdGetAddr: - s.handleGetaddrCmd(msg, p) + err = s.handleGetaddrCmd(msg, p) case cmdInv: - resp := s.handleInvCmd(msg, p) - p.send <- resp + err = s.handleInvCmd(msg, p) case cmdBlock: - s.handleBlockCmd(msg, p) + err = s.handleBlockCmd(msg, p) case cmdConsensus: case cmdTX: case cmdVerack: - go s.sendLoop(p) + // If we receive a verack here we disconnect. We already handled the verack + // when we sended our version. + err = errors.New("received verack twice") case cmdGetHeaders: case cmdGetBlocks: case cmdGetData: case cmdHeaders: } + + // catch all errors here and disconnect. + if err != nil { + s.logger.Printf("processing message failed: %s", err) + return + } } } +type sendTuple struct { + msg *Message + err chan error +} + // TCPPeer represents a remote node, backed by TCP transport. type TCPPeer struct { s *Server @@ -132,7 +162,7 @@ type TCPPeer struct { // host and port information about this peer. endpoint util.Endpoint // channel to coordinate messages writen back to the connection. - send chan *Message + send chan sendTuple // channel to receive from underlying connection. receive chan *Message } @@ -143,15 +173,22 @@ func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { return &TCPPeer{ conn: conn, - send: make(chan *Message), + send: make(chan sendTuple), receive: make(chan *Message), endpoint: e, s: s, } } -func (p *TCPPeer) callVersion(msg *Message) { - p.send <- msg +func (p *TCPPeer) callVersion(msg *Message) error { + t := sendTuple{ + msg: msg, + err: make(chan error), + } + + p.send <- t + + return <-t.err } // id implements the peer interface @@ -165,16 +202,51 @@ func (p *TCPPeer) addr() util.Endpoint { } // callGetaddr will send the "getaddr" command to the remote. -func (p *TCPPeer) callGetaddr(msg *Message) { - p.send <- msg +func (p *TCPPeer) callGetaddr(msg *Message) error { + t := sendTuple{ + msg: msg, + err: make(chan error), + } + + p.send <- t + + return <-t.err } -// disconnect closes the send channel and the underlying connection. -// TODO: this needs some love. We will get send on closed channel. +func (p *TCPPeer) callVerack(msg *Message) error { + t := sendTuple{ + msg: msg, + err: make(chan error), + } + + p.send <- t + + return <-t.err +} + +func (p *TCPPeer) callGetdata(msg *Message) error { + t := sendTuple{ + msg: msg, + err: make(chan error), + } + + p.send <- t + + return <-t.err +} + +// disconnect disconnects the peer, cleaning up all its resources. +// 3 goroutines needs to be cleanup (writeLoop, handleConnection and handleMessage) func (p *TCPPeer) disconnect() { - p.conn.Close() - close(p.send) - close(p.receive) + select { + case <-p.send: + case <-p.receive: + default: + close(p.send) + close(p.receive) + p.s.unregister <- p + p.conn.Close() + } } // writeLoop writes messages to the underlying TCP connection. @@ -184,17 +256,17 @@ func (p *TCPPeer) disconnect() { func (p *TCPPeer) writeLoop() { // clean up the connection. defer func() { - p.conn.Close() + p.disconnect() }() for { - msg := <-p.send - - p.s.logger.Printf("OUT :: %s :: %+v", msg.commandType(), msg.Payload) - - // should we disconnect here? - if err := msg.encode(p.conn); err != nil { - p.s.logger.Printf("encode error: %s", err) + t := <-p.send + if t.msg == nil { + return } + + p.s.logger.Printf("OUT :: %s :: %+v", t.msg.commandType(), t.msg.Payload) + + t.err <- t.msg.encode(p.conn) } }