protocol: move magic exchange to version payload

closes #889
This commit is contained in:
Anna Shaleva 2020-05-21 13:35:44 +03:00
parent 23b814ad4d
commit 64a2fb63e1
5 changed files with 46 additions and 43 deletions

View file

@ -3,7 +3,6 @@ package network
import ( import (
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/consensus"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
@ -15,9 +14,6 @@ import (
// Message is the complete message send between nodes. // Message is the complete message send between nodes.
type Message struct { type Message struct {
// NetMode of the node that sends this message.
Magic config.NetMode
// Command is byte command code. // Command is byte command code.
Command CommandType Command CommandType
@ -67,7 +63,7 @@ const (
) )
// NewMessage returns a new message with the given payload. // NewMessage returns a new message with the given payload.
func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Message { func NewMessage(cmd CommandType, p payload.Payload) *Message {
var ( var (
size uint32 size uint32
) )
@ -83,7 +79,6 @@ func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Messa
} }
return &Message{ return &Message{
Magic: magic,
Command: cmd, Command: cmd,
Length: size, Length: size,
Payload: p, Payload: p,
@ -92,7 +87,6 @@ func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Messa
// Decode decodes a Message from the given reader. // Decode decodes a Message from the given reader.
func (m *Message) Decode(br *io.BinReader) error { func (m *Message) Decode(br *io.BinReader) error {
m.Magic = config.NetMode(br.ReadU32LE())
m.Command = CommandType(br.ReadB()) m.Command = CommandType(br.ReadB())
m.Length = br.ReadU32LE() m.Length = br.ReadU32LE()
if br.Err != nil { if br.Err != nil {
@ -150,7 +144,6 @@ func (m *Message) decodePayload(br *io.BinReader) error {
// Encode encodes a Message to any given BinWriter. // Encode encodes a Message to any given BinWriter.
func (m *Message) Encode(br *io.BinWriter) error { func (m *Message) Encode(br *io.BinWriter) error {
br.WriteU32LE(uint32(m.Magic))
br.WriteB(byte(m.Command)) br.WriteB(byte(m.Command))
br.WriteU32LE(m.Length) br.WriteU32LE(m.Length)
if m.Payload != nil { if m.Payload != nil {

View file

@ -3,6 +3,7 @@ package payload
import ( import (
"time" "time"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
) )
@ -21,6 +22,8 @@ const (
// Version payload. // Version payload.
type Version struct { type Version struct {
// NetMode of the node
Magic config.NetMode
// currently the version of the protocol is 0 // currently the version of the protocol is 0
Version uint32 Version uint32
// currently 1 // currently 1
@ -40,8 +43,9 @@ type Version struct {
} }
// NewVersion returns a pointer to a Version payload. // NewVersion returns a pointer to a Version payload.
func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version { func NewVersion(magic config.NetMode, id uint32, p uint16, ua string, h uint32, r bool) *Version {
return &Version{ return &Version{
Magic: magic,
Version: 0, Version: 0,
Services: nodePeerService, Services: nodePeerService,
Timestamp: uint32(time.Now().UTC().Unix()), Timestamp: uint32(time.Now().UTC().Unix()),
@ -55,6 +59,7 @@ func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version {
// DecodeBinary implements Serializable interface. // DecodeBinary implements Serializable interface.
func (p *Version) DecodeBinary(br *io.BinReader) { func (p *Version) DecodeBinary(br *io.BinReader) {
p.Magic = config.NetMode(br.ReadU32LE())
p.Version = br.ReadU32LE() p.Version = br.ReadU32LE()
p.Services = br.ReadU64LE() p.Services = br.ReadU64LE()
p.Timestamp = br.ReadU32LE() p.Timestamp = br.ReadU32LE()
@ -67,6 +72,7 @@ func (p *Version) DecodeBinary(br *io.BinReader) {
// EncodeBinary implements Serializable interface. // EncodeBinary implements Serializable interface.
func (p *Version) EncodeBinary(br *io.BinWriter) { func (p *Version) EncodeBinary(br *io.BinWriter) {
br.WriteU32LE(uint32(p.Magic))
br.WriteU32LE(p.Version) br.WriteU32LE(p.Version)
br.WriteU64LE(p.Services) br.WriteU64LE(p.Services)
br.WriteU32LE(p.Timestamp) br.WriteU32LE(p.Timestamp)

View file

@ -3,18 +3,20 @@ package payload
import ( import (
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestVersionEncodeDecode(t *testing.T) { func TestVersionEncodeDecode(t *testing.T) {
var magic config.NetMode = 56753
var port uint16 = 3000 var port uint16 = 3000
var id uint32 = 13337 var id uint32 = 13337
useragent := "/NEO:0.0.1/" useragent := "/NEO:0.0.1/"
var height uint32 = 100500 var height uint32 = 100500
var relay = true var relay = true
version := NewVersion(id, port, useragent, height, relay) version := NewVersion(magic, id, port, useragent, height, relay)
versionDecoded := &Version{} versionDecoded := &Version{}
testserdes.EncodeDecodeBinary(t, version, versionDecoded) testserdes.EncodeDecodeBinary(t, version, versionDecoded)

View file

@ -157,12 +157,6 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo
return s, nil return s, nil
} }
// MkMsg creates a new message based on the server configured network and given
// parameters.
func (s *Server) MkMsg(cmd CommandType, p payload.Payload) *Message {
return NewMessage(s.Net, cmd, p)
}
// ID returns the servers ID. // ID returns the servers ID.
func (s *Server) ID() uint32 { func (s *Server) ID() uint32 {
return s.id return s.id
@ -230,7 +224,7 @@ func (s *Server) run() {
s.discovery.RequestRemote(s.AttemptConnPeers) s.discovery.RequestRemote(s.AttemptConnPeers)
} }
if s.discovery.PoolCount() < minPoolCount { if s.discovery.PoolCount() < minPoolCount {
s.broadcastHPMessage(s.MkMsg(CMDGetAddr, payload.NewNullPayload())) s.broadcastHPMessage(NewMessage(CMDGetAddr, payload.NewNullPayload()))
} }
select { select {
case <-s.quit: case <-s.quit:
@ -292,7 +286,7 @@ func (s *Server) runProto() {
if s.chain.BlockHeight() == prevHeight { if s.chain.BlockHeight() == prevHeight {
// Get a copy of s.peers to avoid holding a lock while sending. // Get a copy of s.peers to avoid holding a lock while sending.
for peer := range s.Peers() { for peer := range s.Peers() {
_ = peer.SendPing(s.MkMsg(CMDPing, payload.NewPing(s.id, s.chain.HeaderHeight()))) _ = peer.SendPing(NewMessage(CMDPing, payload.NewPing(s.id, s.chain.HeaderHeight())))
} }
} }
pingTimer.Reset(s.PingInterval) pingTimer.Reset(s.PingInterval)
@ -354,13 +348,14 @@ func (s *Server) HandshakedPeersCount() int {
// getVersionMsg returns current version message. // getVersionMsg returns current version message.
func (s *Server) getVersionMsg() *Message { func (s *Server) getVersionMsg() *Message {
payload := payload.NewVersion( payload := payload.NewVersion(
s.Net,
s.id, s.id,
s.Port, s.Port,
s.UserAgent, s.UserAgent,
s.chain.BlockHeight(), s.chain.BlockHeight(),
s.Relay, s.Relay,
) )
return s.MkMsg(CMDVersion, payload) return NewMessage(CMDVersion, payload)
} }
// IsInSync answers the question of whether the server is in sync with the // IsInSync answers the question of whether the server is in sync with the
@ -406,6 +401,11 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
if s.id == version.Nonce { if s.id == version.Nonce {
return errIdenticalID return errIdenticalID
} }
// Make sure both server and peer are operating on
// the same network.
if s.Net != version.Magic {
return errInvalidNetwork
}
peerAddr := p.PeerAddr().String() peerAddr := p.PeerAddr().String()
s.discovery.RegisterConnectedAddr(peerAddr) s.discovery.RegisterConnectedAddr(peerAddr)
s.lock.RLock() s.lock.RLock()
@ -421,7 +421,7 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
} }
} }
s.lock.RUnlock() s.lock.RUnlock()
return p.SendVersionAck(s.MkMsg(CMDVerack, nil)) return p.SendVersionAck(NewMessage(CMDVerack, nil))
} }
// handleHeadersCmd processes the headers received from its peer. // handleHeadersCmd processes the headers received from its peer.
@ -448,7 +448,7 @@ func (s *Server) handleBlockCmd(p Peer, block *block.Block) error {
// handlePing processes ping request. // handlePing processes ping request.
func (s *Server) handlePing(p Peer, ping *payload.Ping) error { func (s *Server) handlePing(p Peer, ping *payload.Ping) error {
return p.EnqueueP2PMessage(s.MkMsg(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id))) return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id)))
} }
// handlePing processes pong request. // handlePing processes pong request.
@ -482,7 +482,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
} }
} }
if len(reqHashes) > 0 { if len(reqHashes) > 0 {
msg := s.MkMsg(CMDGetData, payload.NewInventory(inv.Type, reqHashes)) msg := NewMessage(CMDGetData, payload.NewInventory(inv.Type, reqHashes))
pkt, err := msg.Bytes() pkt, err := msg.Bytes()
if err != nil { if err != nil {
return err return err
@ -504,16 +504,16 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
case payload.TXType: case payload.TXType:
tx, _, err := s.chain.GetTransaction(hash) tx, _, err := s.chain.GetTransaction(hash)
if err == nil { if err == nil {
msg = s.MkMsg(CMDTX, tx) msg = NewMessage(CMDTX, tx)
} }
case payload.BlockType: case payload.BlockType:
b, err := s.chain.GetBlock(hash) b, err := s.chain.GetBlock(hash)
if err == nil { if err == nil {
msg = s.MkMsg(CMDBlock, b) msg = NewMessage(CMDBlock, b)
} }
case payload.ConsensusType: case payload.ConsensusType:
if cp := s.consensus.GetPayload(hash); cp != nil { if cp := s.consensus.GetPayload(hash); cp != nil {
msg = s.MkMsg(CMDConsensus, cp) msg = NewMessage(CMDConsensus, cp)
} }
} }
if msg != nil { if msg != nil {
@ -559,7 +559,7 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
return nil return nil
} }
payload := payload.NewInventory(payload.BlockType, blockHashes) payload := payload.NewInventory(payload.BlockType, blockHashes)
msg := s.MkMsg(CMDInv, payload) msg := NewMessage(CMDInv, payload)
return p.EnqueueP2PMessage(msg) return p.EnqueueP2PMessage(msg)
} }
@ -589,7 +589,7 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error {
if len(resp.Hdrs) == 0 { if len(resp.Hdrs) == 0 {
return nil return nil
} }
msg := s.MkMsg(CMDHeaders, &resp) msg := NewMessage(CMDHeaders, &resp)
return p.EnqueueP2PMessage(msg) return p.EnqueueP2PMessage(msg)
} }
@ -633,7 +633,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error {
netaddr, _ := net.ResolveTCPAddr("tcp", addr) netaddr, _ := net.ResolveTCPAddr("tcp", addr)
alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts) alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts)
} }
return p.EnqueueP2PMessage(s.MkMsg(CMDAddr, alist)) return p.EnqueueP2PMessage(NewMessage(CMDAddr, alist))
} }
// requestHeaders sends a getheaders message to the peer. // requestHeaders sends a getheaders message to the peer.
@ -641,7 +641,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error {
func (s *Server) requestHeaders(p Peer) error { func (s *Server) requestHeaders(p Peer) error {
start := []util.Uint256{s.chain.CurrentHeaderHash()} start := []util.Uint256{s.chain.CurrentHeaderHash()}
payload := payload.NewGetBlocks(start, util.Uint256{}) payload := payload.NewGetBlocks(start, util.Uint256{})
return p.EnqueueP2PMessage(s.MkMsg(CMDGetHeaders, payload)) return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload))
} }
// requestBlocks sends a getdata message to the peer // requestBlocks sends a getdata message to the peer
@ -660,7 +660,7 @@ func (s *Server) requestBlocks(p Peer) error {
} }
if len(hashes) > 0 { if len(hashes) > 0 {
payload := payload.NewInventory(payload.BlockType, hashes) payload := payload.NewInventory(payload.BlockType, hashes)
return p.EnqueueP2PMessage(s.MkMsg(CMDGetData, payload)) return p.EnqueueP2PMessage(NewMessage(CMDGetData, payload))
} else if s.chain.HeaderHeight() < p.LastBlockIndex() { } else if s.chain.HeaderHeight() < p.LastBlockIndex() {
return s.requestHeaders(p) return s.requestHeaders(p)
} }
@ -673,12 +673,6 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
zap.Stringer("addr", peer.RemoteAddr()), zap.Stringer("addr", peer.RemoteAddr()),
zap.String("type", msg.Command.String())) zap.String("type", msg.Command.String()))
// Make sure both server and peer are operating on
// the same network.
if msg.Magic != s.Net {
return errInvalidNetwork
}
if peer.Handshaked() { if peer.Handshaked() {
if inv, ok := msg.Payload.(*payload.Inventory); ok { if inv, ok := msg.Payload.(*payload.Inventory); ok {
if !inv.Type.Valid() || len(inv.Hashes) == 0 { if !inv.Type.Valid() || len(inv.Hashes) == 0 {
@ -746,7 +740,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
} }
func (s *Server) handleNewPayload(p *consensus.Payload) { func (s *Server) handleNewPayload(p *consensus.Payload) {
msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) msg := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()}))
// It's high priority because it directly affects consensus process, // It's high priority because it directly affects consensus process,
// even though it's just an inv. // even though it's just an inv.
s.broadcastHPMessage(msg) s.broadcastHPMessage(msg)
@ -757,7 +751,7 @@ func (s *Server) requestTx(hashes ...util.Uint256) {
return return
} }
msg := s.MkMsg(CMDGetData, payload.NewInventory(payload.TXType, hashes)) msg := NewMessage(CMDGetData, payload.NewInventory(payload.TXType, hashes))
// It's high priority because it directly affects consensus process, // It's high priority because it directly affects consensus process,
// even though it's getdata. // even though it's getdata.
s.broadcastHPMessage(msg) s.broadcastHPMessage(msg)
@ -793,7 +787,7 @@ func (s *Server) broadcastHPMessage(msg *Message) {
// relayBlock tells all the other connected nodes about the given block. // relayBlock tells all the other connected nodes about the given block.
func (s *Server) relayBlock(b *block.Block) { func (s *Server) relayBlock(b *block.Block) {
msg := s.MkMsg(CMDInv, payload.NewInventory(payload.BlockType, []util.Uint256{b.Hash()})) msg := NewMessage(CMDInv, payload.NewInventory(payload.BlockType, []util.Uint256{b.Hash()}))
// Filter out nodes that are more current (avoid spamming the network // Filter out nodes that are more current (avoid spamming the network
// during initial sync). // during initial sync).
s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool { s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool {
@ -837,7 +831,7 @@ func (s *Server) broadcastTX(t *transaction.Transaction) {
} }
func (s *Server) broadcastTxHashes(hs []util.Uint256) { func (s *Server) broadcastTxHashes(hs []util.Uint256) {
msg := s.MkMsg(CMDInv, payload.NewInventory(payload.TXType, hs)) msg := NewMessage(CMDInv, payload.NewInventory(payload.TXType, hs))
// We need to filter out non-relaying nodes, so plain broadcast // We need to filter out non-relaying nodes, so plain broadcast
// functions don't fit here. // functions don't fit here.

View file

@ -45,7 +45,7 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) {
p.messageHandler = func(t *testing.T, msg *Message) { p.messageHandler = func(t *testing.T, msg *Message) {
assert.Equal(t, CMDVerack, msg.Command) assert.Equal(t, CMDVerack, msg.Command)
} }
version := payload.NewVersion(1337, 3000, "/NEO-GO/", 0, true) version := payload.NewVersion(0, 1337, 3000, "/NEO-GO/", 0, true)
require.NoError(t, s.handleVersionCmd(p, version)) require.NoError(t, s.handleVersionCmd(p, version))
} }
@ -59,6 +59,7 @@ func TestServerNotSendsVerack(t *testing.T) {
p2 = newLocalPeer(t, s) p2 = newLocalPeer(t, s)
) )
s.id = 1 s.id = 1
s.Net = 56753
finished := make(chan struct{}) finished := make(chan struct{})
go func() { go func() {
s.run() s.run()
@ -76,13 +77,20 @@ func TestServerNotSendsVerack(t *testing.T) {
s.register <- p s.register <- p
// identical id's // identical id's
version := payload.NewVersion(1, 3000, "/NEO-GO/", 0, true) version := payload.NewVersion(56753, 1, 3000, "/NEO-GO/", 0, true)
err := s.handleVersionCmd(p, version) err := s.handleVersionCmd(p, version)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, errIdenticalID, err) assert.Equal(t, errIdenticalID, err)
// Different IDs, make handshake pass. // Different IDs, but also different magics
version.Nonce = 2 version.Nonce = 2
version.Magic = 56752
err = s.handleVersionCmd(p, version)
assert.NotNil(t, err)
assert.Equal(t, errInvalidNetwork, err)
// Different IDs and same network, make handshake pass.
version.Magic = 56753
require.NoError(t, s.handleVersionCmd(p, version)) require.NoError(t, s.handleVersionCmd(p, version))
require.NoError(t, p.HandleVersionAck()) require.NoError(t, p.HandleVersionAck())
require.Equal(t, true, p.Handshaked()) require.Equal(t, true, p.Handshaked())