diff --git a/pkg/network/payload/getblocks.go b/pkg/network/payload/getblocks.go index 5a07d38b4..8a0bf0a56 100644 --- a/pkg/network/payload/getblocks.go +++ b/pkg/network/payload/getblocks.go @@ -1,6 +1,8 @@ package payload import ( + "errors" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -13,27 +15,30 @@ const ( // GetBlocks contains fields and methods to be shared with the type GetBlocks struct { // hash of latest block that node requests - HashStart []util.Uint256 - // hash of last block that node requests - HashStop util.Uint256 + HashStart util.Uint256 + Count int16 } // NewGetBlocks returns a pointer to a GetBlocks object. -func NewGetBlocks(start []util.Uint256, stop util.Uint256) *GetBlocks { +func NewGetBlocks(start util.Uint256, count int16) *GetBlocks { return &GetBlocks{ HashStart: start, - HashStop: stop, + Count: count, } } // DecodeBinary implements Serializable interface. func (p *GetBlocks) DecodeBinary(br *io.BinReader) { - br.ReadArray(&p.HashStart) - br.ReadBytes(p.HashStop[:]) + p.HashStart.DecodeBinary(br) + p.Count = int16(br.ReadU16LE()) + if p.Count < -1 || p.Count == 0 { + br.Err = errors.New("invalid count") + } + } // EncodeBinary implements Serializable interface. func (p *GetBlocks) EncodeBinary(bw *io.BinWriter) { - bw.WriteArray(p.HashStart) - bw.WriteBytes(p.HashStop[:]) + p.HashStart.EncodeBinary(bw) + bw.WriteU16LE(uint16(p.Count)) } diff --git a/pkg/network/payload/getblocks_test.go b/pkg/network/payload/getblocks_test.go index baa3f730e..8d720b12b 100644 --- a/pkg/network/payload/getblocks_test.go +++ b/pkg/network/payload/getblocks_test.go @@ -5,31 +5,24 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" - "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" ) func TestGetBlockEncodeDecode(t *testing.T) { - start := []util.Uint256{ - hash.Sha256([]byte("a")), - hash.Sha256([]byte("b")), - hash.Sha256([]byte("c")), - hash.Sha256([]byte("d")), - } + start := hash.Sha256([]byte("a")) - p := NewGetBlocks(start, util.Uint256{}) + p := NewGetBlocks(start, 124) testserdes.EncodeDecodeBinary(t, p, new(GetBlocks)) -} -func TestGetBlockEncodeDecodeWithHashStop(t *testing.T) { - var ( - start = []util.Uint256{ - hash.Sha256([]byte("a")), - hash.Sha256([]byte("b")), - hash.Sha256([]byte("c")), - hash.Sha256([]byte("d")), - } - stop = hash.Sha256([]byte("e")) - ) - p := NewGetBlocks(start, stop) - testserdes.EncodeDecodeBinary(t, p, new(GetBlocks)) + // invalid count + p = NewGetBlocks(start, -2) + data, err := testserdes.EncodeBinary(p) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(data, new(GetBlocks))) + + // invalid count + p = NewGetBlocks(start, 0) + data, err = testserdes.EncodeBinary(p) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(data, new(GetBlocks))) } diff --git a/pkg/network/server.go b/pkg/network/server.go index ff99d3d9a..5cd5384e0 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -535,21 +535,18 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { // handleGetBlocksCmd processes the getblocks request. func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { - if len(gb.HashStart) < 1 { - return errInvalidHashStart + count := gb.Count + if gb.Count < 0 || gb.Count > payload.MaxHashesCount { + count = payload.MaxHashesCount } - startHash := gb.HashStart[0] - if startHash.Equals(gb.HashStop) { - return nil - } - start, err := s.chain.GetHeader(startHash) + start, err := s.chain.GetHeader(gb.HashStart) if err != nil { return err } blockHashes := make([]util.Uint256, 0) - for i := start.Index + 1; i < start.Index+1+payload.MaxHashesCount; i++ { + for i := start.Index + 1; i < start.Index+uint32(count); i++ { hash := s.chain.GetHeaderHash(int(i)) - if hash.Equals(util.Uint256{}) || hash.Equals(gb.HashStop) { + if hash.Equals(util.Uint256{}) { break } blockHashes = append(blockHashes, hash) @@ -565,19 +562,19 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { // handleGetHeadersCmd processes the getheaders request. func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { - if len(gh.HashStart) < 1 { - return errInvalidHashStart + count := gh.Count + if gh.Count < 0 || gh.Count > payload.MaxHashesCount { + count = payload.MaxHashesCount } - startHash := gh.HashStart[0] - start, err := s.chain.GetHeader(startHash) + start, err := s.chain.GetHeader(gh.HashStart) if err != nil { return err } resp := payload.Headers{} resp.Hdrs = make([]*block.Header, 0, payload.MaxHeadersAllowed) - for i := start.Index + 1; i < start.Index+1+payload.MaxHeadersAllowed; i++ { + for i := start.Index + 1; i < start.Index+uint32(count); i++ { hash := s.chain.GetHeaderHash(int(i)) - if hash.Equals(util.Uint256{}) || hash.Equals(gh.HashStop) { + if hash.Equals(util.Uint256{}) { break } header, err := s.chain.GetHeader(hash) @@ -637,10 +634,9 @@ func (s *Server) handleGetAddrCmd(p Peer) error { } // requestHeaders sends a getheaders message to the peer. -// The peer will respond with headers op to a count of 2000. +// The peer will respond with headers op to a count of 500. func (s *Server) requestHeaders(p Peer) error { - start := []util.Uint256{s.chain.CurrentHeaderHash()} - payload := payload.NewGetBlocks(start, util.Uint256{}) + payload := payload.NewGetBlocks(s.chain.CurrentHeaderHash(), -1) return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload)) }