mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-22 19:29:39 +00:00
core: implement statesync module
And support GetMPTData and MPTData P2P commands.
This commit is contained in:
parent
a22b1caa3e
commit
d67ff30704
24 changed files with 1197 additions and 32 deletions
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services"
|
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/native"
|
"github.com/nspcc-dev/neo-go/pkg/core/native"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
@ -42,6 +43,13 @@ type FakeChain struct {
|
||||||
UtilityTokenBalance *big.Int
|
UtilityTokenBalance *big.Int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FakeStateSync implements StateSync interface.
|
||||||
|
type FakeStateSync struct {
|
||||||
|
IsActiveFlag bool
|
||||||
|
IsInitializedFlag bool
|
||||||
|
InitFunc func(h uint32) error
|
||||||
|
}
|
||||||
|
|
||||||
// NewFakeChain returns new FakeChain structure.
|
// NewFakeChain returns new FakeChain structure.
|
||||||
func NewFakeChain() *FakeChain {
|
func NewFakeChain() *FakeChain {
|
||||||
return &FakeChain{
|
return &FakeChain{
|
||||||
|
@ -294,6 +302,16 @@ func (chain *FakeChain) GetStateModule() blockchainer.StateRoot {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStateSyncModule implements Blockchainer interface.
|
||||||
|
func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync {
|
||||||
|
return &FakeStateSync{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JumpToState implements Blockchainer interface.
|
||||||
|
func (chain *FakeChain) JumpToState(module blockchainer.StateSync) error {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
// GetStorageItem implements Blockchainer interface.
|
// GetStorageItem implements Blockchainer interface.
|
||||||
func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
|
func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
|
||||||
panic("TODO")
|
panic("TODO")
|
||||||
|
@ -436,3 +454,57 @@ func (chain *FakeChain) UnsubscribeFromNotifications(ch chan<- *state.Notificati
|
||||||
func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
|
func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
|
||||||
panic("TODO")
|
panic("TODO")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddBlock implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) AddBlock(block *block.Block) error {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHeaders implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) AddHeaders(...*block.Header) error {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMPTNodes implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) AddMPTNodes([][]byte) error {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockHeight implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) BlockHeight() uint32 {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsActive implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag }
|
||||||
|
|
||||||
|
// IsInitialized implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) IsInitialized() bool {
|
||||||
|
return s.IsInitializedFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) Init(currChainHeight uint32) error {
|
||||||
|
if s.InitFunc != nil {
|
||||||
|
return s.InitFunc(currChainHeight)
|
||||||
|
}
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedHeaders implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) NeedHeaders() bool { return false }
|
||||||
|
|
||||||
|
// NeedMPTNodes implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) NeedMPTNodes() bool {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJumpHeight implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) GetJumpHeight() (uint32, error) {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/native/noderoles"
|
"github.com/nspcc-dev/neo-go/pkg/core/native/noderoles"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/stateroot"
|
"github.com/nspcc-dev/neo-go/pkg/core/stateroot"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/statesync"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||||
|
@ -409,6 +410,64 @@ func (bc *Blockchain) init() error {
|
||||||
return bc.updateExtensibleWhitelist(bHeight)
|
return bc.updateExtensibleWhitelist(bHeight)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JumpToState is an atomic operation that changes Blockchain state to the one
|
||||||
|
// specified by the state sync point p. All the data needed for the jump must be
|
||||||
|
// collected by the state sync module.
|
||||||
|
func (bc *Blockchain) JumpToState(module blockchainer.StateSync) error {
|
||||||
|
bc.lock.Lock()
|
||||||
|
defer bc.lock.Unlock()
|
||||||
|
|
||||||
|
p, err := module.GetJumpHeight()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get jump height: %w", err)
|
||||||
|
}
|
||||||
|
if p+1 >= uint32(len(bc.headerHashes)) {
|
||||||
|
return fmt.Errorf("invalid state sync point")
|
||||||
|
}
|
||||||
|
|
||||||
|
bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p))
|
||||||
|
|
||||||
|
block, err := bc.dao.GetBlock(bc.headerHashes[p])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get current block: %w", err)
|
||||||
|
}
|
||||||
|
err = bc.dao.StoreAsCurrentBlock(block, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store current block: %w", err)
|
||||||
|
}
|
||||||
|
bc.topBlock.Store(block)
|
||||||
|
atomic.StoreUint32(&bc.blockHeight, p)
|
||||||
|
atomic.StoreUint32(&bc.persistedHeight, p)
|
||||||
|
|
||||||
|
block, err = bc.dao.GetBlock(bc.headerHashes[p+1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get block to init MPT: %w", err)
|
||||||
|
}
|
||||||
|
if err = bc.stateRoot.JumpToState(&state.MPTRoot{
|
||||||
|
Index: p,
|
||||||
|
Root: block.PrevStateRoot,
|
||||||
|
}, bc.config.KeepOnlyLatestState); err != nil {
|
||||||
|
return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
|
||||||
|
}
|
||||||
|
err = bc.contracts.Management.InitializeCache(bc.dao)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't init cache for Management native contract: %w", err)
|
||||||
|
}
|
||||||
|
bc.contracts.Designate.InitializeCache()
|
||||||
|
|
||||||
|
if err := bc.updateExtensibleWhitelist(p); err != nil {
|
||||||
|
return fmt.Errorf("failed to update extensible whitelist: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateBlockHeightMetric(p)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Run runs chain loop, it needs to be run as goroutine and executing it is
|
// Run runs chain loop, it needs to be run as goroutine and executing it is
|
||||||
// critical for correct Blockchain operation.
|
// critical for correct Blockchain operation.
|
||||||
func (bc *Blockchain) Run() {
|
func (bc *Blockchain) Run() {
|
||||||
|
@ -696,6 +755,11 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot {
|
||||||
return bc.stateRoot
|
return bc.stateRoot
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStateSyncModule returns new state sync service instance.
|
||||||
|
func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync {
|
||||||
|
return statesync.NewModule(bc, bc.log, bc.dao)
|
||||||
|
}
|
||||||
|
|
||||||
// storeBlock performs chain update using the block given, it executes all
|
// storeBlock performs chain update using the block given, it executes all
|
||||||
// transactions with all appropriate side-effects and updates Blockchain state.
|
// transactions with all appropriate side-effects and updates Blockchain state.
|
||||||
// This is the only way to change Blockchain state.
|
// This is the only way to change Blockchain state.
|
||||||
|
|
|
@ -21,7 +21,6 @@ import (
|
||||||
type Blockchainer interface {
|
type Blockchainer interface {
|
||||||
ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction
|
ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction
|
||||||
GetConfig() config.ProtocolConfiguration
|
GetConfig() config.ProtocolConfiguration
|
||||||
AddHeaders(...*block.Header) error
|
|
||||||
Blockqueuer // Blockqueuer interface
|
Blockqueuer // Blockqueuer interface
|
||||||
CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error)
|
CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error)
|
||||||
Close()
|
Close()
|
||||||
|
@ -56,10 +55,12 @@ type Blockchainer interface {
|
||||||
GetStandByCommittee() keys.PublicKeys
|
GetStandByCommittee() keys.PublicKeys
|
||||||
GetStandByValidators() keys.PublicKeys
|
GetStandByValidators() keys.PublicKeys
|
||||||
GetStateModule() StateRoot
|
GetStateModule() StateRoot
|
||||||
|
GetStateSyncModule() StateSync
|
||||||
GetStorageItem(id int32, key []byte) state.StorageItem
|
GetStorageItem(id int32, key []byte) state.StorageItem
|
||||||
GetStorageItems(id int32) (map[string]state.StorageItem, error)
|
GetStorageItems(id int32) (map[string]state.StorageItem, error)
|
||||||
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM
|
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM
|
||||||
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
|
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
|
||||||
|
JumpToState(module StateSync) error
|
||||||
SetOracle(service services.Oracle)
|
SetOracle(service services.Oracle)
|
||||||
mempool.Feer // fee interface
|
mempool.Feer // fee interface
|
||||||
ManagementContractHash() util.Uint160
|
ManagementContractHash() util.Uint160
|
||||||
|
|
|
@ -5,5 +5,6 @@ import "github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
// Blockqueuer is an interface for blockqueue.
|
// Blockqueuer is an interface for blockqueue.
|
||||||
type Blockqueuer interface {
|
type Blockqueuer interface {
|
||||||
AddBlock(block *block.Block) error
|
AddBlock(block *block.Block) error
|
||||||
|
AddHeaders(...*block.Header) error
|
||||||
BlockHeight() uint32
|
BlockHeight() uint32
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
// StateRoot represents local state root module.
|
// StateRoot represents local state root module.
|
||||||
type StateRoot interface {
|
type StateRoot interface {
|
||||||
AddStateRoot(root *state.MPTRoot) error
|
AddStateRoot(root *state.MPTRoot) error
|
||||||
|
CurrentLocalHeight() uint32
|
||||||
CurrentLocalStateRoot() util.Uint256
|
CurrentLocalStateRoot() util.Uint256
|
||||||
CurrentValidatedHeight() uint32
|
CurrentValidatedHeight() uint32
|
||||||
GetStateProof(root util.Uint256, key []byte) ([][]byte, error)
|
GetStateProof(root util.Uint256, key []byte) ([][]byte, error)
|
||||||
|
|
19
pkg/core/blockchainer/state_sync.go
Normal file
19
pkg/core/blockchainer/state_sync.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package blockchainer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StateSync represents state sync module.
|
||||||
|
type StateSync interface {
|
||||||
|
AddMPTNodes([][]byte) error
|
||||||
|
Blockqueuer // Blockqueuer interface
|
||||||
|
Init(currChainHeight uint32) error
|
||||||
|
IsActive() bool
|
||||||
|
IsInitialized() bool
|
||||||
|
GetJumpHeight() (uint32, error)
|
||||||
|
NeedHeaders() bool
|
||||||
|
NeedMPTNodes() bool
|
||||||
|
Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
|
||||||
|
}
|
|
@ -1,5 +1,7 @@
|
||||||
package mpt
|
package mpt
|
||||||
|
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
|
||||||
// lcp returns longest common prefix of a and b.
|
// lcp returns longest common prefix of a and b.
|
||||||
// Note: it does no allocations.
|
// Note: it does no allocations.
|
||||||
func lcp(a, b []byte) []byte {
|
func lcp(a, b []byte) []byte {
|
||||||
|
@ -49,3 +51,36 @@ func fromNibbles(path []byte) []byte {
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetChildrenPaths returns a set of paths to node's children who are non-empty HashNodes
|
||||||
|
// based on the node's path.
|
||||||
|
func GetChildrenPaths(path []byte, node Node) map[util.Uint256][][]byte {
|
||||||
|
res := make(map[util.Uint256][][]byte)
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *LeafNode, *HashNode, EmptyNode:
|
||||||
|
return nil
|
||||||
|
case *BranchNode:
|
||||||
|
for i, child := range n.Children {
|
||||||
|
if child.Type() == HashT {
|
||||||
|
cPath := make([]byte, len(path), len(path)+1)
|
||||||
|
copy(cPath, path)
|
||||||
|
if i != lastChild {
|
||||||
|
cPath = append(cPath, byte(i))
|
||||||
|
}
|
||||||
|
paths := res[child.Hash()]
|
||||||
|
paths = append(paths, cPath)
|
||||||
|
res[child.Hash()] = paths
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *ExtensionNode:
|
||||||
|
if n.next.Type() == HashT {
|
||||||
|
cPath := make([]byte, len(path)+len(n.key))
|
||||||
|
copy(cPath, path)
|
||||||
|
copy(cPath[len(path):], n.key)
|
||||||
|
res[n.next.Hash()] = [][]byte{cPath}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("unknown Node type")
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package mpt
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,3 +19,49 @@ func TestToNibblesFromNibbles(t *testing.T) {
|
||||||
check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF})
|
check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetChildrenPaths(t *testing.T) {
|
||||||
|
h1 := NewHashNode(util.Uint256{1, 2, 3})
|
||||||
|
h2 := NewHashNode(util.Uint256{4, 5, 6})
|
||||||
|
h3 := NewHashNode(util.Uint256{7, 8, 9})
|
||||||
|
l := NewLeafNode([]byte{1, 2, 3})
|
||||||
|
ext1 := NewExtensionNode([]byte{8, 9}, h1)
|
||||||
|
ext2 := NewExtensionNode([]byte{7, 6}, l)
|
||||||
|
branch := NewBranchNode()
|
||||||
|
branch.Children[3] = h1
|
||||||
|
branch.Children[5] = l
|
||||||
|
branch.Children[6] = h1 // 3-th and 6-th children have the same hash
|
||||||
|
branch.Children[7] = h3
|
||||||
|
branch.Children[lastChild] = h2
|
||||||
|
testCases := map[string]struct {
|
||||||
|
node Node
|
||||||
|
expected map[util.Uint256][][]byte
|
||||||
|
}{
|
||||||
|
"Hash": {h1, nil},
|
||||||
|
"Leaf": {l, nil},
|
||||||
|
"Extension with next Hash": {ext1, map[util.Uint256][][]byte{h1.Hash(): {ext1.key}}},
|
||||||
|
"Extension with next non-Hash": {ext2, map[util.Uint256][][]byte{}},
|
||||||
|
"Branch": {branch, map[util.Uint256][][]byte{
|
||||||
|
h1.Hash(): {{0x03}, {0x06}},
|
||||||
|
h2.Hash(): {{}},
|
||||||
|
h3.Hash(): {{0x07}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
parentPath := []byte{4, 5, 6}
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
require.Equal(t, testCase.expected, GetChildrenPaths([]byte{}, testCase.node))
|
||||||
|
if testCase.expected != nil {
|
||||||
|
expectedWithPrefix := make(map[util.Uint256][][]byte, len(testCase.expected))
|
||||||
|
for h, paths := range testCase.expected {
|
||||||
|
var res [][]byte
|
||||||
|
for _, path := range paths {
|
||||||
|
res = append(res, append(parentPath, path...))
|
||||||
|
}
|
||||||
|
expectedWithPrefix[h] = res
|
||||||
|
}
|
||||||
|
require.Equal(t, expectedWithPrefix, GetChildrenPaths(parentPath, testCase.node))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newProofTrie(t *testing.T) *Trie {
|
func newProofTrie(t *testing.T, missingHashNode bool) *Trie {
|
||||||
l := NewLeafNode([]byte("somevalue"))
|
l := NewLeafNode([]byte("somevalue"))
|
||||||
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
|
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
|
||||||
l2 := NewLeafNode([]byte("invalid"))
|
l2 := NewLeafNode([]byte("invalid"))
|
||||||
|
@ -20,11 +20,14 @@ func newProofTrie(t *testing.T) *Trie {
|
||||||
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
|
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
|
||||||
tr.putToStore(l)
|
tr.putToStore(l)
|
||||||
tr.putToStore(e)
|
tr.putToStore(e)
|
||||||
|
if !missingHashNode {
|
||||||
|
tr.putToStore(l2)
|
||||||
|
}
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrie_GetProof(t *testing.T) {
|
func TestTrie_GetProof(t *testing.T) {
|
||||||
tr := newProofTrie(t)
|
tr := newProofTrie(t, true)
|
||||||
|
|
||||||
t.Run("MissingKey", func(t *testing.T) {
|
t.Run("MissingKey", func(t *testing.T) {
|
||||||
_, err := tr.GetProof([]byte{0x12})
|
_, err := tr.GetProof([]byte{0x12})
|
||||||
|
@ -43,7 +46,7 @@ func TestTrie_GetProof(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifyProof(t *testing.T) {
|
func TestVerifyProof(t *testing.T) {
|
||||||
tr := newProofTrie(t)
|
tr := newProofTrie(t, true)
|
||||||
|
|
||||||
t.Run("Simple", func(t *testing.T) {
|
t.Run("Simple", func(t *testing.T) {
|
||||||
proof, err := tr.GetProof([]byte{0x12, 0x32})
|
proof, err := tr.GetProof([]byte{0x12, 0x32})
|
||||||
|
|
|
@ -353,3 +353,8 @@ func (s *Designate) getRole(item stackitem.Item) (noderoles.Role, bool) {
|
||||||
u := bi.Uint64()
|
u := bi.Uint64()
|
||||||
return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u))
|
return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitializeCache invalidates native Designate cache.
|
||||||
|
func (s *Designate) InitializeCache() {
|
||||||
|
s.rolesChangedFlag.Store(true)
|
||||||
|
}
|
||||||
|
|
|
@ -114,6 +114,25 @@ func (s *Module) Init(height uint32, enableRefCount bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JumpToState performs jump to the state specified by given stateroot index.
|
||||||
|
func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error {
|
||||||
|
if err := s.addLocalStateRoot(s.Store, sr); err != nil {
|
||||||
|
return fmt.Errorf("failed to store local state root: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]byte, 4)
|
||||||
|
binary.LittleEndian.PutUint32(data, sr.Index)
|
||||||
|
if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixValidated}, data); err != nil {
|
||||||
|
return fmt.Errorf("failed to store validated height: %w", err)
|
||||||
|
}
|
||||||
|
s.validatedHeight.Store(sr.Index)
|
||||||
|
|
||||||
|
s.currentLocal.Store(sr.Root)
|
||||||
|
s.localHeight.Store(sr.Index)
|
||||||
|
s.mpt = mpt.NewTrie(mpt.NewHashNode(sr.Root), enableRefCount, s.Store)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddMPTBatch updates using provided batch.
|
// AddMPTBatch updates using provided batch.
|
||||||
func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) {
|
func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) {
|
||||||
mpt := *s.mpt
|
mpt := *s.mpt
|
||||||
|
|
440
pkg/core/statesync/module.go
Normal file
440
pkg/core/statesync/module.go
Normal file
|
@ -0,0 +1,440 @@
|
||||||
|
/*
|
||||||
|
Package statesync implements module for the P2P state synchronisation process. The
|
||||||
|
module manages state synchronisation for non-archival nodes which are joining the
|
||||||
|
network and don't have the ability to resync from the genesis block.
|
||||||
|
|
||||||
|
Given the currently available state synchronisation point P, sate sync process
|
||||||
|
includes the following stages:
|
||||||
|
|
||||||
|
1. Fetching headers starting from height 0 up to P+1.
|
||||||
|
2. Fetching MPT nodes for height P stating from the corresponding state root.
|
||||||
|
3. Fetching blocks starting from height P-MaxTraceableBlocks (or 0) up to P.
|
||||||
|
|
||||||
|
Steps 2 and 3 are being performed in parallel. Once all the data are collected
|
||||||
|
and stored in the db, an atomic state jump is occurred to the state sync point P.
|
||||||
|
Further node operation process is performed using standard sync mechanism until
|
||||||
|
the node reaches synchronised state.
|
||||||
|
*/
|
||||||
|
package statesync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stateSyncStage is a type of state synchronisation stage.
|
||||||
|
type stateSyncStage uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// inactive means that state exchange is disabled by the protocol configuration.
|
||||||
|
// Can't be combined with other states.
|
||||||
|
inactive stateSyncStage = 1 << iota
|
||||||
|
// none means that state exchange is enabled in the configuration, but
|
||||||
|
// initialisation of the state sync module wasn't yet performed, i.e.
|
||||||
|
// (*Module).Init wasn't called. Can't be combined with other states.
|
||||||
|
none
|
||||||
|
// initialized means that (*Module).Init was called, but other sync stages
|
||||||
|
// are not yet reached (i.e. that headers are requested, but not yet fetched).
|
||||||
|
// Can't be combined with other states.
|
||||||
|
initialized
|
||||||
|
// headersSynced means that headers for the current state sync point are fetched.
|
||||||
|
// May be combined with mptSynced and/or blocksSynced.
|
||||||
|
headersSynced
|
||||||
|
// mptSynced means that MPT nodes for the current state sync point are fetched.
|
||||||
|
// Always combined with headersSynced; may be combined with blocksSynced.
|
||||||
|
mptSynced
|
||||||
|
// blocksSynced means that blocks up to the current state sync point are stored.
|
||||||
|
// Always combined with headersSynced; may be combined with mptSynced.
|
||||||
|
blocksSynced
|
||||||
|
)
|
||||||
|
|
||||||
|
// Module represents state sync module and aimed to gather state-related data to
|
||||||
|
// perform an atomic state jump.
|
||||||
|
type Module struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
log *zap.Logger
|
||||||
|
|
||||||
|
// syncPoint is the state synchronisation point P we're currently working against.
|
||||||
|
syncPoint uint32
|
||||||
|
// syncStage is the stage of the sync process.
|
||||||
|
syncStage stateSyncStage
|
||||||
|
// syncInterval is the delta between two adjacent state sync points.
|
||||||
|
syncInterval uint32
|
||||||
|
// blockHeight is the index of the latest stored block.
|
||||||
|
blockHeight uint32
|
||||||
|
|
||||||
|
dao *dao.Simple
|
||||||
|
bc blockchainer.Blockchainer
|
||||||
|
mptpool *Pool
|
||||||
|
|
||||||
|
billet *mpt.Billet
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewModule returns new instance of statesync module.
|
||||||
|
func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple) *Module {
|
||||||
|
if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) {
|
||||||
|
return &Module{
|
||||||
|
dao: s,
|
||||||
|
bc: bc,
|
||||||
|
syncStage: inactive,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Module{
|
||||||
|
dao: s,
|
||||||
|
bc: bc,
|
||||||
|
log: log,
|
||||||
|
syncInterval: uint32(bc.GetConfig().StateSyncInterval),
|
||||||
|
mptpool: NewPool(),
|
||||||
|
syncStage: none,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init initializes state sync module for the current chain's height with given
|
||||||
|
// callback for MPT nodes requests.
|
||||||
|
func (s *Module) Init(currChainHeight uint32) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
if s.syncStage != none {
|
||||||
|
return errors.New("already initialized or inactive")
|
||||||
|
}
|
||||||
|
|
||||||
|
p := (currChainHeight / s.syncInterval) * s.syncInterval
|
||||||
|
if p < 2*s.syncInterval {
|
||||||
|
// chain is too low to start state exchange process, use the standard sync mechanism
|
||||||
|
s.syncStage = inactive
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pOld, err := s.dao.GetStateSyncPoint()
|
||||||
|
if err == nil && pOld >= p-s.syncInterval {
|
||||||
|
// old point is still valid, so try to resync states for this point.
|
||||||
|
p = pOld
|
||||||
|
} else if s.bc.BlockHeight() > p-2*s.syncInterval {
|
||||||
|
// chain has already been synchronised up to old state sync point and regular blocks processing was started
|
||||||
|
s.syncStage = inactive
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.syncPoint = p
|
||||||
|
err = s.dao.PutStateSyncPoint(p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store state synchronisation point %d: %w", p, err)
|
||||||
|
}
|
||||||
|
s.syncStage = initialized
|
||||||
|
s.log.Info("try to sync state for the latest state synchronisation point",
|
||||||
|
zap.Uint32("point", p),
|
||||||
|
zap.Uint32("evaluated chain's blockHeight", currChainHeight))
|
||||||
|
|
||||||
|
// check headers sync state first
|
||||||
|
ltstHeaderHeight := s.bc.HeaderHeight()
|
||||||
|
if ltstHeaderHeight > p {
|
||||||
|
s.syncStage = headersSynced
|
||||||
|
s.log.Info("headers are in sync",
|
||||||
|
zap.Uint32("headerHeight", s.bc.HeaderHeight()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// check blocks sync state
|
||||||
|
s.blockHeight = s.getLatestSavedBlock(p)
|
||||||
|
if s.blockHeight >= p {
|
||||||
|
s.syncStage |= blocksSynced
|
||||||
|
s.log.Info("blocks are in sync",
|
||||||
|
zap.Uint32("blockHeight", s.blockHeight))
|
||||||
|
}
|
||||||
|
|
||||||
|
// check MPT sync state
|
||||||
|
if s.blockHeight > p {
|
||||||
|
s.syncStage |= mptSynced
|
||||||
|
s.log.Info("MPT is in sync",
|
||||||
|
zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight()))
|
||||||
|
} else if s.syncStage&headersSynced != 0 {
|
||||||
|
header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(p + 1)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get header to initialize MPT billet: %w", err)
|
||||||
|
}
|
||||||
|
s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store)
|
||||||
|
s.log.Info("MPT billet initialized",
|
||||||
|
zap.Uint32("height", s.syncPoint),
|
||||||
|
zap.String("state root", header.PrevStateRoot.StringBE()))
|
||||||
|
pool := NewPool()
|
||||||
|
pool.Add(header.PrevStateRoot, []byte{})
|
||||||
|
err = s.billet.Traverse(func(n mpt.Node, _ []byte) bool {
|
||||||
|
nPaths, ok := pool.TryGet(n.Hash())
|
||||||
|
if !ok {
|
||||||
|
// if this situation occurs, then it's a bug in MPT pool or Traverse.
|
||||||
|
panic("failed to get MPT node from the pool")
|
||||||
|
}
|
||||||
|
pool.Remove(n.Hash())
|
||||||
|
childrenPaths := make(map[util.Uint256][][]byte)
|
||||||
|
for _, path := range nPaths {
|
||||||
|
nChildrenPaths := mpt.GetChildrenPaths(path, n)
|
||||||
|
for hash, paths := range nChildrenPaths {
|
||||||
|
childrenPaths[hash] = append(childrenPaths[hash], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pool.Update(nil, childrenPaths)
|
||||||
|
return false
|
||||||
|
}, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to traverse MPT while initialization: %w", err)
|
||||||
|
}
|
||||||
|
s.mptpool.Update(nil, pool.GetAll())
|
||||||
|
if s.mptpool.Count() == 0 {
|
||||||
|
s.syncStage |= mptSynced
|
||||||
|
s.log.Info("MPT is in sync",
|
||||||
|
zap.Uint32("stateroot height", p))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.syncStage == headersSynced|blocksSynced|mptSynced {
|
||||||
|
s.log.Info("state is in sync, starting regular blocks processing")
|
||||||
|
s.syncStage = inactive
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLatestSavedBlock returns either current block index (if it's still relevant
|
||||||
|
// to continue state sync process) or H-1 where H is the index of the earliest
|
||||||
|
// block that should be saved next.
|
||||||
|
func (s *Module) getLatestSavedBlock(p uint32) uint32 {
|
||||||
|
var result uint32
|
||||||
|
mtb := s.bc.GetConfig().MaxTraceableBlocks
|
||||||
|
if p > mtb {
|
||||||
|
result = p - mtb
|
||||||
|
}
|
||||||
|
storedH, err := s.dao.GetStateSyncCurrentBlockHeight()
|
||||||
|
if err == nil && storedH > result {
|
||||||
|
result = storedH
|
||||||
|
}
|
||||||
|
actualH := s.bc.BlockHeight()
|
||||||
|
if actualH > result {
|
||||||
|
result = actualH
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHeaders validates and adds specified headers to the chain.
|
||||||
|
func (s *Module) AddHeaders(hdrs ...*block.Header) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
if s.syncStage != initialized {
|
||||||
|
return errors.New("headers were not requested")
|
||||||
|
}
|
||||||
|
|
||||||
|
hdrsErr := s.bc.AddHeaders(hdrs...)
|
||||||
|
if s.bc.HeaderHeight() > s.syncPoint {
|
||||||
|
s.syncStage = headersSynced
|
||||||
|
s.log.Info("headers for state sync are fetched",
|
||||||
|
zap.Uint32("header height", s.bc.HeaderHeight()))
|
||||||
|
|
||||||
|
header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint) + 1))
|
||||||
|
if err != nil {
|
||||||
|
s.log.Fatal("failed to get header to initialize MPT billet",
|
||||||
|
zap.Uint32("height", s.syncPoint+1),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store)
|
||||||
|
s.mptpool.Add(header.PrevStateRoot, []byte{})
|
||||||
|
s.log.Info("MPT billet initialized",
|
||||||
|
zap.Uint32("height", s.syncPoint),
|
||||||
|
zap.String("state root", header.PrevStateRoot.StringBE()))
|
||||||
|
}
|
||||||
|
return hdrsErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBlock verifies and saves block skipping executable scripts.
|
||||||
|
func (s *Module) AddBlock(block *block.Block) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
if s.syncStage&headersSynced == 0 || s.syncStage&blocksSynced != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.blockHeight == s.syncPoint {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
expectedHeight := s.blockHeight + 1
|
||||||
|
if expectedHeight != block.Index {
|
||||||
|
return fmt.Errorf("expected %d, got %d: invalid block index", expectedHeight, block.Index)
|
||||||
|
}
|
||||||
|
if s.bc.GetConfig().StateRootInHeader != block.StateRootEnabled {
|
||||||
|
return fmt.Errorf("stateroot setting mismatch: %v != %v", s.bc.GetConfig().StateRootInHeader, block.StateRootEnabled)
|
||||||
|
}
|
||||||
|
if s.bc.GetConfig().VerifyBlocks {
|
||||||
|
merkle := block.ComputeMerkleRoot()
|
||||||
|
if !block.MerkleRoot.Equals(merkle) {
|
||||||
|
return errors.New("invalid block: MerkleRoot mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache := s.dao.GetWrapped()
|
||||||
|
writeBuf := io.NewBufBinWriter()
|
||||||
|
if err := cache.StoreAsBlock(block, writeBuf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
writeBuf.Reset()
|
||||||
|
|
||||||
|
err := cache.PutStateSyncCurrentBlockHeight(block.Index)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store current block height: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tx := range block.Transactions {
|
||||||
|
if err := cache.StoreAsTransaction(tx, block.Index, writeBuf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
writeBuf.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = cache.Persist()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to persist results: %w", err)
|
||||||
|
}
|
||||||
|
s.blockHeight = block.Index
|
||||||
|
if s.blockHeight == s.syncPoint {
|
||||||
|
s.syncStage |= blocksSynced
|
||||||
|
s.log.Info("blocks are in sync",
|
||||||
|
zap.Uint32("blockHeight", s.blockHeight))
|
||||||
|
s.checkSyncIsCompleted()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMPTNodes tries to add provided set of MPT nodes to the MPT billet if they are
|
||||||
|
// not yet collected.
|
||||||
|
func (s *Module) AddMPTNodes(nodes [][]byte) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
if s.syncStage&headersSynced == 0 || s.syncStage&mptSynced != 0 {
|
||||||
|
return errors.New("MPT nodes were not requested")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, nBytes := range nodes {
|
||||||
|
var n mpt.NodeObject
|
||||||
|
r := io.NewBinReaderFromBuf(nBytes)
|
||||||
|
n.DecodeBinary(r)
|
||||||
|
if r.Err != nil {
|
||||||
|
return fmt.Errorf("failed to decode MPT node: %w", r.Err)
|
||||||
|
}
|
||||||
|
nPaths, ok := s.mptpool.TryGet(n.Hash())
|
||||||
|
if !ok {
|
||||||
|
// it can easily happen after receiving the same data from different peers.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var childrenPaths = make(map[util.Uint256][][]byte)
|
||||||
|
for _, path := range nPaths {
|
||||||
|
err := s.billet.RestoreHashNode(path, n.Node)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
|
||||||
|
}
|
||||||
|
for h, paths := range mpt.GetChildrenPaths(path, n.Node) {
|
||||||
|
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
|
||||||
|
}
|
||||||
|
if s.mptpool.Count() == 0 {
|
||||||
|
s.syncStage |= mptSynced
|
||||||
|
s.log.Info("MPT is in sync",
|
||||||
|
zap.Uint32("height", s.syncPoint))
|
||||||
|
s.checkSyncIsCompleted()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1
|
||||||
|
// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored.
|
||||||
|
// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller
|
||||||
|
// should take care of it.
|
||||||
|
func (s *Module) checkSyncIsCompleted() {
|
||||||
|
if s.syncStage != headersSynced|mptSynced|blocksSynced {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.log.Info("state is in sync",
|
||||||
|
zap.Uint32("state sync point", s.syncPoint))
|
||||||
|
err := s.bc.JumpToState(s)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err))
|
||||||
|
}
|
||||||
|
s.syncStage = inactive
|
||||||
|
s.dispose()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Module) dispose() {
|
||||||
|
s.billet = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockHeight returns index of the last stored block.
|
||||||
|
func (s *Module) BlockHeight() uint32 {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.blockHeight
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsActive tells whether state sync module is on and still gathering state
|
||||||
|
// synchronisation data (headers, blocks or MPT nodes).
|
||||||
|
func (s *Module) IsActive() bool {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return !(s.syncStage == inactive || (s.syncStage == headersSynced|mptSynced|blocksSynced))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInitialized tells whether state sync module does not require initialization.
|
||||||
|
// If `false` is returned then Init can be safely called.
|
||||||
|
func (s *Module) IsInitialized() bool {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.syncStage != none
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedHeaders tells whether the module hasn't completed headers synchronisation.
|
||||||
|
func (s *Module) NeedHeaders() bool {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.syncStage == initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedMPTNodes returns whether the module hasn't completed MPT synchronisation.
|
||||||
|
func (s *Module) NeedMPTNodes() bool {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.syncStage&headersSynced != 0 && s.syncStage&mptSynced == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse traverses local MPT nodes starting from the specified root down to its
|
||||||
|
// children calling `process` for each serialised node until stop condition is satisfied.
|
||||||
|
func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
b := mpt.NewBillet(root, s.bc.GetConfig().KeepOnlyLatestState, storage.NewMemCachedStore(s.dao.Store))
|
||||||
|
return b.Traverse(process, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJumpHeight returns state sync point to jump to. It is not protected by mutex and should be called
|
||||||
|
// under the module lock.
|
||||||
|
func (s *Module) GetJumpHeight() (uint32, error) {
|
||||||
|
if s.syncStage != headersSynced|mptSynced|blocksSynced {
|
||||||
|
return 0, errors.New("state sync module has wong state to perform state jump")
|
||||||
|
}
|
||||||
|
return s.syncPoint, nil
|
||||||
|
}
|
119
pkg/core/statesync/mptpool.go
Normal file
119
pkg/core/statesync/mptpool.go
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
package statesync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Pool stores unknown MPT nodes along with the corresponding paths (single node is
|
||||||
|
// allowed to have multiple MPT paths).
|
||||||
|
type Pool struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
hashes map[util.Uint256][][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPool returns new MPT node hashes pool.
|
||||||
|
func NewPool() *Pool {
|
||||||
|
return &Pool{
|
||||||
|
hashes: make(map[util.Uint256][][]byte),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContainsKey checks if MPT node with the specified hash is in the Pool.
|
||||||
|
func (mp *Pool) ContainsKey(hash util.Uint256) bool {
|
||||||
|
mp.lock.RLock()
|
||||||
|
defer mp.lock.RUnlock()
|
||||||
|
|
||||||
|
_, ok := mp.hashes[hash]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryGet returns a set of MPT paths for the specified HashNode.
|
||||||
|
func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) {
|
||||||
|
mp.lock.RLock()
|
||||||
|
defer mp.lock.RUnlock()
|
||||||
|
|
||||||
|
paths, ok := mp.hashes[hash]
|
||||||
|
return paths, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns all MPT nodes with the corresponding paths from the pool.
|
||||||
|
func (mp *Pool) GetAll() map[util.Uint256][][]byte {
|
||||||
|
mp.lock.RLock()
|
||||||
|
defer mp.lock.RUnlock()
|
||||||
|
|
||||||
|
return mp.hashes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes MPT node from the pool by the specified hash.
|
||||||
|
func (mp *Pool) Remove(hash util.Uint256) {
|
||||||
|
mp.lock.Lock()
|
||||||
|
defer mp.lock.Unlock()
|
||||||
|
|
||||||
|
delete(mp.hashes, hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds path to the set of paths for the specified node.
|
||||||
|
func (mp *Pool) Add(hash util.Uint256, path []byte) {
|
||||||
|
mp.lock.Lock()
|
||||||
|
defer mp.lock.Unlock()
|
||||||
|
|
||||||
|
mp.addPaths(hash, [][]byte{path})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update is an atomic operation and removes/adds specified nodes from/to the pool.
|
||||||
|
func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) {
|
||||||
|
mp.lock.Lock()
|
||||||
|
defer mp.lock.Unlock()
|
||||||
|
|
||||||
|
for h, paths := range remove {
|
||||||
|
old := mp.hashes[h]
|
||||||
|
for _, path := range paths {
|
||||||
|
i := sort.Search(len(old), func(i int) bool {
|
||||||
|
return bytes.Compare(old[i], path) >= 0
|
||||||
|
})
|
||||||
|
if i < len(old) && bytes.Equal(old[i], path) {
|
||||||
|
old = append(old[:i], old[i+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(old) == 0 {
|
||||||
|
delete(mp.hashes, h)
|
||||||
|
} else {
|
||||||
|
mp.hashes[h] = old
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for h, paths := range add {
|
||||||
|
mp.addPaths(h, paths)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPaths adds set of the specified node paths to the pool.
|
||||||
|
func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) {
|
||||||
|
old := mp.hashes[nodeHash]
|
||||||
|
for _, path := range paths {
|
||||||
|
i := sort.Search(len(old), func(i int) bool {
|
||||||
|
return bytes.Compare(old[i], path) >= 0
|
||||||
|
})
|
||||||
|
if i < len(old) && bytes.Equal(old[i], path) {
|
||||||
|
// then path is already added
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
old = append(old, path)
|
||||||
|
if i != len(old)-1 {
|
||||||
|
copy(old[i+1:], old[i:])
|
||||||
|
old[i] = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mp.hashes[nodeHash] = old
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of nodes in the pool.
|
||||||
|
func (mp *Pool) Count() int {
|
||||||
|
mp.lock.RLock()
|
||||||
|
defer mp.lock.RUnlock()
|
||||||
|
|
||||||
|
return len(mp.hashes)
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"github.com/Workiva/go-datastructures/queue"
|
"github.com/Workiva/go-datastructures/queue"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||||
|
"go.uber.org/atomic"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,6 +14,7 @@ type blockQueue struct {
|
||||||
checkBlocks chan struct{}
|
checkBlocks chan struct{}
|
||||||
chain blockchainer.Blockqueuer
|
chain blockchainer.Blockqueuer
|
||||||
relayF func(*block.Block)
|
relayF func(*block.Block)
|
||||||
|
discarded *atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -32,6 +34,7 @@ func newBlockQueue(capacity int, bc blockchainer.Blockqueuer, log *zap.Logger, r
|
||||||
checkBlocks: make(chan struct{}, 1),
|
checkBlocks: make(chan struct{}, 1),
|
||||||
chain: bc,
|
chain: bc,
|
||||||
relayF: relayer,
|
relayF: relayer,
|
||||||
|
discarded: atomic.NewBool(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,8 +94,10 @@ func (bq *blockQueue) putBlock(block *block.Block) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bq *blockQueue) discard() {
|
func (bq *blockQueue) discard() {
|
||||||
|
if bq.discarded.CAS(false, true) {
|
||||||
close(bq.checkBlocks)
|
close(bq.checkBlocks)
|
||||||
bq.queue.Dispose()
|
bq.queue.Dispose()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bq *blockQueue) length() int {
|
func (bq *blockQueue) length() int {
|
||||||
|
|
|
@ -71,6 +71,8 @@ const (
|
||||||
CMDBlock = CommandType(payload.BlockType)
|
CMDBlock = CommandType(payload.BlockType)
|
||||||
CMDExtensible = CommandType(payload.ExtensibleType)
|
CMDExtensible = CommandType(payload.ExtensibleType)
|
||||||
CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType)
|
CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType)
|
||||||
|
CMDGetMPTData CommandType = 0x51 // 0x5.. commands are used for extensions (P2PNotary, state exchange cmds)
|
||||||
|
CMDMPTData CommandType = 0x52
|
||||||
CMDReject CommandType = 0x2f
|
CMDReject CommandType = 0x2f
|
||||||
|
|
||||||
// SPV protocol.
|
// SPV protocol.
|
||||||
|
@ -136,6 +138,10 @@ func (m *Message) decodePayload() error {
|
||||||
p = &payload.Version{}
|
p = &payload.Version{}
|
||||||
case CMDInv, CMDGetData:
|
case CMDInv, CMDGetData:
|
||||||
p = &payload.Inventory{}
|
p = &payload.Inventory{}
|
||||||
|
case CMDGetMPTData:
|
||||||
|
p = &payload.MPTInventory{}
|
||||||
|
case CMDMPTData:
|
||||||
|
p = &payload.MPTData{}
|
||||||
case CMDAddr:
|
case CMDAddr:
|
||||||
p = &payload.AddressList{}
|
p = &payload.AddressList{}
|
||||||
case CMDBlock:
|
case CMDBlock:
|
||||||
|
@ -221,7 +227,7 @@ func (m *Message) tryCompressPayload() error {
|
||||||
if m.Flags&Compressed == 0 {
|
if m.Flags&Compressed == 0 {
|
||||||
switch m.Payload.(type) {
|
switch m.Payload.(type) {
|
||||||
case *payload.Headers, *payload.MerkleBlock, payload.NullPayload,
|
case *payload.Headers, *payload.MerkleBlock, payload.NullPayload,
|
||||||
*payload.Inventory:
|
*payload.Inventory, *payload.MPTInventory:
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
size := len(compressedPayload)
|
size := len(compressedPayload)
|
||||||
|
|
|
@ -26,6 +26,8 @@ func _() {
|
||||||
_ = x[CMDBlock-44]
|
_ = x[CMDBlock-44]
|
||||||
_ = x[CMDExtensible-46]
|
_ = x[CMDExtensible-46]
|
||||||
_ = x[CMDP2PNotaryRequest-80]
|
_ = x[CMDP2PNotaryRequest-80]
|
||||||
|
_ = x[CMDGetMPTData-81]
|
||||||
|
_ = x[CMDMPTData-82]
|
||||||
_ = x[CMDReject-47]
|
_ = x[CMDReject-47]
|
||||||
_ = x[CMDFilterLoad-48]
|
_ = x[CMDFilterLoad-48]
|
||||||
_ = x[CMDFilterAdd-49]
|
_ = x[CMDFilterAdd-49]
|
||||||
|
@ -44,7 +46,7 @@ const (
|
||||||
_CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
|
_CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
|
||||||
_CommandType_name_7 = "CMDMerkleBlock"
|
_CommandType_name_7 = "CMDMerkleBlock"
|
||||||
_CommandType_name_8 = "CMDAlert"
|
_CommandType_name_8 = "CMDAlert"
|
||||||
_CommandType_name_9 = "CMDP2PNotaryRequest"
|
_CommandType_name_9 = "CMDP2PNotaryRequestCMDGetMPTDataCMDMPTData"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -55,6 +57,7 @@ var (
|
||||||
_CommandType_index_4 = [...]uint8{0, 12, 22}
|
_CommandType_index_4 = [...]uint8{0, 12, 22}
|
||||||
_CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58}
|
_CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58}
|
||||||
_CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61}
|
_CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61}
|
||||||
|
_CommandType_index_9 = [...]uint8{0, 19, 32, 42}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (i CommandType) String() string {
|
func (i CommandType) String() string {
|
||||||
|
@ -83,8 +86,9 @@ func (i CommandType) String() string {
|
||||||
return _CommandType_name_7
|
return _CommandType_name_7
|
||||||
case i == 64:
|
case i == 64:
|
||||||
return _CommandType_name_8
|
return _CommandType_name_8
|
||||||
case i == 80:
|
case 80 <= i && i <= 82:
|
||||||
return _CommandType_name_9
|
i -= 80
|
||||||
|
return _CommandType_name_9[_CommandType_index_9[i]:_CommandType_index_9[i+1]]
|
||||||
default:
|
default:
|
||||||
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
|
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||||
}
|
}
|
||||||
|
|
|
@ -258,6 +258,21 @@ func TestEncodeDecodeNotFound(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeGetMPTData(t *testing.T) {
|
||||||
|
testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{
|
||||||
|
Hashes: []util.Uint256{
|
||||||
|
{1, 2, 3},
|
||||||
|
{4, 5, 6},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeMPTData(t *testing.T) {
|
||||||
|
testEncodeDecode(t, CMDMPTData, &payload.MPTData{
|
||||||
|
Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestInvalidMessages(t *testing.T) {
|
func TestInvalidMessages(t *testing.T) {
|
||||||
t.Run("CMDBlock, empty payload", func(t *testing.T) {
|
t.Run("CMDBlock, empty payload", func(t *testing.T) {
|
||||||
testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})
|
testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})
|
||||||
|
|
35
pkg/network/payload/mptdata.go
Normal file
35
pkg/network/payload/mptdata.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package payload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MPTData represents the set of serialized MPT nodes.
|
||||||
|
type MPTData struct {
|
||||||
|
Nodes [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeBinary implements io.Serializable.
|
||||||
|
func (d *MPTData) EncodeBinary(w *io.BinWriter) {
|
||||||
|
w.WriteVarUint(uint64(len(d.Nodes)))
|
||||||
|
for _, n := range d.Nodes {
|
||||||
|
w.WriteVarBytes(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeBinary implements io.Serializable.
|
||||||
|
func (d *MPTData) DecodeBinary(r *io.BinReader) {
|
||||||
|
sz := r.ReadVarUint()
|
||||||
|
if sz == 0 {
|
||||||
|
r.Err = errors.New("empty MPT nodes list")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := uint64(0); i < sz; i++ {
|
||||||
|
d.Nodes = append(d.Nodes, r.ReadVarBytes())
|
||||||
|
if r.Err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
24
pkg/network/payload/mptdata_test.go
Normal file
24
pkg/network/payload/mptdata_test.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package payload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMPTData_EncodeDecodeBinary(t *testing.T) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
d := new(MPTData)
|
||||||
|
bytes, err := testserdes.EncodeBinary(d)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTData)))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
d := &MPTData{
|
||||||
|
Nodes: [][]byte{{}, {1}, {1, 2, 3}},
|
||||||
|
}
|
||||||
|
testserdes.EncodeDecodeBinary(t, d, new(MPTData))
|
||||||
|
})
|
||||||
|
}
|
32
pkg/network/payload/mptinventory.go
Normal file
32
pkg/network/payload/mptinventory.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package payload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaxMPTHashesCount is the maximum number of requested MPT nodes hashes.
|
||||||
|
const MaxMPTHashesCount = 32
|
||||||
|
|
||||||
|
// MPTInventory payload.
|
||||||
|
type MPTInventory struct {
|
||||||
|
// A list of requested MPT nodes hashes.
|
||||||
|
Hashes []util.Uint256
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMPTInventory return a pointer to an MPTInventory.
|
||||||
|
func NewMPTInventory(hashes []util.Uint256) *MPTInventory {
|
||||||
|
return &MPTInventory{
|
||||||
|
Hashes: hashes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeBinary implements Serializable interface.
|
||||||
|
func (p *MPTInventory) DecodeBinary(br *io.BinReader) {
|
||||||
|
br.ReadArray(&p.Hashes, MaxMPTHashesCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeBinary implements Serializable interface.
|
||||||
|
func (p *MPTInventory) EncodeBinary(bw *io.BinWriter) {
|
||||||
|
bw.WriteArray(p.Hashes)
|
||||||
|
}
|
38
pkg/network/payload/mptinventory_test.go
Normal file
38
pkg/network/payload/mptinventory_test.go
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
package payload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMPTInventory_EncodeDecodeBinary(t *testing.T) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
testserdes.EncodeDecodeBinary(t, NewMPTInventory([]util.Uint256{}), new(MPTInventory))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
inv := NewMPTInventory([]util.Uint256{{1, 2, 3}, {2, 3, 4}})
|
||||||
|
testserdes.EncodeDecodeBinary(t, inv, new(MPTInventory))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("too large", func(t *testing.T) {
|
||||||
|
check := func(t *testing.T, count int, fail bool) {
|
||||||
|
h := make([]util.Uint256, count)
|
||||||
|
for i := range h {
|
||||||
|
h[i] = util.Uint256{1, 2, 3}
|
||||||
|
}
|
||||||
|
if fail {
|
||||||
|
bytes, err := testserdes.EncodeBinary(NewMPTInventory(h))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTInventory)))
|
||||||
|
} else {
|
||||||
|
testserdes.EncodeDecodeBinary(t, NewMPTInventory(h), new(MPTInventory))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
check(t, MaxMPTHashesCount, false)
|
||||||
|
check(t, MaxMPTHashesCount+1, true)
|
||||||
|
})
|
||||||
|
}
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
mrand "math/rand"
|
mrand "math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -17,7 +18,9 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
|
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/extpool"
|
"github.com/nspcc-dev/neo-go/pkg/network/extpool"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||||
|
@ -67,6 +70,7 @@ type (
|
||||||
discovery Discoverer
|
discovery Discoverer
|
||||||
chain blockchainer.Blockchainer
|
chain blockchainer.Blockchainer
|
||||||
bQueue *blockQueue
|
bQueue *blockQueue
|
||||||
|
bSyncQueue *blockQueue
|
||||||
consensus consensus.Service
|
consensus consensus.Service
|
||||||
mempool *mempool.Pool
|
mempool *mempool.Pool
|
||||||
notaryRequestPool *mempool.Pool
|
notaryRequestPool *mempool.Pool
|
||||||
|
@ -93,6 +97,7 @@ type (
|
||||||
|
|
||||||
oracle *oracle.Oracle
|
oracle *oracle.Oracle
|
||||||
stateRoot stateroot.Service
|
stateRoot stateroot.Service
|
||||||
|
stateSync blockchainer.StateSync
|
||||||
|
|
||||||
log *zap.Logger
|
log *zap.Logger
|
||||||
}
|
}
|
||||||
|
@ -191,6 +196,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
|
||||||
}
|
}
|
||||||
s.stateRoot = sr
|
s.stateRoot = sr
|
||||||
|
|
||||||
|
sSync := chain.GetStateSyncModule()
|
||||||
|
s.stateSync = sSync
|
||||||
|
s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil)
|
||||||
|
|
||||||
if config.OracleCfg.Enabled {
|
if config.OracleCfg.Enabled {
|
||||||
orcCfg := oracle.Config{
|
orcCfg := oracle.Config{
|
||||||
Log: log,
|
Log: log,
|
||||||
|
@ -277,6 +286,7 @@ func (s *Server) Start(errChan chan error) {
|
||||||
go s.broadcastTxLoop()
|
go s.broadcastTxLoop()
|
||||||
go s.relayBlocksLoop()
|
go s.relayBlocksLoop()
|
||||||
go s.bQueue.run()
|
go s.bQueue.run()
|
||||||
|
go s.bSyncQueue.run()
|
||||||
go s.transport.Accept()
|
go s.transport.Accept()
|
||||||
setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10))
|
setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10))
|
||||||
s.run()
|
s.run()
|
||||||
|
@ -292,6 +302,7 @@ func (s *Server) Shutdown() {
|
||||||
p.Disconnect(errServerShutdown)
|
p.Disconnect(errServerShutdown)
|
||||||
}
|
}
|
||||||
s.bQueue.discard()
|
s.bQueue.discard()
|
||||||
|
s.bSyncQueue.discard()
|
||||||
if s.StateRootCfg.Enabled {
|
if s.StateRootCfg.Enabled {
|
||||||
s.stateRoot.Shutdown()
|
s.stateRoot.Shutdown()
|
||||||
}
|
}
|
||||||
|
@ -573,6 +584,10 @@ func (s *Server) IsInSync() bool {
|
||||||
var peersNumber int
|
var peersNumber int
|
||||||
var notHigher int
|
var notHigher int
|
||||||
|
|
||||||
|
if s.stateSync.IsActive() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if s.MinPeers == 0 {
|
if s.MinPeers == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -630,6 +645,9 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
|
||||||
|
|
||||||
// handleBlockCmd processes the received block received from its peer.
|
// handleBlockCmd processes the received block received from its peer.
|
||||||
func (s *Server) handleBlockCmd(p Peer, block *block.Block) error {
|
func (s *Server) handleBlockCmd(p Peer, block *block.Block) error {
|
||||||
|
if s.stateSync.IsActive() {
|
||||||
|
return s.bSyncQueue.putBlock(block)
|
||||||
|
}
|
||||||
return s.bQueue.putBlock(block)
|
return s.bQueue.putBlock(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -639,25 +657,46 @@ func (s *Server) handlePing(p Peer, ping *payload.Ping) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if s.chain.BlockHeight() < ping.LastBlockIndex {
|
err = s.requestBlocksOrHeaders(p)
|
||||||
err = s.requestBlocks(p)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id)))
|
return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) requestBlocksOrHeaders(p Peer) error {
|
||||||
|
if s.stateSync.NeedHeaders() {
|
||||||
|
if s.chain.HeaderHeight() < p.LastBlockIndex() {
|
||||||
|
return s.requestHeaders(p)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var bq blockchainer.Blockqueuer = s.chain
|
||||||
|
if s.stateSync.IsActive() {
|
||||||
|
bq = s.stateSync
|
||||||
|
}
|
||||||
|
if bq.BlockHeight() < p.LastBlockIndex() {
|
||||||
|
return s.requestBlocks(bq, p)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers.
|
||||||
|
func (s *Server) requestHeaders(p Peer) error {
|
||||||
|
// TODO: optimize
|
||||||
|
currHeight := s.chain.HeaderHeight()
|
||||||
|
needHeight := currHeight + 1
|
||||||
|
payload := payload.NewGetBlockByIndex(needHeight, -1)
|
||||||
|
return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload))
|
||||||
|
}
|
||||||
|
|
||||||
// handlePing processes pong request.
|
// handlePing processes pong request.
|
||||||
func (s *Server) handlePong(p Peer, pong *payload.Ping) error {
|
func (s *Server) handlePong(p Peer, pong *payload.Ping) error {
|
||||||
err := p.HandlePong(pong)
|
err := p.HandlePong(pong)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if s.chain.BlockHeight() < pong.LastBlockIndex {
|
return s.requestBlocksOrHeaders(p)
|
||||||
return s.requestBlocks(p)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleInvCmd processes the received inventory.
|
// handleInvCmd processes the received inventory.
|
||||||
|
@ -766,6 +805,50 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleGetMPTDataCmd processes the received MPT inventory.
|
||||||
|
func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error {
|
||||||
|
if !s.chain.GetConfig().P2PStateExchangeExtensions {
|
||||||
|
return errors.New("GetMPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
|
||||||
|
}
|
||||||
|
if s.chain.GetConfig().KeepOnlyLatestState {
|
||||||
|
// TODO: implement keeping MPT states for P1 and P2 height (#2095, #2152 related)
|
||||||
|
return errors.New("GetMPTDataCMD was received, but only latest MPT state is supported")
|
||||||
|
}
|
||||||
|
resp := payload.MPTData{}
|
||||||
|
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
|
||||||
|
for _, h := range inv.Hashes {
|
||||||
|
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err := s.stateSync.Traverse(h,
|
||||||
|
func(_ mpt.Node, node []byte) bool {
|
||||||
|
l := len(node)
|
||||||
|
size := l + io.GetVarSize(l)
|
||||||
|
if size > capLeft {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
resp.Nodes = append(resp.Nodes, node)
|
||||||
|
capLeft -= size
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to traverse MPT starting from %s: %w", h.StringBE(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(resp.Nodes) > 0 {
|
||||||
|
msg := NewMessage(CMDMPTData, &resp)
|
||||||
|
return p.EnqueueP2PMessage(msg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error {
|
||||||
|
if !s.chain.GetConfig().P2PStateExchangeExtensions {
|
||||||
|
return errors.New("MPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
|
||||||
|
}
|
||||||
|
return s.stateSync.AddMPTNodes(data.Nodes)
|
||||||
|
}
|
||||||
|
|
||||||
// handleGetBlocksCmd processes the getblocks request.
|
// handleGetBlocksCmd processes the getblocks request.
|
||||||
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
|
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
|
||||||
count := gb.Count
|
count := gb.Count
|
||||||
|
@ -845,6 +928,11 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error
|
||||||
return p.EnqueueP2PMessage(msg)
|
return p.EnqueueP2PMessage(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleHeadersCmd processes headers payload.
|
||||||
|
func (s *Server) handleHeadersCmd(p Peer, h *payload.Headers) error {
|
||||||
|
return s.stateSync.AddHeaders(h.Hdrs...)
|
||||||
|
}
|
||||||
|
|
||||||
// handleExtensibleCmd processes received extensible payload.
|
// handleExtensibleCmd processes received extensible payload.
|
||||||
func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
|
func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
|
||||||
if !s.syncReached.Load() {
|
if !s.syncReached.Load() {
|
||||||
|
@ -993,8 +1081,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error {
|
||||||
// 1. Block range is divided into chunks of payload.MaxHashesCount.
|
// 1. Block range is divided into chunks of payload.MaxHashesCount.
|
||||||
// 2. Send requests for chunk in increasing order.
|
// 2. Send requests for chunk in increasing order.
|
||||||
// 3. After all requests were sent, request random height.
|
// 3. After all requests were sent, request random height.
|
||||||
func (s *Server) requestBlocks(p Peer) error {
|
func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error {
|
||||||
var currHeight = s.chain.BlockHeight()
|
var currHeight = bq.BlockHeight()
|
||||||
var peerHeight = p.LastBlockIndex()
|
var peerHeight = p.LastBlockIndex()
|
||||||
var needHeight uint32
|
var needHeight uint32
|
||||||
// lastRequestedHeight can only be increased.
|
// lastRequestedHeight can only be increased.
|
||||||
|
@ -1051,9 +1139,18 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
||||||
case CMDGetData:
|
case CMDGetData:
|
||||||
inv := msg.Payload.(*payload.Inventory)
|
inv := msg.Payload.(*payload.Inventory)
|
||||||
return s.handleGetDataCmd(peer, inv)
|
return s.handleGetDataCmd(peer, inv)
|
||||||
|
case CMDGetMPTData:
|
||||||
|
inv := msg.Payload.(*payload.MPTInventory)
|
||||||
|
return s.handleGetMPTDataCmd(peer, inv)
|
||||||
|
case CMDMPTData:
|
||||||
|
inv := msg.Payload.(*payload.MPTData)
|
||||||
|
return s.handleMPTDataCmd(peer, inv)
|
||||||
case CMDGetHeaders:
|
case CMDGetHeaders:
|
||||||
gh := msg.Payload.(*payload.GetBlockByIndex)
|
gh := msg.Payload.(*payload.GetBlockByIndex)
|
||||||
return s.handleGetHeadersCmd(peer, gh)
|
return s.handleGetHeadersCmd(peer, gh)
|
||||||
|
case CMDHeaders:
|
||||||
|
h := msg.Payload.(*payload.Headers)
|
||||||
|
return s.handleHeadersCmd(peer, h)
|
||||||
case CMDInv:
|
case CMDInv:
|
||||||
inventory := msg.Payload.(*payload.Inventory)
|
inventory := msg.Payload.(*payload.Inventory)
|
||||||
return s.handleInvCmd(peer, inventory)
|
return s.handleInvCmd(peer, inventory)
|
||||||
|
@ -1093,6 +1190,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
||||||
}
|
}
|
||||||
go peer.StartProtocol()
|
go peer.StartProtocol()
|
||||||
|
|
||||||
|
s.tryInitStateSync()
|
||||||
s.tryStartServices()
|
s.tryStartServices()
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("received '%s' during handshake", msg.Command.String())
|
return fmt.Errorf("received '%s' during handshake", msg.Command.String())
|
||||||
|
@ -1101,6 +1199,52 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) tryInitStateSync() {
|
||||||
|
if !s.stateSync.IsActive() {
|
||||||
|
s.bSyncQueue.discard()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.stateSync.IsInitialized() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var peersNumber int
|
||||||
|
s.lock.RLock()
|
||||||
|
heights := make([]uint32, 0)
|
||||||
|
for p := range s.peers {
|
||||||
|
if p.Handshaked() {
|
||||||
|
peersNumber++
|
||||||
|
peerLastBlock := p.LastBlockIndex()
|
||||||
|
i := sort.Search(len(heights), func(i int) bool {
|
||||||
|
return heights[i] >= peerLastBlock
|
||||||
|
})
|
||||||
|
heights = append(heights, peerLastBlock)
|
||||||
|
if i != len(heights)-1 {
|
||||||
|
copy(heights[i+1:], heights[i:])
|
||||||
|
heights[i] = peerLastBlock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.lock.RUnlock()
|
||||||
|
if peersNumber >= s.MinPeers && len(heights) > 0 {
|
||||||
|
// choose the height of the median peer as current chain's height
|
||||||
|
h := heights[len(heights)/2]
|
||||||
|
err := s.stateSync.Init(h)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Fatal("failed to init state sync module",
|
||||||
|
zap.Uint32("evaluated chain's blockHeight", h),
|
||||||
|
zap.Uint32("blockHeight", s.chain.BlockHeight()),
|
||||||
|
zap.Uint32("headerHeight", s.chain.HeaderHeight()),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// module can be inactive after init (i.e. full state is collected and ordinary block processing is needed)
|
||||||
|
if !s.stateSync.IsActive() {
|
||||||
|
s.bSyncQueue.discard()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
func (s *Server) handleNewPayload(p *payload.Extensible) {
|
func (s *Server) handleNewPayload(p *payload.Extensible) {
|
||||||
_, err := s.extensiblePool.Add(p)
|
_, err := s.extensiblePool.Add(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -46,7 +47,10 @@ func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs =
|
||||||
func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
|
func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
|
||||||
|
|
||||||
func TestNewServer(t *testing.T) {
|
func TestNewServer(t *testing.T) {
|
||||||
bc := &fakechain.FakeChain{}
|
bc := &fakechain.FakeChain{ProtocolConfiguration: config.ProtocolConfiguration{
|
||||||
|
P2PStateExchangeExtensions: true,
|
||||||
|
StateRootInHeader: true,
|
||||||
|
}}
|
||||||
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
|
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
|
@ -899,3 +903,39 @@ func TestVerifyNotaryRequest(t *testing.T) {
|
||||||
require.NoError(t, verifyNotaryRequest(bc, nil, r))
|
require.NoError(t, verifyNotaryRequest(bc, nil, r))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTryInitStateSync(t *testing.T) {
|
||||||
|
t.Run("module inactive", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
s.tryInitStateSync()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("module already initialized", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true}
|
||||||
|
s.tryInitStateSync()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
for _, h := range []uint32{10, 8, 7, 4, 11, 4} {
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
p.lastBlockIndex = h
|
||||||
|
s.peers[p] = true
|
||||||
|
}
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = false // one disconnected peer to check it won't be taken into attention
|
||||||
|
p.lastBlockIndex = 5
|
||||||
|
s.peers[p] = true
|
||||||
|
var expectedH uint32 = 8 // median peer
|
||||||
|
|
||||||
|
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error {
|
||||||
|
if h != expectedH {
|
||||||
|
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
s.tryInitStateSync()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -267,13 +267,11 @@ func (p *TCPPeer) StartProtocol() {
|
||||||
zap.Uint32("id", p.Version().Nonce))
|
zap.Uint32("id", p.Version().Nonce))
|
||||||
|
|
||||||
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
|
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
|
||||||
if p.server.chain.BlockHeight() < p.LastBlockIndex() {
|
err = p.server.requestBlocksOrHeaders(p)
|
||||||
err = p.server.requestBlocks(p)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.Disconnect(err)
|
p.Disconnect(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
timer := time.NewTimer(p.server.ProtoTickInterval)
|
timer := time.NewTimer(p.server.ProtoTickInterval)
|
||||||
for {
|
for {
|
||||||
|
@ -281,10 +279,8 @@ func (p *TCPPeer) StartProtocol() {
|
||||||
case <-p.done:
|
case <-p.done:
|
||||||
return
|
return
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
// Try to sync in headers and block with the peer if his block height is higher then ours.
|
// Try to sync in headers and block with the peer if his block height is higher than ours.
|
||||||
if p.LastBlockIndex() > p.server.chain.BlockHeight() {
|
err = p.server.requestBlocksOrHeaders(p)
|
||||||
err = p.server.requestBlocks(p)
|
|
||||||
}
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
timer.Reset(p.server.ProtoTickInterval)
|
timer.Reset(p.server.ProtoTickInterval)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue