core: implement statesync module

And support GetMPTData and MPTData P2P commands.
This commit is contained in:
Anna Shaleva 2021-07-30 16:57:42 +03:00
parent a22b1caa3e
commit d67ff30704
24 changed files with 1197 additions and 32 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/native" "github.com/nspcc-dev/neo-go/pkg/core/native"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
@ -42,6 +43,13 @@ type FakeChain struct {
UtilityTokenBalance *big.Int UtilityTokenBalance *big.Int
} }
// FakeStateSync implements StateSync interface.
type FakeStateSync struct {
IsActiveFlag bool
IsInitializedFlag bool
InitFunc func(h uint32) error
}
// NewFakeChain returns new FakeChain structure. // NewFakeChain returns new FakeChain structure.
func NewFakeChain() *FakeChain { func NewFakeChain() *FakeChain {
return &FakeChain{ return &FakeChain{
@ -294,6 +302,16 @@ func (chain *FakeChain) GetStateModule() blockchainer.StateRoot {
return nil return nil
} }
// GetStateSyncModule implements Blockchainer interface.
func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync {
return &FakeStateSync{}
}
// JumpToState implements Blockchainer interface.
func (chain *FakeChain) JumpToState(module blockchainer.StateSync) error {
panic("TODO")
}
// GetStorageItem implements Blockchainer interface. // GetStorageItem implements Blockchainer interface.
func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem { func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
panic("TODO") panic("TODO")
@ -436,3 +454,57 @@ func (chain *FakeChain) UnsubscribeFromNotifications(ch chan<- *state.Notificati
func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) { func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
panic("TODO") panic("TODO")
} }
// AddBlock implements StateSync interface.
func (s *FakeStateSync) AddBlock(block *block.Block) error {
panic("TODO")
}
// AddHeaders implements StateSync interface.
func (s *FakeStateSync) AddHeaders(...*block.Header) error {
panic("TODO")
}
// AddMPTNodes implements StateSync interface.
func (s *FakeStateSync) AddMPTNodes([][]byte) error {
panic("TODO")
}
// BlockHeight implements StateSync interface.
func (s *FakeStateSync) BlockHeight() uint32 {
panic("TODO")
}
// IsActive implements StateSync interface.
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag }
// IsInitialized implements StateSync interface.
func (s *FakeStateSync) IsInitialized() bool {
return s.IsInitializedFlag
}
// Init implements StateSync interface.
func (s *FakeStateSync) Init(currChainHeight uint32) error {
if s.InitFunc != nil {
return s.InitFunc(currChainHeight)
}
panic("TODO")
}
// NeedHeaders implements StateSync interface.
func (s *FakeStateSync) NeedHeaders() bool { return false }
// NeedMPTNodes implements StateSync interface.
func (s *FakeStateSync) NeedMPTNodes() bool {
panic("TODO")
}
// Traverse implements StateSync interface.
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
panic("TODO")
}
// GetJumpHeight implements StateSync interface.
func (s *FakeStateSync) GetJumpHeight() (uint32, error) {
panic("TODO")
}

View file

@ -23,6 +23,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/native/noderoles" "github.com/nspcc-dev/neo-go/pkg/core/native/noderoles"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/stateroot" "github.com/nspcc-dev/neo-go/pkg/core/stateroot"
"github.com/nspcc-dev/neo-go/pkg/core/statesync"
"github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
@ -409,6 +410,64 @@ func (bc *Blockchain) init() error {
return bc.updateExtensibleWhitelist(bHeight) return bc.updateExtensibleWhitelist(bHeight)
} }
// JumpToState is an atomic operation that changes Blockchain state to the one
// specified by the state sync point p. All the data needed for the jump must be
// collected by the state sync module.
func (bc *Blockchain) JumpToState(module blockchainer.StateSync) error {
bc.lock.Lock()
defer bc.lock.Unlock()
p, err := module.GetJumpHeight()
if err != nil {
return fmt.Errorf("failed to get jump height: %w", err)
}
if p+1 >= uint32(len(bc.headerHashes)) {
return fmt.Errorf("invalid state sync point")
}
bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p))
block, err := bc.dao.GetBlock(bc.headerHashes[p])
if err != nil {
return fmt.Errorf("failed to get current block: %w", err)
}
err = bc.dao.StoreAsCurrentBlock(block, nil)
if err != nil {
return fmt.Errorf("failed to store current block: %w", err)
}
bc.topBlock.Store(block)
atomic.StoreUint32(&bc.blockHeight, p)
atomic.StoreUint32(&bc.persistedHeight, p)
block, err = bc.dao.GetBlock(bc.headerHashes[p+1])
if err != nil {
return fmt.Errorf("failed to get block to init MPT: %w", err)
}
if err = bc.stateRoot.JumpToState(&state.MPTRoot{
Index: p,
Root: block.PrevStateRoot,
}, bc.config.KeepOnlyLatestState); err != nil {
return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err)
}
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
if err != nil {
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
}
err = bc.contracts.Management.InitializeCache(bc.dao)
if err != nil {
return fmt.Errorf("can't init cache for Management native contract: %w", err)
}
bc.contracts.Designate.InitializeCache()
if err := bc.updateExtensibleWhitelist(p); err != nil {
return fmt.Errorf("failed to update extensible whitelist: %w", err)
}
updateBlockHeightMetric(p)
return nil
}
// Run runs chain loop, it needs to be run as goroutine and executing it is // Run runs chain loop, it needs to be run as goroutine and executing it is
// critical for correct Blockchain operation. // critical for correct Blockchain operation.
func (bc *Blockchain) Run() { func (bc *Blockchain) Run() {
@ -696,6 +755,11 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot {
return bc.stateRoot return bc.stateRoot
} }
// GetStateSyncModule returns new state sync service instance.
func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync {
return statesync.NewModule(bc, bc.log, bc.dao)
}
// storeBlock performs chain update using the block given, it executes all // storeBlock performs chain update using the block given, it executes all
// transactions with all appropriate side-effects and updates Blockchain state. // transactions with all appropriate side-effects and updates Blockchain state.
// This is the only way to change Blockchain state. // This is the only way to change Blockchain state.

View file

@ -21,7 +21,6 @@ import (
type Blockchainer interface { type Blockchainer interface {
ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction
GetConfig() config.ProtocolConfiguration GetConfig() config.ProtocolConfiguration
AddHeaders(...*block.Header) error
Blockqueuer // Blockqueuer interface Blockqueuer // Blockqueuer interface
CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error) CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error)
Close() Close()
@ -56,10 +55,12 @@ type Blockchainer interface {
GetStandByCommittee() keys.PublicKeys GetStandByCommittee() keys.PublicKeys
GetStandByValidators() keys.PublicKeys GetStandByValidators() keys.PublicKeys
GetStateModule() StateRoot GetStateModule() StateRoot
GetStateSyncModule() StateSync
GetStorageItem(id int32, key []byte) state.StorageItem GetStorageItem(id int32, key []byte) state.StorageItem
GetStorageItems(id int32) (map[string]state.StorageItem, error) GetStorageItems(id int32) (map[string]state.StorageItem, error)
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
JumpToState(module StateSync) error
SetOracle(service services.Oracle) SetOracle(service services.Oracle)
mempool.Feer // fee interface mempool.Feer // fee interface
ManagementContractHash() util.Uint160 ManagementContractHash() util.Uint160

View file

@ -5,5 +5,6 @@ import "github.com/nspcc-dev/neo-go/pkg/core/block"
// Blockqueuer is an interface for blockqueue. // Blockqueuer is an interface for blockqueue.
type Blockqueuer interface { type Blockqueuer interface {
AddBlock(block *block.Block) error AddBlock(block *block.Block) error
AddHeaders(...*block.Header) error
BlockHeight() uint32 BlockHeight() uint32
} }

View file

@ -9,6 +9,7 @@ import (
// StateRoot represents local state root module. // StateRoot represents local state root module.
type StateRoot interface { type StateRoot interface {
AddStateRoot(root *state.MPTRoot) error AddStateRoot(root *state.MPTRoot) error
CurrentLocalHeight() uint32
CurrentLocalStateRoot() util.Uint256 CurrentLocalStateRoot() util.Uint256
CurrentValidatedHeight() uint32 CurrentValidatedHeight() uint32
GetStateProof(root util.Uint256, key []byte) ([][]byte, error) GetStateProof(root util.Uint256, key []byte) ([][]byte, error)

View file

@ -0,0 +1,19 @@
package blockchainer
import (
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// StateSync represents state sync module.
type StateSync interface {
AddMPTNodes([][]byte) error
Blockqueuer // Blockqueuer interface
Init(currChainHeight uint32) error
IsActive() bool
IsInitialized() bool
GetJumpHeight() (uint32, error)
NeedHeaders() bool
NeedMPTNodes() bool
Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
}

View file

@ -1,5 +1,7 @@
package mpt package mpt
import "github.com/nspcc-dev/neo-go/pkg/util"
// lcp returns longest common prefix of a and b. // lcp returns longest common prefix of a and b.
// Note: it does no allocations. // Note: it does no allocations.
func lcp(a, b []byte) []byte { func lcp(a, b []byte) []byte {
@ -49,3 +51,36 @@ func fromNibbles(path []byte) []byte {
} }
return result return result
} }
// GetChildrenPaths returns a set of paths to node's children who are non-empty HashNodes
// based on the node's path.
func GetChildrenPaths(path []byte, node Node) map[util.Uint256][][]byte {
res := make(map[util.Uint256][][]byte)
switch n := node.(type) {
case *LeafNode, *HashNode, EmptyNode:
return nil
case *BranchNode:
for i, child := range n.Children {
if child.Type() == HashT {
cPath := make([]byte, len(path), len(path)+1)
copy(cPath, path)
if i != lastChild {
cPath = append(cPath, byte(i))
}
paths := res[child.Hash()]
paths = append(paths, cPath)
res[child.Hash()] = paths
}
}
case *ExtensionNode:
if n.next.Type() == HashT {
cPath := make([]byte, len(path)+len(n.key))
copy(cPath, path)
copy(cPath[len(path):], n.key)
res[n.next.Hash()] = [][]byte{cPath}
}
default:
panic("unknown Node type")
}
return res
}

View file

@ -3,6 +3,7 @@ package mpt
import ( import (
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -18,3 +19,49 @@ func TestToNibblesFromNibbles(t *testing.T) {
check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF}) check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF})
}) })
} }
func TestGetChildrenPaths(t *testing.T) {
h1 := NewHashNode(util.Uint256{1, 2, 3})
h2 := NewHashNode(util.Uint256{4, 5, 6})
h3 := NewHashNode(util.Uint256{7, 8, 9})
l := NewLeafNode([]byte{1, 2, 3})
ext1 := NewExtensionNode([]byte{8, 9}, h1)
ext2 := NewExtensionNode([]byte{7, 6}, l)
branch := NewBranchNode()
branch.Children[3] = h1
branch.Children[5] = l
branch.Children[6] = h1 // 3-th and 6-th children have the same hash
branch.Children[7] = h3
branch.Children[lastChild] = h2
testCases := map[string]struct {
node Node
expected map[util.Uint256][][]byte
}{
"Hash": {h1, nil},
"Leaf": {l, nil},
"Extension with next Hash": {ext1, map[util.Uint256][][]byte{h1.Hash(): {ext1.key}}},
"Extension with next non-Hash": {ext2, map[util.Uint256][][]byte{}},
"Branch": {branch, map[util.Uint256][][]byte{
h1.Hash(): {{0x03}, {0x06}},
h2.Hash(): {{}},
h3.Hash(): {{0x07}},
}},
}
parentPath := []byte{4, 5, 6}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
require.Equal(t, testCase.expected, GetChildrenPaths([]byte{}, testCase.node))
if testCase.expected != nil {
expectedWithPrefix := make(map[util.Uint256][][]byte, len(testCase.expected))
for h, paths := range testCase.expected {
var res [][]byte
for _, path := range paths {
res = append(res, append(parentPath, path...))
}
expectedWithPrefix[h] = res
}
require.Equal(t, expectedWithPrefix, GetChildrenPaths(parentPath, testCase.node))
}
})
}
}

