diff --git a/pkg/network/peer.go b/pkg/network/peer.go index af38fe142..025f4b7d8 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -1,8 +1,6 @@ package network import ( - "net" - "github.com/anthdm/neo-go/pkg/util" ) @@ -44,81 +42,3 @@ func (p *LocalPeer) id() uint32 { return p.nonce } func (p *LocalPeer) verack() bool { return p.isVerack } func (p *LocalPeer) addr() util.Endpoint { return p.endpoint } func (p *LocalPeer) disconnect() {} - -// TCPPeer represents a remote node, backed by TCP transport. -type TCPPeer struct { - s *Server - // nonce (id) of the peer. - nonce uint32 - // underlying TCP connection - conn net.Conn - // host and port information about this peer. - endpoint util.Endpoint - // channel to coordinate messages writen back to the connection. - send chan *Message - // whether this peers version was acknowledged. - isVerack bool -} - -// NewTCPPeer returns a pointer to a TCP Peer. -func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { - e, _ := util.EndpointFromString(conn.RemoteAddr().String()) - - return &TCPPeer{ - conn: conn, - send: make(chan *Message), - endpoint: e, - s: s, - } -} - -func (p *TCPPeer) callVersion(msg *Message) { - p.send <- msg -} - -// id implements the peer interface -func (p *TCPPeer) id() uint32 { - return p.nonce -} - -// endpoint implements the peer interface -func (p *TCPPeer) addr() util.Endpoint { - return p.endpoint -} - -// verack implements the peer interface -func (p *TCPPeer) verack() bool { - return p.isVerack -} - -// callGetaddr will send the "getaddr" command to the remote. -func (p *TCPPeer) callGetaddr(msg *Message) { - p.send <- msg -} - -func (p *TCPPeer) disconnect() { - close(p.send) - p.conn.Close() -} - -// writeLoop writes messages to the underlying TCP connection. -// A goroutine writeLoop is started for each connection. -// There should be at most one writer to a connection executing -// all writes from this goroutine. -func (p *TCPPeer) writeLoop() { - // clean up the connection. - defer func() { - p.conn.Close() - }() - - for { - msg := <-p.send - - p.s.logger.Printf("OUT :: %s :: %+v", msg.commandType(), msg.Payload) - - // should we disconnect here? - if err := msg.encode(p.conn); err != nil { - p.s.logger.Printf("encode error: %s", err) - } - } -} diff --git a/pkg/network/server.go b/pkg/network/server.go index 36dd70fa9..7e73c57d7 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -53,12 +53,6 @@ type Server struct { relay bool // TCP listener of the server listener net.Listener - - // RPC channels - versionCh chan versionTuple - getaddrCh chan getaddrTuple - invCh chan invTuple - addrCh chan addrTuple } // NewServer returns a pointer to a new server. @@ -80,10 +74,6 @@ func NewServer(net NetMode) *Server { relay: true, // currently relay is not handled. net: net, quit: make(chan struct{}), - versionCh: make(chan versionTuple), - getaddrCh: make(chan getaddrTuple), - invCh: make(chan invTuple), - addrCh: make(chan addrTuple), } return s @@ -147,45 +137,6 @@ func (s *Server) loop() { s.logger.Printf("peer %s disconnected", peer.addr()) } - // Process the received version and respond with a verack. - case t := <-s.versionCh: - if s.id == t.request.Nonce { - t.peer.disconnect() - } - if t.peer.addr().Port != t.request.Port { - t.peer.disconnect() - } - t.response <- newMessage(ModeDevNet, cmdVerack, nil) - - // Process the getaddr cmd. - case t := <-s.getaddrCh: - t.response <- &Message{} // just for now. - - // Process the addr cmd. Register peer will handle the maxPeers connected. - case t := <-s.addrCh: - for _, addr := range t.request.Addrs { - if !s.peerAlreadyConnected(addr.Addr) { - // TODO: this is not transport abstracted. - go connectToRemoteNode(s, addr.Addr.String()) - } - } - t.response <- true - - // Process inventories cmd. - case t := <-s.invCh: - if !t.request.Type.Valid() { - t.peer.disconnect() - break - } - if len(t.request.Hashes) == 0 { - t.peer.disconnect() - break - } - - payload := payload.NewInventory(t.request.Type, t.request.Hashes) - msg := newMessage(s.net, cmdGetData, payload) - t.response <- msg - case <-s.quit: s.shutdown() } @@ -202,73 +153,47 @@ func (s *Server) handlePeerConnected(p Peer) { p.callVersion(msg) } -type versionTuple struct { - peer Peer - request *payload.Version - response chan *Message -} - func (s *Server) handleVersionCmd(msg *Message, p Peer) *Message { - t := versionTuple{ - peer: p, - request: msg.Payload.(*payload.Version), - response: make(chan *Message), + version := msg.Payload.(*payload.Version) + if s.id == version.Nonce { + p.disconnect() + return nil } - - s.versionCh <- t - - return <-t.response -} - -type getaddrTuple struct { - peer Peer - request *Message - response chan *Message + if p.addr().Port != version.Port { + p.disconnect() + return nil + } + return newMessage(ModeDevNet, cmdVerack, nil) } func (s *Server) handleGetaddrCmd(msg *Message, p Peer) *Message { - t := getaddrTuple{ - peer: p, - request: msg, - response: make(chan *Message), - } - - s.getaddrCh <- t - - return <-t.response -} - -type invTuple struct { - peer Peer - request *payload.Inventory - response chan *Message + return nil } func (s *Server) handleInvCmd(msg *Message, p Peer) *Message { - t := invTuple{ - request: msg.Payload.(*payload.Inventory), - response: make(chan *Message), + inv := msg.Payload.(*payload.Inventory) + if !inv.Type.Valid() { + p.disconnect() + return nil + } + if len(inv.Hashes) == 0 { + p.disconnect() + return nil } - s.invCh <- t - - return <-t.response + payload := payload.NewInventory(inv.Type, inv.Hashes) + resp := newMessage(s.net, cmdGetData, payload) + return resp } -type addrTuple struct { - request *payload.AddressList - response chan bool -} - -func (s *Server) handleAddrCmd(msg *Message, p Peer) bool { - t := addrTuple{ - request: msg.Payload.(*payload.AddressList), - response: make(chan bool), +func (s *Server) handleAddrCmd(msg *Message, p Peer) { + addrList := msg.Payload.(*payload.AddressList) + for _, addr := range addrList.Addrs { + if !s.peerAlreadyConnected(addr.Addr) { + // TODO: this is not transport abstracted. + go connectToRemoteNode(s, addr.Addr.String()) + } } - - s.addrCh <- t - - return <-t.response } // check if the addr is already connected to the server. diff --git a/pkg/network/tcp.go b/pkg/network/tcp.go index f9ffc39e0..e8421265e 100644 --- a/pkg/network/tcp.go +++ b/pkg/network/tcp.go @@ -5,6 +5,7 @@ import ( "net" "github.com/anthdm/neo-go/pkg/network/payload" + "github.com/anthdm/neo-go/pkg/util" ) func listenTCP(s *Server, port string) error { @@ -76,7 +77,7 @@ func handleConnection(s *Server, conn net.Conn) { func handleMessage(msg *Message, s *Server, p *TCPPeer) { command := msg.commandType() - s.logger.Printf("%d :: IN :: %s :: %v", p.id(), command, msg) + s.logger.Printf("IN :: %d :: %s :: %v", p.id(), command, msg) switch command { case cmdVersion: @@ -103,3 +104,82 @@ func handleMessage(msg *Message, s *Server, p *TCPPeer) { default: } } + +// TCPPeer represents a remote node, backed by TCP transport. +type TCPPeer struct { + s *Server + // nonce (id) of the peer. + nonce uint32 + // underlying TCP connection + conn net.Conn + // host and port information about this peer. + endpoint util.Endpoint + // channel to coordinate messages writen back to the connection. + send chan *Message + // whether this peers version was acknowledged. + isVerack bool +} + +// NewTCPPeer returns a pointer to a TCP Peer. +func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { + e, _ := util.EndpointFromString(conn.RemoteAddr().String()) + + return &TCPPeer{ + conn: conn, + send: make(chan *Message), + endpoint: e, + s: s, + } +} + +func (p *TCPPeer) callVersion(msg *Message) { + p.send <- msg +} + +// id implements the peer interface +func (p *TCPPeer) id() uint32 { + return p.nonce +} + +// endpoint implements the peer interface +func (p *TCPPeer) addr() util.Endpoint { + return p.endpoint +} + +// verack implements the peer interface +func (p *TCPPeer) verack() bool { + return p.isVerack +} + +// callGetaddr will send the "getaddr" command to the remote. +func (p *TCPPeer) callGetaddr(msg *Message) { + p.send <- msg +} + +// disconnect closes the send channel and the underlying connection. +func (p *TCPPeer) disconnect() { + close(p.send) + p.conn.Close() +} + +// writeLoop writes messages to the underlying TCP connection. +// A goroutine writeLoop is started for each connection. +// There should be at most one writer to a connection executing +// all writes from this goroutine. +func (p *TCPPeer) writeLoop() { + // clean up the connection. + defer func() { + p.conn.Close() + }() + + for { + msg := <-p.send + + p.s.logger.Printf("OUT :: %s :: %+v", msg.commandType(), msg.Payload) + + // should we disconnect here? + if err := msg.encode(p.conn); err != nil { + p.s.logger.Printf("encode error: %s", err) + } + } +}