diff --git a/pkg/network/payload/getblockbyindex.go b/pkg/network/payload/getblockbyindex.go index 450dd4503..a4e59db1d 100644 --- a/pkg/network/payload/getblockbyindex.go +++ b/pkg/network/payload/getblockbyindex.go @@ -9,11 +9,11 @@ import ( // GetBlockByIndex payload type GetBlockByIndex struct { IndexStart uint32 - Count uint16 + Count int16 } // NewGetBlockByIndex returns GetBlockByIndex payload with specified start index and count -func NewGetBlockByIndex(indexStart uint32, count uint16) *GetBlockByIndex { +func NewGetBlockByIndex(indexStart uint32, count int16) *GetBlockByIndex { return &GetBlockByIndex{ IndexStart: indexStart, Count: count, @@ -23,8 +23,8 @@ func NewGetBlockByIndex(indexStart uint32, count uint16) *GetBlockByIndex { // DecodeBinary implements Serializable interface. func (d *GetBlockByIndex) DecodeBinary(br *io.BinReader) { d.IndexStart = br.ReadU32LE() - d.Count = br.ReadU16LE() - if d.Count == 0 || d.Count > MaxHeadersAllowed { + d.Count = int16(br.ReadU16LE()) + if d.Count < -1 || d.Count == 0 || d.Count > MaxHeadersAllowed { br.Err = errors.New("invalid block count") } } @@ -32,5 +32,5 @@ func (d *GetBlockByIndex) DecodeBinary(br *io.BinReader) { // EncodeBinary implements Serializable interface. func (d *GetBlockByIndex) EncodeBinary(bw *io.BinWriter) { bw.WriteU32LE(d.IndexStart) - bw.WriteU16LE(d.Count) + bw.WriteU16LE(uint16(d.Count)) } diff --git a/pkg/network/server.go b/pkg/network/server.go index 874642588..6df448d27 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -611,7 +611,11 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { // handleGetBlockByIndexCmd processes the getblockbyindex request. func (s *Server) handleGetBlockByIndexCmd(p Peer, gbd *payload.GetBlockByIndex) error { - for i := gbd.IndexStart; i < gbd.IndexStart+uint32(gbd.Count); i++ { + count := gbd.Count + if gbd.Count < 0 || gbd.Count > payload.MaxHashesCount { + count = payload.MaxHashesCount + } + for i := gbd.IndexStart; i < gbd.IndexStart+uint32(count); i++ { b, err := s.chain.GetBlock(s.chain.GetHeaderHash(int(i))) if err != nil { return err