View file

@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func newProofTrie(t *testing.T) *Trie { func newProofTrie(t *testing.T, missingHashNode bool) *Trie {
l := NewLeafNode([]byte("somevalue")) l := NewLeafNode([]byte("somevalue"))
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l) e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
l2 := NewLeafNode([]byte("invalid")) l2 := NewLeafNode([]byte("invalid"))
@ -20,11 +20,14 @@ func newProofTrie(t *testing.T) *Trie {
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
tr.putToStore(l) tr.putToStore(l)
tr.putToStore(e) tr.putToStore(e)
if !missingHashNode {
tr.putToStore(l2)
}
return tr return tr
} }
func TestTrie_GetProof(t *testing.T) { func TestTrie_GetProof(t *testing.T) {
tr := newProofTrie(t) tr := newProofTrie(t, true)
t.Run("MissingKey", func(t *testing.T) { t.Run("MissingKey", func(t *testing.T) {
_, err := tr.GetProof([]byte{0x12}) _, err := tr.GetProof([]byte{0x12})
@ -43,7 +46,7 @@ func TestTrie_GetProof(t *testing.T) {
} }
func TestVerifyProof(t *testing.T) { func TestVerifyProof(t *testing.T) {
tr := newProofTrie(t) tr := newProofTrie(t, true)
t.Run("Simple", func(t *testing.T) { t.Run("Simple", func(t *testing.T) {
proof, err := tr.GetProof([]byte{0x12, 0x32}) proof, err := tr.GetProof([]byte{0x12, 0x32})

View file

@ -353,3 +353,8 @@ func (s *Designate) getRole(item stackitem.Item) (noderoles.Role, bool) {
u := bi.Uint64() u := bi.Uint64()
return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u)) return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u))
} }
// InitializeCache invalidates native Designate cache.
func (s *Designate) InitializeCache() {
s.rolesChangedFlag.Store(true)
}

View file

@ -114,6 +114,25 @@ func (s *Module) Init(height uint32, enableRefCount bool) error {
return nil return nil
} }
// JumpToState performs jump to the state specified by given stateroot index.
func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error {
if err := s.addLocalStateRoot(s.Store, sr); err != nil {
return fmt.Errorf("failed to store local state root: %w", err)
}
data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, sr.Index)
if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixValidated}, data); err != nil {
return fmt.Errorf("failed to store validated height: %w", err)
}
s.validatedHeight.Store(sr.Index)
s.currentLocal.Store(sr.Root)
s.localHeight.Store(sr.Index)
s.mpt = mpt.NewTrie(mpt.NewHashNode(sr.Root), enableRefCount, s.Store)
return nil
}
// AddMPTBatch updates using provided batch. // AddMPTBatch updates using provided batch.
func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) { func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) {
mpt := *s.mpt mpt := *s.mpt

View file

@ -0,0 +1,440 @@
/*
Package statesync implements module for the P2P state synchronisation process. The
module manages state synchronisation for non-archival nodes which are joining the
network and don't have the ability to resync from the genesis block.
Given the currently available state synchronisation point P, sate sync process
includes the following stages:
1. Fetching headers starting from height 0 up to P+1.
2. Fetching MPT nodes for height P stating from the corresponding state root.
3. Fetching blocks starting from height P-MaxTraceableBlocks (or 0) up to P.
Steps 2 and 3 are being performed in parallel. Once all the data are collected
and stored in the db, an atomic state jump is occurred to the state sync point P.
Further node operation process is performed using standard sync mechanism until
the node reaches synchronised state.
*/
package statesync
import (
"encoding/hex"
"errors"
"fmt"
"sync"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
"go.uber.org/zap"
)
// stateSyncStage is a type of state synchronisation stage.
type stateSyncStage uint8
const (
// inactive means that state exchange is disabled by the protocol configuration.
// Can't be combined with other states.
inactive stateSyncStage = 1 << iota
// none means that state exchange is enabled in the configuration, but
// initialisation of the state sync module wasn't yet performed, i.e.
// (*Module).Init wasn't called. Can't be combined with other states.
none
// initialized means that (*Module).Init was called, but other sync stages
// are not yet reached (i.e. that headers are requested, but not yet fetched).
// Can't be combined with other states.
initialized
// headersSynced means that headers for the current state sync point are fetched.
// May be combined with mptSynced and/or blocksSynced.
headersSynced
// mptSynced means that MPT nodes for the current state sync point are fetched.
// Always combined with headersSynced; may be combined with blocksSynced.
mptSynced
// blocksSynced means that blocks up to the current state sync point are stored.
// Always combined with headersSynced; may be combined with mptSynced.
blocksSynced
)
// Module represents state sync module and aimed to gather state-related data to
// perform an atomic state jump.
type Module struct {
lock sync.RWMutex
log *zap.Logger
// syncPoint is the state synchronisation point P we're currently working against.
syncPoint uint32
// syncStage is the stage of the sync process.
syncStage stateSyncStage
// syncInterval is the delta between two adjacent state sync points.
syncInterval uint32
// blockHeight is the index of the latest stored block.
blockHeight uint32
dao *dao.Simple
bc blockchainer.Blockchainer
mptpool *Pool
billet *mpt.Billet
}
// NewModule returns new instance of statesync module.
func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple) *Module {
if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) {
return &Module{
dao: s,
bc: bc,
syncStage: inactive,
}
}
return &Module{
dao: s,
bc: bc,
log: log,
syncInterval: uint32(bc.GetConfig().StateSyncInterval),
mptpool: NewPool(),
syncStage: none,
}
}
// Init initializes state sync module for the current chain's height with given
// callback for MPT nodes requests.
func (s *Module) Init(currChainHeight uint32) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.syncStage != none {
return errors.New("already initialized or inactive")
}
p := (currChainHeight / s.syncInterval) * s.syncInterval
if p < 2*s.syncInterval {
// chain is too low to start state exchange process, use the standard sync mechanism
s.syncStage = inactive
return nil
}
pOld, err := s.dao.GetStateSyncPoint()
if err == nil && pOld >= p-s.syncInterval {
// old point is still valid, so try to resync states for this point.
p = pOld
} else if s.bc.BlockHeight() > p-2*s.syncInterval {
// chain has already been synchronised up to old state sync point and regular blocks processing was started
s.syncStage = inactive
return nil
}
s.syncPoint = p
err = s.dao.PutStateSyncPoint(p)
if err != nil {
return fmt.Errorf("failed to store state synchronisation point %d: %w", p, err)
}
s.syncStage = initialized
s.log.Info("try to sync state for the latest state synchronisation point",
zap.Uint32("point", p),
zap.Uint32("evaluated chain's blockHeight", currChainHeight))
// check headers sync state first
ltstHeaderHeight := s.bc.HeaderHeight()
if ltstHeaderHeight > p {
s.syncStage = headersSynced
s.log.Info("headers are in sync",
zap.Uint32("headerHeight", s.bc.HeaderHeight()))
}
// check blocks sync state
s.blockHeight = s.getLatestSavedBlock(p)
if s.blockHeight >= p {
s.syncStage |= blocksSynced
s.log.Info("blocks are in sync",
zap.Uint32("blockHeight", s.blockHeight))
}
// check MPT sync state
if s.blockHeight > p {
s.syncStage |= mptSynced
s.log.Info("MPT is in sync",
zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight()))
} else if s.syncStage&headersSynced != 0 {
header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(p + 1)))
if err != nil {
return fmt.Errorf("failed to get header to initialize MPT billet: %w", err)
}
s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store)
s.log.Info("MPT billet initialized",
zap.Uint32("height", s.syncPoint),
zap.String("state root", header.PrevStateRoot.StringBE()))
pool := NewPool()
pool.Add(header.PrevStateRoot, []byte{})
err = s.billet.Traverse(func(n mpt.Node, _ []byte) bool {
nPaths, ok := pool.TryGet(n.Hash())
if !ok {
// if this situation occurs, then it's a bug in MPT pool or Traverse.
panic("failed to get MPT node from the pool")
}
pool.Remove(n.Hash())
childrenPaths := make(map[util.Uint256][][]byte)
for _, path := range nPaths {
nChildrenPaths := mpt.GetChildrenPaths(path, n)
for hash, paths := range nChildrenPaths {
childrenPaths[hash] = append(childrenPaths[hash], paths...) // it's OK to have duplicates, they'll be handled by mempool
}
}
pool.Update(nil, childrenPaths)
return false
}, true)
if err != nil {
return fmt.Errorf("failed to traverse MPT while initialization: %w", err)
}
s.mptpool.Update(nil, pool.GetAll())
if s.mptpool.Count() == 0 {
s.syncStage |= mptSynced
s.log.Info("MPT is in sync",
zap.Uint32("stateroot height", p))
}
}
if s.syncStage == headersSynced|blocksSynced|mptSynced {
s.log.Info("state is in sync, starting regular blocks processing")
s.syncStage = inactive
}
return nil
}
// getLatestSavedBlock returns either current block index (if it's still relevant
// to continue state sync process) or H-1 where H is the index of the earliest
// block that should be saved next.
func (s *Module) getLatestSavedBlock(p uint32) uint32 {
var result uint32
mtb := s.bc.GetConfig().MaxTraceableBlocks
if p > mtb {
result = p - mtb
}
storedH, err := s.dao.GetStateSyncCurrentBlockHeight()
if err == nil && storedH > result {
result = storedH
}
actualH := s.bc.BlockHeight()
if actualH > result {
result = actualH
}
return result
}
// AddHeaders validates and adds specified headers to the chain.
func (s *Module) AddHeaders(hdrs ...*block.Header) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.syncStage != initialized {
return errors.New("headers were not requested")
}
hdrsErr := s.bc.AddHeaders(hdrs...)
if s.bc.HeaderHeight() > s.syncPoint {
s.syncStage = headersSynced
s.log.Info("headers for state sync are fetched",
zap.Uint32("header height", s.bc.HeaderHeight()))
header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint) + 1))
if err != nil {
s.log.Fatal("failed to get header to initialize MPT billet",
zap.Uint32("height", s.syncPoint+1),
zap.Error(err))
}
s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store)
s.mptpool.Add(header.PrevStateRoot, []byte{})
s.log.Info("MPT billet initialized",
zap.Uint32("height", s.syncPoint),
zap.String("state root", header.PrevStateRoot.StringBE()))
}
return hdrsErr
}
// AddBlock verifies and saves block skipping executable scripts.
func (s *Module) AddBlock(block *block.Block) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.syncStage&headersSynced == 0 || s.syncStage&blocksSynced != 0 {
return nil
}
if s.blockHeight == s.syncPoint {
return nil
}
expectedHeight := s.blockHeight + 1
if expectedHeight != block.Index {
return fmt.Errorf("expected %d, got %d: invalid block index", expectedHeight, block.Index)
}
if s.bc.GetConfig().StateRootInHeader != block.StateRootEnabled {
return fmt.Errorf("stateroot setting mismatch: %v != %v", s.bc.GetConfig().StateRootInHeader, block.StateRootEnabled)
}
if s.bc.GetConfig().VerifyBlocks {
merkle := block.ComputeMerkleRoot()
if !block.MerkleRoot.Equals(merkle) {
return errors.New("invalid block: MerkleRoot mismatch")
}
}
cache := s.dao.GetWrapped()
writeBuf := io.NewBufBinWriter()
if err := cache.StoreAsBlock(block, writeBuf); err != nil {
return err
}
writeBuf.Reset()
err := cache.PutStateSyncCurrentBlockHeight(block.Index)
if err != nil {
return fmt.Errorf("failed to store current block height: %w", err)
}
for _, tx := range block.Transactions {
if err := cache.StoreAsTransaction(tx, block.Index, writeBuf); err != nil {
return err
}
writeBuf.Reset()
}
_, err = cache.Persist()
if err != nil {
return fmt.Errorf("failed to persist results: %w", err)
}
s.blockHeight = block.Index
if s.blockHeight == s.syncPoint {
s.syncStage |= blocksSynced
s.log.Info("blocks are in sync",
zap.Uint32("blockHeight", s.blockHeight))
s.checkSyncIsCompleted()
}
return nil
}
// AddMPTNodes tries to add provided set of MPT nodes to the MPT billet if they are
// not yet collected.
func (s *Module) AddMPTNodes(nodes [][]byte) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.syncStage&headersSynced == 0 || s.syncStage&mptSynced != 0 {
return errors.New("MPT nodes were not requested")
}
for _, nBytes := range nodes {
var n mpt.NodeObject
r := io.NewBinReaderFromBuf(nBytes)
n.DecodeBinary(r)
if r.Err != nil {
return fmt.Errorf("failed to decode MPT node: %w", r.Err)
}
nPaths, ok := s.mptpool.TryGet(n.Hash())
if !ok {
// it can easily happen after receiving the same data from different peers.
return nil
}
var childrenPaths = make(map[util.Uint256][][]byte)
for _, path := range nPaths {
err := s.billet.RestoreHashNode(path, n.Node)
if err != nil {
return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
}
for h, paths := range mpt.GetChildrenPaths(path, n.Node) {
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
}
}
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
}
if s.mptpool.Count() == 0 {
s.syncStage |= mptSynced
s.log.Info("MPT is in sync",
zap.Uint32("height", s.syncPoint))
s.checkSyncIsCompleted()
}
return nil
}
// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1
// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored.
// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller
// should take care of it.
func (s *Module) checkSyncIsCompleted() {
if s.syncStage != headersSynced|mptSynced|blocksSynced {
return
}
s.log.Info("state is in sync",
zap.Uint32("state sync point", s.syncPoint))
err := s.bc.JumpToState(s)
if err != nil {
s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err))
}
s.syncStage = inactive
s.dispose()
}
func (s *Module) dispose() {
s.billet = nil
}
// BlockHeight returns index of the last stored block.
func (s *Module) BlockHeight() uint32 {
s.lock.RLock()
defer s.lock.RUnlock()
return s.blockHeight
}
// IsActive tells whether state sync module is on and still gathering state
// synchronisation data (headers, blocks or MPT nodes).
func (s *Module) IsActive() bool {
s.lock.RLock()
defer s.lock.RUnlock()
return !(s.syncStage == inactive || (s.syncStage == headersSynced|mptSynced|blocksSynced))
}
// IsInitialized tells whether state sync module does not require initialization.
// If `false` is returned then Init can be safely called.
func (s *Module) IsInitialized() bool {
s.lock.RLock()
defer s.lock.RUnlock()
return s.syncStage != none
}
// NeedHeaders tells whether the module hasn't completed headers synchronisation.
func (s *Module) NeedHeaders() bool {
s.lock.RLock()
defer s.lock.RUnlock()
return s.syncStage == initialized
}
// NeedMPTNodes returns whether the module hasn't completed MPT synchronisation.
func (s *Module) NeedMPTNodes() bool {
s.lock.RLock()
defer s.lock.RUnlock()
return s.syncStage&headersSynced != 0 && s.syncStage&mptSynced == 0
}
// Traverse traverses local MPT nodes starting from the specified root down to its
// children calling `process` for each serialised node until stop condition is satisfied.
func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
s.lock.RLock()
defer s.lock.RUnlock()
b := mpt.NewBillet(root, s.bc.GetConfig().KeepOnlyLatestState, storage.NewMemCachedStore(s.dao.Store))
return b.Traverse(process, false)
}
// GetJumpHeight returns state sync point to jump to. It is not protected by mutex and should be called
// under the module lock.
func (s *Module) GetJumpHeight() (uint32, error) {
if s.syncStage != headersSynced|mptSynced|blocksSynced {
return 0, errors.New("state sync module has wong state to perform state jump")
}
return s.syncPoint, nil
}

