Refactor version msg

This commit is contained in:
anthdm 2018-01-29 19:17:49 +01:00
parent 6f08d967ba
commit d4a96267c6
6 changed files with 109 additions and 99 deletions

View file

@ -2,66 +2,21 @@ package payload
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"strconv"
"strings" "github.com/anthdm/neo-go/pkg/util"
) )
// Endpoint host + port of a node, compatible with net.Addr.
type Endpoint struct {
IP [16]byte // TODO: make a uint128 type
Port uint16
}
// EndpointFromString returns an Endpoint from the given string.
// For now this only handles the most simple hostport form.
// e.g. 127.0.0.1:3000
// This should be enough to work with for now.
func EndpointFromString(s string) (Endpoint, error) {
hostPort := strings.Split(s, ":")
if len(hostPort) != 2 {
return Endpoint{}, fmt.Errorf("invalid address string: %s", s)
}
host := hostPort[0]
port := hostPort[1]
ch := strings.Split(host, ".")
buf := [16]byte{}
var n int
for i := 0; i < len(ch); i++ {
n = 12 + i
nn, _ := strconv.Atoi(ch[i])
buf[n] = byte(nn)
}
p, _ := strconv.Atoi(port)
return Endpoint{buf, uint16(p)}, nil
}
// Network implements the net.Addr interface.
func (e Endpoint) Network() string { return "tcp" }
// String implements the net.Addr interface.
func (e Endpoint) String() string {
b := make([]uint8, 4)
for i := 0; i < 4; i++ {
b[i] = byte(e.IP[len(e.IP)-4+i])
}
return fmt.Sprintf("%d.%d.%d.%d:%d", b[0], b[1], b[2], b[3], e.Port)
}
// AddrWithTime payload // AddrWithTime payload
type AddrWithTime struct { type AddrWithTime struct {
// Timestamp the node connected to the network. // Timestamp the node connected to the network.
Timestamp uint32 Timestamp uint32
Services uint64 Services uint64
Addr Endpoint Addr util.Endpoint
} }
func NewAddrWithTime(addr Endpoint) *AddrWithTime { // NewAddrWithTime return a pointer to AddrWithTime.
func NewAddrWithTime(addr util.Endpoint) *AddrWithTime {
return &AddrWithTime{ return &AddrWithTime{
Timestamp: 1337, Timestamp: 1337,
Services: 1, Services: 1,

View file

@ -5,10 +5,12 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
"github.com/anthdm/neo-go/pkg/util"
) )
func TestEncodeDecodeAddr(t *testing.T) { func TestEncodeDecodeAddr(t *testing.T) {
e, err := EndpointFromString("127.0.0.1:2000") e, err := util.EndpointFromString("127.0.0.1:2000")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -33,7 +35,7 @@ func TestEncodeDecodeAddressList(t *testing.T) {
var lenList uint8 = 4 var lenList uint8 = 4
addrs := make([]*AddrWithTime, lenList) addrs := make([]*AddrWithTime, lenList)
for i := 0; i < int(lenList); i++ { for i := 0; i < int(lenList); i++ {
e, _ := EndpointFromString(fmt.Sprintf("127.0.0.1:200%d", i)) e, _ := util.EndpointFromString(fmt.Sprintf("127.0.0.1:200%d", i))
addrs[i] = NewAddrWithTime(e) addrs[i] = NewAddrWithTime(e)
} }

View file

@ -3,6 +3,8 @@ package network
import ( import (
"log" "log"
"net" "net"
"github.com/anthdm/neo-go/pkg/util"
) )
// Peer represents a remote node, backed by TCP transport. // Peer represents a remote node, backed by TCP transport.
@ -10,20 +12,22 @@ type Peer struct {
id uint32 id uint32
// underlying TCP connection // underlying TCP connection
conn net.Conn conn net.Conn
// channel to coordinate message writes back to the connection. // host and port information about this peer.
endpoint util.Endpoint
// channel to coordinate messages writen back to the connection.
send chan *Message send chan *Message
// verack is true if this node has sended it's version. // whether this peers version was acknowledged.
verack bool verack bool
// whether we or him made the initial connection.
initiator bool
} }
// NewPeer returns a (TCP) Peer. // NewPeer returns a (TCP) Peer.
func NewPeer(conn net.Conn, init bool) *Peer { func NewPeer(conn net.Conn) *Peer {
e, _ := util.EndpointFromString(conn.RemoteAddr().String())
return &Peer{ return &Peer{
conn: conn, conn: conn,
send: make(chan *Message), send: make(chan *Message),
initiator: init, endpoint: e,
} }
} }

View file

@ -78,8 +78,6 @@ func NewServer(net NetMode) *Server {
} }
s := &Server{ s := &Server{
// It is important to have this user agent correct. Otherwise we will get
// disconnected.
id: util.RandUint32(1111111, 9999999), id: util.RandUint32(1111111, 9999999),
userAgent: fmt.Sprintf("\v/NEO:%s/", version), userAgent: fmt.Sprintf("\v/NEO:%s/", version),
logger: logger, logger: logger,
@ -132,10 +130,6 @@ func (s *Server) shutdown() {
} }
} }
func (s *Server) disconnect(p *Peer) {
s.unregister <- p
}
func (s *Server) loop() { func (s *Server) loop() {
for { for {
select { select {
@ -144,18 +138,9 @@ func (s *Server) loop() {
// its peer will be received on this channel. // its peer will be received on this channel.
// Any peer registration must happen via this channel. // Any peer registration must happen via this channel.
s.logger.Printf("peer registered from address %s", peer.conn.RemoteAddr()) s.logger.Printf("peer registered from address %s", peer.conn.RemoteAddr())
s.peers[peer] = true s.peers[peer] = true
s.handlePeerConnected(peer)
// Only respond with a version message if the peer initiated the connection.
if peer.initiator {
resp, err := s.handlePeerConnected()
if err != nil {
s.logger.Fatalf("handling initial peer connection failed: %s", err)
} else {
peer.send <- resp
}
}
case peer := <-s.unregister: case peer := <-s.unregister:
// unregister should take care of all the cleanup that has to be made. // unregister should take care of all the cleanup that has to be made.
if _, ok := s.peers[peer]; ok { if _, ok := s.peers[peer]; ok {
@ -164,25 +149,35 @@ func (s *Server) loop() {
delete(s.peers, peer) delete(s.peers, peer)
s.logger.Printf("peer %s disconnected", peer.conn.RemoteAddr()) s.logger.Printf("peer %s disconnected", peer.conn.RemoteAddr())
} }
case tuple := <-s.message: case tuple := <-s.message:
// When a remote node sends data over its connection it will be received // When a remote node sends data over its connection it will be received
// on this channel. // on this channel.
// All errors encountered should be return and handled here.
if err := s.processMessage(tuple.msg, tuple.peer); err != nil { if err := s.processMessage(tuple.msg, tuple.peer); err != nil {
s.logger.Fatalf("failed to process message: %s", err) s.logger.Fatalf("failed to process message: %s", err)
s.disconnect(tuple.peer) s.unregister <- tuple.peer
} }
case <-s.quit: case <-s.quit:
s.shutdown() s.shutdown()
} }
} }
} }
// TODO: unregister peers on error. // processMessage processes the message received from the peer.
// processMessage processes the received message from a remote node.
func (s *Server) processMessage(msg *Message, peer *Peer) error { func (s *Server) processMessage(msg *Message, peer *Peer) error {
rpcLogger.Printf("[NODE %d] :: IN :: %s :: %+v", peer.id, msg.commandType(), msg.Payload) command := msg.commandType()
switch msg.commandType() { rpcLogger.Printf("[NODE %d] :: IN :: %s :: %+v", peer.id, command, msg.Payload)
// Disconnect if the remote is sending messages other then version
// if we didn't verack this peer.
if !peer.verack && command != cmdVersion {
return errors.New("version noack")
}
switch command {
case cmdVersion: case cmdVersion:
return s.handleVersionCmd(msg.Payload.(*payload.Version), peer) return s.handleVersionCmd(msg.Payload.(*payload.Version), peer)
case cmdVerack: case cmdVerack:
@ -198,29 +193,31 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error {
case cmdBlock: case cmdBlock:
case cmdTX: case cmdTX:
default: default:
return errors.New("invalid RPC command received: " + string(msg.commandType())) return fmt.Errorf("invalid RPC command received: %s", command)
} }
return nil return nil
} }
// When a new peer is connected we respond with the version command. // When a new peer is connected we send our version.
// No further communication should been made before both sides has received // No further communication should be made before both sides has received
// the version of eachother. // the versions of eachother.
func (s *Server) handlePeerConnected() (*Message, error) { func (s *Server) handlePeerConnected(peer *Peer) {
// TODO get heigth of block when thats implemented.
payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay) payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay)
msg := newMessage(s.net, cmdVersion, payload) msg := newMessage(s.net, cmdVersion, payload)
return msg, nil
peer.send <- msg
} }
// Version declares the server's version. // Version declares the server's version.
func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error { func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error {
// TODO: check version and verify to trust that node. if s.id == v.Nonce {
return errors.New("remote nonce equal to server id")
payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay) }
// we respond with our version. if peer.endpoint.Port != v.Port {
versionMsg := newMessage(s.net, cmdVersion, payload) return errors.New("port mismatch")
peer.send <- versionMsg }
// we respond with a verack, we successfully received peer's version // we respond with a verack, we successfully received peer's version
// at this point. // at this point.
@ -229,7 +226,7 @@ func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error {
verackMsg := newMessage(s.net, cmdVerack, nil) verackMsg := newMessage(s.net, cmdVerack, nil)
peer.send <- verackMsg peer.send <- verackMsg
go s.startProtocol(peer) go s.sendLoop(peer)
return nil return nil
} }
@ -267,7 +264,7 @@ func (s *Server) handleGetAddrCmd(msg *Message, peer *Peer) error {
return nil return nil
} }
func (s *Server) startProtocol(peer *Peer) { func (s *Server) sendLoop(peer *Peer) {
// TODO: check if this peer is still connected. // TODO: check if this peer is still connected.
for { for {
getaddrMsg := newMessage(s.net, cmdGetAddr, nil) getaddrMsg := newMessage(s.net, cmdGetAddr, nil)

View file

@ -16,7 +16,7 @@ func listenTCP(s *Server, port string) error {
if err != nil { if err != nil {
return err return err
} }
go handleConnection(s, conn, true) go handleConnection(s, conn)
} }
} }
@ -30,7 +30,7 @@ func connectToRemoteNode(s *Server, address string) {
return return
} }
s.logger.Printf("connected to %s", conn.RemoteAddr()) s.logger.Printf("connected to %s", conn.RemoteAddr())
go handleConnection(s, conn, false) go handleConnection(s, conn)
} }
func connectToSeeds(s *Server, addrs []string) { func connectToSeeds(s *Server, addrs []string) {
@ -39,8 +39,8 @@ func connectToSeeds(s *Server, addrs []string) {
} }
} }
func handleConnection(s *Server, conn net.Conn, initiated bool) { func handleConnection(s *Server, conn net.Conn) {
peer := NewPeer(conn, initiated) peer := NewPeer(conn)
s.register <- peer s.register <- peer
// remove the peer from connected peers and cleanup the connection. // remove the peer from connected peers and cleanup the connection.

52
pkg/util/endpoint.go Normal file
View file

@ -0,0 +1,52 @@
package util
import (
"fmt"
"strconv"
"strings"
)
// Endpoint host + port of a node, compatible with net.Addr.
type Endpoint struct {
IP [16]byte // TODO: make a uint128 type
Port uint16
}
// EndpointFromString returns an Endpoint from the given string.
// For now this only handles the most simple hostport form.
// e.g. 127.0.0.1:3000
// This should be enough to work with for now.
func EndpointFromString(s string) (Endpoint, error) {
hostPort := strings.Split(s, ":")
if len(hostPort) != 2 {
return Endpoint{}, fmt.Errorf("invalid address string: %s", s)
}
host := hostPort[0]
port := hostPort[1]
ch := strings.Split(host, ".")
buf := [16]byte{}
var n int
for i := 0; i < len(ch); i++ {
n = 12 + i
nn, _ := strconv.Atoi(ch[i])
buf[n] = byte(nn)
}
p, _ := strconv.Atoi(port)
return Endpoint{buf, uint16(p)}, nil
}
// Network implements the net.Addr interface.
func (e Endpoint) Network() string { return "tcp" }
// String implements the net.Addr interface.
func (e Endpoint) String() string {
b := make([]uint8, 4)
for i := 0; i < 4; i++ {
b[i] = byte(e.IP[len(e.IP)-4+i])
}
return fmt.Sprintf("%d.%d.%d.%d:%d", b[0], b[1], b[2], b[3], e.Port)
}