forked from TrueCloudLab/neoneo-go
Merge pull request #2019 from nspcc-dev/states-exchange/cmd
core, network: implement P2P state exchange
This commit is contained in:
commit
752c7f106b
39 changed files with 3001 additions and 53 deletions
|
@ -114,7 +114,11 @@ balance won't be shown in the list of NEP17 balances returned by the neo-go node
|
|||
(unlike the C# node behavior). However, transfer logs of such token are still
|
||||
available via `getnep17transfers` RPC call.
|
||||
|
||||
The behaviour of the `LastUpdatedBlock` tracking matches the C# node's one.
|
||||
The behaviour of the `LastUpdatedBlock` tracking for archival nodes as far as for
|
||||
governing token balances matches the C# node's one. For non-archival nodes and
|
||||
other NEP17-compliant tokens if transfer's `LastUpdatedBlock` is lower than the
|
||||
latest state synchronization point P the node working against, then
|
||||
`LastUpdatedBlock` equals P.
|
||||
|
||||
### Unsupported methods
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/native"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
|
@ -21,6 +22,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
uatomic "go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// FakeChain implements Blockchainer interface, but does not provide real functionality.
|
||||
|
@ -42,6 +44,15 @@ type FakeChain struct {
|
|||
UtilityTokenBalance *big.Int
|
||||
}
|
||||
|
||||
// FakeStateSync implements StateSync interface.
|
||||
type FakeStateSync struct {
|
||||
IsActiveFlag uatomic.Bool
|
||||
IsInitializedFlag uatomic.Bool
|
||||
InitFunc func(h uint32) error
|
||||
TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
|
||||
AddMPTNodesFunc func(nodes [][]byte) error
|
||||
}
|
||||
|
||||
// NewFakeChain returns new FakeChain structure.
|
||||
func NewFakeChain() *FakeChain {
|
||||
return &FakeChain{
|
||||
|
@ -294,6 +305,11 @@ func (chain *FakeChain) GetStateModule() blockchainer.StateRoot {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetStateSyncModule implements Blockchainer interface.
|
||||
func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync {
|
||||
return &FakeStateSync{}
|
||||
}
|
||||
|
||||
// GetStorageItem implements Blockchainer interface.
|
||||
func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
|
||||
panic("TODO")
|
||||
|
@ -436,3 +452,63 @@ func (chain *FakeChain) UnsubscribeFromNotifications(ch chan<- *state.Notificati
|
|||
func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// AddBlock implements StateSync interface.
|
||||
func (s *FakeStateSync) AddBlock(block *block.Block) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// AddHeaders implements StateSync interface.
|
||||
func (s *FakeStateSync) AddHeaders(...*block.Header) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// AddMPTNodes implements StateSync interface.
|
||||
func (s *FakeStateSync) AddMPTNodes(nodes [][]byte) error {
|
||||
if s.AddMPTNodesFunc != nil {
|
||||
return s.AddMPTNodesFunc(nodes)
|
||||
}
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// BlockHeight implements StateSync interface.
|
||||
func (s *FakeStateSync) BlockHeight() uint32 {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// IsActive implements StateSync interface.
|
||||
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag.Load() }
|
||||
|
||||
// IsInitialized implements StateSync interface.
|
||||
func (s *FakeStateSync) IsInitialized() bool {
|
||||
return s.IsInitializedFlag.Load()
|
||||
}
|
||||
|
||||
// Init implements StateSync interface.
|
||||
func (s *FakeStateSync) Init(currChainHeight uint32) error {
|
||||
if s.InitFunc != nil {
|
||||
return s.InitFunc(currChainHeight)
|
||||
}
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// NeedHeaders implements StateSync interface.
|
||||
func (s *FakeStateSync) NeedHeaders() bool { return false }
|
||||
|
||||
// NeedMPTNodes implements StateSync interface.
|
||||
func (s *FakeStateSync) NeedMPTNodes() bool {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// Traverse implements StateSync interface.
|
||||
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||
if s.TraverseFunc != nil {
|
||||
return s.TraverseFunc(root, process)
|
||||
}
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// GetUnknownMPTNodesBatch implements StateSync interface.
|
||||
func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
|
||||
panic("TODO")
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/native/noderoles"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/stateroot"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/statesync"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
|
@ -34,6 +35,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
"go.uber.org/zap"
|
||||
|
@ -42,7 +44,7 @@ import (
|
|||
// Tuning parameters.
|
||||
const (
|
||||
headerBatchCount = 2000
|
||||
version = "0.1.2"
|
||||
version = "0.1.4"
|
||||
|
||||
defaultInitialGAS = 52000000_00000000
|
||||
defaultMemPoolSize = 50000
|
||||
|
@ -54,6 +56,31 @@ const (
|
|||
// HeaderVerificationGasLimit is the maximum amount of GAS for block header verification.
|
||||
HeaderVerificationGasLimit = 3_00000000 // 3 GAS
|
||||
defaultStateSyncInterval = 40000
|
||||
|
||||
// maxStorageBatchSize is the number of elements in storage batch expected to fit into the
|
||||
// storage without delays and problems. Estimated size of batch in case of given number of
|
||||
// elements does not exceed 1Mb.
|
||||
maxStorageBatchSize = 10000
|
||||
)
|
||||
|
||||
// stateJumpStage denotes the stage of state jump process.
|
||||
type stateJumpStage byte
|
||||
|
||||
const (
|
||||
// none means that no state jump process was initiated yet.
|
||||
none stateJumpStage = 1 << iota
|
||||
// stateJumpStarted means that state jump was just initiated, but outdated storage items
|
||||
// were not yet removed.
|
||||
stateJumpStarted
|
||||
// oldStorageItemsRemoved means that outdated contract storage items were removed, but
|
||||
// new storage items were not yet saved.
|
||||
oldStorageItemsRemoved
|
||||
// newStorageItemsAdded means that contract storage items are up-to-date with the current
|
||||
// state.
|
||||
newStorageItemsAdded
|
||||
// genesisStateRemoved means that state corresponding to the genesis block was removed
|
||||
// from the storage.
|
||||
genesisStateRemoved
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -308,16 +335,6 @@ func (bc *Blockchain) init() error {
|
|||
// and the genesis block as first block.
|
||||
bc.log.Info("restoring blockchain", zap.String("version", version))
|
||||
|
||||
bHeight, err := bc.dao.GetCurrentBlockHeight()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bc.blockHeight = bHeight
|
||||
bc.persistedHeight = bHeight
|
||||
if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil {
|
||||
return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err)
|
||||
}
|
||||
|
||||
bc.headerHashes, err = bc.dao.GetHeaderHashes()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -365,6 +382,34 @@ func (bc *Blockchain) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Check whether StateJump stage is in the storage and continue interrupted state jump if so.
|
||||
jumpStage, err := bc.dao.Store.Get(storage.SYSStateJumpStage.Bytes())
|
||||
if err == nil {
|
||||
if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) {
|
||||
return errors.New("state jump was not completed, but P2PStateExchangeExtensions are disabled or archival node capability is on. " +
|
||||
"To start an archival node drop the database manually and restart the node")
|
||||
}
|
||||
if len(jumpStage) != 1 {
|
||||
return fmt.Errorf("invalid state jump stage format")
|
||||
}
|
||||
// State jump wasn't finished yet, thus continue it.
|
||||
stateSyncPoint, err := bc.dao.GetStateSyncPoint()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get state sync point from the storage")
|
||||
}
|
||||
return bc.jumpToStateInternal(stateSyncPoint, stateJumpStage(jumpStage[0]))
|
||||
}
|
||||
|
||||
bHeight, err := bc.dao.GetCurrentBlockHeight()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bc.blockHeight = bHeight
|
||||
bc.persistedHeight = bHeight
|
||||
if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil {
|
||||
return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err)
|
||||
}
|
||||
|
||||
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
|
||||
|
@ -409,6 +454,158 @@ func (bc *Blockchain) init() error {
|
|||
return bc.updateExtensibleWhitelist(bHeight)
|
||||
}
|
||||
|
||||
// jumpToState is an atomic operation that changes Blockchain state to the one
|
||||
// specified by the state sync point p. All the data needed for the jump must be
|
||||
// collected by the state sync module.
|
||||
func (bc *Blockchain) jumpToState(p uint32) error {
|
||||
bc.lock.Lock()
|
||||
defer bc.lock.Unlock()
|
||||
|
||||
return bc.jumpToStateInternal(p, none)
|
||||
}
|
||||
|
||||
// jumpToStateInternal is an internal representation of jumpToState callback that
|
||||
// changes Blockchain state to the one specified by state sync point p and state
|
||||
// jump stage. All the data needed for the jump must be in the DB, otherwise an
|
||||
// error is returned. It is not protected by mutex.
|
||||
func (bc *Blockchain) jumpToStateInternal(p uint32, stage stateJumpStage) error {
|
||||
if p+1 >= uint32(len(bc.headerHashes)) {
|
||||
return fmt.Errorf("invalid state sync point %d: headerHeignt is %d", p, len(bc.headerHashes))
|
||||
}
|
||||
|
||||
bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p))
|
||||
|
||||
writeBuf := io.NewBufBinWriter()
|
||||
jumpStageKey := storage.SYSStateJumpStage.Bytes()
|
||||
switch stage {
|
||||
case none:
|
||||
err := bc.dao.Store.Put(jumpStageKey, []byte{byte(stateJumpStarted)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store state jump stage: %w", err)
|
||||
}
|
||||
fallthrough
|
||||
case stateJumpStarted:
|
||||
// Replace old storage items by new ones, it should be done step-by step.
|
||||
// Firstly, remove all old genesis-related items.
|
||||
b := bc.dao.Store.Batch()
|
||||
bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) {
|
||||
// Must copy here, #1468.
|
||||
key := slice.Copy(k)
|
||||
b.Delete(key)
|
||||
})
|
||||
b.Put(jumpStageKey, []byte{byte(oldStorageItemsRemoved)})
|
||||
err := bc.dao.Store.PutBatch(b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store state jump stage: %w", err)
|
||||
}
|
||||
fallthrough
|
||||
case oldStorageItemsRemoved:
|
||||
// Then change STTempStorage prefix to STStorage. Each replace operation is atomic.
|
||||
for {
|
||||
count := 0
|
||||
b := bc.dao.Store.Batch()
|
||||
bc.dao.Store.Seek([]byte{byte(storage.STTempStorage)}, func(k, v []byte) {
|
||||
if count >= maxStorageBatchSize {
|
||||
return
|
||||
}
|
||||
// Must copy here, #1468.
|
||||
oldKey := slice.Copy(k)
|
||||
b.Delete(oldKey)
|
||||
key := make([]byte, len(k))
|
||||
key[0] = byte(storage.STStorage)
|
||||
copy(key[1:], k[1:])
|
||||
value := slice.Copy(v)
|
||||
b.Put(key, value)
|
||||
count += 2
|
||||
})
|
||||
if count > 0 {
|
||||
err := bc.dao.Store.PutBatch(b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace outdated contract storage items with the fresh ones: %w", err)
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
err := bc.dao.Store.Put(jumpStageKey, []byte{byte(newStorageItemsAdded)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store state jump stage: %w", err)
|
||||
}
|
||||
fallthrough
|
||||
case newStorageItemsAdded:
|
||||
// After current state is updated, we need to remove outdated state-related data if so.
|
||||
// The only outdated data we might have is genesis-related data, so check it.
|
||||
if p-bc.config.MaxTraceableBlocks > 0 {
|
||||
cache := bc.dao.GetWrapped()
|
||||
writeBuf.Reset()
|
||||
err := cache.DeleteBlock(bc.headerHashes[0], writeBuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove outdated state data for the genesis block: %w", err)
|
||||
}
|
||||
// TODO: remove NEP17 transfers and NEP17 transfer info for genesis block, #2096 related.
|
||||
_, err = cache.Persist()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop genesis block state: %w", err)
|
||||
}
|
||||
}
|
||||
err := bc.dao.Store.Put(jumpStageKey, []byte{byte(genesisStateRemoved)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store state jump stage: %w", err)
|
||||
}
|
||||
case genesisStateRemoved:
|
||||
// there's nothing to do after that, so just continue with common operations
|
||||
// and remove state jump stage in the end.
|
||||
default:
|
||||
return errors.New("unknown state jump stage")
|
||||
}
|
||||
|
||||
block, err := bc.dao.GetBlock(bc.headerHashes[p])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current block: %w", err)
|
||||
}
|
||||
writeBuf.Reset()
|
||||
err = bc.dao.StoreAsCurrentBlock(block, writeBuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store current block: %w", err)
|
||||
}
|
||||
bc.topBlock.Store(block)
|
||||
atomic.StoreUint32(&bc.blockHeight, p)
|
||||
atomic.StoreUint32(&bc.persistedHeight, p)
|
||||
|
||||
block, err = bc.dao.GetBlock(bc.headerHashes[p+1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get block to init MPT: %w", err)
|
||||
}
|
||||
if err = bc.stateRoot.JumpToState(&state.MPTRoot{
|
||||
Index: p,
|
||||
Root: block.PrevStateRoot,
|
||||
}, bc.config.KeepOnlyLatestState); err != nil {
|
||||
return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err)
|
||||
}
|
||||
|
||||
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
|
||||
}
|
||||
err = bc.contracts.Management.InitializeCache(bc.dao)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't init cache for Management native contract: %w", err)
|
||||
}
|
||||
bc.contracts.Designate.InitializeCache()
|
||||
|
||||
if err := bc.updateExtensibleWhitelist(p); err != nil {
|
||||
return fmt.Errorf("failed to update extensible whitelist: %w", err)
|
||||
}
|
||||
|
||||
updateBlockHeightMetric(p)
|
||||
|
||||
err = bc.dao.Store.Delete(jumpStageKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove outdated state jump stage: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run runs chain loop, it needs to be run as goroutine and executing it is
|
||||
// critical for correct Blockchain operation.
|
||||
func (bc *Blockchain) Run() {
|
||||
|
@ -696,6 +893,11 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot {
|
|||
return bc.stateRoot
|
||||
}
|
||||
|
||||
// GetStateSyncModule returns new state sync service instance.
|
||||
func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync {
|
||||
return statesync.NewModule(bc, bc.log, bc.dao, bc.jumpToState)
|
||||
}
|
||||
|
||||
// storeBlock performs chain update using the block given, it executes all
|
||||
// transactions with all appropriate side-effects and updates Blockchain state.
|
||||
// This is the only way to change Blockchain state.
|
||||
|
@ -1159,12 +1361,25 @@ func (bc *Blockchain) GetNEP17Contracts() []util.Uint160 {
|
|||
}
|
||||
|
||||
// GetNEP17LastUpdated returns a set of contract ids with the corresponding last updated
|
||||
// block indexes.
|
||||
// block indexes. In case of an empty account, latest stored state synchronisation point
|
||||
// is returned under Math.MinInt32 key.
|
||||
func (bc *Blockchain) GetNEP17LastUpdated(acc util.Uint160) (map[int32]uint32, error) {
|
||||
info, err := bc.dao.GetNEP17TransferInfo(acc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if bc.config.P2PStateExchangeExtensions && bc.config.RemoveUntraceableBlocks {
|
||||
if _, ok := info.LastUpdated[bc.contracts.NEO.ID]; !ok {
|
||||
nBalance, lub := bc.contracts.NEO.BalanceOf(bc.dao, acc)
|
||||
if nBalance.Sign() != 0 {
|
||||
info.LastUpdated[bc.contracts.NEO.ID] = lub
|
||||
}
|
||||
}
|
||||
}
|
||||
stateSyncPoint, err := bc.dao.GetStateSyncPoint()
|
||||
if err == nil {
|
||||
info.LastUpdated[math.MinInt32] = stateSyncPoint
|
||||
}
|
||||
return info.LastUpdated, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
@ -34,6 +35,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
|
@ -41,6 +43,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/wallet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func TestVerifyHeader(t *testing.T) {
|
||||
|
@ -1734,3 +1737,88 @@ func TestConfigNativeUpdateHistory(t *testing.T) {
|
|||
check(t, tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlockchain_InitWithIncompleteStateJump(t *testing.T) {
|
||||
var (
|
||||
stateSyncInterval = 4
|
||||
maxTraceable uint32 = 6
|
||||
)
|
||||
spountCfg := func(c *config.Config) {
|
||||
c.ProtocolConfiguration.RemoveUntraceableBlocks = true
|
||||
c.ProtocolConfiguration.StateRootInHeader = true
|
||||
c.ProtocolConfiguration.P2PStateExchangeExtensions = true
|
||||
c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval
|
||||
c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable
|
||||
}
|
||||
bcSpout := newTestChainWithCustomCfg(t, spountCfg)
|
||||
initBasicChain(t, bcSpout)
|
||||
|
||||
// reach next to the latest state sync point and pretend that we've just restored
|
||||
stateSyncPoint := (int(bcSpout.BlockHeight())/stateSyncInterval + 1) * stateSyncInterval
|
||||
for i := bcSpout.BlockHeight() + 1; i <= uint32(stateSyncPoint); i++ {
|
||||
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||
}
|
||||
require.Equal(t, uint32(stateSyncPoint), bcSpout.BlockHeight())
|
||||
b := bcSpout.newBlock()
|
||||
require.NoError(t, bcSpout.AddHeaders(&b.Header))
|
||||
|
||||
// put storage items with STTemp prefix
|
||||
batch := bcSpout.dao.Store.Batch()
|
||||
bcSpout.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) {
|
||||
key := slice.Copy(k)
|
||||
key[0] = storage.STTempStorage.Bytes()[0]
|
||||
value := slice.Copy(v)
|
||||
batch.Put(key, value)
|
||||
})
|
||||
require.NoError(t, bcSpout.dao.Store.PutBatch(batch))
|
||||
|
||||
checkNewBlockchainErr := func(t *testing.T, cfg func(c *config.Config), store storage.Store, shouldFail bool) {
|
||||
unitTestNetCfg, err := config.Load("../../config", testchain.Network())
|
||||
require.NoError(t, err)
|
||||
cfg(&unitTestNetCfg)
|
||||
log := zaptest.NewLogger(t)
|
||||
_, err = NewBlockchain(store, unitTestNetCfg.ProtocolConfiguration, log)
|
||||
if shouldFail {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
boltCfg := func(c *config.Config) {
|
||||
spountCfg(c)
|
||||
c.ProtocolConfiguration.KeepOnlyLatestState = true
|
||||
}
|
||||
// manually store statejump stage to check statejump recover process
|
||||
t.Run("invalid RemoveUntraceableBlocks setting", func(t *testing.T) {
|
||||
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)}))
|
||||
checkNewBlockchainErr(t, func(c *config.Config) {
|
||||
boltCfg(c)
|
||||
c.ProtocolConfiguration.RemoveUntraceableBlocks = false
|
||||
}, bcSpout.dao.Store, true)
|
||||
})
|
||||
t.Run("invalid state jump stage format", func(t *testing.T) {
|
||||
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{0x01, 0x02}))
|
||||
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true)
|
||||
})
|
||||
t.Run("missing state sync point", func(t *testing.T) {
|
||||
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)}))
|
||||
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true)
|
||||
})
|
||||
t.Run("invalid state sync point", func(t *testing.T) {
|
||||
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)}))
|
||||
point := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(point, uint32(len(bcSpout.headerHashes)))
|
||||
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point))
|
||||
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true)
|
||||
})
|
||||
for _, stage := range []stateJumpStage{stateJumpStarted, oldStorageItemsRemoved, newStorageItemsAdded, genesisStateRemoved, 0x03} {
|
||||
t.Run(fmt.Sprintf("state jump stage %d", stage), func(t *testing.T) {
|
||||
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stage)}))
|
||||
point := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(point, uint32(stateSyncPoint))
|
||||
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point))
|
||||
shouldFail := stage == 0x03 // unknown stage
|
||||
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, shouldFail)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
type Blockchainer interface {
|
||||
ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction
|
||||
GetConfig() config.ProtocolConfiguration
|
||||
AddHeaders(...*block.Header) error
|
||||
Blockqueuer // Blockqueuer interface
|
||||
CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error)
|
||||
Close()
|
||||
|
@ -56,6 +55,7 @@ type Blockchainer interface {
|
|||
GetStandByCommittee() keys.PublicKeys
|
||||
GetStandByValidators() keys.PublicKeys
|
||||
GetStateModule() StateRoot
|
||||
GetStateSyncModule() StateSync
|
||||
GetStorageItem(id int32, key []byte) state.StorageItem
|
||||
GetStorageItems(id int32) (map[string]state.StorageItem, error)
|
||||
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM
|
||||
|
|
|
@ -5,5 +5,6 @@ import "github.com/nspcc-dev/neo-go/pkg/core/block"
|
|||
// Blockqueuer is an interface for blockqueue.
|
||||
type Blockqueuer interface {
|
||||
AddBlock(block *block.Block) error
|
||||
AddHeaders(...*block.Header) error
|
||||
BlockHeight() uint32
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
// StateRoot represents local state root module.
|
||||
type StateRoot interface {
|
||||
AddStateRoot(root *state.MPTRoot) error
|
||||
CleanStorage() error
|
||||
CurrentLocalHeight() uint32
|
||||
CurrentLocalStateRoot() util.Uint256
|
||||
CurrentValidatedHeight() uint32
|
||||
GetStateProof(root util.Uint256, key []byte) ([][]byte, error)
|
||||
|
|
19
pkg/core/blockchainer/state_sync.go
Normal file
19
pkg/core/blockchainer/state_sync.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package blockchainer
|
||||
|
||||
import (
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// StateSync represents state sync module.
|
||||
type StateSync interface {
|
||||
AddMPTNodes([][]byte) error
|
||||
Blockqueuer // Blockqueuer interface
|
||||
Init(currChainHeight uint32) error
|
||||
IsActive() bool
|
||||
IsInitialized() bool
|
||||
GetUnknownMPTNodesBatch(limit int) []util.Uint256
|
||||
NeedHeaders() bool
|
||||
NeedMPTNodes() bool
|
||||
Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
|
||||
}
|
314
pkg/core/mpt/billet.go
Normal file
314
pkg/core/mpt/billet.go
Normal file
|
@ -0,0 +1,314 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRestoreFailed is returned when replacing HashNode by its "unhashed"
|
||||
// candidate fails.
|
||||
ErrRestoreFailed = errors.New("failed to restore MPT node")
|
||||
errStop = errors.New("stop condition is met")
|
||||
)
|
||||
|
||||
// Billet is a part of MPT trie with missing hash nodes that need to be restored.
|
||||
// Billet is based on the following assumptions:
|
||||
// 1. Refcount can only be incremented (we don't change MPT structure during restore,
|
||||
// thus don't need to decrease refcount).
|
||||
// 2. Each time the part of Billet is completely restored, it is collapsed into
|
||||
// HashNode.
|
||||
// 3. Pair (node, path) must be restored only once. It's a duty of MPT pool to manage
|
||||
// MPT paths in order to provide this assumption.
|
||||
type Billet struct {
|
||||
Store *storage.MemCachedStore
|
||||
|
||||
root Node
|
||||
refcountEnabled bool
|
||||
}
|
||||
|
||||
// NewBillet returns new billet for MPT trie restoring. It accepts a MemCachedStore
|
||||
// to decouple storage errors from logic errors so that all storage errors are
|
||||
// processed during `store.Persist()` at the caller. This also has the benefit,
|
||||
// that every `Put` can be considered an atomic operation.
|
||||
func NewBillet(rootHash util.Uint256, enableRefCount bool, store *storage.MemCachedStore) *Billet {
|
||||
return &Billet{
|
||||
Store: store,
|
||||
root: NewHashNode(rootHash),
|
||||
refcountEnabled: enableRefCount,
|
||||
}
|
||||
}
|
||||
|
||||
// RestoreHashNode replaces HashNode located at the provided path by the specified Node
|
||||
// and stores it. It also maintains MPT as small as possible by collapsing those parts
|
||||
// of MPT that have been completely restored.
|
||||
func (b *Billet) RestoreHashNode(path []byte, node Node) error {
|
||||
if _, ok := node.(*HashNode); ok {
|
||||
return fmt.Errorf("%w: unable to restore node into HashNode", ErrRestoreFailed)
|
||||
}
|
||||
if _, ok := node.(EmptyNode); ok {
|
||||
return fmt.Errorf("%w: unable to restore node into EmptyNode", ErrRestoreFailed)
|
||||
}
|
||||
r, err := b.putIntoNode(b.root, path, node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.root = r
|
||||
|
||||
// If it's a leaf, then put into temporary contract storage.
|
||||
if leaf, ok := node.(*LeafNode); ok {
|
||||
k := append([]byte{byte(storage.STTempStorage)}, fromNibbles(path)...)
|
||||
_ = b.Store.Put(k, leaf.value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// putIntoNode puts val with provided path inside curr and returns updated node.
|
||||
// Reference counters are updated for both curr and returned value.
|
||||
func (b *Billet) putIntoNode(curr Node, path []byte, val Node) (Node, error) {
|
||||
switch n := curr.(type) {
|
||||
case *LeafNode:
|
||||
return b.putIntoLeaf(n, path, val)
|
||||
case *BranchNode:
|
||||
return b.putIntoBranch(n, path, val)
|
||||
case *ExtensionNode:
|
||||
return b.putIntoExtension(n, path, val)
|
||||
case *HashNode:
|
||||
return b.putIntoHash(n, path, val)
|
||||
case EmptyNode:
|
||||
return nil, fmt.Errorf("%w: can't modify EmptyNode during restore", ErrRestoreFailed)
|
||||
default:
|
||||
panic("invalid MPT node type")
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Billet) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) {
|
||||
if len(path) != 0 {
|
||||
return nil, fmt.Errorf("%w: can't modify LeafNode during restore", ErrRestoreFailed)
|
||||
}
|
||||
if curr.Hash() != val.Hash() {
|
||||
return nil, fmt.Errorf("%w: bad Leaf node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE())
|
||||
}
|
||||
// Once Leaf node is restored, it will be collapsed into HashNode forever, so
|
||||
// there shouldn't be such situation when we try to restore Leaf node.
|
||||
panic("bug: can't restore LeafNode")
|
||||
}
|
||||
|
||||
func (b *Billet) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) {
|
||||
if len(path) == 0 && curr.Hash().Equals(val.Hash()) {
|
||||
// This node has already been restored, so it's an MPT pool duty to avoid
|
||||
// duplicating restore requests.
|
||||
panic("bug: can't perform restoring of BranchNode twice")
|
||||
}
|
||||
i, path := splitPath(path)
|
||||
r, err := b.putIntoNode(curr.Children[i], path, val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
curr.Children[i] = r
|
||||
return b.tryCollapseBranch(curr), nil
|
||||
}
|
||||
|
||||
func (b *Billet) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) {
|
||||
if len(path) == 0 {
|
||||
if curr.Hash() != val.Hash() {
|
||||
return nil, fmt.Errorf("%w: bad Extension node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE())
|
||||
}
|
||||
// This node has already been restored, so it's an MPT pool duty to avoid
|
||||
// duplicating restore requests.
|
||||
panic("bug: can't perform restoring of ExtensionNode twice")
|
||||
}
|
||||
if !bytes.HasPrefix(path, curr.key) {
|
||||
return nil, fmt.Errorf("%w: can't modify ExtensionNode during restore", ErrRestoreFailed)
|
||||
}
|
||||
|
||||
r, err := b.putIntoNode(curr.next, path[len(curr.key):], val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
curr.next = r
|
||||
return b.tryCollapseExtension(curr), nil
|
||||
}
|
||||
|
||||
func (b *Billet) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) {
|
||||
// Once a part of MPT Billet is completely restored, it will be collapsed forever, so
|
||||
// it's an MPT pool duty to avoid duplicating restore requests.
|
||||
if len(path) != 0 {
|
||||
return nil, fmt.Errorf("%w: node has already been collapsed", ErrRestoreFailed)
|
||||
}
|
||||
|
||||
// `curr` hash node can be either of
|
||||
// 1) saved in storage (i.g. if we've already restored node with the same hash from the
|
||||
// other part of MPT), so just add it to local in-memory MPT.
|
||||
// 2) missing from the storage. It's OK because we're syncing MPT state, and the purpose
|
||||
// is to store missing hash nodes.
|
||||
// both cases are OK, but we still need to validate `val` against `curr`.
|
||||
if val.Hash() != curr.Hash() {
|
||||
return nil, fmt.Errorf("%w: can't restore HashNode: expected and actual hashes mismatch (%s vs %s)", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE())
|
||||
}
|
||||
|
||||
if curr.Collapsed {
|
||||
// This node has already been restored and collapsed, so it's an MPT pool duty to avoid
|
||||
// duplicating restore requests.
|
||||
panic("bug: can't perform restoring of collapsed node")
|
||||
}
|
||||
|
||||
// We also need to increment refcount in both cases. That's the only place where refcount
|
||||
// is changed during restore process. Also flush right now, because sync process can be
|
||||
// interrupted at any time.
|
||||
b.incrementRefAndStore(val.Hash(), val.Bytes())
|
||||
|
||||
if val.Type() == LeafT {
|
||||
return b.tryCollapseLeaf(val.(*LeafNode)), nil
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (b *Billet) incrementRefAndStore(h util.Uint256, bs []byte) {
|
||||
key := makeStorageKey(h.BytesBE())
|
||||
if b.refcountEnabled {
|
||||
var (
|
||||
err error
|
||||
data []byte
|
||||
cnt int32
|
||||
)
|
||||
// An item may already be in store.
|
||||
data, err = b.Store.Get(key)
|
||||
if err == nil {
|
||||
cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:]))
|
||||
}
|
||||
cnt++
|
||||
if len(data) == 0 {
|
||||
data = append(bs, 0, 0, 0, 0)
|
||||
}
|
||||
binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt))
|
||||
_ = b.Store.Put(key, data)
|
||||
} else {
|
||||
_ = b.Store.Put(key, bs)
|
||||
}
|
||||
}
|
||||
|
||||
// Traverse traverses MPT nodes (pre-order) starting from the billet root down
|
||||
// to its children calling `process` for each serialised node until true is
|
||||
// returned from `process` function. It also replaces all HashNodes to their
|
||||
// "unhashed" counterparts until the stop condition is satisfied.
|
||||
func (b *Billet) Traverse(process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) error {
|
||||
r, err := b.traverse(b.root, process, ignoreStorageErr)
|
||||
if err != nil && !errors.Is(err, errStop) {
|
||||
return err
|
||||
}
|
||||
b.root = r
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) (Node, error) {
|
||||
if _, ok := curr.(EmptyNode); ok {
|
||||
// We're not interested in EmptyNodes, and they do not affect the
|
||||
// traversal process, thus remain them untouched.
|
||||
return curr, nil
|
||||
}
|
||||
if hn, ok := curr.(*HashNode); ok {
|
||||
r, err := b.GetFromStore(hn.Hash())
|
||||
if err != nil {
|
||||
if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) {
|
||||
return hn, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return b.traverse(r, process, ignoreStorageErr)
|
||||
}
|
||||
bytes := slice.Copy(curr.Bytes())
|
||||
if process(curr, bytes) {
|
||||
return curr, errStop
|
||||
}
|
||||
switch n := curr.(type) {
|
||||
case *LeafNode:
|
||||
return b.tryCollapseLeaf(n), nil
|
||||
case *BranchNode:
|
||||
for i := range n.Children {
|
||||
r, err := b.traverse(n.Children[i], process, ignoreStorageErr)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errStop) {
|
||||
return nil, err
|
||||
}
|
||||
n.Children[i] = r
|
||||
return n, err
|
||||
}
|
||||
n.Children[i] = r
|
||||
}
|
||||
return b.tryCollapseBranch(n), nil
|
||||
case *ExtensionNode:
|
||||
r, err := b.traverse(n.next, process, ignoreStorageErr)
|
||||
if err != nil && !errors.Is(err, errStop) {
|
||||
return nil, err
|
||||
}
|
||||
n.next = r
|
||||
return b.tryCollapseExtension(n), err
|
||||
default:
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Billet) tryCollapseLeaf(curr *LeafNode) Node {
|
||||
// Leaf can always be collapsed.
|
||||
res := NewHashNode(curr.Hash())
|
||||
res.Collapsed = true
|
||||
return res
|
||||
}
|
||||
|
||||
func (b *Billet) tryCollapseExtension(curr *ExtensionNode) Node {
|
||||
if !(curr.next.Type() == HashT && curr.next.(*HashNode).Collapsed) {
|
||||
return curr
|
||||
}
|
||||
res := NewHashNode(curr.Hash())
|
||||
res.Collapsed = true
|
||||
return res
|
||||
}
|
||||
|
||||
func (b *Billet) tryCollapseBranch(curr *BranchNode) Node {
|
||||
canCollapse := true
|
||||
for i := 0; i < childrenCount; i++ {
|
||||
if curr.Children[i].Type() == EmptyT {
|
||||
continue
|
||||
}
|
||||
if curr.Children[i].Type() == HashT && curr.Children[i].(*HashNode).Collapsed {
|
||||
continue
|
||||
}
|
||||
canCollapse = false
|
||||
break
|
||||
}
|
||||
if !canCollapse {
|
||||
return curr
|
||||
}
|
||||
res := NewHashNode(curr.Hash())
|
||||
res.Collapsed = true
|
||||
return res
|
||||
}
|
||||
|
||||
// GetFromStore returns MPT node from the storage.
|
||||
func (b *Billet) GetFromStore(h util.Uint256) (Node, error) {
|
||||
data, err := b.Store.Get(makeStorageKey(h.BytesBE()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var n NodeObject
|
||||
r := io.NewBinReaderFromBuf(data)
|
||||
n.DecodeBinary(r)
|
||||
if r.Err != nil {
|
||||
return nil, r.Err
|
||||
}
|
||||
|
||||
if b.refcountEnabled {
|
||||
data = data[:len(data)-4]
|
||||
}
|
||||
n.Node.(flushedNode).setCache(data, h)
|
||||
return n.Node, nil
|
||||
}
|
211
pkg/core/mpt/billet_test.go
Normal file
211
pkg/core/mpt/billet_test.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBillet_RestoreHashNode(t *testing.T) {
|
||||
check := func(t *testing.T, tr *Billet, expectedRoot Node, expectedNode Node, expectedRefCount uint32) {
|
||||
_ = expectedRoot.Hash()
|
||||
_ = tr.root.Hash()
|
||||
require.Equal(t, expectedRoot, tr.root)
|
||||
expectedBytes, err := tr.Store.Get(makeStorageKey(expectedNode.Hash().BytesBE()))
|
||||
if expectedRefCount != 0 {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedRefCount, binary.LittleEndian.Uint32(expectedBytes[len(expectedBytes)-4:]))
|
||||
} else {
|
||||
require.True(t, errors.Is(err, storage.ErrKeyNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("parent is Extension", func(t *testing.T) {
|
||||
t.Run("restore Branch", func(t *testing.T) {
|
||||
b := NewBranchNode()
|
||||
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xCD}))
|
||||
b.Children[5] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xDE}))
|
||||
path := toNibbles([]byte{0xAC})
|
||||
e := NewExtensionNode(path, NewHashNode(b.Hash()))
|
||||
tr := NewBillet(e.Hash(), true, newTestStore())
|
||||
tr.root = e
|
||||
|
||||
// OK
|
||||
n := new(NodeObject)
|
||||
n.DecodeBinary(io.NewBinReaderFromBuf(b.Bytes()))
|
||||
require.NoError(t, tr.RestoreHashNode(path, n.Node))
|
||||
expected := NewExtensionNode(path, n.Node)
|
||||
check(t, tr, expected, n.Node, 1)
|
||||
|
||||
// One more time (already restored) => panic expected, no refcount changes
|
||||
require.Panics(t, func() {
|
||||
_ = tr.RestoreHashNode(path, n.Node)
|
||||
})
|
||||
check(t, tr, expected, n.Node, 1)
|
||||
|
||||
// Same path, but wrong hash => error expected, no refcount changes
|
||||
require.True(t, errors.Is(tr.RestoreHashNode(path, NewBranchNode()), ErrRestoreFailed))
|
||||
check(t, tr, expected, n.Node, 1)
|
||||
|
||||
// New path (changes in the MPT structure are not allowed) => error expected, no refcount changes
|
||||
require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), n.Node), ErrRestoreFailed))
|
||||
check(t, tr, expected, n.Node, 1)
|
||||
})
|
||||
|
||||
t.Run("restore Leaf", func(t *testing.T) {
|
||||
l := NewLeafNode([]byte{0xAB, 0xCD})
|
||||
path := toNibbles([]byte{0xAC})
|
||||
e := NewExtensionNode(path, NewHashNode(l.Hash()))
|
||||
tr := NewBillet(e.Hash(), true, newTestStore())
|
||||
tr.root = e
|
||||
|
||||
// OK
|
||||
require.NoError(t, tr.RestoreHashNode(path, l))
|
||||
expected := NewHashNode(e.Hash()) // leaf should be collapsed immediately => extension should also be collapsed
|
||||
expected.Collapsed = true
|
||||
check(t, tr, expected, l, 1)
|
||||
|
||||
// One more time (already restored and collapsed) => error expected, no refcount changes
|
||||
require.Error(t, tr.RestoreHashNode(path, l))
|
||||
check(t, tr, expected, l, 1)
|
||||
|
||||
// Same path, but wrong hash => error expected, no refcount changes
|
||||
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed))
|
||||
check(t, tr, expected, l, 1)
|
||||
|
||||
// New path (changes in the MPT structure are not allowed) => error expected, no refcount changes
|
||||
require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), l), ErrRestoreFailed))
|
||||
check(t, tr, expected, l, 1)
|
||||
})
|
||||
|
||||
t.Run("restore Hash", func(t *testing.T) {
|
||||
h := NewHashNode(util.Uint256{1, 2, 3})
|
||||
path := toNibbles([]byte{0xAC})
|
||||
e := NewExtensionNode(path, h)
|
||||
tr := NewBillet(e.Hash(), true, newTestStore())
|
||||
tr.root = e
|
||||
|
||||
// no-op
|
||||
require.True(t, errors.Is(tr.RestoreHashNode(path, h), ErrRestoreFailed))
|
||||
check(t, tr, e, h, 0)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("parent is Leaf", func(t *testing.T) {
|
||||
l := NewLeafNode([]byte{0xAB, 0xCD})
|
||||
path := []byte{}
|
||||
tr := NewBillet(l.Hash(), true, newTestStore())
|
||||
tr.root = l
|
||||
|
||||
// Already restored => panic expected
|
||||
require.Panics(t, func() {
|
||||
_ = tr.RestoreHashNode(path, l)
|
||||
})
|
||||
|
||||
// Same path, but wrong hash => error expected, no refcount changes
|
||||
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed))
|
||||
|
||||
// Non-nil path, but MPT structure can't be changed => error expected, no refcount changes
|
||||
require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAC}), NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed))
|
||||
})
|
||||
|
||||
t.Run("parent is Branch", func(t *testing.T) {
|
||||
t.Run("middle child", func(t *testing.T) {
|
||||
l1 := NewLeafNode([]byte{0xAB, 0xCD})
|
||||
l2 := NewLeafNode([]byte{0xAB, 0xDE})
|
||||
b := NewBranchNode()
|
||||
b.Children[5] = NewHashNode(l1.Hash())
|
||||
b.Children[lastChild] = NewHashNode(l2.Hash())
|
||||
tr := NewBillet(b.Hash(), true, newTestStore())
|
||||
tr.root = b
|
||||
|
||||
// OK
|
||||
path := []byte{0x05}
|
||||
require.NoError(t, tr.RestoreHashNode(path, l1))
|
||||
check(t, tr, b, l1, 1)
|
||||
|
||||
// One more time (already restored) => panic expected.
|
||||
// It's an MPT pool duty to avoid such situations during real restore process.
|
||||
require.Panics(t, func() {
|
||||
_ = tr.RestoreHashNode(path, l1)
|
||||
})
|
||||
// No refcount changes expected.
|
||||
check(t, tr, b, l1, 1)
|
||||
|
||||
// Same path, but wrong hash => error expected, no refcount changes
|
||||
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed))
|
||||
check(t, tr, b, l1, 1)
|
||||
|
||||
// New path pointing to the empty HashNode (changes in the MPT structure are not allowed) => error expected, no refcount changes
|
||||
require.True(t, errors.Is(tr.RestoreHashNode([]byte{0x01}, l1), ErrRestoreFailed))
|
||||
check(t, tr, b, l1, 1)
|
||||
})
|
||||
|
||||
t.Run("last child", func(t *testing.T) {
|
||||
l1 := NewLeafNode([]byte{0xAB, 0xCD})
|
||||
l2 := NewLeafNode([]byte{0xAB, 0xDE})
|
||||
b := NewBranchNode()
|
||||
b.Children[5] = NewHashNode(l1.Hash())
|
||||
b.Children[lastChild] = NewHashNode(l2.Hash())
|
||||
tr := NewBillet(b.Hash(), true, newTestStore())
|
||||
tr.root = b
|
||||
|
||||
// OK
|
||||
path := []byte{}
|
||||
require.NoError(t, tr.RestoreHashNode(path, l2))
|
||||
check(t, tr, b, l2, 1)
|
||||
|
||||
// One more time (already restored) => panic expected.
|
||||
// It's an MPT pool duty to avoid such situations during real restore process.
|
||||
require.Panics(t, func() {
|
||||
_ = tr.RestoreHashNode(path, l2)
|
||||
})
|
||||
// No refcount changes expected.
|
||||
check(t, tr, b, l2, 1)
|
||||
|
||||
// Same path, but wrong hash => error expected, no refcount changes
|
||||
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed))
|
||||
check(t, tr, b, l2, 1)
|
||||
})
|
||||
|
||||
t.Run("two children with same hash", func(t *testing.T) {
|
||||
l := NewLeafNode([]byte{0xAB, 0xCD})
|
||||
b := NewBranchNode()
|
||||
// two same hashnodes => leaf's refcount expected to be 2 in the end.
|
||||
b.Children[3] = NewHashNode(l.Hash())
|
||||
b.Children[4] = NewHashNode(l.Hash())
|
||||
tr := NewBillet(b.Hash(), true, newTestStore())
|
||||
tr.root = b
|
||||
|
||||
// OK
|
||||
require.NoError(t, tr.RestoreHashNode([]byte{0x03}, l))
|
||||
expected := b
|
||||
expected.Children[3].(*HashNode).Collapsed = true
|
||||
check(t, tr, b, l, 1)
|
||||
|
||||
// Restore another node with the same hash => no error expected, refcount should be incremented.
|
||||
// Branch node should be collapsed.
|
||||
require.NoError(t, tr.RestoreHashNode([]byte{0x04}, l))
|
||||
res := NewHashNode(b.Hash())
|
||||
res.Collapsed = true
|
||||
check(t, tr, res, l, 2)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("parent is Hash", func(t *testing.T) {
|
||||
l := NewLeafNode([]byte{0xAB, 0xCD})
|
||||
b := NewBranchNode()
|
||||
b.Children[3] = NewHashNode(l.Hash())
|
||||
b.Children[4] = NewHashNode(l.Hash())
|
||||
tr := NewBillet(b.Hash(), true, newTestStore())
|
||||
|
||||
// Should fail, because if it's a hash node with non-empty path, then the node
|
||||
// has already been collapsed.
|
||||
require.Error(t, tr.RestoreHashNode([]byte{0x03}, l))
|
||||
})
|
||||
}
|
|
@ -89,6 +89,12 @@ func (b *BranchNode) UnmarshalJSON(data []byte) error {
|
|||
return errors.New("expected branch node")
|
||||
}
|
||||
|
||||
// Clone implements Node interface.
|
||||
func (b *BranchNode) Clone() Node {
|
||||
res := *b
|
||||
return &res
|
||||
}
|
||||
|
||||
// splitPath splits path for a branch node.
|
||||
func splitPath(path []byte) (byte, []byte) {
|
||||
if len(path) != 0 {
|
||||
|
|
|
@ -54,3 +54,6 @@ func (e EmptyNode) Type() NodeType {
|
|||
func (e EmptyNode) Bytes() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clone implements Node interface.
|
||||
func (EmptyNode) Clone() Node { return EmptyNode{} }
|
||||
|
|
|
@ -98,3 +98,9 @@ func (e *ExtensionNode) UnmarshalJSON(data []byte) error {
|
|||
}
|
||||
return errors.New("expected extension node")
|
||||
}
|
||||
|
||||
// Clone implements Node interface.
|
||||
func (e *ExtensionNode) Clone() Node {
|
||||
res := *e
|
||||
return &res
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
// HashNode represents MPT's hash node.
|
||||
type HashNode struct {
|
||||
BaseNode
|
||||
Collapsed bool
|
||||
}
|
||||
|
||||
var _ Node = (*HashNode)(nil)
|
||||
|
@ -76,3 +77,10 @@ func (h *HashNode) UnmarshalJSON(data []byte) error {
|
|||
}
|
||||
return errors.New("expected hash node")
|
||||
}
|
||||
|
||||
// Clone implements Node interface.
|
||||
func (h *HashNode) Clone() Node {
|
||||
res := *h
|
||||
res.Collapsed = false
|
||||
return &res
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package mpt
|
||||
|
||||
import "github.com/nspcc-dev/neo-go/pkg/util"
|
||||
|
||||
// lcp returns longest common prefix of a and b.
|
||||
// Note: it does no allocations.
|
||||
func lcp(a, b []byte) []byte {
|
||||
|
@ -40,3 +42,45 @@ func toNibbles(path []byte) []byte {
|
|||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// fromNibbles performs operation opposite to toNibbles and does no path validity checks.
|
||||
func fromNibbles(path []byte) []byte {
|
||||
result := make([]byte, len(path)/2)
|
||||
for i := range result {
|
||||
result[i] = path[2*i]<<4 + path[2*i+1]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetChildrenPaths returns a set of paths to node's children who are non-empty HashNodes
|
||||
// based on the node's path.
|
||||
func GetChildrenPaths(path []byte, node Node) map[util.Uint256][][]byte {
|
||||
res := make(map[util.Uint256][][]byte)
|
||||
switch n := node.(type) {
|
||||
case *LeafNode, *HashNode, EmptyNode:
|
||||
return nil
|
||||
case *BranchNode:
|
||||
for i, child := range n.Children {
|
||||
if child.Type() == HashT {
|
||||
cPath := make([]byte, len(path), len(path)+1)
|
||||
copy(cPath, path)
|
||||
if i != lastChild {
|
||||
cPath = append(cPath, byte(i))
|
||||
}
|
||||
paths := res[child.Hash()]
|
||||
paths = append(paths, cPath)
|
||||
res[child.Hash()] = paths
|
||||
}
|
||||
}
|
||||
case *ExtensionNode:
|
||||
if n.next.Type() == HashT {
|
||||
cPath := make([]byte, len(path)+len(n.key))
|
||||
copy(cPath, path)
|
||||
copy(cPath[len(path):], n.key)
|
||||
res[n.next.Hash()] = [][]byte{cPath}
|
||||
}
|
||||
default:
|
||||
panic("unknown Node type")
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
|
67
pkg/core/mpt/helpers_test.go
Normal file
67
pkg/core/mpt/helpers_test.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToNibblesFromNibbles(t *testing.T) {
|
||||
check := func(t *testing.T, expected []byte) {
|
||||
actual := fromNibbles(toNibbles(expected))
|
||||
require.Equal(t, expected, actual)
|
||||
}
|
||||
t.Run("empty path", func(t *testing.T) {
|
||||
check(t, []byte{})
|
||||
})
|
||||
t.Run("non-empty path", func(t *testing.T) {
|
||||
check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetChildrenPaths(t *testing.T) {
|
||||
h1 := NewHashNode(util.Uint256{1, 2, 3})
|
||||
h2 := NewHashNode(util.Uint256{4, 5, 6})
|
||||
h3 := NewHashNode(util.Uint256{7, 8, 9})
|
||||
l := NewLeafNode([]byte{1, 2, 3})
|
||||
ext1 := NewExtensionNode([]byte{8, 9}, h1)
|
||||
ext2 := NewExtensionNode([]byte{7, 6}, l)
|
||||
branch := NewBranchNode()
|
||||
branch.Children[3] = h1
|
||||
branch.Children[5] = l
|
||||
branch.Children[6] = h1 // 3-th and 6-th children have the same hash
|
||||
branch.Children[7] = h3
|
||||
branch.Children[lastChild] = h2
|
||||
testCases := map[string]struct {
|
||||
node Node
|
||||
expected map[util.Uint256][][]byte
|
||||
}{
|
||||
"Hash": {h1, nil},
|
||||
"Leaf": {l, nil},
|
||||
"Extension with next Hash": {ext1, map[util.Uint256][][]byte{h1.Hash(): {ext1.key}}},
|
||||
"Extension with next non-Hash": {ext2, map[util.Uint256][][]byte{}},
|
||||
"Branch": {branch, map[util.Uint256][][]byte{
|
||||
h1.Hash(): {{0x03}, {0x06}},
|
||||
h2.Hash(): {{}},
|
||||
h3.Hash(): {{0x07}},
|
||||
}},
|
||||
}
|
||||
parentPath := []byte{4, 5, 6}
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
require.Equal(t, testCase.expected, GetChildrenPaths([]byte{}, testCase.node))
|
||||
if testCase.expected != nil {
|
||||
expectedWithPrefix := make(map[util.Uint256][][]byte, len(testCase.expected))
|
||||
for h, paths := range testCase.expected {
|
||||
var res [][]byte
|
||||
for _, path := range paths {
|
||||
res = append(res, append(parentPath, path...))
|
||||
}
|
||||
expectedWithPrefix[h] = res
|
||||
}
|
||||
require.Equal(t, expectedWithPrefix, GetChildrenPaths(parentPath, testCase.node))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -77,3 +77,9 @@ func (n *LeafNode) UnmarshalJSON(data []byte) error {
|
|||
}
|
||||
return errors.New("expected leaf node")
|
||||
}
|
||||
|
||||
// Clone implements Node interface.
|
||||
func (n *LeafNode) Clone() Node {
|
||||
res := *n
|
||||
return &res
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ type Node interface {
|
|||
json.Marshaler
|
||||
json.Unmarshaler
|
||||
Size() int
|
||||
Clone() Node
|
||||
BaseNodeIface
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newProofTrie(t *testing.T) *Trie {
|
||||
func newProofTrie(t *testing.T, missingHashNode bool) *Trie {
|
||||
l := NewLeafNode([]byte("somevalue"))
|
||||
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
|
||||
l2 := NewLeafNode([]byte("invalid"))
|
||||
|
@ -20,11 +20,14 @@ func newProofTrie(t *testing.T) *Trie {
|
|||
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
|
||||
tr.putToStore(l)
|
||||
tr.putToStore(e)
|
||||
if !missingHashNode {
|
||||
tr.putToStore(l2)
|
||||
}
|
||||
return tr
|
||||
}
|
||||
|
||||
func TestTrie_GetProof(t *testing.T) {
|
||||
tr := newProofTrie(t)
|
||||
tr := newProofTrie(t, true)
|
||||
|
||||
t.Run("MissingKey", func(t *testing.T) {
|
||||
_, err := tr.GetProof([]byte{0x12})
|
||||
|
@ -43,7 +46,7 @@ func TestTrie_GetProof(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestVerifyProof(t *testing.T) {
|
||||
tr := newProofTrie(t)
|
||||
tr := newProofTrie(t, true)
|
||||
|
||||
t.Run("Simple", func(t *testing.T) {
|
||||
proof, err := tr.GetProof([]byte{0x12, 0x32})
|
||||
|
|
|
@ -353,3 +353,8 @@ func (s *Designate) getRole(item stackitem.Item) (noderoles.Role, bool) {
|
|||
u := bi.Uint64()
|
||||
return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u))
|
||||
}
|
||||
|
||||
// InitializeCache invalidates native Designate cache.
|
||||
func (s *Designate) InitializeCache() {
|
||||
s.rolesChangedFlag.Store(true)
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
@ -114,6 +115,66 @@ func (s *Module) Init(height uint32, enableRefCount bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// CleanStorage removes all MPT-related data from the storage (MPT nodes, validated stateroots)
|
||||
// except local stateroot for the current height and GC flag. This method is aimed to clean
|
||||
// outdated MPT data before state sync process can be started.
|
||||
// Note: this method is aimed to be called for genesis block only, an error is returned otherwice.
|
||||
func (s *Module) CleanStorage() error {
|
||||
if s.localHeight.Load() != 0 {
|
||||
return fmt.Errorf("can't clean MPT data for non-genesis block: expected local stateroot height 0, got %d", s.localHeight.Load())
|
||||
}
|
||||
gcKey := []byte{byte(storage.DataMPT), prefixGC}
|
||||
gcVal, err := s.Store.Get(gcKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get GC flag: %w", err)
|
||||
}
|
||||
//
|
||||
b := s.Store.Batch()
|
||||
s.Store.Seek([]byte{byte(storage.DataMPT)}, func(k, _ []byte) {
|
||||
// Must copy here, #1468.
|
||||
key := slice.Copy(k)
|
||||
b.Delete(key)
|
||||
})
|
||||
err = s.Store.PutBatch(b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove outdated MPT-reated items: %w", err)
|
||||
}
|
||||
err = s.Store.Put(gcKey, gcVal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store GC flag: %w", err)
|
||||
}
|
||||
currentLocal := s.currentLocal.Load().(util.Uint256)
|
||||
if !currentLocal.Equals(util.Uint256{}) {
|
||||
err := s.addLocalStateRoot(s.Store, &state.MPTRoot{
|
||||
Index: s.localHeight.Load(),
|
||||
Root: currentLocal,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store current local stateroot: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// JumpToState performs jump to the state specified by given stateroot index.
|
||||
func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error {
|
||||
if err := s.addLocalStateRoot(s.Store, sr); err != nil {
|
||||
return fmt.Errorf("failed to store local state root: %w", err)
|
||||
}
|
||||
|
||||
data := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(data, sr.Index)
|
||||
if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixValidated}, data); err != nil {
|
||||
return fmt.Errorf("failed to store validated height: %w", err)
|
||||
}
|
||||
s.validatedHeight.Store(sr.Index)
|
||||
|
||||
s.currentLocal.Store(sr.Root)
|
||||
s.localHeight.Store(sr.Index)
|
||||
s.mpt = mpt.NewTrie(mpt.NewHashNode(sr.Root), enableRefCount, s.Store)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMPTBatch updates using provided batch.
|
||||
func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) {
|
||||
mpt := *s.mpt
|
||||
|
|
479
pkg/core/statesync/module.go
Normal file
479
pkg/core/statesync/module.go
Normal file
|
@ -0,0 +1,479 @@
|
|||
/*
|
||||
Package statesync implements module for the P2P state synchronisation process. The
|
||||
module manages state synchronisation for non-archival nodes which are joining the
|
||||
network and don't have the ability to resync from the genesis block.
|
||||
|
||||
Given the currently available state synchronisation point P, sate sync process
|
||||
includes the following stages:
|
||||
|
||||
1. Fetching headers starting from height 0 up to P+1.
|
||||
2. Fetching MPT nodes for height P stating from the corresponding state root.
|
||||
3. Fetching blocks starting from height P-MaxTraceableBlocks (or 0) up to P.
|
||||
|
||||
Steps 2 and 3 are being performed in parallel. Once all the data are collected
|
||||
and stored in the db, an atomic state jump is occurred to the state sync point P.
|
||||
Further node operation process is performed using standard sync mechanism until
|
||||
the node reaches synchronised state.
|
||||
*/
|
||||
package statesync
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// stateSyncStage is a type of state synchronisation stage.
|
||||
type stateSyncStage uint8
|
||||
|
||||
const (
|
||||
// inactive means that state exchange is disabled by the protocol configuration.
|
||||
// Can't be combined with other states.
|
||||
inactive stateSyncStage = 1 << iota
|
||||
// none means that state exchange is enabled in the configuration, but
|
||||
// initialisation of the state sync module wasn't yet performed, i.e.
|
||||
// (*Module).Init wasn't called. Can't be combined with other states.
|
||||
none
|
||||
// initialized means that (*Module).Init was called, but other sync stages
|
||||
// are not yet reached (i.e. that headers are requested, but not yet fetched).
|
||||
// Can't be combined with other states.
|
||||
initialized
|
||||
// headersSynced means that headers for the current state sync point are fetched.
|
||||
// May be combined with mptSynced and/or blocksSynced.
|
||||
headersSynced
|
||||
// mptSynced means that MPT nodes for the current state sync point are fetched.
|
||||
// Always combined with headersSynced; may be combined with blocksSynced.
|
||||
mptSynced
|
||||
// blocksSynced means that blocks up to the current state sync point are stored.
|
||||
// Always combined with headersSynced; may be combined with mptSynced.
|
||||
blocksSynced
|
||||
)
|
||||
|
||||
// Module represents state sync module and aimed to gather state-related data to
|
||||
// perform an atomic state jump.
|
||||
type Module struct {
|
||||
lock sync.RWMutex
|
||||
log *zap.Logger
|
||||
|
||||
// syncPoint is the state synchronisation point P we're currently working against.
|
||||
syncPoint uint32
|
||||
// syncStage is the stage of the sync process.
|
||||
syncStage stateSyncStage
|
||||
// syncInterval is the delta between two adjacent state sync points.
|
||||
syncInterval uint32
|
||||
// blockHeight is the index of the latest stored block.
|
||||
blockHeight uint32
|
||||
|
||||
dao *dao.Simple
|
||||
bc blockchainer.Blockchainer
|
||||
mptpool *Pool
|
||||
|
||||
billet *mpt.Billet
|
||||
|
||||
jumpCallback func(p uint32) error
|
||||
}
|
||||
|
||||
// NewModule returns new instance of statesync module.
|
||||
func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple, jumpCallback func(p uint32) error) *Module {
|
||||
if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) {
|
||||
return &Module{
|
||||
dao: s,
|
||||
bc: bc,
|
||||
syncStage: inactive,
|
||||
}
|
||||
}
|
||||
return &Module{
|
||||
dao: s,
|
||||
bc: bc,
|
||||
log: log,
|
||||
syncInterval: uint32(bc.GetConfig().StateSyncInterval),
|
||||
mptpool: NewPool(),
|
||||
syncStage: none,
|
||||
jumpCallback: jumpCallback,
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes state sync module for the current chain's height with given
|
||||
// callback for MPT nodes requests.
|
||||
func (s *Module) Init(currChainHeight uint32) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.syncStage != none {
|
||||
return errors.New("already initialized or inactive")
|
||||
}
|
||||
|
||||
p := (currChainHeight / s.syncInterval) * s.syncInterval
|
||||
if p < 2*s.syncInterval {
|
||||
// chain is too low to start state exchange process, use the standard sync mechanism
|
||||
s.syncStage = inactive
|
||||
return nil
|
||||
}
|
||||
pOld, err := s.dao.GetStateSyncPoint()
|
||||
if err == nil && pOld >= p-s.syncInterval {
|
||||
// old point is still valid, so try to resync states for this point.
|
||||
p = pOld
|
||||
} else {
|
||||
if s.bc.BlockHeight() > p-2*s.syncInterval {
|
||||
// chain has already been synchronised up to old state sync point and regular blocks processing was started.
|
||||
// Current block height is enough to start regular blocks processing.
|
||||
s.syncStage = inactive
|
||||
return nil
|
||||
}
|
||||
if err == nil {
|
||||
// pOld was found, it is outdated, and chain wasn't completely synchronised for pOld. Need to drop the db.
|
||||
return fmt.Errorf("state sync point %d is found in the storage, "+
|
||||
"but sync process wasn't completed and point is outdated. Please, drop the database manually and restart the node to run state sync process", pOld)
|
||||
}
|
||||
if s.bc.BlockHeight() != 0 {
|
||||
// pOld wasn't found, but blocks processing was started in a regular manner and latest stored block is too outdated
|
||||
// to start regular blocks processing again. Need to drop the db.
|
||||
return fmt.Errorf("current chain's height is too low to start regular blocks processing from the oldest sync point %d. "+
|
||||
"Please, drop the database manually and restart the node to run state sync process", p-s.syncInterval)
|
||||
}
|
||||
|
||||
// We've reached this point, so chain has genesis block only. As far as we can't ruin
|
||||
// current chain's state until new state is completely fetched, outdated state-related data
|
||||
// will be removed from storage during (*Blockchain).jumpToState(...) execution.
|
||||
// All we need to do right now is to remove genesis-related MPT nodes.
|
||||
err = s.bc.GetStateModule().CleanStorage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove outdated MPT data from storage: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.syncPoint = p
|
||||
err = s.dao.PutStateSyncPoint(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store state synchronisation point %d: %w", p, err)
|
||||
}
|
||||
s.syncStage = initialized
|
||||
s.log.Info("try to sync state for the latest state synchronisation point",
|
||||
zap.Uint32("point", p),
|
||||
zap.Uint32("evaluated chain's blockHeight", currChainHeight))
|
||||
|
||||
return s.defineSyncStage()
|
||||
}
|
||||
|
||||
// defineSyncStage sequentially checks and sets sync state process stage after Module
|
||||
// initialization. It also performs initialization of MPT Billet if necessary.
|
||||
func (s *Module) defineSyncStage() error {
|
||||
// check headers sync stage first
|
||||
ltstHeaderHeight := s.bc.HeaderHeight()
|
||||
if ltstHeaderHeight > s.syncPoint {
|
||||
s.syncStage = headersSynced
|
||||
s.log.Info("headers are in sync",
|
||||
zap.Uint32("headerHeight", s.bc.HeaderHeight()))
|
||||
}
|
||||
|
||||
// check blocks sync stage
|
||||
s.blockHeight = s.getLatestSavedBlock(s.syncPoint)
|
||||
if s.blockHeight >= s.syncPoint {
|
||||
s.syncStage |= blocksSynced
|
||||
s.log.Info("blocks are in sync",
|
||||
zap.Uint32("blockHeight", s.blockHeight))
|
||||
}
|
||||
|
||||
// check MPT sync stage
|
||||
if s.blockHeight > s.syncPoint {
|
||||
s.syncStage |= mptSynced
|
||||
s.log.Info("MPT is in sync",
|
||||
zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight()))
|
||||
} else if s.syncStage&headersSynced != 0 {
|
||||
header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint + 1)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get header to initialize MPT billet: %w", err)
|
||||
}
|
||||
s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store)
|
||||
s.log.Info("MPT billet initialized",
|
||||
zap.Uint32("height", s.syncPoint),
|
||||
zap.String("state root", header.PrevStateRoot.StringBE()))
|
||||
pool := NewPool()
|
||||
pool.Add(header.PrevStateRoot, []byte{})
|
||||
err = s.billet.Traverse(func(n mpt.Node, _ []byte) bool {
|
||||
nPaths, ok := pool.TryGet(n.Hash())
|
||||
if !ok {
|
||||
// if this situation occurs, then it's a bug in MPT pool or Traverse.
|
||||
panic("failed to get MPT node from the pool")
|
||||
}
|
||||
pool.Remove(n.Hash())
|
||||
childrenPaths := make(map[util.Uint256][][]byte)
|
||||
for _, path := range nPaths {
|
||||
nChildrenPaths := mpt.GetChildrenPaths(path, n)
|
||||
for hash, paths := range nChildrenPaths {
|
||||
childrenPaths[hash] = append(childrenPaths[hash], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||
}
|
||||
}
|
||||
pool.Update(nil, childrenPaths)
|
||||
return false
|
||||
}, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to traverse MPT during initialization: %w", err)
|
||||
}
|
||||
s.mptpool.Update(nil, pool.GetAll())
|
||||
if s.mptpool.Count() == 0 {
|
||||
s.syncStage |= mptSynced
|
||||
s.log.Info("MPT is in sync",
|
||||
zap.Uint32("stateroot height", s.syncPoint))
|
||||
}
|
||||
}
|
||||
|
||||
if s.syncStage == headersSynced|blocksSynced|mptSynced {
|
||||
s.log.Info("state is in sync, starting regular blocks processing")
|
||||
s.syncStage = inactive
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLatestSavedBlock returns either current block index (if it's still relevant
|
||||
// to continue state sync process) or H-1 where H is the index of the earliest
|
||||
// block that should be saved next.
|
||||
func (s *Module) getLatestSavedBlock(p uint32) uint32 {
|
||||
var result uint32
|
||||
mtb := s.bc.GetConfig().MaxTraceableBlocks
|
||||
if p > mtb {
|
||||
result = p - mtb
|
||||
}
|
||||
storedH, err := s.dao.GetStateSyncCurrentBlockHeight()
|
||||
if err == nil && storedH > result {
|
||||
result = storedH
|
||||
}
|
||||
actualH := s.bc.BlockHeight()
|
||||
if actualH > result {
|
||||
result = actualH
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AddHeaders validates and adds specified headers to the chain.
|
||||
func (s *Module) AddHeaders(hdrs ...*block.Header) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.syncStage != initialized {
|
||||
return errors.New("headers were not requested")
|
||||
}
|
||||
|
||||
hdrsErr := s.bc.AddHeaders(hdrs...)
|
||||
if s.bc.HeaderHeight() > s.syncPoint {
|
||||
err := s.defineSyncStage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to define current sync stage: %w", err)
|
||||
}
|
||||
}
|
||||
return hdrsErr
|
||||
}
|
||||
|
||||
// AddBlock verifies and saves block skipping executable scripts.
|
||||
func (s *Module) AddBlock(block *block.Block) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.syncStage&headersSynced == 0 || s.syncStage&blocksSynced != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.blockHeight == s.syncPoint {
|
||||
return nil
|
||||
}
|
||||
expectedHeight := s.blockHeight + 1
|
||||
if expectedHeight != block.Index {
|
||||
return fmt.Errorf("expected %d, got %d: invalid block index", expectedHeight, block.Index)
|
||||
}
|
||||
if s.bc.GetConfig().StateRootInHeader != block.StateRootEnabled {
|
||||
return fmt.Errorf("stateroot setting mismatch: %v != %v", s.bc.GetConfig().StateRootInHeader, block.StateRootEnabled)
|
||||
}
|
||||
if s.bc.GetConfig().VerifyBlocks {
|
||||
merkle := block.ComputeMerkleRoot()
|
||||
if !block.MerkleRoot.Equals(merkle) {
|
||||
return errors.New("invalid block: MerkleRoot mismatch")
|
||||
}
|
||||
}
|
||||
cache := s.dao.GetWrapped()
|
||||
writeBuf := io.NewBufBinWriter()
|
||||
if err := cache.StoreAsBlock(block, writeBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
writeBuf.Reset()
|
||||
|
||||
err := cache.PutStateSyncCurrentBlockHeight(block.Index)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store current block height: %w", err)
|
||||
}
|
||||
|
||||
for _, tx := range block.Transactions {
|
||||
if err := cache.StoreAsTransaction(tx, block.Index, writeBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
writeBuf.Reset()
|
||||
}
|
||||
|
||||
_, err = cache.Persist()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to persist results: %w", err)
|
||||
}
|
||||
s.blockHeight = block.Index
|
||||
if s.blockHeight == s.syncPoint {
|
||||
s.syncStage |= blocksSynced
|
||||
s.log.Info("blocks are in sync",
|
||||
zap.Uint32("blockHeight", s.blockHeight))
|
||||
s.checkSyncIsCompleted()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMPTNodes tries to add provided set of MPT nodes to the MPT billet if they are
|
||||
// not yet collected.
|
||||
func (s *Module) AddMPTNodes(nodes [][]byte) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.syncStage&headersSynced == 0 || s.syncStage&mptSynced != 0 {
|
||||
return errors.New("MPT nodes were not requested")
|
||||
}
|
||||
|
||||
for _, nBytes := range nodes {
|
||||
var n mpt.NodeObject
|
||||
r := io.NewBinReaderFromBuf(nBytes)
|
||||
n.DecodeBinary(r)
|
||||
if r.Err != nil {
|
||||
return fmt.Errorf("failed to decode MPT node: %w", r.Err)
|
||||
}
|
||||
err := s.restoreNode(n.Node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if s.mptpool.Count() == 0 {
|
||||
s.syncStage |= mptSynced
|
||||
s.log.Info("MPT is in sync",
|
||||
zap.Uint32("height", s.syncPoint))
|
||||
s.checkSyncIsCompleted()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Module) restoreNode(n mpt.Node) error {
|
||||
nPaths, ok := s.mptpool.TryGet(n.Hash())
|
||||
if !ok {
|
||||
// it can easily happen after receiving the same data from different peers.
|
||||
return nil
|
||||
}
|
||||
var childrenPaths = make(map[util.Uint256][][]byte)
|
||||
for _, path := range nPaths {
|
||||
// Must clone here in order to avoid future collapse collisions. If the node's refcount>1 then MPT pool
|
||||
// will manage all paths for this node and call RestoreHashNode separately for each of the paths.
|
||||
err := s.billet.RestoreHashNode(path, n.Clone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
|
||||
}
|
||||
for h, paths := range mpt.GetChildrenPaths(path, n) {
|
||||
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
|
||||
}
|
||||
}
|
||||
|
||||
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
|
||||
|
||||
for h := range childrenPaths {
|
||||
if child, err := s.billet.GetFromStore(h); err == nil {
|
||||
// child is already in the storage, so we don't need to request it one more time.
|
||||
err = s.restoreNode(child)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to restore saved children: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1
|
||||
// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored.
|
||||
// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller
|
||||
// should take care of it.
|
||||
func (s *Module) checkSyncIsCompleted() {
|
||||
if s.syncStage != headersSynced|mptSynced|blocksSynced {
|
||||
return
|
||||
}
|
||||
s.log.Info("state is in sync",
|
||||
zap.Uint32("state sync point", s.syncPoint))
|
||||
err := s.jumpCallback(s.syncPoint)
|
||||
if err != nil {
|
||||
s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err))
|
||||
}
|
||||
s.syncStage = inactive
|
||||
s.dispose()
|
||||
}
|
||||
|
||||
func (s *Module) dispose() {
|
||||
s.billet = nil
|
||||
}
|
||||
|
||||
// BlockHeight returns index of the last stored block.
|
||||
func (s *Module) BlockHeight() uint32 {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.blockHeight
|
||||
}
|
||||
|
||||
// IsActive tells whether state sync module is on and still gathering state
|
||||
// synchronisation data (headers, blocks or MPT nodes).
|
||||
func (s *Module) IsActive() bool {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return !(s.syncStage == inactive || (s.syncStage == headersSynced|mptSynced|blocksSynced))
|
||||
}
|
||||
|
||||
// IsInitialized tells whether state sync module does not require initialization.
|
||||
// If `false` is returned then Init can be safely called.
|
||||
func (s *Module) IsInitialized() bool {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.syncStage != none
|
||||
}
|
||||
|
||||
// NeedHeaders tells whether the module hasn't completed headers synchronisation.
|
||||
func (s *Module) NeedHeaders() bool {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.syncStage == initialized
|
||||
}
|
||||
|
||||
// NeedMPTNodes returns whether the module hasn't completed MPT synchronisation.
|
||||
func (s *Module) NeedMPTNodes() bool {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.syncStage&headersSynced != 0 && s.syncStage&mptSynced == 0
|
||||
}
|
||||
|
||||
// Traverse traverses local MPT nodes starting from the specified root down to its
|
||||
// children calling `process` for each serialised node until stop condition is satisfied.
|
||||
func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
b := mpt.NewBillet(root, s.bc.GetConfig().KeepOnlyLatestState, storage.NewMemCachedStore(s.dao.Store))
|
||||
return b.Traverse(process, false)
|
||||
}
|
||||
|
||||
// GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max).
|
||||
func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.mptpool.GetBatch(limit)
|
||||
}
|
106
pkg/core/statesync/module_test.go
Normal file
106
pkg/core/statesync/module_test.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package statesync
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func TestModule_PR2019_discussion_r689629704(t *testing.T) {
|
||||
expectedStorage := storage.NewMemCachedStore(storage.NewMemoryStore())
|
||||
tr := mpt.NewTrie(nil, true, expectedStorage)
|
||||
require.NoError(t, tr.Put([]byte{0x03}, []byte("leaf1")))
|
||||
require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x02}, []byte("leaf2")))
|
||||
require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x04}, []byte("leaf3")))
|
||||
require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x02}, []byte("leaf2"))) // <-- the same `leaf2` and `leaf3` values are put in the storage,
|
||||
require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x04}, []byte("leaf3"))) // <-- but the path should differ.
|
||||
require.NoError(t, tr.Put([]byte{0x06, 0x03}, []byte("leaf4")))
|
||||
|
||||
sr := tr.StateRoot()
|
||||
tr.Flush()
|
||||
|
||||
// Keep MPT nodes in a map in order not to repeat them. We'll use `nodes` map to ask
|
||||
// state sync module to restore the nodes.
|
||||
var (
|
||||
nodes = make(map[util.Uint256][]byte)
|
||||
expectedItems []storage.KeyValue
|
||||
)
|
||||
expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) {
|
||||
key := slice.Copy(k)
|
||||
value := slice.Copy(v)
|
||||
expectedItems = append(expectedItems, storage.KeyValue{
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
hash, err := util.Uint256DecodeBytesBE(key[1:])
|
||||
require.NoError(t, err)
|
||||
nodeBytes := value[:len(value)-4]
|
||||
nodes[hash] = nodeBytes
|
||||
})
|
||||
|
||||
actualStorage := storage.NewMemCachedStore(storage.NewMemoryStore())
|
||||
// These actions are done in module.Init(), but it's not the point of the test.
|
||||
// Here we want to test only MPT restoring process.
|
||||
stateSync := &Module{
|
||||
log: zaptest.NewLogger(t),
|
||||
syncPoint: 1000500,
|
||||
syncStage: headersSynced,
|
||||
syncInterval: 100500,
|
||||
dao: dao.NewSimple(actualStorage, true, false),
|
||||
mptpool: NewPool(),
|
||||
}
|
||||
stateSync.billet = mpt.NewBillet(sr, true, actualStorage)
|
||||
stateSync.mptpool.Add(sr, []byte{})
|
||||
|
||||
// The test itself: we'll ask state sync module to restore each node exactly once.
|
||||
// After that storage content (including storage items and refcounts) must
|
||||
// match exactly the one got from real MPT trie. MPT pool must be empty.
|
||||
// State sync module must have mptSynced state in the end.
|
||||
// MPT Billet root must become a collapsed hashnode (it was checked manually).
|
||||
requested := make(map[util.Uint256]struct{})
|
||||
for {
|
||||
unknownHashes := stateSync.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one
|
||||
if len(unknownHashes) == 0 {
|
||||
break
|
||||
}
|
||||
h := unknownHashes[0]
|
||||
node, ok := nodes[h]
|
||||
if !ok {
|
||||
if _, ok = requested[h]; ok {
|
||||
t.Fatal("node was requested twice")
|
||||
}
|
||||
t.Fatal("unknown node was requested")
|
||||
}
|
||||
require.NotPanics(t, func() {
|
||||
err := stateSync.AddMPTNodes([][]byte{node})
|
||||
require.NoError(t, err)
|
||||
}, fmt.Errorf("hash=%s, value=%s", h.StringBE(), string(node)))
|
||||
requested[h] = struct{}{}
|
||||
delete(nodes, h)
|
||||
if len(nodes) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Equal(t, headersSynced|mptSynced, stateSync.syncStage, "all nodes were sent exactly ones, but MPT wasn't restored")
|
||||
require.Equal(t, 0, len(nodes), "not all nodes were requested by state sync module")
|
||||
require.Equal(t, 0, stateSync.mptpool.Count(), "MPT was restored, but MPT pool still contains items")
|
||||
|
||||
// Compare resulting storage items and refcounts.
|
||||
var actualItems []storage.KeyValue
|
||||
expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) {
|
||||
key := slice.Copy(k)
|
||||
value := slice.Copy(v)
|
||||
actualItems = append(actualItems, storage.KeyValue{
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
})
|
||||
require.ElementsMatch(t, expectedItems, actualItems)
|
||||
}
|
142
pkg/core/statesync/mptpool.go
Normal file
142
pkg/core/statesync/mptpool.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package statesync
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// Pool stores unknown MPT nodes along with the corresponding paths (single node is
|
||||
// allowed to have multiple MPT paths).
|
||||
type Pool struct {
|
||||
lock sync.RWMutex
|
||||
hashes map[util.Uint256][][]byte
|
||||
}
|
||||
|
||||
// NewPool returns new MPT node hashes pool.
|
||||
func NewPool() *Pool {
|
||||
return &Pool{
|
||||
hashes: make(map[util.Uint256][][]byte),
|
||||
}
|
||||
}
|
||||
|
||||
// ContainsKey checks if MPT node with the specified hash is in the Pool.
|
||||
func (mp *Pool) ContainsKey(hash util.Uint256) bool {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
_, ok := mp.hashes[hash]
|
||||
return ok
|
||||
}
|
||||
|
||||
// TryGet returns a set of MPT paths for the specified HashNode.
|
||||
func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
paths, ok := mp.hashes[hash]
|
||||
// need to copy here, because we can modify existing array of paths inside the pool.
|
||||
res := make([][]byte, len(paths))
|
||||
copy(res, paths)
|
||||
return res, ok
|
||||
}
|
||||
|
||||
// GetAll returns all MPT nodes with the corresponding paths from the pool.
|
||||
func (mp *Pool) GetAll() map[util.Uint256][][]byte {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
return mp.hashes
|
||||
}
|
||||
|
||||
// GetBatch returns set of unknown MPT nodes hashes (`limit` at max).
|
||||
func (mp *Pool) GetBatch(limit int) []util.Uint256 {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
count := len(mp.hashes)
|
||||
if count > limit {
|
||||
count = limit
|
||||
}
|
||||
result := make([]util.Uint256, 0, limit)
|
||||
for h := range mp.hashes {
|
||||
if count == 0 {
|
||||
break
|
||||
}
|
||||
result = append(result, h)
|
||||
count--
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Remove removes MPT node from the pool by the specified hash.
|
||||
func (mp *Pool) Remove(hash util.Uint256) {
|
||||
mp.lock.Lock()
|
||||
defer mp.lock.Unlock()
|
||||
|
||||
delete(mp.hashes, hash)
|
||||
}
|
||||
|
||||
// Add adds path to the set of paths for the specified node.
|
||||
func (mp *Pool) Add(hash util.Uint256, path []byte) {
|
||||
mp.lock.Lock()
|
||||
defer mp.lock.Unlock()
|
||||
|
||||
mp.addPaths(hash, [][]byte{path})
|
||||
}
|
||||
|
||||
// Update is an atomic operation and removes/adds specified nodes from/to the pool.
|
||||
func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) {
|
||||
mp.lock.Lock()
|
||||
defer mp.lock.Unlock()
|
||||
|
||||
for h, paths := range remove {
|
||||
old := mp.hashes[h]
|
||||
for _, path := range paths {
|
||||
i := sort.Search(len(old), func(i int) bool {
|
||||
return bytes.Compare(old[i], path) >= 0
|
||||
})
|
||||
if i < len(old) && bytes.Equal(old[i], path) {
|
||||
old = append(old[:i], old[i+1:]...)
|
||||
}
|
||||
}
|
||||
if len(old) == 0 {
|
||||
delete(mp.hashes, h)
|
||||
} else {
|
||||
mp.hashes[h] = old
|
||||
}
|
||||
}
|
||||
for h, paths := range add {
|
||||
mp.addPaths(h, paths)
|
||||
}
|
||||
}
|
||||
|
||||
// addPaths adds set of the specified node paths to the pool.
|
||||
func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) {
|
||||
old := mp.hashes[nodeHash]
|
||||
for _, path := range paths {
|
||||
i := sort.Search(len(old), func(i int) bool {
|
||||
return bytes.Compare(old[i], path) >= 0
|
||||
})
|
||||
if i < len(old) && bytes.Equal(old[i], path) {
|
||||
// then path is already added
|
||||
continue
|
||||
}
|
||||
old = append(old, path)
|
||||
if i != len(old)-1 {
|
||||
copy(old[i+1:], old[i:])
|
||||
old[i] = path
|
||||
}
|
||||
}
|
||||
mp.hashes[nodeHash] = old
|
||||
}
|
||||
|
||||
// Count returns the number of nodes in the pool.
|
||||
func (mp *Pool) Count() int {
|
||||
mp.lock.RLock()
|
||||
defer mp.lock.RUnlock()
|
||||
|
||||
return len(mp.hashes)
|
||||
}
|
123
pkg/core/statesync/mptpool_test.go
Normal file
123
pkg/core/statesync/mptpool_test.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
package statesync
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPool_AddRemoveUpdate(t *testing.T) {
|
||||
mp := NewPool()
|
||||
|
||||
i1 := []byte{1, 2, 3}
|
||||
i1h := util.Uint256{1, 2, 3}
|
||||
i2 := []byte{2, 3, 4}
|
||||
i2h := util.Uint256{2, 3, 4}
|
||||
i3 := []byte{4, 5, 6}
|
||||
i3h := util.Uint256{3, 4, 5}
|
||||
i4 := []byte{3, 4, 5} // has the same hash as i3
|
||||
i5 := []byte{6, 7, 8} // has the same hash as i3
|
||||
mapAll := map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}, i3h: {i4, i3}}
|
||||
|
||||
// No items
|
||||
_, ok := mp.TryGet(i1h)
|
||||
require.False(t, ok)
|
||||
require.False(t, mp.ContainsKey(i1h))
|
||||
require.Equal(t, 0, mp.Count())
|
||||
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
|
||||
|
||||
// Add i1, i2, check OK
|
||||
mp.Add(i1h, i1)
|
||||
mp.Add(i2h, i2)
|
||||
itm, ok := mp.TryGet(i1h)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, [][]byte{i1}, itm)
|
||||
require.True(t, mp.ContainsKey(i1h))
|
||||
require.True(t, mp.ContainsKey(i2h))
|
||||
require.Equal(t, map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
|
||||
// Remove i1 and unexisting item
|
||||
mp.Remove(i3h)
|
||||
mp.Remove(i1h)
|
||||
require.False(t, mp.ContainsKey(i1h))
|
||||
require.True(t, mp.ContainsKey(i2h))
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}}, mp.GetAll())
|
||||
require.Equal(t, 1, mp.Count())
|
||||
|
||||
// Update: remove nothing, add all
|
||||
mp.Update(nil, mapAll)
|
||||
require.Equal(t, mapAll, mp.GetAll())
|
||||
require.Equal(t, 3, mp.Count())
|
||||
// Update: remove all, add all
|
||||
mp.Update(mapAll, mapAll)
|
||||
require.Equal(t, mapAll, mp.GetAll()) // deletion first, addition after that
|
||||
require.Equal(t, 3, mp.Count())
|
||||
// Update: remove all, add nothing
|
||||
mp.Update(mapAll, nil)
|
||||
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
|
||||
require.Equal(t, 0, mp.Count())
|
||||
// Update: remove several, add several
|
||||
mp.Update(map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}})
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
|
||||
// Update: remove nothing, add several with same hashes
|
||||
mp.Update(nil, map[util.Uint256][][]byte{i3h: {i5, i4}}) // should be sorted by the pool
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i3, i5}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
// Update: remove several with same hashes, add nothing
|
||||
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i4}}, nil)
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
// Update: remove several with same hashes, add several with same hashes
|
||||
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i3}}, map[util.Uint256][][]byte{i3h: {i5, i4}})
|
||||
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i5}}, mp.GetAll())
|
||||
require.Equal(t, 2, mp.Count())
|
||||
}
|
||||
|
||||
func TestPool_GetBatch(t *testing.T) {
|
||||
check := func(t *testing.T, limit int, itemsCount int) {
|
||||
mp := NewPool()
|
||||
for i := 0; i < itemsCount; i++ {
|
||||
mp.Add(random.Uint256(), []byte{0x01})
|
||||
}
|
||||
batch := mp.GetBatch(limit)
|
||||
if limit < itemsCount {
|
||||
require.Equal(t, limit, len(batch))
|
||||
} else {
|
||||
require.Equal(t, itemsCount, len(batch))
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("limit less than items count", func(t *testing.T) {
|
||||
check(t, 5, 6)
|
||||
})
|
||||
t.Run("limit more than items count", func(t *testing.T) {
|
||||
check(t, 6, 5)
|
||||
})
|
||||
t.Run("items count limit", func(t *testing.T) {
|
||||
check(t, 5, 5)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPool_UpdateUsingSliceFromPool(t *testing.T) {
|
||||
mp := NewPool()
|
||||
p1, _ := hex.DecodeString("0f0a0f0f0f0f0f0f0104020b02080c0a06050e070b050404060206060d07080602030b04040b050e040406030f0708060c05")
|
||||
p2, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040a0b000f04000b03090b02090b0e040f0d0b060d070e0b0b090b0906080602060c0d0f0e0d04070e")
|
||||
p3, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040b010d01080f050f000a0d0e08060c040b050800050904060f050807080a080c07040d0107080007")
|
||||
h, _ := util.Uint256DecodeStringBE("57e197679ef031bf2f0b466b20afe3f67ac04dcff80a1dc4d12dd98dd21a2511")
|
||||
mp.Add(h, p1)
|
||||
mp.Add(h, p2)
|
||||
mp.Add(h, p3)
|
||||
|
||||
toBeRemoved, ok := mp.TryGet(h)
|
||||
require.True(t, ok)
|
||||
|
||||
mp.Update(map[util.Uint256][][]byte{h: toBeRemoved}, nil)
|
||||
// test that all items were successfully removed.
|
||||
require.Equal(t, 0, len(mp.GetAll()))
|
||||
}
|
435
pkg/core/statesync_test.go
Normal file
435
pkg/core/statesync_test.go
Normal file
|
@ -0,0 +1,435 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/config"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStateSyncModule_Init(t *testing.T) {
|
||||
var (
|
||||
stateSyncInterval = 2
|
||||
maxTraceable uint32 = 3
|
||||
)
|
||||
spoutCfg := func(c *config.Config) {
|
||||
c.ProtocolConfiguration.StateRootInHeader = true
|
||||
c.ProtocolConfiguration.P2PStateExchangeExtensions = true
|
||||
c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval
|
||||
c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable
|
||||
}
|
||||
bcSpout := newTestChainWithCustomCfg(t, spoutCfg)
|
||||
for i := 0; i <= 2*stateSyncInterval+int(maxTraceable)+1; i++ {
|
||||
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||
}
|
||||
|
||||
boltCfg := func(c *config.Config) {
|
||||
spoutCfg(c)
|
||||
c.ProtocolConfiguration.KeepOnlyLatestState = true
|
||||
c.ProtocolConfiguration.RemoveUntraceableBlocks = true
|
||||
}
|
||||
t.Run("error: module disabled by config", func(t *testing.T) {
|
||||
bcBolt := newTestChainWithCustomCfg(t, func(c *config.Config) {
|
||||
boltCfg(c)
|
||||
c.ProtocolConfiguration.RemoveUntraceableBlocks = false
|
||||
})
|
||||
module := bcBolt.GetStateSyncModule()
|
||||
require.Error(t, module.Init(bcSpout.BlockHeight())) // module inactive (non-archival node)
|
||||
})
|
||||
|
||||
t.Run("inactive: spout chain is too low to start state sync process", func(t *testing.T) {
|
||||
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||
module := bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(uint32(2*stateSyncInterval-1)))
|
||||
require.False(t, module.IsActive())
|
||||
})
|
||||
|
||||
t.Run("inactive: bolt chain height is close enough to spout chain height", func(t *testing.T) {
|
||||
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||
for i := 1; i < int(bcSpout.BlockHeight())-stateSyncInterval; i++ {
|
||||
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, bcBolt.AddBlock(b))
|
||||
}
|
||||
module := bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.False(t, module.IsActive())
|
||||
})
|
||||
|
||||
t.Run("error: bolt chain is too low to start state sync process", func(t *testing.T) {
|
||||
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||
require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock()))
|
||||
|
||||
module := bcBolt.GetStateSyncModule()
|
||||
require.Error(t, module.Init(uint32(3*stateSyncInterval)))
|
||||
})
|
||||
|
||||
t.Run("initialized: no previous state sync point", func(t *testing.T) {
|
||||
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||
|
||||
module := bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.True(t, module.NeedHeaders())
|
||||
require.False(t, module.NeedMPTNodes())
|
||||
})
|
||||
|
||||
t.Run("error: outdated state sync point in the storage", func(t *testing.T) {
|
||||
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||
module := bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
|
||||
module = bcBolt.GetStateSyncModule()
|
||||
require.Error(t, module.Init(bcSpout.BlockHeight()+2*uint32(stateSyncInterval)))
|
||||
})
|
||||
|
||||
t.Run("initialized: valid previous state sync point in the storage", func(t *testing.T) {
|
||||
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||
module := bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
|
||||
module = bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.True(t, module.NeedHeaders())
|
||||
require.False(t, module.NeedMPTNodes())
|
||||
})
|
||||
|
||||
t.Run("initialization from headers/blocks/mpt synced stages", func(t *testing.T) {
|
||||
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||
module := bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
|
||||
// firstly, fetch all headers to create proper DB state (where headers are in sync)
|
||||
stateSyncPoint := (int(bcSpout.BlockHeight()) / stateSyncInterval) * stateSyncInterval
|
||||
var expectedHeader *block.Header
|
||||
for i := 1; i <= int(bcSpout.HeaderHeight()); i++ {
|
||||
header, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(i))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, module.AddHeaders(header))
|
||||
if i == stateSyncPoint+1 {
|
||||
expectedHeader = header
|
||||
}
|
||||
}
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.True(t, module.NeedMPTNodes())
|
||||
|
||||
// then create new statesync module with the same DB and check that state is proper
|
||||
// (headers are in sync)
|
||||
module = bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.True(t, module.NeedMPTNodes())
|
||||
unknownNodes := module.GetUnknownMPTNodesBatch(2)
|
||||
require.Equal(t, 1, len(unknownNodes))
|
||||
require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0])
|
||||
|
||||
// add several blocks to create DB state where blocks are not in sync yet, but it's not a genesis.
|
||||
for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint-stateSyncInterval-1; i++ {
|
||||
block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, module.AddBlock(block))
|
||||
}
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.True(t, module.NeedMPTNodes())
|
||||
require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight())
|
||||
|
||||
// then create new statesync module with the same DB and check that state is proper
|
||||
// (blocks are not in sync yet)
|
||||
module = bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.True(t, module.NeedMPTNodes())
|
||||
unknownNodes = module.GetUnknownMPTNodesBatch(2)
|
||||
require.Equal(t, 1, len(unknownNodes))
|
||||
require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0])
|
||||
require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight())
|
||||
|
||||
// add rest of blocks to create DB state where blocks are in sync
|
||||
for i := stateSyncPoint - stateSyncInterval; i <= stateSyncPoint; i++ {
|
||||
block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, module.AddBlock(block))
|
||||
}
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.True(t, module.NeedMPTNodes())
|
||||
lastBlock, err := bcBolt.GetBlock(expectedHeader.PrevHash)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(stateSyncPoint), lastBlock.Index)
|
||||
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||
|
||||
// then create new statesync module with the same DB and check that state is proper
|
||||
// (headers and blocks are in sync)
|
||||
module = bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.True(t, module.NeedMPTNodes())
|
||||
unknownNodes = module.GetUnknownMPTNodesBatch(2)
|
||||
require.Equal(t, 1, len(unknownNodes))
|
||||
require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0])
|
||||
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||
|
||||
// add a few MPT nodes to create DB state where some of MPT nodes are missing
|
||||
count := 5
|
||||
for {
|
||||
unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one
|
||||
if len(unknownHashes) == 0 {
|
||||
break
|
||||
}
|
||||
err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool {
|
||||
require.NoError(t, module.AddMPTNodes([][]byte{nodeBytes}))
|
||||
return true // add nodes one-by-one
|
||||
})
|
||||
require.NoError(t, err)
|
||||
count--
|
||||
if count < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// then create new statesync module with the same DB and check that state is proper
|
||||
// (headers and blocks are in sync, mpt is not yet synced)
|
||||
module = bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.True(t, module.NeedMPTNodes())
|
||||
unknownNodes = module.GetUnknownMPTNodesBatch(100)
|
||||
require.True(t, len(unknownNodes) > 0)
|
||||
require.NotContains(t, unknownNodes, expectedHeader.PrevStateRoot)
|
||||
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||
|
||||
// add the rest of MPT nodes and jump to state
|
||||
for {
|
||||
unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one
|
||||
if len(unknownHashes) == 0 {
|
||||
break
|
||||
}
|
||||
err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool {
|
||||
require.NoError(t, module.AddMPTNodes([][]byte{slice.Copy(nodeBytes)}))
|
||||
return true // add nodes one-by-one
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// check that module is inactive and statejump is completed
|
||||
require.False(t, module.IsActive())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.False(t, module.NeedMPTNodes())
|
||||
unknownNodes = module.GetUnknownMPTNodesBatch(1)
|
||||
require.True(t, len(unknownNodes) == 0)
|
||||
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||
require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight())
|
||||
|
||||
// create new module from completed state: the module should recognise that state sync is completed
|
||||
module = bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.False(t, module.IsActive())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.False(t, module.NeedMPTNodes())
|
||||
unknownNodes = module.GetUnknownMPTNodesBatch(1)
|
||||
require.True(t, len(unknownNodes) == 0)
|
||||
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||
require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight())
|
||||
|
||||
// add one more block to the restored chain and start new module: the module should recognise state sync is completed
|
||||
// and regular blocks processing was started
|
||||
require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock()))
|
||||
module = bcBolt.GetStateSyncModule()
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.False(t, module.IsActive())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.False(t, module.NeedMPTNodes())
|
||||
unknownNodes = module.GetUnknownMPTNodesBatch(1)
|
||||
require.True(t, len(unknownNodes) == 0)
|
||||
require.Equal(t, uint32(stateSyncPoint)+1, module.BlockHeight())
|
||||
require.Equal(t, uint32(stateSyncPoint)+1, bcBolt.BlockHeight())
|
||||
})
|
||||
}
|
||||
|
||||
func TestStateSyncModule_RestoreBasicChain(t *testing.T) {
|
||||
var (
|
||||
stateSyncInterval = 4
|
||||
maxTraceable uint32 = 6
|
||||
stateSyncPoint = 16
|
||||
)
|
||||
spoutCfg := func(c *config.Config) {
|
||||
c.ProtocolConfiguration.StateRootInHeader = true
|
||||
c.ProtocolConfiguration.P2PStateExchangeExtensions = true
|
||||
c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval
|
||||
c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable
|
||||
}
|
||||
bcSpout := newTestChainWithCustomCfg(t, spoutCfg)
|
||||
initBasicChain(t, bcSpout)
|
||||
|
||||
// make spout chain higher that latest state sync point
|
||||
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
|
||||
require.Equal(t, uint32(stateSyncPoint+2), bcSpout.BlockHeight())
|
||||
|
||||
boltCfg := func(c *config.Config) {
|
||||
spoutCfg(c)
|
||||
c.ProtocolConfiguration.KeepOnlyLatestState = true
|
||||
c.ProtocolConfiguration.RemoveUntraceableBlocks = true
|
||||
}
|
||||
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
|
||||
module := bcBolt.GetStateSyncModule()
|
||||
|
||||
t.Run("error: add headers before initialisation", func(t *testing.T) {
|
||||
h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(1))
|
||||
require.NoError(t, err)
|
||||
require.Error(t, module.AddHeaders(h))
|
||||
})
|
||||
t.Run("no error: add blocks before initialisation", func(t *testing.T) {
|
||||
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(1))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, module.AddBlock(b))
|
||||
})
|
||||
t.Run("error: add MPT nodes without initialisation", func(t *testing.T) {
|
||||
require.Error(t, module.AddMPTNodes([][]byte{}))
|
||||
})
|
||||
|
||||
require.NoError(t, module.Init(bcSpout.BlockHeight()))
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.True(t, module.NeedHeaders())
|
||||
require.False(t, module.NeedMPTNodes())
|
||||
|
||||
// add headers to module
|
||||
headers := make([]*block.Header, 0, bcSpout.HeaderHeight())
|
||||
for i := uint32(1); i <= bcSpout.HeaderHeight(); i++ {
|
||||
h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(int(i)))
|
||||
require.NoError(t, err)
|
||||
headers = append(headers, h)
|
||||
}
|
||||
require.NoError(t, module.AddHeaders(headers...))
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.True(t, module.NeedMPTNodes())
|
||||
require.Equal(t, bcSpout.HeaderHeight(), bcBolt.HeaderHeight())
|
||||
|
||||
// add blocks
|
||||
t.Run("error: unexpected block index", func(t *testing.T) {
|
||||
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(stateSyncPoint - int(maxTraceable)))
|
||||
require.NoError(t, err)
|
||||
require.Error(t, module.AddBlock(b))
|
||||
})
|
||||
t.Run("error: missing state root in block header", func(t *testing.T) {
|
||||
b := &block.Block{
|
||||
Header: block.Header{
|
||||
Index: uint32(stateSyncPoint) - maxTraceable + 1,
|
||||
StateRootEnabled: false,
|
||||
},
|
||||
}
|
||||
require.Error(t, module.AddBlock(b))
|
||||
})
|
||||
t.Run("error: invalid block merkle root", func(t *testing.T) {
|
||||
b := &block.Block{
|
||||
Header: block.Header{
|
||||
Index: uint32(stateSyncPoint) - maxTraceable + 1,
|
||||
StateRootEnabled: true,
|
||||
MerkleRoot: util.Uint256{1, 2, 3},
|
||||
},
|
||||
}
|
||||
require.Error(t, module.AddBlock(b))
|
||||
})
|
||||
|
||||
for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint; i++ {
|
||||
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, module.AddBlock(b))
|
||||
}
|
||||
require.True(t, module.IsActive())
|
||||
require.True(t, module.IsInitialized())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.True(t, module.NeedMPTNodes())
|
||||
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||
|
||||
// add MPT nodes in batches
|
||||
h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(stateSyncPoint + 1))
|
||||
require.NoError(t, err)
|
||||
unknownHashes := module.GetUnknownMPTNodesBatch(100)
|
||||
require.Equal(t, 1, len(unknownHashes))
|
||||
require.Equal(t, h.PrevStateRoot, unknownHashes[0])
|
||||
nodesMap := make(map[util.Uint256][]byte)
|
||||
err = bcSpout.GetStateSyncModule().Traverse(h.PrevStateRoot, func(n mpt.Node, nodeBytes []byte) bool {
|
||||
nodesMap[n.Hash()] = nodeBytes
|
||||
return false
|
||||
})
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
need := module.GetUnknownMPTNodesBatch(10)
|
||||
if len(need) == 0 {
|
||||
break
|
||||
}
|
||||
add := make([][]byte, len(need))
|
||||
for i, h := range need {
|
||||
nodeBytes, ok := nodesMap[h]
|
||||
if !ok {
|
||||
t.Fatal("unknown or restored node requested")
|
||||
}
|
||||
add[i] = nodeBytes
|
||||
delete(nodesMap, h)
|
||||
}
|
||||
require.NoError(t, module.AddMPTNodes(add))
|
||||
}
|
||||
require.False(t, module.IsActive())
|
||||
require.False(t, module.NeedHeaders())
|
||||
require.False(t, module.NeedMPTNodes())
|
||||
unknownNodes := module.GetUnknownMPTNodesBatch(1)
|
||||
require.True(t, len(unknownNodes) == 0)
|
||||
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
|
||||
require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight())
|
||||
|
||||
// add missing blocks to bcBolt: should be ok, because state is synced
|
||||
for i := stateSyncPoint + 1; i <= int(bcSpout.BlockHeight()); i++ {
|
||||
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, bcBolt.AddBlock(b))
|
||||
}
|
||||
require.Equal(t, bcSpout.BlockHeight(), bcBolt.BlockHeight())
|
||||
|
||||
// compare storage states
|
||||
fetchStorage := func(bc *Blockchain) []storage.KeyValue {
|
||||
var kv []storage.KeyValue
|
||||
bc.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) {
|
||||
key := slice.Copy(k)
|
||||
value := slice.Copy(v)
|
||||
kv = append(kv, storage.KeyValue{
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
})
|
||||
return kv
|
||||
}
|
||||
expected := fetchStorage(bcSpout)
|
||||
actual := fetchStorage(bcBolt)
|
||||
require.ElementsMatch(t, expected, actual)
|
||||
|
||||
// no temp items should be left
|
||||
bcBolt.dao.Store.Seek(storage.STTempStorage.Bytes(), func(k, v []byte) {
|
||||
t.Fatal("temp storage items are found")
|
||||
})
|
||||
}
|
|
@ -8,13 +8,18 @@ import (
|
|||
|
||||
// KeyPrefix constants.
|
||||
const (
|
||||
DataBlock KeyPrefix = 0x01
|
||||
DataTransaction KeyPrefix = 0x02
|
||||
DataMPT KeyPrefix = 0x03
|
||||
STAccount KeyPrefix = 0x40
|
||||
STNotification KeyPrefix = 0x4d
|
||||
STContractID KeyPrefix = 0x51
|
||||
STStorage KeyPrefix = 0x70
|
||||
DataBlock KeyPrefix = 0x01
|
||||
DataTransaction KeyPrefix = 0x02
|
||||
DataMPT KeyPrefix = 0x03
|
||||
STAccount KeyPrefix = 0x40
|
||||
STNotification KeyPrefix = 0x4d
|
||||
STContractID KeyPrefix = 0x51
|
||||
STStorage KeyPrefix = 0x70
|
||||
// STTempStorage is used to store contract storage items during state sync process
|
||||
// in order not to mess up the previous state which has its own items stored by
|
||||
// STStorage prefix. Once state exchange process is completed, all items with
|
||||
// STStorage prefix will be replaced with STTempStorage-prefixed ones.
|
||||
STTempStorage KeyPrefix = 0x71
|
||||
STNEP17Transfers KeyPrefix = 0x72
|
||||
STNEP17TransferInfo KeyPrefix = 0x73
|
||||
IXHeaderHashList KeyPrefix = 0x80
|
||||
|
@ -22,6 +27,7 @@ const (
|
|||
SYSCurrentHeader KeyPrefix = 0xc1
|
||||
SYSStateSyncCurrentBlockHeight KeyPrefix = 0xc2
|
||||
SYSStateSyncPoint KeyPrefix = 0xc3
|
||||
SYSStateJumpStage KeyPrefix = 0xc4
|
||||
SYSVersion KeyPrefix = 0xf0
|
||||
)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"github.com/Workiva/go-datastructures/queue"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
@ -13,6 +14,7 @@ type blockQueue struct {
|
|||
checkBlocks chan struct{}
|
||||
chain blockchainer.Blockqueuer
|
||||
relayF func(*block.Block)
|
||||
discarded *atomic.Bool
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -32,6 +34,7 @@ func newBlockQueue(capacity int, bc blockchainer.Blockqueuer, log *zap.Logger, r
|
|||
checkBlocks: make(chan struct{}, 1),
|
||||
chain: bc,
|
||||
relayF: relayer,
|
||||
discarded: atomic.NewBool(false),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,8 +94,10 @@ func (bq *blockQueue) putBlock(block *block.Block) error {
|
|||
}
|
||||
|
||||
func (bq *blockQueue) discard() {
|
||||
close(bq.checkBlocks)
|
||||
bq.queue.Dispose()
|
||||
if bq.discarded.CAS(false, true) {
|
||||
close(bq.checkBlocks)
|
||||
bq.queue.Dispose()
|
||||
}
|
||||
}
|
||||
|
||||
func (bq *blockQueue) length() int {
|
||||
|
|
|
@ -71,6 +71,8 @@ const (
|
|||
CMDBlock = CommandType(payload.BlockType)
|
||||
CMDExtensible = CommandType(payload.ExtensibleType)
|
||||
CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType)
|
||||
CMDGetMPTData CommandType = 0x51 // 0x5.. commands are used for extensions (P2PNotary, state exchange cmds)
|
||||
CMDMPTData CommandType = 0x52
|
||||
CMDReject CommandType = 0x2f
|
||||
|
||||
// SPV protocol.
|
||||
|
@ -136,6 +138,10 @@ func (m *Message) decodePayload() error {
|
|||
p = &payload.Version{}
|
||||
case CMDInv, CMDGetData:
|
||||
p = &payload.Inventory{}
|
||||
case CMDGetMPTData:
|
||||
p = &payload.MPTInventory{}
|
||||
case CMDMPTData:
|
||||
p = &payload.MPTData{}
|
||||
case CMDAddr:
|
||||
p = &payload.AddressList{}
|
||||
case CMDBlock:
|
||||
|
@ -221,7 +227,7 @@ func (m *Message) tryCompressPayload() error {
|
|||
if m.Flags&Compressed == 0 {
|
||||
switch m.Payload.(type) {
|
||||
case *payload.Headers, *payload.MerkleBlock, payload.NullPayload,
|
||||
*payload.Inventory:
|
||||
*payload.Inventory, *payload.MPTInventory:
|
||||
break
|
||||
default:
|
||||
size := len(compressedPayload)
|
||||
|
|
|
@ -26,6 +26,8 @@ func _() {
|
|||
_ = x[CMDBlock-44]
|
||||
_ = x[CMDExtensible-46]
|
||||
_ = x[CMDP2PNotaryRequest-80]
|
||||
_ = x[CMDGetMPTData-81]
|
||||
_ = x[CMDMPTData-82]
|
||||
_ = x[CMDReject-47]
|
||||
_ = x[CMDFilterLoad-48]
|
||||
_ = x[CMDFilterAdd-49]
|
||||
|
@ -44,7 +46,7 @@ const (
|
|||
_CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
|
||||
_CommandType_name_7 = "CMDMerkleBlock"
|
||||
_CommandType_name_8 = "CMDAlert"
|
||||
_CommandType_name_9 = "CMDP2PNotaryRequest"
|
||||
_CommandType_name_9 = "CMDP2PNotaryRequestCMDGetMPTDataCMDMPTData"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -55,6 +57,7 @@ var (
|
|||
_CommandType_index_4 = [...]uint8{0, 12, 22}
|
||||
_CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58}
|
||||
_CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61}
|
||||
_CommandType_index_9 = [...]uint8{0, 19, 32, 42}
|
||||
)
|
||||
|
||||
func (i CommandType) String() string {
|
||||
|
@ -83,8 +86,9 @@ func (i CommandType) String() string {
|
|||
return _CommandType_name_7
|
||||
case i == 64:
|
||||
return _CommandType_name_8
|
||||
case i == 80:
|
||||
return _CommandType_name_9
|
||||
case 80 <= i && i <= 82:
|
||||
i -= 80
|
||||
return _CommandType_name_9[_CommandType_index_9[i]:_CommandType_index_9[i+1]]
|
||||
default:
|
||||
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
|
|
|
@ -258,6 +258,21 @@ func TestEncodeDecodeNotFound(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestEncodeDecodeGetMPTData(t *testing.T) {
|
||||
testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{
|
||||
Hashes: []util.Uint256{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestEncodeDecodeMPTData(t *testing.T) {
|
||||
testEncodeDecode(t, CMDMPTData, &payload.MPTData{
|
||||
Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}},
|
||||
})
|
||||
}
|
||||
|
||||
func TestInvalidMessages(t *testing.T) {
|
||||
t.Run("CMDBlock, empty payload", func(t *testing.T) {
|
||||
testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})
|
||||
|
|
35
pkg/network/payload/mptdata.go
Normal file
35
pkg/network/payload/mptdata.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
)
|
||||
|
||||
// MPTData represents the set of serialized MPT nodes.
|
||||
type MPTData struct {
|
||||
Nodes [][]byte
|
||||
}
|
||||
|
||||
// EncodeBinary implements io.Serializable.
|
||||
func (d *MPTData) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteVarUint(uint64(len(d.Nodes)))
|
||||
for _, n := range d.Nodes {
|
||||
w.WriteVarBytes(n)
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeBinary implements io.Serializable.
|
||||
func (d *MPTData) DecodeBinary(r *io.BinReader) {
|
||||
sz := r.ReadVarUint()
|
||||
if sz == 0 {
|
||||
r.Err = errors.New("empty MPT nodes list")
|
||||
return
|
||||
}
|
||||
for i := uint64(0); i < sz; i++ {
|
||||
d.Nodes = append(d.Nodes, r.ReadVarBytes())
|
||||
if r.Err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
24
pkg/network/payload/mptdata_test.go
Normal file
24
pkg/network/payload/mptdata_test.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMPTData_EncodeDecodeBinary(t *testing.T) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
d := new(MPTData)
|
||||
bytes, err := testserdes.EncodeBinary(d)
|
||||
require.NoError(t, err)
|
||||
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTData)))
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
d := &MPTData{
|
||||
Nodes: [][]byte{{}, {1}, {1, 2, 3}},
|
||||
}
|
||||
testserdes.EncodeDecodeBinary(t, d, new(MPTData))
|
||||
})
|
||||
}
|
32
pkg/network/payload/mptinventory.go
Normal file
32
pkg/network/payload/mptinventory.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
)
|
||||
|
||||
// MaxMPTHashesCount is the maximum number of requested MPT nodes hashes.
|
||||
const MaxMPTHashesCount = 32
|
||||
|
||||
// MPTInventory payload.
|
||||
type MPTInventory struct {
|
||||
// A list of requested MPT nodes hashes.
|
||||
Hashes []util.Uint256
|
||||
}
|
||||
|
||||
// NewMPTInventory return a pointer to an MPTInventory.
|
||||
func NewMPTInventory(hashes []util.Uint256) *MPTInventory {
|
||||
return &MPTInventory{
|
||||
Hashes: hashes,
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeBinary implements Serializable interface.
|
||||
func (p *MPTInventory) DecodeBinary(br *io.BinReader) {
|
||||
br.ReadArray(&p.Hashes, MaxMPTHashesCount)
|
||||
}
|
||||
|
||||
// EncodeBinary implements Serializable interface.
|
||||
func (p *MPTInventory) EncodeBinary(bw *io.BinWriter) {
|
||||
bw.WriteArray(p.Hashes)
|
||||
}
|
38
pkg/network/payload/mptinventory_test.go
Normal file
38
pkg/network/payload/mptinventory_test.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package payload
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMPTInventory_EncodeDecodeBinary(t *testing.T) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
testserdes.EncodeDecodeBinary(t, NewMPTInventory([]util.Uint256{}), new(MPTInventory))
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
inv := NewMPTInventory([]util.Uint256{{1, 2, 3}, {2, 3, 4}})
|
||||
testserdes.EncodeDecodeBinary(t, inv, new(MPTInventory))
|
||||
})
|
||||
|
||||
t.Run("too large", func(t *testing.T) {
|
||||
check := func(t *testing.T, count int, fail bool) {
|
||||
h := make([]util.Uint256, count)
|
||||
for i := range h {
|
||||
h[i] = util.Uint256{1, 2, 3}
|
||||
}
|
||||
if fail {
|
||||
bytes, err := testserdes.EncodeBinary(NewMPTInventory(h))
|
||||
require.NoError(t, err)
|
||||
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTInventory)))
|
||||
} else {
|
||||
testserdes.EncodeDecodeBinary(t, NewMPTInventory(h), new(MPTInventory))
|
||||
}
|
||||
}
|
||||
check(t, MaxMPTHashesCount, false)
|
||||
check(t, MaxMPTHashesCount+1, true)
|
||||
})
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -17,7 +18,9 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/extpool"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||
|
@ -67,6 +70,7 @@ type (
|
|||
discovery Discoverer
|
||||
chain blockchainer.Blockchainer
|
||||
bQueue *blockQueue
|
||||
bSyncQueue *blockQueue
|
||||
consensus consensus.Service
|
||||
mempool *mempool.Pool
|
||||
notaryRequestPool *mempool.Pool
|
||||
|
@ -93,6 +97,7 @@ type (
|
|||
|
||||
oracle *oracle.Oracle
|
||||
stateRoot stateroot.Service
|
||||
stateSync blockchainer.StateSync
|
||||
|
||||
log *zap.Logger
|
||||
}
|
||||
|
@ -191,6 +196,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
|
|||
}
|
||||
s.stateRoot = sr
|
||||
|
||||
sSync := chain.GetStateSyncModule()
|
||||
s.stateSync = sSync
|
||||
s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil)
|
||||
|
||||
if config.OracleCfg.Enabled {
|
||||
orcCfg := oracle.Config{
|
||||
Log: log,
|
||||
|
@ -277,6 +286,7 @@ func (s *Server) Start(errChan chan error) {
|
|||
go s.broadcastTxLoop()
|
||||
go s.relayBlocksLoop()
|
||||
go s.bQueue.run()
|
||||
go s.bSyncQueue.run()
|
||||
go s.transport.Accept()
|
||||
setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10))
|
||||
s.run()
|
||||
|
@ -292,6 +302,7 @@ func (s *Server) Shutdown() {
|
|||
p.Disconnect(errServerShutdown)
|
||||
}
|
||||
s.bQueue.discard()
|
||||
s.bSyncQueue.discard()
|
||||
if s.StateRootCfg.Enabled {
|
||||
s.stateRoot.Shutdown()
|
||||
}
|
||||
|
@ -573,6 +584,10 @@ func (s *Server) IsInSync() bool {
|
|||
var peersNumber int
|
||||
var notHigher int
|
||||
|
||||
if s.stateSync.IsActive() {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.MinPeers == 0 {
|
||||
return true
|
||||
}
|
||||
|
@ -630,6 +645,9 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
|
|||
|
||||
// handleBlockCmd processes the received block received from its peer.
|
||||
func (s *Server) handleBlockCmd(p Peer, block *block.Block) error {
|
||||
if s.stateSync.IsActive() {
|
||||
return s.bSyncQueue.putBlock(block)
|
||||
}
|
||||
return s.bQueue.putBlock(block)
|
||||
}
|
||||
|
||||
|
@ -639,25 +657,57 @@ func (s *Server) handlePing(p Peer, ping *payload.Ping) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.chain.BlockHeight() < ping.LastBlockIndex {
|
||||
err = s.requestBlocks(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.requestBlocksOrHeaders(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id)))
|
||||
}
|
||||
|
||||
func (s *Server) requestBlocksOrHeaders(p Peer) error {
|
||||
if s.stateSync.NeedHeaders() {
|
||||
if s.chain.HeaderHeight() < p.LastBlockIndex() {
|
||||
return s.requestHeaders(p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
bq blockchainer.Blockqueuer = s.chain
|
||||
requestMPTNodes bool
|
||||
)
|
||||
if s.stateSync.IsActive() {
|
||||
bq = s.stateSync
|
||||
requestMPTNodes = s.stateSync.NeedMPTNodes()
|
||||
}
|
||||
if bq.BlockHeight() >= p.LastBlockIndex() {
|
||||
return nil
|
||||
}
|
||||
err := s.requestBlocks(bq, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if requestMPTNodes {
|
||||
return s.requestMPTNodes(p, s.stateSync.GetUnknownMPTNodesBatch(payload.MaxMPTHashesCount))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers.
|
||||
func (s *Server) requestHeaders(p Peer) error {
|
||||
// TODO: optimize
|
||||
currHeight := s.chain.HeaderHeight()
|
||||
needHeight := currHeight + 1
|
||||
payload := payload.NewGetBlockByIndex(needHeight, -1)
|
||||
return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload))
|
||||
}
|
||||
|
||||
// handlePing processes pong request.
|
||||
func (s *Server) handlePong(p Peer, pong *payload.Ping) error {
|
||||
err := p.HandlePong(pong)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.chain.BlockHeight() < pong.LastBlockIndex {
|
||||
return s.requestBlocks(p)
|
||||
}
|
||||
return nil
|
||||
return s.requestBlocksOrHeaders(p)
|
||||
}
|
||||
|
||||
// handleInvCmd processes the received inventory.
|
||||
|
@ -766,6 +816,69 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// handleGetMPTDataCmd processes the received MPT inventory.
|
||||
func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error {
|
||||
if !s.chain.GetConfig().P2PStateExchangeExtensions {
|
||||
return errors.New("GetMPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
|
||||
}
|
||||
if s.chain.GetConfig().KeepOnlyLatestState {
|
||||
// TODO: implement keeping MPT states for P1 and P2 height (#2095, #2152 related)
|
||||
return errors.New("GetMPTDataCMD was received, but only latest MPT state is supported")
|
||||
}
|
||||
resp := payload.MPTData{}
|
||||
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
|
||||
added := make(map[util.Uint256]struct{})
|
||||
for _, h := range inv.Hashes {
|
||||
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
|
||||
break
|
||||
}
|
||||
err := s.stateSync.Traverse(h,
|
||||
func(n mpt.Node, node []byte) bool {
|
||||
if _, ok := added[n.Hash()]; ok {
|
||||
return false
|
||||
}
|
||||
l := len(node)
|
||||
size := l + io.GetVarSize(l)
|
||||
if size > capLeft {
|
||||
return true
|
||||
}
|
||||
resp.Nodes = append(resp.Nodes, node)
|
||||
added[n.Hash()] = struct{}{}
|
||||
capLeft -= size
|
||||
return false
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to traverse MPT starting from %s: %w", h.StringBE(), err)
|
||||
}
|
||||
}
|
||||
if len(resp.Nodes) > 0 {
|
||||
msg := NewMessage(CMDMPTData, &resp)
|
||||
return p.EnqueueP2PMessage(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error {
|
||||
if !s.chain.GetConfig().P2PStateExchangeExtensions {
|
||||
return errors.New("MPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
|
||||
}
|
||||
return s.stateSync.AddMPTNodes(data.Nodes)
|
||||
}
|
||||
|
||||
// requestMPTNodes requests specified MPT nodes from the peer or broadcasts
|
||||
// request if peer is not specified.
|
||||
func (s *Server) requestMPTNodes(p Peer, itms []util.Uint256) error {
|
||||
if len(itms) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(itms) > payload.MaxMPTHashesCount {
|
||||
itms = itms[:payload.MaxMPTHashesCount]
|
||||
}
|
||||
pl := payload.NewMPTInventory(itms)
|
||||
msg := NewMessage(CMDGetMPTData, pl)
|
||||
return p.EnqueueP2PMessage(msg)
|
||||
}
|
||||
|
||||
// handleGetBlocksCmd processes the getblocks request.
|
||||
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
|
||||
count := gb.Count
|
||||
|
@ -845,6 +958,11 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error
|
|||
return p.EnqueueP2PMessage(msg)
|
||||
}
|
||||
|
||||
// handleHeadersCmd processes headers payload.
|
||||
func (s *Server) handleHeadersCmd(p Peer, h *payload.Headers) error {
|
||||
return s.stateSync.AddHeaders(h.Hdrs...)
|
||||
}
|
||||
|
||||
// handleExtensibleCmd processes received extensible payload.
|
||||
func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
|
||||
if !s.syncReached.Load() {
|
||||
|
@ -993,8 +1111,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error {
|
|||
// 1. Block range is divided into chunks of payload.MaxHashesCount.
|
||||
// 2. Send requests for chunk in increasing order.
|
||||
// 3. After all requests were sent, request random height.
|
||||
func (s *Server) requestBlocks(p Peer) error {
|
||||
var currHeight = s.chain.BlockHeight()
|
||||
func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error {
|
||||
var currHeight = bq.BlockHeight()
|
||||
var peerHeight = p.LastBlockIndex()
|
||||
var needHeight uint32
|
||||
// lastRequestedHeight can only be increased.
|
||||
|
@ -1051,9 +1169,18 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
case CMDGetData:
|
||||
inv := msg.Payload.(*payload.Inventory)
|
||||
return s.handleGetDataCmd(peer, inv)
|
||||
case CMDGetMPTData:
|
||||
inv := msg.Payload.(*payload.MPTInventory)
|
||||
return s.handleGetMPTDataCmd(peer, inv)
|
||||
case CMDMPTData:
|
||||
inv := msg.Payload.(*payload.MPTData)
|
||||
return s.handleMPTDataCmd(peer, inv)
|
||||
case CMDGetHeaders:
|
||||
gh := msg.Payload.(*payload.GetBlockByIndex)
|
||||
return s.handleGetHeadersCmd(peer, gh)
|
||||
case CMDHeaders:
|
||||
h := msg.Payload.(*payload.Headers)
|
||||
return s.handleHeadersCmd(peer, h)
|
||||
case CMDInv:
|
||||
inventory := msg.Payload.(*payload.Inventory)
|
||||
return s.handleInvCmd(peer, inventory)
|
||||
|
@ -1093,6 +1220,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
}
|
||||
go peer.StartProtocol()
|
||||
|
||||
s.tryInitStateSync()
|
||||
s.tryStartServices()
|
||||
default:
|
||||
return fmt.Errorf("received '%s' during handshake", msg.Command.String())
|
||||
|
@ -1101,6 +1229,52 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) tryInitStateSync() {
|
||||
if !s.stateSync.IsActive() {
|
||||
s.bSyncQueue.discard()
|
||||
return
|
||||
}
|
||||
|
||||
if s.stateSync.IsInitialized() {
|
||||
return
|
||||
}
|
||||
|
||||
var peersNumber int
|
||||
s.lock.RLock()
|
||||
heights := make([]uint32, 0)
|
||||
for p := range s.peers {
|
||||
if p.Handshaked() {
|
||||
peersNumber++
|
||||
peerLastBlock := p.LastBlockIndex()
|
||||
i := sort.Search(len(heights), func(i int) bool {
|
||||
return heights[i] >= peerLastBlock
|
||||
})
|
||||
heights = append(heights, peerLastBlock)
|
||||
if i != len(heights)-1 {
|
||||
copy(heights[i+1:], heights[i:])
|
||||
heights[i] = peerLastBlock
|
||||
}
|
||||
}
|
||||
}
|
||||
s.lock.RUnlock()
|
||||
if peersNumber >= s.MinPeers && len(heights) > 0 {
|
||||
// choose the height of the median peer as current chain's height
|
||||
h := heights[len(heights)/2]
|
||||
err := s.stateSync.Init(h)
|
||||
if err != nil {
|
||||
s.log.Fatal("failed to init state sync module",
|
||||
zap.Uint32("evaluated chain's blockHeight", h),
|
||||
zap.Uint32("blockHeight", s.chain.BlockHeight()),
|
||||
zap.Uint32("headerHeight", s.chain.HeaderHeight()),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
// module can be inactive after init (i.e. full state is collected and ordinary block processing is needed)
|
||||
if !s.stateSync.IsActive() {
|
||||
s.bSyncQueue.discard()
|
||||
}
|
||||
}
|
||||
}
|
||||
func (s *Server) handleNewPayload(p *payload.Extensible) {
|
||||
_, err := s.extensiblePool.Add(p)
|
||||
if err != nil {
|
||||
|
|
|
@ -2,6 +2,7 @@ package network
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||
|
@ -46,7 +48,10 @@ func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs =
|
|||
func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
bc := &fakechain.FakeChain{}
|
||||
bc := &fakechain.FakeChain{ProtocolConfiguration: config.ProtocolConfiguration{
|
||||
P2PStateExchangeExtensions: true,
|
||||
StateRootInHeader: true,
|
||||
}}
|
||||
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
|
||||
require.Error(t, err)
|
||||
|
||||
|
@ -733,6 +738,139 @@ func TestInv(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestHandleGetMPTData(t *testing.T) {
|
||||
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
|
||||
Hashes: []util.Uint256{{1, 2, 3}},
|
||||
})
|
||||
require.Error(t, s.handleMessage(p, msg))
|
||||
})
|
||||
|
||||
t.Run("KeepOnlyLatestState on", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||
s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
|
||||
Hashes: []util.Uint256{{1, 2, 3}},
|
||||
})
|
||||
require.Error(t, s.handleMessage(p, msg))
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||
var recvResponse atomic.Bool
|
||||
r1 := random.Uint256()
|
||||
r2 := random.Uint256()
|
||||
r3 := random.Uint256()
|
||||
node := []byte{1, 2, 3}
|
||||
s.stateSync.(*fakechain.FakeStateSync).TraverseFunc = func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||
if !(root.Equals(r1) || root.Equals(r2)) {
|
||||
t.Fatal("unexpected root")
|
||||
}
|
||||
require.False(t, process(mpt.NewHashNode(r3), node))
|
||||
return nil
|
||||
}
|
||||
found := &payload.MPTData{
|
||||
Nodes: [][]byte{node}, // no duplicates expected
|
||||
}
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
p.messageHandler = func(t *testing.T, msg *Message) {
|
||||
switch msg.Command {
|
||||
case CMDMPTData:
|
||||
require.Equal(t, found, msg.Payload)
|
||||
recvResponse.Store(true)
|
||||
}
|
||||
}
|
||||
hs := []util.Uint256{r1, r2}
|
||||
s.testHandleMessage(t, p, CMDGetMPTData, payload.NewMPTInventory(hs))
|
||||
|
||||
require.Eventually(t, recvResponse.Load, time.Second, time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleMPTData(t *testing.T) {
|
||||
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
msg := NewMessage(CMDMPTData, &payload.MPTData{
|
||||
Nodes: [][]byte{{1, 2, 3}},
|
||||
})
|
||||
require.Error(t, s.handleMessage(p, msg))
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
expected := [][]byte{{1, 2, 3}, {2, 3, 4}}
|
||||
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||
s.stateSync = &fakechain.FakeStateSync{
|
||||
AddMPTNodesFunc: func(nodes [][]byte) error {
|
||||
require.Equal(t, expected, nodes)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
msg := NewMessage(CMDMPTData, &payload.MPTData{
|
||||
Nodes: expected,
|
||||
})
|
||||
require.NoError(t, s.handleMessage(p, msg))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestMPTNodes(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
|
||||
var actual []util.Uint256
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
p.messageHandler = func(t *testing.T, msg *Message) {
|
||||
if msg.Command == CMDGetMPTData {
|
||||
actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...)
|
||||
}
|
||||
}
|
||||
s.register <- p
|
||||
s.register <- p // ensure previous send was handled
|
||||
|
||||
t.Run("no hashes, no message", func(t *testing.T) {
|
||||
actual = nil
|
||||
require.NoError(t, s.requestMPTNodes(p, nil))
|
||||
require.Nil(t, actual)
|
||||
})
|
||||
t.Run("good, small", func(t *testing.T) {
|
||||
actual = nil
|
||||
expected := []util.Uint256{random.Uint256(), random.Uint256()}
|
||||
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||
require.Equal(t, expected, actual)
|
||||
})
|
||||
t.Run("good, exactly one chunk", func(t *testing.T) {
|
||||
actual = nil
|
||||
expected := make([]util.Uint256, payload.MaxMPTHashesCount)
|
||||
for i := range expected {
|
||||
expected[i] = random.Uint256()
|
||||
}
|
||||
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||
require.Equal(t, expected, actual)
|
||||
})
|
||||
t.Run("good, too large chunk", func(t *testing.T) {
|
||||
actual = nil
|
||||
expected := make([]util.Uint256, payload.MaxMPTHashesCount+1)
|
||||
for i := range expected {
|
||||
expected[i] = random.Uint256()
|
||||
}
|
||||
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||
require.Equal(t, expected[:payload.MaxMPTHashesCount], actual)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestTx(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
|
||||
|
@ -899,3 +1037,44 @@ func TestVerifyNotaryRequest(t *testing.T) {
|
|||
require.NoError(t, verifyNotaryRequest(bc, nil, r))
|
||||
})
|
||||
}
|
||||
|
||||
func TestTryInitStateSync(t *testing.T) {
|
||||
t.Run("module inactive", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
s.tryInitStateSync()
|
||||
})
|
||||
|
||||
t.Run("module already initialized", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
ss := &fakechain.FakeStateSync{}
|
||||
ss.IsActiveFlag.Store(true)
|
||||
ss.IsInitializedFlag.Store(true)
|
||||
s.stateSync = ss
|
||||
s.tryInitStateSync()
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
for _, h := range []uint32{10, 8, 7, 4, 11, 4} {
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
p.lastBlockIndex = h
|
||||
s.peers[p] = true
|
||||
}
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = false // one disconnected peer to check it won't be taken into attention
|
||||
p.lastBlockIndex = 5
|
||||
s.peers[p] = true
|
||||
var expectedH uint32 = 8 // median peer
|
||||
|
||||
ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error {
|
||||
if h != expectedH {
|
||||
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
ss.IsActiveFlag.Store(true)
|
||||
s.stateSync = ss
|
||||
s.tryInitStateSync()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -267,12 +267,10 @@ func (p *TCPPeer) StartProtocol() {
|
|||
zap.Uint32("id", p.Version().Nonce))
|
||||
|
||||
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
|
||||
if p.server.chain.BlockHeight() < p.LastBlockIndex() {
|
||||
err = p.server.requestBlocks(p)
|
||||
if err != nil {
|
||||
p.Disconnect(err)
|
||||
return
|
||||
}
|
||||
err = p.server.requestBlocksOrHeaders(p)
|
||||
if err != nil {
|
||||
p.Disconnect(err)
|
||||
return
|
||||
}
|
||||
|
||||
timer := time.NewTimer(p.server.ProtoTickInterval)
|
||||
|
@ -281,10 +279,8 @@ func (p *TCPPeer) StartProtocol() {
|
|||
case <-p.done:
|
||||
return
|
||||
case <-timer.C:
|
||||
// Try to sync in headers and block with the peer if his block height is higher then ours.
|
||||
if p.LastBlockIndex() > p.server.chain.BlockHeight() {
|
||||
err = p.server.requestBlocks(p)
|
||||
}
|
||||
// Try to sync in headers and block with the peer if his block height is higher than ours.
|
||||
err = p.server.requestBlocksOrHeaders(p)
|
||||
if err == nil {
|
||||
timer.Reset(p.server.ProtoTickInterval)
|
||||
}
|
||||
|
|
|
@ -679,6 +679,7 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err
|
|||
if err != nil {
|
||||
return nil, response.NewRPCError("Failed to get NEP17 last updated block", err.Error(), err)
|
||||
}
|
||||
stateSyncPoint := lastUpdated[math.MinInt32]
|
||||
bw := io.NewBufBinWriter()
|
||||
for _, h := range s.chain.GetNEP17Contracts() {
|
||||
balance, err := s.getNEP17Balance(h, u, bw)
|
||||
|
@ -692,10 +693,18 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err
|
|||
if cs == nil {
|
||||
continue
|
||||
}
|
||||
lub, ok := lastUpdated[cs.ID]
|
||||
if !ok {
|
||||
cfg := s.chain.GetConfig()
|
||||
if !cfg.P2PStateExchangeExtensions && cfg.RemoveUntraceableBlocks {
|
||||
return nil, response.NewInternalServerError(fmt.Sprintf("failed to get LastUpdatedBlock for balance of %s token", cs.Hash.StringLE()), nil)
|
||||
}
|
||||
lub = stateSyncPoint
|
||||
}
|
||||
bs.Balances = append(bs.Balances, result.NEP17Balance{
|
||||
Asset: h,
|
||||
Amount: balance.String(),
|
||||
LastUpdated: lastUpdated[cs.ID],
|
||||
LastUpdated: lub,
|
||||
})
|
||||
}
|
||||
return bs, nil
|
||||
|
|
Loading…
Reference in a new issue