diff --git a/pkg/network/payload/headers.go b/pkg/network/payload/headers.go index 8635cb1e6..935bd05de 100644 --- a/pkg/network/payload/headers.go +++ b/pkg/network/payload/headers.go @@ -13,7 +13,7 @@ type Headers struct { // Users can at most request 2k header. const ( - maxHeadersAllowed = 2000 + MaxHeadersAllowed = 2000 ) // DecodeBinary implements Serializable interface. @@ -21,9 +21,9 @@ func (p *Headers) DecodeBinary(br *io.BinReader) { lenHeaders := br.ReadVarUint() // C# node does it silently - if lenHeaders > maxHeadersAllowed { - log.Warnf("received %d headers, capping to %d", lenHeaders, maxHeadersAllowed) - lenHeaders = maxHeadersAllowed + if lenHeaders > MaxHeadersAllowed { + log.Warnf("received %d headers, capping to %d", lenHeaders, MaxHeadersAllowed) + lenHeaders = MaxHeadersAllowed } p.Hdrs = make([]*core.Header, lenHeaders) diff --git a/pkg/network/server.go b/pkg/network/server.go index 97196ec9c..2cf5d00f3 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -37,6 +37,7 @@ var ( errMaxPeers = errors.New("max peers reached") errServerShutdown = errors.New("server shutdown") errInvalidInvType = errors.New("invalid inventory type") + errInvalidHashStart = errors.New("invalid requested HashStart") ) type ( @@ -421,6 +422,35 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { return nil } +// handleGetHeadersCmd processes the getheaders request. +func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { + if len(gh.HashStart) < 1 { + return errInvalidHashStart + } + startHash := gh.HashStart[0] + start, err := s.chain.GetHeader(startHash) + if err != nil { + return err + } + resp := payload.Headers{} + resp.Hdrs = make([]*core.Header, 0, payload.MaxHeadersAllowed) + for i := start.Index + 1; i < start.Index+1+payload.MaxHeadersAllowed; i++ { + hash := s.chain.GetHeaderHash(int(i)) + if hash.Equals(util.Uint256{}) || hash.Equals(gh.HashStop) { + break + } + header, err := s.chain.GetHeader(hash) + if err != nil { + break + } + resp.Hdrs = append(resp.Hdrs, header) + } + if len(resp.Hdrs) == 0 { + return nil + } + return p.WriteMsg(NewMessage(s.Net, CMDHeaders, &resp)) +} + // handleConsensusCmd processes received consensus payload. // It never returns an error. func (s *Server) handleConsensusCmd(cp *consensus.Payload) error { @@ -514,6 +544,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDGetData: inv := msg.Payload.(*payload.Inventory) return s.handleGetDataCmd(peer, inv) + case CMDGetHeaders: + gh := msg.Payload.(*payload.GetBlocks) + return s.handleGetHeadersCmd(peer, gh) case CMDHeaders: headers := msg.Payload.(*payload.Headers) go s.handleHeadersCmd(peer, headers)