View file

@ -0,0 +1,119 @@
package statesync
import (
"bytes"
"sort"
"sync"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// Pool stores unknown MPT nodes along with the corresponding paths (single node is
// allowed to have multiple MPT paths).
type Pool struct {
lock sync.RWMutex
hashes map[util.Uint256][][]byte
}
// NewPool returns new MPT node hashes pool.
func NewPool() *Pool {
return &Pool{
hashes: make(map[util.Uint256][][]byte),
}
}
// ContainsKey checks if MPT node with the specified hash is in the Pool.
func (mp *Pool) ContainsKey(hash util.Uint256) bool {
mp.lock.RLock()
defer mp.lock.RUnlock()
_, ok := mp.hashes[hash]
return ok
}
// TryGet returns a set of MPT paths for the specified HashNode.
func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) {
mp.lock.RLock()
defer mp.lock.RUnlock()
paths, ok := mp.hashes[hash]
return paths, ok
}
// GetAll returns all MPT nodes with the corresponding paths from the pool.
func (mp *Pool) GetAll() map[util.Uint256][][]byte {
mp.lock.RLock()
defer mp.lock.RUnlock()
return mp.hashes
}
// Remove removes MPT node from the pool by the specified hash.
func (mp *Pool) Remove(hash util.Uint256) {
mp.lock.Lock()
defer mp.lock.Unlock()
delete(mp.hashes, hash)
}
// Add adds path to the set of paths for the specified node.
func (mp *Pool) Add(hash util.Uint256, path []byte) {
mp.lock.Lock()
defer mp.lock.Unlock()
mp.addPaths(hash, [][]byte{path})
}
// Update is an atomic operation and removes/adds specified nodes from/to the pool.
func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) {
mp.lock.Lock()
defer mp.lock.Unlock()
for h, paths := range remove {
old := mp.hashes[h]
for _, path := range paths {
i := sort.Search(len(old), func(i int) bool {
return bytes.Compare(old[i], path) >= 0
})
if i < len(old) && bytes.Equal(old[i], path) {
old = append(old[:i], old[i+1:]...)
}
}
if len(old) == 0 {
delete(mp.hashes, h)
} else {
mp.hashes[h] = old
}
}
for h, paths := range add {
mp.addPaths(h, paths)
}
}
// addPaths adds set of the specified node paths to the pool.
func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) {
old := mp.hashes[nodeHash]
for _, path := range paths {
i := sort.Search(len(old), func(i int) bool {
return bytes.Compare(old[i], path) >= 0
})
if i < len(old) && bytes.Equal(old[i], path) {
// then path is already added
continue
}
old = append(old, path)
if i != len(old)-1 {
copy(old[i+1:], old[i:])
old[i] = path
}
}
mp.hashes[nodeHash] = old
}
// Count returns the number of nodes in the pool.
func (mp *Pool) Count() int {
mp.lock.RLock()
defer mp.lock.RUnlock()
return len(mp.hashes)
}

