diff --git a/pkg/network/payload/addr.go b/pkg/network/payload/addr.go index 880af540d..0403161ec 100644 --- a/pkg/network/payload/addr.go +++ b/pkg/network/payload/addr.go @@ -2,66 +2,21 @@ package payload import ( "encoding/binary" - "fmt" "io" - "strconv" - "strings" + + "github.com/anthdm/neo-go/pkg/util" ) -// Endpoint host + port of a node, compatible with net.Addr. -type Endpoint struct { - IP [16]byte // TODO: make a uint128 type - Port uint16 -} - -// EndpointFromString returns an Endpoint from the given string. -// For now this only handles the most simple hostport form. -// e.g. 127.0.0.1:3000 -// This should be enough to work with for now. -func EndpointFromString(s string) (Endpoint, error) { - hostPort := strings.Split(s, ":") - if len(hostPort) != 2 { - return Endpoint{}, fmt.Errorf("invalid address string: %s", s) - } - host := hostPort[0] - port := hostPort[1] - - ch := strings.Split(host, ".") - - buf := [16]byte{} - var n int - for i := 0; i < len(ch); i++ { - n = 12 + i - nn, _ := strconv.Atoi(ch[i]) - buf[n] = byte(nn) - } - - p, _ := strconv.Atoi(port) - - return Endpoint{buf, uint16(p)}, nil -} - -// Network implements the net.Addr interface. -func (e Endpoint) Network() string { return "tcp" } - -// String implements the net.Addr interface. -func (e Endpoint) String() string { - b := make([]uint8, 4) - for i := 0; i < 4; i++ { - b[i] = byte(e.IP[len(e.IP)-4+i]) - } - return fmt.Sprintf("%d.%d.%d.%d:%d", b[0], b[1], b[2], b[3], e.Port) -} - // AddrWithTime payload type AddrWithTime struct { // Timestamp the node connected to the network. Timestamp uint32 Services uint64 - Addr Endpoint + Addr util.Endpoint } -func NewAddrWithTime(addr Endpoint) *AddrWithTime { +// NewAddrWithTime return a pointer to AddrWithTime. +func NewAddrWithTime(addr util.Endpoint) *AddrWithTime { return &AddrWithTime{ Timestamp: 1337, Services: 1, diff --git a/pkg/network/payload/addr_test.go b/pkg/network/payload/addr_test.go index 35538aa33..c021b72df 100644 --- a/pkg/network/payload/addr_test.go +++ b/pkg/network/payload/addr_test.go @@ -5,10 +5,12 @@ import ( "fmt" "reflect" "testing" + + "github.com/anthdm/neo-go/pkg/util" ) func TestEncodeDecodeAddr(t *testing.T) { - e, err := EndpointFromString("127.0.0.1:2000") + e, err := util.EndpointFromString("127.0.0.1:2000") if err != nil { t.Fatal(err) } @@ -33,7 +35,7 @@ func TestEncodeDecodeAddressList(t *testing.T) { var lenList uint8 = 4 addrs := make([]*AddrWithTime, lenList) for i := 0; i < int(lenList); i++ { - e, _ := EndpointFromString(fmt.Sprintf("127.0.0.1:200%d", i)) + e, _ := util.EndpointFromString(fmt.Sprintf("127.0.0.1:200%d", i)) addrs[i] = NewAddrWithTime(e) } diff --git a/pkg/network/peer.go b/pkg/network/peer.go index d0c54dcb1..544b7cfe2 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -3,6 +3,8 @@ package network import ( "log" "net" + + "github.com/anthdm/neo-go/pkg/util" ) // Peer represents a remote node, backed by TCP transport. @@ -10,20 +12,22 @@ type Peer struct { id uint32 // underlying TCP connection conn net.Conn - // channel to coordinate message writes back to the connection. + // host and port information about this peer. + endpoint util.Endpoint + // channel to coordinate messages writen back to the connection. send chan *Message - // verack is true if this node has sended it's version. + // whether this peers version was acknowledged. verack bool - // whether we or him made the initial connection. - initiator bool } // NewPeer returns a (TCP) Peer. -func NewPeer(conn net.Conn, init bool) *Peer { +func NewPeer(conn net.Conn) *Peer { + e, _ := util.EndpointFromString(conn.RemoteAddr().String()) + return &Peer{ - conn: conn, - send: make(chan *Message), - initiator: init, + conn: conn, + send: make(chan *Message), + endpoint: e, } } diff --git a/pkg/network/server.go b/pkg/network/server.go index 1bd29c1ad..4b2680431 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -78,8 +78,6 @@ func NewServer(net NetMode) *Server { } s := &Server{ - // It is important to have this user agent correct. Otherwise we will get - // disconnected. id: util.RandUint32(1111111, 9999999), userAgent: fmt.Sprintf("\v/NEO:%s/", version), logger: logger, @@ -132,10 +130,6 @@ func (s *Server) shutdown() { } } -func (s *Server) disconnect(p *Peer) { - s.unregister <- p -} - func (s *Server) loop() { for { select { @@ -144,18 +138,9 @@ func (s *Server) loop() { // its peer will be received on this channel. // Any peer registration must happen via this channel. s.logger.Printf("peer registered from address %s", peer.conn.RemoteAddr()) - s.peers[peer] = true + s.handlePeerConnected(peer) - // Only respond with a version message if the peer initiated the connection. - if peer.initiator { - resp, err := s.handlePeerConnected() - if err != nil { - s.logger.Fatalf("handling initial peer connection failed: %s", err) - } else { - peer.send <- resp - } - } case peer := <-s.unregister: // unregister should take care of all the cleanup that has to be made. if _, ok := s.peers[peer]; ok { @@ -164,25 +149,35 @@ func (s *Server) loop() { delete(s.peers, peer) s.logger.Printf("peer %s disconnected", peer.conn.RemoteAddr()) } + case tuple := <-s.message: // When a remote node sends data over its connection it will be received // on this channel. + // All errors encountered should be return and handled here. if err := s.processMessage(tuple.msg, tuple.peer); err != nil { s.logger.Fatalf("failed to process message: %s", err) - s.disconnect(tuple.peer) + s.unregister <- tuple.peer } + case <-s.quit: s.shutdown() } } } -// TODO: unregister peers on error. -// processMessage processes the received message from a remote node. +// processMessage processes the message received from the peer. func (s *Server) processMessage(msg *Message, peer *Peer) error { - rpcLogger.Printf("[NODE %d] :: IN :: %s :: %+v", peer.id, msg.commandType(), msg.Payload) + command := msg.commandType() - switch msg.commandType() { + rpcLogger.Printf("[NODE %d] :: IN :: %s :: %+v", peer.id, command, msg.Payload) + + // Disconnect if the remote is sending messages other then version + // if we didn't verack this peer. + if !peer.verack && command != cmdVersion { + return errors.New("version noack") + } + + switch command { case cmdVersion: return s.handleVersionCmd(msg.Payload.(*payload.Version), peer) case cmdVerack: @@ -198,29 +193,31 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error { case cmdBlock: case cmdTX: default: - return errors.New("invalid RPC command received: " + string(msg.commandType())) + return fmt.Errorf("invalid RPC command received: %s", command) } return nil } -// When a new peer is connected we respond with the version command. -// No further communication should been made before both sides has received -// the version of eachother. -func (s *Server) handlePeerConnected() (*Message, error) { +// When a new peer is connected we send our version. +// No further communication should be made before both sides has received +// the versions of eachother. +func (s *Server) handlePeerConnected(peer *Peer) { + // TODO get heigth of block when thats implemented. payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay) msg := newMessage(s.net, cmdVersion, payload) - return msg, nil + + peer.send <- msg } // Version declares the server's version. func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error { - // TODO: check version and verify to trust that node. - - payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay) - // we respond with our version. - versionMsg := newMessage(s.net, cmdVersion, payload) - peer.send <- versionMsg + if s.id == v.Nonce { + return errors.New("remote nonce equal to server id") + } + if peer.endpoint.Port != v.Port { + return errors.New("port mismatch") + } // we respond with a verack, we successfully received peer's version // at this point. @@ -229,7 +226,7 @@ func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error { verackMsg := newMessage(s.net, cmdVerack, nil) peer.send <- verackMsg - go s.startProtocol(peer) + go s.sendLoop(peer) return nil } @@ -267,7 +264,7 @@ func (s *Server) handleGetAddrCmd(msg *Message, peer *Peer) error { return nil } -func (s *Server) startProtocol(peer *Peer) { +func (s *Server) sendLoop(peer *Peer) { // TODO: check if this peer is still connected. for { getaddrMsg := newMessage(s.net, cmdGetAddr, nil) diff --git a/pkg/network/tcp.go b/pkg/network/tcp.go index 457830f8b..e05cefbbb 100644 --- a/pkg/network/tcp.go +++ b/pkg/network/tcp.go @@ -16,7 +16,7 @@ func listenTCP(s *Server, port string) error { if err != nil { return err } - go handleConnection(s, conn, true) + go handleConnection(s, conn) } } @@ -30,7 +30,7 @@ func connectToRemoteNode(s *Server, address string) { return } s.logger.Printf("connected to %s", conn.RemoteAddr()) - go handleConnection(s, conn, false) + go handleConnection(s, conn) } func connectToSeeds(s *Server, addrs []string) { @@ -39,8 +39,8 @@ func connectToSeeds(s *Server, addrs []string) { } } -func handleConnection(s *Server, conn net.Conn, initiated bool) { - peer := NewPeer(conn, initiated) +func handleConnection(s *Server, conn net.Conn) { + peer := NewPeer(conn) s.register <- peer // remove the peer from connected peers and cleanup the connection. diff --git a/pkg/util/endpoint.go b/pkg/util/endpoint.go new file mode 100644 index 000000000..a6b5030d5 --- /dev/null +++ b/pkg/util/endpoint.go @@ -0,0 +1,52 @@ +package util + +import ( + "fmt" + "strconv" + "strings" +) + +// Endpoint host + port of a node, compatible with net.Addr. +type Endpoint struct { + IP [16]byte // TODO: make a uint128 type + Port uint16 +} + +// EndpointFromString returns an Endpoint from the given string. +// For now this only handles the most simple hostport form. +// e.g. 127.0.0.1:3000 +// This should be enough to work with for now. +func EndpointFromString(s string) (Endpoint, error) { + hostPort := strings.Split(s, ":") + if len(hostPort) != 2 { + return Endpoint{}, fmt.Errorf("invalid address string: %s", s) + } + host := hostPort[0] + port := hostPort[1] + + ch := strings.Split(host, ".") + + buf := [16]byte{} + var n int + for i := 0; i < len(ch); i++ { + n = 12 + i + nn, _ := strconv.Atoi(ch[i]) + buf[n] = byte(nn) + } + + p, _ := strconv.Atoi(port) + + return Endpoint{buf, uint16(p)}, nil +} + +// Network implements the net.Addr interface. +func (e Endpoint) Network() string { return "tcp" } + +// String implements the net.Addr interface. +func (e Endpoint) String() string { + b := make([]uint8, 4) + for i := 0; i < 4; i++ { + b[i] = byte(e.IP[len(e.IP)-4+i]) + } + return fmt.Sprintf("%d.%d.%d.%d:%d", b[0], b[1], b[2], b[3], e.Port) +}