diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index 243a31109..0054d246b 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -49,6 +49,7 @@ type FakeChain struct { type FakeStateSync struct { IsActiveFlag uatomic.Bool IsInitializedFlag uatomic.Bool + RequestHeaders uatomic.Bool InitFunc func(h uint32) error TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error AddMPTNodesFunc func(nodes [][]byte) error @@ -503,7 +504,7 @@ func (s *FakeStateSync) Init(currChainHeight uint32) error { } // NeedHeaders implements StateSync interface. -func (s *FakeStateSync) NeedHeaders() bool { return false } +func (s *FakeStateSync) NeedHeaders() bool { return s.RequestHeaders.Load() } // NeedMPTNodes implements StateSync interface. func (s *FakeStateSync) NeedMPTNodes() bool { diff --git a/pkg/network/server.go b/pkg/network/server.go index 16a6a46f5..9ec819978 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -84,8 +84,10 @@ type ( lock sync.RWMutex peers map[Peer]bool - // lastRequestedHeight contains last requested height. - lastRequestedHeight atomic.Uint32 + // lastRequestedBlock contains a height of the last requested block. + lastRequestedBlock atomic.Uint32 + // lastRequestedHeader contains a height of the last requested header. + lastRequestedHeader atomic.Uint32 register chan Peer unregister chan peerDrop @@ -694,11 +696,8 @@ func (s *Server) requestBlocksOrHeaders(p Peer) error { // requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers. func (s *Server) requestHeaders(p Peer) error { - // TODO: optimize - currHeight := s.chain.HeaderHeight() - needHeight := currHeight + 1 - payload := payload.NewGetBlockByIndex(needHeight, -1) - return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload)) + pl := getRequestBlocksPayload(p, s.chain.HeaderHeight(), &s.lastRequestedHeader) + return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, pl)) } // handlePing processes pong request. @@ -1112,22 +1111,26 @@ func (s *Server) handleGetAddrCmd(p Peer) error { // 2. Send requests for chunk in increasing order. // 3. After all requests were sent, request random height. func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error { - var currHeight = bq.BlockHeight() + pl := getRequestBlocksPayload(p, bq.BlockHeight(), &s.lastRequestedBlock) + return p.EnqueueP2PMessage(NewMessage(CMDGetBlockByIndex, pl)) +} + +func getRequestBlocksPayload(p Peer, currHeight uint32, lastRequestedHeight *atomic.Uint32) *payload.GetBlockByIndex { var peerHeight = p.LastBlockIndex() var needHeight uint32 - // lastRequestedHeight can only be increased. + // lastRequestedBlock can only be increased. for { - old := s.lastRequestedHeight.Load() + old := lastRequestedHeight.Load() if old <= currHeight { needHeight = currHeight + 1 - if !s.lastRequestedHeight.CAS(old, needHeight) { + if !lastRequestedHeight.CAS(old, needHeight) { continue } } else if old < currHeight+(blockCacheSize-payload.MaxHashesCount) { needHeight = currHeight + 1 if peerHeight > old+payload.MaxHashesCount { needHeight = old + payload.MaxHashesCount - if !s.lastRequestedHeight.CAS(old, needHeight) { + if !lastRequestedHeight.CAS(old, needHeight) { continue } } @@ -1137,8 +1140,7 @@ func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error { } break } - payload := payload.NewGetBlockByIndex(needHeight, -1) - return p.EnqueueP2PMessage(NewMessage(CMDGetBlockByIndex, payload)) + return payload.NewGetBlockByIndex(needHeight, -1) } // handleMessage processes the given message. diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index e4c109c9f..9b35c09db 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -186,26 +186,34 @@ func TestServerRegisterPeer(t *testing.T) { } func TestGetBlocksByIndex(t *testing.T) { + testGetBlocksByIndex(t, CMDGetBlockByIndex) +} + +func testGetBlocksByIndex(t *testing.T, cmd CommandType) { s := newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"}) + start := s.chain.BlockHeight() + if cmd == CMDGetHeaders { + start = s.chain.HeaderHeight() + s.stateSync.(*fakechain.FakeStateSync).RequestHeaders.Store(true) + } ps := make([]*localPeer, 10) expectsCmd := make([]CommandType, 10) expectedHeight := make([][]uint32, 10) - start := s.chain.BlockHeight() for i := range ps { i := i ps[i] = newLocalPeer(t, s) ps[i].messageHandler = func(t *testing.T, msg *Message) { require.Equal(t, expectsCmd[i], msg.Command) - if expectsCmd[i] == CMDGetBlockByIndex { + if expectsCmd[i] == cmd { p, ok := msg.Payload.(*payload.GetBlockByIndex) require.True(t, ok) require.Contains(t, expectedHeight[i], p.IndexStart) expectsCmd[i] = CMDPong } else if expectsCmd[i] == CMDPong { - expectsCmd[i] = CMDGetBlockByIndex + expectsCmd[i] = cmd } } - expectsCmd[i] = CMDGetBlockByIndex + expectsCmd[i] = cmd expectedHeight[i] = []uint32{start + 1} } go s.transport.Accept() @@ -678,6 +686,9 @@ func TestGetHeaders(t *testing.T) { s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: 123, Count: -1}) require.Nil(t, actual) }) + t.Run("distribute requests between peers", func(t *testing.T) { + testGetBlocksByIndex(t, CMDGetHeaders) + }) } func TestInv(t *testing.T) {