View file

@ -4,6 +4,7 @@ import (
"github.com/Workiva/go-datastructures/queue" "github.com/Workiva/go-datastructures/queue"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -13,6 +14,7 @@ type blockQueue struct {
checkBlocks chan struct{} checkBlocks chan struct{}
chain blockchainer.Blockqueuer chain blockchainer.Blockqueuer
relayF func(*block.Block) relayF func(*block.Block)
discarded *atomic.Bool
} }
const ( const (
@ -32,6 +34,7 @@ func newBlockQueue(capacity int, bc blockchainer.Blockqueuer, log *zap.Logger, r
checkBlocks: make(chan struct{}, 1), checkBlocks: make(chan struct{}, 1),
chain: bc, chain: bc,
relayF: relayer, relayF: relayer,
discarded: atomic.NewBool(false),
} }
} }
@ -91,8 +94,10 @@ func (bq *blockQueue) putBlock(block *block.Block) error {
} }
func (bq *blockQueue) discard() { func (bq *blockQueue) discard() {
if bq.discarded.CAS(false, true) {
close(bq.checkBlocks) close(bq.checkBlocks)
bq.queue.Dispose() bq.queue.Dispose()
}
} }
func (bq *blockQueue) length() int { func (bq *blockQueue) length() int {

View file

@ -71,6 +71,8 @@ const (
CMDBlock = CommandType(payload.BlockType) CMDBlock = CommandType(payload.BlockType)
CMDExtensible = CommandType(payload.ExtensibleType) CMDExtensible = CommandType(payload.ExtensibleType)
CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType) CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType)
CMDGetMPTData CommandType = 0x51 // 0x5.. commands are used for extensions (P2PNotary, state exchange cmds)
CMDMPTData CommandType = 0x52
CMDReject CommandType = 0x2f CMDReject CommandType = 0x2f
// SPV protocol. // SPV protocol.
@ -136,6 +138,10 @@ func (m *Message) decodePayload() error {
p = &payload.Version{} p = &payload.Version{}
case CMDInv, CMDGetData: case CMDInv, CMDGetData:
p = &payload.Inventory{} p = &payload.Inventory{}
case CMDGetMPTData:
p = &payload.MPTInventory{}
case CMDMPTData:
p = &payload.MPTData{}
case CMDAddr: case CMDAddr:
p = &payload.AddressList{} p = &payload.AddressList{}
case CMDBlock: case CMDBlock:
@ -221,7 +227,7 @@ func (m *Message) tryCompressPayload() error {
if m.Flags&Compressed == 0 { if m.Flags&Compressed == 0 {
switch m.Payload.(type) { switch m.Payload.(type) {
case *payload.Headers, *payload.MerkleBlock, payload.NullPayload, case *payload.Headers, *payload.MerkleBlock, payload.NullPayload,
*payload.Inventory: *payload.Inventory, *payload.MPTInventory:
break break
default: default:
size := len(compressedPayload) size := len(compressedPayload)

View file

@ -26,6 +26,8 @@ func _() {
_ = x[CMDBlock-44] _ = x[CMDBlock-44]
_ = x[CMDExtensible-46] _ = x[CMDExtensible-46]
_ = x[CMDP2PNotaryRequest-80] _ = x[CMDP2PNotaryRequest-80]
_ = x[CMDGetMPTData-81]
_ = x[CMDMPTData-82]
_ = x[CMDReject-47] _ = x[CMDReject-47]
_ = x[CMDFilterLoad-48] _ = x[CMDFilterLoad-48]
_ = x[CMDFilterAdd-49] _ = x[CMDFilterAdd-49]
@ -44,7 +46,7 @@ const (
_CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear" _CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
_CommandType_name_7 = "CMDMerkleBlock" _CommandType_name_7 = "CMDMerkleBlock"
_CommandType_name_8 = "CMDAlert" _CommandType_name_8 = "CMDAlert"
_CommandType_name_9 = "CMDP2PNotaryRequest" _CommandType_name_9 = "CMDP2PNotaryRequestCMDGetMPTDataCMDMPTData"
) )
var ( var (
@ -55,6 +57,7 @@ var (
_CommandType_index_4 = [...]uint8{0, 12, 22} _CommandType_index_4 = [...]uint8{0, 12, 22}
_CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58} _CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58}
_CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61} _CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61}
_CommandType_index_9 = [...]uint8{0, 19, 32, 42}
) )
func (i CommandType) String() string { func (i CommandType) String() string {
@ -83,8 +86,9 @@ func (i CommandType) String() string {
return _CommandType_name_7 return _CommandType_name_7
case i == 64: case i == 64:
return _CommandType_name_8 return _CommandType_name_8
case i == 80: case 80 <= i && i <= 82:
return _CommandType_name_9 i -= 80
return _CommandType_name_9[_CommandType_index_9[i]:_CommandType_index_9[i+1]]
default: default:
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")" return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
} }

View file

@ -258,6 +258,21 @@ func TestEncodeDecodeNotFound(t *testing.T) {
}) })
} }
func TestEncodeDecodeGetMPTData(t *testing.T) {
testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{
Hashes: []util.Uint256{
{1, 2, 3},
{4, 5, 6},
},
})
}
func TestEncodeDecodeMPTData(t *testing.T) {
testEncodeDecode(t, CMDMPTData, &payload.MPTData{
Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}},
})
}
func TestInvalidMessages(t *testing.T) { func TestInvalidMessages(t *testing.T) {
t.Run("CMDBlock, empty payload", func(t *testing.T) { t.Run("CMDBlock, empty payload", func(t *testing.T) {
testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{}) testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})

View file

@ -0,0 +1,35 @@
package payload
import (
"errors"
"github.com/nspcc-dev/neo-go/pkg/io"
)
// MPTData represents the set of serialized MPT nodes.
type MPTData struct {
Nodes [][]byte
}
// EncodeBinary implements io.Serializable.
func (d *MPTData) EncodeBinary(w *io.BinWriter) {
w.WriteVarUint(uint64(len(d.Nodes)))
for _, n := range d.Nodes {
w.WriteVarBytes(n)
}
}
// DecodeBinary implements io.Serializable.
func (d *MPTData) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint()
if sz == 0 {
r.Err = errors.New("empty MPT nodes list")
return
}
for i := uint64(0); i < sz; i++ {
d.Nodes = append(d.Nodes, r.ReadVarBytes())
if r.Err != nil {
return
}
}
}

View file

@ -0,0 +1,24 @@
package payload
import (
"testing"
"github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/stretchr/testify/require"
)
func TestMPTData_EncodeDecodeBinary(t *testing.T) {
t.Run("empty", func(t *testing.T) {
d := new(MPTData)
bytes, err := testserdes.EncodeBinary(d)
require.NoError(t, err)
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTData)))
})
t.Run("good", func(t *testing.T) {
d := &MPTData{
Nodes: [][]byte{{}, {1}, {1, 2, 3}},
}
testserdes.EncodeDecodeBinary(t, d, new(MPTData))
})
}

