network: request unknown MPT nodes

In this commit:

1. Request unknown MPT nodes from peers. Note, that StateSync module itself
shouldn't be responsible for nodes requests, that's a server duty.
2. Do not request the same node twice, check if it is in storage
already. If so, then the only thing remaining is to update refcounter.
This commit is contained in:
Anna Shaleva 2021-08-13 12:46:23 +03:00
parent 6a04880b49
commit 3b7807e897
7 changed files with 203 additions and 22 deletions

View file

@ -508,3 +508,8 @@ func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node,
func (s *FakeStateSync) GetJumpHeight() (uint32, error) { func (s *FakeStateSync) GetJumpHeight() (uint32, error) {
panic("TODO") panic("TODO")
} }
// GetUnknownMPTNodesBatch implements StateSync interface.
func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
panic("TODO")
}

View file

@ -13,6 +13,7 @@ type StateSync interface {
IsActive() bool IsActive() bool
IsInitialized() bool IsInitialized() bool
GetJumpHeight() (uint32, error) GetJumpHeight() (uint32, error)
GetUnknownMPTNodesBatch(limit int) []util.Uint256
NeedHeaders() bool NeedHeaders() bool
NeedMPTNodes() bool NeedMPTNodes() bool
Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error

View file

@ -215,7 +215,7 @@ func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) b
return curr, nil return curr, nil
} }
if hn, ok := curr.(*HashNode); ok { if hn, ok := curr.(*HashNode); ok {
r, err := b.getFromStore(hn.Hash()) r, err := b.GetFromStore(hn.Hash())
if err != nil { if err != nil {
if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) { if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) {
return hn, nil return hn, nil
@ -292,7 +292,8 @@ func (b *Billet) tryCollapseBranch(curr *BranchNode) Node {
return res return res
} }
func (b *Billet) getFromStore(h util.Uint256) (Node, error) { // GetFromStore returns MPT node from the storage.
func (b *Billet) GetFromStore(h util.Uint256) (Node, error) {
data, err := b.Store.Get(makeStorageKey(h.BytesBE())) data, err := b.Store.Get(makeStorageKey(h.BytesBE()))
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -328,24 +328,10 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error {
if r.Err != nil { if r.Err != nil {
return fmt.Errorf("failed to decode MPT node: %w", r.Err) return fmt.Errorf("failed to decode MPT node: %w", r.Err)
} }
nPaths, ok := s.mptpool.TryGet(n.Hash()) err := s.restoreNode(n.Node)
if !ok {
// it can easily happen after receiving the same data from different peers.
return nil
}
var childrenPaths = make(map[util.Uint256][][]byte)
for _, path := range nPaths {
err := s.billet.RestoreHashNode(path, n.Node)
if err != nil { if err != nil {
return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err) return err
} }
for h, paths := range mpt.GetChildrenPaths(path, n.Node) {
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
}
}
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
} }
if s.mptpool.Count() == 0 { if s.mptpool.Count() == 0 {
s.syncStage |= mptSynced s.syncStage |= mptSynced
@ -356,6 +342,37 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error {
return nil return nil
} }
func (s *Module) restoreNode(n mpt.Node) error {
nPaths, ok := s.mptpool.TryGet(n.Hash())
if !ok {
// it can easily happen after receiving the same data from different peers.
return nil
}
var childrenPaths = make(map[util.Uint256][][]byte)
for _, path := range nPaths {
err := s.billet.RestoreHashNode(path, n)
if err != nil {
return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
}
for h, paths := range mpt.GetChildrenPaths(path, n) {
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
}
}
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
for h := range childrenPaths {
if child, err := s.billet.GetFromStore(h); err == nil {
// child is already in the storage, so we don't need to request it one more time.
err = s.restoreNode(child)
if err != nil {
return fmt.Errorf("unable to restore saved children: %w", err)
}
}
}
return nil
}
// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1 // checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1
// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored. // height are fetched, blocks up to P height are stored and MPT nodes for P height are stored.
// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller // If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller
@ -438,3 +455,11 @@ func (s *Module) GetJumpHeight() (uint32, error) {
} }
return s.syncPoint, nil return s.syncPoint, nil
} }
// GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max).
func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
s.lock.RLock()
defer s.lock.RUnlock()
return s.mptpool.GetBatch(limit)
}

View file

