diff --git a/docs/rpc.md b/docs/rpc.md index f12a7abca..1f8763f59 100644 --- a/docs/rpc.md +++ b/docs/rpc.md @@ -114,7 +114,11 @@ balance won't be shown in the list of NEP17 balances returned by the neo-go node (unlike the C# node behavior). However, transfer logs of such token are still available via `getnep17transfers` RPC call. -The behaviour of the `LastUpdatedBlock` tracking matches the C# node's one. +The behaviour of the `LastUpdatedBlock` tracking for archival nodes as far as for +governing token balances matches the C# node's one. For non-archival nodes and +other NEP17-compliant tokens if transfer's `LastUpdatedBlock` is lower than the +latest state synchronization point P the node working against, then +`LastUpdatedBlock` equals P. ### Unsupported methods diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index ef8145858..426972c6b 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services" "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/mempool" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/native" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -21,6 +22,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" + uatomic "go.uber.org/atomic" ) // FakeChain implements Blockchainer interface, but does not provide real functionality. @@ -42,6 +44,15 @@ type FakeChain struct { UtilityTokenBalance *big.Int } +// FakeStateSync implements StateSync interface. +type FakeStateSync struct { + IsActiveFlag uatomic.Bool + IsInitializedFlag uatomic.Bool + InitFunc func(h uint32) error + TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error + AddMPTNodesFunc func(nodes [][]byte) error +} + // NewFakeChain returns new FakeChain structure. func NewFakeChain() *FakeChain { return &FakeChain{ @@ -294,6 +305,11 @@ func (chain *FakeChain) GetStateModule() blockchainer.StateRoot { return nil } +// GetStateSyncModule implements Blockchainer interface. +func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync { + return &FakeStateSync{} +} + // GetStorageItem implements Blockchainer interface. func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem { panic("TODO") @@ -436,3 +452,63 @@ func (chain *FakeChain) UnsubscribeFromNotifications(ch chan<- *state.Notificati func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) { panic("TODO") } + +// AddBlock implements StateSync interface. +func (s *FakeStateSync) AddBlock(block *block.Block) error { + panic("TODO") +} + +// AddHeaders implements StateSync interface. +func (s *FakeStateSync) AddHeaders(...*block.Header) error { + panic("TODO") +} + +// AddMPTNodes implements StateSync interface. +func (s *FakeStateSync) AddMPTNodes(nodes [][]byte) error { + if s.AddMPTNodesFunc != nil { + return s.AddMPTNodesFunc(nodes) + } + panic("TODO") +} + +// BlockHeight implements StateSync interface. +func (s *FakeStateSync) BlockHeight() uint32 { + panic("TODO") +} + +// IsActive implements StateSync interface. +func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag.Load() } + +// IsInitialized implements StateSync interface. +func (s *FakeStateSync) IsInitialized() bool { + return s.IsInitializedFlag.Load() +} + +// Init implements StateSync interface. +func (s *FakeStateSync) Init(currChainHeight uint32) error { + if s.InitFunc != nil { + return s.InitFunc(currChainHeight) + } + panic("TODO") +} + +// NeedHeaders implements StateSync interface. +func (s *FakeStateSync) NeedHeaders() bool { return false } + +// NeedMPTNodes implements StateSync interface. +func (s *FakeStateSync) NeedMPTNodes() bool { + panic("TODO") +} + +// Traverse implements StateSync interface. +func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + if s.TraverseFunc != nil { + return s.TraverseFunc(root, process) + } + panic("TODO") +} + +// GetUnknownMPTNodesBatch implements StateSync interface. +func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 { + panic("TODO") +} diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 1a20308f2..f2e7e14a1 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -23,6 +23,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/native/noderoles" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/stateroot" + "github.com/nspcc-dev/neo-go/pkg/core/statesync" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -34,6 +35,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "go.uber.org/zap" @@ -42,7 +44,7 @@ import ( // Tuning parameters. const ( headerBatchCount = 2000 - version = "0.1.2" + version = "0.1.4" defaultInitialGAS = 52000000_00000000 defaultMemPoolSize = 50000 @@ -54,6 +56,31 @@ const ( // HeaderVerificationGasLimit is the maximum amount of GAS for block header verification. HeaderVerificationGasLimit = 3_00000000 // 3 GAS defaultStateSyncInterval = 40000 + + // maxStorageBatchSize is the number of elements in storage batch expected to fit into the + // storage without delays and problems. Estimated size of batch in case of given number of + // elements does not exceed 1Mb. + maxStorageBatchSize = 10000 +) + +// stateJumpStage denotes the stage of state jump process. +type stateJumpStage byte + +const ( + // none means that no state jump process was initiated yet. + none stateJumpStage = 1 << iota + // stateJumpStarted means that state jump was just initiated, but outdated storage items + // were not yet removed. + stateJumpStarted + // oldStorageItemsRemoved means that outdated contract storage items were removed, but + // new storage items were not yet saved. + oldStorageItemsRemoved + // newStorageItemsAdded means that contract storage items are up-to-date with the current + // state. + newStorageItemsAdded + // genesisStateRemoved means that state corresponding to the genesis block was removed + // from the storage. + genesisStateRemoved ) var ( @@ -308,16 +335,6 @@ func (bc *Blockchain) init() error { // and the genesis block as first block. bc.log.Info("restoring blockchain", zap.String("version", version)) - bHeight, err := bc.dao.GetCurrentBlockHeight() - if err != nil { - return err - } - bc.blockHeight = bHeight - bc.persistedHeight = bHeight - if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil { - return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err) - } - bc.headerHashes, err = bc.dao.GetHeaderHashes() if err != nil { return err @@ -365,6 +382,34 @@ func (bc *Blockchain) init() error { } } + // Check whether StateJump stage is in the storage and continue interrupted state jump if so. + jumpStage, err := bc.dao.Store.Get(storage.SYSStateJumpStage.Bytes()) + if err == nil { + if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) { + return errors.New("state jump was not completed, but P2PStateExchangeExtensions are disabled or archival node capability is on. " + + "To start an archival node drop the database manually and restart the node") + } + if len(jumpStage) != 1 { + return fmt.Errorf("invalid state jump stage format") + } + // State jump wasn't finished yet, thus continue it. + stateSyncPoint, err := bc.dao.GetStateSyncPoint() + if err != nil { + return fmt.Errorf("failed to get state sync point from the storage") + } + return bc.jumpToStateInternal(stateSyncPoint, stateJumpStage(jumpStage[0])) + } + + bHeight, err := bc.dao.GetCurrentBlockHeight() + if err != nil { + return err + } + bc.blockHeight = bHeight + bc.persistedHeight = bHeight + if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil { + return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err) + } + err = bc.contracts.NEO.InitializeCache(bc, bc.dao) if err != nil { return fmt.Errorf("can't init cache for NEO native contract: %w", err) @@ -409,6 +454,158 @@ func (bc *Blockchain) init() error { return bc.updateExtensibleWhitelist(bHeight) } +// jumpToState is an atomic operation that changes Blockchain state to the one +// specified by the state sync point p. All the data needed for the jump must be +// collected by the state sync module. +func (bc *Blockchain) jumpToState(p uint32) error { + bc.lock.Lock() + defer bc.lock.Unlock() + + return bc.jumpToStateInternal(p, none) +} + +// jumpToStateInternal is an internal representation of jumpToState callback that +// changes Blockchain state to the one specified by state sync point p and state +// jump stage. All the data needed for the jump must be in the DB, otherwise an +// error is returned. It is not protected by mutex. +func (bc *Blockchain) jumpToStateInternal(p uint32, stage stateJumpStage) error { + if p+1 >= uint32(len(bc.headerHashes)) { + return fmt.Errorf("invalid state sync point %d: headerHeignt is %d", p, len(bc.headerHashes)) + } + + bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p)) + + writeBuf := io.NewBufBinWriter() + jumpStageKey := storage.SYSStateJumpStage.Bytes() + switch stage { + case none: + err := bc.dao.Store.Put(jumpStageKey, []byte{byte(stateJumpStarted)}) + if err != nil { + return fmt.Errorf("failed to store state jump stage: %w", err) + } + fallthrough + case stateJumpStarted: + // Replace old storage items by new ones, it should be done step-by step. + // Firstly, remove all old genesis-related items. + b := bc.dao.Store.Batch() + bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) { + // Must copy here, #1468. + key := slice.Copy(k) + b.Delete(key) + }) + b.Put(jumpStageKey, []byte{byte(oldStorageItemsRemoved)}) + err := bc.dao.Store.PutBatch(b) + if err != nil { + return fmt.Errorf("failed to store state jump stage: %w", err) + } + fallthrough + case oldStorageItemsRemoved: + // Then change STTempStorage prefix to STStorage. Each replace operation is atomic. + for { + count := 0 + b := bc.dao.Store.Batch() + bc.dao.Store.Seek([]byte{byte(storage.STTempStorage)}, func(k, v []byte) { + if count >= maxStorageBatchSize { + return + } + // Must copy here, #1468. + oldKey := slice.Copy(k) + b.Delete(oldKey) + key := make([]byte, len(k)) + key[0] = byte(storage.STStorage) + copy(key[1:], k[1:]) + value := slice.Copy(v) + b.Put(key, value) + count += 2 + }) + if count > 0 { + err := bc.dao.Store.PutBatch(b) + if err != nil { + return fmt.Errorf("failed to replace outdated contract storage items with the fresh ones: %w", err) + } + } else { + break + } + } + err := bc.dao.Store.Put(jumpStageKey, []byte{byte(newStorageItemsAdded)}) + if err != nil { + return fmt.Errorf("failed to store state jump stage: %w", err) + } + fallthrough + case newStorageItemsAdded: + // After current state is updated, we need to remove outdated state-related data if so. + // The only outdated data we might have is genesis-related data, so check it. + if p-bc.config.MaxTraceableBlocks > 0 { + cache := bc.dao.GetWrapped() + writeBuf.Reset() + err := cache.DeleteBlock(bc.headerHashes[0], writeBuf) + if err != nil { + return fmt.Errorf("failed to remove outdated state data for the genesis block: %w", err) + } + // TODO: remove NEP17 transfers and NEP17 transfer info for genesis block, #2096 related. + _, err = cache.Persist() + if err != nil { + return fmt.Errorf("failed to drop genesis block state: %w", err) + } + } + err := bc.dao.Store.Put(jumpStageKey, []byte{byte(genesisStateRemoved)}) + if err != nil { + return fmt.Errorf("failed to store state jump stage: %w", err) + } + case genesisStateRemoved: + // there's nothing to do after that, so just continue with common operations + // and remove state jump stage in the end. + default: + return errors.New("unknown state jump stage") + } + + block, err := bc.dao.GetBlock(bc.headerHashes[p]) + if err != nil { + return fmt.Errorf("failed to get current block: %w", err) + } + writeBuf.Reset() + err = bc.dao.StoreAsCurrentBlock(block, writeBuf) + if err != nil { + return fmt.Errorf("failed to store current block: %w", err) + } + bc.topBlock.Store(block) + atomic.StoreUint32(&bc.blockHeight, p) + atomic.StoreUint32(&bc.persistedHeight, p) + + block, err = bc.dao.GetBlock(bc.headerHashes[p+1]) + if err != nil { + return fmt.Errorf("failed to get block to init MPT: %w", err) + } + if err = bc.stateRoot.JumpToState(&state.MPTRoot{ + Index: p, + Root: block.PrevStateRoot, + }, bc.config.KeepOnlyLatestState); err != nil { + return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err) + } + + err = bc.contracts.NEO.InitializeCache(bc, bc.dao) + if err != nil { + return fmt.Errorf("can't init cache for NEO native contract: %w", err) + } + err = bc.contracts.Management.InitializeCache(bc.dao) + if err != nil { + return fmt.Errorf("can't init cache for Management native contract: %w", err) + } + bc.contracts.Designate.InitializeCache() + + if err := bc.updateExtensibleWhitelist(p); err != nil { + return fmt.Errorf("failed to update extensible whitelist: %w", err) + } + + updateBlockHeightMetric(p) + + err = bc.dao.Store.Delete(jumpStageKey) + if err != nil { + return fmt.Errorf("failed to remove outdated state jump stage: %w", err) + } + return nil +} + // Run runs chain loop, it needs to be run as goroutine and executing it is // critical for correct Blockchain operation. func (bc *Blockchain) Run() { @@ -696,6 +893,11 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot { return bc.stateRoot } +// GetStateSyncModule returns new state sync service instance. +func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync { + return statesync.NewModule(bc, bc.log, bc.dao, bc.jumpToState) +} + // storeBlock performs chain update using the block given, it executes all // transactions with all appropriate side-effects and updates Blockchain state. // This is the only way to change Blockchain state. @@ -1159,12 +1361,25 @@ func (bc *Blockchain) GetNEP17Contracts() []util.Uint160 { } // GetNEP17LastUpdated returns a set of contract ids with the corresponding last updated -// block indexes. +// block indexes. In case of an empty account, latest stored state synchronisation point +// is returned under Math.MinInt32 key. func (bc *Blockchain) GetNEP17LastUpdated(acc util.Uint160) (map[int32]uint32, error) { info, err := bc.dao.GetNEP17TransferInfo(acc) if err != nil { return nil, err } + if bc.config.P2PStateExchangeExtensions && bc.config.RemoveUntraceableBlocks { + if _, ok := info.LastUpdated[bc.contracts.NEO.ID]; !ok { + nBalance, lub := bc.contracts.NEO.BalanceOf(bc.dao, acc) + if nBalance.Sign() != 0 { + info.LastUpdated[bc.contracts.NEO.ID] = lub + } + } + } + stateSyncPoint, err := bc.dao.GetStateSyncPoint() + if err == nil { + info.LastUpdated[math.MinInt32] = stateSyncPoint + } return info.LastUpdated, nil } diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 2b0ed4a0b..285eeb379 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -1,6 +1,7 @@ package core import ( + "encoding/binary" "errors" "fmt" "math/big" @@ -34,6 +35,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -41,6 +43,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" ) func TestVerifyHeader(t *testing.T) { @@ -1734,3 +1737,88 @@ func TestConfigNativeUpdateHistory(t *testing.T) { check(t, tc) } } + +func TestBlockchain_InitWithIncompleteStateJump(t *testing.T) { + var ( + stateSyncInterval = 4 + maxTraceable uint32 = 6 + ) + spountCfg := func(c *config.Config) { + c.ProtocolConfiguration.RemoveUntraceableBlocks = true + c.ProtocolConfiguration.StateRootInHeader = true + c.ProtocolConfiguration.P2PStateExchangeExtensions = true + c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval + c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable + } + bcSpout := newTestChainWithCustomCfg(t, spountCfg) + initBasicChain(t, bcSpout) + + // reach next to the latest state sync point and pretend that we've just restored + stateSyncPoint := (int(bcSpout.BlockHeight())/stateSyncInterval + 1) * stateSyncInterval + for i := bcSpout.BlockHeight() + 1; i <= uint32(stateSyncPoint); i++ { + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + } + require.Equal(t, uint32(stateSyncPoint), bcSpout.BlockHeight()) + b := bcSpout.newBlock() + require.NoError(t, bcSpout.AddHeaders(&b.Header)) + + // put storage items with STTemp prefix + batch := bcSpout.dao.Store.Batch() + bcSpout.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) { + key := slice.Copy(k) + key[0] = storage.STTempStorage.Bytes()[0] + value := slice.Copy(v) + batch.Put(key, value) + }) + require.NoError(t, bcSpout.dao.Store.PutBatch(batch)) + + checkNewBlockchainErr := func(t *testing.T, cfg func(c *config.Config), store storage.Store, shouldFail bool) { + unitTestNetCfg, err := config.Load("../../config", testchain.Network()) + require.NoError(t, err) + cfg(&unitTestNetCfg) + log := zaptest.NewLogger(t) + _, err = NewBlockchain(store, unitTestNetCfg.ProtocolConfiguration, log) + if shouldFail { + require.Error(t, err) + } else { + require.NoError(t, err) + } + } + boltCfg := func(c *config.Config) { + spountCfg(c) + c.ProtocolConfiguration.KeepOnlyLatestState = true + } + // manually store statejump stage to check statejump recover process + t.Run("invalid RemoveUntraceableBlocks setting", func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)})) + checkNewBlockchainErr(t, func(c *config.Config) { + boltCfg(c) + c.ProtocolConfiguration.RemoveUntraceableBlocks = false + }, bcSpout.dao.Store, true) + }) + t.Run("invalid state jump stage format", func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{0x01, 0x02})) + checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true) + }) + t.Run("missing state sync point", func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)})) + checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true) + }) + t.Run("invalid state sync point", func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)})) + point := make([]byte, 4) + binary.LittleEndian.PutUint32(point, uint32(len(bcSpout.headerHashes))) + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point)) + checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true) + }) + for _, stage := range []stateJumpStage{stateJumpStarted, oldStorageItemsRemoved, newStorageItemsAdded, genesisStateRemoved, 0x03} { + t.Run(fmt.Sprintf("state jump stage %d", stage), func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stage)})) + point := make([]byte, 4) + binary.LittleEndian.PutUint32(point, uint32(stateSyncPoint)) + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point)) + shouldFail := stage == 0x03 // unknown stage + checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, shouldFail) + }) + } +} diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index a0d5af6f4..998b8a846 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -21,7 +21,6 @@ import ( type Blockchainer interface { ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction GetConfig() config.ProtocolConfiguration - AddHeaders(...*block.Header) error Blockqueuer // Blockqueuer interface CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error) Close() @@ -56,6 +55,7 @@ type Blockchainer interface { GetStandByCommittee() keys.PublicKeys GetStandByValidators() keys.PublicKeys GetStateModule() StateRoot + GetStateSyncModule() StateSync GetStorageItem(id int32, key []byte) state.StorageItem GetStorageItems(id int32) (map[string]state.StorageItem, error) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM diff --git a/pkg/core/blockchainer/blockqueuer.go b/pkg/core/blockchainer/blockqueuer.go index 384ccd489..bae45281f 100644 --- a/pkg/core/blockchainer/blockqueuer.go +++ b/pkg/core/blockchainer/blockqueuer.go @@ -5,5 +5,6 @@ import "github.com/nspcc-dev/neo-go/pkg/core/block" // Blockqueuer is an interface for blockqueue. type Blockqueuer interface { AddBlock(block *block.Block) error + AddHeaders(...*block.Header) error BlockHeight() uint32 } diff --git a/pkg/core/blockchainer/state_root.go b/pkg/core/blockchainer/state_root.go index 979e15963..9a540bda8 100644 --- a/pkg/core/blockchainer/state_root.go +++ b/pkg/core/blockchainer/state_root.go @@ -9,6 +9,8 @@ import ( // StateRoot represents local state root module. type StateRoot interface { AddStateRoot(root *state.MPTRoot) error + CleanStorage() error + CurrentLocalHeight() uint32 CurrentLocalStateRoot() util.Uint256 CurrentValidatedHeight() uint32 GetStateProof(root util.Uint256, key []byte) ([][]byte, error) diff --git a/pkg/core/blockchainer/state_sync.go b/pkg/core/blockchainer/state_sync.go new file mode 100644 index 000000000..a8ff919d9 --- /dev/null +++ b/pkg/core/blockchainer/state_sync.go @@ -0,0 +1,19 @@ +package blockchainer + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// StateSync represents state sync module. +type StateSync interface { + AddMPTNodes([][]byte) error + Blockqueuer // Blockqueuer interface + Init(currChainHeight uint32) error + IsActive() bool + IsInitialized() bool + GetUnknownMPTNodesBatch(limit int) []util.Uint256 + NeedHeaders() bool + NeedMPTNodes() bool + Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error +} diff --git a/pkg/core/mpt/billet.go b/pkg/core/mpt/billet.go new file mode 100644 index 000000000..b2e19c3c8 --- /dev/null +++ b/pkg/core/mpt/billet.go @@ -0,0 +1,314 @@ +package mpt + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" +) + +var ( + // ErrRestoreFailed is returned when replacing HashNode by its "unhashed" + // candidate fails. + ErrRestoreFailed = errors.New("failed to restore MPT node") + errStop = errors.New("stop condition is met") +) + +// Billet is a part of MPT trie with missing hash nodes that need to be restored. +// Billet is based on the following assumptions: +// 1. Refcount can only be incremented (we don't change MPT structure during restore, +// thus don't need to decrease refcount). +// 2. Each time the part of Billet is completely restored, it is collapsed into +// HashNode. +// 3. Pair (node, path) must be restored only once. It's a duty of MPT pool to manage +// MPT paths in order to provide this assumption. +type Billet struct { + Store *storage.MemCachedStore + + root Node + refcountEnabled bool +} + +// NewBillet returns new billet for MPT trie restoring. It accepts a MemCachedStore +// to decouple storage errors from logic errors so that all storage errors are +// processed during `store.Persist()` at the caller. This also has the benefit, +// that every `Put` can be considered an atomic operation. +func NewBillet(rootHash util.Uint256, enableRefCount bool, store *storage.MemCachedStore) *Billet { + return &Billet{ + Store: store, + root: NewHashNode(rootHash), + refcountEnabled: enableRefCount, + } +} + +// RestoreHashNode replaces HashNode located at the provided path by the specified Node +// and stores it. It also maintains MPT as small as possible by collapsing those parts +// of MPT that have been completely restored. +func (b *Billet) RestoreHashNode(path []byte, node Node) error { + if _, ok := node.(*HashNode); ok { + return fmt.Errorf("%w: unable to restore node into HashNode", ErrRestoreFailed) + } + if _, ok := node.(EmptyNode); ok { + return fmt.Errorf("%w: unable to restore node into EmptyNode", ErrRestoreFailed) + } + r, err := b.putIntoNode(b.root, path, node) + if err != nil { + return err + } + b.root = r + + // If it's a leaf, then put into temporary contract storage. + if leaf, ok := node.(*LeafNode); ok { + k := append([]byte{byte(storage.STTempStorage)}, fromNibbles(path)...) + _ = b.Store.Put(k, leaf.value) + } + return nil +} + +// putIntoNode puts val with provided path inside curr and returns updated node. +// Reference counters are updated for both curr and returned value. +func (b *Billet) putIntoNode(curr Node, path []byte, val Node) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + return b.putIntoLeaf(n, path, val) + case *BranchNode: + return b.putIntoBranch(n, path, val) + case *ExtensionNode: + return b.putIntoExtension(n, path, val) + case *HashNode: + return b.putIntoHash(n, path, val) + case EmptyNode: + return nil, fmt.Errorf("%w: can't modify EmptyNode during restore", ErrRestoreFailed) + default: + panic("invalid MPT node type") + } +} + +func (b *Billet) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { + if len(path) != 0 { + return nil, fmt.Errorf("%w: can't modify LeafNode during restore", ErrRestoreFailed) + } + if curr.Hash() != val.Hash() { + return nil, fmt.Errorf("%w: bad Leaf node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) + } + // Once Leaf node is restored, it will be collapsed into HashNode forever, so + // there shouldn't be such situation when we try to restore Leaf node. + panic("bug: can't restore LeafNode") +} + +func (b *Billet) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { + if len(path) == 0 && curr.Hash().Equals(val.Hash()) { + // This node has already been restored, so it's an MPT pool duty to avoid + // duplicating restore requests. + panic("bug: can't perform restoring of BranchNode twice") + } + i, path := splitPath(path) + r, err := b.putIntoNode(curr.Children[i], path, val) + if err != nil { + return nil, err + } + curr.Children[i] = r + return b.tryCollapseBranch(curr), nil +} + +func (b *Billet) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { + if len(path) == 0 { + if curr.Hash() != val.Hash() { + return nil, fmt.Errorf("%w: bad Extension node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) + } + // This node has already been restored, so it's an MPT pool duty to avoid + // duplicating restore requests. + panic("bug: can't perform restoring of ExtensionNode twice") + } + if !bytes.HasPrefix(path, curr.key) { + return nil, fmt.Errorf("%w: can't modify ExtensionNode during restore", ErrRestoreFailed) + } + + r, err := b.putIntoNode(curr.next, path[len(curr.key):], val) + if err != nil { + return nil, err + } + curr.next = r + return b.tryCollapseExtension(curr), nil +} + +func (b *Billet) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { + // Once a part of MPT Billet is completely restored, it will be collapsed forever, so + // it's an MPT pool duty to avoid duplicating restore requests. + if len(path) != 0 { + return nil, fmt.Errorf("%w: node has already been collapsed", ErrRestoreFailed) + } + + // `curr` hash node can be either of + // 1) saved in storage (i.g. if we've already restored node with the same hash from the + // other part of MPT), so just add it to local in-memory MPT. + // 2) missing from the storage. It's OK because we're syncing MPT state, and the purpose + // is to store missing hash nodes. + // both cases are OK, but we still need to validate `val` against `curr`. + if val.Hash() != curr.Hash() { + return nil, fmt.Errorf("%w: can't restore HashNode: expected and actual hashes mismatch (%s vs %s)", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) + } + + if curr.Collapsed { + // This node has already been restored and collapsed, so it's an MPT pool duty to avoid + // duplicating restore requests. + panic("bug: can't perform restoring of collapsed node") + } + + // We also need to increment refcount in both cases. That's the only place where refcount + // is changed during restore process. Also flush right now, because sync process can be + // interrupted at any time. + b.incrementRefAndStore(val.Hash(), val.Bytes()) + + if val.Type() == LeafT { + return b.tryCollapseLeaf(val.(*LeafNode)), nil + } + return val, nil +} + +func (b *Billet) incrementRefAndStore(h util.Uint256, bs []byte) { + key := makeStorageKey(h.BytesBE()) + if b.refcountEnabled { + var ( + err error + data []byte + cnt int32 + ) + // An item may already be in store. + data, err = b.Store.Get(key) + if err == nil { + cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:])) + } + cnt++ + if len(data) == 0 { + data = append(bs, 0, 0, 0, 0) + } + binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt)) + _ = b.Store.Put(key, data) + } else { + _ = b.Store.Put(key, bs) + } +} + +// Traverse traverses MPT nodes (pre-order) starting from the billet root down +// to its children calling `process` for each serialised node until true is +// returned from `process` function. It also replaces all HashNodes to their +// "unhashed" counterparts until the stop condition is satisfied. +func (b *Billet) Traverse(process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) error { + r, err := b.traverse(b.root, process, ignoreStorageErr) + if err != nil && !errors.Is(err, errStop) { + return err + } + b.root = r + return nil +} + +func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) (Node, error) { + if _, ok := curr.(EmptyNode); ok { + // We're not interested in EmptyNodes, and they do not affect the + // traversal process, thus remain them untouched. + return curr, nil + } + if hn, ok := curr.(*HashNode); ok { + r, err := b.GetFromStore(hn.Hash()) + if err != nil { + if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) { + return hn, nil + } + return nil, err + } + return b.traverse(r, process, ignoreStorageErr) + } + bytes := slice.Copy(curr.Bytes()) + if process(curr, bytes) { + return curr, errStop + } + switch n := curr.(type) { + case *LeafNode: + return b.tryCollapseLeaf(n), nil + case *BranchNode: + for i := range n.Children { + r, err := b.traverse(n.Children[i], process, ignoreStorageErr) + if err != nil { + if !errors.Is(err, errStop) { + return nil, err + } + n.Children[i] = r + return n, err + } + n.Children[i] = r + } + return b.tryCollapseBranch(n), nil + case *ExtensionNode: + r, err := b.traverse(n.next, process, ignoreStorageErr) + if err != nil && !errors.Is(err, errStop) { + return nil, err + } + n.next = r + return b.tryCollapseExtension(n), err + default: + return nil, ErrNotFound + } +} + +func (b *Billet) tryCollapseLeaf(curr *LeafNode) Node { + // Leaf can always be collapsed. + res := NewHashNode(curr.Hash()) + res.Collapsed = true + return res +} + +func (b *Billet) tryCollapseExtension(curr *ExtensionNode) Node { + if !(curr.next.Type() == HashT && curr.next.(*HashNode).Collapsed) { + return curr + } + res := NewHashNode(curr.Hash()) + res.Collapsed = true + return res +} + +func (b *Billet) tryCollapseBranch(curr *BranchNode) Node { + canCollapse := true + for i := 0; i < childrenCount; i++ { + if curr.Children[i].Type() == EmptyT { + continue + } + if curr.Children[i].Type() == HashT && curr.Children[i].(*HashNode).Collapsed { + continue + } + canCollapse = false + break + } + if !canCollapse { + return curr + } + res := NewHashNode(curr.Hash()) + res.Collapsed = true + return res +} + +// GetFromStore returns MPT node from the storage. +func (b *Billet) GetFromStore(h util.Uint256) (Node, error) { + data, err := b.Store.Get(makeStorageKey(h.BytesBE())) + if err != nil { + return nil, err + } + + var n NodeObject + r := io.NewBinReaderFromBuf(data) + n.DecodeBinary(r) + if r.Err != nil { + return nil, r.Err + } + + if b.refcountEnabled { + data = data[:len(data)-4] + } + n.Node.(flushedNode).setCache(data, h) + return n.Node, nil +} diff --git a/pkg/core/mpt/billet_test.go b/pkg/core/mpt/billet_test.go new file mode 100644 index 000000000..7850b129e --- /dev/null +++ b/pkg/core/mpt/billet_test.go @@ -0,0 +1,211 @@ +package mpt + +import ( + "encoding/binary" + "errors" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestBillet_RestoreHashNode(t *testing.T) { + check := func(t *testing.T, tr *Billet, expectedRoot Node, expectedNode Node, expectedRefCount uint32) { + _ = expectedRoot.Hash() + _ = tr.root.Hash() + require.Equal(t, expectedRoot, tr.root) + expectedBytes, err := tr.Store.Get(makeStorageKey(expectedNode.Hash().BytesBE())) + if expectedRefCount != 0 { + require.NoError(t, err) + require.Equal(t, expectedRefCount, binary.LittleEndian.Uint32(expectedBytes[len(expectedBytes)-4:])) + } else { + require.True(t, errors.Is(err, storage.ErrKeyNotFound)) + } + } + + t.Run("parent is Extension", func(t *testing.T) { + t.Run("restore Branch", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xCD})) + b.Children[5] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xDE})) + path := toNibbles([]byte{0xAC}) + e := NewExtensionNode(path, NewHashNode(b.Hash())) + tr := NewBillet(e.Hash(), true, newTestStore()) + tr.root = e + + // OK + n := new(NodeObject) + n.DecodeBinary(io.NewBinReaderFromBuf(b.Bytes())) + require.NoError(t, tr.RestoreHashNode(path, n.Node)) + expected := NewExtensionNode(path, n.Node) + check(t, tr, expected, n.Node, 1) + + // One more time (already restored) => panic expected, no refcount changes + require.Panics(t, func() { + _ = tr.RestoreHashNode(path, n.Node) + }) + check(t, tr, expected, n.Node, 1) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewBranchNode()), ErrRestoreFailed)) + check(t, tr, expected, n.Node, 1) + + // New path (changes in the MPT structure are not allowed) => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), n.Node), ErrRestoreFailed)) + check(t, tr, expected, n.Node, 1) + }) + + t.Run("restore Leaf", func(t *testing.T) { + l := NewLeafNode([]byte{0xAB, 0xCD}) + path := toNibbles([]byte{0xAC}) + e := NewExtensionNode(path, NewHashNode(l.Hash())) + tr := NewBillet(e.Hash(), true, newTestStore()) + tr.root = e + + // OK + require.NoError(t, tr.RestoreHashNode(path, l)) + expected := NewHashNode(e.Hash()) // leaf should be collapsed immediately => extension should also be collapsed + expected.Collapsed = true + check(t, tr, expected, l, 1) + + // One more time (already restored and collapsed) => error expected, no refcount changes + require.Error(t, tr.RestoreHashNode(path, l)) + check(t, tr, expected, l, 1) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed)) + check(t, tr, expected, l, 1) + + // New path (changes in the MPT structure are not allowed) => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), l), ErrRestoreFailed)) + check(t, tr, expected, l, 1) + }) + + t.Run("restore Hash", func(t *testing.T) { + h := NewHashNode(util.Uint256{1, 2, 3}) + path := toNibbles([]byte{0xAC}) + e := NewExtensionNode(path, h) + tr := NewBillet(e.Hash(), true, newTestStore()) + tr.root = e + + // no-op + require.True(t, errors.Is(tr.RestoreHashNode(path, h), ErrRestoreFailed)) + check(t, tr, e, h, 0) + }) + }) + + t.Run("parent is Leaf", func(t *testing.T) { + l := NewLeafNode([]byte{0xAB, 0xCD}) + path := []byte{} + tr := NewBillet(l.Hash(), true, newTestStore()) + tr.root = l + + // Already restored => panic expected + require.Panics(t, func() { + _ = tr.RestoreHashNode(path, l) + }) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed)) + + // Non-nil path, but MPT structure can't be changed => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAC}), NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed)) + }) + + t.Run("parent is Branch", func(t *testing.T) { + t.Run("middle child", func(t *testing.T) { + l1 := NewLeafNode([]byte{0xAB, 0xCD}) + l2 := NewLeafNode([]byte{0xAB, 0xDE}) + b := NewBranchNode() + b.Children[5] = NewHashNode(l1.Hash()) + b.Children[lastChild] = NewHashNode(l2.Hash()) + tr := NewBillet(b.Hash(), true, newTestStore()) + tr.root = b + + // OK + path := []byte{0x05} + require.NoError(t, tr.RestoreHashNode(path, l1)) + check(t, tr, b, l1, 1) + + // One more time (already restored) => panic expected. + // It's an MPT pool duty to avoid such situations during real restore process. + require.Panics(t, func() { + _ = tr.RestoreHashNode(path, l1) + }) + // No refcount changes expected. + check(t, tr, b, l1, 1) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed)) + check(t, tr, b, l1, 1) + + // New path pointing to the empty HashNode (changes in the MPT structure are not allowed) => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode([]byte{0x01}, l1), ErrRestoreFailed)) + check(t, tr, b, l1, 1) + }) + + t.Run("last child", func(t *testing.T) { + l1 := NewLeafNode([]byte{0xAB, 0xCD}) + l2 := NewLeafNode([]byte{0xAB, 0xDE}) + b := NewBranchNode() + b.Children[5] = NewHashNode(l1.Hash()) + b.Children[lastChild] = NewHashNode(l2.Hash()) + tr := NewBillet(b.Hash(), true, newTestStore()) + tr.root = b + + // OK + path := []byte{} + require.NoError(t, tr.RestoreHashNode(path, l2)) + check(t, tr, b, l2, 1) + + // One more time (already restored) => panic expected. + // It's an MPT pool duty to avoid such situations during real restore process. + require.Panics(t, func() { + _ = tr.RestoreHashNode(path, l2) + }) + // No refcount changes expected. + check(t, tr, b, l2, 1) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed)) + check(t, tr, b, l2, 1) + }) + + t.Run("two children with same hash", func(t *testing.T) { + l := NewLeafNode([]byte{0xAB, 0xCD}) + b := NewBranchNode() + // two same hashnodes => leaf's refcount expected to be 2 in the end. + b.Children[3] = NewHashNode(l.Hash()) + b.Children[4] = NewHashNode(l.Hash()) + tr := NewBillet(b.Hash(), true, newTestStore()) + tr.root = b + + // OK + require.NoError(t, tr.RestoreHashNode([]byte{0x03}, l)) + expected := b + expected.Children[3].(*HashNode).Collapsed = true + check(t, tr, b, l, 1) + + // Restore another node with the same hash => no error expected, refcount should be incremented. + // Branch node should be collapsed. + require.NoError(t, tr.RestoreHashNode([]byte{0x04}, l)) + res := NewHashNode(b.Hash()) + res.Collapsed = true + check(t, tr, res, l, 2) + }) + }) + + t.Run("parent is Hash", func(t *testing.T) { + l := NewLeafNode([]byte{0xAB, 0xCD}) + b := NewBranchNode() + b.Children[3] = NewHashNode(l.Hash()) + b.Children[4] = NewHashNode(l.Hash()) + tr := NewBillet(b.Hash(), true, newTestStore()) + + // Should fail, because if it's a hash node with non-empty path, then the node + // has already been collapsed. + require.Error(t, tr.RestoreHashNode([]byte{0x03}, l)) + }) +} diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go index 0338ff4a7..d2fc84dfc 100644 --- a/pkg/core/mpt/branch.go +++ b/pkg/core/mpt/branch.go @@ -89,6 +89,12 @@ func (b *BranchNode) UnmarshalJSON(data []byte) error { return errors.New("expected branch node") } +// Clone implements Node interface. +func (b *BranchNode) Clone() Node { + res := *b + return &res +} + // splitPath splits path for a branch node. func splitPath(path []byte) (byte, []byte) { if len(path) != 0 { diff --git a/pkg/core/mpt/empty.go b/pkg/core/mpt/empty.go index 6669ef8c1..bc4f4914a 100644 --- a/pkg/core/mpt/empty.go +++ b/pkg/core/mpt/empty.go @@ -54,3 +54,6 @@ func (e EmptyNode) Type() NodeType { func (e EmptyNode) Bytes() []byte { return nil } + +// Clone implements Node interface. +func (EmptyNode) Clone() Node { return EmptyNode{} } diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go index 2dcbcb66b..1266c6acb 100644 --- a/pkg/core/mpt/extension.go +++ b/pkg/core/mpt/extension.go @@ -98,3 +98,9 @@ func (e *ExtensionNode) UnmarshalJSON(data []byte) error { } return errors.New("expected extension node") } + +// Clone implements Node interface. +func (e *ExtensionNode) Clone() Node { + res := *e + return &res +} diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go index 05ddbe5f3..03dc47a36 100644 --- a/pkg/core/mpt/hash.go +++ b/pkg/core/mpt/hash.go @@ -10,6 +10,7 @@ import ( // HashNode represents MPT's hash node. type HashNode struct { BaseNode + Collapsed bool } var _ Node = (*HashNode)(nil) @@ -76,3 +77,10 @@ func (h *HashNode) UnmarshalJSON(data []byte) error { } return errors.New("expected hash node") } + +// Clone implements Node interface. +func (h *HashNode) Clone() Node { + res := *h + res.Collapsed = false + return &res +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go index a7399d37d..63c02c089 100644 --- a/pkg/core/mpt/helpers.go +++ b/pkg/core/mpt/helpers.go @@ -1,5 +1,7 @@ package mpt +import "github.com/nspcc-dev/neo-go/pkg/util" + // lcp returns longest common prefix of a and b. // Note: it does no allocations. func lcp(a, b []byte) []byte { @@ -40,3 +42,45 @@ func toNibbles(path []byte) []byte { } return result } + +// fromNibbles performs operation opposite to toNibbles and does no path validity checks. +func fromNibbles(path []byte) []byte { + result := make([]byte, len(path)/2) + for i := range result { + result[i] = path[2*i]<<4 + path[2*i+1] + } + return result +} + +// GetChildrenPaths returns a set of paths to node's children who are non-empty HashNodes +// based on the node's path. +func GetChildrenPaths(path []byte, node Node) map[util.Uint256][][]byte { + res := make(map[util.Uint256][][]byte) + switch n := node.(type) { + case *LeafNode, *HashNode, EmptyNode: + return nil + case *BranchNode: + for i, child := range n.Children { + if child.Type() == HashT { + cPath := make([]byte, len(path), len(path)+1) + copy(cPath, path) + if i != lastChild { + cPath = append(cPath, byte(i)) + } + paths := res[child.Hash()] + paths = append(paths, cPath) + res[child.Hash()] = paths + } + } + case *ExtensionNode: + if n.next.Type() == HashT { + cPath := make([]byte, len(path)+len(n.key)) + copy(cPath, path) + copy(cPath[len(path):], n.key) + res[n.next.Hash()] = [][]byte{cPath} + } + default: + panic("unknown Node type") + } + return res +} diff --git a/pkg/core/mpt/helpers_test.go b/pkg/core/mpt/helpers_test.go new file mode 100644 index 000000000..28181dcc2 --- /dev/null +++ b/pkg/core/mpt/helpers_test.go @@ -0,0 +1,67 @@ +package mpt + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestToNibblesFromNibbles(t *testing.T) { + check := func(t *testing.T, expected []byte) { + actual := fromNibbles(toNibbles(expected)) + require.Equal(t, expected, actual) + } + t.Run("empty path", func(t *testing.T) { + check(t, []byte{}) + }) + t.Run("non-empty path", func(t *testing.T) { + check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF}) + }) +} + +func TestGetChildrenPaths(t *testing.T) { + h1 := NewHashNode(util.Uint256{1, 2, 3}) + h2 := NewHashNode(util.Uint256{4, 5, 6}) + h3 := NewHashNode(util.Uint256{7, 8, 9}) + l := NewLeafNode([]byte{1, 2, 3}) + ext1 := NewExtensionNode([]byte{8, 9}, h1) + ext2 := NewExtensionNode([]byte{7, 6}, l) + branch := NewBranchNode() + branch.Children[3] = h1 + branch.Children[5] = l + branch.Children[6] = h1 // 3-th and 6-th children have the same hash + branch.Children[7] = h3 + branch.Children[lastChild] = h2 + testCases := map[string]struct { + node Node + expected map[util.Uint256][][]byte + }{ + "Hash": {h1, nil}, + "Leaf": {l, nil}, + "Extension with next Hash": {ext1, map[util.Uint256][][]byte{h1.Hash(): {ext1.key}}}, + "Extension with next non-Hash": {ext2, map[util.Uint256][][]byte{}}, + "Branch": {branch, map[util.Uint256][][]byte{ + h1.Hash(): {{0x03}, {0x06}}, + h2.Hash(): {{}}, + h3.Hash(): {{0x07}}, + }}, + } + parentPath := []byte{4, 5, 6} + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + require.Equal(t, testCase.expected, GetChildrenPaths([]byte{}, testCase.node)) + if testCase.expected != nil { + expectedWithPrefix := make(map[util.Uint256][][]byte, len(testCase.expected)) + for h, paths := range testCase.expected { + var res [][]byte + for _, path := range paths { + res = append(res, append(parentPath, path...)) + } + expectedWithPrefix[h] = res + } + require.Equal(t, expectedWithPrefix, GetChildrenPaths(parentPath, testCase.node)) + } + }) + } +} diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go index 0f3072b85..0efa45e70 100644 --- a/pkg/core/mpt/leaf.go +++ b/pkg/core/mpt/leaf.go @@ -77,3 +77,9 @@ func (n *LeafNode) UnmarshalJSON(data []byte) error { } return errors.New("expected leaf node") } + +// Clone implements Node interface. +func (n *LeafNode) Clone() Node { + res := *n + return &res +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index af35286c1..a5f6cc814 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -34,6 +34,7 @@ type Node interface { json.Marshaler json.Unmarshaler Size() int + Clone() Node BaseNodeIface } diff --git a/pkg/core/mpt/proof_test.go b/pkg/core/mpt/proof_test.go index 75a76408f..d733fb6d0 100644 --- a/pkg/core/mpt/proof_test.go +++ b/pkg/core/mpt/proof_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" ) -func newProofTrie(t *testing.T) *Trie { +func newProofTrie(t *testing.T, missingHashNode bool) *Trie { l := NewLeafNode([]byte("somevalue")) e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l) l2 := NewLeafNode([]byte("invalid")) @@ -20,11 +20,14 @@ func newProofTrie(t *testing.T) *Trie { require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) tr.putToStore(l) tr.putToStore(e) + if !missingHashNode { + tr.putToStore(l2) + } return tr } func TestTrie_GetProof(t *testing.T) { - tr := newProofTrie(t) + tr := newProofTrie(t, true) t.Run("MissingKey", func(t *testing.T) { _, err := tr.GetProof([]byte{0x12}) @@ -43,7 +46,7 @@ func TestTrie_GetProof(t *testing.T) { } func TestVerifyProof(t *testing.T) { - tr := newProofTrie(t) + tr := newProofTrie(t, true) t.Run("Simple", func(t *testing.T) { proof, err := tr.GetProof([]byte{0x12, 0x32}) diff --git a/pkg/core/native/designate.go b/pkg/core/native/designate.go index 8c75eebc8..844690a3d 100644 --- a/pkg/core/native/designate.go +++ b/pkg/core/native/designate.go @@ -353,3 +353,8 @@ func (s *Designate) getRole(item stackitem.Item) (noderoles.Role, bool) { u := bi.Uint64() return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u)) } + +// InitializeCache invalidates native Designate cache. +func (s *Designate) InitializeCache() { + s.rolesChangedFlag.Store(true) +} diff --git a/pkg/core/stateroot/module.go b/pkg/core/stateroot/module.go index 530dbf583..838e4717a 100644 --- a/pkg/core/stateroot/module.go +++ b/pkg/core/stateroot/module.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" "go.uber.org/atomic" "go.uber.org/zap" ) @@ -114,6 +115,66 @@ func (s *Module) Init(height uint32, enableRefCount bool) error { return nil } +// CleanStorage removes all MPT-related data from the storage (MPT nodes, validated stateroots) +// except local stateroot for the current height and GC flag. This method is aimed to clean +// outdated MPT data before state sync process can be started. +// Note: this method is aimed to be called for genesis block only, an error is returned otherwice. +func (s *Module) CleanStorage() error { + if s.localHeight.Load() != 0 { + return fmt.Errorf("can't clean MPT data for non-genesis block: expected local stateroot height 0, got %d", s.localHeight.Load()) + } + gcKey := []byte{byte(storage.DataMPT), prefixGC} + gcVal, err := s.Store.Get(gcKey) + if err != nil { + return fmt.Errorf("failed to get GC flag: %w", err) + } + // + b := s.Store.Batch() + s.Store.Seek([]byte{byte(storage.DataMPT)}, func(k, _ []byte) { + // Must copy here, #1468. + key := slice.Copy(k) + b.Delete(key) + }) + err = s.Store.PutBatch(b) + if err != nil { + return fmt.Errorf("failed to remove outdated MPT-reated items: %w", err) + } + err = s.Store.Put(gcKey, gcVal) + if err != nil { + return fmt.Errorf("failed to store GC flag: %w", err) + } + currentLocal := s.currentLocal.Load().(util.Uint256) + if !currentLocal.Equals(util.Uint256{}) { + err := s.addLocalStateRoot(s.Store, &state.MPTRoot{ + Index: s.localHeight.Load(), + Root: currentLocal, + }) + if err != nil { + return fmt.Errorf("failed to store current local stateroot: %w", err) + } + } + return nil +} + +// JumpToState performs jump to the state specified by given stateroot index. +func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error { + if err := s.addLocalStateRoot(s.Store, sr); err != nil { + return fmt.Errorf("failed to store local state root: %w", err) + } + + data := make([]byte, 4) + binary.LittleEndian.PutUint32(data, sr.Index) + if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixValidated}, data); err != nil { + return fmt.Errorf("failed to store validated height: %w", err) + } + s.validatedHeight.Store(sr.Index) + + s.currentLocal.Store(sr.Root) + s.localHeight.Store(sr.Index) + s.mpt = mpt.NewTrie(mpt.NewHashNode(sr.Root), enableRefCount, s.Store) + return nil +} + // AddMPTBatch updates using provided batch. func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) { mpt := *s.mpt diff --git a/pkg/core/statesync/module.go b/pkg/core/statesync/module.go new file mode 100644 index 000000000..801b51a97 --- /dev/null +++ b/pkg/core/statesync/module.go @@ -0,0 +1,479 @@ +/* +Package statesync implements module for the P2P state synchronisation process. The +module manages state synchronisation for non-archival nodes which are joining the +network and don't have the ability to resync from the genesis block. + +Given the currently available state synchronisation point P, sate sync process +includes the following stages: + +1. Fetching headers starting from height 0 up to P+1. +2. Fetching MPT nodes for height P stating from the corresponding state root. +3. Fetching blocks starting from height P-MaxTraceableBlocks (or 0) up to P. + +Steps 2 and 3 are being performed in parallel. Once all the data are collected +and stored in the db, an atomic state jump is occurred to the state sync point P. +Further node operation process is performed using standard sync mechanism until +the node reaches synchronised state. +*/ +package statesync + +import ( + "encoding/hex" + "errors" + "fmt" + "sync" + + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" + "github.com/nspcc-dev/neo-go/pkg/core/dao" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "go.uber.org/zap" +) + +// stateSyncStage is a type of state synchronisation stage. +type stateSyncStage uint8 + +const ( + // inactive means that state exchange is disabled by the protocol configuration. + // Can't be combined with other states. + inactive stateSyncStage = 1 << iota + // none means that state exchange is enabled in the configuration, but + // initialisation of the state sync module wasn't yet performed, i.e. + // (*Module).Init wasn't called. Can't be combined with other states. + none + // initialized means that (*Module).Init was called, but other sync stages + // are not yet reached (i.e. that headers are requested, but not yet fetched). + // Can't be combined with other states. + initialized + // headersSynced means that headers for the current state sync point are fetched. + // May be combined with mptSynced and/or blocksSynced. + headersSynced + // mptSynced means that MPT nodes for the current state sync point are fetched. + // Always combined with headersSynced; may be combined with blocksSynced. + mptSynced + // blocksSynced means that blocks up to the current state sync point are stored. + // Always combined with headersSynced; may be combined with mptSynced. + blocksSynced +) + +// Module represents state sync module and aimed to gather state-related data to +// perform an atomic state jump. +type Module struct { + lock sync.RWMutex + log *zap.Logger + + // syncPoint is the state synchronisation point P we're currently working against. + syncPoint uint32 + // syncStage is the stage of the sync process. + syncStage stateSyncStage + // syncInterval is the delta between two adjacent state sync points. + syncInterval uint32 + // blockHeight is the index of the latest stored block. + blockHeight uint32 + + dao *dao.Simple + bc blockchainer.Blockchainer + mptpool *Pool + + billet *mpt.Billet + + jumpCallback func(p uint32) error +} + +// NewModule returns new instance of statesync module. +func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple, jumpCallback func(p uint32) error) *Module { + if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) { + return &Module{ + dao: s, + bc: bc, + syncStage: inactive, + } + } + return &Module{ + dao: s, + bc: bc, + log: log, + syncInterval: uint32(bc.GetConfig().StateSyncInterval), + mptpool: NewPool(), + syncStage: none, + jumpCallback: jumpCallback, + } +} + +// Init initializes state sync module for the current chain's height with given +// callback for MPT nodes requests. +func (s *Module) Init(currChainHeight uint32) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage != none { + return errors.New("already initialized or inactive") + } + + p := (currChainHeight / s.syncInterval) * s.syncInterval + if p < 2*s.syncInterval { + // chain is too low to start state exchange process, use the standard sync mechanism + s.syncStage = inactive + return nil + } + pOld, err := s.dao.GetStateSyncPoint() + if err == nil && pOld >= p-s.syncInterval { + // old point is still valid, so try to resync states for this point. + p = pOld + } else { + if s.bc.BlockHeight() > p-2*s.syncInterval { + // chain has already been synchronised up to old state sync point and regular blocks processing was started. + // Current block height is enough to start regular blocks processing. + s.syncStage = inactive + return nil + } + if err == nil { + // pOld was found, it is outdated, and chain wasn't completely synchronised for pOld. Need to drop the db. + return fmt.Errorf("state sync point %d is found in the storage, "+ + "but sync process wasn't completed and point is outdated. Please, drop the database manually and restart the node to run state sync process", pOld) + } + if s.bc.BlockHeight() != 0 { + // pOld wasn't found, but blocks processing was started in a regular manner and latest stored block is too outdated + // to start regular blocks processing again. Need to drop the db. + return fmt.Errorf("current chain's height is too low to start regular blocks processing from the oldest sync point %d. "+ + "Please, drop the database manually and restart the node to run state sync process", p-s.syncInterval) + } + + // We've reached this point, so chain has genesis block only. As far as we can't ruin + // current chain's state until new state is completely fetched, outdated state-related data + // will be removed from storage during (*Blockchain).jumpToState(...) execution. + // All we need to do right now is to remove genesis-related MPT nodes. + err = s.bc.GetStateModule().CleanStorage() + if err != nil { + return fmt.Errorf("failed to remove outdated MPT data from storage: %w", err) + } + } + + s.syncPoint = p + err = s.dao.PutStateSyncPoint(p) + if err != nil { + return fmt.Errorf("failed to store state synchronisation point %d: %w", p, err) + } + s.syncStage = initialized + s.log.Info("try to sync state for the latest state synchronisation point", + zap.Uint32("point", p), + zap.Uint32("evaluated chain's blockHeight", currChainHeight)) + + return s.defineSyncStage() +} + +// defineSyncStage sequentially checks and sets sync state process stage after Module +// initialization. It also performs initialization of MPT Billet if necessary. +func (s *Module) defineSyncStage() error { + // check headers sync stage first + ltstHeaderHeight := s.bc.HeaderHeight() + if ltstHeaderHeight > s.syncPoint { + s.syncStage = headersSynced + s.log.Info("headers are in sync", + zap.Uint32("headerHeight", s.bc.HeaderHeight())) + } + + // check blocks sync stage + s.blockHeight = s.getLatestSavedBlock(s.syncPoint) + if s.blockHeight >= s.syncPoint { + s.syncStage |= blocksSynced + s.log.Info("blocks are in sync", + zap.Uint32("blockHeight", s.blockHeight)) + } + + // check MPT sync stage + if s.blockHeight > s.syncPoint { + s.syncStage |= mptSynced + s.log.Info("MPT is in sync", + zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight())) + } else if s.syncStage&headersSynced != 0 { + header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint + 1))) + if err != nil { + return fmt.Errorf("failed to get header to initialize MPT billet: %w", err) + } + s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store) + s.log.Info("MPT billet initialized", + zap.Uint32("height", s.syncPoint), + zap.String("state root", header.PrevStateRoot.StringBE())) + pool := NewPool() + pool.Add(header.PrevStateRoot, []byte{}) + err = s.billet.Traverse(func(n mpt.Node, _ []byte) bool { + nPaths, ok := pool.TryGet(n.Hash()) + if !ok { + // if this situation occurs, then it's a bug in MPT pool or Traverse. + panic("failed to get MPT node from the pool") + } + pool.Remove(n.Hash()) + childrenPaths := make(map[util.Uint256][][]byte) + for _, path := range nPaths { + nChildrenPaths := mpt.GetChildrenPaths(path, n) + for hash, paths := range nChildrenPaths { + childrenPaths[hash] = append(childrenPaths[hash], paths...) // it's OK to have duplicates, they'll be handled by mempool + } + } + pool.Update(nil, childrenPaths) + return false + }, true) + if err != nil { + return fmt.Errorf("failed to traverse MPT during initialization: %w", err) + } + s.mptpool.Update(nil, pool.GetAll()) + if s.mptpool.Count() == 0 { + s.syncStage |= mptSynced + s.log.Info("MPT is in sync", + zap.Uint32("stateroot height", s.syncPoint)) + } + } + + if s.syncStage == headersSynced|blocksSynced|mptSynced { + s.log.Info("state is in sync, starting regular blocks processing") + s.syncStage = inactive + } + return nil +} + +// getLatestSavedBlock returns either current block index (if it's still relevant +// to continue state sync process) or H-1 where H is the index of the earliest +// block that should be saved next. +func (s *Module) getLatestSavedBlock(p uint32) uint32 { + var result uint32 + mtb := s.bc.GetConfig().MaxTraceableBlocks + if p > mtb { + result = p - mtb + } + storedH, err := s.dao.GetStateSyncCurrentBlockHeight() + if err == nil && storedH > result { + result = storedH + } + actualH := s.bc.BlockHeight() + if actualH > result { + result = actualH + } + return result +} + +// AddHeaders validates and adds specified headers to the chain. +func (s *Module) AddHeaders(hdrs ...*block.Header) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage != initialized { + return errors.New("headers were not requested") + } + + hdrsErr := s.bc.AddHeaders(hdrs...) + if s.bc.HeaderHeight() > s.syncPoint { + err := s.defineSyncStage() + if err != nil { + return fmt.Errorf("failed to define current sync stage: %w", err) + } + } + return hdrsErr +} + +// AddBlock verifies and saves block skipping executable scripts. +func (s *Module) AddBlock(block *block.Block) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage&headersSynced == 0 || s.syncStage&blocksSynced != 0 { + return nil + } + + if s.blockHeight == s.syncPoint { + return nil + } + expectedHeight := s.blockHeight + 1 + if expectedHeight != block.Index { + return fmt.Errorf("expected %d, got %d: invalid block index", expectedHeight, block.Index) + } + if s.bc.GetConfig().StateRootInHeader != block.StateRootEnabled { + return fmt.Errorf("stateroot setting mismatch: %v != %v", s.bc.GetConfig().StateRootInHeader, block.StateRootEnabled) + } + if s.bc.GetConfig().VerifyBlocks { + merkle := block.ComputeMerkleRoot() + if !block.MerkleRoot.Equals(merkle) { + return errors.New("invalid block: MerkleRoot mismatch") + } + } + cache := s.dao.GetWrapped() + writeBuf := io.NewBufBinWriter() + if err := cache.StoreAsBlock(block, writeBuf); err != nil { + return err + } + writeBuf.Reset() + + err := cache.PutStateSyncCurrentBlockHeight(block.Index) + if err != nil { + return fmt.Errorf("failed to store current block height: %w", err) + } + + for _, tx := range block.Transactions { + if err := cache.StoreAsTransaction(tx, block.Index, writeBuf); err != nil { + return err + } + writeBuf.Reset() + } + + _, err = cache.Persist() + if err != nil { + return fmt.Errorf("failed to persist results: %w", err) + } + s.blockHeight = block.Index + if s.blockHeight == s.syncPoint { + s.syncStage |= blocksSynced + s.log.Info("blocks are in sync", + zap.Uint32("blockHeight", s.blockHeight)) + s.checkSyncIsCompleted() + } + return nil +} + +// AddMPTNodes tries to add provided set of MPT nodes to the MPT billet if they are +// not yet collected. +func (s *Module) AddMPTNodes(nodes [][]byte) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage&headersSynced == 0 || s.syncStage&mptSynced != 0 { + return errors.New("MPT nodes were not requested") + } + + for _, nBytes := range nodes { + var n mpt.NodeObject + r := io.NewBinReaderFromBuf(nBytes) + n.DecodeBinary(r) + if r.Err != nil { + return fmt.Errorf("failed to decode MPT node: %w", r.Err) + } + err := s.restoreNode(n.Node) + if err != nil { + return err + } + } + if s.mptpool.Count() == 0 { + s.syncStage |= mptSynced + s.log.Info("MPT is in sync", + zap.Uint32("height", s.syncPoint)) + s.checkSyncIsCompleted() + } + return nil +} + +func (s *Module) restoreNode(n mpt.Node) error { + nPaths, ok := s.mptpool.TryGet(n.Hash()) + if !ok { + // it can easily happen after receiving the same data from different peers. + return nil + } + var childrenPaths = make(map[util.Uint256][][]byte) + for _, path := range nPaths { + // Must clone here in order to avoid future collapse collisions. If the node's refcount>1 then MPT pool + // will manage all paths for this node and call RestoreHashNode separately for each of the paths. + err := s.billet.RestoreHashNode(path, n.Clone()) + if err != nil { + return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err) + } + for h, paths := range mpt.GetChildrenPaths(path, n) { + childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool + } + } + + s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths) + + for h := range childrenPaths { + if child, err := s.billet.GetFromStore(h); err == nil { + // child is already in the storage, so we don't need to request it one more time. + err = s.restoreNode(child) + if err != nil { + return fmt.Errorf("unable to restore saved children: %w", err) + } + } + } + return nil +} + +// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1 +// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored. +// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller +// should take care of it. +func (s *Module) checkSyncIsCompleted() { + if s.syncStage != headersSynced|mptSynced|blocksSynced { + return + } + s.log.Info("state is in sync", + zap.Uint32("state sync point", s.syncPoint)) + err := s.jumpCallback(s.syncPoint) + if err != nil { + s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err)) + } + s.syncStage = inactive + s.dispose() +} + +func (s *Module) dispose() { + s.billet = nil +} + +// BlockHeight returns index of the last stored block. +func (s *Module) BlockHeight() uint32 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.blockHeight +} + +// IsActive tells whether state sync module is on and still gathering state +// synchronisation data (headers, blocks or MPT nodes). +func (s *Module) IsActive() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return !(s.syncStage == inactive || (s.syncStage == headersSynced|mptSynced|blocksSynced)) +} + +// IsInitialized tells whether state sync module does not require initialization. +// If `false` is returned then Init can be safely called. +func (s *Module) IsInitialized() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.syncStage != none +} + +// NeedHeaders tells whether the module hasn't completed headers synchronisation. +func (s *Module) NeedHeaders() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.syncStage == initialized +} + +// NeedMPTNodes returns whether the module hasn't completed MPT synchronisation. +func (s *Module) NeedMPTNodes() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.syncStage&headersSynced != 0 && s.syncStage&mptSynced == 0 +} + +// Traverse traverses local MPT nodes starting from the specified root down to its +// children calling `process` for each serialised node until stop condition is satisfied. +func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + s.lock.RLock() + defer s.lock.RUnlock() + + b := mpt.NewBillet(root, s.bc.GetConfig().KeepOnlyLatestState, storage.NewMemCachedStore(s.dao.Store)) + return b.Traverse(process, false) +} + +// GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max). +func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.mptpool.GetBatch(limit) +} diff --git a/pkg/core/statesync/module_test.go b/pkg/core/statesync/module_test.go new file mode 100644 index 000000000..c1bf6b117 --- /dev/null +++ b/pkg/core/statesync/module_test.go @@ -0,0 +1,106 @@ +package statesync + +import ( + "fmt" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/dao" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestModule_PR2019_discussion_r689629704(t *testing.T) { + expectedStorage := storage.NewMemCachedStore(storage.NewMemoryStore()) + tr := mpt.NewTrie(nil, true, expectedStorage) + require.NoError(t, tr.Put([]byte{0x03}, []byte("leaf1"))) + require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x02}, []byte("leaf2"))) + require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x04}, []byte("leaf3"))) + require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x02}, []byte("leaf2"))) // <-- the same `leaf2` and `leaf3` values are put in the storage, + require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x04}, []byte("leaf3"))) // <-- but the path should differ. + require.NoError(t, tr.Put([]byte{0x06, 0x03}, []byte("leaf4"))) + + sr := tr.StateRoot() + tr.Flush() + + // Keep MPT nodes in a map in order not to repeat them. We'll use `nodes` map to ask + // state sync module to restore the nodes. + var ( + nodes = make(map[util.Uint256][]byte) + expectedItems []storage.KeyValue + ) + expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) { + key := slice.Copy(k) + value := slice.Copy(v) + expectedItems = append(expectedItems, storage.KeyValue{ + Key: key, + Value: value, + }) + hash, err := util.Uint256DecodeBytesBE(key[1:]) + require.NoError(t, err) + nodeBytes := value[:len(value)-4] + nodes[hash] = nodeBytes + }) + + actualStorage := storage.NewMemCachedStore(storage.NewMemoryStore()) + // These actions are done in module.Init(), but it's not the point of the test. + // Here we want to test only MPT restoring process. + stateSync := &Module{ + log: zaptest.NewLogger(t), + syncPoint: 1000500, + syncStage: headersSynced, + syncInterval: 100500, + dao: dao.NewSimple(actualStorage, true, false), + mptpool: NewPool(), + } + stateSync.billet = mpt.NewBillet(sr, true, actualStorage) + stateSync.mptpool.Add(sr, []byte{}) + + // The test itself: we'll ask state sync module to restore each node exactly once. + // After that storage content (including storage items and refcounts) must + // match exactly the one got from real MPT trie. MPT pool must be empty. + // State sync module must have mptSynced state in the end. + // MPT Billet root must become a collapsed hashnode (it was checked manually). + requested := make(map[util.Uint256]struct{}) + for { + unknownHashes := stateSync.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one + if len(unknownHashes) == 0 { + break + } + h := unknownHashes[0] + node, ok := nodes[h] + if !ok { + if _, ok = requested[h]; ok { + t.Fatal("node was requested twice") + } + t.Fatal("unknown node was requested") + } + require.NotPanics(t, func() { + err := stateSync.AddMPTNodes([][]byte{node}) + require.NoError(t, err) + }, fmt.Errorf("hash=%s, value=%s", h.StringBE(), string(node))) + requested[h] = struct{}{} + delete(nodes, h) + if len(nodes) == 0 { + break + } + } + require.Equal(t, headersSynced|mptSynced, stateSync.syncStage, "all nodes were sent exactly ones, but MPT wasn't restored") + require.Equal(t, 0, len(nodes), "not all nodes were requested by state sync module") + require.Equal(t, 0, stateSync.mptpool.Count(), "MPT was restored, but MPT pool still contains items") + + // Compare resulting storage items and refcounts. + var actualItems []storage.KeyValue + expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) { + key := slice.Copy(k) + value := slice.Copy(v) + actualItems = append(actualItems, storage.KeyValue{ + Key: key, + Value: value, + }) + }) + require.ElementsMatch(t, expectedItems, actualItems) +} diff --git a/pkg/core/statesync/mptpool.go b/pkg/core/statesync/mptpool.go new file mode 100644 index 000000000..819188246 --- /dev/null +++ b/pkg/core/statesync/mptpool.go @@ -0,0 +1,142 @@ +package statesync + +import ( + "bytes" + "sort" + "sync" + + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// Pool stores unknown MPT nodes along with the corresponding paths (single node is +// allowed to have multiple MPT paths). +type Pool struct { + lock sync.RWMutex + hashes map[util.Uint256][][]byte +} + +// NewPool returns new MPT node hashes pool. +func NewPool() *Pool { + return &Pool{ + hashes: make(map[util.Uint256][][]byte), + } +} + +// ContainsKey checks if MPT node with the specified hash is in the Pool. +func (mp *Pool) ContainsKey(hash util.Uint256) bool { + mp.lock.RLock() + defer mp.lock.RUnlock() + + _, ok := mp.hashes[hash] + return ok +} + +// TryGet returns a set of MPT paths for the specified HashNode. +func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) { + mp.lock.RLock() + defer mp.lock.RUnlock() + + paths, ok := mp.hashes[hash] + // need to copy here, because we can modify existing array of paths inside the pool. + res := make([][]byte, len(paths)) + copy(res, paths) + return res, ok +} + +// GetAll returns all MPT nodes with the corresponding paths from the pool. +func (mp *Pool) GetAll() map[util.Uint256][][]byte { + mp.lock.RLock() + defer mp.lock.RUnlock() + + return mp.hashes +} + +// GetBatch returns set of unknown MPT nodes hashes (`limit` at max). +func (mp *Pool) GetBatch(limit int) []util.Uint256 { + mp.lock.RLock() + defer mp.lock.RUnlock() + + count := len(mp.hashes) + if count > limit { + count = limit + } + result := make([]util.Uint256, 0, limit) + for h := range mp.hashes { + if count == 0 { + break + } + result = append(result, h) + count-- + } + return result +} + +// Remove removes MPT node from the pool by the specified hash. +func (mp *Pool) Remove(hash util.Uint256) { + mp.lock.Lock() + defer mp.lock.Unlock() + + delete(mp.hashes, hash) +} + +// Add adds path to the set of paths for the specified node. +func (mp *Pool) Add(hash util.Uint256, path []byte) { + mp.lock.Lock() + defer mp.lock.Unlock() + + mp.addPaths(hash, [][]byte{path}) +} + +// Update is an atomic operation and removes/adds specified nodes from/to the pool. +func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) { + mp.lock.Lock() + defer mp.lock.Unlock() + + for h, paths := range remove { + old := mp.hashes[h] + for _, path := range paths { + i := sort.Search(len(old), func(i int) bool { + return bytes.Compare(old[i], path) >= 0 + }) + if i < len(old) && bytes.Equal(old[i], path) { + old = append(old[:i], old[i+1:]...) + } + } + if len(old) == 0 { + delete(mp.hashes, h) + } else { + mp.hashes[h] = old + } + } + for h, paths := range add { + mp.addPaths(h, paths) + } +} + +// addPaths adds set of the specified node paths to the pool. +func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) { + old := mp.hashes[nodeHash] + for _, path := range paths { + i := sort.Search(len(old), func(i int) bool { + return bytes.Compare(old[i], path) >= 0 + }) + if i < len(old) && bytes.Equal(old[i], path) { + // then path is already added + continue + } + old = append(old, path) + if i != len(old)-1 { + copy(old[i+1:], old[i:]) + old[i] = path + } + } + mp.hashes[nodeHash] = old +} + +// Count returns the number of nodes in the pool. +func (mp *Pool) Count() int { + mp.lock.RLock() + defer mp.lock.RUnlock() + + return len(mp.hashes) +} diff --git a/pkg/core/statesync/mptpool_test.go b/pkg/core/statesync/mptpool_test.go new file mode 100644 index 000000000..2d094025c --- /dev/null +++ b/pkg/core/statesync/mptpool_test.go @@ -0,0 +1,123 @@ +package statesync + +import ( + "encoding/hex" + "testing" + + "github.com/nspcc-dev/neo-go/internal/random" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestPool_AddRemoveUpdate(t *testing.T) { + mp := NewPool() + + i1 := []byte{1, 2, 3} + i1h := util.Uint256{1, 2, 3} + i2 := []byte{2, 3, 4} + i2h := util.Uint256{2, 3, 4} + i3 := []byte{4, 5, 6} + i3h := util.Uint256{3, 4, 5} + i4 := []byte{3, 4, 5} // has the same hash as i3 + i5 := []byte{6, 7, 8} // has the same hash as i3 + mapAll := map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}, i3h: {i4, i3}} + + // No items + _, ok := mp.TryGet(i1h) + require.False(t, ok) + require.False(t, mp.ContainsKey(i1h)) + require.Equal(t, 0, mp.Count()) + require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll()) + + // Add i1, i2, check OK + mp.Add(i1h, i1) + mp.Add(i2h, i2) + itm, ok := mp.TryGet(i1h) + require.True(t, ok) + require.Equal(t, [][]byte{i1}, itm) + require.True(t, mp.ContainsKey(i1h)) + require.True(t, mp.ContainsKey(i2h)) + require.Equal(t, map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + + // Remove i1 and unexisting item + mp.Remove(i3h) + mp.Remove(i1h) + require.False(t, mp.ContainsKey(i1h)) + require.True(t, mp.ContainsKey(i2h)) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}}, mp.GetAll()) + require.Equal(t, 1, mp.Count()) + + // Update: remove nothing, add all + mp.Update(nil, mapAll) + require.Equal(t, mapAll, mp.GetAll()) + require.Equal(t, 3, mp.Count()) + // Update: remove all, add all + mp.Update(mapAll, mapAll) + require.Equal(t, mapAll, mp.GetAll()) // deletion first, addition after that + require.Equal(t, 3, mp.Count()) + // Update: remove all, add nothing + mp.Update(mapAll, nil) + require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll()) + require.Equal(t, 0, mp.Count()) + // Update: remove several, add several + mp.Update(map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + + // Update: remove nothing, add several with same hashes + mp.Update(nil, map[util.Uint256][][]byte{i3h: {i5, i4}}) // should be sorted by the pool + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i3, i5}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + // Update: remove several with same hashes, add nothing + mp.Update(map[util.Uint256][][]byte{i3h: {i5, i4}}, nil) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + // Update: remove several with same hashes, add several with same hashes + mp.Update(map[util.Uint256][][]byte{i3h: {i5, i3}}, map[util.Uint256][][]byte{i3h: {i5, i4}}) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i5}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) +} + +func TestPool_GetBatch(t *testing.T) { + check := func(t *testing.T, limit int, itemsCount int) { + mp := NewPool() + for i := 0; i < itemsCount; i++ { + mp.Add(random.Uint256(), []byte{0x01}) + } + batch := mp.GetBatch(limit) + if limit < itemsCount { + require.Equal(t, limit, len(batch)) + } else { + require.Equal(t, itemsCount, len(batch)) + } + } + + t.Run("limit less than items count", func(t *testing.T) { + check(t, 5, 6) + }) + t.Run("limit more than items count", func(t *testing.T) { + check(t, 6, 5) + }) + t.Run("items count limit", func(t *testing.T) { + check(t, 5, 5) + }) +} + +func TestPool_UpdateUsingSliceFromPool(t *testing.T) { + mp := NewPool() + p1, _ := hex.DecodeString("0f0a0f0f0f0f0f0f0104020b02080c0a06050e070b050404060206060d07080602030b04040b050e040406030f0708060c05") + p2, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040a0b000f04000b03090b02090b0e040f0d0b060d070e0b0b090b0906080602060c0d0f0e0d04070e") + p3, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040b010d01080f050f000a0d0e08060c040b050800050904060f050807080a080c07040d0107080007") + h, _ := util.Uint256DecodeStringBE("57e197679ef031bf2f0b466b20afe3f67ac04dcff80a1dc4d12dd98dd21a2511") + mp.Add(h, p1) + mp.Add(h, p2) + mp.Add(h, p3) + + toBeRemoved, ok := mp.TryGet(h) + require.True(t, ok) + + mp.Update(map[util.Uint256][][]byte{h: toBeRemoved}, nil) + // test that all items were successfully removed. + require.Equal(t, 0, len(mp.GetAll())) +} diff --git a/pkg/core/statesync_test.go b/pkg/core/statesync_test.go new file mode 100644 index 000000000..0751ab65e --- /dev/null +++ b/pkg/core/statesync_test.go @@ -0,0 +1,435 @@ +package core + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/config" + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" + "github.com/stretchr/testify/require" +) + +func TestStateSyncModule_Init(t *testing.T) { + var ( + stateSyncInterval = 2 + maxTraceable uint32 = 3 + ) + spoutCfg := func(c *config.Config) { + c.ProtocolConfiguration.StateRootInHeader = true + c.ProtocolConfiguration.P2PStateExchangeExtensions = true + c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval + c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable + } + bcSpout := newTestChainWithCustomCfg(t, spoutCfg) + for i := 0; i <= 2*stateSyncInterval+int(maxTraceable)+1; i++ { + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + } + + boltCfg := func(c *config.Config) { + spoutCfg(c) + c.ProtocolConfiguration.KeepOnlyLatestState = true + c.ProtocolConfiguration.RemoveUntraceableBlocks = true + } + t.Run("error: module disabled by config", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, func(c *config.Config) { + boltCfg(c) + c.ProtocolConfiguration.RemoveUntraceableBlocks = false + }) + module := bcBolt.GetStateSyncModule() + require.Error(t, module.Init(bcSpout.BlockHeight())) // module inactive (non-archival node) + }) + + t.Run("inactive: spout chain is too low to start state sync process", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(uint32(2*stateSyncInterval-1))) + require.False(t, module.IsActive()) + }) + + t.Run("inactive: bolt chain height is close enough to spout chain height", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + for i := 1; i < int(bcSpout.BlockHeight())-stateSyncInterval; i++ { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, bcBolt.AddBlock(b)) + } + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.False(t, module.IsActive()) + }) + + t.Run("error: bolt chain is too low to start state sync process", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock())) + + module := bcBolt.GetStateSyncModule() + require.Error(t, module.Init(uint32(3*stateSyncInterval))) + }) + + t.Run("initialized: no previous state sync point", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.True(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + }) + + t.Run("error: outdated state sync point in the storage", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + + module = bcBolt.GetStateSyncModule() + require.Error(t, module.Init(bcSpout.BlockHeight()+2*uint32(stateSyncInterval))) + }) + + t.Run("initialized: valid previous state sync point in the storage", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.True(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + }) + + t.Run("initialization from headers/blocks/mpt synced stages", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + + // firstly, fetch all headers to create proper DB state (where headers are in sync) + stateSyncPoint := (int(bcSpout.BlockHeight()) / stateSyncInterval) * stateSyncInterval + var expectedHeader *block.Header + for i := 1; i <= int(bcSpout.HeaderHeight()); i++ { + header, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, module.AddHeaders(header)) + if i == stateSyncPoint+1 { + expectedHeader = header + } + } + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + + // then create new statesync module with the same DB and check that state is proper + // (headers are in sync) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + unknownNodes := module.GetUnknownMPTNodesBatch(2) + require.Equal(t, 1, len(unknownNodes)) + require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0]) + + // add several blocks to create DB state where blocks are not in sync yet, but it's not a genesis. + for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint-stateSyncInterval-1; i++ { + block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, module.AddBlock(block)) + } + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight()) + + // then create new statesync module with the same DB and check that state is proper + // (blocks are not in sync yet) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(2) + require.Equal(t, 1, len(unknownNodes)) + require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0]) + require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight()) + + // add rest of blocks to create DB state where blocks are in sync + for i := stateSyncPoint - stateSyncInterval; i <= stateSyncPoint; i++ { + block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, module.AddBlock(block)) + } + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + lastBlock, err := bcBolt.GetBlock(expectedHeader.PrevHash) + require.NoError(t, err) + require.Equal(t, uint32(stateSyncPoint), lastBlock.Index) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + + // then create new statesync module with the same DB and check that state is proper + // (headers and blocks are in sync) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(2) + require.Equal(t, 1, len(unknownNodes)) + require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0]) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + + // add a few MPT nodes to create DB state where some of MPT nodes are missing + count := 5 + for { + unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one + if len(unknownHashes) == 0 { + break + } + err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool { + require.NoError(t, module.AddMPTNodes([][]byte{nodeBytes})) + return true // add nodes one-by-one + }) + require.NoError(t, err) + count-- + if count < 0 { + break + } + } + + // then create new statesync module with the same DB and check that state is proper + // (headers and blocks are in sync, mpt is not yet synced) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(100) + require.True(t, len(unknownNodes) > 0) + require.NotContains(t, unknownNodes, expectedHeader.PrevStateRoot) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + + // add the rest of MPT nodes and jump to state + for { + unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one + if len(unknownHashes) == 0 { + break + } + err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool { + require.NoError(t, module.AddMPTNodes([][]byte{slice.Copy(nodeBytes)})) + return true // add nodes one-by-one + }) + require.NoError(t, err) + } + + // check that module is inactive and statejump is completed + require.False(t, module.IsActive()) + require.False(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(1) + require.True(t, len(unknownNodes) == 0) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight()) + + // create new module from completed state: the module should recognise that state sync is completed + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.False(t, module.IsActive()) + require.False(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(1) + require.True(t, len(unknownNodes) == 0) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight()) + + // add one more block to the restored chain and start new module: the module should recognise state sync is completed + // and regular blocks processing was started + require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock())) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.False(t, module.IsActive()) + require.False(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(1) + require.True(t, len(unknownNodes) == 0) + require.Equal(t, uint32(stateSyncPoint)+1, module.BlockHeight()) + require.Equal(t, uint32(stateSyncPoint)+1, bcBolt.BlockHeight()) + }) +} + +func TestStateSyncModule_RestoreBasicChain(t *testing.T) { + var ( + stateSyncInterval = 4 + maxTraceable uint32 = 6 + stateSyncPoint = 16 + ) + spoutCfg := func(c *config.Config) { + c.ProtocolConfiguration.StateRootInHeader = true + c.ProtocolConfiguration.P2PStateExchangeExtensions = true + c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval + c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable + } + bcSpout := newTestChainWithCustomCfg(t, spoutCfg) + initBasicChain(t, bcSpout) + + // make spout chain higher that latest state sync point + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + require.Equal(t, uint32(stateSyncPoint+2), bcSpout.BlockHeight()) + + boltCfg := func(c *config.Config) { + spoutCfg(c) + c.ProtocolConfiguration.KeepOnlyLatestState = true + c.ProtocolConfiguration.RemoveUntraceableBlocks = true + } + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + + t.Run("error: add headers before initialisation", func(t *testing.T) { + h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(1)) + require.NoError(t, err) + require.Error(t, module.AddHeaders(h)) + }) + t.Run("no error: add blocks before initialisation", func(t *testing.T) { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(1)) + require.NoError(t, err) + require.NoError(t, module.AddBlock(b)) + }) + t.Run("error: add MPT nodes without initialisation", func(t *testing.T) { + require.Error(t, module.AddMPTNodes([][]byte{})) + }) + + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.True(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + + // add headers to module + headers := make([]*block.Header, 0, bcSpout.HeaderHeight()) + for i := uint32(1); i <= bcSpout.HeaderHeight(); i++ { + h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(int(i))) + require.NoError(t, err) + headers = append(headers, h) + } + require.NoError(t, module.AddHeaders(headers...)) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + require.Equal(t, bcSpout.HeaderHeight(), bcBolt.HeaderHeight()) + + // add blocks + t.Run("error: unexpected block index", func(t *testing.T) { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(stateSyncPoint - int(maxTraceable))) + require.NoError(t, err) + require.Error(t, module.AddBlock(b)) + }) + t.Run("error: missing state root in block header", func(t *testing.T) { + b := &block.Block{ + Header: block.Header{ + Index: uint32(stateSyncPoint) - maxTraceable + 1, + StateRootEnabled: false, + }, + } + require.Error(t, module.AddBlock(b)) + }) + t.Run("error: invalid block merkle root", func(t *testing.T) { + b := &block.Block{ + Header: block.Header{ + Index: uint32(stateSyncPoint) - maxTraceable + 1, + StateRootEnabled: true, + MerkleRoot: util.Uint256{1, 2, 3}, + }, + } + require.Error(t, module.AddBlock(b)) + }) + + for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint; i++ { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, module.AddBlock(b)) + } + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + + // add MPT nodes in batches + h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(stateSyncPoint + 1)) + require.NoError(t, err) + unknownHashes := module.GetUnknownMPTNodesBatch(100) + require.Equal(t, 1, len(unknownHashes)) + require.Equal(t, h.PrevStateRoot, unknownHashes[0]) + nodesMap := make(map[util.Uint256][]byte) + err = bcSpout.GetStateSyncModule().Traverse(h.PrevStateRoot, func(n mpt.Node, nodeBytes []byte) bool { + nodesMap[n.Hash()] = nodeBytes + return false + }) + require.NoError(t, err) + for { + need := module.GetUnknownMPTNodesBatch(10) + if len(need) == 0 { + break + } + add := make([][]byte, len(need)) + for i, h := range need { + nodeBytes, ok := nodesMap[h] + if !ok { + t.Fatal("unknown or restored node requested") + } + add[i] = nodeBytes + delete(nodesMap, h) + } + require.NoError(t, module.AddMPTNodes(add)) + } + require.False(t, module.IsActive()) + require.False(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + unknownNodes := module.GetUnknownMPTNodesBatch(1) + require.True(t, len(unknownNodes) == 0) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight()) + + // add missing blocks to bcBolt: should be ok, because state is synced + for i := stateSyncPoint + 1; i <= int(bcSpout.BlockHeight()); i++ { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, bcBolt.AddBlock(b)) + } + require.Equal(t, bcSpout.BlockHeight(), bcBolt.BlockHeight()) + + // compare storage states + fetchStorage := func(bc *Blockchain) []storage.KeyValue { + var kv []storage.KeyValue + bc.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) { + key := slice.Copy(k) + value := slice.Copy(v) + kv = append(kv, storage.KeyValue{ + Key: key, + Value: value, + }) + }) + return kv + } + expected := fetchStorage(bcSpout) + actual := fetchStorage(bcBolt) + require.ElementsMatch(t, expected, actual) + + // no temp items should be left + bcBolt.dao.Store.Seek(storage.STTempStorage.Bytes(), func(k, v []byte) { + t.Fatal("temp storage items are found") + }) +} diff --git a/pkg/core/storage/store.go b/pkg/core/storage/store.go index dd2376c63..bd62f6001 100644 --- a/pkg/core/storage/store.go +++ b/pkg/core/storage/store.go @@ -8,13 +8,18 @@ import ( // KeyPrefix constants. const ( - DataBlock KeyPrefix = 0x01 - DataTransaction KeyPrefix = 0x02 - DataMPT KeyPrefix = 0x03 - STAccount KeyPrefix = 0x40 - STNotification KeyPrefix = 0x4d - STContractID KeyPrefix = 0x51 - STStorage KeyPrefix = 0x70 + DataBlock KeyPrefix = 0x01 + DataTransaction KeyPrefix = 0x02 + DataMPT KeyPrefix = 0x03 + STAccount KeyPrefix = 0x40 + STNotification KeyPrefix = 0x4d + STContractID KeyPrefix = 0x51 + STStorage KeyPrefix = 0x70 + // STTempStorage is used to store contract storage items during state sync process + // in order not to mess up the previous state which has its own items stored by + // STStorage prefix. Once state exchange process is completed, all items with + // STStorage prefix will be replaced with STTempStorage-prefixed ones. + STTempStorage KeyPrefix = 0x71 STNEP17Transfers KeyPrefix = 0x72 STNEP17TransferInfo KeyPrefix = 0x73 IXHeaderHashList KeyPrefix = 0x80 @@ -22,6 +27,7 @@ const ( SYSCurrentHeader KeyPrefix = 0xc1 SYSStateSyncCurrentBlockHeight KeyPrefix = 0xc2 SYSStateSyncPoint KeyPrefix = 0xc3 + SYSStateJumpStage KeyPrefix = 0xc4 SYSVersion KeyPrefix = 0xf0 ) diff --git a/pkg/network/blockqueue.go b/pkg/network/blockqueue.go index 8a21cab1b..7cc796886 100644 --- a/pkg/network/blockqueue.go +++ b/pkg/network/blockqueue.go @@ -4,6 +4,7 @@ import ( "github.com/Workiva/go-datastructures/queue" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -13,6 +14,7 @@ type blockQueue struct { checkBlocks chan struct{} chain blockchainer.Blockqueuer relayF func(*block.Block) + discarded *atomic.Bool } const ( @@ -32,6 +34,7 @@ func newBlockQueue(capacity int, bc blockchainer.Blockqueuer, log *zap.Logger, r checkBlocks: make(chan struct{}, 1), chain: bc, relayF: relayer, + discarded: atomic.NewBool(false), } } @@ -91,8 +94,10 @@ func (bq *blockQueue) putBlock(block *block.Block) error { } func (bq *blockQueue) discard() { - close(bq.checkBlocks) - bq.queue.Dispose() + if bq.discarded.CAS(false, true) { + close(bq.checkBlocks) + bq.queue.Dispose() + } } func (bq *blockQueue) length() int { diff --git a/pkg/network/message.go b/pkg/network/message.go index cbdaa554f..fb605c8a7 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -71,6 +71,8 @@ const ( CMDBlock = CommandType(payload.BlockType) CMDExtensible = CommandType(payload.ExtensibleType) CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType) + CMDGetMPTData CommandType = 0x51 // 0x5.. commands are used for extensions (P2PNotary, state exchange cmds) + CMDMPTData CommandType = 0x52 CMDReject CommandType = 0x2f // SPV protocol. @@ -136,6 +138,10 @@ func (m *Message) decodePayload() error { p = &payload.Version{} case CMDInv, CMDGetData: p = &payload.Inventory{} + case CMDGetMPTData: + p = &payload.MPTInventory{} + case CMDMPTData: + p = &payload.MPTData{} case CMDAddr: p = &payload.AddressList{} case CMDBlock: @@ -221,7 +227,7 @@ func (m *Message) tryCompressPayload() error { if m.Flags&Compressed == 0 { switch m.Payload.(type) { case *payload.Headers, *payload.MerkleBlock, payload.NullPayload, - *payload.Inventory: + *payload.Inventory, *payload.MPTInventory: break default: size := len(compressedPayload) diff --git a/pkg/network/message_string.go b/pkg/network/message_string.go index 7da007079..2ebdacd9b 100644 --- a/pkg/network/message_string.go +++ b/pkg/network/message_string.go @@ -26,6 +26,8 @@ func _() { _ = x[CMDBlock-44] _ = x[CMDExtensible-46] _ = x[CMDP2PNotaryRequest-80] + _ = x[CMDGetMPTData-81] + _ = x[CMDMPTData-82] _ = x[CMDReject-47] _ = x[CMDFilterLoad-48] _ = x[CMDFilterAdd-49] @@ -44,7 +46,7 @@ const ( _CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear" _CommandType_name_7 = "CMDMerkleBlock" _CommandType_name_8 = "CMDAlert" - _CommandType_name_9 = "CMDP2PNotaryRequest" + _CommandType_name_9 = "CMDP2PNotaryRequestCMDGetMPTDataCMDMPTData" ) var ( @@ -55,6 +57,7 @@ var ( _CommandType_index_4 = [...]uint8{0, 12, 22} _CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58} _CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61} + _CommandType_index_9 = [...]uint8{0, 19, 32, 42} ) func (i CommandType) String() string { @@ -83,8 +86,9 @@ func (i CommandType) String() string { return _CommandType_name_7 case i == 64: return _CommandType_name_8 - case i == 80: - return _CommandType_name_9 + case 80 <= i && i <= 82: + i -= 80 + return _CommandType_name_9[_CommandType_index_9[i]:_CommandType_index_9[i+1]] default: return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/network/message_test.go b/pkg/network/message_test.go index c4976a702..38b620189 100644 --- a/pkg/network/message_test.go +++ b/pkg/network/message_test.go @@ -258,6 +258,21 @@ func TestEncodeDecodeNotFound(t *testing.T) { }) } +func TestEncodeDecodeGetMPTData(t *testing.T) { + testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{ + Hashes: []util.Uint256{ + {1, 2, 3}, + {4, 5, 6}, + }, + }) +} + +func TestEncodeDecodeMPTData(t *testing.T) { + testEncodeDecode(t, CMDMPTData, &payload.MPTData{ + Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}}, + }) +} + func TestInvalidMessages(t *testing.T) { t.Run("CMDBlock, empty payload", func(t *testing.T) { testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{}) diff --git a/pkg/network/payload/mptdata.go b/pkg/network/payload/mptdata.go new file mode 100644 index 000000000..ba607a86b --- /dev/null +++ b/pkg/network/payload/mptdata.go @@ -0,0 +1,35 @@ +package payload + +import ( + "errors" + + "github.com/nspcc-dev/neo-go/pkg/io" +) + +// MPTData represents the set of serialized MPT nodes. +type MPTData struct { + Nodes [][]byte +} + +// EncodeBinary implements io.Serializable. +func (d *MPTData) EncodeBinary(w *io.BinWriter) { + w.WriteVarUint(uint64(len(d.Nodes))) + for _, n := range d.Nodes { + w.WriteVarBytes(n) + } +} + +// DecodeBinary implements io.Serializable. +func (d *MPTData) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz == 0 { + r.Err = errors.New("empty MPT nodes list") + return + } + for i := uint64(0); i < sz; i++ { + d.Nodes = append(d.Nodes, r.ReadVarBytes()) + if r.Err != nil { + return + } + } +} diff --git a/pkg/network/payload/mptdata_test.go b/pkg/network/payload/mptdata_test.go new file mode 100644 index 000000000..d9db7bff6 --- /dev/null +++ b/pkg/network/payload/mptdata_test.go @@ -0,0 +1,24 @@ +package payload + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/stretchr/testify/require" +) + +func TestMPTData_EncodeDecodeBinary(t *testing.T) { + t.Run("empty", func(t *testing.T) { + d := new(MPTData) + bytes, err := testserdes.EncodeBinary(d) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(bytes, new(MPTData))) + }) + + t.Run("good", func(t *testing.T) { + d := &MPTData{ + Nodes: [][]byte{{}, {1}, {1, 2, 3}}, + } + testserdes.EncodeDecodeBinary(t, d, new(MPTData)) + }) +} diff --git a/pkg/network/payload/mptinventory.go b/pkg/network/payload/mptinventory.go new file mode 100644 index 000000000..66aaf5c02 --- /dev/null +++ b/pkg/network/payload/mptinventory.go @@ -0,0 +1,32 @@ +package payload + +import ( + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxMPTHashesCount is the maximum number of requested MPT nodes hashes. +const MaxMPTHashesCount = 32 + +// MPTInventory payload. +type MPTInventory struct { + // A list of requested MPT nodes hashes. + Hashes []util.Uint256 +} + +// NewMPTInventory return a pointer to an MPTInventory. +func NewMPTInventory(hashes []util.Uint256) *MPTInventory { + return &MPTInventory{ + Hashes: hashes, + } +} + +// DecodeBinary implements Serializable interface. +func (p *MPTInventory) DecodeBinary(br *io.BinReader) { + br.ReadArray(&p.Hashes, MaxMPTHashesCount) +} + +// EncodeBinary implements Serializable interface. +func (p *MPTInventory) EncodeBinary(bw *io.BinWriter) { + bw.WriteArray(p.Hashes) +} diff --git a/pkg/network/payload/mptinventory_test.go b/pkg/network/payload/mptinventory_test.go new file mode 100644 index 000000000..a0c052a67 --- /dev/null +++ b/pkg/network/payload/mptinventory_test.go @@ -0,0 +1,38 @@ +package payload + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestMPTInventory_EncodeDecodeBinary(t *testing.T) { + t.Run("empty", func(t *testing.T) { + testserdes.EncodeDecodeBinary(t, NewMPTInventory([]util.Uint256{}), new(MPTInventory)) + }) + + t.Run("good", func(t *testing.T) { + inv := NewMPTInventory([]util.Uint256{{1, 2, 3}, {2, 3, 4}}) + testserdes.EncodeDecodeBinary(t, inv, new(MPTInventory)) + }) + + t.Run("too large", func(t *testing.T) { + check := func(t *testing.T, count int, fail bool) { + h := make([]util.Uint256, count) + for i := range h { + h[i] = util.Uint256{1, 2, 3} + } + if fail { + bytes, err := testserdes.EncodeBinary(NewMPTInventory(h)) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(bytes, new(MPTInventory))) + } else { + testserdes.EncodeDecodeBinary(t, NewMPTInventory(h), new(MPTInventory)) + } + } + check(t, MaxMPTHashesCount, false) + check(t, MaxMPTHashesCount+1, true) + }) +} diff --git a/pkg/network/server.go b/pkg/network/server.go index 457add0b0..16a6a46f5 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -7,6 +7,7 @@ import ( "fmt" mrand "math/rand" "net" + "sort" "strconv" "sync" "time" @@ -17,7 +18,9 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/extpool" "github.com/nspcc-dev/neo-go/pkg/network/payload" @@ -67,6 +70,7 @@ type ( discovery Discoverer chain blockchainer.Blockchainer bQueue *blockQueue + bSyncQueue *blockQueue consensus consensus.Service mempool *mempool.Pool notaryRequestPool *mempool.Pool @@ -93,6 +97,7 @@ type ( oracle *oracle.Oracle stateRoot stateroot.Service + stateSync blockchainer.StateSync log *zap.Logger } @@ -191,6 +196,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai } s.stateRoot = sr + sSync := chain.GetStateSyncModule() + s.stateSync = sSync + s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil) + if config.OracleCfg.Enabled { orcCfg := oracle.Config{ Log: log, @@ -277,6 +286,7 @@ func (s *Server) Start(errChan chan error) { go s.broadcastTxLoop() go s.relayBlocksLoop() go s.bQueue.run() + go s.bSyncQueue.run() go s.transport.Accept() setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10)) s.run() @@ -292,6 +302,7 @@ func (s *Server) Shutdown() { p.Disconnect(errServerShutdown) } s.bQueue.discard() + s.bSyncQueue.discard() if s.StateRootCfg.Enabled { s.stateRoot.Shutdown() } @@ -573,6 +584,10 @@ func (s *Server) IsInSync() bool { var peersNumber int var notHigher int + if s.stateSync.IsActive() { + return false + } + if s.MinPeers == 0 { return true } @@ -630,6 +645,9 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { // handleBlockCmd processes the received block received from its peer. func (s *Server) handleBlockCmd(p Peer, block *block.Block) error { + if s.stateSync.IsActive() { + return s.bSyncQueue.putBlock(block) + } return s.bQueue.putBlock(block) } @@ -639,25 +657,57 @@ func (s *Server) handlePing(p Peer, ping *payload.Ping) error { if err != nil { return err } - if s.chain.BlockHeight() < ping.LastBlockIndex { - err = s.requestBlocks(p) - if err != nil { - return err - } + err = s.requestBlocksOrHeaders(p) + if err != nil { + return err } return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id))) } +func (s *Server) requestBlocksOrHeaders(p Peer) error { + if s.stateSync.NeedHeaders() { + if s.chain.HeaderHeight() < p.LastBlockIndex() { + return s.requestHeaders(p) + } + return nil + } + var ( + bq blockchainer.Blockqueuer = s.chain + requestMPTNodes bool + ) + if s.stateSync.IsActive() { + bq = s.stateSync + requestMPTNodes = s.stateSync.NeedMPTNodes() + } + if bq.BlockHeight() >= p.LastBlockIndex() { + return nil + } + err := s.requestBlocks(bq, p) + if err != nil { + return err + } + if requestMPTNodes { + return s.requestMPTNodes(p, s.stateSync.GetUnknownMPTNodesBatch(payload.MaxMPTHashesCount)) + } + return nil +} + +// requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers. +func (s *Server) requestHeaders(p Peer) error { + // TODO: optimize + currHeight := s.chain.HeaderHeight() + needHeight := currHeight + 1 + payload := payload.NewGetBlockByIndex(needHeight, -1) + return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload)) +} + // handlePing processes pong request. func (s *Server) handlePong(p Peer, pong *payload.Ping) error { err := p.HandlePong(pong) if err != nil { return err } - if s.chain.BlockHeight() < pong.LastBlockIndex { - return s.requestBlocks(p) - } - return nil + return s.requestBlocksOrHeaders(p) } // handleInvCmd processes the received inventory. @@ -766,6 +816,69 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { return nil } +// handleGetMPTDataCmd processes the received MPT inventory. +func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error { + if !s.chain.GetConfig().P2PStateExchangeExtensions { + return errors.New("GetMPTDataCMD was received, but P2PStateExchangeExtensions are disabled") + } + if s.chain.GetConfig().KeepOnlyLatestState { + // TODO: implement keeping MPT states for P1 and P2 height (#2095, #2152 related) + return errors.New("GetMPTDataCMD was received, but only latest MPT state is supported") + } + resp := payload.MPTData{} + capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes))) + added := make(map[util.Uint256]struct{}) + for _, h := range inv.Hashes { + if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type + break + } + err := s.stateSync.Traverse(h, + func(n mpt.Node, node []byte) bool { + if _, ok := added[n.Hash()]; ok { + return false + } + l := len(node) + size := l + io.GetVarSize(l) + if size > capLeft { + return true + } + resp.Nodes = append(resp.Nodes, node) + added[n.Hash()] = struct{}{} + capLeft -= size + return false + }) + if err != nil { + return fmt.Errorf("failed to traverse MPT starting from %s: %w", h.StringBE(), err) + } + } + if len(resp.Nodes) > 0 { + msg := NewMessage(CMDMPTData, &resp) + return p.EnqueueP2PMessage(msg) + } + return nil +} + +func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error { + if !s.chain.GetConfig().P2PStateExchangeExtensions { + return errors.New("MPTDataCMD was received, but P2PStateExchangeExtensions are disabled") + } + return s.stateSync.AddMPTNodes(data.Nodes) +} + +// requestMPTNodes requests specified MPT nodes from the peer or broadcasts +// request if peer is not specified. +func (s *Server) requestMPTNodes(p Peer, itms []util.Uint256) error { + if len(itms) == 0 { + return nil + } + if len(itms) > payload.MaxMPTHashesCount { + itms = itms[:payload.MaxMPTHashesCount] + } + pl := payload.NewMPTInventory(itms) + msg := NewMessage(CMDGetMPTData, pl) + return p.EnqueueP2PMessage(msg) +} + // handleGetBlocksCmd processes the getblocks request. func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { count := gb.Count @@ -845,6 +958,11 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error return p.EnqueueP2PMessage(msg) } +// handleHeadersCmd processes headers payload. +func (s *Server) handleHeadersCmd(p Peer, h *payload.Headers) error { + return s.stateSync.AddHeaders(h.Hdrs...) +} + // handleExtensibleCmd processes received extensible payload. func (s *Server) handleExtensibleCmd(e *payload.Extensible) error { if !s.syncReached.Load() { @@ -993,8 +1111,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error { // 1. Block range is divided into chunks of payload.MaxHashesCount. // 2. Send requests for chunk in increasing order. // 3. After all requests were sent, request random height. -func (s *Server) requestBlocks(p Peer) error { - var currHeight = s.chain.BlockHeight() +func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error { + var currHeight = bq.BlockHeight() var peerHeight = p.LastBlockIndex() var needHeight uint32 // lastRequestedHeight can only be increased. @@ -1051,9 +1169,18 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDGetData: inv := msg.Payload.(*payload.Inventory) return s.handleGetDataCmd(peer, inv) + case CMDGetMPTData: + inv := msg.Payload.(*payload.MPTInventory) + return s.handleGetMPTDataCmd(peer, inv) + case CMDMPTData: + inv := msg.Payload.(*payload.MPTData) + return s.handleMPTDataCmd(peer, inv) case CMDGetHeaders: gh := msg.Payload.(*payload.GetBlockByIndex) return s.handleGetHeadersCmd(peer, gh) + case CMDHeaders: + h := msg.Payload.(*payload.Headers) + return s.handleHeadersCmd(peer, h) case CMDInv: inventory := msg.Payload.(*payload.Inventory) return s.handleInvCmd(peer, inventory) @@ -1093,6 +1220,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { } go peer.StartProtocol() + s.tryInitStateSync() s.tryStartServices() default: return fmt.Errorf("received '%s' during handshake", msg.Command.String()) @@ -1101,6 +1229,52 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return nil } +func (s *Server) tryInitStateSync() { + if !s.stateSync.IsActive() { + s.bSyncQueue.discard() + return + } + + if s.stateSync.IsInitialized() { + return + } + + var peersNumber int + s.lock.RLock() + heights := make([]uint32, 0) + for p := range s.peers { + if p.Handshaked() { + peersNumber++ + peerLastBlock := p.LastBlockIndex() + i := sort.Search(len(heights), func(i int) bool { + return heights[i] >= peerLastBlock + }) + heights = append(heights, peerLastBlock) + if i != len(heights)-1 { + copy(heights[i+1:], heights[i:]) + heights[i] = peerLastBlock + } + } + } + s.lock.RUnlock() + if peersNumber >= s.MinPeers && len(heights) > 0 { + // choose the height of the median peer as current chain's height + h := heights[len(heights)/2] + err := s.stateSync.Init(h) + if err != nil { + s.log.Fatal("failed to init state sync module", + zap.Uint32("evaluated chain's blockHeight", h), + zap.Uint32("blockHeight", s.chain.BlockHeight()), + zap.Uint32("headerHeight", s.chain.HeaderHeight()), + zap.Error(err)) + } + + // module can be inactive after init (i.e. full state is collected and ordinary block processing is needed) + if !s.stateSync.IsActive() { + s.bSyncQueue.discard() + } + } +} func (s *Server) handleNewPayload(p *payload.Extensible) { _, err := s.extensiblePool.Add(p) if err != nil { diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 0c9c9a358..69dc0d86b 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -2,6 +2,7 @@ package network import ( "errors" + "fmt" "math/big" "net" "strconv" @@ -16,6 +17,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" @@ -46,7 +48,10 @@ func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") } func TestNewServer(t *testing.T) { - bc := &fakechain.FakeChain{} + bc := &fakechain.FakeChain{ProtocolConfiguration: config.ProtocolConfiguration{ + P2PStateExchangeExtensions: true, + StateRootInHeader: true, + }} s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery) require.Error(t, err) @@ -733,6 +738,139 @@ func TestInv(t *testing.T) { }) } +func TestHandleGetMPTData(t *testing.T) { + t.Run("P2PStateExchange extensions off", func(t *testing.T) { + s := startTestServer(t) + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{ + Hashes: []util.Uint256{{1, 2, 3}}, + }) + require.Error(t, s.handleMessage(p, msg)) + }) + + t.Run("KeepOnlyLatestState on", func(t *testing.T) { + s := startTestServer(t) + s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{ + Hashes: []util.Uint256{{1, 2, 3}}, + }) + require.Error(t, s.handleMessage(p, msg)) + }) + + t.Run("good", func(t *testing.T) { + s := startTestServer(t) + s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + var recvResponse atomic.Bool + r1 := random.Uint256() + r2 := random.Uint256() + r3 := random.Uint256() + node := []byte{1, 2, 3} + s.stateSync.(*fakechain.FakeStateSync).TraverseFunc = func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + if !(root.Equals(r1) || root.Equals(r2)) { + t.Fatal("unexpected root") + } + require.False(t, process(mpt.NewHashNode(r3), node)) + return nil + } + found := &payload.MPTData{ + Nodes: [][]byte{node}, // no duplicates expected + } + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + switch msg.Command { + case CMDMPTData: + require.Equal(t, found, msg.Payload) + recvResponse.Store(true) + } + } + hs := []util.Uint256{r1, r2} + s.testHandleMessage(t, p, CMDGetMPTData, payload.NewMPTInventory(hs)) + + require.Eventually(t, recvResponse.Load, time.Second, time.Millisecond) + }) +} + +func TestHandleMPTData(t *testing.T) { + t.Run("P2PStateExchange extensions off", func(t *testing.T) { + s := startTestServer(t) + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDMPTData, &payload.MPTData{ + Nodes: [][]byte{{1, 2, 3}}, + }) + require.Error(t, s.handleMessage(p, msg)) + }) + + t.Run("good", func(t *testing.T) { + s := startTestServer(t) + expected := [][]byte{{1, 2, 3}, {2, 3, 4}} + s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + s.stateSync = &fakechain.FakeStateSync{ + AddMPTNodesFunc: func(nodes [][]byte) error { + require.Equal(t, expected, nodes) + return nil + }, + } + + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDMPTData, &payload.MPTData{ + Nodes: expected, + }) + require.NoError(t, s.handleMessage(p, msg)) + }) +} + +func TestRequestMPTNodes(t *testing.T) { + s := startTestServer(t) + + var actual []util.Uint256 + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDGetMPTData { + actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...) + } + } + s.register <- p + s.register <- p // ensure previous send was handled + + t.Run("no hashes, no message", func(t *testing.T) { + actual = nil + require.NoError(t, s.requestMPTNodes(p, nil)) + require.Nil(t, actual) + }) + t.Run("good, small", func(t *testing.T) { + actual = nil + expected := []util.Uint256{random.Uint256(), random.Uint256()} + require.NoError(t, s.requestMPTNodes(p, expected)) + require.Equal(t, expected, actual) + }) + t.Run("good, exactly one chunk", func(t *testing.T) { + actual = nil + expected := make([]util.Uint256, payload.MaxMPTHashesCount) + for i := range expected { + expected[i] = random.Uint256() + } + require.NoError(t, s.requestMPTNodes(p, expected)) + require.Equal(t, expected, actual) + }) + t.Run("good, too large chunk", func(t *testing.T) { + actual = nil + expected := make([]util.Uint256, payload.MaxMPTHashesCount+1) + for i := range expected { + expected[i] = random.Uint256() + } + require.NoError(t, s.requestMPTNodes(p, expected)) + require.Equal(t, expected[:payload.MaxMPTHashesCount], actual) + }) +} + func TestRequestTx(t *testing.T) { s := startTestServer(t) @@ -899,3 +1037,44 @@ func TestVerifyNotaryRequest(t *testing.T) { require.NoError(t, verifyNotaryRequest(bc, nil, r)) }) } + +func TestTryInitStateSync(t *testing.T) { + t.Run("module inactive", func(t *testing.T) { + s := startTestServer(t) + s.tryInitStateSync() + }) + + t.Run("module already initialized", func(t *testing.T) { + s := startTestServer(t) + ss := &fakechain.FakeStateSync{} + ss.IsActiveFlag.Store(true) + ss.IsInitializedFlag.Store(true) + s.stateSync = ss + s.tryInitStateSync() + }) + + t.Run("good", func(t *testing.T) { + s := startTestServer(t) + for _, h := range []uint32{10, 8, 7, 4, 11, 4} { + p := newLocalPeer(t, s) + p.handshaked = true + p.lastBlockIndex = h + s.peers[p] = true + } + p := newLocalPeer(t, s) + p.handshaked = false // one disconnected peer to check it won't be taken into attention + p.lastBlockIndex = 5 + s.peers[p] = true + var expectedH uint32 = 8 // median peer + + ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error { + if h != expectedH { + return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h) + } + return nil + }} + ss.IsActiveFlag.Store(true) + s.stateSync = ss + s.tryInitStateSync() + }) +} diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 8ff47a18c..5b159fda1 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -267,12 +267,10 @@ func (p *TCPPeer) StartProtocol() { zap.Uint32("id", p.Version().Nonce)) p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities) - if p.server.chain.BlockHeight() < p.LastBlockIndex() { - err = p.server.requestBlocks(p) - if err != nil { - p.Disconnect(err) - return - } + err = p.server.requestBlocksOrHeaders(p) + if err != nil { + p.Disconnect(err) + return } timer := time.NewTimer(p.server.ProtoTickInterval) @@ -281,10 +279,8 @@ func (p *TCPPeer) StartProtocol() { case <-p.done: return case <-timer.C: - // Try to sync in headers and block with the peer if his block height is higher then ours. - if p.LastBlockIndex() > p.server.chain.BlockHeight() { - err = p.server.requestBlocks(p) - } + // Try to sync in headers and block with the peer if his block height is higher than ours. + err = p.server.requestBlocksOrHeaders(p) if err == nil { timer.Reset(p.server.ProtoTickInterval) } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 3fa82c659..8bd2e26a3 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -679,6 +679,7 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err if err != nil { return nil, response.NewRPCError("Failed to get NEP17 last updated block", err.Error(), err) } + stateSyncPoint := lastUpdated[math.MinInt32] bw := io.NewBufBinWriter() for _, h := range s.chain.GetNEP17Contracts() { balance, err := s.getNEP17Balance(h, u, bw) @@ -692,10 +693,18 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err if cs == nil { continue } + lub, ok := lastUpdated[cs.ID] + if !ok { + cfg := s.chain.GetConfig() + if !cfg.P2PStateExchangeExtensions && cfg.RemoveUntraceableBlocks { + return nil, response.NewInternalServerError(fmt.Sprintf("failed to get LastUpdatedBlock for balance of %s token", cs.Hash.StringLE()), nil) + } + lub = stateSyncPoint + } bs.Balances = append(bs.Balances, result.NEP17Balance{ Asset: h, Amount: balance.String(), - LastUpdated: lastUpdated[cs.ID], + LastUpdated: lub, }) } return bs, nil