mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2025-01-03 19:22:49 +00:00
network: request unknown MPT nodes
In this commit: 1. Request unknown MPT nodes from peers. Note, that StateSync module itself shouldn't be responsible for nodes requests, that's a server duty. 2. Do not request the same node twice, check if it is in storage already. If so, then the only thing remaining is to update refcounter.
This commit is contained in:
parent
6a04880b49
commit
3b7807e897
7 changed files with 203 additions and 22 deletions
|
@ -508,3 +508,8 @@ func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node,
|
|||
func (s *FakeStateSync) GetJumpHeight() (uint32, error) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// GetUnknownMPTNodesBatch implements StateSync interface.
|
||||
func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
|
||||
panic("TODO")
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ type StateSync interface {
|
|||
IsActive() bool
|
||||
IsInitialized() bool
|
||||
GetJumpHeight() (uint32, error)
|
||||
GetUnknownMPTNodesBatch(limit int) []util.Uint256
|
||||
NeedHeaders() bool
|
||||
NeedMPTNodes() bool
|
||||
Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
|
||||
|
|
|
@ -215,7 +215,7 @@ func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) b
|
|||
return curr, nil
|
||||
}
|
||||
if hn, ok := curr.(*HashNode); ok {
|
||||
r, err := b.getFromStore(hn.Hash())
|
||||
r, err := b.GetFromStore(hn.Hash())
|
||||
if err != nil {
|
||||
if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) {
|
||||
return hn, nil
|
||||
|
@ -292,7 +292,8 @@ func (b *Billet) tryCollapseBranch(curr *BranchNode) Node {
|
|||
return res
|
||||
}
|
||||
|
||||
func (b *Billet) getFromStore(h util.Uint256) (Node, error) {
|
||||
// GetFromStore returns MPT node from the storage.
|
||||
func (b *Billet) GetFromStore(h util.Uint256) (Node, error) {
|
||||
data, err := b.Store.Get(makeStorageKey(h.BytesBE()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -328,24 +328,10 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error {
|
|||
if r.Err != nil {
|
||||
return fmt.Errorf("failed to decode MPT node: %w", r.Err)
|
||||
}
|
||||
nPaths, ok := s.mptpool.TryGet(n.Hash())
|
||||
if !ok {
|
||||
// it can easily happen after receiving the same data from different peers.
|
||||
return nil
|
||||
err := s.restoreNode(n.Node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var childrenPaths = make(map[util.Uint256][][]byte)
|
||||
for _, path := range nPaths {
|
||||
err := s.billet.RestoreHashNode(path, n.Node)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
|
||||
}
|
||||
for h, paths := range mpt.GetChildrenPaths(path, n.Node) {
|
||||
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||
}
|
||||
}
|
||||
|
||||
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
|
||||
}
|
||||
if s.mptpool.Count() == 0 {
|
||||
s.syncStage |= mptSynced
|
||||
|
@ -356,6 +342,37 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Module) restoreNode(n mpt.Node) error {
|
||||
nPaths, ok := s.mptpool.TryGet(n.Hash())
|
||||
if !ok {
|
||||
// it can easily happen after receiving the same data from different peers.
|
||||
return nil
|
||||
}
|
||||
var childrenPaths = make(map[util.Uint256][][]byte)
|
||||
for _, path := range nPaths {
|
||||
err := s.billet.RestoreHashNode(path, n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
|
||||
}
|
||||
for h, paths := range mpt.GetChildrenPaths(path, n) {
|
||||
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||
}
|
||||
}
|
||||
|
||||
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
|
||||
|
||||
for h := range childrenPaths {
|
||||
if child, err := s.billet.GetFromStore(h); err == nil {
|
||||
// child is already in the storage, so we don't need to request it one more time.
|
||||
err = s.restoreNode(child)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to restore saved children: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1
|
||||
// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored.
|
||||
// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller
|
||||
|
@ -438,3 +455,11 @@ func (s *Module) GetJumpHeight() (uint32, error) {
|
|||
}
|
||||
return s.syncPoint, nil
|
||||
}
|
||||
|
||||
// GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max).
|
||||
func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.mptpool.GetBatch(limit)
|
||||
}
|
||||
|
|
|
@ -48,6 +48,26 @@ func (mp *Pool) GetAll() map[util.Uint256][][]byte {
|
|||
return mp.hashes
|
||||
}
|
||||
|
||||
// GetBatch returns set of unknown MPT nodes hashes (`limit` at max).
|
||||
func (mp *Pool) GetBatch(limit int) []util.Uint256 {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
count := len(mp.hashes)
|
||||
if count > limit {
|
||||
count = limit
|
||||
}
|
||||
result := make([]util.Uint256, 0, limit)
|
||||
for h := range mp.hashes {
|
||||
if count == 0 {
|
||||
break
|
||||
}
|
||||
result = append(result, h)
|
||||
count--
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Remove removes MPT node from the pool by the specified hash.
|
||||
func (mp *Pool) Remove(hash util.Uint256) {
|
||||
mp.lock.Lock()
|
||||
|
|
104
pkg/core/statesync/mptpool_test.go
Normal file
104
pkg/core/statesync/mptpool_test.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package statesync
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPool_AddRemoveUpdate(t *testing.T) {
|
||||
mp := NewPool()
|
||||
|
||||
i1 := []byte{1, 2, 3}
|
||||
i1h := util.Uint256{1, 2, 3}
|
||||
i2 := []byte{2, 3, 4}
|
||||
i2h := util.Uint256{2, 3, 4}
|
||||
i3 := []byte{4, 5, 6}
|
||||
i3h := util.Uint256{3, 4, 5}
|
||||
i4 := []byte{3, 4, 5} // has the same hash as i3
|
||||
i5 := []byte{6, 7, 8} // has the same hash as i3
|
||||
mapAll := map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}, i3h: {i4, i3}}
|
||||
|
||||
// No items
|
||||
_, ok := mp.TryGet(i1h)
|
||||
require.False(t, ok)
|
||||
require.False(t, mp.ContainsKey(i1h))
|
||||
require.Equal(t, 0, mp.Count())
|
||||
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
|
||||
|
||||
// Add i1, i2, check OK
|
||||
mp.Add(i1h, i1)
|
||||
mp.Add(i2h, i2)
|
||||
itm, ok := mp.TryGet(i1h)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, [][]byte{i1}, itm)
|
||||
require.True(t, mp.ContainsKey(i1h))
|
||||
require.True(t, mp.ContainsKey(i2h))
|
||||
require.Equal(t, map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
|
||||
// Remove i1 and unexisting item
|
||||
mp.Remove(i3h)
|
||||
mp.Remove(i1h)
|
||||
require.False(t, mp.ContainsKey(i1h))
|
||||
require.True(t, mp.ContainsKey(i2h))
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}}, mp.GetAll())
|
||||
require.Equal(t, 1, mp.Count())
|
||||
|
||||
// Update: remove nothing, add all
|
||||
mp.Update(nil, mapAll)
|
||||
require.Equal(t, mapAll, mp.GetAll())
|
||||
require.Equal(t, 3, mp.Count())
|
||||
// Update: remove all, add all
|
||||
mp.Update(mapAll, mapAll)
|
||||
require.Equal(t, mapAll, mp.GetAll()) // deletion first, addition after that
|
||||
require.Equal(t, 3, mp.Count())
|
||||
// Update: remove all, add nothing
|
||||
mp.Update(mapAll, nil)
|
||||
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
|
||||
require.Equal(t, 0, mp.Count())
|
||||
// Update: remove several, add several
|
||||
mp.Update(map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}})
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
|
||||
// Update: remove nothing, add several with same hashes
|
||||
mp.Update(nil, map[util.Uint256][][]byte{i3h: {i5, i4}}) // should be sorted by the pool
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i3, i5}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
// Update: remove several with same hashes, add nothing
|
||||
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i4}}, nil)
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
// Update: remove several with same hashes, add several with same hashes
|
||||
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i3}}, map[util.Uint256][][]byte{i3h: {i5, i4}})
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i5}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
}
|
||||
|
||||
func TestPool_GetBatch(t *testing.T) {
|
||||
check := func(t *testing.T, limit int, itemsCount int) {
|
||||
mp := NewPool()
|
||||
for i := 0; i < itemsCount; i++ {
|
||||
mp.Add(random.Uint256(), []byte{0x01})
|
||||
}
|
||||
batch := mp.GetBatch(limit)
|
||||
if limit < itemsCount {
|
||||
require.Equal(t, limit, len(batch))
|
||||
} else {
|
||||
require.Equal(t, itemsCount, len(batch))
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("limit less than items count", func(t *testing.T) {
|
||||
check(t, 5, 6)
|
||||
})
|
||||
t.Run("limit more than items count", func(t *testing.T) {
|
||||
check(t, 6, 5)
|
||||
})
|
||||
t.Run("items count limit", func(t *testing.T) {
|
||||
check(t, 5, 5)
|
||||
})
|
||||
}
|
|
@ -671,12 +671,23 @@ func (s *Server) requestBlocksOrHeaders(p Peer) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
var bq blockchainer.Blockqueuer = s.chain
|
||||
var (
|
||||
bq blockchainer.Blockqueuer = s.chain
|
||||
requestMPTNodes bool
|
||||
)
|
||||
if s.stateSync.IsActive() {
|
||||
bq = s.stateSync
|
||||
requestMPTNodes = s.stateSync.NeedMPTNodes()
|
||||
}
|
||||
if bq.BlockHeight() < p.LastBlockIndex() {
|
||||
return s.requestBlocks(bq, p)
|
||||
if bq.BlockHeight() >= p.LastBlockIndex() {
|
||||
return nil
|
||||
}
|
||||
err := s.requestBlocks(bq, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if requestMPTNodes {
|
||||
return s.requestMPTNodes(p, s.stateSync.GetUnknownMPTNodesBatch(payload.MaxMPTHashesCount))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -849,6 +860,20 @@ func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error {
|
|||
return s.stateSync.AddMPTNodes(data.Nodes)
|
||||
}
|
||||
|
||||
// requestMPTNodes requests specified MPT nodes from the peer or broadcasts
|
||||
// request if peer is not specified.
|
||||
func (s *Server) requestMPTNodes(p Peer, itms []util.Uint256) error {
|
||||
if len(itms) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(itms) > payload.MaxMPTHashesCount {
|
||||
itms = itms[:payload.MaxMPTHashesCount]
|
||||
}
|
||||
pl := payload.NewMPTInventory(itms)
|
||||
msg := NewMessage(CMDGetMPTData, pl)
|
||||
return p.EnqueueP2PMessage(msg)
|
||||
}
|
||||
|
||||
// handleGetBlocksCmd processes the getblocks request.
|
||||
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
|
||||
count := gb.Count
|
||||
|
|
Loading…
Reference in a new issue