@ -48,6 +48,26 @@ func (mp *Pool) GetAll() map[util.Uint256][][]byte {
return mp.hashes return mp.hashes
} }
// GetBatch returns set of unknown MPT nodes hashes (`limit` at max).
func (mp *Pool) GetBatch(limit int) []util.Uint256 {
mp.lock.RLock()
defer mp.lock.RUnlock()
count := len(mp.hashes)
if count > limit {
count = limit
}
result := make([]util.Uint256, 0, limit)
for h := range mp.hashes {
if count == 0 {
break
}
result = append(result, h)
count--
}
return result
}
// Remove removes MPT node from the pool by the specified hash. // Remove removes MPT node from the pool by the specified hash.
func (mp *Pool) Remove(hash util.Uint256) { func (mp *Pool) Remove(hash util.Uint256) {
mp.lock.Lock() mp.lock.Lock()

View file

@ -0,0 +1,104 @@
package statesync
import (
"testing"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestPool_AddRemoveUpdate(t *testing.T) {
mp := NewPool()
i1 := []byte{1, 2, 3}
i1h := util.Uint256{1, 2, 3}
i2 := []byte{2, 3, 4}
i2h := util.Uint256{2, 3, 4}
i3 := []byte{4, 5, 6}
i3h := util.Uint256{3, 4, 5}
i4 := []byte{3, 4, 5} // has the same hash as i3
i5 := []byte{6, 7, 8} // has the same hash as i3
mapAll := map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}, i3h: {i4, i3}}
// No items
_, ok := mp.TryGet(i1h)
require.False(t, ok)
require.False(t, mp.ContainsKey(i1h))
require.Equal(t, 0, mp.Count())
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
// Add i1, i2, check OK
mp.Add(i1h, i1)
mp.Add(i2h, i2)
itm, ok := mp.TryGet(i1h)
require.True(t, ok)
require.Equal(t, [][]byte{i1}, itm)
require.True(t, mp.ContainsKey(i1h))
require.True(t, mp.ContainsKey(i2h))
require.Equal(t, map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Remove i1 and unexisting item
mp.Remove(i3h)
mp.Remove(i1h)
require.False(t, mp.ContainsKey(i1h))
require.True(t, mp.ContainsKey(i2h))
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}}, mp.GetAll())
require.Equal(t, 1, mp.Count())
// Update: remove nothing, add all
mp.Update(nil, mapAll)
require.Equal(t, mapAll, mp.GetAll())
require.Equal(t, 3, mp.Count())
// Update: remove all, add all
mp.Update(mapAll, mapAll)
require.Equal(t, mapAll, mp.GetAll()) // deletion first, addition after that
require.Equal(t, 3, mp.Count())
// Update: remove all, add nothing
mp.Update(mapAll, nil)
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
require.Equal(t, 0, mp.Count())
// Update: remove several, add several
mp.Update(map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}})
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Update: remove nothing, add several with same hashes
mp.Update(nil, map[util.Uint256][][]byte{i3h: {i5, i4}}) // should be sorted by the pool
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i3, i5}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Update: remove several with same hashes, add nothing
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i4}}, nil)
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Update: remove several with same hashes, add several with same hashes
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i3}}, map[util.Uint256][][]byte{i3h: {i5, i4}})
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i5}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
}
func TestPool_GetBatch(t *testing.T) {
check := func(t *testing.T, limit int, itemsCount int) {
mp := NewPool()
for i := 0; i < itemsCount; i++ {
mp.Add(random.Uint256(), []byte{0x01})
}
batch := mp.GetBatch(limit)
if limit < itemsCount {
require.Equal(t, limit, len(batch))
} else {
require.Equal(t, itemsCount, len(batch))
}
}
t.Run("limit less than items count", func(t *testing.T) {
check(t, 5, 6)
})
t.Run("limit more than items count", func(t *testing.T) {
check(t, 6, 5)
})
t.Run("items count limit", func(t *testing.T) {
check(t, 5, 5)
})
}

View file

@ -671,12 +671,23 @@ func (s *Server) requestBlocksOrHeaders(p Peer) error {
} }
return nil return nil
} }
var bq blockchainer.Blockqueuer = s.chain var (
bq blockchainer.Blockqueuer = s.chain
requestMPTNodes bool
)
if s.stateSync.IsActive() { if s.stateSync.IsActive() {
bq = s.stateSync bq = s.stateSync
requestMPTNodes = s.stateSync.NeedMPTNodes()
} }
if bq.BlockHeight() < p.LastBlockIndex() { if bq.BlockHeight() >= p.LastBlockIndex() {
return s.requestBlocks(bq, p) return nil
}
err := s.requestBlocks(bq, p)
if err != nil {
return err
}
if requestMPTNodes {
return s.requestMPTNodes(p, s.stateSync.GetUnknownMPTNodesBatch(payload.MaxMPTHashesCount))
} }
return nil return nil
} }
@ -849,6 +860,20 @@ func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error {
return s.stateSync.AddMPTNodes(data.Nodes) return s.stateSync.AddMPTNodes(data.Nodes)
} }
// requestMPTNodes requests specified MPT nodes from the peer or broadcasts
// request if peer is not specified.
func (s *Server) requestMPTNodes(p Peer, itms []util.Uint256) error {
if len(itms) == 0 {
return nil
}
if len(itms) > payload.MaxMPTHashesCount {
itms = itms[:payload.MaxMPTHashesCount]
}
pl := payload.NewMPTInventory(itms)
msg := NewMessage(CMDGetMPTData, pl)
return p.EnqueueP2PMessage(msg)
}
// handleGetBlocksCmd processes the getblocks request. // handleGetBlocksCmd processes the getblocks request.
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
count := gb.Count count := gb.Count