diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index ef8145858..ca1f0bc6d 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services" "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/mempool" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/native" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -42,6 +43,13 @@ type FakeChain struct { UtilityTokenBalance *big.Int } +// FakeStateSync implements StateSync interface. +type FakeStateSync struct { + IsActiveFlag bool + IsInitializedFlag bool + InitFunc func(h uint32) error +} + // NewFakeChain returns new FakeChain structure. func NewFakeChain() *FakeChain { return &FakeChain{ @@ -294,6 +302,16 @@ func (chain *FakeChain) GetStateModule() blockchainer.StateRoot { return nil } +// GetStateSyncModule implements Blockchainer interface. +func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync { + return &FakeStateSync{} +} + +// JumpToState implements Blockchainer interface. +func (chain *FakeChain) JumpToState(module blockchainer.StateSync) error { + panic("TODO") +} + // GetStorageItem implements Blockchainer interface. func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem { panic("TODO") @@ -436,3 +454,57 @@ func (chain *FakeChain) UnsubscribeFromNotifications(ch chan<- *state.Notificati func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) { panic("TODO") } + +// AddBlock implements StateSync interface. +func (s *FakeStateSync) AddBlock(block *block.Block) error { + panic("TODO") +} + +// AddHeaders implements StateSync interface. +func (s *FakeStateSync) AddHeaders(...*block.Header) error { + panic("TODO") +} + +// AddMPTNodes implements StateSync interface. +func (s *FakeStateSync) AddMPTNodes([][]byte) error { + panic("TODO") +} + +// BlockHeight implements StateSync interface. +func (s *FakeStateSync) BlockHeight() uint32 { + panic("TODO") +} + +// IsActive implements StateSync interface. +func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag } + +// IsInitialized implements StateSync interface. +func (s *FakeStateSync) IsInitialized() bool { + return s.IsInitializedFlag +} + +// Init implements StateSync interface. +func (s *FakeStateSync) Init(currChainHeight uint32) error { + if s.InitFunc != nil { + return s.InitFunc(currChainHeight) + } + panic("TODO") +} + +// NeedHeaders implements StateSync interface. +func (s *FakeStateSync) NeedHeaders() bool { return false } + +// NeedMPTNodes implements StateSync interface. +func (s *FakeStateSync) NeedMPTNodes() bool { + panic("TODO") +} + +// Traverse implements StateSync interface. +func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + panic("TODO") +} + +// GetJumpHeight implements StateSync interface. +func (s *FakeStateSync) GetJumpHeight() (uint32, error) { + panic("TODO") +} diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 1a20308f2..18fdf786f 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -23,6 +23,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/native/noderoles" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/stateroot" + "github.com/nspcc-dev/neo-go/pkg/core/statesync" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -409,6 +410,64 @@ func (bc *Blockchain) init() error { return bc.updateExtensibleWhitelist(bHeight) } +// JumpToState is an atomic operation that changes Blockchain state to the one +// specified by the state sync point p. All the data needed for the jump must be +// collected by the state sync module. +func (bc *Blockchain) JumpToState(module blockchainer.StateSync) error { + bc.lock.Lock() + defer bc.lock.Unlock() + + p, err := module.GetJumpHeight() + if err != nil { + return fmt.Errorf("failed to get jump height: %w", err) + } + if p+1 >= uint32(len(bc.headerHashes)) { + return fmt.Errorf("invalid state sync point") + } + + bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p)) + + block, err := bc.dao.GetBlock(bc.headerHashes[p]) + if err != nil { + return fmt.Errorf("failed to get current block: %w", err) + } + err = bc.dao.StoreAsCurrentBlock(block, nil) + if err != nil { + return fmt.Errorf("failed to store current block: %w", err) + } + bc.topBlock.Store(block) + atomic.StoreUint32(&bc.blockHeight, p) + atomic.StoreUint32(&bc.persistedHeight, p) + + block, err = bc.dao.GetBlock(bc.headerHashes[p+1]) + if err != nil { + return fmt.Errorf("failed to get block to init MPT: %w", err) + } + if err = bc.stateRoot.JumpToState(&state.MPTRoot{ + Index: p, + Root: block.PrevStateRoot, + }, bc.config.KeepOnlyLatestState); err != nil { + return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err) + } + + err = bc.contracts.NEO.InitializeCache(bc, bc.dao) + if err != nil { + return fmt.Errorf("can't init cache for NEO native contract: %w", err) + } + err = bc.contracts.Management.InitializeCache(bc.dao) + if err != nil { + return fmt.Errorf("can't init cache for Management native contract: %w", err) + } + bc.contracts.Designate.InitializeCache() + + if err := bc.updateExtensibleWhitelist(p); err != nil { + return fmt.Errorf("failed to update extensible whitelist: %w", err) + } + + updateBlockHeightMetric(p) + return nil +} + // Run runs chain loop, it needs to be run as goroutine and executing it is // critical for correct Blockchain operation. func (bc *Blockchain) Run() { @@ -696,6 +755,11 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot { return bc.stateRoot } +// GetStateSyncModule returns new state sync service instance. +func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync { + return statesync.NewModule(bc, bc.log, bc.dao) +} + // storeBlock performs chain update using the block given, it executes all // transactions with all appropriate side-effects and updates Blockchain state. // This is the only way to change Blockchain state. diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index a0d5af6f4..9f9ae4e03 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -21,7 +21,6 @@ import ( type Blockchainer interface { ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction GetConfig() config.ProtocolConfiguration - AddHeaders(...*block.Header) error Blockqueuer // Blockqueuer interface CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error) Close() @@ -56,10 +55,12 @@ type Blockchainer interface { GetStandByCommittee() keys.PublicKeys GetStandByValidators() keys.PublicKeys GetStateModule() StateRoot + GetStateSyncModule() StateSync GetStorageItem(id int32, key []byte) state.StorageItem GetStorageItems(id int32) (map[string]state.StorageItem, error) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) + JumpToState(module StateSync) error SetOracle(service services.Oracle) mempool.Feer // fee interface ManagementContractHash() util.Uint160 diff --git a/pkg/core/blockchainer/blockqueuer.go b/pkg/core/blockchainer/blockqueuer.go index 384ccd489..bae45281f 100644 --- a/pkg/core/blockchainer/blockqueuer.go +++ b/pkg/core/blockchainer/blockqueuer.go @@ -5,5 +5,6 @@ import "github.com/nspcc-dev/neo-go/pkg/core/block" // Blockqueuer is an interface for blockqueue. type Blockqueuer interface { AddBlock(block *block.Block) error + AddHeaders(...*block.Header) error BlockHeight() uint32 } diff --git a/pkg/core/blockchainer/state_root.go b/pkg/core/blockchainer/state_root.go index 979e15963..64f7d8b7c 100644 --- a/pkg/core/blockchainer/state_root.go +++ b/pkg/core/blockchainer/state_root.go @@ -9,6 +9,7 @@ import ( // StateRoot represents local state root module. type StateRoot interface { AddStateRoot(root *state.MPTRoot) error + CurrentLocalHeight() uint32 CurrentLocalStateRoot() util.Uint256 CurrentValidatedHeight() uint32 GetStateProof(root util.Uint256, key []byte) ([][]byte, error) diff --git a/pkg/core/blockchainer/state_sync.go b/pkg/core/blockchainer/state_sync.go new file mode 100644 index 000000000..66643020c --- /dev/null +++ b/pkg/core/blockchainer/state_sync.go @@ -0,0 +1,19 @@ +package blockchainer + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// StateSync represents state sync module. +type StateSync interface { + AddMPTNodes([][]byte) error + Blockqueuer // Blockqueuer interface + Init(currChainHeight uint32) error + IsActive() bool + IsInitialized() bool + GetJumpHeight() (uint32, error) + NeedHeaders() bool + NeedMPTNodes() bool + Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go index ceabb31f7..63c02c089 100644 --- a/pkg/core/mpt/helpers.go +++ b/pkg/core/mpt/helpers.go @@ -1,5 +1,7 @@ package mpt +import "github.com/nspcc-dev/neo-go/pkg/util" + // lcp returns longest common prefix of a and b. // Note: it does no allocations. func lcp(a, b []byte) []byte { @@ -49,3 +51,36 @@ func fromNibbles(path []byte) []byte { } return result } + +// GetChildrenPaths returns a set of paths to node's children who are non-empty HashNodes +// based on the node's path. +func GetChildrenPaths(path []byte, node Node) map[util.Uint256][][]byte { + res := make(map[util.Uint256][][]byte) + switch n := node.(type) { + case *LeafNode, *HashNode, EmptyNode: + return nil + case *BranchNode: + for i, child := range n.Children { + if child.Type() == HashT { + cPath := make([]byte, len(path), len(path)+1) + copy(cPath, path) + if i != lastChild { + cPath = append(cPath, byte(i)) + } + paths := res[child.Hash()] + paths = append(paths, cPath) + res[child.Hash()] = paths + } + } + case *ExtensionNode: + if n.next.Type() == HashT { + cPath := make([]byte, len(path)+len(n.key)) + copy(cPath, path) + copy(cPath[len(path):], n.key) + res[n.next.Hash()] = [][]byte{cPath} + } + default: + panic("unknown Node type") + } + return res +} diff --git a/pkg/core/mpt/helpers_test.go b/pkg/core/mpt/helpers_test.go index ee5634bfd..28181dcc2 100644 --- a/pkg/core/mpt/helpers_test.go +++ b/pkg/core/mpt/helpers_test.go @@ -3,6 +3,7 @@ package mpt import ( "testing" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" ) @@ -18,3 +19,49 @@ func TestToNibblesFromNibbles(t *testing.T) { check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF}) }) } + +func TestGetChildrenPaths(t *testing.T) { + h1 := NewHashNode(util.Uint256{1, 2, 3}) + h2 := NewHashNode(util.Uint256{4, 5, 6}) + h3 := NewHashNode(util.Uint256{7, 8, 9}) + l := NewLeafNode([]byte{1, 2, 3}) + ext1 := NewExtensionNode([]byte{8, 9}, h1) + ext2 := NewExtensionNode([]byte{7, 6}, l) + branch := NewBranchNode() + branch.Children[3] = h1 + branch.Children[5] = l + branch.Children[6] = h1 // 3-th and 6-th children have the same hash + branch.Children[7] = h3 + branch.Children[lastChild] = h2 + testCases := map[string]struct { + node Node + expected map[util.Uint256][][]byte + }{ + "Hash": {h1, nil}, + "Leaf": {l, nil}, + "Extension with next Hash": {ext1, map[util.Uint256][][]byte{h1.Hash(): {ext1.key}}}, + "Extension with next non-Hash": {ext2, map[util.Uint256][][]byte{}}, + "Branch": {branch, map[util.Uint256][][]byte{ + h1.Hash(): {{0x03}, {0x06}}, + h2.Hash(): {{}}, + h3.Hash(): {{0x07}}, + }}, + } + parentPath := []byte{4, 5, 6} + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + require.Equal(t, testCase.expected, GetChildrenPaths([]byte{}, testCase.node)) + if testCase.expected != nil { + expectedWithPrefix := make(map[util.Uint256][][]byte, len(testCase.expected)) + for h, paths := range testCase.expected { + var res [][]byte + for _, path := range paths { + res = append(res, append(parentPath, path...)) + } + expectedWithPrefix[h] = res + } + require.Equal(t, expectedWithPrefix, GetChildrenPaths(parentPath, testCase.node)) + } + }) + } +} diff --git a/pkg/core/mpt/proof_test.go b/pkg/core/mpt/proof_test.go index 75a76408f..d733fb6d0 100644 --- a/pkg/core/mpt/proof_test.go +++ b/pkg/core/mpt/proof_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" ) -func newProofTrie(t *testing.T) *Trie { +func newProofTrie(t *testing.T, missingHashNode bool) *Trie { l := NewLeafNode([]byte("somevalue")) e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l) l2 := NewLeafNode([]byte("invalid")) @@ -20,11 +20,14 @@ func newProofTrie(t *testing.T) *Trie { require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) tr.putToStore(l) tr.putToStore(e) + if !missingHashNode { + tr.putToStore(l2) + } return tr } func TestTrie_GetProof(t *testing.T) { - tr := newProofTrie(t) + tr := newProofTrie(t, true) t.Run("MissingKey", func(t *testing.T) { _, err := tr.GetProof([]byte{0x12}) @@ -43,7 +46,7 @@ func TestTrie_GetProof(t *testing.T) { } func TestVerifyProof(t *testing.T) { - tr := newProofTrie(t) + tr := newProofTrie(t, true) t.Run("Simple", func(t *testing.T) { proof, err := tr.GetProof([]byte{0x12, 0x32}) diff --git a/pkg/core/native/designate.go b/pkg/core/native/designate.go index 8c75eebc8..844690a3d 100644 --- a/pkg/core/native/designate.go +++ b/pkg/core/native/designate.go @@ -353,3 +353,8 @@ func (s *Designate) getRole(item stackitem.Item) (noderoles.Role, bool) { u := bi.Uint64() return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u)) } + +// InitializeCache invalidates native Designate cache. +func (s *Designate) InitializeCache() { + s.rolesChangedFlag.Store(true) +} diff --git a/pkg/core/stateroot/module.go b/pkg/core/stateroot/module.go index 530dbf583..20a3396e9 100644 --- a/pkg/core/stateroot/module.go +++ b/pkg/core/stateroot/module.go @@ -114,6 +114,25 @@ func (s *Module) Init(height uint32, enableRefCount bool) error { return nil } +// JumpToState performs jump to the state specified by given stateroot index. +func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error { + if err := s.addLocalStateRoot(s.Store, sr); err != nil { + return fmt.Errorf("failed to store local state root: %w", err) + } + + data := make([]byte, 4) + binary.LittleEndian.PutUint32(data, sr.Index) + if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixValidated}, data); err != nil { + return fmt.Errorf("failed to store validated height: %w", err) + } + s.validatedHeight.Store(sr.Index) + + s.currentLocal.Store(sr.Root) + s.localHeight.Store(sr.Index) + s.mpt = mpt.NewTrie(mpt.NewHashNode(sr.Root), enableRefCount, s.Store) + return nil +} + // AddMPTBatch updates using provided batch. func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) { mpt := *s.mpt diff --git a/pkg/core/statesync/module.go b/pkg/core/statesync/module.go new file mode 100644 index 000000000..24f38f14e --- /dev/null +++ b/pkg/core/statesync/module.go @@ -0,0 +1,440 @@ +/* +Package statesync implements module for the P2P state synchronisation process. The +module manages state synchronisation for non-archival nodes which are joining the +network and don't have the ability to resync from the genesis block. + +Given the currently available state synchronisation point P, sate sync process +includes the following stages: + +1. Fetching headers starting from height 0 up to P+1. +2. Fetching MPT nodes for height P stating from the corresponding state root. +3. Fetching blocks starting from height P-MaxTraceableBlocks (or 0) up to P. + +Steps 2 and 3 are being performed in parallel. Once all the data are collected +and stored in the db, an atomic state jump is occurred to the state sync point P. +Further node operation process is performed using standard sync mechanism until +the node reaches synchronised state. +*/ +package statesync + +import ( + "encoding/hex" + "errors" + "fmt" + "sync" + + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" + "github.com/nspcc-dev/neo-go/pkg/core/dao" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "go.uber.org/zap" +) + +// stateSyncStage is a type of state synchronisation stage. +type stateSyncStage uint8 + +const ( + // inactive means that state exchange is disabled by the protocol configuration. + // Can't be combined with other states. + inactive stateSyncStage = 1 << iota + // none means that state exchange is enabled in the configuration, but + // initialisation of the state sync module wasn't yet performed, i.e. + // (*Module).Init wasn't called. Can't be combined with other states. + none + // initialized means that (*Module).Init was called, but other sync stages + // are not yet reached (i.e. that headers are requested, but not yet fetched). + // Can't be combined with other states. + initialized + // headersSynced means that headers for the current state sync point are fetched. + // May be combined with mptSynced and/or blocksSynced. + headersSynced + // mptSynced means that MPT nodes for the current state sync point are fetched. + // Always combined with headersSynced; may be combined with blocksSynced. + mptSynced + // blocksSynced means that blocks up to the current state sync point are stored. + // Always combined with headersSynced; may be combined with mptSynced. + blocksSynced +) + +// Module represents state sync module and aimed to gather state-related data to +// perform an atomic state jump. +type Module struct { + lock sync.RWMutex + log *zap.Logger + + // syncPoint is the state synchronisation point P we're currently working against. + syncPoint uint32 + // syncStage is the stage of the sync process. + syncStage stateSyncStage + // syncInterval is the delta between two adjacent state sync points. + syncInterval uint32 + // blockHeight is the index of the latest stored block. + blockHeight uint32 + + dao *dao.Simple + bc blockchainer.Blockchainer + mptpool *Pool + + billet *mpt.Billet +} + +// NewModule returns new instance of statesync module. +func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple) *Module { + if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) { + return &Module{ + dao: s, + bc: bc, + syncStage: inactive, + } + } + return &Module{ + dao: s, + bc: bc, + log: log, + syncInterval: uint32(bc.GetConfig().StateSyncInterval), + mptpool: NewPool(), + syncStage: none, + } +} + +// Init initializes state sync module for the current chain's height with given +// callback for MPT nodes requests. +func (s *Module) Init(currChainHeight uint32) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage != none { + return errors.New("already initialized or inactive") + } + + p := (currChainHeight / s.syncInterval) * s.syncInterval + if p < 2*s.syncInterval { + // chain is too low to start state exchange process, use the standard sync mechanism + s.syncStage = inactive + return nil + } + pOld, err := s.dao.GetStateSyncPoint() + if err == nil && pOld >= p-s.syncInterval { + // old point is still valid, so try to resync states for this point. + p = pOld + } else if s.bc.BlockHeight() > p-2*s.syncInterval { + // chain has already been synchronised up to old state sync point and regular blocks processing was started + s.syncStage = inactive + return nil + } + + s.syncPoint = p + err = s.dao.PutStateSyncPoint(p) + if err != nil { + return fmt.Errorf("failed to store state synchronisation point %d: %w", p, err) + } + s.syncStage = initialized + s.log.Info("try to sync state for the latest state synchronisation point", + zap.Uint32("point", p), + zap.Uint32("evaluated chain's blockHeight", currChainHeight)) + + // check headers sync state first + ltstHeaderHeight := s.bc.HeaderHeight() + if ltstHeaderHeight > p { + s.syncStage = headersSynced + s.log.Info("headers are in sync", + zap.Uint32("headerHeight", s.bc.HeaderHeight())) + } + + // check blocks sync state + s.blockHeight = s.getLatestSavedBlock(p) + if s.blockHeight >= p { + s.syncStage |= blocksSynced + s.log.Info("blocks are in sync", + zap.Uint32("blockHeight", s.blockHeight)) + } + + // check MPT sync state + if s.blockHeight > p { + s.syncStage |= mptSynced + s.log.Info("MPT is in sync", + zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight())) + } else if s.syncStage&headersSynced != 0 { + header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(p + 1))) + if err != nil { + return fmt.Errorf("failed to get header to initialize MPT billet: %w", err) + } + s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store) + s.log.Info("MPT billet initialized", + zap.Uint32("height", s.syncPoint), + zap.String("state root", header.PrevStateRoot.StringBE())) + pool := NewPool() + pool.Add(header.PrevStateRoot, []byte{}) + err = s.billet.Traverse(func(n mpt.Node, _ []byte) bool { + nPaths, ok := pool.TryGet(n.Hash()) + if !ok { + // if this situation occurs, then it's a bug in MPT pool or Traverse. + panic("failed to get MPT node from the pool") + } + pool.Remove(n.Hash()) + childrenPaths := make(map[util.Uint256][][]byte) + for _, path := range nPaths { + nChildrenPaths := mpt.GetChildrenPaths(path, n) + for hash, paths := range nChildrenPaths { + childrenPaths[hash] = append(childrenPaths[hash], paths...) // it's OK to have duplicates, they'll be handled by mempool + } + } + pool.Update(nil, childrenPaths) + return false + }, true) + if err != nil { + return fmt.Errorf("failed to traverse MPT while initialization: %w", err) + } + s.mptpool.Update(nil, pool.GetAll()) + if s.mptpool.Count() == 0 { + s.syncStage |= mptSynced + s.log.Info("MPT is in sync", + zap.Uint32("stateroot height", p)) + } + } + + if s.syncStage == headersSynced|blocksSynced|mptSynced { + s.log.Info("state is in sync, starting regular blocks processing") + s.syncStage = inactive + } + return nil +} + +// getLatestSavedBlock returns either current block index (if it's still relevant +// to continue state sync process) or H-1 where H is the index of the earliest +// block that should be saved next. +func (s *Module) getLatestSavedBlock(p uint32) uint32 { + var result uint32 + mtb := s.bc.GetConfig().MaxTraceableBlocks + if p > mtb { + result = p - mtb + } + storedH, err := s.dao.GetStateSyncCurrentBlockHeight() + if err == nil && storedH > result { + result = storedH + } + actualH := s.bc.BlockHeight() + if actualH > result { + result = actualH + } + return result +} + +// AddHeaders validates and adds specified headers to the chain. +func (s *Module) AddHeaders(hdrs ...*block.Header) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage != initialized { + return errors.New("headers were not requested") + } + + hdrsErr := s.bc.AddHeaders(hdrs...) + if s.bc.HeaderHeight() > s.syncPoint { + s.syncStage = headersSynced + s.log.Info("headers for state sync are fetched", + zap.Uint32("header height", s.bc.HeaderHeight())) + + header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint) + 1)) + if err != nil { + s.log.Fatal("failed to get header to initialize MPT billet", + zap.Uint32("height", s.syncPoint+1), + zap.Error(err)) + } + s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store) + s.mptpool.Add(header.PrevStateRoot, []byte{}) + s.log.Info("MPT billet initialized", + zap.Uint32("height", s.syncPoint), + zap.String("state root", header.PrevStateRoot.StringBE())) + } + return hdrsErr +} + +// AddBlock verifies and saves block skipping executable scripts. +func (s *Module) AddBlock(block *block.Block) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage&headersSynced == 0 || s.syncStage&blocksSynced != 0 { + return nil + } + + if s.blockHeight == s.syncPoint { + return nil + } + expectedHeight := s.blockHeight + 1 + if expectedHeight != block.Index { + return fmt.Errorf("expected %d, got %d: invalid block index", expectedHeight, block.Index) + } + if s.bc.GetConfig().StateRootInHeader != block.StateRootEnabled { + return fmt.Errorf("stateroot setting mismatch: %v != %v", s.bc.GetConfig().StateRootInHeader, block.StateRootEnabled) + } + if s.bc.GetConfig().VerifyBlocks { + merkle := block.ComputeMerkleRoot() + if !block.MerkleRoot.Equals(merkle) { + return errors.New("invalid block: MerkleRoot mismatch") + } + } + cache := s.dao.GetWrapped() + writeBuf := io.NewBufBinWriter() + if err := cache.StoreAsBlock(block, writeBuf); err != nil { + return err + } + writeBuf.Reset() + + err := cache.PutStateSyncCurrentBlockHeight(block.Index) + if err != nil { + return fmt.Errorf("failed to store current block height: %w", err) + } + + for _, tx := range block.Transactions { + if err := cache.StoreAsTransaction(tx, block.Index, writeBuf); err != nil { + return err + } + writeBuf.Reset() + } + + _, err = cache.Persist() + if err != nil { + return fmt.Errorf("failed to persist results: %w", err) + } + s.blockHeight = block.Index + if s.blockHeight == s.syncPoint { + s.syncStage |= blocksSynced + s.log.Info("blocks are in sync", + zap.Uint32("blockHeight", s.blockHeight)) + s.checkSyncIsCompleted() + } + return nil +} + +// AddMPTNodes tries to add provided set of MPT nodes to the MPT billet if they are +// not yet collected. +func (s *Module) AddMPTNodes(nodes [][]byte) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage&headersSynced == 0 || s.syncStage&mptSynced != 0 { + return errors.New("MPT nodes were not requested") + } + + for _, nBytes := range nodes { + var n mpt.NodeObject + r := io.NewBinReaderFromBuf(nBytes) + n.DecodeBinary(r) + if r.Err != nil { + return fmt.Errorf("failed to decode MPT node: %w", r.Err) + } + nPaths, ok := s.mptpool.TryGet(n.Hash()) + if !ok { + // it can easily happen after receiving the same data from different peers. + return nil + } + + var childrenPaths = make(map[util.Uint256][][]byte) + for _, path := range nPaths { + err := s.billet.RestoreHashNode(path, n.Node) + if err != nil { + return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err) + } + for h, paths := range mpt.GetChildrenPaths(path, n.Node) { + childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool + } + } + + s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths) + } + if s.mptpool.Count() == 0 { + s.syncStage |= mptSynced + s.log.Info("MPT is in sync", + zap.Uint32("height", s.syncPoint)) + s.checkSyncIsCompleted() + } + return nil +} + +// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1 +// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored. +// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller +// should take care of it. +func (s *Module) checkSyncIsCompleted() { + if s.syncStage != headersSynced|mptSynced|blocksSynced { + return + } + s.log.Info("state is in sync", + zap.Uint32("state sync point", s.syncPoint)) + err := s.bc.JumpToState(s) + if err != nil { + s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err)) + } + s.syncStage = inactive + s.dispose() +} + +func (s *Module) dispose() { + s.billet = nil +} + +// BlockHeight returns index of the last stored block. +func (s *Module) BlockHeight() uint32 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.blockHeight +} + +// IsActive tells whether state sync module is on and still gathering state +// synchronisation data (headers, blocks or MPT nodes). +func (s *Module) IsActive() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return !(s.syncStage == inactive || (s.syncStage == headersSynced|mptSynced|blocksSynced)) +} + +// IsInitialized tells whether state sync module does not require initialization. +// If `false` is returned then Init can be safely called. +func (s *Module) IsInitialized() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.syncStage != none +} + +// NeedHeaders tells whether the module hasn't completed headers synchronisation. +func (s *Module) NeedHeaders() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.syncStage == initialized +} + +// NeedMPTNodes returns whether the module hasn't completed MPT synchronisation. +func (s *Module) NeedMPTNodes() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.syncStage&headersSynced != 0 && s.syncStage&mptSynced == 0 +} + +// Traverse traverses local MPT nodes starting from the specified root down to its +// children calling `process` for each serialised node until stop condition is satisfied. +func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + s.lock.RLock() + defer s.lock.RUnlock() + + b := mpt.NewBillet(root, s.bc.GetConfig().KeepOnlyLatestState, storage.NewMemCachedStore(s.dao.Store)) + return b.Traverse(process, false) +} + +// GetJumpHeight returns state sync point to jump to. It is not protected by mutex and should be called +// under the module lock. +func (s *Module) GetJumpHeight() (uint32, error) { + if s.syncStage != headersSynced|mptSynced|blocksSynced { + return 0, errors.New("state sync module has wong state to perform state jump") + } + return s.syncPoint, nil +} diff --git a/pkg/core/statesync/mptpool.go b/pkg/core/statesync/mptpool.go new file mode 100644 index 000000000..23610d97a --- /dev/null +++ b/pkg/core/statesync/mptpool.go @@ -0,0 +1,119 @@ +package statesync + +import ( + "bytes" + "sort" + "sync" + + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// Pool stores unknown MPT nodes along with the corresponding paths (single node is +// allowed to have multiple MPT paths). +type Pool struct { + lock sync.RWMutex + hashes map[util.Uint256][][]byte +} + +// NewPool returns new MPT node hashes pool. +func NewPool() *Pool { + return &Pool{ + hashes: make(map[util.Uint256][][]byte), + } +} + +// ContainsKey checks if MPT node with the specified hash is in the Pool. +func (mp *Pool) ContainsKey(hash util.Uint256) bool { + mp.lock.RLock() + defer mp.lock.RUnlock() + + _, ok := mp.hashes[hash] + return ok +} + +// TryGet returns a set of MPT paths for the specified HashNode. +func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) { + mp.lock.RLock() + defer mp.lock.RUnlock() + + paths, ok := mp.hashes[hash] + return paths, ok +} + +// GetAll returns all MPT nodes with the corresponding paths from the pool. +func (mp *Pool) GetAll() map[util.Uint256][][]byte { + mp.lock.RLock() + defer mp.lock.RUnlock() + + return mp.hashes +} + +// Remove removes MPT node from the pool by the specified hash. +func (mp *Pool) Remove(hash util.Uint256) { + mp.lock.Lock() + defer mp.lock.Unlock() + + delete(mp.hashes, hash) +} + +// Add adds path to the set of paths for the specified node. +func (mp *Pool) Add(hash util.Uint256, path []byte) { + mp.lock.Lock() + defer mp.lock.Unlock() + + mp.addPaths(hash, [][]byte{path}) +} + +// Update is an atomic operation and removes/adds specified nodes from/to the pool. +func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) { + mp.lock.Lock() + defer mp.lock.Unlock() + + for h, paths := range remove { + old := mp.hashes[h] + for _, path := range paths { + i := sort.Search(len(old), func(i int) bool { + return bytes.Compare(old[i], path) >= 0 + }) + if i < len(old) && bytes.Equal(old[i], path) { + old = append(old[:i], old[i+1:]...) + } + } + if len(old) == 0 { + delete(mp.hashes, h) + } else { + mp.hashes[h] = old + } + } + for h, paths := range add { + mp.addPaths(h, paths) + } +} + +// addPaths adds set of the specified node paths to the pool. +func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) { + old := mp.hashes[nodeHash] + for _, path := range paths { + i := sort.Search(len(old), func(i int) bool { + return bytes.Compare(old[i], path) >= 0 + }) + if i < len(old) && bytes.Equal(old[i], path) { + // then path is already added + continue + } + old = append(old, path) + if i != len(old)-1 { + copy(old[i+1:], old[i:]) + old[i] = path + } + } + mp.hashes[nodeHash] = old +} + +// Count returns the number of nodes in the pool. +func (mp *Pool) Count() int { + mp.lock.RLock() + defer mp.lock.RUnlock() + + return len(mp.hashes) +} diff --git a/pkg/network/blockqueue.go b/pkg/network/blockqueue.go index 8a21cab1b..7cc796886 100644 --- a/pkg/network/blockqueue.go +++ b/pkg/network/blockqueue.go @@ -4,6 +4,7 @@ import ( "github.com/Workiva/go-datastructures/queue" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -13,6 +14,7 @@ type blockQueue struct { checkBlocks chan struct{} chain blockchainer.Blockqueuer relayF func(*block.Block) + discarded *atomic.Bool } const ( @@ -32,6 +34,7 @@ func newBlockQueue(capacity int, bc blockchainer.Blockqueuer, log *zap.Logger, r checkBlocks: make(chan struct{}, 1), chain: bc, relayF: relayer, + discarded: atomic.NewBool(false), } } @@ -91,8 +94,10 @@ func (bq *blockQueue) putBlock(block *block.Block) error { } func (bq *blockQueue) discard() { - close(bq.checkBlocks) - bq.queue.Dispose() + if bq.discarded.CAS(false, true) { + close(bq.checkBlocks) + bq.queue.Dispose() + } } func (bq *blockQueue) length() int { diff --git a/pkg/network/message.go b/pkg/network/message.go index cbdaa554f..fb605c8a7 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -71,6 +71,8 @@ const ( CMDBlock = CommandType(payload.BlockType) CMDExtensible = CommandType(payload.ExtensibleType) CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType) + CMDGetMPTData CommandType = 0x51 // 0x5.. commands are used for extensions (P2PNotary, state exchange cmds) + CMDMPTData CommandType = 0x52 CMDReject CommandType = 0x2f // SPV protocol. @@ -136,6 +138,10 @@ func (m *Message) decodePayload() error { p = &payload.Version{} case CMDInv, CMDGetData: p = &payload.Inventory{} + case CMDGetMPTData: + p = &payload.MPTInventory{} + case CMDMPTData: + p = &payload.MPTData{} case CMDAddr: p = &payload.AddressList{} case CMDBlock: @@ -221,7 +227,7 @@ func (m *Message) tryCompressPayload() error { if m.Flags&Compressed == 0 { switch m.Payload.(type) { case *payload.Headers, *payload.MerkleBlock, payload.NullPayload, - *payload.Inventory: + *payload.Inventory, *payload.MPTInventory: break default: size := len(compressedPayload) diff --git a/pkg/network/message_string.go b/pkg/network/message_string.go index 7da007079..2ebdacd9b 100644 --- a/pkg/network/message_string.go +++ b/pkg/network/message_string.go @@ -26,6 +26,8 @@ func _() { _ = x[CMDBlock-44] _ = x[CMDExtensible-46] _ = x[CMDP2PNotaryRequest-80] + _ = x[CMDGetMPTData-81] + _ = x[CMDMPTData-82] _ = x[CMDReject-47] _ = x[CMDFilterLoad-48] _ = x[CMDFilterAdd-49] @@ -44,7 +46,7 @@ const ( _CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear" _CommandType_name_7 = "CMDMerkleBlock" _CommandType_name_8 = "CMDAlert" - _CommandType_name_9 = "CMDP2PNotaryRequest" + _CommandType_name_9 = "CMDP2PNotaryRequestCMDGetMPTDataCMDMPTData" ) var ( @@ -55,6 +57,7 @@ var ( _CommandType_index_4 = [...]uint8{0, 12, 22} _CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58} _CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61} + _CommandType_index_9 = [...]uint8{0, 19, 32, 42} ) func (i CommandType) String() string { @@ -83,8 +86,9 @@ func (i CommandType) String() string { return _CommandType_name_7 case i == 64: return _CommandType_name_8 - case i == 80: - return _CommandType_name_9 + case 80 <= i && i <= 82: + i -= 80 + return _CommandType_name_9[_CommandType_index_9[i]:_CommandType_index_9[i+1]] default: return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/network/message_test.go b/pkg/network/message_test.go index c4976a702..38b620189 100644 --- a/pkg/network/message_test.go +++ b/pkg/network/message_test.go @@ -258,6 +258,21 @@ func TestEncodeDecodeNotFound(t *testing.T) { }) } +func TestEncodeDecodeGetMPTData(t *testing.T) { + testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{ + Hashes: []util.Uint256{ + {1, 2, 3}, + {4, 5, 6}, + }, + }) +} + +func TestEncodeDecodeMPTData(t *testing.T) { + testEncodeDecode(t, CMDMPTData, &payload.MPTData{ + Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}}, + }) +} + func TestInvalidMessages(t *testing.T) { t.Run("CMDBlock, empty payload", func(t *testing.T) { testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{}) diff --git a/pkg/network/payload/mptdata.go b/pkg/network/payload/mptdata.go new file mode 100644 index 000000000..ba607a86b --- /dev/null +++ b/pkg/network/payload/mptdata.go @@ -0,0 +1,35 @@ +package payload + +import ( + "errors" + + "github.com/nspcc-dev/neo-go/pkg/io" +) + +// MPTData represents the set of serialized MPT nodes. +type MPTData struct { + Nodes [][]byte +} + +// EncodeBinary implements io.Serializable. +func (d *MPTData) EncodeBinary(w *io.BinWriter) { + w.WriteVarUint(uint64(len(d.Nodes))) + for _, n := range d.Nodes { + w.WriteVarBytes(n) + } +} + +// DecodeBinary implements io.Serializable. +func (d *MPTData) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz == 0 { + r.Err = errors.New("empty MPT nodes list") + return + } + for i := uint64(0); i < sz; i++ { + d.Nodes = append(d.Nodes, r.ReadVarBytes()) + if r.Err != nil { + return + } + } +} diff --git a/pkg/network/payload/mptdata_test.go b/pkg/network/payload/mptdata_test.go new file mode 100644 index 000000000..d9db7bff6 --- /dev/null +++ b/pkg/network/payload/mptdata_test.go @@ -0,0 +1,24 @@ +package payload + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/stretchr/testify/require" +) + +func TestMPTData_EncodeDecodeBinary(t *testing.T) { + t.Run("empty", func(t *testing.T) { + d := new(MPTData) + bytes, err := testserdes.EncodeBinary(d) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(bytes, new(MPTData))) + }) + + t.Run("good", func(t *testing.T) { + d := &MPTData{ + Nodes: [][]byte{{}, {1}, {1, 2, 3}}, + } + testserdes.EncodeDecodeBinary(t, d, new(MPTData)) + }) +} diff --git a/pkg/network/payload/mptinventory.go b/pkg/network/payload/mptinventory.go new file mode 100644 index 000000000..66aaf5c02 --- /dev/null +++ b/pkg/network/payload/mptinventory.go @@ -0,0 +1,32 @@ +package payload + +import ( + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxMPTHashesCount is the maximum number of requested MPT nodes hashes. +const MaxMPTHashesCount = 32 + +// MPTInventory payload. +type MPTInventory struct { + // A list of requested MPT nodes hashes. + Hashes []util.Uint256 +} + +// NewMPTInventory return a pointer to an MPTInventory. +func NewMPTInventory(hashes []util.Uint256) *MPTInventory { + return &MPTInventory{ + Hashes: hashes, + } +} + +// DecodeBinary implements Serializable interface. +func (p *MPTInventory) DecodeBinary(br *io.BinReader) { + br.ReadArray(&p.Hashes, MaxMPTHashesCount) +} + +// EncodeBinary implements Serializable interface. +func (p *MPTInventory) EncodeBinary(bw *io.BinWriter) { + bw.WriteArray(p.Hashes) +} diff --git a/pkg/network/payload/mptinventory_test.go b/pkg/network/payload/mptinventory_test.go new file mode 100644 index 000000000..a0c052a67 --- /dev/null +++ b/pkg/network/payload/mptinventory_test.go @@ -0,0 +1,38 @@ +package payload + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestMPTInventory_EncodeDecodeBinary(t *testing.T) { + t.Run("empty", func(t *testing.T) { + testserdes.EncodeDecodeBinary(t, NewMPTInventory([]util.Uint256{}), new(MPTInventory)) + }) + + t.Run("good", func(t *testing.T) { + inv := NewMPTInventory([]util.Uint256{{1, 2, 3}, {2, 3, 4}}) + testserdes.EncodeDecodeBinary(t, inv, new(MPTInventory)) + }) + + t.Run("too large", func(t *testing.T) { + check := func(t *testing.T, count int, fail bool) { + h := make([]util.Uint256, count) + for i := range h { + h[i] = util.Uint256{1, 2, 3} + } + if fail { + bytes, err := testserdes.EncodeBinary(NewMPTInventory(h)) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(bytes, new(MPTInventory))) + } else { + testserdes.EncodeDecodeBinary(t, NewMPTInventory(h), new(MPTInventory)) + } + } + check(t, MaxMPTHashesCount, false) + check(t, MaxMPTHashesCount+1, true) + }) +} diff --git a/pkg/network/server.go b/pkg/network/server.go index 457add0b0..b9f53ed0e 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -7,6 +7,7 @@ import ( "fmt" mrand "math/rand" "net" + "sort" "strconv" "sync" "time" @@ -17,7 +18,9 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/extpool" "github.com/nspcc-dev/neo-go/pkg/network/payload" @@ -67,6 +70,7 @@ type ( discovery Discoverer chain blockchainer.Blockchainer bQueue *blockQueue + bSyncQueue *blockQueue consensus consensus.Service mempool *mempool.Pool notaryRequestPool *mempool.Pool @@ -93,6 +97,7 @@ type ( oracle *oracle.Oracle stateRoot stateroot.Service + stateSync blockchainer.StateSync log *zap.Logger } @@ -191,6 +196,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai } s.stateRoot = sr + sSync := chain.GetStateSyncModule() + s.stateSync = sSync + s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil) + if config.OracleCfg.Enabled { orcCfg := oracle.Config{ Log: log, @@ -277,6 +286,7 @@ func (s *Server) Start(errChan chan error) { go s.broadcastTxLoop() go s.relayBlocksLoop() go s.bQueue.run() + go s.bSyncQueue.run() go s.transport.Accept() setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10)) s.run() @@ -292,6 +302,7 @@ func (s *Server) Shutdown() { p.Disconnect(errServerShutdown) } s.bQueue.discard() + s.bSyncQueue.discard() if s.StateRootCfg.Enabled { s.stateRoot.Shutdown() } @@ -573,6 +584,10 @@ func (s *Server) IsInSync() bool { var peersNumber int var notHigher int + if s.stateSync.IsActive() { + return false + } + if s.MinPeers == 0 { return true } @@ -630,6 +645,9 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { // handleBlockCmd processes the received block received from its peer. func (s *Server) handleBlockCmd(p Peer, block *block.Block) error { + if s.stateSync.IsActive() { + return s.bSyncQueue.putBlock(block) + } return s.bQueue.putBlock(block) } @@ -639,25 +657,46 @@ func (s *Server) handlePing(p Peer, ping *payload.Ping) error { if err != nil { return err } - if s.chain.BlockHeight() < ping.LastBlockIndex { - err = s.requestBlocks(p) - if err != nil { - return err - } + err = s.requestBlocksOrHeaders(p) + if err != nil { + return err } return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id))) } +func (s *Server) requestBlocksOrHeaders(p Peer) error { + if s.stateSync.NeedHeaders() { + if s.chain.HeaderHeight() < p.LastBlockIndex() { + return s.requestHeaders(p) + } + return nil + } + var bq blockchainer.Blockqueuer = s.chain + if s.stateSync.IsActive() { + bq = s.stateSync + } + if bq.BlockHeight() < p.LastBlockIndex() { + return s.requestBlocks(bq, p) + } + return nil +} + +// requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers. +func (s *Server) requestHeaders(p Peer) error { + // TODO: optimize + currHeight := s.chain.HeaderHeight() + needHeight := currHeight + 1 + payload := payload.NewGetBlockByIndex(needHeight, -1) + return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload)) +} + // handlePing processes pong request. func (s *Server) handlePong(p Peer, pong *payload.Ping) error { err := p.HandlePong(pong) if err != nil { return err } - if s.chain.BlockHeight() < pong.LastBlockIndex { - return s.requestBlocks(p) - } - return nil + return s.requestBlocksOrHeaders(p) } // handleInvCmd processes the received inventory. @@ -766,6 +805,50 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { return nil } +// handleGetMPTDataCmd processes the received MPT inventory. +func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error { + if !s.chain.GetConfig().P2PStateExchangeExtensions { + return errors.New("GetMPTDataCMD was received, but P2PStateExchangeExtensions are disabled") + } + if s.chain.GetConfig().KeepOnlyLatestState { + // TODO: implement keeping MPT states for P1 and P2 height (#2095, #2152 related) + return errors.New("GetMPTDataCMD was received, but only latest MPT state is supported") + } + resp := payload.MPTData{} + capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes))) + for _, h := range inv.Hashes { + if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type + break + } + err := s.stateSync.Traverse(h, + func(_ mpt.Node, node []byte) bool { + l := len(node) + size := l + io.GetVarSize(l) + if size > capLeft { + return true + } + resp.Nodes = append(resp.Nodes, node) + capLeft -= size + return false + }) + if err != nil { + return fmt.Errorf("failed to traverse MPT starting from %s: %w", h.StringBE(), err) + } + } + if len(resp.Nodes) > 0 { + msg := NewMessage(CMDMPTData, &resp) + return p.EnqueueP2PMessage(msg) + } + return nil +} + +func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error { + if !s.chain.GetConfig().P2PStateExchangeExtensions { + return errors.New("MPTDataCMD was received, but P2PStateExchangeExtensions are disabled") + } + return s.stateSync.AddMPTNodes(data.Nodes) +} + // handleGetBlocksCmd processes the getblocks request. func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { count := gb.Count @@ -845,6 +928,11 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error return p.EnqueueP2PMessage(msg) } +// handleHeadersCmd processes headers payload. +func (s *Server) handleHeadersCmd(p Peer, h *payload.Headers) error { + return s.stateSync.AddHeaders(h.Hdrs...) +} + // handleExtensibleCmd processes received extensible payload. func (s *Server) handleExtensibleCmd(e *payload.Extensible) error { if !s.syncReached.Load() { @@ -993,8 +1081,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error { // 1. Block range is divided into chunks of payload.MaxHashesCount. // 2. Send requests for chunk in increasing order. // 3. After all requests were sent, request random height. -func (s *Server) requestBlocks(p Peer) error { - var currHeight = s.chain.BlockHeight() +func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error { + var currHeight = bq.BlockHeight() var peerHeight = p.LastBlockIndex() var needHeight uint32 // lastRequestedHeight can only be increased. @@ -1051,9 +1139,18 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDGetData: inv := msg.Payload.(*payload.Inventory) return s.handleGetDataCmd(peer, inv) + case CMDGetMPTData: + inv := msg.Payload.(*payload.MPTInventory) + return s.handleGetMPTDataCmd(peer, inv) + case CMDMPTData: + inv := msg.Payload.(*payload.MPTData) + return s.handleMPTDataCmd(peer, inv) case CMDGetHeaders: gh := msg.Payload.(*payload.GetBlockByIndex) return s.handleGetHeadersCmd(peer, gh) + case CMDHeaders: + h := msg.Payload.(*payload.Headers) + return s.handleHeadersCmd(peer, h) case CMDInv: inventory := msg.Payload.(*payload.Inventory) return s.handleInvCmd(peer, inventory) @@ -1093,6 +1190,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { } go peer.StartProtocol() + s.tryInitStateSync() s.tryStartServices() default: return fmt.Errorf("received '%s' during handshake", msg.Command.String()) @@ -1101,6 +1199,52 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return nil } +func (s *Server) tryInitStateSync() { + if !s.stateSync.IsActive() { + s.bSyncQueue.discard() + return + } + + if s.stateSync.IsInitialized() { + return + } + + var peersNumber int + s.lock.RLock() + heights := make([]uint32, 0) + for p := range s.peers { + if p.Handshaked() { + peersNumber++ + peerLastBlock := p.LastBlockIndex() + i := sort.Search(len(heights), func(i int) bool { + return heights[i] >= peerLastBlock + }) + heights = append(heights, peerLastBlock) + if i != len(heights)-1 { + copy(heights[i+1:], heights[i:]) + heights[i] = peerLastBlock + } + } + } + s.lock.RUnlock() + if peersNumber >= s.MinPeers && len(heights) > 0 { + // choose the height of the median peer as current chain's height + h := heights[len(heights)/2] + err := s.stateSync.Init(h) + if err != nil { + s.log.Fatal("failed to init state sync module", + zap.Uint32("evaluated chain's blockHeight", h), + zap.Uint32("blockHeight", s.chain.BlockHeight()), + zap.Uint32("headerHeight", s.chain.HeaderHeight()), + zap.Error(err)) + } + + // module can be inactive after init (i.e. full state is collected and ordinary block processing is needed) + if !s.stateSync.IsActive() { + s.bSyncQueue.discard() + } + } +} func (s *Server) handleNewPayload(p *payload.Extensible) { _, err := s.extensiblePool.Add(p) if err != nil { diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 0c9c9a358..4b1de5b64 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -2,6 +2,7 @@ package network import ( "errors" + "fmt" "math/big" "net" "strconv" @@ -46,7 +47,10 @@ func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") } func TestNewServer(t *testing.T) { - bc := &fakechain.FakeChain{} + bc := &fakechain.FakeChain{ProtocolConfiguration: config.ProtocolConfiguration{ + P2PStateExchangeExtensions: true, + StateRootInHeader: true, + }} s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery) require.Error(t, err) @@ -899,3 +903,39 @@ func TestVerifyNotaryRequest(t *testing.T) { require.NoError(t, verifyNotaryRequest(bc, nil, r)) }) } + +func TestTryInitStateSync(t *testing.T) { + t.Run("module inactive", func(t *testing.T) { + s := startTestServer(t) + s.tryInitStateSync() + }) + + t.Run("module already initialized", func(t *testing.T) { + s := startTestServer(t) + s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true} + s.tryInitStateSync() + }) + + t.Run("good", func(t *testing.T) { + s := startTestServer(t) + for _, h := range []uint32{10, 8, 7, 4, 11, 4} { + p := newLocalPeer(t, s) + p.handshaked = true + p.lastBlockIndex = h + s.peers[p] = true + } + p := newLocalPeer(t, s) + p.handshaked = false // one disconnected peer to check it won't be taken into attention + p.lastBlockIndex = 5 + s.peers[p] = true + var expectedH uint32 = 8 // median peer + + s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error { + if h != expectedH { + return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h) + } + return nil + }} + s.tryInitStateSync() + }) +} diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 8ff47a18c..5b159fda1 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -267,12 +267,10 @@ func (p *TCPPeer) StartProtocol() { zap.Uint32("id", p.Version().Nonce)) p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities) - if p.server.chain.BlockHeight() < p.LastBlockIndex() { - err = p.server.requestBlocks(p) - if err != nil { - p.Disconnect(err) - return - } + err = p.server.requestBlocksOrHeaders(p) + if err != nil { + p.Disconnect(err) + return } timer := time.NewTimer(p.server.ProtoTickInterval) @@ -281,10 +279,8 @@ func (p *TCPPeer) StartProtocol() { case <-p.done: return case <-timer.C: - // Try to sync in headers and block with the peer if his block height is higher then ours. - if p.LastBlockIndex() > p.server.chain.BlockHeight() { - err = p.server.requestBlocks(p) - } + // Try to sync in headers and block with the peer if his block height is higher than ours. + err = p.server.requestBlocksOrHeaders(p) if err == nil { timer.Reset(p.server.ProtoTickInterval) }