View file

@ -0,0 +1,32 @@
package payload
import (
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// MaxMPTHashesCount is the maximum number of requested MPT nodes hashes.
const MaxMPTHashesCount = 32
// MPTInventory payload.
type MPTInventory struct {
// A list of requested MPT nodes hashes.
Hashes []util.Uint256
}
// NewMPTInventory return a pointer to an MPTInventory.
func NewMPTInventory(hashes []util.Uint256) *MPTInventory {
return &MPTInventory{
Hashes: hashes,
}
}
// DecodeBinary implements Serializable interface.
func (p *MPTInventory) DecodeBinary(br *io.BinReader) {
br.ReadArray(&p.Hashes, MaxMPTHashesCount)
}
// EncodeBinary implements Serializable interface.
func (p *MPTInventory) EncodeBinary(bw *io.BinWriter) {
bw.WriteArray(p.Hashes)
}

View file

@ -0,0 +1,38 @@
package payload
import (
"testing"
"github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestMPTInventory_EncodeDecodeBinary(t *testing.T) {
t.Run("empty", func(t *testing.T) {
testserdes.EncodeDecodeBinary(t, NewMPTInventory([]util.Uint256{}), new(MPTInventory))
})
t.Run("good", func(t *testing.T) {
inv := NewMPTInventory([]util.Uint256{{1, 2, 3}, {2, 3, 4}})
testserdes.EncodeDecodeBinary(t, inv, new(MPTInventory))
})
t.Run("too large", func(t *testing.T) {
check := func(t *testing.T, count int, fail bool) {
h := make([]util.Uint256, count)
for i := range h {
h[i] = util.Uint256{1, 2, 3}
}
if fail {
bytes, err := testserdes.EncodeBinary(NewMPTInventory(h))
require.NoError(t, err)
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTInventory)))
} else {
testserdes.EncodeDecodeBinary(t, NewMPTInventory(h), new(MPTInventory))
}
}
check(t, MaxMPTHashesCount, false)
check(t, MaxMPTHashesCount+1, true)
})
}

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
mrand "math/rand" mrand "math/rand"
"net" "net"
"sort"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -17,7 +18,9 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/extpool" "github.com/nspcc-dev/neo-go/pkg/network/extpool"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
@ -67,6 +70,7 @@ type (
discovery Discoverer discovery Discoverer
chain blockchainer.Blockchainer chain blockchainer.Blockchainer
bQueue *blockQueue bQueue *blockQueue
bSyncQueue *blockQueue
consensus consensus.Service consensus consensus.Service
mempool *mempool.Pool mempool *mempool.Pool
notaryRequestPool *mempool.Pool notaryRequestPool *mempool.Pool
@ -93,6 +97,7 @@ type (
oracle *oracle.Oracle oracle *oracle.Oracle
stateRoot stateroot.Service stateRoot stateroot.Service
stateSync blockchainer.StateSync
log *zap.Logger log *zap.Logger
} }
@ -191,6 +196,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
} }
s.stateRoot = sr s.stateRoot = sr
sSync := chain.GetStateSyncModule()
s.stateSync = sSync
s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil)
if config.OracleCfg.Enabled { if config.OracleCfg.Enabled {
orcCfg := oracle.Config{ orcCfg := oracle.Config{
Log: log, Log: log,
@ -277,6 +286,7 @@ func (s *Server) Start(errChan chan error) {
go s.broadcastTxLoop() go s.broadcastTxLoop()
go s.relayBlocksLoop() go s.relayBlocksLoop()
go s.bQueue.run() go s.bQueue.run()
go s.bSyncQueue.run()
go s.transport.Accept() go s.transport.Accept()
setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10)) setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10))
s.run() s.run()
@ -292,6 +302,7 @@ func (s *Server) Shutdown() {
p.Disconnect(errServerShutdown) p.Disconnect(errServerShutdown)
} }
s.bQueue.discard() s.bQueue.discard()
s.bSyncQueue.discard()
if s.StateRootCfg.Enabled { if s.StateRootCfg.Enabled {
s.stateRoot.Shutdown() s.stateRoot.Shutdown()
} }
@ -573,6 +584,10 @@ func (s *Server) IsInSync() bool {
var peersNumber int var peersNumber int
var notHigher int var notHigher int
if s.stateSync.IsActive() {
return false
}
if s.MinPeers == 0 { if s.MinPeers == 0 {
return true return true
} }
@ -630,6 +645,9 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
// handleBlockCmd processes the received block received from its peer. // handleBlockCmd processes the received block received from its peer.
func (s *Server) handleBlockCmd(p Peer, block *block.Block) error { func (s *Server) handleBlockCmd(p Peer, block *block.Block) error {
if s.stateSync.IsActive() {
return s.bSyncQueue.putBlock(block)
}
return s.bQueue.putBlock(block) return s.bQueue.putBlock(block)
} }
@ -639,25 +657,46 @@ func (s *Server) handlePing(p Peer, ping *payload.Ping) error {
if err != nil { if err != nil {
return err return err
} }
if s.chain.BlockHeight() < ping.LastBlockIndex { err = s.requestBlocksOrHeaders(p)
err = s.requestBlocks(p)
if err != nil { if err != nil {
return err return err
} }
}
return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id))) return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id)))
} }
func (s *Server) requestBlocksOrHeaders(p Peer) error {
if s.stateSync.NeedHeaders() {
if s.chain.HeaderHeight() < p.LastBlockIndex() {
return s.requestHeaders(p)
}
return nil
}
var bq blockchainer.Blockqueuer = s.chain
if s.stateSync.IsActive() {
bq = s.stateSync
}
if bq.BlockHeight() < p.LastBlockIndex() {
return s.requestBlocks(bq, p)
}
return nil
}
// requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers.
func (s *Server) requestHeaders(p Peer) error {
// TODO: optimize
currHeight := s.chain.HeaderHeight()
needHeight := currHeight + 1
payload := payload.NewGetBlockByIndex(needHeight, -1)
return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload))
}
// handlePing processes pong request. // handlePing processes pong request.
func (s *Server) handlePong(p Peer, pong *payload.Ping) error { func (s *Server) handlePong(p Peer, pong *payload.Ping) error {
err := p.HandlePong(pong) err := p.HandlePong(pong)
if err != nil { if err != nil {
return err return err
} }
if s.chain.BlockHeight() < pong.LastBlockIndex { return s.requestBlocksOrHeaders(p)
return s.requestBlocks(p)
}
return nil
} }
// handleInvCmd processes the received inventory. // handleInvCmd processes the received inventory.
@ -766,6 +805,50 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
return nil return nil
} }
// handleGetMPTDataCmd processes the received MPT inventory.
func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error {
if !s.chain.GetConfig().P2PStateExchangeExtensions {
return errors.New("GetMPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
}
if s.chain.GetConfig().KeepOnlyLatestState {
// TODO: implement keeping MPT states for P1 and P2 height (#2095, #2152 related)
return errors.New("GetMPTDataCMD was received, but only latest MPT state is supported")
}
resp := payload.MPTData{}
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
for _, h := range inv.Hashes {
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
break
}
err := s.stateSync.Traverse(h,
func(_ mpt.Node, node []byte) bool {
l := len(node)
size := l + io.GetVarSize(l)
if size > capLeft {
return true
}
resp.Nodes = append(resp.Nodes, node)
capLeft -= size
return false
})
if err != nil {
return fmt.Errorf("failed to traverse MPT starting from %s: %w", h.StringBE(), err)
}
}
if len(resp.Nodes) > 0 {
msg := NewMessage(CMDMPTData, &resp)
return p.EnqueueP2PMessage(msg)
}
return nil
}
func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error {
if !s.chain.GetConfig().P2PStateExchangeExtensions {
return errors.New("MPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
}
return s.stateSync.AddMPTNodes(data.Nodes)
}
// handleGetBlocksCmd processes the getblocks request. // handleGetBlocksCmd processes the getblocks request.
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
count := gb.Count count := gb.Count
@ -845,6 +928,11 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error
return p.EnqueueP2PMessage(msg) return p.EnqueueP2PMessage(msg)
} }
// handleHeadersCmd processes headers payload.
func (s *Server) handleHeadersCmd(p Peer, h *payload.Headers) error {
return s.stateSync.AddHeaders(h.Hdrs...)
}
// handleExtensibleCmd processes received extensible payload. // handleExtensibleCmd processes received extensible payload.
func (s *Server) handleExtensibleCmd(e *payload.Extensible) error { func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
if !s.syncReached.Load() { if !s.syncReached.Load() {
@ -993,8 +1081,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error {
// 1. Block range is divided into chunks of payload.MaxHashesCount. // 1. Block range is divided into chunks of payload.MaxHashesCount.
// 2. Send requests for chunk in increasing order. // 2. Send requests for chunk in increasing order.
// 3. After all requests were sent, request random height. // 3. After all requests were sent, request random height.
func (s *Server) requestBlocks(p Peer) error { func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error {
var currHeight = s.chain.BlockHeight() var currHeight = bq.BlockHeight()
var peerHeight = p.LastBlockIndex() var peerHeight = p.LastBlockIndex()
var needHeight uint32 var needHeight uint32
// lastRequestedHeight can only be increased. // lastRequestedHeight can only be increased.
@ -1051,9 +1139,18 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
case CMDGetData: case CMDGetData:
inv := msg.Payload.(*payload.Inventory) inv := msg.Payload.(*payload.Inventory)
return s.handleGetDataCmd(peer, inv) return s.handleGetDataCmd(peer, inv)
case CMDGetMPTData:
inv := msg.Payload.(*payload.MPTInventory)
return s.handleGetMPTDataCmd(peer, inv)
case CMDMPTData:
inv := msg.Payload.(*payload.MPTData)
return s.handleMPTDataCmd(peer, inv)
case CMDGetHeaders: case CMDGetHeaders:
gh := msg.Payload.(*payload.GetBlockByIndex) gh := msg.Payload.(*payload.GetBlockByIndex)
return s.handleGetHeadersCmd(peer, gh) return s.handleGetHeadersCmd(peer, gh)
case CMDHeaders:
h := msg.Payload.(*payload.Headers)
return s.handleHeadersCmd(peer, h)
case CMDInv: case CMDInv:
inventory := msg.Payload.(*payload.Inventory) inventory := msg.Payload.(*payload.Inventory)
return s.handleInvCmd(peer, inventory) return s.handleInvCmd(peer, inventory)
@ -1093,6 +1190,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
} }
go peer.StartProtocol() go peer.StartProtocol()
s.tryInitStateSync()
s.tryStartServices() s.tryStartServices()
default: default:
return fmt.Errorf("received '%s' during handshake", msg.Command.String()) return fmt.Errorf("received '%s' during handshake", msg.Command.String())
@ -1101,6 +1199,52 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return nil return nil
} }
func (s *Server) tryInitStateSync() {
if !s.stateSync.IsActive() {
s.bSyncQueue.discard()
return
}
if s.stateSync.IsInitialized() {
return
}
var peersNumber int
s.lock.RLock()
heights := make([]uint32, 0)
for p := range s.peers {
if p.Handshaked() {
peersNumber++
peerLastBlock := p.LastBlockIndex()
i := sort.Search(len(heights), func(i int) bool {
return heights[i] >= peerLastBlock
})
heights = append(heights, peerLastBlock)
if i != len(heights)-1 {
copy(heights[i+1:], heights[i:])
heights[i] = peerLastBlock
}
}
}
s.lock.RUnlock()
if peersNumber >= s.MinPeers && len(heights) > 0 {
// choose the height of the median peer as current chain's height
h := heights[len(heights)/2]
err := s.stateSync.Init(h)
if err != nil {
s.log.Fatal("failed to init state sync module",
zap.Uint32("evaluated chain's blockHeight", h),
zap.Uint32("blockHeight", s.chain.BlockHeight()),
zap.Uint32("headerHeight", s.chain.HeaderHeight()),
zap.Error(err))
}
// module can be inactive after init (i.e. full state is collected and ordinary block processing is needed)
if !s.stateSync.IsActive() {
s.bSyncQueue.discard()
}
}
}
func (s *Server) handleNewPayload(p *payload.Extensible) { func (s *Server) handleNewPayload(p *payload.Extensible) {
_, err := s.extensiblePool.Add(p) _, err := s.extensiblePool.Add(p)
if err != nil { if err != nil {

View file

@ -2,6 +2,7 @@ package network
import ( import (
"errors" "errors"
"fmt"
"math/big" "math/big"
"net" "net"
"strconv" "strconv"
@ -46,7 +47,10 @@ func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs =
func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") } func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
func TestNewServer(t *testing.T) { func TestNewServer(t *testing.T) {
bc := &fakechain.FakeChain{} bc := &fakechain.FakeChain{ProtocolConfiguration: config.ProtocolConfiguration{
P2PStateExchangeExtensions: true,
StateRootInHeader: true,
}}
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery) s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
require.Error(t, err) require.Error(t, err)
@ -899,3 +903,39 @@ func TestVerifyNotaryRequest(t *testing.T) {
require.NoError(t, verifyNotaryRequest(bc, nil, r)) require.NoError(t, verifyNotaryRequest(bc, nil, r))
}) })
} }
func TestTryInitStateSync(t *testing.T) {
t.Run("module inactive", func(t *testing.T) {
s := startTestServer(t)
s.tryInitStateSync()
})
t.Run("module already initialized", func(t *testing.T) {
s := startTestServer(t)
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true}
s.tryInitStateSync()
})
t.Run("good", func(t *testing.T) {
s := startTestServer(t)
for _, h := range []uint32{10, 8, 7, 4, 11, 4} {
p := newLocalPeer(t, s)
p.handshaked = true
p.lastBlockIndex = h
s.peers[p] = true
}
p := newLocalPeer(t, s)
p.handshaked = false // one disconnected peer to check it won't be taken into attention
p.lastBlockIndex = 5
s.peers[p] = true
var expectedH uint32 = 8 // median peer
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error {
if h != expectedH {
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
}
return nil
}}
s.tryInitStateSync()
})
}

View file

@ -267,13 +267,11 @@ func (p *TCPPeer) StartProtocol() {
zap.Uint32("id", p.Version().Nonce)) zap.Uint32("id", p.Version().Nonce))
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities) p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
if p.server.chain.BlockHeight() < p.LastBlockIndex() { err = p.server.requestBlocksOrHeaders(p)
err = p.server.requestBlocks(p)
if err != nil { if err != nil {
p.Disconnect(err) p.Disconnect(err)
return return
} }
}
timer := time.NewTimer(p.server.ProtoTickInterval) timer := time.NewTimer(p.server.ProtoTickInterval)
for { for {
@ -281,10 +279,8 @@ func (p *TCPPeer) StartProtocol() {
case <-p.done: case <-p.done:
return return
case <-timer.C: case <-timer.C:
// Try to sync in headers and block with the peer if his block height is higher then ours. // Try to sync in headers and block with the peer if his block height is higher than ours.
if p.LastBlockIndex() > p.server.chain.BlockHeight() { err = p.server.requestBlocksOrHeaders(p)
err = p.server.requestBlocks(p)
}
if err == nil { if err == nil {
timer.Reset(p.server.ProtoTickInterval) timer.Reset(p.server.ProtoTickInterval)
} }