network: request unknown MPT nodes

In this commit:

1. Request unknown MPT nodes from peers. Note, that StateSync module itself
shouldn't be responsible for nodes requests, that's a server duty.
2. Do not request the same node twice, check if it is in storage
already. If so, then the only thing remaining is to update refcounter.
This commit is contained in:
Anna Shaleva 2021-08-13 12:46:23 +03:00
parent 6a04880b49
commit 3b7807e897
7 changed files with 203 additions and 22 deletions

View file

@ -508,3 +508,8 @@ func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node,
func (s *FakeStateSync) GetJumpHeight() (uint32, error) {
panic("TODO")
}
// GetUnknownMPTNodesBatch implements StateSync interface.
func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
panic("TODO")
}

View file

@ -13,6 +13,7 @@ type StateSync interface {
IsActive() bool
IsInitialized() bool
GetJumpHeight() (uint32, error)
GetUnknownMPTNodesBatch(limit int) []util.Uint256
NeedHeaders() bool
NeedMPTNodes() bool
Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error

View file

@ -215,7 +215,7 @@ func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) b
return curr, nil
}
if hn, ok := curr.(*HashNode); ok {
r, err := b.getFromStore(hn.Hash())
r, err := b.GetFromStore(hn.Hash())
if err != nil {
if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) {
return hn, nil
@ -292,7 +292,8 @@ func (b *Billet) tryCollapseBranch(curr *BranchNode) Node {
return res
}
func (b *Billet) getFromStore(h util.Uint256) (Node, error) {
// GetFromStore returns MPT node from the storage.
func (b *Billet) GetFromStore(h util.Uint256) (Node, error) {
data, err := b.Store.Get(makeStorageKey(h.BytesBE()))
if err != nil {
return nil, err

View file

@ -328,24 +328,10 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error {
if r.Err != nil {
return fmt.Errorf("failed to decode MPT node: %w", r.Err)
}
nPaths, ok := s.mptpool.TryGet(n.Hash())
if !ok {
// it can easily happen after receiving the same data from different peers.
return nil
err := s.restoreNode(n.Node)
if err != nil {
return err
}
var childrenPaths = make(map[util.Uint256][][]byte)
for _, path := range nPaths {
err := s.billet.RestoreHashNode(path, n.Node)
if err != nil {
return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
}
for h, paths := range mpt.GetChildrenPaths(path, n.Node) {
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
}
}
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
}
if s.mptpool.Count() == 0 {
s.syncStage |= mptSynced
@ -356,6 +342,37 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error {
return nil
}
func (s *Module) restoreNode(n mpt.Node) error {
nPaths, ok := s.mptpool.TryGet(n.Hash())
if !ok {
// it can easily happen after receiving the same data from different peers.
return nil
}
var childrenPaths = make(map[util.Uint256][][]byte)
for _, path := range nPaths {
err := s.billet.RestoreHashNode(path, n)
if err != nil {
return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
}
for h, paths := range mpt.GetChildrenPaths(path, n) {
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
}
}
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
for h := range childrenPaths {
if child, err := s.billet.GetFromStore(h); err == nil {
// child is already in the storage, so we don't need to request it one more time.
err = s.restoreNode(child)
if err != nil {
return fmt.Errorf("unable to restore saved children: %w", err)
}
}
}
return nil
}
// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1
// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored.
// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller
@ -438,3 +455,11 @@ func (s *Module) GetJumpHeight() (uint32, error) {
}
return s.syncPoint, nil
}
// GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max).
func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
s.lock.RLock()
defer s.lock.RUnlock()
return s.mptpool.GetBatch(limit)
}

View file

@ -48,6 +48,26 @@ func (mp *Pool) GetAll() map[util.Uint256][][]byte {
return mp.hashes
}
// GetBatch returns set of unknown MPT nodes hashes (`limit` at max).
func (mp *Pool) GetBatch(limit int) []util.Uint256 {
mp.lock.RLock()
defer mp.lock.RUnlock()
count := len(mp.hashes)
if count > limit {
count = limit
}
result := make([]util.Uint256, 0, limit)
for h := range mp.hashes {
if count == 0 {
break
}
result = append(result, h)
count--
}
return result
}
// Remove removes MPT node from the pool by the specified hash.
func (mp *Pool) Remove(hash util.Uint256) {
mp.lock.Lock()

View file

@ -0,0 +1,104 @@
package statesync
import (
"testing"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestPool_AddRemoveUpdate(t *testing.T) {
mp := NewPool()
i1 := []byte{1, 2, 3}
i1h := util.Uint256{1, 2, 3}
i2 := []byte{2, 3, 4}
i2h := util.Uint256{2, 3, 4}
i3 := []byte{4, 5, 6}
i3h := util.Uint256{3, 4, 5}
i4 := []byte{3, 4, 5} // has the same hash as i3
i5 := []byte{6, 7, 8} // has the same hash as i3
mapAll := map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}, i3h: {i4, i3}}
// No items
_, ok := mp.TryGet(i1h)
require.False(t, ok)
require.False(t, mp.ContainsKey(i1h))
require.Equal(t, 0, mp.Count())
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
// Add i1, i2, check OK
mp.Add(i1h, i1)
mp.Add(i2h, i2)
itm, ok := mp.TryGet(i1h)
require.True(t, ok)
require.Equal(t, [][]byte{i1}, itm)
require.True(t, mp.ContainsKey(i1h))
require.True(t, mp.ContainsKey(i2h))
require.Equal(t, map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Remove i1 and unexisting item
mp.Remove(i3h)
mp.Remove(i1h)
require.False(t, mp.ContainsKey(i1h))
require.True(t, mp.ContainsKey(i2h))
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}}, mp.GetAll())
require.Equal(t, 1, mp.Count())
// Update: remove nothing, add all
mp.Update(nil, mapAll)
require.Equal(t, mapAll, mp.GetAll())
require.Equal(t, 3, mp.Count())
// Update: remove all, add all
mp.Update(mapAll, mapAll)
require.Equal(t, mapAll, mp.GetAll()) // deletion first, addition after that
require.Equal(t, 3, mp.Count())
// Update: remove all, add nothing
mp.Update(mapAll, nil)
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
require.Equal(t, 0, mp.Count())
// Update: remove several, add several
mp.Update(map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}})
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Update: remove nothing, add several with same hashes
mp.Update(nil, map[util.Uint256][][]byte{i3h: {i5, i4}}) // should be sorted by the pool
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i3, i5}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Update: remove several with same hashes, add nothing
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i4}}, nil)
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Update: remove several with same hashes, add several with same hashes
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i3}}, map[util.Uint256][][]byte{i3h: {i5, i4}})
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i5}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
}
func TestPool_GetBatch(t *testing.T) {
check := func(t *testing.T, limit int, itemsCount int) {
mp := NewPool()
for i := 0; i < itemsCount; i++ {
mp.Add(random.Uint256(), []byte{0x01})
}
batch := mp.GetBatch(limit)
if limit < itemsCount {
require.Equal(t, limit, len(batch))
} else {
require.Equal(t, itemsCount, len(batch))
}
}
t.Run("limit less than items count", func(t *testing.T) {
check(t, 5, 6)
})
t.Run("limit more than items count", func(t *testing.T) {
check(t, 6, 5)
})
t.Run("items count limit", func(t *testing.T) {
check(t, 5, 5)
})
}

