forked from TrueCloudLab/neoneo-go
core: implement statesync module
And support GetMPTData and MPTData P2P commands.
This commit is contained in:
parent
a22b1caa3e
commit
d67ff30704
24 changed files with 1197 additions and 32 deletions
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/native"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
|
@ -42,6 +43,13 @@ type FakeChain struct {
|
|||
UtilityTokenBalance *big.Int
|
||||
}
|
||||
|
||||
// FakeStateSync implements StateSync interface.
|
||||
type FakeStateSync struct {
|
||||
IsActiveFlag bool
|
||||
IsInitializedFlag bool
|
||||
InitFunc func(h uint32) error
|
||||
}
|
||||
|
||||
// NewFakeChain returns new FakeChain structure.
|
||||
func NewFakeChain() *FakeChain {
|
||||
return &FakeChain{
|
||||
|
@ -294,6 +302,16 @@ func (chain *FakeChain) GetStateModule() blockchainer.StateRoot {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetStateSyncModule implements Blockchainer interface.
|
||||
func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync {
|
||||
return &FakeStateSync{}
|
||||
}
|
||||
|
||||
// JumpToState implements Blockchainer interface.
|
||||
func (chain *FakeChain) JumpToState(module blockchainer.StateSync) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// GetStorageItem implements Blockchainer interface.
|
||||
func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
|
||||
panic("TODO")
|
||||
|
@ -436,3 +454,57 @@ func (chain *FakeChain) UnsubscribeFromNotifications(ch chan<- *state.Notificati
|
|||
func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// AddBlock implements StateSync interface.
|
||||
func (s *FakeStateSync) AddBlock(block *block.Block) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// AddHeaders implements StateSync interface.
|
||||
func (s *FakeStateSync) AddHeaders(...*block.Header) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// AddMPTNodes implements StateSync interface.
|
||||
func (s *FakeStateSync) AddMPTNodes([][]byte) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// BlockHeight implements StateSync interface.
|
||||
func (s *FakeStateSync) BlockHeight() uint32 {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// IsActive implements StateSync interface.
|
||||
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag }
|
||||
|
||||
// IsInitialized implements StateSync interface.
|
||||
func (s *FakeStateSync) IsInitialized() bool {
|
||||
return s.IsInitializedFlag
|
||||
}
|
||||
|
||||
// Init implements StateSync interface.
|
||||
func (s *FakeStateSync) Init(currChainHeight uint32) error {
|
||||
if s.InitFunc != nil {
|
||||
return s.InitFunc(currChainHeight)
|
||||
}
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// NeedHeaders implements StateSync interface.
|
||||
func (s *FakeStateSync) NeedHeaders() bool { return false }
|
||||
|
||||
// NeedMPTNodes implements StateSync interface.
|
||||
func (s *FakeStateSync) NeedMPTNodes() bool {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// Traverse implements StateSync interface.
|
||||
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// GetJumpHeight implements StateSync interface.
|
||||
func (s *FakeStateSync) GetJumpHeight() (uint32, error) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/native/noderoles"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/stateroot"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/statesync"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
|
@ -409,6 +410,64 @@ func (bc *Blockchain) init() error {
|
|||
return bc.updateExtensibleWhitelist(bHeight)
|
||||
}
|
||||
|
||||
// JumpToState is an atomic operation that changes Blockchain state to the one
|
||||
// specified by the state sync point p. All the data needed for the jump must be
|
||||
// collected by the state sync module.
|
||||
func (bc *Blockchain) JumpToState(module blockchainer.StateSync) error {
|
||||
bc.lock.Lock()
|
||||
defer bc.lock.Unlock()
|
||||
|
||||
p, err := module.GetJumpHeight()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get jump height: %w", err)
|
||||
}
|
||||
if p+1 >= uint32(len(bc.headerHashes)) {
|
||||
return fmt.Errorf("invalid state sync point")
|
||||
}
|
||||
|
||||
bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p))
|
||||
|
||||
block, err := bc.dao.GetBlock(bc.headerHashes[p])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current block: %w", err)
|
||||
}
|
||||
err = bc.dao.StoreAsCurrentBlock(block, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store current block: %w", err)
|
||||
}
|
||||
bc.topBlock.Store(block)
|
||||
atomic.StoreUint32(&bc.blockHeight, p)
|
||||
atomic.StoreUint32(&bc.persistedHeight, p)
|
||||
|
||||
block, err = bc.dao.GetBlock(bc.headerHashes[p+1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get block to init MPT: %w", err)
|
||||
}
|
||||
if err = bc.stateRoot.JumpToState(&state.MPTRoot{
|
||||
Index: p,
|
||||
Root: block.PrevStateRoot,
|
||||
}, bc.config.KeepOnlyLatestState); err != nil {
|
||||
return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err)
|
||||
}
|
||||
|
||||
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
|
||||
}
|
||||
err = bc.contracts.Management.InitializeCache(bc.dao)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't init cache for Management native contract: %w", err)
|
||||
}
|
||||
bc.contracts.Designate.InitializeCache()
|
||||
|
||||
if err := bc.updateExtensibleWhitelist(p); err != nil {
|
||||
return fmt.Errorf("failed to update extensible whitelist: %w", err)
|
||||
}
|
||||
|
||||
updateBlockHeightMetric(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run runs chain loop, it needs to be run as goroutine and executing it is
|
||||
// critical for correct Blockchain operation.
|
||||
func (bc *Blockchain) Run() {
|
||||
|
@ -696,6 +755,11 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot {
|
|||
return bc.stateRoot
|
||||
}
|
||||
|
||||
// GetStateSyncModule returns new state sync service instance.
|
||||
func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync {
|
||||
return statesync.NewModule(bc, bc.log, bc.dao)
|
||||
}
|
||||
|
||||
// storeBlock performs chain update using the block given, it executes all
|
||||
// transactions with all appropriate side-effects and updates Blockchain state.
|
||||
// This is the only way to change Blockchain state.
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
type Blockchainer interface {
|
||||
ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction
|
||||
GetConfig() config.ProtocolConfiguration
|
||||
AddHeaders(...*block.Header) error
|
||||
Blockqueuer // Blockqueuer interface
|
||||
CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error)
|
||||
Close()
|
||||
|
@ -56,10 +55,12 @@ type Blockchainer interface {
|
|||
GetStandByCommittee() keys.PublicKeys
|
||||
GetStandByValidators() keys.PublicKeys
|
||||
GetStateModule() StateRoot
|
||||
GetStateSyncModule() StateSync
|
||||
GetStorageItem(id int32, key []byte) state.StorageItem
|
||||
GetStorageItems(id int32) (map[string]state.StorageItem, error)
|
||||
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM
|
||||
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
|
||||
JumpToState(module StateSync) error
|
||||
SetOracle(service services.Oracle)
|
||||
mempool.Feer // fee interface
|
||||
ManagementContractHash() util.Uint160
|
||||
|
|
|
@ -5,5 +5,6 @@ import "github.com/nspcc-dev/neo-go/pkg/core/block"
|
|||
// Blockqueuer is an interface for blockqueue.
|
||||
type Blockqueuer interface {
|
||||
AddBlock(block *block.Block) error
|
||||
AddHeaders(...*block.Header) error
|
||||
BlockHeight() uint32
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
// StateRoot represents local state root module.
|
||||
type StateRoot interface {
|
||||
AddStateRoot(root *state.MPTRoot) error
|
||||
CurrentLocalHeight() uint32
|
||||
CurrentLocalStateRoot() util.Uint256
|
||||
CurrentValidatedHeight() uint32
|
||||
GetStateProof(root util.Uint256, key []byte) ([][]byte, error)
|
||||
|
|
19
pkg/core/blockchainer/state_sync.go
Normal file
19
pkg/core/blockchainer/state_sync.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package blockchainer
|
||||
|
||||
import (
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// StateSync represents state sync module.
|
||||
type StateSync interface {
|
||||
AddMPTNodes([][]byte) error
|
||||
Blockqueuer // Blockqueuer interface
|
||||
Init(currChainHeight uint32) error
|
||||
IsActive() bool
|
||||
IsInitialized() bool
|
||||
GetJumpHeight() (uint32, error)
|
||||
NeedHeaders() bool
|
||||
NeedMPTNodes() bool
|
||||
Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
|
||||
}
|
|
@ -1,5 +1,7 @@
|
|||
package mpt
|
||||
|
||||
import "github.com/nspcc-dev/neo-go/pkg/util"
|
||||
|
||||
// lcp returns longest common prefix of a and b.
|
||||
// Note: it does no allocations.
|
||||
func lcp(a, b []byte) []byte {
|
||||
|
@ -49,3 +51,36 @@ func fromNibbles(path []byte) []byte {
|
|||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetChildrenPaths returns a set of paths to node's children who are non-empty HashNodes
|
||||
// based on the node's path.
|
||||
func GetChildrenPaths(path []byte, node Node) map[util.Uint256][][]byte {
|
||||
res := make(map[util.Uint256][][]byte)
|
||||
switch n := node.(type) {
|
||||
case *LeafNode, *HashNode, EmptyNode:
|
||||
return nil
|
||||
case *BranchNode:
|
||||
for i, child := range n.Children {
|
||||
if child.Type() == HashT {
|
||||
cPath := make([]byte, len(path), len(path)+1)
|
||||
copy(cPath, path)
|
||||
if i != lastChild {
|
||||
cPath = append(cPath, byte(i))
|
||||
}
|
||||
paths := res[child.Hash()]
|
||||
paths = append(paths, cPath)
|
||||
res[child.Hash()] = paths
|
||||
}
|
||||
}
|
||||
case *ExtensionNode:
|
||||
if n.next.Type() == HashT {
|
||||
cPath := make([]byte, len(path)+len(n.key))
|
||||
copy(cPath, path)
|
||||
copy(cPath[len(path):], n.key)
|
||||
res[n.next.Hash()] = [][]byte{cPath}
|
||||
}
|
||||
default:
|
||||
panic("unknown Node type")
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package mpt
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -18,3 +19,49 @@ func TestToNibblesFromNibbles(t *testing.T) {
|
|||
check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetChildrenPaths(t *testing.T) {
|
||||
h1 := NewHashNode(util.Uint256{1, 2, 3})
|
||||
h2 := NewHashNode(util.Uint256{4, 5, 6})
|
||||
h3 := NewHashNode(util.Uint256{7, 8, 9})
|
||||
l := NewLeafNode([]byte{1, 2, 3})
|
||||
ext1 := NewExtensionNode([]byte{8, 9}, h1)
|
||||
ext2 := NewExtensionNode([]byte{7, 6}, l)
|
||||
branch := NewBranchNode()
|
||||
branch.Children[3] = h1
|
||||
branch.Children[5] = l
|
||||
branch.Children[6] = h1 // 3-th and 6-th children have the same hash
|
||||
branch.Children[7] = h3
|
||||
branch.Children[lastChild] = h2
|
||||
testCases := map[string]struct {
|
||||
node Node
|
||||
expected map[util.Uint256][][]byte
|
||||
}{
|
||||
"Hash": {h1, nil},
|
||||
"Leaf": {l, nil},
|
||||
"Extension with next Hash": {ext1, map[util.Uint256][][]byte{h1.Hash(): {ext1.key}}},
|
||||
"Extension with next non-Hash": {ext2, map[util.Uint256][][]byte{}},
|
||||
"Branch": {branch, map[util.Uint256][][]byte{
|
||||
h1.Hash(): {{0x03}, {0x06}},
|
||||
h2.Hash(): {{}},
|
||||
h3.Hash(): {{0x07}},
|
||||
}},
|
||||
}
|
||||
parentPath := []byte{4, 5, 6}
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
require.Equal(t, testCase.expected, GetChildrenPaths([]byte{}, testCase.node))
|
||||
if testCase.expected != nil {
|
||||
expectedWithPrefix := make(map[util.Uint256][][]byte, len(testCase.expected))
|
||||
for h, paths := range testCase.expected {
|
||||
var res [][]byte
|
||||
for _, path := range paths {
|
||||
res = append(res, append(parentPath, path...))
|
||||
}
|
||||
expectedWithPrefix[h] = res
|
||||
}
|
||||
require.Equal(t, expectedWithPrefix, GetChildrenPaths(parentPath, testCase.node))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newProofTrie(t *testing.T) *Trie {
|
||||
func newProofTrie(t *testing.T, missingHashNode bool) *Trie {
|
||||
l := NewLeafNode([]byte("somevalue"))
|
||||
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
|
||||
l2 := NewLeafNode([]byte("invalid"))
|
||||
|
@ -20,11 +20,14 @@ func newProofTrie(t *testing.T) *Trie {
|
|||
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
|
||||
tr.putToStore(l)
|
||||
tr.putToStore(e)
|
||||
if !missingHashNode {
|
||||
tr.putToStore(l2)
|
||||
}
|
||||
return tr
|
||||
}
|
||||
|
||||
func TestTrie_GetProof(t *testing.T) {
|
||||
tr := newProofTrie(t)
|
||||
tr := newProofTrie(t, true)
|
||||
|
||||
t.Run("MissingKey", func(t *testing.T) {
|
||||
_, err := tr.GetProof([]byte{0x12})
|
||||
|
@ -43,7 +46,7 @@ func TestTrie_GetProof(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestVerifyProof(t *testing.T) {
|
||||
tr := newProofTrie(t)
|
||||
tr := newProofTrie(t, true)
|
||||
|
||||
t.Run("Simple", func(t *testing.T) {
|
||||
proof, err := tr.GetProof([]byte{0x12, 0x32})
|
||||
|
|
|
@ -353,3 +353,8 @@ func (s *Designate) getRole(item stackitem.Item) (noderoles.Role, bool) {
|
|||
u := bi.Uint64()
|
||||
return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u))
|
||||
}
|
||||
|
||||
// InitializeCache invalidates native Designate cache.
|
||||
func (s *Designate) InitializeCache() {
|
||||
s.rolesChangedFlag.Store(true)
|
||||
}
|
||||
|
|
|
@ -114,6 +114,25 @@ func (s *Module) Init(height uint32, enableRefCount bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// JumpToState performs jump to the state specified by given stateroot index.
|
||||
func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error {
|
||||
if err := s.addLocalStateRoot(s.Store, sr); err != nil {
|
||||
return fmt.Errorf("failed to store local state root: %w", err)
|
||||
}
|
||||
|
||||
data := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(data, sr.Index)
|
||||
if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixValidated}, data); err != nil {
|
||||
return fmt.Errorf("failed to store validated height: %w", err)
|
||||
}
|
||||
s.validatedHeight.Store(sr.Index)
|
||||
|
||||
s.currentLocal.Store(sr.Root)
|
||||
s.localHeight.Store(sr.Index)
|
||||
s.mpt = mpt.NewTrie(mpt.NewHashNode(sr.Root), enableRefCount, s.Store)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMPTBatch updates using provided batch.
|
||||
func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) {
|
||||
mpt := *s.mpt
|
||||
|
|
440
pkg/core/statesync/module.go
Normal file
440
pkg/core/statesync/module.go
Normal file
|
@ -0,0 +1,440 @@
|
|||
/*
|
||||
Package statesync implements module for the P2P state synchronisation process. The
|
||||
module manages state synchronisation for non-archival nodes which are joining the
|
||||
network and don't have the ability to resync from the genesis block.
|
||||
|
||||
Given the currently available state synchronisation point P, sate sync process
|
||||
includes the following stages:
|
||||
|
||||
1. Fetching headers starting from height 0 up to P+1.
|
||||
2. Fetching MPT nodes for height P stating from the corresponding state root.
|
||||
3. Fetching blocks starting from height P-MaxTraceableBlocks (or 0) up to P.
|
||||
|
||||
Steps 2 and 3 are being performed in parallel. Once all the data are collected
|
||||
and stored in the db, an atomic state jump is occurred to the state sync point P.
|
||||
Further node operation process is performed using standard sync mechanism until
|
||||
the node reaches synchronised state.
|
||||
*/
|
||||
package statesync
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// stateSyncStage is a type of state synchronisation stage.
|
||||
type stateSyncStage uint8
|
||||
|
||||
const (
|
||||
// inactive means that state exchange is disabled by the protocol configuration.
|
||||
// Can't be combined with other states.
|
||||
inactive stateSyncStage = 1 << iota
|
||||
// none means that state exchange is enabled in the configuration, but
|
||||
// initialisation of the state sync module wasn't yet performed, i.e.
|
||||
// (*Module).Init wasn't called. Can't be combined with other states.
|
||||
none
|
||||
// initialized means that (*Module).Init was called, but other sync stages
|
||||
// are not yet reached (i.e. that headers are requested, but not yet fetched).
|
||||
// Can't be combined with other states.
|
||||
initialized
|
||||
// headersSynced means that headers for the current state sync point are fetched.
|
||||
// May be combined with mptSynced and/or blocksSynced.
|
||||
headersSynced
|
||||
// mptSynced means that MPT nodes for the current state sync point are fetched.
|
||||
// Always combined with headersSynced; may be combined with blocksSynced.
|
||||
mptSynced
|
||||
// blocksSynced means that blocks up to the current state sync point are stored.
|
||||
// Always combined with headersSynced; may be combined with mptSynced.
|
||||
blocksSynced
|
||||
)
|
||||
|
||||
// Module represents state sync module and aimed to gather state-related data to
|
||||
// perform an atomic state jump.
|
||||
type Module struct {
|
||||
lock sync.RWMutex
|
||||
log *zap.Logger
|
||||
|
||||
// syncPoint is the state synchronisation point P we're currently working against.
|
||||
syncPoint uint32
|
||||
// syncStage is the stage of the sync process.
|
||||
syncStage stateSyncStage
|
||||
// syncInterval is the delta between two adjacent state sync points.
|
||||
syncInterval uint32
|
||||
// blockHeight is the index of the latest stored block.
|
||||
blockHeight uint32
|
||||
|
||||
dao *dao.Simple
|
||||
bc blockchainer.Blockchainer
|
||||
mptpool *Pool
|
||||
|
||||
billet *mpt.Billet
|
||||
}
|
||||
|
||||
// NewModule returns new instance of statesync module.
|
||||
func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple) *Module {
|
||||
if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) {
|
||||
return &Module{
|
||||
dao: s,
|
||||
bc: bc,
|
||||
syncStage: inactive,
|
||||
}
|
||||
}
|
||||
return &Module{
|
||||
dao: s,
|
||||
bc: bc,
|
||||
log: log,
|
||||
syncInterval: uint32(bc.GetConfig().StateSyncInterval),
|
||||
mptpool: NewPool(),
|
||||
syncStage: none,
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes state sync module for the current chain's height with given
|
||||
// callback for MPT nodes requests.
|
||||
func (s *Module) Init(currChainHeight uint32) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.syncStage != none {
|
||||
return errors.New("already initialized or inactive")
|
||||
}
|
||||
|
||||
p := (currChainHeight / s.syncInterval) * s.syncInterval
|
||||
if p < 2*s.syncInterval {
|
||||
// chain is too low to start state exchange process, use the standard sync mechanism
|
||||
s.syncStage = inactive
|
||||
return nil
|
||||
}
|
||||
pOld, err := s.dao.GetStateSyncPoint()
|
||||
if err == nil && pOld >= p-s.syncInterval {
|
||||
// old point is still valid, so try to resync states for this point.
|
||||
p = pOld
|
||||
} else if s.bc.BlockHeight() > p-2*s.syncInterval {
|
||||
// chain has already been synchronised up to old state sync point and regular blocks processing was started
|
||||
s.syncStage = inactive
|
||||
return nil
|
||||
}
|
||||
|
||||
s.syncPoint = p
|
||||
err = s.dao.PutStateSyncPoint(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store state synchronisation point %d: %w", p, err)
|
||||
}
|
||||
s.syncStage = initialized
|
||||
s.log.Info("try to sync state for the latest state synchronisation point",
|
||||
zap.Uint32("point", p),
|
||||
zap.Uint32("evaluated chain's blockHeight", currChainHeight))
|
||||
|
||||
// check headers sync state first
|
||||
ltstHeaderHeight := s.bc.HeaderHeight()
|
||||
if ltstHeaderHeight > p {
|
||||
s.syncStage = headersSynced
|
||||
s.log.Info("headers are in sync",
|
||||
zap.Uint32("headerHeight", s.bc.HeaderHeight()))
|
||||
}
|
||||
|
||||
// check blocks sync state
|
||||
s.blockHeight = s.getLatestSavedBlock(p)
|
||||
if s.blockHeight >= p {
|
||||
s.syncStage |= blocksSynced
|
||||
s.log.Info("blocks are in sync",
|
||||
zap.Uint32("blockHeight", s.blockHeight))
|
||||
}
|
||||
|
||||
// check MPT sync state
|
||||
if s.blockHeight > p {
|
||||
s.syncStage |= mptSynced
|
||||
s.log.Info("MPT is in sync",
|
||||
zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight()))
|
||||
} else if s.syncStage&headersSynced != 0 {
|
||||
header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(p + 1)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get header to initialize MPT billet: %w", err)
|
||||
}
|
||||
s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store)
|
||||
s.log.Info("MPT billet initialized",
|
||||
zap.Uint32("height", s.syncPoint),
|
||||
zap.String("state root", header.PrevStateRoot.StringBE()))
|
||||
pool := NewPool()
|
||||
pool.Add(header.PrevStateRoot, []byte{})
|
||||
err = s.billet.Traverse(func(n mpt.Node, _ []byte) bool {
|
||||
nPaths, ok := pool.TryGet(n.Hash())
|
||||
if !ok {
|
||||
// if this situation occurs, then it's a bug in MPT pool or Traverse.
|
||||
panic("failed to get MPT node from the pool")
|
||||
}
|
||||
pool.Remove(n.Hash())
|
||||
childrenPaths := make(map[util.Uint256][][]byte)
|
||||
for _, path := range nPaths {
|
||||
nChildrenPaths := mpt.GetChildrenPaths(path, n)
|
||||
for hash, paths := range nChildrenPaths {
|
||||
childrenPaths[hash] = append(childrenPaths[hash], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||
}
|
||||
}
|
||||
pool.Update(nil, childrenPaths)
|
||||
return false
|
||||
}, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to traverse MPT while initialization: %w", err)
|
||||
}
|
||||
s.mptpool.Update(nil, pool.GetAll())
|
||||
if s.mptpool.Count() == 0 {
|
||||
s.syncStage |= mptSynced
|
||||
s.log.Info("MPT is in sync",
|
||||
zap.Uint32("stateroot height", p))
|
||||
}
|
||||
}
|
||||
|
||||
if s.syncStage == headersSynced|blocksSynced|mptSynced {
|
||||
s.log.Info("state is in sync, starting regular blocks processing")
|
||||
s.syncStage = inactive
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLatestSavedBlock returns either current block index (if it's still relevant
|
||||
// to continue state sync process) or H-1 where H is the index of the earliest
|
||||
// block that should be saved next.
|
||||
func (s *Module) getLatestSavedBlock(p uint32) uint32 {
|
||||
var result uint32
|
||||
mtb := s.bc.GetConfig().MaxTraceableBlocks
|
||||
if p > mtb {
|
||||
result = p - mtb
|
||||
}
|
||||
storedH, err := s.dao.GetStateSyncCurrentBlockHeight()
|
||||
if err == nil && storedH > result {
|
||||
result = storedH
|
||||
}
|
||||
actualH := s.bc.BlockHeight()
|
||||
if actualH > result {
|
||||
result = actualH
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AddHeaders validates and adds specified headers to the chain.
|
||||
func (s *Module) AddHeaders(hdrs ...*block.Header) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.syncStage != initialized {
|
||||
return errors.New("headers were not requested")
|
||||
}
|
||||
|
||||
hdrsErr := s.bc.AddHeaders(hdrs...)
|
||||
if s.bc.HeaderHeight() > s.syncPoint {
|
||||
s.syncStage = headersSynced
|
||||
s.log.Info("headers for state sync are fetched",
|
||||
zap.Uint32("header height", s.bc.HeaderHeight()))
|
||||
|
||||
header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint) + 1))
|
||||
if err != nil {
|
||||
s.log.Fatal("failed to get header to initialize MPT billet",
|
||||
zap.Uint32("height", s.syncPoint+1),
|
||||
zap.Error(err))
|
||||
}
|
||||
s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store)
|
||||
s.mptpool.Add(header.PrevStateRoot, []byte{})
|
||||
s.log.Info("MPT billet initialized",
|
||||
zap.Uint32("height", s.syncPoint),
|
||||
zap.String("state root", header.PrevStateRoot.StringBE()))
|
||||
}
|
||||
return hdrsErr
|
||||
}
|
||||
|
||||
// AddBlock verifies and saves block skipping executable scripts.
|
||||
func (s *Module) AddBlock(block *block.Block) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.syncStage&headersSynced == 0 || s.syncStage&blocksSynced != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.blockHeight == s.syncPoint {
|
||||
return nil
|
||||
}
|
||||
expectedHeight := s.blockHeight + 1
|
||||
if expectedHeight != block.Index {
|
||||
return fmt.Errorf("expected %d, got %d: invalid block index", expectedHeight, block.Index)
|
||||
}
|
||||
if s.bc.GetConfig().StateRootInHeader != block.StateRootEnabled {
|
||||
return fmt.Errorf("stateroot setting mismatch: %v != %v", s.bc.GetConfig().StateRootInHeader, block.StateRootEnabled)
|
||||
}
|
||||
if s.bc.GetConfig().VerifyBlocks {
|
||||
merkle := block.ComputeMerkleRoot()
|
||||
if !block.MerkleRoot.Equals(merkle) {
|
||||
return errors.New("invalid block: MerkleRoot mismatch")
|
||||
}
|
||||
}
|
||||
cache := s.dao.GetWrapped()
|
||||
writeBuf := io.NewBufBinWriter()
|
||||
if err := cache.StoreAsBlock(block, writeBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
writeBuf.Reset()
|
||||
|
||||
err := cache.PutStateSyncCurrentBlockHeight(block.Index)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store current block height: %w", err)
|
||||
}
|
||||
|
||||
for _, tx := range block.Transactions {
|
||||
if err := cache.StoreAsTransaction(tx, block.Index, writeBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
writeBuf.Reset()
|
||||
}
|
||||
|
||||
_, err = cache.Persist()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to persist results: %w", err)
|
||||
}
|
||||
s.blockHeight = block.Index
|
||||
if s.blockHeight == s.syncPoint {
|
||||
s.syncStage |= blocksSynced
|
||||
s.log.Info("blocks are in sync",
|
||||
zap.Uint32("blockHeight", s.blockHeight))
|
||||
s.checkSyncIsCompleted()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMPTNodes tries to add provided set of MPT nodes to the MPT billet if they are
|
||||
// not yet collected.
|
||||
func (s *Module) AddMPTNodes(nodes [][]byte) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.syncStage&headersSynced == 0 || s.syncStage&mptSynced != 0 {
|
||||
return errors.New("MPT nodes were not requested")
|
||||
}
|
||||
|
||||
for _, nBytes := range nodes {
|
||||
var n mpt.NodeObject
|
||||
r := io.NewBinReaderFromBuf(nBytes)
|
||||
n.DecodeBinary(r)
|
||||
if r.Err != nil {
|
||||
return fmt.Errorf("failed to decode MPT node: %w", r.Err)
|
||||
}
|
||||
nPaths, ok := s.mptpool.TryGet(n.Hash())
|
||||
if !ok {
|
||||
// it can easily happen after receiving the same data from different peers.
|
||||
return nil
|
||||
}
|
||||
|
||||
var childrenPaths = make(map[util.Uint256][][]byte)
|
||||
for _, path := range nPaths {
|
||||
err := s.billet.RestoreHashNode(path, n.Node)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
|
||||
}
|
||||
for h, paths := range mpt.GetChildrenPaths(path, n.Node) {
|
||||
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||
}
|
||||
}
|
||||
|
||||
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
|
||||
}
|
||||
if s.mptpool.Count() == 0 {
|
||||
s.syncStage |= mptSynced
|
||||
s.log.Info("MPT is in sync",
|
||||
zap.Uint32("height", s.syncPoint))
|
||||
s.checkSyncIsCompleted()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1
|
||||
// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored.
|
||||
// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller
|
||||
// should take care of it.
|
||||
func (s *Module) checkSyncIsCompleted() {
|
||||
if s.syncStage != headersSynced|mptSynced|blocksSynced {
|
||||
return
|
||||
}
|
||||
s.log.Info("state is in sync",
|
||||
zap.Uint32("state sync point", s.syncPoint))
|
||||
err := s.bc.JumpToState(s)
|
||||
if err != nil {
|
||||
s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err))
|
||||
}
|
||||
s.syncStage = inactive
|
||||
s.dispose()
|
||||
}
|
||||
|
||||
func (s *Module) dispose() {
|
||||
s.billet = nil
|
||||
}
|
||||
|
||||
// BlockHeight returns index of the last stored block.
|
||||
func (s *Module) BlockHeight() uint32 {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.blockHeight
|
||||
}
|
||||
|
||||
// IsActive tells whether state sync module is on and still gathering state
|
||||
// synchronisation data (headers, blocks or MPT nodes).
|
||||
func (s *Module) IsActive() bool {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return !(s.syncStage == inactive || (s.syncStage == headersSynced|mptSynced|blocksSynced))
|
||||
}
|
||||
|
||||
// IsInitialized tells whether state sync module does not require initialization.
|
||||
// If `false` is returned then Init can be safely called.
|
||||
func (s *Module) IsInitialized() bool {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.syncStage != none
|
||||
}
|
||||
|
||||
// NeedHeaders tells whether the module hasn't completed headers synchronisation.
|
||||
func (s *Module) NeedHeaders() bool {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.syncStage == initialized
|
||||
}
|
||||
|
||||
// NeedMPTNodes returns whether the module hasn't completed MPT synchronisation.
|
||||
func (s *Module) NeedMPTNodes() bool {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.syncStage&headersSynced != 0 && s.syncStage&mptSynced == 0
|
||||
}
|
||||
|
||||
// Traverse traverses local MPT nodes starting from the specified root down to its
|
||||
// children calling `process` for each serialised node until stop condition is satisfied.
|
||||
func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
b := mpt.NewBillet(root, s.bc.GetConfig().KeepOnlyLatestState, storage.NewMemCachedStore(s.dao.Store))
|
||||
return b.Traverse(process, false)
|
||||
}
|
||||
|
||||
// GetJumpHeight returns state sync point to jump to. It is not protected by mutex and should be called
|
||||
// under the module lock.
|
||||
func (s *Module) GetJumpHeight() (uint32, error) {
|
||||
if s.syncStage != headersSynced|mptSynced|blocksSynced {
|
||||
return 0, errors.New("state sync module has wong state to perform state jump")
|
||||
}
|
||||
return s.syncPoint, nil
|
||||
}
|
119
pkg/core/statesync/mptpool.go
Normal file
119
pkg/core/statesync/mptpool.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package statesync
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// Pool stores unknown MPT nodes along with the corresponding paths (single node is
|
||||
// allowed to have multiple MPT paths).
|
||||
type Pool struct {
|
||||
lock sync.RWMutex
|
||||
hashes map[util.Uint256][][]byte
|
||||
}
|
||||
|
||||
// NewPool returns new MPT node hashes pool.
|
||||
func NewPool() *Pool {
|
||||
return &Pool{
|
||||
hashes: make(map[util.Uint256][][]byte),
|
||||
}
|
||||
}
|
||||
|
||||
// ContainsKey checks if MPT node with the specified hash is in the Pool.
|
||||
func (mp *Pool) ContainsKey(hash util.Uint256) bool {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
_, ok := mp.hashes[hash]
|
||||
return ok
|
||||
}
|
||||
|
||||
// TryGet returns a set of MPT paths for the specified HashNode.
|
||||
func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
paths, ok := mp.hashes[hash]
|
||||
return paths, ok
|
||||
}
|
||||
|
||||
// GetAll returns all MPT nodes with the corresponding paths from the pool.
|
||||
func (mp *Pool) GetAll() map[util.Uint256][][]byte {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
return mp.hashes
|
||||
}
|
||||
|
||||
// Remove removes MPT node from the pool by the specified hash.
|
||||
func (mp *Pool) Remove(hash util.Uint256) {
|
||||
mp.lock.Lock()
|
||||
defer mp.lock.Unlock()
|
||||
|
||||
delete(mp.hashes, hash)
|
||||
}
|
||||
|
||||
// Add adds path to the set of paths for the specified node.
|
||||
func (mp *Pool) Add(hash util.Uint256, path []byte) {
|
||||
mp.lock.Lock()
|
||||
defer mp.lock.Unlock()
|
||||
|
||||
mp.addPaths(hash, [][]byte{path})
|
||||
}
|
||||
|
||||
// Update is an atomic operation and removes/adds specified nodes from/to the pool.
|
||||
func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) {
|
||||
mp.lock.Lock()
|
||||
defer mp.lock.Unlock()
|
||||
|
||||
for h, paths := range remove {
|
||||
old := mp.hashes[h]
|
||||
for _, path := range paths {
|
||||
i := sort.Search(len(old), func(i int) bool {
|
||||
return bytes.Compare(old[i], path) >= 0
|
||||
})
|
||||
if i < len(old) && bytes.Equal(old[i], path) {
|
||||
old = append(old[:i], old[i+1:]...)
|
||||
}
|
||||
}
|
||||
if len(old) == 0 {
|
||||
delete(mp.hashes, h)
|
||||
} else {
|
||||
mp.hashes[h] = old
|
||||
}
|
||||
}
|
||||
for h, paths := range add {
|
||||
mp.addPaths(h, paths)
|
||||
}
|
||||
}
|
||||
|
||||
// addPaths adds set of the specified node paths to the pool.
|
||||
func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) {
|
||||
old := mp.hashes[nodeHash]
|
||||
for _, path := range paths {
|
||||
i := sort.Search(len(old), func(i int) bool {
|
||||
return bytes.Compare(old[i], path) >= 0
|
||||
})
|
||||
if i < len(old) && bytes.Equal(old[i], path) {
|
||||
// then path is already added
|
||||
continue
|
||||
}
|
||||
old = append(old, path)
|
||||
if i != len(old)-1 {
|
||||
copy(old[i+1:], old[i:])
|
||||
old[i] = path
|
||||
}
|
||||
}
|
||||
mp.hashes[nodeHash] = old
|
||||
}
|
||||
|
||||
// Count returns the number of nodes in the pool.
|
||||
func (mp *Pool) Count() int {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
return len(mp.hashes)
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"github.com/Workiva/go-datastructures/queue"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
@ -13,6 +14,7 @@ type blockQueue struct {
|
|||
checkBlocks chan struct{}
|
||||
chain blockchainer.Blockqueuer
|
||||
relayF func(*block.Block)
|
||||
discarded *atomic.Bool
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -32,6 +34,7 @@ func newBlockQueue(capacity int, bc blockchainer.Blockqueuer, log *zap.Logger, r
|
|||
checkBlocks: make(chan struct{}, 1),
|
||||
chain: bc,
|
||||
relayF: relayer,
|
||||
discarded: atomic.NewBool(false),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,8 +94,10 @@ func (bq *blockQueue) putBlock(block *block.Block) error {
|
|||
}
|
||||
|
||||
func (bq *blockQueue) discard() {
|
||||
close(bq.checkBlocks)
|
||||
bq.queue.Dispose()
|
||||
if bq.discarded.CAS(false, true) {
|
||||
close(bq.checkBlocks)
|
||||
bq.queue.Dispose()
|
||||
}
|
||||
}
|
||||
|
||||
func (bq *blockQueue) length() int {
|
||||
|
|
|
@ -71,6 +71,8 @@ const (
|
|||
CMDBlock = CommandType(payload.BlockType)
|
||||
CMDExtensible = CommandType(payload.ExtensibleType)
|
||||
CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType)
|
||||
CMDGetMPTData CommandType = 0x51 // 0x5.. commands are used for extensions (P2PNotary, state exchange cmds)
|
||||
CMDMPTData CommandType = 0x52
|
||||
CMDReject CommandType = 0x2f
|
||||
|
||||
// SPV protocol.
|
||||
|
@ -136,6 +138,10 @@ func (m *Message) decodePayload() error {
|
|||
p = &payload.Version{}
|
||||
case CMDInv, CMDGetData:
|
||||
p = &payload.Inventory{}
|
||||
case CMDGetMPTData:
|
||||
p = &payload.MPTInventory{}
|
||||
case CMDMPTData:
|
||||
p = &payload.MPTData{}
|
||||
case CMDAddr:
|
||||
p = &payload.AddressList{}
|
||||
case CMDBlock:
|
||||
|
@ -221,7 +227,7 @@ func (m *Message) tryCompressPayload() error {
|
|||
if m.Flags&Compressed == 0 {
|
||||
switch m.Payload.(type) {
|
||||
case *payload.Headers, *payload.MerkleBlock, payload.NullPayload,
|
||||
*payload.Inventory:
|
||||
*payload.Inventory, *payload.MPTInventory:
|
||||
break
|
||||
default:
|
||||
size := len(compressedPayload)
|
||||
|
|
|
@ -26,6 +26,8 @@ func _() {
|
|||
_ = x[CMDBlock-44]
|
||||
_ = x[CMDExtensible-46]
|
||||
_ = x[CMDP2PNotaryRequest-80]
|
||||
_ = x[CMDGetMPTData-81]
|
||||
_ = x[CMDMPTData-82]
|
||||
_ = x[CMDReject-47]
|
||||
_ = x[CMDFilterLoad-48]
|
||||
_ = x[CMDFilterAdd-49]
|
||||
|
@ -44,7 +46,7 @@ const (
|
|||
_CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
|
||||
_CommandType_name_7 = "CMDMerkleBlock"
|
||||
_CommandType_name_8 = "CMDAlert"
|
||||
_CommandType_name_9 = "CMDP2PNotaryRequest"
|
||||
_CommandType_name_9 = "CMDP2PNotaryRequestCMDGetMPTDataCMDMPTData"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -55,6 +57,7 @@ var (
|
|||
_CommandType_index_4 = [...]uint8{0, 12, 22}
|
||||
_CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58}
|
||||
_CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61}
|
||||
_CommandType_index_9 = [...]uint8{0, 19, 32, 42}
|
||||
)
|
||||
|
||||
func (i CommandType) String() string {
|
||||
|
@ -83,8 +86,9 @@ func (i CommandType) String() string {
|
|||
return _CommandType_name_7
|
||||
case i == 64:
|
||||
return _CommandType_name_8
|
||||
case i == 80:
|
||||
return _CommandType_name_9
|
||||
case 80 <= i && i <= 82:
|
||||
i -= 80
|
||||
return _CommandType_name_9[_CommandType_index_9[i]:_CommandType_index_9[i+1]]
|
||||
default:
|
||||
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
|
|
|
@ -258,6 +258,21 @@ func TestEncodeDecodeNotFound(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestEncodeDecodeGetMPTData(t *testing.T) {
|
||||
testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{
|
||||
Hashes: []util.Uint256{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestEncodeDecodeMPTData(t *testing.T) {
|
||||
testEncodeDecode(t, CMDMPTData, &payload.MPTData{
|
||||
Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}},
|
||||
})
|
||||
}
|
||||
|
||||
func TestInvalidMessages(t *testing.T) {
|
||||
t.Run("CMDBlock, empty payload", func(t *testing.T) {
|
||||
testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})
|
||||
|
|
35
pkg/network/payload/mptdata.go
Normal file
35
pkg/network/payload/mptdata.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
)
|
||||
|
||||
// MPTData represents the set of serialized MPT nodes.
|
||||
type MPTData struct {
|
||||
Nodes [][]byte
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (d *MPTData) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteVarUint(uint64(len(d.Nodes)))
|
||||
for _, n := range d.Nodes {
|
||||
w.WriteVarBytes(n)
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (d *MPTData) DecodeBinary(r *io.BinReader) {
|
||||
sz := r.ReadVarUint()
|
||||
if sz == 0 {
|
||||
r.Err = errors.New("empty MPT nodes list")
|
||||
return
|
||||
}
|
||||
for i := uint64(0); i < sz; i++ {
|
||||
d.Nodes = append(d.Nodes, r.ReadVarBytes())
|
||||
if r.Err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
24
pkg/network/payload/mptdata_test.go
Normal file
24
pkg/network/payload/mptdata_test.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMPTData_EncodeDecodeBinary(t *testing.T) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
d := new(MPTData)
|
||||
bytes, err := testserdes.EncodeBinary(d)
|
||||
require.NoError(t, err)
|
||||
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTData)))
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
d := &MPTData{
|
||||
Nodes: [][]byte{{}, {1}, {1, 2, 3}},
|
||||
}
|
||||
testserdes.EncodeDecodeBinary(t, d, new(MPTData))
|
||||
})
|
||||
}
|
32
pkg/network/payload/mptinventory.go
Normal file
32
pkg/network/payload/mptinventory.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// MaxMPTHashesCount is the maximum number of requested MPT nodes hashes.
|
||||
const MaxMPTHashesCount = 32
|
||||
|
||||
// MPTInventory payload.
|
||||
type MPTInventory struct {
|
||||
// A list of requested MPT nodes hashes.
|
||||
Hashes []util.Uint256
|
||||
}
|
||||
|
||||
// NewMPTInventory return a pointer to an MPTInventory.
|
||||
func NewMPTInventory(hashes []util.Uint256) *MPTInventory {
|
||||
return &MPTInventory{
|
||||
Hashes: hashes,
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeBinary implements Serializable interface.
|
||||
func (p *MPTInventory) DecodeBinary(br *io.BinReader) {
|
||||
br.ReadArray(&p.Hashes, MaxMPTHashesCount)
|
||||
}
|
||||
|
||||
// EncodeBinary implements Serializable interface.
|
||||
func (p *MPTInventory) EncodeBinary(bw *io.BinWriter) {
|
||||
bw.WriteArray(p.Hashes)
|
||||
}
|
38
pkg/network/payload/mptinventory_test.go
Normal file
38
pkg/network/payload/mptinventory_test.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMPTInventory_EncodeDecodeBinary(t *testing.T) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
testserdes.EncodeDecodeBinary(t, NewMPTInventory([]util.Uint256{}), new(MPTInventory))
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
inv := NewMPTInventory([]util.Uint256{{1, 2, 3}, {2, 3, 4}})
|
||||
testserdes.EncodeDecodeBinary(t, inv, new(MPTInventory))
|
||||
})
|
||||
|
||||
t.Run("too large", func(t *testing.T) {
|
||||
check := func(t *testing.T, count int, fail bool) {
|
||||
h := make([]util.Uint256, count)
|
||||
for i := range h {
|
||||
h[i] = util.Uint256{1, 2, 3}
|
||||
}
|
||||
if fail {
|
||||
bytes, err := testserdes.EncodeBinary(NewMPTInventory(h))
|
||||
require.NoError(t, err)
|
||||
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTInventory)))
|
||||
} else {
|
||||
testserdes.EncodeDecodeBinary(t, NewMPTInventory(h), new(MPTInventory))
|
||||
}
|
||||
}
|
||||
check(t, MaxMPTHashesCount, false)
|
||||
check(t, MaxMPTHashesCount+1, true)
|
||||
})
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -17,7 +18,9 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/extpool"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||
|
@ -67,6 +70,7 @@ type (
|
|||
discovery Discoverer
|
||||
chain blockchainer.Blockchainer
|
||||
bQueue *blockQueue
|
||||
bSyncQueue *blockQueue
|
||||
consensus consensus.Service
|
||||
mempool *mempool.Pool
|
||||
notaryRequestPool *mempool.Pool
|
||||
|
@ -93,6 +97,7 @@ type (
|
|||
|
||||
oracle *oracle.Oracle
|
||||
stateRoot stateroot.Service
|
||||
stateSync blockchainer.StateSync
|
||||
|
||||
log *zap.Logger
|
||||
}
|
||||
|
@ -191,6 +196,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
|
|||
}
|
||||
s.stateRoot = sr
|
||||
|
||||
sSync := chain.GetStateSyncModule()
|
||||
s.stateSync = sSync
|
||||
s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil)
|
||||
|
||||
if config.OracleCfg.Enabled {
|
||||
orcCfg := oracle.Config{
|
||||
Log: log,
|
||||
|
@ -277,6 +286,7 @@ func (s *Server) Start(errChan chan error) {
|
|||
go s.broadcastTxLoop()
|
||||
go s.relayBlocksLoop()
|
||||
go s.bQueue.run()
|
||||
go s.bSyncQueue.run()
|
||||
go s.transport.Accept()
|
||||
setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10))
|
||||
s.run()
|
||||
|
@ -292,6 +302,7 @@ func (s *Server) Shutdown() {
|
|||
p.Disconnect(errServerShutdown)
|
||||
}
|
||||
s.bQueue.discard()
|
||||
s.bSyncQueue.discard()
|
||||
if s.StateRootCfg.Enabled {
|
||||
s.stateRoot.Shutdown()
|
||||
}
|
||||
|
@ -573,6 +584,10 @@ func (s *Server) IsInSync() bool {
|
|||
var peersNumber int
|
||||
var notHigher int
|
||||
|
||||
if s.stateSync.IsActive() {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.MinPeers == 0 {
|
||||
return true
|
||||
}
|
||||
|
@ -630,6 +645,9 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
|
|||
|
||||
// handleBlockCmd processes the received block received from its peer.
|
||||
func (s *Server) handleBlockCmd(p Peer, block *block.Block) error {
|
||||
if s.stateSync.IsActive() {
|
||||
return s.bSyncQueue.putBlock(block)
|
||||
}
|
||||
return s.bQueue.putBlock(block)
|
||||
}
|
||||
|
||||
|
@ -639,25 +657,46 @@ func (s *Server) handlePing(p Peer, ping *payload.Ping) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.chain.BlockHeight() < ping.LastBlockIndex {
|
||||
err = s.requestBlocks(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.requestBlocksOrHeaders(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id)))
|
||||
}
|
||||
|
||||
func (s *Server) requestBlocksOrHeaders(p Peer) error {
|
||||
if s.stateSync.NeedHeaders() {
|
||||
if s.chain.HeaderHeight() < p.LastBlockIndex() {
|
||||
return s.requestHeaders(p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var bq blockchainer.Blockqueuer = s.chain
|
||||
if s.stateSync.IsActive() {
|
||||
bq = s.stateSync
|
||||
}
|
||||
if bq.BlockHeight() < p.LastBlockIndex() {
|
||||
return s.requestBlocks(bq, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers.
|
||||
func (s *Server) requestHeaders(p Peer) error {
|
||||
// TODO: optimize
|
||||
currHeight := s.chain.HeaderHeight()
|
||||
needHeight := currHeight + 1
|
||||
payload := payload.NewGetBlockByIndex(needHeight, -1)
|
||||
return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload))
|
||||
}
|
||||
|
||||
// handlePing processes pong request.
|
||||
func (s *Server) handlePong(p Peer, pong *payload.Ping) error {
|
||||
err := p.HandlePong(pong)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.chain.BlockHeight() < pong.LastBlockIndex {
|
||||
return s.requestBlocks(p)
|
||||
}
|
||||
return nil
|
||||
return s.requestBlocksOrHeaders(p)
|
||||
}
|
||||
|
||||
// handleInvCmd processes the received inventory.
|
||||
|
@ -766,6 +805,50 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// handleGetMPTDataCmd processes the received MPT inventory.
|
||||
func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error {
|
||||
if !s.chain.GetConfig().P2PStateExchangeExtensions {
|
||||
return errors.New("GetMPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
|
||||
}
|
||||
if s.chain.GetConfig().KeepOnlyLatestState {
|
||||
// TODO: implement keeping MPT states for P1 and P2 height (#2095, #2152 related)
|
||||
return errors.New("GetMPTDataCMD was received, but only latest MPT state is supported")
|
||||
}
|
||||
resp := payload.MPTData{}
|
||||
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
|
||||
for _, h := range inv.Hashes {
|
||||
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
|
||||
break
|
||||
}
|
||||
err := s.stateSync.Traverse(h,
|
||||
func(_ mpt.Node, node []byte) bool {
|
||||
l := len(node)
|
||||
size := l + io.GetVarSize(l)
|
||||
if size > capLeft {
|
||||
return true
|
||||
}
|
||||
resp.Nodes = append(resp.Nodes, node)
|
||||
capLeft -= size
|
||||
return false
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to traverse MPT starting from %s: %w", h.StringBE(), err)
|
||||
}
|
||||
}
|
||||
if len(resp.Nodes) > 0 {
|
||||
msg := NewMessage(CMDMPTData, &resp)
|
||||
return p.EnqueueP2PMessage(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error {
|
||||
if !s.chain.GetConfig().P2PStateExchangeExtensions {
|
||||
return errors.New("MPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
|
||||
}
|
||||
return s.stateSync.AddMPTNodes(data.Nodes)
|
||||
}
|
||||
|
||||
// handleGetBlocksCmd processes the getblocks request.
|
||||
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
|
||||
count := gb.Count
|
||||
|
@ -845,6 +928,11 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error
|
|||
return p.EnqueueP2PMessage(msg)
|
||||
}
|
||||
|
||||
// handleHeadersCmd processes headers payload.
|
||||
func (s *Server) handleHeadersCmd(p Peer, h *payload.Headers) error {
|
||||
return s.stateSync.AddHeaders(h.Hdrs...)
|
||||
}
|
||||
|
||||
// handleExtensibleCmd processes received extensible payload.
|
||||
func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
|
||||
if !s.syncReached.Load() {
|
||||
|
@ -993,8 +1081,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error {
|
|||
// 1. Block range is divided into chunks of payload.MaxHashesCount.
|
||||
// 2. Send requests for chunk in increasing order.
|
||||
// 3. After all requests were sent, request random height.
|
||||
func (s *Server) requestBlocks(p Peer) error {
|
||||
var currHeight = s.chain.BlockHeight()
|
||||
func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error {
|
||||
var currHeight = bq.BlockHeight()
|
||||
var peerHeight = p.LastBlockIndex()
|
||||
var needHeight uint32
|
||||
// lastRequestedHeight can only be increased.
|
||||
|
@ -1051,9 +1139,18 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
case CMDGetData:
|
||||
inv := msg.Payload.(*payload.Inventory)
|
||||
return s.handleGetDataCmd(peer, inv)
|
||||
case CMDGetMPTData:
|
||||
inv := msg.Payload.(*payload.MPTInventory)
|
||||
return s.handleGetMPTDataCmd(peer, inv)
|
||||
case CMDMPTData:
|
||||
inv := msg.Payload.(*payload.MPTData)
|
||||
return s.handleMPTDataCmd(peer, inv)
|
||||
case CMDGetHeaders:
|
||||
gh := msg.Payload.(*payload.GetBlockByIndex)
|
||||
return s.handleGetHeadersCmd(peer, gh)
|
||||
case CMDHeaders:
|
||||
h := msg.Payload.(*payload.Headers)
|
||||
return s.handleHeadersCmd(peer, h)
|
||||
case CMDInv:
|
||||
inventory := msg.Payload.(*payload.Inventory)
|
||||
return s.handleInvCmd(peer, inventory)
|
||||
|
@ -1093,6 +1190,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
}
|
||||
go peer.StartProtocol()
|
||||
|
||||
s.tryInitStateSync()
|
||||
s.tryStartServices()
|
||||
default:
|
||||
return fmt.Errorf("received '%s' during handshake", msg.Command.String())
|
||||
|
@ -1101,6 +1199,52 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) tryInitStateSync() {
|
||||
if !s.stateSync.IsActive() {
|
||||
s.bSyncQueue.discard()
|
||||
return
|
||||
}
|
||||
|
||||
if s.stateSync.IsInitialized() {
|
||||
return
|
||||
}
|
||||
|
||||
var peersNumber int
|
||||
s.lock.RLock()
|
||||
heights := make([]uint32, 0)
|
||||
for p := range s.peers {
|
||||
if p.Handshaked() {
|
||||
peersNumber++
|
||||
peerLastBlock := p.LastBlockIndex()
|
||||
i := sort.Search(len(heights), func(i int) bool {
|
||||
return heights[i] >= peerLastBlock
|
||||
})
|
||||
heights = append(heights, peerLastBlock)
|
||||
if i != len(heights)-1 {
|
||||
copy(heights[i+1:], heights[i:])
|
||||
heights[i] = peerLastBlock
|
||||
}
|
||||
}
|
||||
}
|
||||
s.lock.RUnlock()
|
||||
if peersNumber >= s.MinPeers && len(heights) > 0 {
|
||||
// choose the height of the median peer as current chain's height
|
||||
h := heights[len(heights)/2]
|
||||
err := s.stateSync.Init(h)
|
||||
if err != nil {
|
||||
s.log.Fatal("failed to init state sync module",
|
||||
zap.Uint32("evaluated chain's blockHeight", h),
|
||||
zap.Uint32("blockHeight", s.chain.BlockHeight()),
|
||||
zap.Uint32("headerHeight", s.chain.HeaderHeight()),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
// module can be inactive after init (i.e. full state is collected and ordinary block processing is needed)
|
||||
if !s.stateSync.IsActive() {
|
||||
s.bSyncQueue.discard()
|
||||
}
|
||||
}
|
||||
}
|
||||
func (s *Server) handleNewPayload(p *payload.Extensible) {
|
||||
_, err := s.extensiblePool.Add(p)
|
||||
if err != nil {
|
||||
|
|
|
@ -2,6 +2,7 @@ package network
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
|
@ -46,7 +47,10 @@ func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs =
|
|||
func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
bc := &fakechain.FakeChain{}
|
||||
bc := &fakechain.FakeChain{ProtocolConfiguration: config.ProtocolConfiguration{
|
||||
P2PStateExchangeExtensions: true,
|
||||
StateRootInHeader: true,
|
||||
}}
|
||||
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
|
||||
require.Error(t, err)
|
||||
|
||||
|
@ -899,3 +903,39 @@ func TestVerifyNotaryRequest(t *testing.T) {
|
|||
require.NoError(t, verifyNotaryRequest(bc, nil, r))
|
||||
})
|
||||
}
|
||||
|
||||
func TestTryInitStateSync(t *testing.T) {
|
||||
t.Run("module inactive", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
s.tryInitStateSync()
|
||||
})
|
||||
|
||||
t.Run("module already initialized", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true}
|
||||
s.tryInitStateSync()
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
for _, h := range []uint32{10, 8, 7, 4, 11, 4} {
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
p.lastBlockIndex = h
|
||||
s.peers[p] = true
|
||||
}
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = false // one disconnected peer to check it won't be taken into attention
|
||||
p.lastBlockIndex = 5
|
||||
s.peers[p] = true
|
||||
var expectedH uint32 = 8 // median peer
|
||||
|
||||
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error {
|
||||
if h != expectedH {
|
||||
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
s.tryInitStateSync()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -267,12 +267,10 @@ func (p *TCPPeer) StartProtocol() {
|
|||
zap.Uint32("id", p.Version().Nonce))
|
||||
|
||||
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
|
||||
if p.server.chain.BlockHeight() < p.LastBlockIndex() {
|
||||
err = p.server.requestBlocks(p)
|
||||
if err != nil {
|
||||
p.Disconnect(err)
|
||||
return
|
||||
}
|
||||
err = p.server.requestBlocksOrHeaders(p)
|
||||
if err != nil {
|
||||
p.Disconnect(err)
|
||||
return
|
||||
}
|
||||
|
||||
timer := time.NewTimer(p.server.ProtoTickInterval)
|
||||
|
@ -281,10 +279,8 @@ func (p *TCPPeer) StartProtocol() {
|
|||
case <-p.done:
|
||||
return
|
||||
case <-timer.C:
|
||||
// Try to sync in headers and block with the peer if his block height is higher then ours.
|
||||
if p.LastBlockIndex() > p.server.chain.BlockHeight() {
|
||||
err = p.server.requestBlocks(p)
|
||||
}
|
||||
// Try to sync in headers and block with the peer if his block height is higher than ours.
|
||||
err = p.server.requestBlocksOrHeaders(p)
|
||||
if err == nil {
|
||||
timer.Reset(p.server.ProtoTickInterval)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue