Merge pull request #2019 from nspcc-dev/states-exchange/cmd
core, network: implement P2P state exchange
This commit is contained in:
commit
752c7f106b
39 changed files with 3001 additions and 53 deletions
|
@ -114,7 +114,11 @@ balance won't be shown in the list of NEP17 balances returned by the neo-go node
|
||||||
(unlike the C# node behavior). However, transfer logs of such token are still
|
(unlike the C# node behavior). However, transfer logs of such token are still
|
||||||
available via `getnep17transfers` RPC call.
|
available via `getnep17transfers` RPC call.
|
||||||
|
|
||||||
The behaviour of the `LastUpdatedBlock` tracking matches the C# node's one.
|
The behaviour of the `LastUpdatedBlock` tracking for archival nodes as far as for
|
||||||
|
governing token balances matches the C# node's one. For non-archival nodes and
|
||||||
|
other NEP17-compliant tokens if transfer's `LastUpdatedBlock` is lower than the
|
||||||
|
latest state synchronization point P the node working against, then
|
||||||
|
`LastUpdatedBlock` equals P.
|
||||||
|
|
||||||
### Unsupported methods
|
### Unsupported methods
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services"
|
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/native"
|
"github.com/nspcc-dev/neo-go/pkg/core/native"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
@ -21,6 +22,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||||
|
uatomic "go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FakeChain implements Blockchainer interface, but does not provide real functionality.
|
// FakeChain implements Blockchainer interface, but does not provide real functionality.
|
||||||
|
@ -42,6 +44,15 @@ type FakeChain struct {
|
||||||
UtilityTokenBalance *big.Int
|
UtilityTokenBalance *big.Int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FakeStateSync implements StateSync interface.
|
||||||
|
type FakeStateSync struct {
|
||||||
|
IsActiveFlag uatomic.Bool
|
||||||
|
IsInitializedFlag uatomic.Bool
|
||||||
|
InitFunc func(h uint32) error
|
||||||
|
TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
|
||||||
|
AddMPTNodesFunc func(nodes [][]byte) error
|
||||||
|
}
|
||||||
|
|
||||||
// NewFakeChain returns new FakeChain structure.
|
// NewFakeChain returns new FakeChain structure.
|
||||||
func NewFakeChain() *FakeChain {
|
func NewFakeChain() *FakeChain {
|
||||||
return &FakeChain{
|
return &FakeChain{
|
||||||
|
@ -294,6 +305,11 @@ func (chain *FakeChain) GetStateModule() blockchainer.StateRoot {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStateSyncModule implements Blockchainer interface.
|
||||||
|
func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync {
|
||||||
|
return &FakeStateSync{}
|
||||||
|
}
|
||||||
|
|
||||||
// GetStorageItem implements Blockchainer interface.
|
// GetStorageItem implements Blockchainer interface.
|
||||||
func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
|
func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
|
||||||
panic("TODO")
|
panic("TODO")
|
||||||
|
@ -436,3 +452,63 @@ func (chain *FakeChain) UnsubscribeFromNotifications(ch chan<- *state.Notificati
|
||||||
func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
|
func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
|
||||||
panic("TODO")
|
panic("TODO")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddBlock implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) AddBlock(block *block.Block) error {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHeaders implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) AddHeaders(...*block.Header) error {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMPTNodes implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) AddMPTNodes(nodes [][]byte) error {
|
||||||
|
if s.AddMPTNodesFunc != nil {
|
||||||
|
return s.AddMPTNodesFunc(nodes)
|
||||||
|
}
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockHeight implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) BlockHeight() uint32 {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsActive implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag.Load() }
|
||||||
|
|
||||||
|
// IsInitialized implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) IsInitialized() bool {
|
||||||
|
return s.IsInitializedFlag.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) Init(currChainHeight uint32) error {
|
||||||
|
if s.InitFunc != nil {
|
||||||
|
return s.InitFunc(currChainHeight)
|
||||||
|
}
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedHeaders implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) NeedHeaders() bool { return false }
|
||||||
|
|
||||||
|
// NeedMPTNodes implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) NeedMPTNodes() bool {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||||
|
if s.TraverseFunc != nil {
|
||||||
|
return s.TraverseFunc(root, process)
|
||||||
|
}
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnknownMPTNodesBatch implements StateSync interface.
|
||||||
|
func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
|
||||||
|
panic("TODO")
|
||||||
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/native/noderoles"
|
"github.com/nspcc-dev/neo-go/pkg/core/native/noderoles"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/stateroot"
|
"github.com/nspcc-dev/neo-go/pkg/core/stateroot"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/statesync"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||||
|
@ -34,6 +35,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
@ -42,7 +44,7 @@ import (
|
||||||
// Tuning parameters.
|
// Tuning parameters.
|
||||||
const (
|
const (
|
||||||
headerBatchCount = 2000
|
headerBatchCount = 2000
|
||||||
version = "0.1.2"
|
version = "0.1.4"
|
||||||
|
|
||||||
defaultInitialGAS = 52000000_00000000
|
defaultInitialGAS = 52000000_00000000
|
||||||
defaultMemPoolSize = 50000
|
defaultMemPoolSize = 50000
|
||||||
|
@ -54,6 +56,31 @@ const (
|
||||||
// HeaderVerificationGasLimit is the maximum amount of GAS for block header verification.
|
// HeaderVerificationGasLimit is the maximum amount of GAS for block header verification.
|
||||||
HeaderVerificationGasLimit = 3_00000000 // 3 GAS
|
HeaderVerificationGasLimit = 3_00000000 // 3 GAS
|
||||||
defaultStateSyncInterval = 40000
|
defaultStateSyncInterval = 40000
|
||||||
|
|
||||||
|
// maxStorageBatchSize is the number of elements in storage batch expected to fit into the
|
||||||
|
// storage without delays and problems. Estimated size of batch in case of given number of
|
||||||
|
// elements does not exceed 1Mb.
|
||||||
|
maxStorageBatchSize = 10000
|
||||||
|
)
|
||||||
|
|
||||||
|
// stateJumpStage denotes the stage of state jump process.
|
||||||
|
type stateJumpStage byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
// none means that no state jump process was initiated yet.
|
||||||
|
none stateJumpStage = 1 << iota
|
||||||
|
// stateJumpStarted means that state jump was just initiated, but outdated storage items
|
||||||
|
// were not yet removed.
|
||||||
|
stateJumpStarted
|
||||||
|
// oldStorageItemsRemoved means that outdated contract storage items were removed, but
|
||||||
|
// new storage items were not yet saved.
|
||||||
|
oldStorageItemsRemoved
|
||||||
|
// newStorageItemsAdded means that contract storage items are up-to-date with the current
|
||||||
|
// state.
|
||||||
|
newStorageItemsAdded
|
||||||
|
// genesisStateRemoved means that state corresponding to the genesis block was removed
|
||||||
|
// from the storage.
|
||||||
|
genesisStateRemoved
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -308,16 +335,6 @@ func (bc *Blockchain) init() error {
|
||||||
// and the genesis block as first block.
|
// and the genesis block as first block.
|
||||||
bc.log.Info("restoring blockchain", zap.String("version", version))
|
bc.log.Info("restoring blockchain", zap.String("version", version))
|
||||||
|
|
||||||
bHeight, err := bc.dao.GetCurrentBlockHeight()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
bc.blockHeight = bHeight
|
|
||||||
bc.persistedHeight = bHeight
|
|
||||||
if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil {
|
|
||||||
return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bc.headerHashes, err = bc.dao.GetHeaderHashes()
|
bc.headerHashes, err = bc.dao.GetHeaderHashes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -365,6 +382,34 @@ func (bc *Blockchain) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check whether StateJump stage is in the storage and continue interrupted state jump if so.
|
||||||
|
jumpStage, err := bc.dao.Store.Get(storage.SYSStateJumpStage.Bytes())
|
||||||
|
if err == nil {
|
||||||
|
if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) {
|
||||||
|
return errors.New("state jump was not completed, but P2PStateExchangeExtensions are disabled or archival node capability is on. " +
|
||||||
|
"To start an archival node drop the database manually and restart the node")
|
||||||
|
}
|
||||||
|
if len(jumpStage) != 1 {
|
||||||
|
return fmt.Errorf("invalid state jump stage format")
|
||||||
|
}
|
||||||
|
// State jump wasn't finished yet, thus continue it.
|
||||||
|
stateSyncPoint, err := bc.dao.GetStateSyncPoint()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get state sync point from the storage")
|
||||||
|
}
|
||||||
|
return bc.jumpToStateInternal(stateSyncPoint, stateJumpStage(jumpStage[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
bHeight, err := bc.dao.GetCurrentBlockHeight()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bc.blockHeight = bHeight
|
||||||
|
bc.persistedHeight = bHeight
|
||||||
|
if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil {
|
||||||
|
return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err)
|
||||||
|
}
|
||||||
|
|
||||||
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
|
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
|
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
|
||||||
|
@ -409,6 +454,158 @@ func (bc *Blockchain) init() error {
|
||||||
return bc.updateExtensibleWhitelist(bHeight)
|
return bc.updateExtensibleWhitelist(bHeight)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// jumpToState is an atomic operation that changes Blockchain state to the one
|
||||||
|
// specified by the state sync point p. All the data needed for the jump must be
|
||||||
|
// collected by the state sync module.
|
||||||
|
func (bc *Blockchain) jumpToState(p uint32) error {
|
||||||
|
bc.lock.Lock()
|
||||||
|
defer bc.lock.Unlock()
|
||||||
|
|
||||||
|
return bc.jumpToStateInternal(p, none)
|
||||||
|
}
|
||||||
|
|
||||||
|
// jumpToStateInternal is an internal representation of jumpToState callback that
|
||||||
|
// changes Blockchain state to the one specified by state sync point p and state
|
||||||
|
// jump stage. All the data needed for the jump must be in the DB, otherwise an
|
||||||
|
// error is returned. It is not protected by mutex.
|
||||||
|
func (bc *Blockchain) jumpToStateInternal(p uint32, stage stateJumpStage) error {
|
||||||
|
if p+1 >= uint32(len(bc.headerHashes)) {
|
||||||
|
return fmt.Errorf("invalid state sync point %d: headerHeignt is %d", p, len(bc.headerHashes))
|
||||||
|
}
|
||||||
|
|
||||||
|
bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p))
|
||||||
|
|
||||||
|
writeBuf := io.NewBufBinWriter()
|
||||||
|
jumpStageKey := storage.SYSStateJumpStage.Bytes()
|
||||||
|
switch stage {
|
||||||
|
case none:
|
||||||
|
err := bc.dao.Store.Put(jumpStageKey, []byte{byte(stateJumpStarted)})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store state jump stage: %w", err)
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case stateJumpStarted:
|
||||||
|
// Replace old storage items by new ones, it should be done step-by step.
|
||||||
|
// Firstly, remove all old genesis-related items.
|
||||||
|
b := bc.dao.Store.Batch()
|
||||||
|
bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) {
|
||||||
|
// Must copy here, #1468.
|
||||||
|
key := slice.Copy(k)
|
||||||
|
b.Delete(key)
|
||||||
|
})
|
||||||
|
b.Put(jumpStageKey, []byte{byte(oldStorageItemsRemoved)})
|
||||||
|
err := bc.dao.Store.PutBatch(b)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store state jump stage: %w", err)
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case oldStorageItemsRemoved:
|
||||||
|
// Then change STTempStorage prefix to STStorage. Each replace operation is atomic.
|
||||||
|
for {
|
||||||
|
count := 0
|
||||||
|
b := bc.dao.Store.Batch()
|
||||||
|
bc.dao.Store.Seek([]byte{byte(storage.STTempStorage)}, func(k, v []byte) {
|
||||||
|
if count >= maxStorageBatchSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Must copy here, #1468.
|
||||||
|
oldKey := slice.Copy(k)
|
||||||
|
b.Delete(oldKey)
|
||||||
|
key := make([]byte, len(k))
|
||||||
|
key[0] = byte(storage.STStorage)
|
||||||
|
copy(key[1:], k[1:])
|
||||||
|
value := slice.Copy(v)
|
||||||
|
b.Put(key, value)
|
||||||
|
count += 2
|
||||||
|
})
|
||||||
|
if count > 0 {
|
||||||
|
err := bc.dao.Store.PutBatch(b)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to replace outdated contract storage items with the fresh ones: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := bc.dao.Store.Put(jumpStageKey, []byte{byte(newStorageItemsAdded)})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store state jump stage: %w", err)
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case newStorageItemsAdded:
|
||||||
|
// After current state is updated, we need to remove outdated state-related data if so.
|
||||||
|
// The only outdated data we might have is genesis-related data, so check it.
|
||||||
|
if p-bc.config.MaxTraceableBlocks > 0 {
|
||||||
|
cache := bc.dao.GetWrapped()
|
||||||
|
writeBuf.Reset()
|
||||||
|
err := cache.DeleteBlock(bc.headerHashes[0], writeBuf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove outdated state data for the genesis block: %w", err)
|
||||||
|
}
|
||||||
|
// TODO: remove NEP17 transfers and NEP17 transfer info for genesis block, #2096 related.
|
||||||
|
_, err = cache.Persist()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to drop genesis block state: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := bc.dao.Store.Put(jumpStageKey, []byte{byte(genesisStateRemoved)})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store state jump stage: %w", err)
|
||||||
|
}
|
||||||
|
case genesisStateRemoved:
|
||||||
|
// there's nothing to do after that, so just continue with common operations
|
||||||
|
// and remove state jump stage in the end.
|
||||||
|
default:
|
||||||
|
return errors.New("unknown state jump stage")
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := bc.dao.GetBlock(bc.headerHashes[p])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get current block: %w", err)
|
||||||
|
}
|
||||||
|
writeBuf.Reset()
|
||||||
|
err = bc.dao.StoreAsCurrentBlock(block, writeBuf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store current block: %w", err)
|
||||||
|
}
|
||||||
|
bc.topBlock.Store(block)
|
||||||
|
atomic.StoreUint32(&bc.blockHeight, p)
|
||||||
|
atomic.StoreUint32(&bc.persistedHeight, p)
|
||||||
|
|
||||||
|
block, err = bc.dao.GetBlock(bc.headerHashes[p+1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get block to init MPT: %w", err)
|
||||||
|
}
|
||||||
|
if err = bc.stateRoot.JumpToState(&state.MPTRoot{
|
||||||
|
Index: p,
|
||||||
|
Root: block.PrevStateRoot,
|
||||||
|
}, bc.config.KeepOnlyLatestState); err != nil {
|
||||||
|
return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
|
||||||
|
}
|
||||||
|
err = bc.contracts.Management.InitializeCache(bc.dao)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't init cache for Management native contract: %w", err)
|
||||||
|
}
|
||||||
|
bc.contracts.Designate.InitializeCache()
|
||||||
|
|
||||||
|
if err := bc.updateExtensibleWhitelist(p); err != nil {
|
||||||
|
return fmt.Errorf("failed to update extensible whitelist: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateBlockHeightMetric(p)
|
||||||
|
|
||||||
|
err = bc.dao.Store.Delete(jumpStageKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove outdated state jump stage: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Run runs chain loop, it needs to be run as goroutine and executing it is
|
// Run runs chain loop, it needs to be run as goroutine and executing it is
|
||||||
// critical for correct Blockchain operation.
|
// critical for correct Blockchain operation.
|
||||||
func (bc *Blockchain) Run() {
|
func (bc *Blockchain) Run() {
|
||||||
|
@ -696,6 +893,11 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot {
|
||||||
return bc.stateRoot
|
return bc.stateRoot
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStateSyncModule returns new state sync service instance.
|
||||||
|
func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync {
|
||||||
|
return statesync.NewModule(bc, bc.log, bc.dao, bc.jumpToState)
|
||||||
|
}
|
||||||
|
|
||||||
// storeBlock performs chain update using the block given, it executes all
|
// storeBlock performs chain update using the block given, it executes all
|
||||||
// transactions with all appropriate side-effects and updates Blockchain state.
|
// transactions with all appropriate side-effects and updates Blockchain state.
|
||||||
// This is the only way to change Blockchain state.
|
// This is the only way to change Blockchain state.
|
||||||
|
@ -1159,12 +1361,25 @@ func (bc *Blockchain) GetNEP17Contracts() []util.Uint160 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNEP17LastUpdated returns a set of contract ids with the corresponding last updated
|
// GetNEP17LastUpdated returns a set of contract ids with the corresponding last updated
|
||||||
// block indexes.
|
// block indexes. In case of an empty account, latest stored state synchronisation point
|
||||||
|
// is returned under Math.MinInt32 key.
|
||||||
func (bc *Blockchain) GetNEP17LastUpdated(acc util.Uint160) (map[int32]uint32, error) {
|
func (bc *Blockchain) GetNEP17LastUpdated(acc util.Uint160) (map[int32]uint32, error) {
|
||||||
info, err := bc.dao.GetNEP17TransferInfo(acc)
|
info, err := bc.dao.GetNEP17TransferInfo(acc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if bc.config.P2PStateExchangeExtensions && bc.config.RemoveUntraceableBlocks {
|
||||||
|
if _, ok := info.LastUpdated[bc.contracts.NEO.ID]; !ok {
|
||||||
|
nBalance, lub := bc.contracts.NEO.BalanceOf(bc.dao, acc)
|
||||||
|
if nBalance.Sign() != 0 {
|
||||||
|
info.LastUpdated[bc.contracts.NEO.ID] = lub
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stateSyncPoint, err := bc.dao.GetStateSyncPoint()
|
||||||
|
if err == nil {
|
||||||
|
info.LastUpdated[math.MinInt32] = stateSyncPoint
|
||||||
|
}
|
||||||
return info.LastUpdated, nil
|
return info.LastUpdated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
@ -34,6 +35,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
|
@ -41,6 +43,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/wallet"
|
"github.com/nspcc-dev/neo-go/pkg/wallet"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/zap/zaptest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVerifyHeader(t *testing.T) {
|
func TestVerifyHeader(t *testing.T) {
|
||||||
|
@ -1734,3 +1737,88 @@ func TestConfigNativeUpdateHistory(t *testing.T) {
|
||||||
check(t, tc)
|
check(t, tc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBlockchain_InitWithIncompleteStateJump(t *testing.T) {
|
||||||
|
var (
|
||||||
|
stateSyncInterval = 4
|
||||||
|
maxTraceable uint32 = 6
|
||||||
|
)
|
||||||
|
spountCfg := func(c *config.Config) {
|
||||||
|
c.ProtocolConfiguration.RemoveUntraceableBlocks = true
|
||||||
|
c.ProtocolConfiguration.StateRootInHeader = true
|
||||||
|
c.ProtocolConfiguration.P2PStateExchangeExtensions = true
|
||||||
|
c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval
|
||||||
|
c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable
|
||||||
|
}
|
||||||
|
bcSpout := newTestChainWithCustomCfg(t, spountCfg)
|
||||||
|
initBasicChain(t, bcSpout)
|
||||||
|
|
||||||
|
// reach next to the latest state sync point and pretend that we've just restored
|
||||||
|
stateSyncPoint := (int(bcSpout.BlockHeight())/stateSyncInterval + 1) * stateSyncInterval
|
||||||
|
for i := bcSpout.BlockHeight() + 1; i <= uint32(stateSyncPoint); i++ {
|
||||||
|
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||||
|
}
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), bcSpout.BlockHeight())
|
||||||
|
b := bcSpout.newBlock()
|
||||||
|
require.NoError(t, bcSpout.AddHeaders(&b.Header))
|
||||||
|
|
||||||
|
// put storage items with STTemp prefix
|
||||||
|
batch := bcSpout.dao.Store.Batch()
|
||||||
|
bcSpout.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) {
|
||||||
|
key := slice.Copy(k)
|
||||||
|
key[0] = storage.STTempStorage.Bytes()[0]
|
||||||
|
value := slice.Copy(v)
|
||||||
|
batch.Put(key, value)
|
||||||
|
})
|
||||||
|
require.NoError(t, bcSpout.dao.Store.PutBatch(batch))
|
||||||
|
|
||||||
|
checkNewBlockchainErr := func(t *testing.T, cfg func(c *config.Config), store storage.Store, shouldFail bool) {
|
||||||
|
unitTestNetCfg, err := config.Load("../../config", testchain.Network())
|
||||||
|
require.NoError(t, err)
|
||||||
|
cfg(&unitTestNetCfg)
|
||||||
|
log := zaptest.NewLogger(t)
|
||||||
|
_, err = NewBlockchain(store, unitTestNetCfg.ProtocolConfiguration, log)
|
||||||
|
if shouldFail {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
boltCfg := func(c *config.Config) {
|
||||||
|
spountCfg(c)
|
||||||
|
c.ProtocolConfiguration.KeepOnlyLatestState = true
|
||||||
|
}
|
||||||
|
// manually store statejump stage to check statejump recover process
|
||||||
|
t.Run("invalid RemoveUntraceableBlocks setting", func(t *testing.T) {
|
||||||
|
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)}))
|
||||||
|
checkNewBlockchainErr(t, func(c *config.Config) {
|
||||||
|
boltCfg(c)
|
||||||
|
c.ProtocolConfiguration.RemoveUntraceableBlocks = false
|
||||||
|
}, bcSpout.dao.Store, true)
|
||||||
|
})
|
||||||
|
t.Run("invalid state jump stage format", func(t *testing.T) {
|
||||||
|
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{0x01, 0x02}))
|
||||||
|
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true)
|
||||||
|
})
|
||||||
|
t.Run("missing state sync point", func(t *testing.T) {
|
||||||
|
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)}))
|
||||||
|
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true)
|
||||||
|
})
|
||||||
|
t.Run("invalid state sync point", func(t *testing.T) {
|
||||||
|
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)}))
|
||||||
|
point := make([]byte, 4)
|
||||||
|
binary.LittleEndian.PutUint32(point, uint32(len(bcSpout.headerHashes)))
|
||||||
|
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point))
|
||||||
|
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true)
|
||||||
|
})
|
||||||
|
for _, stage := range []stateJumpStage{stateJumpStarted, oldStorageItemsRemoved, newStorageItemsAdded, genesisStateRemoved, 0x03} {
|
||||||
|
t.Run(fmt.Sprintf("state jump stage %d", stage), func(t *testing.T) {
|
||||||
|
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stage)}))
|
||||||
|
point := make([]byte, 4)
|
||||||
|
binary.LittleEndian.PutUint32(point, uint32(stateSyncPoint))
|
||||||
|
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point))
|
||||||
|
shouldFail := stage == 0x03 // unknown stage
|
||||||
|
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, shouldFail)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -21,7 +21,6 @@ import (
|
||||||
type Blockchainer interface {
|
type Blockchainer interface {
|
||||||
ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction
|
ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction
|
||||||
GetConfig() config.ProtocolConfiguration
|
GetConfig() config.ProtocolConfiguration
|
||||||
AddHeaders(...*block.Header) error
|
|
||||||
Blockqueuer // Blockqueuer interface
|
Blockqueuer // Blockqueuer interface
|
||||||
CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error)
|
CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error)
|
||||||
Close()
|
Close()
|
||||||
|
@ -56,6 +55,7 @@ type Blockchainer interface {
|
||||||
GetStandByCommittee() keys.PublicKeys
|
GetStandByCommittee() keys.PublicKeys
|
||||||
GetStandByValidators() keys.PublicKeys
|
GetStandByValidators() keys.PublicKeys
|
||||||
GetStateModule() StateRoot
|
GetStateModule() StateRoot
|
||||||
|
GetStateSyncModule() StateSync
|
||||||
GetStorageItem(id int32, key []byte) state.StorageItem
|
GetStorageItem(id int32, key []byte) state.StorageItem
|
||||||
GetStorageItems(id int32) (map[string]state.StorageItem, error)
|
GetStorageItems(id int32) (map[string]state.StorageItem, error)
|
||||||
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM
|
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM
|
||||||
|
|
|
@ -5,5 +5,6 @@ import "github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
// Blockqueuer is an interface for blockqueue.
|
// Blockqueuer is an interface for blockqueue.
|
||||||
type Blockqueuer interface {
|
type Blockqueuer interface {
|
||||||
AddBlock(block *block.Block) error
|
AddBlock(block *block.Block) error
|
||||||
|
AddHeaders(...*block.Header) error
|
||||||
BlockHeight() uint32
|
BlockHeight() uint32
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
// StateRoot represents local state root module.
|
// StateRoot represents local state root module.
|
||||||
type StateRoot interface {
|
type StateRoot interface {
|
||||||
AddStateRoot(root *state.MPTRoot) error
|
AddStateRoot(root *state.MPTRoot) error
|
||||||
|
CleanStorage() error
|
||||||
|
CurrentLocalHeight() uint32
|
||||||
CurrentLocalStateRoot() util.Uint256
|
CurrentLocalStateRoot() util.Uint256
|
||||||
CurrentValidatedHeight() uint32
|
CurrentValidatedHeight() uint32
|
||||||
GetStateProof(root util.Uint256, key []byte) ([][]byte, error)
|
GetStateProof(root util.Uint256, key []byte) ([][]byte, error)
|
||||||
|
|
19
pkg/core/blockchainer/state_sync.go
Normal file
19
pkg/core/blockchainer/state_sync.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package blockchainer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StateSync represents state sync module.
|
||||||
|
type StateSync interface {
|
||||||
|
AddMPTNodes([][]byte) error
|
||||||
|
Blockqueuer // Blockqueuer interface
|
||||||
|
Init(currChainHeight uint32) error
|
||||||
|
IsActive() bool
|
||||||
|
IsInitialized() bool
|
||||||
|
GetUnknownMPTNodesBatch(limit int) []util.Uint256
|
||||||
|
NeedHeaders() bool
|
||||||
|
NeedMPTNodes() bool
|
||||||
|
Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
|
||||||
|
}
|
314
pkg/core/mpt/billet.go
Normal file
314
pkg/core/mpt/billet.go
Normal file
|
@ -0,0 +1,314 @@
|
||||||
|
package mpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrRestoreFailed is returned when replacing HashNode by its "unhashed"
|
||||||
|
// candidate fails.
|
||||||
|
ErrRestoreFailed = errors.New("failed to restore MPT node")
|
||||||
|
errStop = errors.New("stop condition is met")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Billet is a part of MPT trie with missing hash nodes that need to be restored.
|
||||||
|
// Billet is based on the following assumptions:
|
||||||
|
// 1. Refcount can only be incremented (we don't change MPT structure during restore,
|
||||||
|
// thus don't need to decrease refcount).
|
||||||
|
// 2. Each time the part of Billet is completely restored, it is collapsed into
|
||||||
|
// HashNode.
|
||||||
|
// 3. Pair (node, path) must be restored only once. It's a duty of MPT pool to manage
|
||||||
|
// MPT paths in order to provide this assumption.
|
||||||
|
type Billet struct {
|
||||||
|
Store *storage.MemCachedStore
|
||||||
|
|
||||||
|
root Node
|
||||||
|
refcountEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBillet returns new billet for MPT trie restoring. It accepts a MemCachedStore
|
||||||
|
// to decouple storage errors from logic errors so that all storage errors are
|
||||||
|
// processed during `store.Persist()` at the caller. This also has the benefit,
|
||||||
|
// that every `Put` can be considered an atomic operation.
|
||||||
|
func NewBillet(rootHash util.Uint256, enableRefCount bool, store *storage.MemCachedStore) *Billet {
|
||||||
|
return &Billet{
|
||||||
|
Store: store,
|
||||||
|
root: NewHashNode(rootHash),
|
||||||
|
refcountEnabled: enableRefCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreHashNode replaces HashNode located at the provided path by the specified Node
|
||||||
|
// and stores it. It also maintains MPT as small as possible by collapsing those parts
|
||||||
|
// of MPT that have been completely restored.
|
||||||
|
func (b *Billet) RestoreHashNode(path []byte, node Node) error {
|
||||||
|
if _, ok := node.(*HashNode); ok {
|
||||||
|
return fmt.Errorf("%w: unable to restore node into HashNode", ErrRestoreFailed)
|
||||||
|
}
|
||||||
|
if _, ok := node.(EmptyNode); ok {
|
||||||
|
return fmt.Errorf("%w: unable to restore node into EmptyNode", ErrRestoreFailed)
|
||||||
|
}
|
||||||
|
r, err := b.putIntoNode(b.root, path, node)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.root = r
|
||||||
|
|
||||||
|
// If it's a leaf, then put into temporary contract storage.
|
||||||
|
if leaf, ok := node.(*LeafNode); ok {
|
||||||
|
k := append([]byte{byte(storage.STTempStorage)}, fromNibbles(path)...)
|
||||||
|
_ = b.Store.Put(k, leaf.value)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// putIntoNode puts val with provided path inside curr and returns updated node.
|
||||||
|
// Reference counters are updated for both curr and returned value.
|
||||||
|
func (b *Billet) putIntoNode(curr Node, path []byte, val Node) (Node, error) {
|
||||||
|
switch n := curr.(type) {
|
||||||
|
case *LeafNode:
|
||||||
|
return b.putIntoLeaf(n, path, val)
|
||||||
|
case *BranchNode:
|
||||||
|
return b.putIntoBranch(n, path, val)
|
||||||
|
case *ExtensionNode:
|
||||||
|
return b.putIntoExtension(n, path, val)
|
||||||
|
case *HashNode:
|
||||||
|
return b.putIntoHash(n, path, val)
|
||||||
|
case EmptyNode:
|
||||||
|
return nil, fmt.Errorf("%w: can't modify EmptyNode during restore", ErrRestoreFailed)
|
||||||
|
default:
|
||||||
|
panic("invalid MPT node type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Billet) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) {
|
||||||
|
if len(path) != 0 {
|
||||||
|
return nil, fmt.Errorf("%w: can't modify LeafNode during restore", ErrRestoreFailed)
|
||||||
|
}
|
||||||
|
if curr.Hash() != val.Hash() {
|
||||||
|
return nil, fmt.Errorf("%w: bad Leaf node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE())
|
||||||
|
}
|
||||||
|
// Once Leaf node is restored, it will be collapsed into HashNode forever, so
|
||||||
|
// there shouldn't be such situation when we try to restore Leaf node.
|
||||||
|
panic("bug: can't restore LeafNode")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Billet) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) {
|
||||||
|
if len(path) == 0 && curr.Hash().Equals(val.Hash()) {
|
||||||
|
// This node has already been restored, so it's an MPT pool duty to avoid
|
||||||
|
// duplicating restore requests.
|
||||||
|
panic("bug: can't perform restoring of BranchNode twice")
|
||||||
|
}
|
||||||
|
i, path := splitPath(path)
|
||||||
|
r, err := b.putIntoNode(curr.Children[i], path, val)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
curr.Children[i] = r
|
||||||
|
return b.tryCollapseBranch(curr), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Billet) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) {
|
||||||
|
if len(path) == 0 {
|
||||||
|
if curr.Hash() != val.Hash() {
|
||||||
|
return nil, fmt.Errorf("%w: bad Extension node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE())
|
||||||
|
}
|
||||||
|
// This node has already been restored, so it's an MPT pool duty to avoid
|
||||||
|
// duplicating restore requests.
|
||||||
|
panic("bug: can't perform restoring of ExtensionNode twice")
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(path, curr.key) {
|
||||||
|
return nil, fmt.Errorf("%w: can't modify ExtensionNode during restore", ErrRestoreFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := b.putIntoNode(curr.next, path[len(curr.key):], val)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
curr.next = r
|
||||||
|
return b.tryCollapseExtension(curr), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Billet) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) {
|
||||||
|
// Once a part of MPT Billet is completely restored, it will be collapsed forever, so
|
||||||
|
// it's an MPT pool duty to avoid duplicating restore requests.
|
||||||
|
if len(path) != 0 {
|
||||||
|
return nil, fmt.Errorf("%w: node has already been collapsed", ErrRestoreFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// `curr` hash node can be either of
|
||||||
|
// 1) saved in storage (i.g. if we've already restored node with the same hash from the
|
||||||
|
// other part of MPT), so just add it to local in-memory MPT.
|
||||||
|
// 2) missing from the storage. It's OK because we're syncing MPT state, and the purpose
|
||||||
|
// is to store missing hash nodes.
|
||||||
|
// both cases are OK, but we still need to validate `val` against `curr`.
|
||||||
|
if val.Hash() != curr.Hash() {
|
||||||
|
return nil, fmt.Errorf("%w: can't restore HashNode: expected and actual hashes mismatch (%s vs %s)", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE())
|
||||||
|
}
|
||||||
|
|
||||||
|
if curr.Collapsed {
|
||||||
|
// This node has already been restored and collapsed, so it's an MPT pool duty to avoid
|
||||||
|
// duplicating restore requests.
|
||||||
|
panic("bug: can't perform restoring of collapsed node")
|
||||||
|
}
|
||||||
|
|
||||||
|
// We also need to increment refcount in both cases. That's the only place where refcount
|
||||||
|
// is changed during restore process. Also flush right now, because sync process can be
|
||||||
|
// interrupted at any time.
|
||||||
|
b.incrementRefAndStore(val.Hash(), val.Bytes())
|
||||||
|
|
||||||
|
if val.Type() == LeafT {
|
||||||
|
return b.tryCollapseLeaf(val.(*LeafNode)), nil
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Billet) incrementRefAndStore(h util.Uint256, bs []byte) {
|
||||||
|
key := makeStorageKey(h.BytesBE())
|
||||||
|
if b.refcountEnabled {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
data []byte
|
||||||
|
cnt int32
|
||||||
|
)
|
||||||
|
// An item may already be in store.
|
||||||
|
data, err = b.Store.Get(key)
|
||||||
|
if err == nil {
|
||||||
|
cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:]))
|
||||||
|
}
|
||||||
|
cnt++
|
||||||
|
if len(data) == 0 {
|
||||||
|
data = append(bs, 0, 0, 0, 0)
|
||||||
|
}
|
||||||
|
binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt))
|
||||||
|
_ = b.Store.Put(key, data)
|
||||||
|
} else {
|
||||||
|
_ = b.Store.Put(key, bs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse traverses MPT nodes (pre-order) starting from the billet root down
|
||||||
|
// to its children calling `process` for each serialised node until true is
|
||||||
|
// returned from `process` function. It also replaces all HashNodes to their
|
||||||
|
// "unhashed" counterparts until the stop condition is satisfied.
|
||||||
|
func (b *Billet) Traverse(process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) error {
|
||||||
|
r, err := b.traverse(b.root, process, ignoreStorageErr)
|
||||||
|
if err != nil && !errors.Is(err, errStop) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.root = r
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) (Node, error) {
|
||||||
|
if _, ok := curr.(EmptyNode); ok {
|
||||||
|
// We're not interested in EmptyNodes, and they do not affect the
|
||||||
|
// traversal process, thus remain them untouched.
|
||||||
|
return curr, nil
|
||||||
|
}
|
||||||
|
if hn, ok := curr.(*HashNode); ok {
|
||||||
|
r, err := b.GetFromStore(hn.Hash())
|
||||||
|
if err != nil {
|
||||||
|
if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) {
|
||||||
|
return hn, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return b.traverse(r, process, ignoreStorageErr)
|
||||||
|
}
|
||||||
|
bytes := slice.Copy(curr.Bytes())
|
||||||
|
if process(curr, bytes) {
|
||||||
|
return curr, errStop
|
||||||
|
}
|
||||||
|
switch n := curr.(type) {
|
||||||
|
case *LeafNode:
|
||||||
|
return b.tryCollapseLeaf(n), nil
|
||||||
|
case *BranchNode:
|
||||||
|
for i := range n.Children {
|
||||||
|
r, err := b.traverse(n.Children[i], process, ignoreStorageErr)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, errStop) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
n.Children[i] = r
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
n.Children[i] = r
|
||||||
|
}
|
||||||
|
return b.tryCollapseBranch(n), nil
|
||||||
|
case *ExtensionNode:
|
||||||
|
r, err := b.traverse(n.next, process, ignoreStorageErr)
|
||||||
|
if err != nil && !errors.Is(err, errStop) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
n.next = r
|
||||||
|
return b.tryCollapseExtension(n), err
|
||||||
|
default:
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Billet) tryCollapseLeaf(curr *LeafNode) Node {
|
||||||
|
// Leaf can always be collapsed.
|
||||||
|
res := NewHashNode(curr.Hash())
|
||||||
|
res.Collapsed = true
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Billet) tryCollapseExtension(curr *ExtensionNode) Node {
|
||||||
|
if !(curr.next.Type() == HashT && curr.next.(*HashNode).Collapsed) {
|
||||||
|
return curr
|
||||||
|
}
|
||||||
|
res := NewHashNode(curr.Hash())
|
||||||
|
res.Collapsed = true
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Billet) tryCollapseBranch(curr *BranchNode) Node {
|
||||||
|
canCollapse := true
|
||||||
|
for i := 0; i < childrenCount; i++ {
|
||||||
|
if curr.Children[i].Type() == EmptyT {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if curr.Children[i].Type() == HashT && curr.Children[i].(*HashNode).Collapsed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
canCollapse = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !canCollapse {
|
||||||
|
return curr
|
||||||
|
}
|
||||||
|
res := NewHashNode(curr.Hash())
|
||||||
|
res.Collapsed = true
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFromStore returns MPT node from the storage.
|
||||||
|
func (b *Billet) GetFromStore(h util.Uint256) (Node, error) {
|
||||||
|
data, err := b.Store.Get(makeStorageKey(h.BytesBE()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var n NodeObject
|
||||||
|
r := io.NewBinReaderFromBuf(data)
|
||||||
|
n.DecodeBinary(r)
|
||||||
|
if r.Err != nil {
|
||||||
|
return nil, r.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.refcountEnabled {
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
}
|
||||||
|
n.Node.(flushedNode).setCache(data, h)
|
||||||
|
return n.Node, nil
|
||||||
|
}
|
211
pkg/core/mpt/billet_test.go
Normal file
211
pkg/core/mpt/billet_test.go
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
package mpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBillet_RestoreHashNode(t *testing.T) {
|
||||||
|
check := func(t *testing.T, tr *Billet, expectedRoot Node, expectedNode Node, expectedRefCount uint32) {
|
||||||
|
_ = expectedRoot.Hash()
|
||||||
|
_ = tr.root.Hash()
|
||||||
|
require.Equal(t, expectedRoot, tr.root)
|
||||||
|
expectedBytes, err := tr.Store.Get(makeStorageKey(expectedNode.Hash().BytesBE()))
|
||||||
|
if expectedRefCount != 0 {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, expectedRefCount, binary.LittleEndian.Uint32(expectedBytes[len(expectedBytes)-4:]))
|
||||||
|
} else {
|
||||||
|
require.True(t, errors.Is(err, storage.ErrKeyNotFound))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("parent is Extension", func(t *testing.T) {
|
||||||
|
t.Run("restore Branch", func(t *testing.T) {
|
||||||
|
b := NewBranchNode()
|
||||||
|
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xCD}))
|
||||||
|
b.Children[5] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xDE}))
|
||||||
|
path := toNibbles([]byte{0xAC})
|
||||||
|
e := NewExtensionNode(path, NewHashNode(b.Hash()))
|
||||||
|
tr := NewBillet(e.Hash(), true, newTestStore())
|
||||||
|
tr.root = e
|
||||||
|
|
||||||
|
// OK
|
||||||
|
n := new(NodeObject)
|
||||||
|
n.DecodeBinary(io.NewBinReaderFromBuf(b.Bytes()))
|
||||||
|
require.NoError(t, tr.RestoreHashNode(path, n.Node))
|
||||||
|
expected := NewExtensionNode(path, n.Node)
|
||||||
|
check(t, tr, expected, n.Node, 1)
|
||||||
|
|
||||||
|
// One more time (already restored) => panic expected, no refcount changes
|
||||||
|
require.Panics(t, func() {
|
||||||
|
_ = tr.RestoreHashNode(path, n.Node)
|
||||||
|
})
|
||||||
|
check(t, tr, expected, n.Node, 1)
|
||||||
|
|
||||||
|
// Same path, but wrong hash => error expected, no refcount changes
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode(path, NewBranchNode()), ErrRestoreFailed))
|
||||||
|
check(t, tr, expected, n.Node, 1)
|
||||||
|
|
||||||
|
// New path (changes in the MPT structure are not allowed) => error expected, no refcount changes
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), n.Node), ErrRestoreFailed))
|
||||||
|
check(t, tr, expected, n.Node, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("restore Leaf", func(t *testing.T) {
|
||||||
|
l := NewLeafNode([]byte{0xAB, 0xCD})
|
||||||
|
path := toNibbles([]byte{0xAC})
|
||||||
|
e := NewExtensionNode(path, NewHashNode(l.Hash()))
|
||||||
|
tr := NewBillet(e.Hash(), true, newTestStore())
|
||||||
|
tr.root = e
|
||||||
|
|
||||||
|
// OK
|
||||||
|
require.NoError(t, tr.RestoreHashNode(path, l))
|
||||||
|
expected := NewHashNode(e.Hash()) // leaf should be collapsed immediately => extension should also be collapsed
|
||||||
|
expected.Collapsed = true
|
||||||
|
check(t, tr, expected, l, 1)
|
||||||
|
|
||||||
|
// One more time (already restored and collapsed) => error expected, no refcount changes
|
||||||
|
require.Error(t, tr.RestoreHashNode(path, l))
|
||||||
|
check(t, tr, expected, l, 1)
|
||||||
|
|
||||||
|
// Same path, but wrong hash => error expected, no refcount changes
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed))
|
||||||
|
check(t, tr, expected, l, 1)
|
||||||
|
|
||||||
|
// New path (changes in the MPT structure are not allowed) => error expected, no refcount changes
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), l), ErrRestoreFailed))
|
||||||
|
check(t, tr, expected, l, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("restore Hash", func(t *testing.T) {
|
||||||
|
h := NewHashNode(util.Uint256{1, 2, 3})
|
||||||
|
path := toNibbles([]byte{0xAC})
|
||||||
|
e := NewExtensionNode(path, h)
|
||||||
|
tr := NewBillet(e.Hash(), true, newTestStore())
|
||||||
|
tr.root = e
|
||||||
|
|
||||||
|
// no-op
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode(path, h), ErrRestoreFailed))
|
||||||
|
check(t, tr, e, h, 0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parent is Leaf", func(t *testing.T) {
|
||||||
|
l := NewLeafNode([]byte{0xAB, 0xCD})
|
||||||
|
path := []byte{}
|
||||||
|
tr := NewBillet(l.Hash(), true, newTestStore())
|
||||||
|
tr.root = l
|
||||||
|
|
||||||
|
// Already restored => panic expected
|
||||||
|
require.Panics(t, func() {
|
||||||
|
_ = tr.RestoreHashNode(path, l)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Same path, but wrong hash => error expected, no refcount changes
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed))
|
||||||
|
|
||||||
|
// Non-nil path, but MPT structure can't be changed => error expected, no refcount changes
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAC}), NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parent is Branch", func(t *testing.T) {
|
||||||
|
t.Run("middle child", func(t *testing.T) {
|
||||||
|
l1 := NewLeafNode([]byte{0xAB, 0xCD})
|
||||||
|
l2 := NewLeafNode([]byte{0xAB, 0xDE})
|
||||||
|
b := NewBranchNode()
|
||||||
|
b.Children[5] = NewHashNode(l1.Hash())
|
||||||
|
b.Children[lastChild] = NewHashNode(l2.Hash())
|
||||||
|
tr := NewBillet(b.Hash(), true, newTestStore())
|
||||||
|
tr.root = b
|
||||||
|
|
||||||
|
// OK
|
||||||
|
path := []byte{0x05}
|
||||||
|
require.NoError(t, tr.RestoreHashNode(path, l1))
|
||||||
|
check(t, tr, b, l1, 1)
|
||||||
|
|
||||||
|
// One more time (already restored) => panic expected.
|
||||||
|
// It's an MPT pool duty to avoid such situations during real restore process.
|
||||||
|
require.Panics(t, func() {
|
||||||
|
_ = tr.RestoreHashNode(path, l1)
|
||||||
|
})
|
||||||
|
// No refcount changes expected.
|
||||||
|
check(t, tr, b, l1, 1)
|
||||||
|
|
||||||
|
// Same path, but wrong hash => error expected, no refcount changes
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed))
|
||||||
|
check(t, tr, b, l1, 1)
|
||||||
|
|
||||||
|
// New path pointing to the empty HashNode (changes in the MPT structure are not allowed) => error expected, no refcount changes
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode([]byte{0x01}, l1), ErrRestoreFailed))
|
||||||
|
check(t, tr, b, l1, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("last child", func(t *testing.T) {
|
||||||
|
l1 := NewLeafNode([]byte{0xAB, 0xCD})
|
||||||
|
l2 := NewLeafNode([]byte{0xAB, 0xDE})
|
||||||
|
b := NewBranchNode()
|
||||||
|
b.Children[5] = NewHashNode(l1.Hash())
|
||||||
|
b.Children[lastChild] = NewHashNode(l2.Hash())
|
||||||
|
tr := NewBillet(b.Hash(), true, newTestStore())
|
||||||
|
tr.root = b
|
||||||
|
|
||||||
|
// OK
|
||||||
|
path := []byte{}
|
||||||
|
require.NoError(t, tr.RestoreHashNode(path, l2))
|
||||||
|
check(t, tr, b, l2, 1)
|
||||||
|
|
||||||
|
// One more time (already restored) => panic expected.
|
||||||
|
// It's an MPT pool duty to avoid such situations during real restore process.
|
||||||
|
require.Panics(t, func() {
|
||||||
|
_ = tr.RestoreHashNode(path, l2)
|
||||||
|
})
|
||||||
|
// No refcount changes expected.
|
||||||
|
check(t, tr, b, l2, 1)
|
||||||
|
|
||||||
|
// Same path, but wrong hash => error expected, no refcount changes
|
||||||
|
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed))
|
||||||
|
check(t, tr, b, l2, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("two children with same hash", func(t *testing.T) {
|
||||||
|
l := NewLeafNode([]byte{0xAB, 0xCD})
|
||||||
|
b := NewBranchNode()
|
||||||
|
// two same hashnodes => leaf's refcount expected to be 2 in the end.
|
||||||
|
b.Children[3] = NewHashNode(l.Hash())
|
||||||
|
b.Children[4] = NewHashNode(l.Hash())
|
||||||
|
tr := NewBillet(b.Hash(), true, newTestStore())
|
||||||
|
tr.root = b
|
||||||
|
|
||||||
|
// OK
|
||||||
|
require.NoError(t, tr.RestoreHashNode([]byte{0x03}, l))
|
||||||
|
expected := b
|
||||||
|
expected.Children[3].(*HashNode).Collapsed = true
|
||||||
|
check(t, tr, b, l, 1)
|
||||||
|
|
||||||
|
// Restore another node with the same hash => no error expected, refcount should be incremented.
|
||||||
|
// Branch node should be collapsed.
|
||||||
|
require.NoError(t, tr.RestoreHashNode([]byte{0x04}, l))
|
||||||
|
res := NewHashNode(b.Hash())
|
||||||
|
res.Collapsed = true
|
||||||
|
check(t, tr, res, l, 2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parent is Hash", func(t *testing.T) {
|
||||||
|
l := NewLeafNode([]byte{0xAB, 0xCD})
|
||||||
|
b := NewBranchNode()
|
||||||
|
b.Children[3] = NewHashNode(l.Hash())
|
||||||
|
b.Children[4] = NewHashNode(l.Hash())
|
||||||
|
tr := NewBillet(b.Hash(), true, newTestStore())
|
||||||
|
|
||||||
|
// Should fail, because if it's a hash node with non-empty path, then the node
|
||||||
|
// has already been collapsed.
|
||||||
|
require.Error(t, tr.RestoreHashNode([]byte{0x03}, l))
|
||||||
|
})
|
||||||
|
}
|
|
@ -89,6 +89,12 @@ func (b *BranchNode) UnmarshalJSON(data []byte) error {
|
||||||
return errors.New("expected branch node")
|
return errors.New("expected branch node")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone implements Node interface.
|
||||||
|
func (b *BranchNode) Clone() Node {
|
||||||
|
res := *b
|
||||||
|
return &res
|
||||||
|
}
|
||||||
|
|
||||||
// splitPath splits path for a branch node.
|
// splitPath splits path for a branch node.
|
||||||
func splitPath(path []byte) (byte, []byte) {
|
func splitPath(path []byte) (byte, []byte) {
|
||||||
if len(path) != 0 {
|
if len(path) != 0 {
|
||||||
|
|
|
@ -54,3 +54,6 @@ func (e EmptyNode) Type() NodeType {
|
||||||
func (e EmptyNode) Bytes() []byte {
|
func (e EmptyNode) Bytes() []byte {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone implements Node interface.
|
||||||
|
func (EmptyNode) Clone() Node { return EmptyNode{} }
|
||||||
|
|
|
@ -98,3 +98,9 @@ func (e *ExtensionNode) UnmarshalJSON(data []byte) error {
|
||||||
}
|
}
|
||||||
return errors.New("expected extension node")
|
return errors.New("expected extension node")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone implements Node interface.
|
||||||
|
func (e *ExtensionNode) Clone() Node {
|
||||||
|
res := *e
|
||||||
|
return &res
|
||||||
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
// HashNode represents MPT's hash node.
|
// HashNode represents MPT's hash node.
|
||||||
type HashNode struct {
|
type HashNode struct {
|
||||||
BaseNode
|
BaseNode
|
||||||
|
Collapsed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Node = (*HashNode)(nil)
|
var _ Node = (*HashNode)(nil)
|
||||||
|
@ -76,3 +77,10 @@ func (h *HashNode) UnmarshalJSON(data []byte) error {
|
||||||
}
|
}
|
||||||
return errors.New("expected hash node")
|
return errors.New("expected hash node")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone implements Node interface.
|
||||||
|
func (h *HashNode) Clone() Node {
|
||||||
|
res := *h
|
||||||
|
res.Collapsed = false
|
||||||
|
return &res
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package mpt
|
package mpt
|
||||||
|
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
|
||||||
// lcp returns longest common prefix of a and b.
|
// lcp returns longest common prefix of a and b.
|
||||||
// Note: it does no allocations.
|
// Note: it does no allocations.
|
||||||
func lcp(a, b []byte) []byte {
|
func lcp(a, b []byte) []byte {
|
||||||
|
@ -40,3 +42,45 @@ func toNibbles(path []byte) []byte {
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fromNibbles performs operation opposite to toNibbles and does no path validity checks.
|
||||||
|
func fromNibbles(path []byte) []byte {
|
||||||
|
result := make([]byte, len(path)/2)
|
||||||
|
for i := range result {
|
||||||
|
result[i] = path[2*i]<<4 + path[2*i+1]
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChildrenPaths returns a set of paths to node's children who are non-empty HashNodes
|
||||||
|
// based on the node's path.
|
||||||
|
func GetChildrenPaths(path []byte, node Node) map[util.Uint256][][]byte {
|
||||||
|
res := make(map[util.Uint256][][]byte)
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *LeafNode, *HashNode, EmptyNode:
|
||||||
|
return nil
|
||||||
|
case *BranchNode:
|
||||||
|
for i, child := range n.Children {
|
||||||
|
if child.Type() == HashT {
|
||||||
|
cPath := make([]byte, len(path), len(path)+1)
|
||||||
|
copy(cPath, path)
|
||||||
|
if i != lastChild {
|
||||||
|
cPath = append(cPath, byte(i))
|
||||||
|
}
|
||||||
|
paths := res[child.Hash()]
|
||||||
|
paths = append(paths, cPath)
|
||||||
|
res[child.Hash()] = paths
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *ExtensionNode:
|
||||||
|
if n.next.Type() == HashT {
|
||||||
|
cPath := make([]byte, len(path)+len(n.key))
|
||||||
|
copy(cPath, path)
|
||||||
|
copy(cPath[len(path):], n.key)
|
||||||
|
res[n.next.Hash()] = [][]byte{cPath}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("unknown Node type")
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
67
pkg/core/mpt/helpers_test.go
Normal file
67
pkg/core/mpt/helpers_test.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package mpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToNibblesFromNibbles(t *testing.T) {
|
||||||
|
check := func(t *testing.T, expected []byte) {
|
||||||
|
actual := fromNibbles(toNibbles(expected))
|
||||||
|
require.Equal(t, expected, actual)
|
||||||
|
}
|
||||||
|
t.Run("empty path", func(t *testing.T) {
|
||||||
|
check(t, []byte{})
|
||||||
|
})
|
||||||
|
t.Run("non-empty path", func(t *testing.T) {
|
||||||
|
check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetChildrenPaths(t *testing.T) {
|
||||||
|
h1 := NewHashNode(util.Uint256{1, 2, 3})
|
||||||
|
h2 := NewHashNode(util.Uint256{4, 5, 6})
|
||||||
|
h3 := NewHashNode(util.Uint256{7, 8, 9})
|
||||||
|
l := NewLeafNode([]byte{1, 2, 3})
|
||||||
|
ext1 := NewExtensionNode([]byte{8, 9}, h1)
|
||||||
|
ext2 := NewExtensionNode([]byte{7, 6}, l)
|
||||||
|
branch := NewBranchNode()
|
||||||
|
branch.Children[3] = h1
|
||||||
|
branch.Children[5] = l
|
||||||
|
branch.Children[6] = h1 // 3-th and 6-th children have the same hash
|
||||||
|
branch.Children[7] = h3
|
||||||
|
branch.Children[lastChild] = h2
|
||||||
|
testCases := map[string]struct {
|
||||||
|
node Node
|
||||||
|
expected map[util.Uint256][][]byte
|
||||||
|
}{
|
||||||
|
"Hash": {h1, nil},
|
||||||
|
"Leaf": {l, nil},
|
||||||
|
"Extension with next Hash": {ext1, map[util.Uint256][][]byte{h1.Hash(): {ext1.key}}},
|
||||||
|
"Extension with next non-Hash": {ext2, map[util.Uint256][][]byte{}},
|
||||||
|
"Branch": {branch, map[util.Uint256][][]byte{
|
||||||
|
h1.Hash(): {{0x03}, {0x06}},
|
||||||
|
h2.Hash(): {{}},
|
||||||
|
h3.Hash(): {{0x07}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
parentPath := []byte{4, 5, 6}
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
require.Equal(t, testCase.expected, GetChildrenPaths([]byte{}, testCase.node))
|
||||||
|
if testCase.expected != nil {
|
||||||
|
expectedWithPrefix := make(map[util.Uint256][][]byte, len(testCase.expected))
|
||||||
|
for h, paths := range testCase.expected {
|
||||||
|
var res [][]byte
|
||||||
|
for _, path := range paths {
|
||||||
|
res = append(res, append(parentPath, path...))
|
||||||
|
}
|
||||||
|
expectedWithPrefix[h] = res
|
||||||
|
}
|
||||||
|
require.Equal(t, expectedWithPrefix, GetChildrenPaths(parentPath, testCase.node))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -77,3 +77,9 @@ func (n *LeafNode) UnmarshalJSON(data []byte) error {
|
||||||
}
|
}
|
||||||
return errors.New("expected leaf node")
|
return errors.New("expected leaf node")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone implements Node interface.
|
||||||
|
func (n *LeafNode) Clone() Node {
|
||||||
|
res := *n
|
||||||
|
return &res
|
||||||
|
}
|
||||||
|
|
|
@ -34,6 +34,7 @@ type Node interface {
|
||||||
json.Marshaler
|
json.Marshaler
|
||||||
json.Unmarshaler
|
json.Unmarshaler
|
||||||
Size() int
|
Size() int
|
||||||
|
Clone() Node
|
||||||
BaseNodeIface
|
BaseNodeIface
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newProofTrie(t *testing.T) *Trie {
|
func newProofTrie(t *testing.T, missingHashNode bool) *Trie {
|
||||||
l := NewLeafNode([]byte("somevalue"))
|
l := NewLeafNode([]byte("somevalue"))
|
||||||
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
|
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
|
||||||
l2 := NewLeafNode([]byte("invalid"))
|
l2 := NewLeafNode([]byte("invalid"))
|
||||||
|
@ -20,11 +20,14 @@ func newProofTrie(t *testing.T) *Trie {
|
||||||
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
|
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
|
||||||
tr.putToStore(l)
|
tr.putToStore(l)
|
||||||
tr.putToStore(e)
|
tr.putToStore(e)
|
||||||
|
if !missingHashNode {
|
||||||
|
tr.putToStore(l2)
|
||||||
|
}
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrie_GetProof(t *testing.T) {
|
func TestTrie_GetProof(t *testing.T) {
|
||||||
tr := newProofTrie(t)
|
tr := newProofTrie(t, true)
|
||||||
|
|
||||||
t.Run("MissingKey", func(t *testing.T) {
|
t.Run("MissingKey", func(t *testing.T) {
|
||||||
_, err := tr.GetProof([]byte{0x12})
|
_, err := tr.GetProof([]byte{0x12})
|
||||||
|
@ -43,7 +46,7 @@ func TestTrie_GetProof(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifyProof(t *testing.T) {
|
func TestVerifyProof(t *testing.T) {
|
||||||
tr := newProofTrie(t)
|
tr := newProofTrie(t, true)
|
||||||
|
|
||||||
t.Run("Simple", func(t *testing.T) {
|
t.Run("Simple", func(t *testing.T) {
|
||||||
proof, err := tr.GetProof([]byte{0x12, 0x32})
|
proof, err := tr.GetProof([]byte{0x12, 0x32})
|
||||||
|
|
|
@ -353,3 +353,8 @@ func (s *Designate) getRole(item stackitem.Item) (noderoles.Role, bool) {
|
||||||
u := bi.Uint64()
|
u := bi.Uint64()
|
||||||
return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u))
|
return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitializeCache invalidates native Designate cache.
|
||||||
|
func (s *Designate) InitializeCache() {
|
||||||
|
s.rolesChangedFlag.Store(true)
|
||||||
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
@ -114,6 +115,66 @@ func (s *Module) Init(height uint32, enableRefCount bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CleanStorage removes all MPT-related data from the storage (MPT nodes, validated stateroots)
|
||||||
|
// except local stateroot for the current height and GC flag. This method is aimed to clean
|
||||||
|
// outdated MPT data before state sync process can be started.
|
||||||
|
// Note: this method is aimed to be called for genesis block only, an error is returned otherwice.
|
||||||
|
func (s *Module) CleanStorage() error {
|
||||||
|
if s.localHeight.Load() != 0 {
|
||||||
|
return fmt.Errorf("can't clean MPT data for non-genesis block: expected local stateroot height 0, got %d", s.localHeight.Load())
|
||||||
|
}
|
||||||
|
gcKey := []byte{byte(storage.DataMPT), prefixGC}
|
||||||
|
gcVal, err := s.Store.Get(gcKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get GC flag: %w", err)
|
||||||
|
}
|
||||||
|
//
|
||||||
|
b := s.Store.Batch()
|
||||||
|
s.Store.Seek([]byte{byte(storage.DataMPT)}, func(k, _ []byte) {
|
||||||
|
// Must copy here, #1468.
|
||||||
|
key := slice.Copy(k)
|
||||||
|
b.Delete(key)
|
||||||
|
})
|
||||||
|
err = s.Store.PutBatch(b)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove outdated MPT-reated items: %w", err)
|
||||||
|
}
|
||||||
|
err = s.Store.Put(gcKey, gcVal)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store GC flag: %w", err)
|
||||||
|
}
|
||||||
|
currentLocal := s.currentLocal.Load().(util.Uint256)
|
||||||
|
if !currentLocal.Equals(util.Uint256{}) {
|
||||||
|
err := s.addLocalStateRoot(s.Store, &state.MPTRoot{
|
||||||
|
Index: s.localHeight.Load(),
|
||||||
|
Root: currentLocal,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store current local stateroot: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// JumpToState performs jump to the state specified by given stateroot index.
|
||||||
|
func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error {
|
||||||
|
if err := s.addLocalStateRoot(s.Store, sr); err != nil {
|
||||||
|
return fmt.Errorf("failed to store local state root: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]byte, 4)
|
||||||
|
binary.LittleEndian.PutUint32(data, sr.Index)
|
||||||
|
if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixValidated}, data); err != nil {
|
||||||
|
return fmt.Errorf("failed to store validated height: %w", err)
|
||||||
|
}
|
||||||
|
s.validatedHeight.Store(sr.Index)
|
||||||
|
|
||||||
|
s.currentLocal.Store(sr.Root)
|
||||||
|
s.localHeight.Store(sr.Index)
|
||||||
|
s.mpt = mpt.NewTrie(mpt.NewHashNode(sr.Root), enableRefCount, s.Store)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddMPTBatch updates using provided batch.
|
// AddMPTBatch updates using provided batch.
|
||||||
func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) {
|
func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) {
|
||||||
mpt := *s.mpt
|
mpt := *s.mpt
|
||||||
|
|
479
pkg/core/statesync/module.go
Normal file
479
pkg/core/statesync/module.go
Normal file
|
@ -0,0 +1,479 @@
|
||||||
|
/*
|
||||||
|
Package statesync implements module for the P2P state synchronisation process. The
|
||||||
|
module manages state synchronisation for non-archival nodes which are joining the
|
||||||
|
network and don't have the ability to resync from the genesis block.
|
||||||
|
|
||||||
|
Given the currently available state synchronisation point P, sate sync process
|
||||||
|
includes the following stages:
|
||||||
|
|
||||||
|
1. Fetching headers starting from height 0 up to P+1.
|
||||||
|
2. Fetching MPT nodes for height P stating from the corresponding state root.
|
||||||
|
3. Fetching blocks starting from height P-MaxTraceableBlocks (or 0) up to P.
|
||||||
|
|
||||||
|
Steps 2 and 3 are being performed in parallel. Once all the data are collected
|
||||||
|
and stored in the db, an atomic state jump is occurred to the state sync point P.
|
||||||
|
Further node operation process is performed using standard sync mechanism until
|
||||||
|
the node reaches synchronised state.
|
||||||
|
*/
|
||||||
|
package statesync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stateSyncStage is a type of state synchronisation stage.
|
||||||
|
type stateSyncStage uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// inactive means that state exchange is disabled by the protocol configuration.
|
||||||
|
// Can't be combined with other states.
|
||||||
|
inactive stateSyncStage = 1 << iota
|
||||||
|
// none means that state exchange is enabled in the configuration, but
|
||||||
|
// initialisation of the state sync module wasn't yet performed, i.e.
|
||||||
|
// (*Module).Init wasn't called. Can't be combined with other states.
|
||||||
|
none
|
||||||
|
// initialized means that (*Module).Init was called, but other sync stages
|
||||||
|
// are not yet reached (i.e. that headers are requested, but not yet fetched).
|
||||||
|
// Can't be combined with other states.
|
||||||
|
initialized
|
||||||
|
// headersSynced means that headers for the current state sync point are fetched.
|
||||||
|
// May be combined with mptSynced and/or blocksSynced.
|
||||||
|
headersSynced
|
||||||
|
// mptSynced means that MPT nodes for the current state sync point are fetched.
|
||||||
|
// Always combined with headersSynced; may be combined with blocksSynced.
|
||||||
|
mptSynced
|
||||||
|
// blocksSynced means that blocks up to the current state sync point are stored.
|
||||||
|
// Always combined with headersSynced; may be combined with mptSynced.
|
||||||
|
blocksSynced
|
||||||
|
)
|
||||||
|
|
||||||
|
// Module represents state sync module and aimed to gather state-related data to
|
||||||
|
// perform an atomic state jump.
|
||||||
|
type Module struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
log *zap.Logger
|
||||||
|
|
||||||
|
// syncPoint is the state synchronisation point P we're currently working against.
|
||||||
|
syncPoint uint32
|
||||||
|
// syncStage is the stage of the sync process.
|
||||||
|
syncStage stateSyncStage
|
||||||
|
// syncInterval is the delta between two adjacent state sync points.
|
||||||
|
syncInterval uint32
|
||||||
|
// blockHeight is the index of the latest stored block.
|
||||||
|
blockHeight uint32
|
||||||
|
|
||||||
|
dao *dao.Simple
|
||||||
|
bc blockchainer.Blockchainer
|
||||||
|
mptpool *Pool
|
||||||
|
|
||||||
|
billet *mpt.Billet
|
||||||
|
|
||||||
|
jumpCallback func(p uint32) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewModule returns new instance of statesync module.
|
||||||
|
func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple, jumpCallback func(p uint32) error) *Module {
|
||||||
|
if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) {
|
||||||
|
return &Module{
|
||||||
|
dao: s,
|
||||||
|
bc: bc,
|
||||||
|
syncStage: inactive,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Module{
|
||||||
|
dao: s,
|
||||||
|
bc: bc,
|
||||||
|
log: log,
|
||||||
|
syncInterval: uint32(bc.GetConfig().StateSyncInterval),
|
||||||
|
mptpool: NewPool(),
|
||||||
|
syncStage: none,
|
||||||
|
jumpCallback: jumpCallback,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init initializes state sync module for the current chain's height with given
|
||||||
|
// callback for MPT nodes requests.
|
||||||
|
func (s *Module) Init(currChainHeight uint32) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
if s.syncStage != none {
|
||||||
|
return errors.New("already initialized or inactive")
|
||||||
|
}
|
||||||
|
|
||||||
|
p := (currChainHeight / s.syncInterval) * s.syncInterval
|
||||||
|
if p < 2*s.syncInterval {
|
||||||
|
// chain is too low to start state exchange process, use the standard sync mechanism
|
||||||
|
s.syncStage = inactive
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pOld, err := s.dao.GetStateSyncPoint()
|
||||||
|
if err == nil && pOld >= p-s.syncInterval {
|
||||||
|
// old point is still valid, so try to resync states for this point.
|
||||||
|
p = pOld
|
||||||
|
} else {
|
||||||
|
if s.bc.BlockHeight() > p-2*s.syncInterval {
|
||||||
|
// chain has already been synchronised up to old state sync point and regular blocks processing was started.
|
||||||
|
// Current block height is enough to start regular blocks processing.
|
||||||
|
s.syncStage = inactive
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
// pOld was found, it is outdated, and chain wasn't completely synchronised for pOld. Need to drop the db.
|
||||||
|
return fmt.Errorf("state sync point %d is found in the storage, "+
|
||||||
|
"but sync process wasn't completed and point is outdated. Please, drop the database manually and restart the node to run state sync process", pOld)
|
||||||
|
}
|
||||||
|
if s.bc.BlockHeight() != 0 {
|
||||||
|
// pOld wasn't found, but blocks processing was started in a regular manner and latest stored block is too outdated
|
||||||
|
// to start regular blocks processing again. Need to drop the db.
|
||||||
|
return fmt.Errorf("current chain's height is too low to start regular blocks processing from the oldest sync point %d. "+
|
||||||
|
"Please, drop the database manually and restart the node to run state sync process", p-s.syncInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We've reached this point, so chain has genesis block only. As far as we can't ruin
|
||||||
|
// current chain's state until new state is completely fetched, outdated state-related data
|
||||||
|
// will be removed from storage during (*Blockchain).jumpToState(...) execution.
|
||||||
|
// All we need to do right now is to remove genesis-related MPT nodes.
|
||||||
|
err = s.bc.GetStateModule().CleanStorage()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove outdated MPT data from storage: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.syncPoint = p
|
||||||
|
err = s.dao.PutStateSyncPoint(p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store state synchronisation point %d: %w", p, err)
|
||||||
|
}
|
||||||
|
s.syncStage = initialized
|
||||||
|
s.log.Info("try to sync state for the latest state synchronisation point",
|
||||||
|
zap.Uint32("point", p),
|
||||||
|
zap.Uint32("evaluated chain's blockHeight", currChainHeight))
|
||||||
|
|
||||||
|
return s.defineSyncStage()
|
||||||
|
}
|
||||||
|
|
||||||
|
// defineSyncStage sequentially checks and sets sync state process stage after Module
|
||||||
|
// initialization. It also performs initialization of MPT Billet if necessary.
|
||||||
|
func (s *Module) defineSyncStage() error {
|
||||||
|
// check headers sync stage first
|
||||||
|
ltstHeaderHeight := s.bc.HeaderHeight()
|
||||||
|
if ltstHeaderHeight > s.syncPoint {
|
||||||
|
s.syncStage = headersSynced
|
||||||
|
s.log.Info("headers are in sync",
|
||||||
|
zap.Uint32("headerHeight", s.bc.HeaderHeight()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// check blocks sync stage
|
||||||
|
s.blockHeight = s.getLatestSavedBlock(s.syncPoint)
|
||||||
|
if s.blockHeight >= s.syncPoint {
|
||||||
|
s.syncStage |= blocksSynced
|
||||||
|
s.log.Info("blocks are in sync",
|
||||||
|
zap.Uint32("blockHeight", s.blockHeight))
|
||||||
|
}
|
||||||
|
|
||||||
|
// check MPT sync stage
|
||||||
|
if s.blockHeight > s.syncPoint {
|
||||||
|
s.syncStage |= mptSynced
|
||||||
|
s.log.Info("MPT is in sync",
|
||||||
|
zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight()))
|
||||||
|
} else if s.syncStage&headersSynced != 0 {
|
||||||
|
header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint + 1)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get header to initialize MPT billet: %w", err)
|
||||||
|
}
|
||||||
|
s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store)
|
||||||
|
s.log.Info("MPT billet initialized",
|
||||||
|
zap.Uint32("height", s.syncPoint),
|
||||||
|
zap.String("state root", header.PrevStateRoot.StringBE()))
|
||||||
|
pool := NewPool()
|
||||||
|
pool.Add(header.PrevStateRoot, []byte{})
|
||||||
|
err = s.billet.Traverse(func(n mpt.Node, _ []byte) bool {
|
||||||
|
nPaths, ok := pool.TryGet(n.Hash())
|
||||||
|
if !ok {
|
||||||
|
// if this situation occurs, then it's a bug in MPT pool or Traverse.
|
||||||
|
panic("failed to get MPT node from the pool")
|
||||||
|
}
|
||||||
|
pool.Remove(n.Hash())
|
||||||
|
childrenPaths := make(map[util.Uint256][][]byte)
|
||||||
|
for _, path := range nPaths {
|
||||||
|
nChildrenPaths := mpt.GetChildrenPaths(path, n)
|
||||||
|
for hash, paths := range nChildrenPaths {
|
||||||
|
childrenPaths[hash] = append(childrenPaths[hash], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pool.Update(nil, childrenPaths)
|
||||||
|
return false
|
||||||
|
}, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to traverse MPT during initialization: %w", err)
|
||||||
|
}
|
||||||
|
s.mptpool.Update(nil, pool.GetAll())
|
||||||
|
if s.mptpool.Count() == 0 {
|
||||||
|
s.syncStage |= mptSynced
|
||||||
|
s.log.Info("MPT is in sync",
|
||||||
|
zap.Uint32("stateroot height", s.syncPoint))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.syncStage == headersSynced|blocksSynced|mptSynced {
|
||||||
|
s.log.Info("state is in sync, starting regular blocks processing")
|
||||||
|
s.syncStage = inactive
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLatestSavedBlock returns either current block index (if it's still relevant
|
||||||
|
// to continue state sync process) or H-1 where H is the index of the earliest
|
||||||
|
// block that should be saved next.
|
||||||
|
func (s *Module) getLatestSavedBlock(p uint32) uint32 {
|
||||||
|
var result uint32
|
||||||
|
mtb := s.bc.GetConfig().MaxTraceableBlocks
|
||||||
|
if p > mtb {
|
||||||
|
result = p - mtb
|
||||||
|
}
|
||||||
|
storedH, err := s.dao.GetStateSyncCurrentBlockHeight()
|
||||||
|
if err == nil && storedH > result {
|
||||||
|
result = storedH
|
||||||
|
}
|
||||||
|
actualH := s.bc.BlockHeight()
|
||||||
|
if actualH > result {
|
||||||
|
result = actualH
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHeaders validates and adds specified headers to the chain.
|
||||||
|
func (s *Module) AddHeaders(hdrs ...*block.Header) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
if s.syncStage != initialized {
|
||||||
|
return errors.New("headers were not requested")
|
||||||
|
}
|
||||||
|
|
||||||
|
hdrsErr := s.bc.AddHeaders(hdrs...)
|
||||||
|
if s.bc.HeaderHeight() > s.syncPoint {
|
||||||
|
err := s.defineSyncStage()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to define current sync stage: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hdrsErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBlock verifies and saves block skipping executable scripts.
|
||||||
|
func (s *Module) AddBlock(block *block.Block) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
if s.syncStage&headersSynced == 0 || s.syncStage&blocksSynced != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.blockHeight == s.syncPoint {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
expectedHeight := s.blockHeight + 1
|
||||||
|
if expectedHeight != block.Index {
|
||||||
|
return fmt.Errorf("expected %d, got %d: invalid block index", expectedHeight, block.Index)
|
||||||
|
}
|
||||||
|
if s.bc.GetConfig().StateRootInHeader != block.StateRootEnabled {
|
||||||
|
return fmt.Errorf("stateroot setting mismatch: %v != %v", s.bc.GetConfig().StateRootInHeader, block.StateRootEnabled)
|
||||||
|
}
|
||||||
|
if s.bc.GetConfig().VerifyBlocks {
|
||||||
|
merkle := block.ComputeMerkleRoot()
|
||||||
|
if !block.MerkleRoot.Equals(merkle) {
|
||||||
|
return errors.New("invalid block: MerkleRoot mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache := s.dao.GetWrapped()
|
||||||
|
writeBuf := io.NewBufBinWriter()
|
||||||
|
if err := cache.StoreAsBlock(block, writeBuf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
writeBuf.Reset()
|
||||||
|
|
||||||
|
err := cache.PutStateSyncCurrentBlockHeight(block.Index)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to store current block height: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tx := range block.Transactions {
|
||||||
|
if err := cache.StoreAsTransaction(tx, block.Index, writeBuf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
writeBuf.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = cache.Persist()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to persist results: %w", err)
|
||||||
|
}
|
||||||
|
s.blockHeight = block.Index
|
||||||
|
if s.blockHeight == s.syncPoint {
|
||||||
|
s.syncStage |= blocksSynced
|
||||||
|
s.log.Info("blocks are in sync",
|
||||||
|
zap.Uint32("blockHeight", s.blockHeight))
|
||||||
|
s.checkSyncIsCompleted()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMPTNodes tries to add provided set of MPT nodes to the MPT billet if they are
|
||||||
|
// not yet collected.
|
||||||
|
func (s *Module) AddMPTNodes(nodes [][]byte) error {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
if s.syncStage&headersSynced == 0 || s.syncStage&mptSynced != 0 {
|
||||||
|
return errors.New("MPT nodes were not requested")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, nBytes := range nodes {
|
||||||
|
var n mpt.NodeObject
|
||||||
|
r := io.NewBinReaderFromBuf(nBytes)
|
||||||
|
n.DecodeBinary(r)
|
||||||
|
if r.Err != nil {
|
||||||
|
return fmt.Errorf("failed to decode MPT node: %w", r.Err)
|
||||||
|
}
|
||||||
|
err := s.restoreNode(n.Node)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.mptpool.Count() == 0 {
|
||||||
|
s.syncStage |= mptSynced
|
||||||
|
s.log.Info("MPT is in sync",
|
||||||
|
zap.Uint32("height", s.syncPoint))
|
||||||
|
s.checkSyncIsCompleted()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Module) restoreNode(n mpt.Node) error {
|
||||||
|
nPaths, ok := s.mptpool.TryGet(n.Hash())
|
||||||
|
if !ok {
|
||||||
|
// it can easily happen after receiving the same data from different peers.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var childrenPaths = make(map[util.Uint256][][]byte)
|
||||||
|
for _, path := range nPaths {
|
||||||
|
// Must clone here in order to avoid future collapse collisions. If the node's refcount>1 then MPT pool
|
||||||
|
// will manage all paths for this node and call RestoreHashNode separately for each of the paths.
|
||||||
|
err := s.billet.RestoreHashNode(path, n.Clone())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
|
||||||
|
}
|
||||||
|
for h, paths := range mpt.GetChildrenPaths(path, n) {
|
||||||
|
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
|
||||||
|
|
||||||
|
for h := range childrenPaths {
|
||||||
|
if child, err := s.billet.GetFromStore(h); err == nil {
|
||||||
|
// child is already in the storage, so we don't need to request it one more time.
|
||||||
|
err = s.restoreNode(child)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to restore saved children: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1
|
||||||
|
// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored.
|
||||||
|
// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller
|
||||||
|
// should take care of it.
|
||||||
|
func (s *Module) checkSyncIsCompleted() {
|
||||||
|
if s.syncStage != headersSynced|mptSynced|blocksSynced {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.log.Info("state is in sync",
|
||||||
|
zap.Uint32("state sync point", s.syncPoint))
|
||||||
|
err := s.jumpCallback(s.syncPoint)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err))
|
||||||
|
}
|
||||||
|
s.syncStage = inactive
|
||||||
|
s.dispose()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Module) dispose() {
|
||||||
|
s.billet = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockHeight returns index of the last stored block.
|
||||||
|
func (s *Module) BlockHeight() uint32 {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.blockHeight
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsActive tells whether state sync module is on and still gathering state
|
||||||
|
// synchronisation data (headers, blocks or MPT nodes).
|
||||||
|
func (s *Module) IsActive() bool {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return !(s.syncStage == inactive || (s.syncStage == headersSynced|mptSynced|blocksSynced))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInitialized tells whether state sync module does not require initialization.
|
||||||
|
// If `false` is returned then Init can be safely called.
|
||||||
|
func (s *Module) IsInitialized() bool {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.syncStage != none
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedHeaders tells whether the module hasn't completed headers synchronisation.
|
||||||
|
func (s *Module) NeedHeaders() bool {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.syncStage == initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedMPTNodes returns whether the module hasn't completed MPT synchronisation.
|
||||||
|
func (s *Module) NeedMPTNodes() bool {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.syncStage&headersSynced != 0 && s.syncStage&mptSynced == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse traverses local MPT nodes starting from the specified root down to its
|
||||||
|
// children calling `process` for each serialised node until stop condition is satisfied.
|
||||||
|
func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
b := mpt.NewBillet(root, s.bc.GetConfig().KeepOnlyLatestState, storage.NewMemCachedStore(s.dao.Store))
|
||||||
|
return b.Traverse(process, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max).
|
||||||
|
func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.mptpool.GetBatch(limit)
|
||||||
|
}
|
106
pkg/core/statesync/module_test.go
Normal file
106
pkg/core/statesync/module_test.go
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
package statesync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/zap/zaptest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModule_PR2019_discussion_r689629704(t *testing.T) {
|
||||||
|
expectedStorage := storage.NewMemCachedStore(storage.NewMemoryStore())
|
||||||
|
tr := mpt.NewTrie(nil, true, expectedStorage)
|
||||||
|
require.NoError(t, tr.Put([]byte{0x03}, []byte("leaf1")))
|
||||||
|
require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x02}, []byte("leaf2")))
|
||||||
|
require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x04}, []byte("leaf3")))
|
||||||
|
require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x02}, []byte("leaf2"))) // <-- the same `leaf2` and `leaf3` values are put in the storage,
|
||||||
|
require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x04}, []byte("leaf3"))) // <-- but the path should differ.
|
||||||
|
require.NoError(t, tr.Put([]byte{0x06, 0x03}, []byte("leaf4")))
|
||||||
|
|
||||||
|
sr := tr.StateRoot()
|
||||||
|
tr.Flush()
|
||||||
|
|
||||||
|
// Keep MPT nodes in a map in order not to repeat them. We'll use `nodes` map to ask
|
||||||
|
// state sync module to restore the nodes.
|
||||||
|
var (
|
||||||
|
nodes = make(map[util.Uint256][]byte)
|
||||||
|
expectedItems []storage.KeyValue
|
||||||
|
)
|
||||||
|
expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) {
|
||||||
|
key := slice.Copy(k)
|
||||||
|
value := slice.Copy(v)
|
||||||
|
expectedItems = append(expectedItems, storage.KeyValue{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
})
|
||||||
|
hash, err := util.Uint256DecodeBytesBE(key[1:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
nodeBytes := value[:len(value)-4]
|
||||||
|
nodes[hash] = nodeBytes
|
||||||
|
})
|
||||||
|
|
||||||
|
actualStorage := storage.NewMemCachedStore(storage.NewMemoryStore())
|
||||||
|
// These actions are done in module.Init(), but it's not the point of the test.
|
||||||
|
// Here we want to test only MPT restoring process.
|
||||||
|
stateSync := &Module{
|
||||||
|
log: zaptest.NewLogger(t),
|
||||||
|
syncPoint: 1000500,
|
||||||
|
syncStage: headersSynced,
|
||||||
|
syncInterval: 100500,
|
||||||
|
dao: dao.NewSimple(actualStorage, true, false),
|
||||||
|
mptpool: NewPool(),
|
||||||
|
}
|
||||||
|
stateSync.billet = mpt.NewBillet(sr, true, actualStorage)
|
||||||
|
stateSync.mptpool.Add(sr, []byte{})
|
||||||
|
|
||||||
|
// The test itself: we'll ask state sync module to restore each node exactly once.
|
||||||
|
// After that storage content (including storage items and refcounts) must
|
||||||
|
// match exactly the one got from real MPT trie. MPT pool must be empty.
|
||||||
|
// State sync module must have mptSynced state in the end.
|
||||||
|
// MPT Billet root must become a collapsed hashnode (it was checked manually).
|
||||||
|
requested := make(map[util.Uint256]struct{})
|
||||||
|
for {
|
||||||
|
unknownHashes := stateSync.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one
|
||||||
|
if len(unknownHashes) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
h := unknownHashes[0]
|
||||||
|
node, ok := nodes[h]
|
||||||
|
if !ok {
|
||||||
|
if _, ok = requested[h]; ok {
|
||||||
|
t.Fatal("node was requested twice")
|
||||||
|
}
|
||||||
|
t.Fatal("unknown node was requested")
|
||||||
|
}
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
err := stateSync.AddMPTNodes([][]byte{node})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}, fmt.Errorf("hash=%s, value=%s", h.StringBE(), string(node)))
|
||||||
|
requested[h] = struct{}{}
|
||||||
|
delete(nodes, h)
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, headersSynced|mptSynced, stateSync.syncStage, "all nodes were sent exactly ones, but MPT wasn't restored")
|
||||||
|
require.Equal(t, 0, len(nodes), "not all nodes were requested by state sync module")
|
||||||
|
require.Equal(t, 0, stateSync.mptpool.Count(), "MPT was restored, but MPT pool still contains items")
|
||||||
|
|
||||||
|
// Compare resulting storage items and refcounts.
|
||||||
|
var actualItems []storage.KeyValue
|
||||||
|
expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) {
|
||||||
|
key := slice.Copy(k)
|
||||||
|
value := slice.Copy(v)
|
||||||
|
actualItems = append(actualItems, storage.KeyValue{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
require.ElementsMatch(t, expectedItems, actualItems)
|
||||||
|
}
|
142
pkg/core/statesync/mptpool.go
Normal file
142
pkg/core/statesync/mptpool.go
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
package statesync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Pool stores unknown MPT nodes along with the corresponding paths (single node is
|
||||||
|
// allowed to have multiple MPT paths).
|
||||||
|
type Pool struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
hashes map[util.Uint256][][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPool returns new MPT node hashes pool.
|
||||||
|
func NewPool() *Pool {
|
||||||
|
return &Pool{
|
||||||
|
hashes: make(map[util.Uint256][][]byte),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContainsKey checks if MPT node with the specified hash is in the Pool.
|
||||||
|
func (mp *Pool) ContainsKey(hash util.Uint256) bool {
|
||||||
|
mp.lock.RLock()
|
||||||
|
defer mp.lock.RUnlock()
|
||||||
|
|
||||||
|
_, ok := mp.hashes[hash]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryGet returns a set of MPT paths for the specified HashNode.
|
||||||
|
func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) {
|
||||||
|
mp.lock.RLock()
|
||||||
|
defer mp.lock.RUnlock()
|
||||||
|
|
||||||
|
paths, ok := mp.hashes[hash]
|
||||||
|
// need to copy here, because we can modify existing array of paths inside the pool.
|
||||||
|
res := make([][]byte, len(paths))
|
||||||
|
copy(res, paths)
|
||||||
|
return res, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns all MPT nodes with the corresponding paths from the pool.
|
||||||
|
func (mp *Pool) GetAll() map[util.Uint256][][]byte {
|
||||||
|
mp.lock.RLock()
|
||||||
|
defer mp.lock.RUnlock()
|
||||||
|
|
||||||
|
return mp.hashes
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBatch returns set of unknown MPT nodes hashes (`limit` at max).
|
||||||
|
func (mp *Pool) GetBatch(limit int) []util.Uint256 {
|
||||||
|
mp.lock.RLock()
|
||||||
|
defer mp.lock.RUnlock()
|
||||||
|
|
||||||
|
count := len(mp.hashes)
|
||||||
|
if count > limit {
|
||||||
|
count = limit
|
||||||
|
}
|
||||||
|
result := make([]util.Uint256, 0, limit)
|
||||||
|
for h := range mp.hashes {
|
||||||
|
if count == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, h)
|
||||||
|
count--
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes MPT node from the pool by the specified hash.
|
||||||
|
func (mp *Pool) Remove(hash util.Uint256) {
|
||||||
|
mp.lock.Lock()
|
||||||
|
defer mp.lock.Unlock()
|
||||||
|
|
||||||
|
delete(mp.hashes, hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds path to the set of paths for the specified node.
|
||||||
|
func (mp *Pool) Add(hash util.Uint256, path []byte) {
|
||||||
|
mp.lock.Lock()
|
||||||
|
defer mp.lock.Unlock()
|
||||||
|
|
||||||
|
mp.addPaths(hash, [][]byte{path})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update is an atomic operation and removes/adds specified nodes from/to the pool.
|
||||||
|
func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) {
|
||||||
|
mp.lock.Lock()
|
||||||
|
defer mp.lock.Unlock()
|
||||||
|
|
||||||
|
for h, paths := range remove {
|
||||||
|
old := mp.hashes[h]
|
||||||
|
for _, path := range paths {
|
||||||
|
i := sort.Search(len(old), func(i int) bool {
|
||||||
|
return bytes.Compare(old[i], path) >= 0
|
||||||
|
})
|
||||||
|
if i < len(old) && bytes.Equal(old[i], path) {
|
||||||
|
old = append(old[:i], old[i+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(old) == 0 {
|
||||||
|
delete(mp.hashes, h)
|
||||||
|
} else {
|
||||||
|
mp.hashes[h] = old
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for h, paths := range add {
|
||||||
|
mp.addPaths(h, paths)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPaths adds set of the specified node paths to the pool.
|
||||||
|
func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) {
|
||||||
|
old := mp.hashes[nodeHash]
|
||||||
|
for _, path := range paths {
|
||||||
|
i := sort.Search(len(old), func(i int) bool {
|
||||||
|
return bytes.Compare(old[i], path) >= 0
|
||||||
|
})
|
||||||
|
if i < len(old) && bytes.Equal(old[i], path) {
|
||||||
|
// then path is already added
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
old = append(old, path)
|
||||||
|
if i != len(old)-1 {
|
||||||
|
copy(old[i+1:], old[i:])
|
||||||
|
old[i] = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mp.hashes[nodeHash] = old
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of nodes in the pool.
|
||||||
|
func (mp *Pool) Count() int {
|
||||||
|
mp.lock.RLock()
|
||||||
|
defer mp.lock.RUnlock()
|
||||||
|
|
||||||
|
return len(mp.hashes)
|
||||||
|
}
|
123
pkg/core/statesync/mptpool_test.go
Normal file
123
pkg/core/statesync/mptpool_test.go
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
package statesync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/internal/random"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPool_AddRemoveUpdate(t *testing.T) {
|
||||||
|
mp := NewPool()
|
||||||
|
|
||||||
|
i1 := []byte{1, 2, 3}
|
||||||
|
i1h := util.Uint256{1, 2, 3}
|
||||||
|
i2 := []byte{2, 3, 4}
|
||||||
|
i2h := util.Uint256{2, 3, 4}
|
||||||
|
i3 := []byte{4, 5, 6}
|
||||||
|
i3h := util.Uint256{3, 4, 5}
|
||||||
|
i4 := []byte{3, 4, 5} // has the same hash as i3
|
||||||
|
i5 := []byte{6, 7, 8} // has the same hash as i3
|
||||||
|
mapAll := map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}, i3h: {i4, i3}}
|
||||||
|
|
||||||
|
// No items
|
||||||
|
_, ok := mp.TryGet(i1h)
|
||||||
|
require.False(t, ok)
|
||||||
|
require.False(t, mp.ContainsKey(i1h))
|
||||||
|
require.Equal(t, 0, mp.Count())
|
||||||
|
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
|
||||||
|
|
||||||
|
// Add i1, i2, check OK
|
||||||
|
mp.Add(i1h, i1)
|
||||||
|
mp.Add(i2h, i2)
|
||||||
|
itm, ok := mp.TryGet(i1h)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, [][]byte{i1}, itm)
|
||||||
|
require.True(t, mp.ContainsKey(i1h))
|
||||||
|
require.True(t, mp.ContainsKey(i2h))
|
||||||
|
require.Equal(t, map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, mp.GetAll())
|
||||||
|
require.Equal(t, 2, mp.Count())
|
||||||
|
|
||||||
|
// Remove i1 and unexisting item
|
||||||
|
mp.Remove(i3h)
|
||||||
|
mp.Remove(i1h)
|
||||||
|
require.False(t, mp.ContainsKey(i1h))
|
||||||
|
require.True(t, mp.ContainsKey(i2h))
|
||||||
|
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}}, mp.GetAll())
|
||||||
|
require.Equal(t, 1, mp.Count())
|
||||||
|
|
||||||
|
// Update: remove nothing, add all
|
||||||
|
mp.Update(nil, mapAll)
|
||||||
|
require.Equal(t, mapAll, mp.GetAll())
|
||||||
|
require.Equal(t, 3, mp.Count())
|
||||||
|
// Update: remove all, add all
|
||||||
|
mp.Update(mapAll, mapAll)
|
||||||
|
require.Equal(t, mapAll, mp.GetAll()) // deletion first, addition after that
|
||||||
|
require.Equal(t, 3, mp.Count())
|
||||||
|
// Update: remove all, add nothing
|
||||||
|
mp.Update(mapAll, nil)
|
||||||
|
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
|
||||||
|
require.Equal(t, 0, mp.Count())
|
||||||
|
// Update: remove several, add several
|
||||||
|
mp.Update(map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}})
|
||||||
|
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
|
||||||
|
require.Equal(t, 2, mp.Count())
|
||||||
|
|
||||||
|
// Update: remove nothing, add several with same hashes
|
||||||
|
mp.Update(nil, map[util.Uint256][][]byte{i3h: {i5, i4}}) // should be sorted by the pool
|
||||||
|
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i3, i5}}, mp.GetAll())
|
||||||
|
require.Equal(t, 2, mp.Count())
|
||||||
|
// Update: remove several with same hashes, add nothing
|
||||||
|
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i4}}, nil)
|
||||||
|
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
|
||||||
|
require.Equal(t, 2, mp.Count())
|
||||||
|
// Update: remove several with same hashes, add several with same hashes
|
||||||
|
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i3}}, map[util.Uint256][][]byte{i3h: {i5, i4}})
|
||||||
|
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i5}}, mp.GetAll())
|
||||||
|
require.Equal(t, 2, mp.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPool_GetBatch(t *testing.T) {
|
||||||
|
check := func(t *testing.T, limit int, itemsCount int) {
|
||||||
|
mp := NewPool()
|
||||||
|
for i := 0; i < itemsCount; i++ {
|
||||||
|
mp.Add(random.Uint256(), []byte{0x01})
|
||||||
|
}
|
||||||
|
batch := mp.GetBatch(limit)
|
||||||
|
if limit < itemsCount {
|
||||||
|
require.Equal(t, limit, len(batch))
|
||||||
|
} else {
|
||||||
|
require.Equal(t, itemsCount, len(batch))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("limit less than items count", func(t *testing.T) {
|
||||||
|
check(t, 5, 6)
|
||||||
|
})
|
||||||
|
t.Run("limit more than items count", func(t *testing.T) {
|
||||||
|
check(t, 6, 5)
|
||||||
|
})
|
||||||
|
t.Run("items count limit", func(t *testing.T) {
|
||||||
|
check(t, 5, 5)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPool_UpdateUsingSliceFromPool(t *testing.T) {
|
||||||
|
mp := NewPool()
|
||||||
|
p1, _ := hex.DecodeString("0f0a0f0f0f0f0f0f0104020b02080c0a06050e070b050404060206060d07080602030b04040b050e040406030f0708060c05")
|
||||||
|
p2, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040a0b000f04000b03090b02090b0e040f0d0b060d070e0b0b090b0906080602060c0d0f0e0d04070e")
|
||||||
|
p3, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040b010d01080f050f000a0d0e08060c040b050800050904060f050807080a080c07040d0107080007")
|
||||||
|
h, _ := util.Uint256DecodeStringBE("57e197679ef031bf2f0b466b20afe3f67ac04dcff80a1dc4d12dd98dd21a2511")
|
||||||
|
mp.Add(h, p1)
|
||||||
|
mp.Add(h, p2)
|
||||||
|
mp.Add(h, p3)
|
||||||
|
|
||||||
|
toBeRemoved, ok := mp.TryGet(h)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
mp.Update(map[util.Uint256][][]byte{h: toBeRemoved}, nil)
|
||||||
|
// test that all items were successfully removed.
|
||||||
|
require.Equal(t, 0, len(mp.GetAll()))
|
||||||
|
}
|
435
pkg/core/statesync_test.go
Normal file
435
pkg/core/statesync_test.go
Normal file
|
@ -0,0 +1,435 @@
|
||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/config"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStateSyncModule_Init(t *testing.T) {
|
||||||
|
var (
|
||||||
|
stateSyncInterval = 2
|
||||||
|
maxTraceable uint32 = 3
|
||||||
|
)
|
||||||
|
spoutCfg := func(c *config.Config) {
|
||||||
|
c.ProtocolConfiguration.StateRootInHeader = true
|
||||||
|
c.ProtocolConfiguration.P2PStateExchangeExtensions = true
|
||||||
|
c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval
|
||||||
|
c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable
|
||||||
|
}
|
||||||
|
bcSpout := newTestChainWithCustomCfg(t, spoutCfg)
|
||||||
|
for i := 0; i <= 2*stateSyncInterval+int(maxTraceable)+1; i++ {
|
||||||
|
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||||
|
}
|
||||||
|
|
||||||
|
boltCfg := func(c *config.Config) {
|
||||||
|
spoutCfg(c)
|
||||||
|
c.ProtocolConfiguration.KeepOnlyLatestState = true
|
||||||
|
c.ProtocolConfiguration.RemoveUntraceableBlocks = true
|
||||||
|
}
|
||||||
|
t.Run("error: module disabled by config", func(t *testing.T) {
|
||||||
|
bcBolt := newTestChainWithCustomCfg(t, func(c *config.Config) {
|
||||||
|
boltCfg(c)
|
||||||
|
c.ProtocolConfiguration.RemoveUntraceableBlocks = false
|
||||||
|
})
|
||||||
|
module := bcBolt.GetStateSyncModule()
|
||||||
|
require.Error(t, module.Init(bcSpout.BlockHeight())) // module inactive (non-archival node)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inactive: spout chain is too low to start state sync process", func(t *testing.T) {
|
||||||
|
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||||
|
module := bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(uint32(2*stateSyncInterval-1)))
|
||||||
|
require.False(t, module.IsActive())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inactive: bolt chain height is close enough to spout chain height", func(t *testing.T) {
|
||||||
|
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||||
|
for i := 1; i < int(bcSpout.BlockHeight())-stateSyncInterval; i++ {
|
||||||
|
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, bcBolt.AddBlock(b))
|
||||||
|
}
|
||||||
|
module := bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.False(t, module.IsActive())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("error: bolt chain is too low to start state sync process", func(t *testing.T) {
|
||||||
|
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||||
|
require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock()))
|
||||||
|
|
||||||
|
module := bcBolt.GetStateSyncModule()
|
||||||
|
require.Error(t, module.Init(uint32(3*stateSyncInterval)))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("initialized: no previous state sync point", func(t *testing.T) {
|
||||||
|
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||||
|
|
||||||
|
module := bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.True(t, module.NeedHeaders())
|
||||||
|
require.False(t, module.NeedMPTNodes())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("error: outdated state sync point in the storage", func(t *testing.T) {
|
||||||
|
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||||
|
module := bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
|
||||||
|
module = bcBolt.GetStateSyncModule()
|
||||||
|
require.Error(t, module.Init(bcSpout.BlockHeight()+2*uint32(stateSyncInterval)))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("initialized: valid previous state sync point in the storage", func(t *testing.T) {
|
||||||
|
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||||
|
module := bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
|
||||||
|
module = bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.True(t, module.NeedHeaders())
|
||||||
|
require.False(t, module.NeedMPTNodes())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("initialization from headers/blocks/mpt synced stages", func(t *testing.T) {
|
||||||
|
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||||
|
module := bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
|
||||||
|
// firstly, fetch all headers to create proper DB state (where headers are in sync)
|
||||||
|
stateSyncPoint := (int(bcSpout.BlockHeight()) / stateSyncInterval) * stateSyncInterval
|
||||||
|
var expectedHeader *block.Header
|
||||||
|
for i := 1; i <= int(bcSpout.HeaderHeight()); i++ {
|
||||||
|
header, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, module.AddHeaders(header))
|
||||||
|
if i == stateSyncPoint+1 {
|
||||||
|
expectedHeader = header
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.True(t, module.NeedMPTNodes())
|
||||||
|
|
||||||
|
// then create new statesync module with the same DB and check that state is proper
|
||||||
|
// (headers are in sync)
|
||||||
|
module = bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.True(t, module.NeedMPTNodes())
|
||||||
|
unknownNodes := module.GetUnknownMPTNodesBatch(2)
|
||||||
|
require.Equal(t, 1, len(unknownNodes))
|
||||||
|
require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0])
|
||||||
|
|
||||||
|
// add several blocks to create DB state where blocks are not in sync yet, but it's not a genesis.
|
||||||
|
for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint-stateSyncInterval-1; i++ {
|
||||||
|
block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, module.AddBlock(block))
|
||||||
|
}
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.True(t, module.NeedMPTNodes())
|
||||||
|
require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight())
|
||||||
|
|
||||||
|
// then create new statesync module with the same DB and check that state is proper
|
||||||
|
// (blocks are not in sync yet)
|
||||||
|
module = bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.True(t, module.NeedMPTNodes())
|
||||||
|
unknownNodes = module.GetUnknownMPTNodesBatch(2)
|
||||||
|
require.Equal(t, 1, len(unknownNodes))
|
||||||
|
require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0])
|
||||||
|
require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight())
|
||||||
|
|
||||||
|
// add rest of blocks to create DB state where blocks are in sync
|
||||||
|
for i := stateSyncPoint - stateSyncInterval; i <= stateSyncPoint; i++ {
|
||||||
|
block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, module.AddBlock(block))
|
||||||
|
}
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.True(t, module.NeedMPTNodes())
|
||||||
|
lastBlock, err := bcBolt.GetBlock(expectedHeader.PrevHash)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), lastBlock.Index)
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||||
|
|
||||||
|
// then create new statesync module with the same DB and check that state is proper
|
||||||
|
// (headers and blocks are in sync)
|
||||||
|
module = bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.True(t, module.NeedMPTNodes())
|
||||||
|
unknownNodes = module.GetUnknownMPTNodesBatch(2)
|
||||||
|
require.Equal(t, 1, len(unknownNodes))
|
||||||
|
require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0])
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||||
|
|
||||||
|
// add a few MPT nodes to create DB state where some of MPT nodes are missing
|
||||||
|
count := 5
|
||||||
|
for {
|
||||||
|
unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one
|
||||||
|
if len(unknownHashes) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool {
|
||||||
|
require.NoError(t, module.AddMPTNodes([][]byte{nodeBytes}))
|
||||||
|
return true // add nodes one-by-one
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
count--
|
||||||
|
if count < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// then create new statesync module with the same DB and check that state is proper
|
||||||
|
// (headers and blocks are in sync, mpt is not yet synced)
|
||||||
|
module = bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.True(t, module.NeedMPTNodes())
|
||||||
|
unknownNodes = module.GetUnknownMPTNodesBatch(100)
|
||||||
|
require.True(t, len(unknownNodes) > 0)
|
||||||
|
require.NotContains(t, unknownNodes, expectedHeader.PrevStateRoot)
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||||
|
|
||||||
|
// add the rest of MPT nodes and jump to state
|
||||||
|
for {
|
||||||
|
unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one
|
||||||
|
if len(unknownHashes) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool {
|
||||||
|
require.NoError(t, module.AddMPTNodes([][]byte{slice.Copy(nodeBytes)}))
|
||||||
|
return true // add nodes one-by-one
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that module is inactive and statejump is completed
|
||||||
|
require.False(t, module.IsActive())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.False(t, module.NeedMPTNodes())
|
||||||
|
unknownNodes = module.GetUnknownMPTNodesBatch(1)
|
||||||
|
require.True(t, len(unknownNodes) == 0)
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight())
|
||||||
|
|
||||||
|
// create new module from completed state: the module should recognise that state sync is completed
|
||||||
|
module = bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.False(t, module.IsActive())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.False(t, module.NeedMPTNodes())
|
||||||
|
unknownNodes = module.GetUnknownMPTNodesBatch(1)
|
||||||
|
require.True(t, len(unknownNodes) == 0)
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight())
|
||||||
|
|
||||||
|
// add one more block to the restored chain and start new module: the module should recognise state sync is completed
|
||||||
|
// and regular blocks processing was started
|
||||||
|
require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock()))
|
||||||
|
module = bcBolt.GetStateSyncModule()
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.False(t, module.IsActive())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.False(t, module.NeedMPTNodes())
|
||||||
|
unknownNodes = module.GetUnknownMPTNodesBatch(1)
|
||||||
|
require.True(t, len(unknownNodes) == 0)
|
||||||
|
require.Equal(t, uint32(stateSyncPoint)+1, module.BlockHeight())
|
||||||
|
require.Equal(t, uint32(stateSyncPoint)+1, bcBolt.BlockHeight())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateSyncModule_RestoreBasicChain(t *testing.T) {
|
||||||
|
var (
|
||||||
|
stateSyncInterval = 4
|
||||||
|
maxTraceable uint32 = 6
|
||||||
|
stateSyncPoint = 16
|
||||||
|
)
|
||||||
|
spoutCfg := func(c *config.Config) {
|
||||||
|
c.ProtocolConfiguration.StateRootInHeader = true
|
||||||
|
c.ProtocolConfiguration.P2PStateExchangeExtensions = true
|
||||||
|
c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval
|
||||||
|
c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable
|
||||||
|
}
|
||||||
|
bcSpout := newTestChainWithCustomCfg(t, spoutCfg)
|
||||||
|
initBasicChain(t, bcSpout)
|
||||||
|
|
||||||
|
// make spout chain higher that latest state sync point
|
||||||
|
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||||
|
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||||
|
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||||
|
require.Equal(t, uint32(stateSyncPoint+2), bcSpout.BlockHeight())
|
||||||
|
|
||||||
|
boltCfg := func(c *config.Config) {
|
||||||
|
spoutCfg(c)
|
||||||
|
c.ProtocolConfiguration.KeepOnlyLatestState = true
|
||||||
|
c.ProtocolConfiguration.RemoveUntraceableBlocks = true
|
||||||
|
}
|
||||||
|
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||||
|
module := bcBolt.GetStateSyncModule()
|
||||||
|
|
||||||
|
t.Run("error: add headers before initialisation", func(t *testing.T) {
|
||||||
|
h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Error(t, module.AddHeaders(h))
|
||||||
|
})
|
||||||
|
t.Run("no error: add blocks before initialisation", func(t *testing.T) {
|
||||||
|
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, module.AddBlock(b))
|
||||||
|
})
|
||||||
|
t.Run("error: add MPT nodes without initialisation", func(t *testing.T) {
|
||||||
|
require.Error(t, module.AddMPTNodes([][]byte{}))
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.True(t, module.NeedHeaders())
|
||||||
|
require.False(t, module.NeedMPTNodes())
|
||||||
|
|
||||||
|
// add headers to module
|
||||||
|
headers := make([]*block.Header, 0, bcSpout.HeaderHeight())
|
||||||
|
for i := uint32(1); i <= bcSpout.HeaderHeight(); i++ {
|
||||||
|
h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(int(i)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
headers = append(headers, h)
|
||||||
|
}
|
||||||
|
require.NoError(t, module.AddHeaders(headers...))
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.True(t, module.NeedMPTNodes())
|
||||||
|
require.Equal(t, bcSpout.HeaderHeight(), bcBolt.HeaderHeight())
|
||||||
|
|
||||||
|
// add blocks
|
||||||
|
t.Run("error: unexpected block index", func(t *testing.T) {
|
||||||
|
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(stateSyncPoint - int(maxTraceable)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Error(t, module.AddBlock(b))
|
||||||
|
})
|
||||||
|
t.Run("error: missing state root in block header", func(t *testing.T) {
|
||||||
|
b := &block.Block{
|
||||||
|
Header: block.Header{
|
||||||
|
Index: uint32(stateSyncPoint) - maxTraceable + 1,
|
||||||
|
StateRootEnabled: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Error(t, module.AddBlock(b))
|
||||||
|
})
|
||||||
|
t.Run("error: invalid block merkle root", func(t *testing.T) {
|
||||||
|
b := &block.Block{
|
||||||
|
Header: block.Header{
|
||||||
|
Index: uint32(stateSyncPoint) - maxTraceable + 1,
|
||||||
|
StateRootEnabled: true,
|
||||||
|
MerkleRoot: util.Uint256{1, 2, 3},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Error(t, module.AddBlock(b))
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint; i++ {
|
||||||
|
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, module.AddBlock(b))
|
||||||
|
}
|
||||||
|
require.True(t, module.IsActive())
|
||||||
|
require.True(t, module.IsInitialized())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.True(t, module.NeedMPTNodes())
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||||
|
|
||||||
|
// add MPT nodes in batches
|
||||||
|
h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(stateSyncPoint + 1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
unknownHashes := module.GetUnknownMPTNodesBatch(100)
|
||||||
|
require.Equal(t, 1, len(unknownHashes))
|
||||||
|
require.Equal(t, h.PrevStateRoot, unknownHashes[0])
|
||||||
|
nodesMap := make(map[util.Uint256][]byte)
|
||||||
|
err = bcSpout.GetStateSyncModule().Traverse(h.PrevStateRoot, func(n mpt.Node, nodeBytes []byte) bool {
|
||||||
|
nodesMap[n.Hash()] = nodeBytes
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
for {
|
||||||
|
need := module.GetUnknownMPTNodesBatch(10)
|
||||||
|
if len(need) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
add := make([][]byte, len(need))
|
||||||
|
for i, h := range need {
|
||||||
|
nodeBytes, ok := nodesMap[h]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("unknown or restored node requested")
|
||||||
|
}
|
||||||
|
add[i] = nodeBytes
|
||||||
|
delete(nodesMap, h)
|
||||||
|
}
|
||||||
|
require.NoError(t, module.AddMPTNodes(add))
|
||||||
|
}
|
||||||
|
require.False(t, module.IsActive())
|
||||||
|
require.False(t, module.NeedHeaders())
|
||||||
|
require.False(t, module.NeedMPTNodes())
|
||||||
|
unknownNodes := module.GetUnknownMPTNodesBatch(1)
|
||||||
|
require.True(t, len(unknownNodes) == 0)
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||||
|
require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight())
|
||||||
|
|
||||||
|
// add missing blocks to bcBolt: should be ok, because state is synced
|
||||||
|
for i := stateSyncPoint + 1; i <= int(bcSpout.BlockHeight()); i++ {
|
||||||
|
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, bcBolt.AddBlock(b))
|
||||||
|
}
|
||||||
|
require.Equal(t, bcSpout.BlockHeight(), bcBolt.BlockHeight())
|
||||||
|
|
||||||
|
// compare storage states
|
||||||
|
fetchStorage := func(bc *Blockchain) []storage.KeyValue {
|
||||||
|
var kv []storage.KeyValue
|
||||||
|
bc.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) {
|
||||||
|
key := slice.Copy(k)
|
||||||
|
value := slice.Copy(v)
|
||||||
|
kv = append(kv, storage.KeyValue{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
expected := fetchStorage(bcSpout)
|
||||||
|
actual := fetchStorage(bcBolt)
|
||||||
|
require.ElementsMatch(t, expected, actual)
|
||||||
|
|
||||||
|
// no temp items should be left
|
||||||
|
bcBolt.dao.Store.Seek(storage.STTempStorage.Bytes(), func(k, v []byte) {
|
||||||
|
t.Fatal("temp storage items are found")
|
||||||
|
})
|
||||||
|
}
|
|
@ -8,13 +8,18 @@ import (
|
||||||
|
|
||||||
// KeyPrefix constants.
|
// KeyPrefix constants.
|
||||||
const (
|
const (
|
||||||
DataBlock KeyPrefix = 0x01
|
DataBlock KeyPrefix = 0x01
|
||||||
DataTransaction KeyPrefix = 0x02
|
DataTransaction KeyPrefix = 0x02
|
||||||
DataMPT KeyPrefix = 0x03
|
DataMPT KeyPrefix = 0x03
|
||||||
STAccount KeyPrefix = 0x40
|
STAccount KeyPrefix = 0x40
|
||||||
STNotification KeyPrefix = 0x4d
|
STNotification KeyPrefix = 0x4d
|
||||||
STContractID KeyPrefix = 0x51
|
STContractID KeyPrefix = 0x51
|
||||||
STStorage KeyPrefix = 0x70
|
STStorage KeyPrefix = 0x70
|
||||||
|
// STTempStorage is used to store contract storage items during state sync process
|
||||||
|
// in order not to mess up the previous state which has its own items stored by
|
||||||
|
// STStorage prefix. Once state exchange process is completed, all items with
|
||||||
|
// STStorage prefix will be replaced with STTempStorage-prefixed ones.
|
||||||
|
STTempStorage KeyPrefix = 0x71
|
||||||
STNEP17Transfers KeyPrefix = 0x72
|
STNEP17Transfers KeyPrefix = 0x72
|
||||||
STNEP17TransferInfo KeyPrefix = 0x73
|
STNEP17TransferInfo KeyPrefix = 0x73
|
||||||
IXHeaderHashList KeyPrefix = 0x80
|
IXHeaderHashList KeyPrefix = 0x80
|
||||||
|
@ -22,6 +27,7 @@ const (
|
||||||
SYSCurrentHeader KeyPrefix = 0xc1
|
SYSCurrentHeader KeyPrefix = 0xc1
|
||||||
SYSStateSyncCurrentBlockHeight KeyPrefix = 0xc2
|
SYSStateSyncCurrentBlockHeight KeyPrefix = 0xc2
|
||||||
SYSStateSyncPoint KeyPrefix = 0xc3
|
SYSStateSyncPoint KeyPrefix = 0xc3
|
||||||
|
SYSStateJumpStage KeyPrefix = 0xc4
|
||||||
SYSVersion KeyPrefix = 0xf0
|
SYSVersion KeyPrefix = 0xf0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"github.com/Workiva/go-datastructures/queue"
|
"github.com/Workiva/go-datastructures/queue"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||||
|
"go.uber.org/atomic"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,6 +14,7 @@ type blockQueue struct {
|
||||||
checkBlocks chan struct{}
|
checkBlocks chan struct{}
|
||||||
chain blockchainer.Blockqueuer
|
chain blockchainer.Blockqueuer
|
||||||
relayF func(*block.Block)
|
relayF func(*block.Block)
|
||||||
|
discarded *atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -32,6 +34,7 @@ func newBlockQueue(capacity int, bc blockchainer.Blockqueuer, log *zap.Logger, r
|
||||||
checkBlocks: make(chan struct{}, 1),
|
checkBlocks: make(chan struct{}, 1),
|
||||||
chain: bc,
|
chain: bc,
|
||||||
relayF: relayer,
|
relayF: relayer,
|
||||||
|
discarded: atomic.NewBool(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,8 +94,10 @@ func (bq *blockQueue) putBlock(block *block.Block) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bq *blockQueue) discard() {
|
func (bq *blockQueue) discard() {
|
||||||
close(bq.checkBlocks)
|
if bq.discarded.CAS(false, true) {
|
||||||
bq.queue.Dispose()
|
close(bq.checkBlocks)
|
||||||
|
bq.queue.Dispose()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bq *blockQueue) length() int {
|
func (bq *blockQueue) length() int {
|
||||||
|
|
|
@ -71,6 +71,8 @@ const (
|
||||||
CMDBlock = CommandType(payload.BlockType)
|
CMDBlock = CommandType(payload.BlockType)
|
||||||
CMDExtensible = CommandType(payload.ExtensibleType)
|
CMDExtensible = CommandType(payload.ExtensibleType)
|
||||||
CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType)
|
CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType)
|
||||||
|
CMDGetMPTData CommandType = 0x51 // 0x5.. commands are used for extensions (P2PNotary, state exchange cmds)
|
||||||
|
CMDMPTData CommandType = 0x52
|
||||||
CMDReject CommandType = 0x2f
|
CMDReject CommandType = 0x2f
|
||||||
|
|
||||||
// SPV protocol.
|
// SPV protocol.
|
||||||
|
@ -136,6 +138,10 @@ func (m *Message) decodePayload() error {
|
||||||
p = &payload.Version{}
|
p = &payload.Version{}
|
||||||
case CMDInv, CMDGetData:
|
case CMDInv, CMDGetData:
|
||||||
p = &payload.Inventory{}
|
p = &payload.Inventory{}
|
||||||
|
case CMDGetMPTData:
|
||||||
|
p = &payload.MPTInventory{}
|
||||||
|
case CMDMPTData:
|
||||||
|
p = &payload.MPTData{}
|
||||||
case CMDAddr:
|
case CMDAddr:
|
||||||
p = &payload.AddressList{}
|
p = &payload.AddressList{}
|
||||||
case CMDBlock:
|
case CMDBlock:
|
||||||
|
@ -221,7 +227,7 @@ func (m *Message) tryCompressPayload() error {
|
||||||
if m.Flags&Compressed == 0 {
|
if m.Flags&Compressed == 0 {
|
||||||
switch m.Payload.(type) {
|
switch m.Payload.(type) {
|
||||||
case *payload.Headers, *payload.MerkleBlock, payload.NullPayload,
|
case *payload.Headers, *payload.MerkleBlock, payload.NullPayload,
|
||||||
*payload.Inventory:
|
*payload.Inventory, *payload.MPTInventory:
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
size := len(compressedPayload)
|
size := len(compressedPayload)
|
||||||
|
|
|
@ -26,6 +26,8 @@ func _() {
|
||||||
_ = x[CMDBlock-44]
|
_ = x[CMDBlock-44]
|
||||||
_ = x[CMDExtensible-46]
|
_ = x[CMDExtensible-46]
|
||||||
_ = x[CMDP2PNotaryRequest-80]
|
_ = x[CMDP2PNotaryRequest-80]
|
||||||
|
_ = x[CMDGetMPTData-81]
|
||||||
|
_ = x[CMDMPTData-82]
|
||||||
_ = x[CMDReject-47]
|
_ = x[CMDReject-47]
|
||||||
_ = x[CMDFilterLoad-48]
|
_ = x[CMDFilterLoad-48]
|
||||||
_ = x[CMDFilterAdd-49]
|
_ = x[CMDFilterAdd-49]
|
||||||
|
@ -44,7 +46,7 @@ const (
|
||||||
_CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
|
_CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
|
||||||
_CommandType_name_7 = "CMDMerkleBlock"
|
_CommandType_name_7 = "CMDMerkleBlock"
|
||||||
_CommandType_name_8 = "CMDAlert"
|
_CommandType_name_8 = "CMDAlert"
|
||||||
_CommandType_name_9 = "CMDP2PNotaryRequest"
|
_CommandType_name_9 = "CMDP2PNotaryRequestCMDGetMPTDataCMDMPTData"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -55,6 +57,7 @@ var (
|
||||||
_CommandType_index_4 = [...]uint8{0, 12, 22}
|
_CommandType_index_4 = [...]uint8{0, 12, 22}
|
||||||
_CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58}
|
_CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58}
|
||||||
_CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61}
|
_CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61}
|
||||||
|
_CommandType_index_9 = [...]uint8{0, 19, 32, 42}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (i CommandType) String() string {
|
func (i CommandType) String() string {
|
||||||
|
@ -83,8 +86,9 @@ func (i CommandType) String() string {
|
||||||
return _CommandType_name_7
|
return _CommandType_name_7
|
||||||
case i == 64:
|
case i == 64:
|
||||||
return _CommandType_name_8
|
return _CommandType_name_8
|
||||||
case i == 80:
|
case 80 <= i && i <= 82:
|
||||||
return _CommandType_name_9
|
i -= 80
|
||||||
|
return _CommandType_name_9[_CommandType_index_9[i]:_CommandType_index_9[i+1]]
|
||||||
default:
|
default:
|
||||||
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
|
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||||
}
|
}
|
||||||
|
|
|
@ -258,6 +258,21 @@ func TestEncodeDecodeNotFound(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeGetMPTData(t *testing.T) {
|
||||||
|
testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{
|
||||||
|
Hashes: []util.Uint256{
|
||||||
|
{1, 2, 3},
|
||||||
|
{4, 5, 6},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeMPTData(t *testing.T) {
|
||||||
|
testEncodeDecode(t, CMDMPTData, &payload.MPTData{
|
||||||
|
Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestInvalidMessages(t *testing.T) {
|
func TestInvalidMessages(t *testing.T) {
|
||||||
t.Run("CMDBlock, empty payload", func(t *testing.T) {
|
t.Run("CMDBlock, empty payload", func(t *testing.T) {
|
||||||
testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})
|
testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})
|
||||||
|
|
35
pkg/network/payload/mptdata.go
Normal file
35
pkg/network/payload/mptdata.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package payload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MPTData represents the set of serialized MPT nodes.
|
||||||
|
type MPTData struct {
|
||||||
|
Nodes [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeBinary implements io.Serializable.
|
||||||
|
func (d *MPTData) EncodeBinary(w *io.BinWriter) {
|
||||||
|
w.WriteVarUint(uint64(len(d.Nodes)))
|
||||||
|
for _, n := range d.Nodes {
|
||||||
|
w.WriteVarBytes(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeBinary implements io.Serializable.
|
||||||
|
func (d *MPTData) DecodeBinary(r *io.BinReader) {
|
||||||
|
sz := r.ReadVarUint()
|
||||||
|
if sz == 0 {
|
||||||
|
r.Err = errors.New("empty MPT nodes list")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := uint64(0); i < sz; i++ {
|
||||||
|
d.Nodes = append(d.Nodes, r.ReadVarBytes())
|
||||||
|
if r.Err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
24
pkg/network/payload/mptdata_test.go
Normal file
24
pkg/network/payload/mptdata_test.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package payload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMPTData_EncodeDecodeBinary(t *testing.T) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
d := new(MPTData)
|
||||||
|
bytes, err := testserdes.EncodeBinary(d)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTData)))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
d := &MPTData{
|
||||||
|
Nodes: [][]byte{{}, {1}, {1, 2, 3}},
|
||||||
|
}
|
||||||
|
testserdes.EncodeDecodeBinary(t, d, new(MPTData))
|
||||||
|
})
|
||||||
|
}
|
32
pkg/network/payload/mptinventory.go
Normal file
32
pkg/network/payload/mptinventory.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package payload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaxMPTHashesCount is the maximum number of requested MPT nodes hashes.
|
||||||
|
const MaxMPTHashesCount = 32
|
||||||
|
|
||||||
|
// MPTInventory payload.
|
||||||
|
type MPTInventory struct {
|
||||||
|
// A list of requested MPT nodes hashes.
|
||||||
|
Hashes []util.Uint256
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMPTInventory return a pointer to an MPTInventory.
|
||||||
|
func NewMPTInventory(hashes []util.Uint256) *MPTInventory {
|
||||||
|
return &MPTInventory{
|
||||||
|
Hashes: hashes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeBinary implements Serializable interface.
|
||||||
|
func (p *MPTInventory) DecodeBinary(br *io.BinReader) {
|
||||||
|
br.ReadArray(&p.Hashes, MaxMPTHashesCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeBinary implements Serializable interface.
|
||||||
|
func (p *MPTInventory) EncodeBinary(bw *io.BinWriter) {
|
||||||
|
bw.WriteArray(p.Hashes)
|
||||||
|
}
|
38
pkg/network/payload/mptinventory_test.go
Normal file
38
pkg/network/payload/mptinventory_test.go
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
package payload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMPTInventory_EncodeDecodeBinary(t *testing.T) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
testserdes.EncodeDecodeBinary(t, NewMPTInventory([]util.Uint256{}), new(MPTInventory))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
inv := NewMPTInventory([]util.Uint256{{1, 2, 3}, {2, 3, 4}})
|
||||||
|
testserdes.EncodeDecodeBinary(t, inv, new(MPTInventory))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("too large", func(t *testing.T) {
|
||||||
|
check := func(t *testing.T, count int, fail bool) {
|
||||||
|
h := make([]util.Uint256, count)
|
||||||
|
for i := range h {
|
||||||
|
h[i] = util.Uint256{1, 2, 3}
|
||||||
|
}
|
||||||
|
if fail {
|
||||||
|
bytes, err := testserdes.EncodeBinary(NewMPTInventory(h))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTInventory)))
|
||||||
|
} else {
|
||||||
|
testserdes.EncodeDecodeBinary(t, NewMPTInventory(h), new(MPTInventory))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
check(t, MaxMPTHashesCount, false)
|
||||||
|
check(t, MaxMPTHashesCount+1, true)
|
||||||
|
})
|
||||||
|
}
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
mrand "math/rand"
|
mrand "math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -17,7 +18,9 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
|
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/extpool"
|
"github.com/nspcc-dev/neo-go/pkg/network/extpool"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||||
|
@ -67,6 +70,7 @@ type (
|
||||||
discovery Discoverer
|
discovery Discoverer
|
||||||
chain blockchainer.Blockchainer
|
chain blockchainer.Blockchainer
|
||||||
bQueue *blockQueue
|
bQueue *blockQueue
|
||||||
|
bSyncQueue *blockQueue
|
||||||
consensus consensus.Service
|
consensus consensus.Service
|
||||||
mempool *mempool.Pool
|
mempool *mempool.Pool
|
||||||
notaryRequestPool *mempool.Pool
|
notaryRequestPool *mempool.Pool
|
||||||
|
@ -93,6 +97,7 @@ type (
|
||||||
|
|
||||||
oracle *oracle.Oracle
|
oracle *oracle.Oracle
|
||||||
stateRoot stateroot.Service
|
stateRoot stateroot.Service
|
||||||
|
stateSync blockchainer.StateSync
|
||||||
|
|
||||||
log *zap.Logger
|
log *zap.Logger
|
||||||
}
|
}
|
||||||
|
@ -191,6 +196,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
|
||||||
}
|
}
|
||||||
s.stateRoot = sr
|
s.stateRoot = sr
|
||||||
|
|
||||||
|
sSync := chain.GetStateSyncModule()
|
||||||
|
s.stateSync = sSync
|
||||||
|
s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil)
|
||||||
|
|
||||||
if config.OracleCfg.Enabled {
|
if config.OracleCfg.Enabled {
|
||||||
orcCfg := oracle.Config{
|
orcCfg := oracle.Config{
|
||||||
Log: log,
|
Log: log,
|
||||||
|
@ -277,6 +286,7 @@ func (s *Server) Start(errChan chan error) {
|
||||||
go s.broadcastTxLoop()
|
go s.broadcastTxLoop()
|
||||||
go s.relayBlocksLoop()
|
go s.relayBlocksLoop()
|
||||||
go s.bQueue.run()
|
go s.bQueue.run()
|
||||||
|
go s.bSyncQueue.run()
|
||||||
go s.transport.Accept()
|
go s.transport.Accept()
|
||||||
setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10))
|
setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10))
|
||||||
s.run()
|
s.run()
|
||||||
|
@ -292,6 +302,7 @@ func (s *Server) Shutdown() {
|
||||||
p.Disconnect(errServerShutdown)
|
p.Disconnect(errServerShutdown)
|
||||||
}
|
}
|
||||||
s.bQueue.discard()
|
s.bQueue.discard()
|
||||||
|
s.bSyncQueue.discard()
|
||||||
if s.StateRootCfg.Enabled {
|
if s.StateRootCfg.Enabled {
|
||||||
s.stateRoot.Shutdown()
|
s.stateRoot.Shutdown()
|
||||||
}
|
}
|
||||||
|
@ -573,6 +584,10 @@ func (s *Server) IsInSync() bool {
|
||||||
var peersNumber int
|
var peersNumber int
|
||||||
var notHigher int
|
var notHigher int
|
||||||
|
|
||||||
|
if s.stateSync.IsActive() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if s.MinPeers == 0 {
|
if s.MinPeers == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -630,6 +645,9 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
|
||||||
|
|
||||||
// handleBlockCmd processes the received block received from its peer.
|
// handleBlockCmd processes the received block received from its peer.
|
||||||
func (s *Server) handleBlockCmd(p Peer, block *block.Block) error {
|
func (s *Server) handleBlockCmd(p Peer, block *block.Block) error {
|
||||||
|
if s.stateSync.IsActive() {
|
||||||
|
return s.bSyncQueue.putBlock(block)
|
||||||
|
}
|
||||||
return s.bQueue.putBlock(block)
|
return s.bQueue.putBlock(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -639,25 +657,57 @@ func (s *Server) handlePing(p Peer, ping *payload.Ping) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if s.chain.BlockHeight() < ping.LastBlockIndex {
|
err = s.requestBlocksOrHeaders(p)
|
||||||
err = s.requestBlocks(p)
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id)))
|
return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) requestBlocksOrHeaders(p Peer) error {
|
||||||
|
if s.stateSync.NeedHeaders() {
|
||||||
|
if s.chain.HeaderHeight() < p.LastBlockIndex() {
|
||||||
|
return s.requestHeaders(p)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
bq blockchainer.Blockqueuer = s.chain
|
||||||
|
requestMPTNodes bool
|
||||||
|
)
|
||||||
|
if s.stateSync.IsActive() {
|
||||||
|
bq = s.stateSync
|
||||||
|
requestMPTNodes = s.stateSync.NeedMPTNodes()
|
||||||
|
}
|
||||||
|
if bq.BlockHeight() >= p.LastBlockIndex() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := s.requestBlocks(bq, p)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if requestMPTNodes {
|
||||||
|
return s.requestMPTNodes(p, s.stateSync.GetUnknownMPTNodesBatch(payload.MaxMPTHashesCount))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers.
|
||||||
|
func (s *Server) requestHeaders(p Peer) error {
|
||||||
|
// TODO: optimize
|
||||||
|
currHeight := s.chain.HeaderHeight()
|
||||||
|
needHeight := currHeight + 1
|
||||||
|
payload := payload.NewGetBlockByIndex(needHeight, -1)
|
||||||
|
return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload))
|
||||||
|
}
|
||||||
|
|
||||||
// handlePing processes pong request.
|
// handlePing processes pong request.
|
||||||
func (s *Server) handlePong(p Peer, pong *payload.Ping) error {
|
func (s *Server) handlePong(p Peer, pong *payload.Ping) error {
|
||||||
err := p.HandlePong(pong)
|
err := p.HandlePong(pong)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if s.chain.BlockHeight() < pong.LastBlockIndex {
|
return s.requestBlocksOrHeaders(p)
|
||||||
return s.requestBlocks(p)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleInvCmd processes the received inventory.
|
// handleInvCmd processes the received inventory.
|
||||||
|
@ -766,6 +816,69 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleGetMPTDataCmd processes the received MPT inventory.
|
||||||
|
func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error {
|
||||||
|
if !s.chain.GetConfig().P2PStateExchangeExtensions {
|
||||||
|
return errors.New("GetMPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
|
||||||
|
}
|
||||||
|
if s.chain.GetConfig().KeepOnlyLatestState {
|
||||||
|
// TODO: implement keeping MPT states for P1 and P2 height (#2095, #2152 related)
|
||||||
|
return errors.New("GetMPTDataCMD was received, but only latest MPT state is supported")
|
||||||
|
}
|
||||||
|
resp := payload.MPTData{}
|
||||||
|
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
|
||||||
|
added := make(map[util.Uint256]struct{})
|
||||||
|
for _, h := range inv.Hashes {
|
||||||
|
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err := s.stateSync.Traverse(h,
|
||||||
|
func(n mpt.Node, node []byte) bool {
|
||||||
|
if _, ok := added[n.Hash()]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
l := len(node)
|
||||||
|
size := l + io.GetVarSize(l)
|
||||||
|
if size > capLeft {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
resp.Nodes = append(resp.Nodes, node)
|
||||||
|
added[n.Hash()] = struct{}{}
|
||||||
|
capLeft -= size
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to traverse MPT starting from %s: %w", h.StringBE(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(resp.Nodes) > 0 {
|
||||||
|
msg := NewMessage(CMDMPTData, &resp)
|
||||||
|
return p.EnqueueP2PMessage(msg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error {
|
||||||
|
if !s.chain.GetConfig().P2PStateExchangeExtensions {
|
||||||
|
return errors.New("MPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
|
||||||
|
}
|
||||||
|
return s.stateSync.AddMPTNodes(data.Nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestMPTNodes requests specified MPT nodes from the peer or broadcasts
|
||||||
|
// request if peer is not specified.
|
||||||
|
func (s *Server) requestMPTNodes(p Peer, itms []util.Uint256) error {
|
||||||
|
if len(itms) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(itms) > payload.MaxMPTHashesCount {
|
||||||
|
itms = itms[:payload.MaxMPTHashesCount]
|
||||||
|
}
|
||||||
|
pl := payload.NewMPTInventory(itms)
|
||||||
|
msg := NewMessage(CMDGetMPTData, pl)
|
||||||
|
return p.EnqueueP2PMessage(msg)
|
||||||
|
}
|
||||||
|
|
||||||
// handleGetBlocksCmd processes the getblocks request.
|
// handleGetBlocksCmd processes the getblocks request.
|
||||||
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
|
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
|
||||||
count := gb.Count
|
count := gb.Count
|
||||||
|
@ -845,6 +958,11 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error
|
||||||
return p.EnqueueP2PMessage(msg)
|
return p.EnqueueP2PMessage(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleHeadersCmd processes headers payload.
|
||||||
|
func (s *Server) handleHeadersCmd(p Peer, h *payload.Headers) error {
|
||||||
|
return s.stateSync.AddHeaders(h.Hdrs...)
|
||||||
|
}
|
||||||
|
|
||||||
// handleExtensibleCmd processes received extensible payload.
|
// handleExtensibleCmd processes received extensible payload.
|
||||||
func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
|
func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
|
||||||
if !s.syncReached.Load() {
|
if !s.syncReached.Load() {
|
||||||
|
@ -993,8 +1111,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error {
|
||||||
// 1. Block range is divided into chunks of payload.MaxHashesCount.
|
// 1. Block range is divided into chunks of payload.MaxHashesCount.
|
||||||
// 2. Send requests for chunk in increasing order.
|
// 2. Send requests for chunk in increasing order.
|
||||||
// 3. After all requests were sent, request random height.
|
// 3. After all requests were sent, request random height.
|
||||||
func (s *Server) requestBlocks(p Peer) error {
|
func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error {
|
||||||
var currHeight = s.chain.BlockHeight()
|
var currHeight = bq.BlockHeight()
|
||||||
var peerHeight = p.LastBlockIndex()
|
var peerHeight = p.LastBlockIndex()
|
||||||
var needHeight uint32
|
var needHeight uint32
|
||||||
// lastRequestedHeight can only be increased.
|
// lastRequestedHeight can only be increased.
|
||||||
|
@ -1051,9 +1169,18 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
||||||
case CMDGetData:
|
case CMDGetData:
|
||||||
inv := msg.Payload.(*payload.Inventory)
|
inv := msg.Payload.(*payload.Inventory)
|
||||||
return s.handleGetDataCmd(peer, inv)
|
return s.handleGetDataCmd(peer, inv)
|
||||||
|
case CMDGetMPTData:
|
||||||
|
inv := msg.Payload.(*payload.MPTInventory)
|
||||||
|
return s.handleGetMPTDataCmd(peer, inv)
|
||||||
|
case CMDMPTData:
|
||||||
|
inv := msg.Payload.(*payload.MPTData)
|
||||||
|
return s.handleMPTDataCmd(peer, inv)
|
||||||
case CMDGetHeaders:
|
case CMDGetHeaders:
|
||||||
gh := msg.Payload.(*payload.GetBlockByIndex)
|
gh := msg.Payload.(*payload.GetBlockByIndex)
|
||||||
return s.handleGetHeadersCmd(peer, gh)
|
return s.handleGetHeadersCmd(peer, gh)
|
||||||
|
case CMDHeaders:
|
||||||
|
h := msg.Payload.(*payload.Headers)
|
||||||
|
return s.handleHeadersCmd(peer, h)
|
||||||
case CMDInv:
|
case CMDInv:
|
||||||
inventory := msg.Payload.(*payload.Inventory)
|
inventory := msg.Payload.(*payload.Inventory)
|
||||||
return s.handleInvCmd(peer, inventory)
|
return s.handleInvCmd(peer, inventory)
|
||||||
|
@ -1093,6 +1220,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
||||||
}
|
}
|
||||||
go peer.StartProtocol()
|
go peer.StartProtocol()
|
||||||
|
|
||||||
|
s.tryInitStateSync()
|
||||||
s.tryStartServices()
|
s.tryStartServices()
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("received '%s' during handshake", msg.Command.String())
|
return fmt.Errorf("received '%s' during handshake", msg.Command.String())
|
||||||
|
@ -1101,6 +1229,52 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) tryInitStateSync() {
|
||||||
|
if !s.stateSync.IsActive() {
|
||||||
|
s.bSyncQueue.discard()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.stateSync.IsInitialized() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var peersNumber int
|
||||||
|
s.lock.RLock()
|
||||||
|
heights := make([]uint32, 0)
|
||||||
|
for p := range s.peers {
|
||||||
|
if p.Handshaked() {
|
||||||
|
peersNumber++
|
||||||
|
peerLastBlock := p.LastBlockIndex()
|
||||||
|
i := sort.Search(len(heights), func(i int) bool {
|
||||||
|
return heights[i] >= peerLastBlock
|
||||||
|
})
|
||||||
|
heights = append(heights, peerLastBlock)
|
||||||
|
if i != len(heights)-1 {
|
||||||
|
copy(heights[i+1:], heights[i:])
|
||||||
|
heights[i] = peerLastBlock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.lock.RUnlock()
|
||||||
|
if peersNumber >= s.MinPeers && len(heights) > 0 {
|
||||||
|
// choose the height of the median peer as current chain's height
|
||||||
|
h := heights[len(heights)/2]
|
||||||
|
err := s.stateSync.Init(h)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Fatal("failed to init state sync module",
|
||||||
|
zap.Uint32("evaluated chain's blockHeight", h),
|
||||||
|
zap.Uint32("blockHeight", s.chain.BlockHeight()),
|
||||||
|
zap.Uint32("headerHeight", s.chain.HeaderHeight()),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// module can be inactive after init (i.e. full state is collected and ordinary block processing is needed)
|
||||||
|
if !s.stateSync.IsActive() {
|
||||||
|
s.bSyncQueue.discard()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
func (s *Server) handleNewPayload(p *payload.Extensible) {
|
func (s *Server) handleNewPayload(p *payload.Extensible) {
|
||||||
_, err := s.extensiblePool.Add(p)
|
_, err := s.extensiblePool.Add(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -16,6 +17,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||||
|
@ -46,7 +48,10 @@ func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs =
|
||||||
func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
|
func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
|
||||||
|
|
||||||
func TestNewServer(t *testing.T) {
|
func TestNewServer(t *testing.T) {
|
||||||
bc := &fakechain.FakeChain{}
|
bc := &fakechain.FakeChain{ProtocolConfiguration: config.ProtocolConfiguration{
|
||||||
|
P2PStateExchangeExtensions: true,
|
||||||
|
StateRootInHeader: true,
|
||||||
|
}}
|
||||||
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
|
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
|
@ -733,6 +738,139 @@ func TestInv(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleGetMPTData(t *testing.T) {
|
||||||
|
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
|
||||||
|
Hashes: []util.Uint256{{1, 2, 3}},
|
||||||
|
})
|
||||||
|
require.Error(t, s.handleMessage(p, msg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("KeepOnlyLatestState on", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||||
|
s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
|
||||||
|
Hashes: []util.Uint256{{1, 2, 3}},
|
||||||
|
})
|
||||||
|
require.Error(t, s.handleMessage(p, msg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||||
|
var recvResponse atomic.Bool
|
||||||
|
r1 := random.Uint256()
|
||||||
|
r2 := random.Uint256()
|
||||||
|
r3 := random.Uint256()
|
||||||
|
node := []byte{1, 2, 3}
|
||||||
|
s.stateSync.(*fakechain.FakeStateSync).TraverseFunc = func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||||
|
if !(root.Equals(r1) || root.Equals(r2)) {
|
||||||
|
t.Fatal("unexpected root")
|
||||||
|
}
|
||||||
|
require.False(t, process(mpt.NewHashNode(r3), node))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
found := &payload.MPTData{
|
||||||
|
Nodes: [][]byte{node}, // no duplicates expected
|
||||||
|
}
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
p.messageHandler = func(t *testing.T, msg *Message) {
|
||||||
|
switch msg.Command {
|
||||||
|
case CMDMPTData:
|
||||||
|
require.Equal(t, found, msg.Payload)
|
||||||
|
recvResponse.Store(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hs := []util.Uint256{r1, r2}
|
||||||
|
s.testHandleMessage(t, p, CMDGetMPTData, payload.NewMPTInventory(hs))
|
||||||
|
|
||||||
|
require.Eventually(t, recvResponse.Load, time.Second, time.Millisecond)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleMPTData(t *testing.T) {
|
||||||
|
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
msg := NewMessage(CMDMPTData, &payload.MPTData{
|
||||||
|
Nodes: [][]byte{{1, 2, 3}},
|
||||||
|
})
|
||||||
|
require.Error(t, s.handleMessage(p, msg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
expected := [][]byte{{1, 2, 3}, {2, 3, 4}}
|
||||||
|
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||||
|
s.stateSync = &fakechain.FakeStateSync{
|
||||||
|
AddMPTNodesFunc: func(nodes [][]byte) error {
|
||||||
|
require.Equal(t, expected, nodes)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
msg := NewMessage(CMDMPTData, &payload.MPTData{
|
||||||
|
Nodes: expected,
|
||||||
|
})
|
||||||
|
require.NoError(t, s.handleMessage(p, msg))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestMPTNodes(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
|
||||||
|
var actual []util.Uint256
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
p.messageHandler = func(t *testing.T, msg *Message) {
|
||||||
|
if msg.Command == CMDGetMPTData {
|
||||||
|
actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.register <- p
|
||||||
|
s.register <- p // ensure previous send was handled
|
||||||
|
|
||||||
|
t.Run("no hashes, no message", func(t *testing.T) {
|
||||||
|
actual = nil
|
||||||
|
require.NoError(t, s.requestMPTNodes(p, nil))
|
||||||
|
require.Nil(t, actual)
|
||||||
|
})
|
||||||
|
t.Run("good, small", func(t *testing.T) {
|
||||||
|
actual = nil
|
||||||
|
expected := []util.Uint256{random.Uint256(), random.Uint256()}
|
||||||
|
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||||
|
require.Equal(t, expected, actual)
|
||||||
|
})
|
||||||
|
t.Run("good, exactly one chunk", func(t *testing.T) {
|
||||||
|
actual = nil
|
||||||
|
expected := make([]util.Uint256, payload.MaxMPTHashesCount)
|
||||||
|
for i := range expected {
|
||||||
|
expected[i] = random.Uint256()
|
||||||
|
}
|
||||||
|
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||||
|
require.Equal(t, expected, actual)
|
||||||
|
})
|
||||||
|
t.Run("good, too large chunk", func(t *testing.T) {
|
||||||
|
actual = nil
|
||||||
|
expected := make([]util.Uint256, payload.MaxMPTHashesCount+1)
|
||||||
|
for i := range expected {
|
||||||
|
expected[i] = random.Uint256()
|
||||||
|
}
|
||||||
|
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||||
|
require.Equal(t, expected[:payload.MaxMPTHashesCount], actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestRequestTx(t *testing.T) {
|
func TestRequestTx(t *testing.T) {
|
||||||
s := startTestServer(t)
|
s := startTestServer(t)
|
||||||
|
|
||||||
|
@ -899,3 +1037,44 @@ func TestVerifyNotaryRequest(t *testing.T) {
|
||||||
require.NoError(t, verifyNotaryRequest(bc, nil, r))
|
require.NoError(t, verifyNotaryRequest(bc, nil, r))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTryInitStateSync(t *testing.T) {
|
||||||
|
t.Run("module inactive", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
s.tryInitStateSync()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("module already initialized", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
ss := &fakechain.FakeStateSync{}
|
||||||
|
ss.IsActiveFlag.Store(true)
|
||||||
|
ss.IsInitializedFlag.Store(true)
|
||||||
|
s.stateSync = ss
|
||||||
|
s.tryInitStateSync()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
for _, h := range []uint32{10, 8, 7, 4, 11, 4} {
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
p.lastBlockIndex = h
|
||||||
|
s.peers[p] = true
|
||||||
|
}
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = false // one disconnected peer to check it won't be taken into attention
|
||||||
|
p.lastBlockIndex = 5
|
||||||
|
s.peers[p] = true
|
||||||
|
var expectedH uint32 = 8 // median peer
|
||||||
|
|
||||||
|
ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error {
|
||||||
|
if h != expectedH {
|
||||||
|
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
ss.IsActiveFlag.Store(true)
|
||||||
|
s.stateSync = ss
|
||||||
|
s.tryInitStateSync()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -267,12 +267,10 @@ func (p *TCPPeer) StartProtocol() {
|
||||||
zap.Uint32("id", p.Version().Nonce))
|
zap.Uint32("id", p.Version().Nonce))
|
||||||
|
|
||||||
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
|
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
|
||||||
if p.server.chain.BlockHeight() < p.LastBlockIndex() {
|
err = p.server.requestBlocksOrHeaders(p)
|
||||||
err = p.server.requestBlocks(p)
|
if err != nil {
|
||||||
if err != nil {
|
p.Disconnect(err)
|
||||||
p.Disconnect(err)
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
timer := time.NewTimer(p.server.ProtoTickInterval)
|
timer := time.NewTimer(p.server.ProtoTickInterval)
|
||||||
|
@ -281,10 +279,8 @@ func (p *TCPPeer) StartProtocol() {
|
||||||
case <-p.done:
|
case <-p.done:
|
||||||
return
|
return
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
// Try to sync in headers and block with the peer if his block height is higher then ours.
|
// Try to sync in headers and block with the peer if his block height is higher than ours.
|
||||||
if p.LastBlockIndex() > p.server.chain.BlockHeight() {
|
err = p.server.requestBlocksOrHeaders(p)
|
||||||
err = p.server.requestBlocks(p)
|
|
||||||
}
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
timer.Reset(p.server.ProtoTickInterval)
|
timer.Reset(p.server.ProtoTickInterval)
|
||||||
}
|
}
|
||||||
|
|
|
@ -679,6 +679,7 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, response.NewRPCError("Failed to get NEP17 last updated block", err.Error(), err)
|
return nil, response.NewRPCError("Failed to get NEP17 last updated block", err.Error(), err)
|
||||||
}
|
}
|
||||||
|
stateSyncPoint := lastUpdated[math.MinInt32]
|
||||||
bw := io.NewBufBinWriter()
|
bw := io.NewBufBinWriter()
|
||||||
for _, h := range s.chain.GetNEP17Contracts() {
|
for _, h := range s.chain.GetNEP17Contracts() {
|
||||||
balance, err := s.getNEP17Balance(h, u, bw)
|
balance, err := s.getNEP17Balance(h, u, bw)
|
||||||
|
@ -692,10 +693,18 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err
|
||||||
if cs == nil {
|
if cs == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
lub, ok := lastUpdated[cs.ID]
|
||||||
|
if !ok {
|
||||||
|
cfg := s.chain.GetConfig()
|
||||||
|
if !cfg.P2PStateExchangeExtensions && cfg.RemoveUntraceableBlocks {
|
||||||
|
return nil, response.NewInternalServerError(fmt.Sprintf("failed to get LastUpdatedBlock for balance of %s token", cs.Hash.StringLE()), nil)
|
||||||
|
}
|
||||||
|
lub = stateSyncPoint
|
||||||
|
}
|
||||||
bs.Balances = append(bs.Balances, result.NEP17Balance{
|
bs.Balances = append(bs.Balances, result.NEP17Balance{
|
||||||
Asset: h,
|
Asset: h,
|
||||||
Amount: balance.String(),
|
Amount: balance.String(),
|
||||||
LastUpdated: lastUpdated[cs.ID],
|
LastUpdated: lub,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return bs, nil
|
return bs, nil
|
||||||
|
|
Loading…
Reference in a new issue