diff --git a/pkg/network/message.go b/pkg/network/message.go index 69995ab92..46eaa91f6 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -3,7 +3,6 @@ package network import ( "fmt" - "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -15,9 +14,6 @@ import ( // Message is the complete message send between nodes. type Message struct { - // NetMode of the node that sends this message. - Magic config.NetMode - // Command is byte command code. Command CommandType @@ -67,7 +63,7 @@ const ( ) // NewMessage returns a new message with the given payload. -func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Message { +func NewMessage(cmd CommandType, p payload.Payload) *Message { var ( size uint32 ) @@ -83,7 +79,6 @@ func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Messa } return &Message{ - Magic: magic, Command: cmd, Length: size, Payload: p, @@ -92,7 +87,6 @@ func NewMessage(magic config.NetMode, cmd CommandType, p payload.Payload) *Messa // Decode decodes a Message from the given reader. func (m *Message) Decode(br *io.BinReader) error { - m.Magic = config.NetMode(br.ReadU32LE()) m.Command = CommandType(br.ReadB()) m.Length = br.ReadU32LE() if br.Err != nil { @@ -150,7 +144,6 @@ func (m *Message) decodePayload(br *io.BinReader) error { // Encode encodes a Message to any given BinWriter. func (m *Message) Encode(br *io.BinWriter) error { - br.WriteU32LE(uint32(m.Magic)) br.WriteB(byte(m.Command)) br.WriteU32LE(m.Length) if m.Payload != nil { diff --git a/pkg/network/payload/version.go b/pkg/network/payload/version.go index 3bb6bc345..86be2b80c 100644 --- a/pkg/network/payload/version.go +++ b/pkg/network/payload/version.go @@ -3,6 +3,7 @@ package payload import ( "time" + "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/io" ) @@ -21,6 +22,8 @@ const ( // Version payload. type Version struct { + // NetMode of the node + Magic config.NetMode // currently the version of the protocol is 0 Version uint32 // currently 1 @@ -40,8 +43,9 @@ type Version struct { } // NewVersion returns a pointer to a Version payload. -func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version { +func NewVersion(magic config.NetMode, id uint32, p uint16, ua string, h uint32, r bool) *Version { return &Version{ + Magic: magic, Version: 0, Services: nodePeerService, Timestamp: uint32(time.Now().UTC().Unix()), @@ -55,6 +59,7 @@ func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version { // DecodeBinary implements Serializable interface. func (p *Version) DecodeBinary(br *io.BinReader) { + p.Magic = config.NetMode(br.ReadU32LE()) p.Version = br.ReadU32LE() p.Services = br.ReadU64LE() p.Timestamp = br.ReadU32LE() @@ -67,6 +72,7 @@ func (p *Version) DecodeBinary(br *io.BinReader) { // EncodeBinary implements Serializable interface. func (p *Version) EncodeBinary(br *io.BinWriter) { + br.WriteU32LE(uint32(p.Magic)) br.WriteU32LE(p.Version) br.WriteU64LE(p.Services) br.WriteU32LE(p.Timestamp) diff --git a/pkg/network/payload/version_test.go b/pkg/network/payload/version_test.go index 1f44a6d7b..c8c8ccad6 100644 --- a/pkg/network/payload/version_test.go +++ b/pkg/network/payload/version_test.go @@ -3,18 +3,20 @@ package payload import ( "testing" + "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/stretchr/testify/assert" ) func TestVersionEncodeDecode(t *testing.T) { + var magic config.NetMode = 56753 var port uint16 = 3000 var id uint32 = 13337 useragent := "/NEO:0.0.1/" var height uint32 = 100500 var relay = true - version := NewVersion(id, port, useragent, height, relay) + version := NewVersion(magic, id, port, useragent, height, relay) versionDecoded := &Version{} testserdes.EncodeDecodeBinary(t, version, versionDecoded) diff --git a/pkg/network/server.go b/pkg/network/server.go index 81e49d7b8..ff99d3d9a 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -157,12 +157,6 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo return s, nil } -// MkMsg creates a new message based on the server configured network and given -// parameters. -func (s *Server) MkMsg(cmd CommandType, p payload.Payload) *Message { - return NewMessage(s.Net, cmd, p) -} - // ID returns the servers ID. func (s *Server) ID() uint32 { return s.id @@ -230,7 +224,7 @@ func (s *Server) run() { s.discovery.RequestRemote(s.AttemptConnPeers) } if s.discovery.PoolCount() < minPoolCount { - s.broadcastHPMessage(s.MkMsg(CMDGetAddr, payload.NewNullPayload())) + s.broadcastHPMessage(NewMessage(CMDGetAddr, payload.NewNullPayload())) } select { case <-s.quit: @@ -292,7 +286,7 @@ func (s *Server) runProto() { if s.chain.BlockHeight() == prevHeight { // Get a copy of s.peers to avoid holding a lock while sending. for peer := range s.Peers() { - _ = peer.SendPing(s.MkMsg(CMDPing, payload.NewPing(s.id, s.chain.HeaderHeight()))) + _ = peer.SendPing(NewMessage(CMDPing, payload.NewPing(s.id, s.chain.HeaderHeight()))) } } pingTimer.Reset(s.PingInterval) @@ -354,13 +348,14 @@ func (s *Server) HandshakedPeersCount() int { // getVersionMsg returns current version message. func (s *Server) getVersionMsg() *Message { payload := payload.NewVersion( + s.Net, s.id, s.Port, s.UserAgent, s.chain.BlockHeight(), s.Relay, ) - return s.MkMsg(CMDVersion, payload) + return NewMessage(CMDVersion, payload) } // IsInSync answers the question of whether the server is in sync with the @@ -406,6 +401,11 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { if s.id == version.Nonce { return errIdenticalID } + // Make sure both server and peer are operating on + // the same network. + if s.Net != version.Magic { + return errInvalidNetwork + } peerAddr := p.PeerAddr().String() s.discovery.RegisterConnectedAddr(peerAddr) s.lock.RLock() @@ -421,7 +421,7 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { } } s.lock.RUnlock() - return p.SendVersionAck(s.MkMsg(CMDVerack, nil)) + return p.SendVersionAck(NewMessage(CMDVerack, nil)) } // handleHeadersCmd processes the headers received from its peer. @@ -448,7 +448,7 @@ func (s *Server) handleBlockCmd(p Peer, block *block.Block) error { // handlePing processes ping request. func (s *Server) handlePing(p Peer, ping *payload.Ping) error { - return p.EnqueueP2PMessage(s.MkMsg(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id))) + return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id))) } // handlePing processes pong request. @@ -482,7 +482,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { } } if len(reqHashes) > 0 { - msg := s.MkMsg(CMDGetData, payload.NewInventory(inv.Type, reqHashes)) + msg := NewMessage(CMDGetData, payload.NewInventory(inv.Type, reqHashes)) pkt, err := msg.Bytes() if err != nil { return err @@ -504,16 +504,16 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { case payload.TXType: tx, _, err := s.chain.GetTransaction(hash) if err == nil { - msg = s.MkMsg(CMDTX, tx) + msg = NewMessage(CMDTX, tx) } case payload.BlockType: b, err := s.chain.GetBlock(hash) if err == nil { - msg = s.MkMsg(CMDBlock, b) + msg = NewMessage(CMDBlock, b) } case payload.ConsensusType: if cp := s.consensus.GetPayload(hash); cp != nil { - msg = s.MkMsg(CMDConsensus, cp) + msg = NewMessage(CMDConsensus, cp) } } if msg != nil { @@ -559,7 +559,7 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { return nil } payload := payload.NewInventory(payload.BlockType, blockHashes) - msg := s.MkMsg(CMDInv, payload) + msg := NewMessage(CMDInv, payload) return p.EnqueueP2PMessage(msg) } @@ -589,7 +589,7 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { if len(resp.Hdrs) == 0 { return nil } - msg := s.MkMsg(CMDHeaders, &resp) + msg := NewMessage(CMDHeaders, &resp) return p.EnqueueP2PMessage(msg) } @@ -633,7 +633,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error { netaddr, _ := net.ResolveTCPAddr("tcp", addr) alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts) } - return p.EnqueueP2PMessage(s.MkMsg(CMDAddr, alist)) + return p.EnqueueP2PMessage(NewMessage(CMDAddr, alist)) } // requestHeaders sends a getheaders message to the peer. @@ -641,7 +641,7 @@ func (s *Server) handleGetAddrCmd(p Peer) error { func (s *Server) requestHeaders(p Peer) error { start := []util.Uint256{s.chain.CurrentHeaderHash()} payload := payload.NewGetBlocks(start, util.Uint256{}) - return p.EnqueueP2PMessage(s.MkMsg(CMDGetHeaders, payload)) + return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload)) } // requestBlocks sends a getdata message to the peer @@ -660,7 +660,7 @@ func (s *Server) requestBlocks(p Peer) error { } if len(hashes) > 0 { payload := payload.NewInventory(payload.BlockType, hashes) - return p.EnqueueP2PMessage(s.MkMsg(CMDGetData, payload)) + return p.EnqueueP2PMessage(NewMessage(CMDGetData, payload)) } else if s.chain.HeaderHeight() < p.LastBlockIndex() { return s.requestHeaders(p) } @@ -673,12 +673,6 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { zap.Stringer("addr", peer.RemoteAddr()), zap.String("type", msg.Command.String())) - // Make sure both server and peer are operating on - // the same network. - if msg.Magic != s.Net { - return errInvalidNetwork - } - if peer.Handshaked() { if inv, ok := msg.Payload.(*payload.Inventory); ok { if !inv.Type.Valid() || len(inv.Hashes) == 0 { @@ -746,7 +740,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { } func (s *Server) handleNewPayload(p *consensus.Payload) { - msg := s.MkMsg(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) + msg := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) // It's high priority because it directly affects consensus process, // even though it's just an inv. s.broadcastHPMessage(msg) @@ -757,7 +751,7 @@ func (s *Server) requestTx(hashes ...util.Uint256) { return } - msg := s.MkMsg(CMDGetData, payload.NewInventory(payload.TXType, hashes)) + msg := NewMessage(CMDGetData, payload.NewInventory(payload.TXType, hashes)) // It's high priority because it directly affects consensus process, // even though it's getdata. s.broadcastHPMessage(msg) @@ -793,7 +787,7 @@ func (s *Server) broadcastHPMessage(msg *Message) { // relayBlock tells all the other connected nodes about the given block. func (s *Server) relayBlock(b *block.Block) { - msg := s.MkMsg(CMDInv, payload.NewInventory(payload.BlockType, []util.Uint256{b.Hash()})) + msg := NewMessage(CMDInv, payload.NewInventory(payload.BlockType, []util.Uint256{b.Hash()})) // Filter out nodes that are more current (avoid spamming the network // during initial sync). s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool { @@ -837,7 +831,7 @@ func (s *Server) broadcastTX(t *transaction.Transaction) { } func (s *Server) broadcastTxHashes(hs []util.Uint256) { - msg := s.MkMsg(CMDInv, payload.NewInventory(payload.TXType, hs)) + msg := NewMessage(CMDInv, payload.NewInventory(payload.TXType, hs)) // We need to filter out non-relaying nodes, so plain broadcast // functions don't fit here. diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index a166c7773..19ed92b1d 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -45,7 +45,7 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) { p.messageHandler = func(t *testing.T, msg *Message) { assert.Equal(t, CMDVerack, msg.Command) } - version := payload.NewVersion(1337, 3000, "/NEO-GO/", 0, true) + version := payload.NewVersion(0, 1337, 3000, "/NEO-GO/", 0, true) require.NoError(t, s.handleVersionCmd(p, version)) } @@ -59,6 +59,7 @@ func TestServerNotSendsVerack(t *testing.T) { p2 = newLocalPeer(t, s) ) s.id = 1 + s.Net = 56753 finished := make(chan struct{}) go func() { s.run() @@ -76,13 +77,20 @@ func TestServerNotSendsVerack(t *testing.T) { s.register <- p // identical id's - version := payload.NewVersion(1, 3000, "/NEO-GO/", 0, true) + version := payload.NewVersion(56753, 1, 3000, "/NEO-GO/", 0, true) err := s.handleVersionCmd(p, version) assert.NotNil(t, err) assert.Equal(t, errIdenticalID, err) - // Different IDs, make handshake pass. + // Different IDs, but also different magics version.Nonce = 2 + version.Magic = 56752 + err = s.handleVersionCmd(p, version) + assert.NotNil(t, err) + assert.Equal(t, errInvalidNetwork, err) + + // Different IDs and same network, make handshake pass. + version.Magic = 56753 require.NoError(t, s.handleVersionCmd(p, version)) require.NoError(t, p.HandleVersionAck()) require.Equal(t, true, p.Handshaked())