View file

@ -671,12 +671,23 @@ func (s *Server) requestBlocksOrHeaders(p Peer) error {
}
return nil
}
var bq blockchainer.Blockqueuer = s.chain
var (
bq blockchainer.Blockqueuer = s.chain
requestMPTNodes bool
)
if s.stateSync.IsActive() {
bq = s.stateSync
requestMPTNodes = s.stateSync.NeedMPTNodes()
}
if bq.BlockHeight() < p.LastBlockIndex() {
return s.requestBlocks(bq, p)
if bq.BlockHeight() >= p.LastBlockIndex() {
return nil
}
err := s.requestBlocks(bq, p)
if err != nil {
return err
}
if requestMPTNodes {
return s.requestMPTNodes(p, s.stateSync.GetUnknownMPTNodesBatch(payload.MaxMPTHashesCount))
}
return nil
}
@ -849,6 +860,20 @@ func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error {
return s.stateSync.AddMPTNodes(data.Nodes)
}
// requestMPTNodes requests specified MPT nodes from the peer or broadcasts
// request if peer is not specified.
func (s *Server) requestMPTNodes(p Peer, itms []util.Uint256) error {
if len(itms) == 0 {
return nil
}
if len(itms) > payload.MaxMPTHashesCount {
itms = itms[:payload.MaxMPTHashesCount]
}
pl := payload.NewMPTInventory(itms)
msg := NewMessage(CMDGetMPTData, pl)
return p.EnqueueP2PMessage(msg)
}
// handleGetBlocksCmd processes the getblocks request.
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
count := gb.Count