diff --git a/pkg/network/message.go b/pkg/network/message.go index 46eaa91f6..35402b844 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -40,17 +40,18 @@ const ( CMDPong CommandType = 0x19 // synchronization - CMDGetHeaders CommandType = 0x20 - CMDHeaders CommandType = 0x21 - CMDGetBlocks CommandType = 0x24 - CMDMempool CommandType = 0x25 - CMDInv CommandType = 0x27 - CMDGetData CommandType = 0x28 - CMDUnknown CommandType = 0x2a - CMDTX CommandType = 0x2b - CMDBlock CommandType = 0x2c - CMDConsensus CommandType = 0x2d - CMDReject CommandType = 0x2f + CMDGetHeaders CommandType = 0x20 + CMDHeaders CommandType = 0x21 + CMDGetBlocks CommandType = 0x24 + CMDMempool CommandType = 0x25 + CMDInv CommandType = 0x27 + CMDGetData CommandType = 0x28 + CMDGetBlockData CommandType = 0x29 + CMDUnknown CommandType = 0x2a + CMDTX CommandType = 0x2b + CMDBlock CommandType = 0x2c + CMDConsensus CommandType = 0x2d + CMDReject CommandType = 0x2f // SPV protocol CMDFilterLoad CommandType = 0x30 @@ -123,6 +124,8 @@ func (m *Message) decodePayload(br *io.BinReader) error { fallthrough case CMDGetHeaders: p = &payload.GetBlocks{} + case CMDGetBlockData: + p = &payload.GetBlockData{} case CMDHeaders: p = &payload.Headers{} case CMDTX: diff --git a/pkg/network/message_string.go b/pkg/network/message_string.go index 7d5d058ff..174bdfb53 100644 --- a/pkg/network/message_string.go +++ b/pkg/network/message_string.go @@ -20,6 +20,7 @@ func _() { _ = x[CMDMempool-37] _ = x[CMDInv-39] _ = x[CMDGetData-40] + _ = x[CMDGetBlockData-41] _ = x[CMDUnknown-42] _ = x[CMDTX-43] _ = x[CMDBlock-44] @@ -38,11 +39,10 @@ const ( _CommandType_name_2 = "CMDPingCMDPong" _CommandType_name_3 = "CMDGetHeadersCMDHeaders" _CommandType_name_4 = "CMDGetBlocksCMDMempool" - _CommandType_name_5 = "CMDInvCMDGetData" - _CommandType_name_6 = "CMDUnknownCMDTXCMDBlockCMDConsensus" - _CommandType_name_7 = "CMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear" - _CommandType_name_8 = "CMDMerkleBlock" - _CommandType_name_9 = "CMDAlert" + _CommandType_name_5 = "CMDInvCMDGetDataCMDGetBlockDataCMDUnknownCMDTXCMDBlockCMDConsensus" + _CommandType_name_6 = "CMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear" + _CommandType_name_7 = "CMDMerkleBlock" + _CommandType_name_8 = "CMDAlert" ) var ( @@ -51,9 +51,8 @@ var ( _CommandType_index_2 = [...]uint8{0, 7, 14} _CommandType_index_3 = [...]uint8{0, 13, 23} _CommandType_index_4 = [...]uint8{0, 12, 22} - _CommandType_index_5 = [...]uint8{0, 6, 16} - _CommandType_index_6 = [...]uint8{0, 10, 15, 23, 35} - _CommandType_index_7 = [...]uint8{0, 9, 22, 34, 48} + _CommandType_index_5 = [...]uint8{0, 6, 16, 31, 41, 46, 54, 66} + _CommandType_index_6 = [...]uint8{0, 9, 22, 34, 48} ) func (i CommandType) String() string { @@ -72,19 +71,16 @@ func (i CommandType) String() string { case 36 <= i && i <= 37: i -= 36 return _CommandType_name_4[_CommandType_index_4[i]:_CommandType_index_4[i+1]] - case 39 <= i && i <= 40: + case 39 <= i && i <= 45: i -= 39 return _CommandType_name_5[_CommandType_index_5[i]:_CommandType_index_5[i+1]] - case 42 <= i && i <= 45: - i -= 42 - return _CommandType_name_6[_CommandType_index_6[i]:_CommandType_index_6[i+1]] case 47 <= i && i <= 50: i -= 47 - return _CommandType_name_7[_CommandType_index_7[i]:_CommandType_index_7[i+1]] + return _CommandType_name_6[_CommandType_index_6[i]:_CommandType_index_6[i+1]] case i == 56: - return _CommandType_name_8 + return _CommandType_name_7 case i == 64: - return _CommandType_name_9 + return _CommandType_name_8 default: return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/network/payload/getblockdata.go b/pkg/network/payload/getblockdata.go new file mode 100644 index 000000000..2ed9c0d9f --- /dev/null +++ b/pkg/network/payload/getblockdata.go @@ -0,0 +1,39 @@ +package payload + +import ( + "errors" + + "github.com/nspcc-dev/neo-go/pkg/io" +) + +// maximum number of blocks to query about +const maxBlockCount = 500 + +// GetBlockData payload +type GetBlockData struct { + IndexStart uint32 + Count uint16 +} + +// NewGetBlockData returns GetBlockData payload with specified start index and count +func NewGetBlockData(indexStart uint32, count uint16) *GetBlockData { + return &GetBlockData{ + IndexStart: indexStart, + Count: count, + } +} + +// DecodeBinary implements Serializable interface. +func (d *GetBlockData) DecodeBinary(br *io.BinReader) { + d.IndexStart = br.ReadU32LE() + d.Count = br.ReadU16LE() + if d.Count == 0 || d.Count > maxBlockCount { + br.Err = errors.New("invalid block count") + } +} + +// EncodeBinary implements Serializable interface. +func (d *GetBlockData) EncodeBinary(bw *io.BinWriter) { + bw.WriteU32LE(d.IndexStart) + bw.WriteU16LE(d.Count) +} diff --git a/pkg/network/payload/getblockdata_test.go b/pkg/network/payload/getblockdata_test.go new file mode 100644 index 000000000..6704af858 --- /dev/null +++ b/pkg/network/payload/getblockdata_test.go @@ -0,0 +1,25 @@ +package payload + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/stretchr/testify/require" +) + +func TestGetBlockDataEncodeDecode(t *testing.T) { + d := NewGetBlockData(123, 100) + testserdes.EncodeDecodeBinary(t, d, new(GetBlockData)) + + // invalid block count + d = NewGetBlockData(5, 0) + data, err := testserdes.EncodeBinary(d) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(data, new(GetBlockData))) + + // invalid block count + d = NewGetBlockData(5, maxBlockCount+1) + data, err = testserdes.EncodeBinary(d) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(data, new(GetBlockData))) +} diff --git a/pkg/network/server.go b/pkg/network/server.go index 5cd5384e0..1ecc7782d 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -560,6 +560,19 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { return p.EnqueueP2PMessage(msg) } +// handleGetBlockDataCmd processes the getblockdata request. +func (s *Server) handleGetBlockDataCmd(p Peer, gbd *payload.GetBlockData) error { + for i := gbd.IndexStart; i < gbd.IndexStart+uint32(gbd.Count); i++ { + b, err := s.chain.GetBlock(s.chain.GetHeaderHash(int(i))) + if err != nil { + return err + } + msg := NewMessage(CMDBlock, b) + return p.EnqueueP2PMessage(msg) + } + return nil +} + // handleGetHeadersCmd processes the getheaders request. func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { count := gh.Count @@ -685,6 +698,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDGetBlocks: gb := msg.Payload.(*payload.GetBlocks) return s.handleGetBlocksCmd(peer, gb) + case CMDGetBlockData: + gbd := msg.Payload.(*payload.GetBlockData) + return s.handleGetBlockDataCmd(peer, gbd) case CMDGetData: inv := msg.Payload.(*payload.Inventory) return s.handleGetDataCmd(peer, inv)