diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index ca1f0bc6d..2d7297349 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -508,3 +508,8 @@ func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, func (s *FakeStateSync) GetJumpHeight() (uint32, error) { panic("TODO") } + +// GetUnknownMPTNodesBatch implements StateSync interface. +func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 { + panic("TODO") +} diff --git a/pkg/core/blockchainer/state_sync.go b/pkg/core/blockchainer/state_sync.go index 66643020c..9c2161853 100644 --- a/pkg/core/blockchainer/state_sync.go +++ b/pkg/core/blockchainer/state_sync.go @@ -13,6 +13,7 @@ type StateSync interface { IsActive() bool IsInitialized() bool GetJumpHeight() (uint32, error) + GetUnknownMPTNodesBatch(limit int) []util.Uint256 NeedHeaders() bool NeedMPTNodes() bool Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error diff --git a/pkg/core/mpt/billet.go b/pkg/core/mpt/billet.go index f66415036..843a746ee 100644 --- a/pkg/core/mpt/billet.go +++ b/pkg/core/mpt/billet.go @@ -215,7 +215,7 @@ func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) b return curr, nil } if hn, ok := curr.(*HashNode); ok { - r, err := b.getFromStore(hn.Hash()) + r, err := b.GetFromStore(hn.Hash()) if err != nil { if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) { return hn, nil @@ -292,7 +292,8 @@ func (b *Billet) tryCollapseBranch(curr *BranchNode) Node { return res } -func (b *Billet) getFromStore(h util.Uint256) (Node, error) { +// GetFromStore returns MPT node from the storage. +func (b *Billet) GetFromStore(h util.Uint256) (Node, error) { data, err := b.Store.Get(makeStorageKey(h.BytesBE())) if err != nil { return nil, err diff --git a/pkg/core/statesync/module.go b/pkg/core/statesync/module.go index 24f38f14e..e7874bd39 100644 --- a/pkg/core/statesync/module.go +++ b/pkg/core/statesync/module.go @@ -328,24 +328,10 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error { if r.Err != nil { return fmt.Errorf("failed to decode MPT node: %w", r.Err) } - nPaths, ok := s.mptpool.TryGet(n.Hash()) - if !ok { - // it can easily happen after receiving the same data from different peers. - return nil + err := s.restoreNode(n.Node) + if err != nil { + return err } - - var childrenPaths = make(map[util.Uint256][][]byte) - for _, path := range nPaths { - err := s.billet.RestoreHashNode(path, n.Node) - if err != nil { - return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err) - } - for h, paths := range mpt.GetChildrenPaths(path, n.Node) { - childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool - } - } - - s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths) } if s.mptpool.Count() == 0 { s.syncStage |= mptSynced @@ -356,6 +342,37 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error { return nil } +func (s *Module) restoreNode(n mpt.Node) error { + nPaths, ok := s.mptpool.TryGet(n.Hash()) + if !ok { + // it can easily happen after receiving the same data from different peers. + return nil + } + var childrenPaths = make(map[util.Uint256][][]byte) + for _, path := range nPaths { + err := s.billet.RestoreHashNode(path, n) + if err != nil { + return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err) + } + for h, paths := range mpt.GetChildrenPaths(path, n) { + childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool + } + } + + s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths) + + for h := range childrenPaths { + if child, err := s.billet.GetFromStore(h); err == nil { + // child is already in the storage, so we don't need to request it one more time. + err = s.restoreNode(child) + if err != nil { + return fmt.Errorf("unable to restore saved children: %w", err) + } + } + } + return nil +} + // checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1 // height are fetched, blocks up to P height are stored and MPT nodes for P height are stored. // If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller @@ -438,3 +455,11 @@ func (s *Module) GetJumpHeight() (uint32, error) { } return s.syncPoint, nil } + +// GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max). +func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.mptpool.GetBatch(limit) +} diff --git a/pkg/core/statesync/mptpool.go b/pkg/core/statesync/mptpool.go index 23610d97a..93bbb41a4 100644 --- a/pkg/core/statesync/mptpool.go +++ b/pkg/core/statesync/mptpool.go @@ -48,6 +48,26 @@ func (mp *Pool) GetAll() map[util.Uint256][][]byte { return mp.hashes } +// GetBatch returns set of unknown MPT nodes hashes (`limit` at max). +func (mp *Pool) GetBatch(limit int) []util.Uint256 { + mp.lock.RLock() + defer mp.lock.RUnlock() + + count := len(mp.hashes) + if count > limit { + count = limit + } + result := make([]util.Uint256, 0, limit) + for h := range mp.hashes { + if count == 0 { + break + } + result = append(result, h) + count-- + } + return result +} + // Remove removes MPT node from the pool by the specified hash. func (mp *Pool) Remove(hash util.Uint256) { mp.lock.Lock() diff --git a/pkg/core/statesync/mptpool_test.go b/pkg/core/statesync/mptpool_test.go new file mode 100644 index 000000000..bab32364b --- /dev/null +++ b/pkg/core/statesync/mptpool_test.go @@ -0,0 +1,104 @@ +package statesync + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/internal/random" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestPool_AddRemoveUpdate(t *testing.T) { + mp := NewPool() + + i1 := []byte{1, 2, 3} + i1h := util.Uint256{1, 2, 3} + i2 := []byte{2, 3, 4} + i2h := util.Uint256{2, 3, 4} + i3 := []byte{4, 5, 6} + i3h := util.Uint256{3, 4, 5} + i4 := []byte{3, 4, 5} // has the same hash as i3 + i5 := []byte{6, 7, 8} // has the same hash as i3 + mapAll := map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}, i3h: {i4, i3}} + + // No items + _, ok := mp.TryGet(i1h) + require.False(t, ok) + require.False(t, mp.ContainsKey(i1h)) + require.Equal(t, 0, mp.Count()) + require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll()) + + // Add i1, i2, check OK + mp.Add(i1h, i1) + mp.Add(i2h, i2) + itm, ok := mp.TryGet(i1h) + require.True(t, ok) + require.Equal(t, [][]byte{i1}, itm) + require.True(t, mp.ContainsKey(i1h)) + require.True(t, mp.ContainsKey(i2h)) + require.Equal(t, map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + + // Remove i1 and unexisting item + mp.Remove(i3h) + mp.Remove(i1h) + require.False(t, mp.ContainsKey(i1h)) + require.True(t, mp.ContainsKey(i2h)) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}}, mp.GetAll()) + require.Equal(t, 1, mp.Count()) + + // Update: remove nothing, add all + mp.Update(nil, mapAll) + require.Equal(t, mapAll, mp.GetAll()) + require.Equal(t, 3, mp.Count()) + // Update: remove all, add all + mp.Update(mapAll, mapAll) + require.Equal(t, mapAll, mp.GetAll()) // deletion first, addition after that + require.Equal(t, 3, mp.Count()) + // Update: remove all, add nothing + mp.Update(mapAll, nil) + require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll()) + require.Equal(t, 0, mp.Count()) + // Update: remove several, add several + mp.Update(map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + + // Update: remove nothing, add several with same hashes + mp.Update(nil, map[util.Uint256][][]byte{i3h: {i5, i4}}) // should be sorted by the pool + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i3, i5}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + // Update: remove several with same hashes, add nothing + mp.Update(map[util.Uint256][][]byte{i3h: {i5, i4}}, nil) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + // Update: remove several with same hashes, add several with same hashes + mp.Update(map[util.Uint256][][]byte{i3h: {i5, i3}}, map[util.Uint256][][]byte{i3h: {i5, i4}}) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i5}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) +} + +func TestPool_GetBatch(t *testing.T) { + check := func(t *testing.T, limit int, itemsCount int) { + mp := NewPool() + for i := 0; i < itemsCount; i++ { + mp.Add(random.Uint256(), []byte{0x01}) + } + batch := mp.GetBatch(limit) + if limit < itemsCount { + require.Equal(t, limit, len(batch)) + } else { + require.Equal(t, itemsCount, len(batch)) + } + } + + t.Run("limit less than items count", func(t *testing.T) { + check(t, 5, 6) + }) + t.Run("limit more than items count", func(t *testing.T) { + check(t, 6, 5) + }) + t.Run("items count limit", func(t *testing.T) { + check(t, 5, 5) + }) +} diff --git a/pkg/network/server.go b/pkg/network/server.go index b9f53ed0e..bf57ade5f 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -671,12 +671,23 @@ func (s *Server) requestBlocksOrHeaders(p Peer) error { } return nil } - var bq blockchainer.Blockqueuer = s.chain + var ( + bq blockchainer.Blockqueuer = s.chain + requestMPTNodes bool + ) if s.stateSync.IsActive() { bq = s.stateSync + requestMPTNodes = s.stateSync.NeedMPTNodes() } - if bq.BlockHeight() < p.LastBlockIndex() { - return s.requestBlocks(bq, p) + if bq.BlockHeight() >= p.LastBlockIndex() { + return nil + } + err := s.requestBlocks(bq, p) + if err != nil { + return err + } + if requestMPTNodes { + return s.requestMPTNodes(p, s.stateSync.GetUnknownMPTNodesBatch(payload.MaxMPTHashesCount)) } return nil } @@ -849,6 +860,20 @@ func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error { return s.stateSync.AddMPTNodes(data.Nodes) } +// requestMPTNodes requests specified MPT nodes from the peer or broadcasts +// request if peer is not specified. +func (s *Server) requestMPTNodes(p Peer, itms []util.Uint256) error { + if len(itms) == 0 { + return nil + } + if len(itms) > payload.MaxMPTHashesCount { + itms = itms[:payload.MaxMPTHashesCount] + } + pl := payload.NewMPTInventory(itms) + msg := NewMessage(CMDGetMPTData, pl) + return p.EnqueueP2PMessage(msg) +} + // handleGetBlocksCmd processes the getblocks request. func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { count := gb.Count