diff --git a/pkg/network/peer.go b/pkg/network/peer.go index 830844cad..af38fe142 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -1,8 +1,6 @@ package network import ( - "fmt" - "log" "net" "github.com/anthdm/neo-go/pkg/util" @@ -12,94 +10,94 @@ import ( // be backed by any concrete transport: local, HTTP, tcp. type Peer interface { id() uint32 - endpoint() util.Endpoint - send(*Message) + addr() util.Endpoint verack() bool - verify(uint32) disconnect() + callVersion(*Message) + callGetaddr(*Message) } -// LocalPeer is a peer without any transport, mainly used for testing. +// LocalPeer is the simplest kind of peer, mapped to a server in the +// same process-space. type LocalPeer struct { - _id uint32 - _verack bool - _endpoint util.Endpoint - _send chan *Message + s *Server + nonce uint32 + isVerack bool + endpoint util.Endpoint } // NewLocalPeer return a LocalPeer. -func NewLocalPeer() *LocalPeer { +func NewLocalPeer(s *Server) *LocalPeer { e, _ := util.EndpointFromString("1.1.1.1:1111") - return &LocalPeer{_endpoint: e} + return &LocalPeer{endpoint: e, s: s} } -func (p *LocalPeer) id() uint32 { return p._id } -func (p *LocalPeer) verack() bool { return p._verack } -func (p *LocalPeer) endpoint() util.Endpoint { return p._endpoint } -func (p *LocalPeer) disconnect() {} - -func (p *LocalPeer) send(msg *Message) { - p._send <- msg +func (p *LocalPeer) callVersion(msg *Message) { + p.s.handleVersionCmd(msg, p) } -func (p *LocalPeer) verify(id uint32) { - fmt.Println(id) - p._verack = true - p._id = id +func (p *LocalPeer) callGetaddr(msg *Message) { + p.s.handleGetaddrCmd(msg, p) } +func (p *LocalPeer) id() uint32 { return p.nonce } +func (p *LocalPeer) verack() bool { return p.isVerack } +func (p *LocalPeer) addr() util.Endpoint { return p.endpoint } +func (p *LocalPeer) disconnect() {} + // TCPPeer represents a remote node, backed by TCP transport. type TCPPeer struct { - _id uint32 + s *Server + // nonce (id) of the peer. + nonce uint32 // underlying TCP connection conn net.Conn // host and port information about this peer. - _endpoint util.Endpoint + endpoint util.Endpoint // channel to coordinate messages writen back to the connection. - _send chan *Message + send chan *Message // whether this peers version was acknowledged. - _verack bool + isVerack bool } // NewTCPPeer returns a pointer to a TCP Peer. -func NewTCPPeer(conn net.Conn) *TCPPeer { +func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { e, _ := util.EndpointFromString(conn.RemoteAddr().String()) return &TCPPeer{ - conn: conn, - _send: make(chan *Message), - _endpoint: e, + conn: conn, + send: make(chan *Message), + endpoint: e, + s: s, } } +func (p *TCPPeer) callVersion(msg *Message) { + p.send <- msg +} + // id implements the peer interface func (p *TCPPeer) id() uint32 { - return p._id + return p.nonce } // endpoint implements the peer interface -func (p *TCPPeer) endpoint() util.Endpoint { - return p._endpoint +func (p *TCPPeer) addr() util.Endpoint { + return p.endpoint } // verack implements the peer interface func (p *TCPPeer) verack() bool { - return p._verack + return p.isVerack } -// verify implements the peer interface -func (p *TCPPeer) verify(id uint32) { - p._id = id - p._verack = true -} - -// send implements the peer interface -func (p *TCPPeer) send(msg *Message) { - p._send <- msg +// callGetaddr will send the "getaddr" command to the remote. +func (p *TCPPeer) callGetaddr(msg *Message) { + p.send <- msg } func (p *TCPPeer) disconnect() { - close(p._send) + close(p.send) p.conn.Close() } @@ -114,12 +112,13 @@ func (p *TCPPeer) writeLoop() { }() for { - msg := <-p._send + msg := <-p.send - rpcLogger.Printf("[SERVER] :: OUT :: %s :: %+v", msg.commandType(), msg.Payload) + p.s.logger.Printf("OUT :: %s :: %+v", msg.commandType(), msg.Payload) + // should we disconnect here? if err := msg.encode(p.conn); err != nil { - log.Printf("encode error: %s", err) + p.s.logger.Printf("encode error: %s", err) } } } diff --git a/pkg/network/server.go b/pkg/network/server.go index 30a003bfb..36dd70fa9 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -1,7 +1,6 @@ package network import ( - "errors" "fmt" "log" "net" @@ -9,7 +8,6 @@ import ( "strconv" "time" - "github.com/anthdm/neo-go/pkg/core" "github.com/anthdm/neo-go/pkg/network/payload" "github.com/anthdm/neo-go/pkg/util" ) @@ -20,11 +18,7 @@ const ( // official ports according to the protocol. portMainNet = 10333 portTestNet = 20333 -) - -var ( - // rpcLogger used for debugging RPC messages between nodes. - rpcLogger = log.New(os.Stdout, "", 0) + maxPeers = 50 ) type messageTuple struct { @@ -35,13 +29,10 @@ type messageTuple struct { // Server is the representation of a full working NEO TCP node. type Server struct { logger *log.Logger - // id of the server id uint32 - // the port the TCP listener is listening on. port uint16 - // userAgent of the server. userAgent string // The "magic" mode the server is currently running on. @@ -50,26 +41,29 @@ type Server struct { net NetMode // map that holds all connected peers to this server. peers map[Peer]bool - - register chan Peer + // channel for handling new registerd peers. + register chan Peer + // channel for safely removing and disconnecting peers. unregister chan Peer - // channel for coordinating messages. message chan messageTuple - // channel used to gracefull shutdown the server. quit chan struct{} - // Whether this server will receive and forward messages. relay bool - // TCP listener of the server listener net.Listener + + // RPC channels + versionCh chan versionTuple + getaddrCh chan getaddrTuple + invCh chan invTuple + addrCh chan addrTuple } // NewServer returns a pointer to a new server. func NewServer(net NetMode) *Server { - logger := log.New(os.Stdout, "NEO SERVER :: ", 0) + logger := log.New(os.Stdout, "[NEO SERVER] :: ", 0) if net != ModeTestNet && net != ModeMainNet && net != ModeDevNet { logger.Fatalf("invalid network mode %d", net) @@ -83,9 +77,13 @@ func NewServer(net NetMode) *Server { register: make(chan Peer), unregister: make(chan Peer), message: make(chan messageTuple), - relay: true, + relay: true, // currently relay is not handled. net: net, quit: make(chan struct{}), + versionCh: make(chan versionTuple), + getaddrCh: make(chan getaddrTuple), + invCh: make(chan invTuple), + addrCh: make(chan addrTuple), } return s @@ -131,30 +129,62 @@ func (s *Server) shutdown() { func (s *Server) loop() { for { select { + // When a new connection is been established, (by this server or remote node) + // its peer will be received on this channel. + // Any peer registration must happen via this channel. case peer := <-s.register: - // When a new connection is been established, (by this server or remote node) - // its peer will be received on this channel. - // Any peer registration must happen via this channel. - s.logger.Printf("peer registered from address %s", peer.endpoint()) - s.peers[peer] = true - s.handlePeerConnected(peer) + if len(s.peers) < maxPeers { + s.logger.Printf("peer registered from address %s", peer.addr()) + s.peers[peer] = true + s.handlePeerConnected(peer) + } + // Unregister should take care of all the cleanup that has to be made. case peer := <-s.unregister: - // unregister should take care of all the cleanup that has to be made. if _, ok := s.peers[peer]; ok { peer.disconnect() delete(s.peers, peer) - s.logger.Printf("peer %s disconnected", peer.endpoint()) + s.logger.Printf("peer %s disconnected", peer.addr()) } - case tuple := <-s.message: - // When a remote node sends data over its connection it will be received - // on this channel. - // All errors encountered should be return and handled here. - if err := s.processMessage(tuple.msg, tuple.peer); err != nil { - s.logger.Fatalf("failed to process message: %s", err) - s.unregister <- tuple.peer + // Process the received version and respond with a verack. + case t := <-s.versionCh: + if s.id == t.request.Nonce { + t.peer.disconnect() } + if t.peer.addr().Port != t.request.Port { + t.peer.disconnect() + } + t.response <- newMessage(ModeDevNet, cmdVerack, nil) + + // Process the getaddr cmd. + case t := <-s.getaddrCh: + t.response <- &Message{} // just for now. + + // Process the addr cmd. Register peer will handle the maxPeers connected. + case t := <-s.addrCh: + for _, addr := range t.request.Addrs { + if !s.peerAlreadyConnected(addr.Addr) { + // TODO: this is not transport abstracted. + go connectToRemoteNode(s, addr.Addr.String()) + } + } + t.response <- true + + // Process inventories cmd. + case t := <-s.invCh: + if !t.request.Type.Valid() { + t.peer.disconnect() + break + } + if len(t.request.Hashes) == 0 { + t.peer.disconnect() + break + } + + payload := payload.NewInventory(t.request.Type, t.request.Hashes) + msg := newMessage(s.net, cmdGetData, payload) + t.response <- msg case <-s.quit: s.shutdown() @@ -162,135 +192,100 @@ func (s *Server) loop() { } } -// processMessage processes the message received from the peer. -func (s *Server) processMessage(msg *Message, peer Peer) error { - command := msg.commandType() - - rpcLogger.Printf("[NODE %d] :: IN :: %s :: %+v", peer.id(), command, msg.Payload) - - // Disconnect if the remote is sending messages other then version - // if we didn't verack this peer. - if !peer.verack() && command != cmdVersion { - return errors.New("version noack") - } - - switch command { - case cmdVersion: - return s.handleVersionCmd(msg.Payload.(*payload.Version), peer) - case cmdVerack: - case cmdGetAddr: - // return s.handleGetAddrCmd(msg, peer) - case cmdAddr: - return s.handleAddrCmd(msg.Payload.(*payload.AddressList), peer) - case cmdGetHeaders: - case cmdHeaders: - case cmdGetBlocks: - case cmdInv: - return s.handleInvCmd(msg.Payload.(*payload.Inventory), peer) - case cmdGetData: - case cmdBlock: - return s.handleBlockCmd(msg.Payload.(*core.Block), peer) - case cmdTX: - case cmdConsensus: - default: - return fmt.Errorf("invalid RPC command received: %s", command) - } - - return nil -} - // When a new peer is connected we send our version. // No further communication should be made before both sides has received // the versions of eachother. -func (s *Server) handlePeerConnected(peer Peer) { - // TODO get heigth of block when thats implemented. +func (s *Server) handlePeerConnected(p Peer) { + // TODO: get the blockheight of this server once core implemented this. payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay) msg := newMessage(s.net, cmdVersion, payload) - - peer.send(msg) + p.callVersion(msg) } -// Version declares the server's version. -func (s *Server) handleVersionCmd(v *payload.Version, peer Peer) error { - if s.id == v.Nonce { - return errors.New("remote nonce equal to server id") - } - - if peer.endpoint().Port != v.Port { - return errors.New("port mismatch") - } - - // we respond with a verack, we successfully received peer's version - // at this point. - peer.verify(v.Nonce) - verackMsg := newMessage(s.net, cmdVerack, nil) - peer.send(verackMsg) - - go s.sendLoop(peer) - - return nil +type versionTuple struct { + peer Peer + request *payload.Version + response chan *Message } -// When the remote node reveals its known peers we try to connect to all of them. -func (s *Server) handleAddrCmd(addrList *payload.AddressList, peer Peer) error { - for _, addr := range addrList.Addrs { - if !s.peerAlreadyConnected(addr.Addr) { - go connectToRemoteNode(s, addr.Addr.String()) - } - } - return nil -} - -func (s *Server) handleInvCmd(inv *payload.Inventory, peer Peer) error { - if !inv.Type.Valid() { - return fmt.Errorf("invalid inventory type: %s", inv.Type) - } - if len(inv.Hashes) == 0 { - return nil +func (s *Server) handleVersionCmd(msg *Message, p Peer) *Message { + t := versionTuple{ + peer: p, + request: msg.Payload.(*payload.Version), + response: make(chan *Message), } - payload := payload.NewInventory(inv.Type, inv.Hashes) - msg := newMessage(s.net, cmdGetData, payload) + s.versionCh <- t - peer.send(msg) - - return nil + return <-t.response } -func (s *Server) handleBlockCmd(block *core.Block, peer Peer) error { - fmt.Println("Block received") - fmt.Printf("%+v\n", block) - return nil +type getaddrTuple struct { + peer Peer + request *Message + response chan *Message } +func (s *Server) handleGetaddrCmd(msg *Message, p Peer) *Message { + t := getaddrTuple{ + peer: p, + request: msg, + response: make(chan *Message), + } + + s.getaddrCh <- t + + return <-t.response +} + +type invTuple struct { + peer Peer + request *payload.Inventory + response chan *Message +} + +func (s *Server) handleInvCmd(msg *Message, p Peer) *Message { + t := invTuple{ + request: msg.Payload.(*payload.Inventory), + response: make(chan *Message), + } + + s.invCh <- t + + return <-t.response +} + +type addrTuple struct { + request *payload.AddressList + response chan bool +} + +func (s *Server) handleAddrCmd(msg *Message, p Peer) bool { + t := addrTuple{ + request: msg.Payload.(*payload.AddressList), + response: make(chan bool), + } + + s.addrCh <- t + + return <-t.response +} + +// check if the addr is already connected to the server. func (s *Server) peerAlreadyConnected(addr net.Addr) bool { - // TODO: check for race conditions - //s.mtx.RLock() - //defer s.mtx.RUnlock() - - // What about ourself ^^ - for peer := range s.peers { - if peer.endpoint().String() == addr.String() { + if peer.addr().String() == addr.String() { return true } } return false } -// After receiving the "getaddr" the server needs to respond with an "addr" message. -// providing information about the other nodes in the network. -// e.g. this server's connected peers. -func (s *Server) handleGetAddrCmd(msg *Message, peer *Peer) error { - // TODO - return nil -} - func (s *Server) sendLoop(peer Peer) { // TODO: check if this peer is still connected. for { getaddrMsg := newMessage(s.net, cmdGetAddr, nil) - peer.send(getaddrMsg) + peer.callGetaddr(getaddrMsg) time.Sleep(120 * time.Second) } diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 5445c3168..ad07d93e1 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -2,25 +2,39 @@ package network import ( "testing" + + "github.com/anthdm/neo-go/pkg/network/payload" ) func TestHandleVersion(t *testing.T) { - // s := NewServer(ModeDevNet) - // go s.Start(":3000", nil) + s := NewServer(ModeDevNet) + go s.loop() - // p := NewLocalPeer() - // s.register <- p + p := NewLocalPeer(s) - // version := payload.NewVersion(1337, p.endpoint().Port, "/NEO:0.0.0/.", 0, true) - // s.handleVersionCmd(version, p) + version := payload.NewVersion(1337, p.addr().Port, "/NEO:0.0.0/", 0, true) + msg := newMessage(ModeDevNet, cmdVersion, version) - // if len(s.peers) != 1 { - // t.Fatalf("expecting the server to have %d peers got %d", 1, len(s.peers)) - // } - // if p.id() != 1337 { - // t.Fatalf("expecting peer's id to be %d got %d", 1337, p._id) - // } - // if !p.verack() { - // t.Fatal("expecting peer to be verified") - // } + resp := s.handleVersionCmd(msg, p) + if resp.commandType() != cmdVerack { + t.Fatalf("expected response message to be verack got %s", resp.commandType()) + } + if resp.Payload != nil { + t.Fatal("verack payload should be nil") + } +} + +func TestHandleAddrCmd(t *testing.T) { + // todo +} + +func TestHandleGetAddrCmd(t *testing.T) { + // todo +} + +func TestHandleInv(t *testing.T) { + // todo +} +func TestHandleBlockCmd(t *testing.T) { + // todo } diff --git a/pkg/network/tcp.go b/pkg/network/tcp.go index 3a8db8501..f9ffc39e0 100644 --- a/pkg/network/tcp.go +++ b/pkg/network/tcp.go @@ -1,8 +1,10 @@ package network import ( - "io" + "bytes" "net" + + "github.com/anthdm/neo-go/pkg/network/payload" ) func listenTCP(s *Server, port string) error { @@ -31,7 +33,6 @@ func connectToRemoteNode(s *Server, address string) { } return } - s.logger.Printf("connected to %s", conn.RemoteAddr()) go handleConnection(s, conn) } @@ -42,7 +43,7 @@ func connectToSeeds(s *Server, addrs []string) { } func handleConnection(s *Server, conn net.Conn) { - peer := NewTCPPeer(conn) + peer := NewTCPPeer(conn, s) s.register <- peer // remove the peer from connected peers and cleanup the connection. @@ -54,20 +55,51 @@ func handleConnection(s *Server, conn net.Conn) { // Start a goroutine that will handle all writes to the registered peer. go peer.writeLoop() - // Read from the connection and decode it into an RPCMessage and - // tell the server there is message available for proccesing. + // Read from the connection and decode it into a Message ready for processing. + buf := make([]byte, 1024) for { - msg := &Message{} - if err := msg.decode(conn); err != nil { - // remote connection probably closed. - if err == io.EOF { - s.logger.Printf("conn read error: %s", err) - break - } - // remove this node on any decode errors. - s.logger.Printf("RPC :: decode error %s", err) + _, err := conn.Read(buf) + if err != nil { + s.logger.Printf("conn read error: %s", err) break } - s.message <- messageTuple{peer, msg} + + msg := &Message{} + if err := msg.decode(bytes.NewReader(buf)); err != nil { + s.logger.Printf("decode error %s", err) + break + } + handleMessage(msg, s, peer) + } +} + +func handleMessage(msg *Message, s *Server, p *TCPPeer) { + command := msg.commandType() + + s.logger.Printf("%d :: IN :: %s :: %v", p.id(), command, msg) + + switch command { + case cmdVersion: + resp := s.handleVersionCmd(msg, p) + p.isVerack = true + p.nonce = msg.Payload.(*payload.Version).Nonce + p.send <- resp + case cmdAddr: + s.handleAddrCmd(msg, p) + case cmdGetAddr: + s.handleGetaddrCmd(msg, p) + case cmdInv: + resp := s.handleInvCmd(msg, p) + p.send <- resp + case cmdBlock: + case cmdConsensus: + case cmdTX: + case cmdVerack: + go s.sendLoop(p) + case cmdGetHeaders: + case cmdGetBlocks: + case cmdGetData: + case cmdHeaders: + default: } }