Merge pull request #2019 from nspcc-dev/states-exchange/cmd

core, network: implement P2P state exchange
This commit is contained in:
Roman Khimov 2021-09-08 17:22:19 +03:00 committed by GitHub
commit 752c7f106b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
39 changed files with 3001 additions and 53 deletions

View file

@ -114,7 +114,11 @@ balance won't be shown in the list of NEP17 balances returned by the neo-go node
(unlike the C# node behavior). However, transfer logs of such token are still
available via `getnep17transfers` RPC call.
The behaviour of the `LastUpdatedBlock` tracking matches the C# node's one.
The behaviour of the `LastUpdatedBlock` tracking for archival nodes as far as for
governing token balances matches the C# node's one. For non-archival nodes and
other NEP17-compliant tokens if transfer's `LastUpdatedBlock` is lower than the
latest state synchronization point P the node working against, then
`LastUpdatedBlock` equals P.
### Unsupported methods

View file

@ -13,6 +13,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services"
"github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/native"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
@ -21,6 +22,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
uatomic "go.uber.org/atomic"
)
// FakeChain implements Blockchainer interface, but does not provide real functionality.
@ -42,6 +44,15 @@ type FakeChain struct {
UtilityTokenBalance *big.Int
}
// FakeStateSync implements StateSync interface.
type FakeStateSync struct {
IsActiveFlag uatomic.Bool
IsInitializedFlag uatomic.Bool
InitFunc func(h uint32) error
TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
AddMPTNodesFunc func(nodes [][]byte) error
}
// NewFakeChain returns new FakeChain structure.
func NewFakeChain() *FakeChain {
return &FakeChain{
@ -294,6 +305,11 @@ func (chain *FakeChain) GetStateModule() blockchainer.StateRoot {
return nil
}
// GetStateSyncModule implements Blockchainer interface.
func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync {
return &FakeStateSync{}
}
// GetStorageItem implements Blockchainer interface.
func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
panic("TODO")
@ -436,3 +452,63 @@ func (chain *FakeChain) UnsubscribeFromNotifications(ch chan<- *state.Notificati
func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
panic("TODO")
}
// AddBlock implements StateSync interface.
func (s *FakeStateSync) AddBlock(block *block.Block) error {
panic("TODO")
}
// AddHeaders implements StateSync interface.
func (s *FakeStateSync) AddHeaders(...*block.Header) error {
panic("TODO")
}
// AddMPTNodes implements StateSync interface.
func (s *FakeStateSync) AddMPTNodes(nodes [][]byte) error {
if s.AddMPTNodesFunc != nil {
return s.AddMPTNodesFunc(nodes)
}
panic("TODO")
}
// BlockHeight implements StateSync interface.
func (s *FakeStateSync) BlockHeight() uint32 {
panic("TODO")
}
// IsActive implements StateSync interface.
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag.Load() }
// IsInitialized implements StateSync interface.
func (s *FakeStateSync) IsInitialized() bool {
return s.IsInitializedFlag.Load()
}
// Init implements StateSync interface.
func (s *FakeStateSync) Init(currChainHeight uint32) error {
if s.InitFunc != nil {
return s.InitFunc(currChainHeight)
}
panic("TODO")
}
// NeedHeaders implements StateSync interface.
func (s *FakeStateSync) NeedHeaders() bool { return false }
// NeedMPTNodes implements StateSync interface.
func (s *FakeStateSync) NeedMPTNodes() bool {
panic("TODO")
}
// Traverse implements StateSync interface.
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
if s.TraverseFunc != nil {
return s.TraverseFunc(root, process)
}
panic("TODO")
}
// GetUnknownMPTNodesBatch implements StateSync interface.
func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
panic("TODO")
}

View file

@ -23,6 +23,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/native/noderoles"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/stateroot"
"github.com/nspcc-dev/neo-go/pkg/core/statesync"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
@ -34,6 +35,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"go.uber.org/zap"
@ -42,7 +44,7 @@ import (
// Tuning parameters.
const (
headerBatchCount = 2000
version = "0.1.2"
version = "0.1.4"
defaultInitialGAS = 52000000_00000000
defaultMemPoolSize = 50000
@ -54,6 +56,31 @@ const (
// HeaderVerificationGasLimit is the maximum amount of GAS for block header verification.
HeaderVerificationGasLimit = 3_00000000 // 3 GAS
defaultStateSyncInterval = 40000
// maxStorageBatchSize is the number of elements in storage batch expected to fit into the
// storage without delays and problems. Estimated size of batch in case of given number of
// elements does not exceed 1Mb.
maxStorageBatchSize = 10000
)
// stateJumpStage denotes the stage of state jump process.
type stateJumpStage byte
const (
// none means that no state jump process was initiated yet.
none stateJumpStage = 1 << iota
// stateJumpStarted means that state jump was just initiated, but outdated storage items
// were not yet removed.
stateJumpStarted
// oldStorageItemsRemoved means that outdated contract storage items were removed, but
// new storage items were not yet saved.
oldStorageItemsRemoved
// newStorageItemsAdded means that contract storage items are up-to-date with the current
// state.
newStorageItemsAdded
// genesisStateRemoved means that state corresponding to the genesis block was removed
// from the storage.
genesisStateRemoved
)
var (
@ -308,16 +335,6 @@ func (bc *Blockchain) init() error {
// and the genesis block as first block.
bc.log.Info("restoring blockchain", zap.String("version", version))
bHeight, err := bc.dao.GetCurrentBlockHeight()
if err != nil {
return err
}
bc.blockHeight = bHeight
bc.persistedHeight = bHeight
if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil {
return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err)
}
bc.headerHashes, err = bc.dao.GetHeaderHashes()
if err != nil {
return err
@ -365,6 +382,34 @@ func (bc *Blockchain) init() error {
}
}
// Check whether StateJump stage is in the storage and continue interrupted state jump if so.
jumpStage, err := bc.dao.Store.Get(storage.SYSStateJumpStage.Bytes())
if err == nil {
if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) {
return errors.New("state jump was not completed, but P2PStateExchangeExtensions are disabled or archival node capability is on. " +
"To start an archival node drop the database manually and restart the node")
}
if len(jumpStage) != 1 {
return fmt.Errorf("invalid state jump stage format")
}
// State jump wasn't finished yet, thus continue it.
stateSyncPoint, err := bc.dao.GetStateSyncPoint()
if err != nil {
return fmt.Errorf("failed to get state sync point from the storage")
}
return bc.jumpToStateInternal(stateSyncPoint, stateJumpStage(jumpStage[0]))
}
bHeight, err := bc.dao.GetCurrentBlockHeight()
if err != nil {
return err
}
bc.blockHeight = bHeight
bc.persistedHeight = bHeight
if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil {
return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err)
}
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
if err != nil {
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
@ -409,6 +454,158 @@ func (bc *Blockchain) init() error {
return bc.updateExtensibleWhitelist(bHeight)
}
// jumpToState is an atomic operation that changes Blockchain state to the one
// specified by the state sync point p. All the data needed for the jump must be
// collected by the state sync module.
func (bc *Blockchain) jumpToState(p uint32) error {
bc.lock.Lock()
defer bc.lock.Unlock()
return bc.jumpToStateInternal(p, none)
}
// jumpToStateInternal is an internal representation of jumpToState callback that
// changes Blockchain state to the one specified by state sync point p and state
// jump stage. All the data needed for the jump must be in the DB, otherwise an
// error is returned. It is not protected by mutex.
func (bc *Blockchain) jumpToStateInternal(p uint32, stage stateJumpStage) error {
if p+1 >= uint32(len(bc.headerHashes)) {
return fmt.Errorf("invalid state sync point %d: headerHeignt is %d", p, len(bc.headerHashes))
}
bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p))
writeBuf := io.NewBufBinWriter()
jumpStageKey := storage.SYSStateJumpStage.Bytes()
switch stage {
case none:
err := bc.dao.Store.Put(jumpStageKey, []byte{byte(stateJumpStarted)})
if err != nil {
return fmt.Errorf("failed to store state jump stage: %w", err)
}
fallthrough
case stateJumpStarted:
// Replace old storage items by new ones, it should be done step-by step.
// Firstly, remove all old genesis-related items.
b := bc.dao.Store.Batch()
bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) {
// Must copy here, #1468.
key := slice.Copy(k)
b.Delete(key)
})
b.Put(jumpStageKey, []byte{byte(oldStorageItemsRemoved)})
err := bc.dao.Store.PutBatch(b)
if err != nil {
return fmt.Errorf("failed to store state jump stage: %w", err)
}
fallthrough
case oldStorageItemsRemoved:
// Then change STTempStorage prefix to STStorage. Each replace operation is atomic.
for {
count := 0
b := bc.dao.Store.Batch()
bc.dao.Store.Seek([]byte{byte(storage.STTempStorage)}, func(k, v []byte) {
if count >= maxStorageBatchSize {
return
}
// Must copy here, #1468.
oldKey := slice.Copy(k)
b.Delete(oldKey)
key := make([]byte, len(k))
key[0] = byte(storage.STStorage)
copy(key[1:], k[1:])
value := slice.Copy(v)
b.Put(key, value)
count += 2
})
if count > 0 {
err := bc.dao.Store.PutBatch(b)
if err != nil {
return fmt.Errorf("failed to replace outdated contract storage items with the fresh ones: %w", err)
}
} else {
break
}
}
err := bc.dao.Store.Put(jumpStageKey, []byte{byte(newStorageItemsAdded)})
if err != nil {
return fmt.Errorf("failed to store state jump stage: %w", err)
}
fallthrough
case newStorageItemsAdded:
// After current state is updated, we need to remove outdated state-related data if so.
// The only outdated data we might have is genesis-related data, so check it.
if p-bc.config.MaxTraceableBlocks > 0 {
cache := bc.dao.GetWrapped()
writeBuf.Reset()
err := cache.DeleteBlock(bc.headerHashes[0], writeBuf)
if err != nil {
return fmt.Errorf("failed to remove outdated state data for the genesis block: %w", err)
}
// TODO: remove NEP17 transfers and NEP17 transfer info for genesis block, #2096 related.
_, err = cache.Persist()
if err != nil {
return fmt.Errorf("failed to drop genesis block state: %w", err)
}
}
err := bc.dao.Store.Put(jumpStageKey, []byte{byte(genesisStateRemoved)})
if err != nil {
return fmt.Errorf("failed to store state jump stage: %w", err)
}
case genesisStateRemoved:
// there's nothing to do after that, so just continue with common operations
// and remove state jump stage in the end.
default:
return errors.New("unknown state jump stage")
}
block, err := bc.dao.GetBlock(bc.headerHashes[p])
if err != nil {
return fmt.Errorf("failed to get current block: %w", err)
}
writeBuf.Reset()
err = bc.dao.StoreAsCurrentBlock(block, writeBuf)
if err != nil {
return fmt.Errorf("failed to store current block: %w", err)
}
bc.topBlock.Store(block)
atomic.StoreUint32(&bc.blockHeight, p)
atomic.StoreUint32(&bc.persistedHeight, p)
block, err = bc.dao.GetBlock(bc.headerHashes[p+1])
if err != nil {
return fmt.Errorf("failed to get block to init MPT: %w", err)
}
if err = bc.stateRoot.JumpToState(&state.MPTRoot{
Index: p,
Root: block.PrevStateRoot,
}, bc.config.KeepOnlyLatestState); err != nil {
return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err)
}
err = bc.contracts.NEO.InitializeCache(bc, bc.dao)
if err != nil {
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
}
err = bc.contracts.Management.InitializeCache(bc.dao)
if err != nil {
return fmt.Errorf("can't init cache for Management native contract: %w", err)
}
bc.contracts.Designate.InitializeCache()
if err := bc.updateExtensibleWhitelist(p); err != nil {
return fmt.Errorf("failed to update extensible whitelist: %w", err)
}
updateBlockHeightMetric(p)
err = bc.dao.Store.Delete(jumpStageKey)
if err != nil {
return fmt.Errorf("failed to remove outdated state jump stage: %w", err)
}
return nil
}
// Run runs chain loop, it needs to be run as goroutine and executing it is
// critical for correct Blockchain operation.
func (bc *Blockchain) Run() {
@ -696,6 +893,11 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot {
return bc.stateRoot
}
// GetStateSyncModule returns new state sync service instance.
func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync {
return statesync.NewModule(bc, bc.log, bc.dao, bc.jumpToState)
}
// storeBlock performs chain update using the block given, it executes all
// transactions with all appropriate side-effects and updates Blockchain state.
// This is the only way to change Blockchain state.
@ -1159,12 +1361,25 @@ func (bc *Blockchain) GetNEP17Contracts() []util.Uint160 {
}
// GetNEP17LastUpdated returns a set of contract ids with the corresponding last updated
// block indexes.
// block indexes. In case of an empty account, latest stored state synchronisation point
// is returned under Math.MinInt32 key.
func (bc *Blockchain) GetNEP17LastUpdated(acc util.Uint160) (map[int32]uint32, error) {
info, err := bc.dao.GetNEP17TransferInfo(acc)
if err != nil {
return nil, err
}
if bc.config.P2PStateExchangeExtensions && bc.config.RemoveUntraceableBlocks {
if _, ok := info.LastUpdated[bc.contracts.NEO.ID]; !ok {
nBalance, lub := bc.contracts.NEO.BalanceOf(bc.dao, acc)
if nBalance.Sign() != 0 {
info.LastUpdated[bc.contracts.NEO.ID] = lub
}
}
}
stateSyncPoint, err := bc.dao.GetStateSyncPoint()
if err == nil {
info.LastUpdated[math.MinInt32] = stateSyncPoint
}
return info.LastUpdated, nil
}

View file

@ -1,6 +1,7 @@
package core
import (
"encoding/binary"
"errors"
"fmt"
"math/big"
@ -34,6 +35,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
@ -41,6 +43,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/wallet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
)
func TestVerifyHeader(t *testing.T) {
@ -1734,3 +1737,88 @@ func TestConfigNativeUpdateHistory(t *testing.T) {
check(t, tc)
}
}
func TestBlockchain_InitWithIncompleteStateJump(t *testing.T) {
var (
stateSyncInterval = 4
maxTraceable uint32 = 6
)
spountCfg := func(c *config.Config) {
c.ProtocolConfiguration.RemoveUntraceableBlocks = true
c.ProtocolConfiguration.StateRootInHeader = true
c.ProtocolConfiguration.P2PStateExchangeExtensions = true
c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval
c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable
}
bcSpout := newTestChainWithCustomCfg(t, spountCfg)
initBasicChain(t, bcSpout)
// reach next to the latest state sync point and pretend that we've just restored
stateSyncPoint := (int(bcSpout.BlockHeight())/stateSyncInterval + 1) * stateSyncInterval
for i := bcSpout.BlockHeight() + 1; i <= uint32(stateSyncPoint); i++ {
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
}
require.Equal(t, uint32(stateSyncPoint), bcSpout.BlockHeight())
b := bcSpout.newBlock()
require.NoError(t, bcSpout.AddHeaders(&b.Header))
// put storage items with STTemp prefix
batch := bcSpout.dao.Store.Batch()
bcSpout.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) {
key := slice.Copy(k)
key[0] = storage.STTempStorage.Bytes()[0]
value := slice.Copy(v)
batch.Put(key, value)
})
require.NoError(t, bcSpout.dao.Store.PutBatch(batch))
checkNewBlockchainErr := func(t *testing.T, cfg func(c *config.Config), store storage.Store, shouldFail bool) {
unitTestNetCfg, err := config.Load("../../config", testchain.Network())
require.NoError(t, err)
cfg(&unitTestNetCfg)
log := zaptest.NewLogger(t)
_, err = NewBlockchain(store, unitTestNetCfg.ProtocolConfiguration, log)
if shouldFail {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}
boltCfg := func(c *config.Config) {
spountCfg(c)
c.ProtocolConfiguration.KeepOnlyLatestState = true
}
// manually store statejump stage to check statejump recover process
t.Run("invalid RemoveUntraceableBlocks setting", func(t *testing.T) {
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)}))
checkNewBlockchainErr(t, func(c *config.Config) {
boltCfg(c)
c.ProtocolConfiguration.RemoveUntraceableBlocks = false
}, bcSpout.dao.Store, true)
})
t.Run("invalid state jump stage format", func(t *testing.T) {
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{0x01, 0x02}))
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true)
})
t.Run("missing state sync point", func(t *testing.T) {
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)}))
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true)
})
t.Run("invalid state sync point", func(t *testing.T) {
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)}))
point := make([]byte, 4)
binary.LittleEndian.PutUint32(point, uint32(len(bcSpout.headerHashes)))
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point))
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true)
})
for _, stage := range []stateJumpStage{stateJumpStarted, oldStorageItemsRemoved, newStorageItemsAdded, genesisStateRemoved, 0x03} {
t.Run(fmt.Sprintf("state jump stage %d", stage), func(t *testing.T) {
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stage)}))
point := make([]byte, 4)
binary.LittleEndian.PutUint32(point, uint32(stateSyncPoint))
require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point))
shouldFail := stage == 0x03 // unknown stage
checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, shouldFail)
})
}
}

View file

@ -21,7 +21,6 @@ import (
type Blockchainer interface {
ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction
GetConfig() config.ProtocolConfiguration
AddHeaders(...*block.Header) error
Blockqueuer // Blockqueuer interface
CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error)
Close()
@ -56,6 +55,7 @@ type Blockchainer interface {
GetStandByCommittee() keys.PublicKeys
GetStandByValidators() keys.PublicKeys
GetStateModule() StateRoot
GetStateSyncModule() StateSync
GetStorageItem(id int32, key []byte) state.StorageItem
GetStorageItems(id int32) (map[string]state.StorageItem, error)
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM

View file

@ -5,5 +5,6 @@ import "github.com/nspcc-dev/neo-go/pkg/core/block"
// Blockqueuer is an interface for blockqueue.
type Blockqueuer interface {
AddBlock(block *block.Block) error
AddHeaders(...*block.Header) error
BlockHeight() uint32
}

View file

@ -9,6 +9,8 @@ import (
// StateRoot represents local state root module.
type StateRoot interface {
AddStateRoot(root *state.MPTRoot) error
CleanStorage() error
CurrentLocalHeight() uint32
CurrentLocalStateRoot() util.Uint256
CurrentValidatedHeight() uint32
GetStateProof(root util.Uint256, key []byte) ([][]byte, error)

View file

@ -0,0 +1,19 @@
package blockchainer
import (
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// StateSync represents state sync module.
type StateSync interface {
AddMPTNodes([][]byte) error
Blockqueuer // Blockqueuer interface
Init(currChainHeight uint32) error
IsActive() bool
IsInitialized() bool
GetUnknownMPTNodesBatch(limit int) []util.Uint256
NeedHeaders() bool
NeedMPTNodes() bool
Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
}

314
pkg/core/mpt/billet.go Normal file
View file

@ -0,0 +1,314 @@
package mpt
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
)
var (
// ErrRestoreFailed is returned when replacing HashNode by its "unhashed"
// candidate fails.
ErrRestoreFailed = errors.New("failed to restore MPT node")
errStop = errors.New("stop condition is met")
)
// Billet is a part of MPT trie with missing hash nodes that need to be restored.
// Billet is based on the following assumptions:
// 1. Refcount can only be incremented (we don't change MPT structure during restore,
// thus don't need to decrease refcount).
// 2. Each time the part of Billet is completely restored, it is collapsed into
// HashNode.
// 3. Pair (node, path) must be restored only once. It's a duty of MPT pool to manage
// MPT paths in order to provide this assumption.
type Billet struct {
Store *storage.MemCachedStore
root Node
refcountEnabled bool
}
// NewBillet returns new billet for MPT trie restoring. It accepts a MemCachedStore
// to decouple storage errors from logic errors so that all storage errors are
// processed during `store.Persist()` at the caller. This also has the benefit,
// that every `Put` can be considered an atomic operation.
func NewBillet(rootHash util.Uint256, enableRefCount bool, store *storage.MemCachedStore) *Billet {
return &Billet{
Store: store,
root: NewHashNode(rootHash),
refcountEnabled: enableRefCount,
}
}
// RestoreHashNode replaces HashNode located at the provided path by the specified Node
// and stores it. It also maintains MPT as small as possible by collapsing those parts
// of MPT that have been completely restored.
func (b *Billet) RestoreHashNode(path []byte, node Node) error {
if _, ok := node.(*HashNode); ok {
return fmt.Errorf("%w: unable to restore node into HashNode", ErrRestoreFailed)
}
if _, ok := node.(EmptyNode); ok {
return fmt.Errorf("%w: unable to restore node into EmptyNode", ErrRestoreFailed)
}
r, err := b.putIntoNode(b.root, path, node)
if err != nil {
return err
}
b.root = r
// If it's a leaf, then put into temporary contract storage.
if leaf, ok := node.(*LeafNode); ok {
k := append([]byte{byte(storage.STTempStorage)}, fromNibbles(path)...)
_ = b.Store.Put(k, leaf.value)
}
return nil
}
// putIntoNode puts val with provided path inside curr and returns updated node.
// Reference counters are updated for both curr and returned value.
func (b *Billet) putIntoNode(curr Node, path []byte, val Node) (Node, error) {
switch n := curr.(type) {
case *LeafNode:
return b.putIntoLeaf(n, path, val)
case *BranchNode:
return b.putIntoBranch(n, path, val)
case *ExtensionNode:
return b.putIntoExtension(n, path, val)
case *HashNode:
return b.putIntoHash(n, path, val)
case EmptyNode:
return nil, fmt.Errorf("%w: can't modify EmptyNode during restore", ErrRestoreFailed)
default:
panic("invalid MPT node type")
}
}
func (b *Billet) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) {
if len(path) != 0 {
return nil, fmt.Errorf("%w: can't modify LeafNode during restore", ErrRestoreFailed)
}
if curr.Hash() != val.Hash() {
return nil, fmt.Errorf("%w: bad Leaf node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE())
}
// Once Leaf node is restored, it will be collapsed into HashNode forever, so
// there shouldn't be such situation when we try to restore Leaf node.
panic("bug: can't restore LeafNode")
}
func (b *Billet) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) {
if len(path) == 0 && curr.Hash().Equals(val.Hash()) {
// This node has already been restored, so it's an MPT pool duty to avoid
// duplicating restore requests.
panic("bug: can't perform restoring of BranchNode twice")
}
i, path := splitPath(path)
r, err := b.putIntoNode(curr.Children[i], path, val)
if err != nil {
return nil, err
}
curr.Children[i] = r
return b.tryCollapseBranch(curr), nil
}
func (b *Billet) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) {
if len(path) == 0 {
if curr.Hash() != val.Hash() {
return nil, fmt.Errorf("%w: bad Extension node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE())
}
// This node has already been restored, so it's an MPT pool duty to avoid
// duplicating restore requests.
panic("bug: can't perform restoring of ExtensionNode twice")
}
if !bytes.HasPrefix(path, curr.key) {
return nil, fmt.Errorf("%w: can't modify ExtensionNode during restore", ErrRestoreFailed)
}
r, err := b.putIntoNode(curr.next, path[len(curr.key):], val)
if err != nil {
return nil, err
}
curr.next = r
return b.tryCollapseExtension(curr), nil
}
func (b *Billet) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) {
// Once a part of MPT Billet is completely restored, it will be collapsed forever, so
// it's an MPT pool duty to avoid duplicating restore requests.
if len(path) != 0 {
return nil, fmt.Errorf("%w: node has already been collapsed", ErrRestoreFailed)
}
// `curr` hash node can be either of
// 1) saved in storage (i.g. if we've already restored node with the same hash from the
// other part of MPT), so just add it to local in-memory MPT.
// 2) missing from the storage. It's OK because we're syncing MPT state, and the purpose
// is to store missing hash nodes.
// both cases are OK, but we still need to validate `val` against `curr`.
if val.Hash() != curr.Hash() {
return nil, fmt.Errorf("%w: can't restore HashNode: expected and actual hashes mismatch (%s vs %s)", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE())
}
if curr.Collapsed {
// This node has already been restored and collapsed, so it's an MPT pool duty to avoid
// duplicating restore requests.
panic("bug: can't perform restoring of collapsed node")
}
// We also need to increment refcount in both cases. That's the only place where refcount
// is changed during restore process. Also flush right now, because sync process can be
// interrupted at any time.
b.incrementRefAndStore(val.Hash(), val.Bytes())
if val.Type() == LeafT {
return b.tryCollapseLeaf(val.(*LeafNode)), nil
}
return val, nil
}
func (b *Billet) incrementRefAndStore(h util.Uint256, bs []byte) {
key := makeStorageKey(h.BytesBE())
if b.refcountEnabled {
var (
err error
data []byte
cnt int32
)
// An item may already be in store.
data, err = b.Store.Get(key)
if err == nil {
cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:]))
}
cnt++
if len(data) == 0 {
data = append(bs, 0, 0, 0, 0)
}
binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt))
_ = b.Store.Put(key, data)
} else {
_ = b.Store.Put(key, bs)
}
}
// Traverse traverses MPT nodes (pre-order) starting from the billet root down
// to its children calling `process` for each serialised node until true is
// returned from `process` function. It also replaces all HashNodes to their
// "unhashed" counterparts until the stop condition is satisfied.
func (b *Billet) Traverse(process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) error {
r, err := b.traverse(b.root, process, ignoreStorageErr)
if err != nil && !errors.Is(err, errStop) {
return err
}
b.root = r
return nil
}
func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) (Node, error) {
if _, ok := curr.(EmptyNode); ok {
// We're not interested in EmptyNodes, and they do not affect the
// traversal process, thus remain them untouched.
return curr, nil
}
if hn, ok := curr.(*HashNode); ok {
r, err := b.GetFromStore(hn.Hash())
if err != nil {
if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) {
return hn, nil
}
return nil, err
}
return b.traverse(r, process, ignoreStorageErr)
}
bytes := slice.Copy(curr.Bytes())
if process(curr, bytes) {
return curr, errStop
}
switch n := curr.(type) {
case *LeafNode:
return b.tryCollapseLeaf(n), nil
case *BranchNode:
for i := range n.Children {
r, err := b.traverse(n.Children[i], process, ignoreStorageErr)
if err != nil {
if !errors.Is(err, errStop) {
return nil, err
}
n.Children[i] = r
return n, err
}
n.Children[i] = r
}
return b.tryCollapseBranch(n), nil
case *ExtensionNode:
r, err := b.traverse(n.next, process, ignoreStorageErr)
if err != nil && !errors.Is(err, errStop) {
return nil, err
}
n.next = r
return b.tryCollapseExtension(n), err
default:
return nil, ErrNotFound
}
}
func (b *Billet) tryCollapseLeaf(curr *LeafNode) Node {
// Leaf can always be collapsed.
res := NewHashNode(curr.Hash())
res.Collapsed = true
return res
}
func (b *Billet) tryCollapseExtension(curr *ExtensionNode) Node {
if !(curr.next.Type() == HashT && curr.next.(*HashNode).Collapsed) {
return curr
}
res := NewHashNode(curr.Hash())
res.Collapsed = true
return res
}
func (b *Billet) tryCollapseBranch(curr *BranchNode) Node {
canCollapse := true
for i := 0; i < childrenCount; i++ {
if curr.Children[i].Type() == EmptyT {
continue
}
if curr.Children[i].Type() == HashT && curr.Children[i].(*HashNode).Collapsed {
continue
}
canCollapse = false
break
}
if !canCollapse {
return curr
}
res := NewHashNode(curr.Hash())
res.Collapsed = true
return res
}
// GetFromStore returns MPT node from the storage.
func (b *Billet) GetFromStore(h util.Uint256) (Node, error) {
data, err := b.Store.Get(makeStorageKey(h.BytesBE()))
if err != nil {
return nil, err
}
var n NodeObject
r := io.NewBinReaderFromBuf(data)
n.DecodeBinary(r)
if r.Err != nil {
return nil, r.Err
}
if b.refcountEnabled {
data = data[:len(data)-4]
}
n.Node.(flushedNode).setCache(data, h)
return n.Node, nil
}

211
pkg/core/mpt/billet_test.go Normal file
View file

@ -0,0 +1,211 @@
package mpt
import (
"encoding/binary"
"errors"
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestBillet_RestoreHashNode(t *testing.T) {
check := func(t *testing.T, tr *Billet, expectedRoot Node, expectedNode Node, expectedRefCount uint32) {
_ = expectedRoot.Hash()
_ = tr.root.Hash()
require.Equal(t, expectedRoot, tr.root)
expectedBytes, err := tr.Store.Get(makeStorageKey(expectedNode.Hash().BytesBE()))
if expectedRefCount != 0 {
require.NoError(t, err)
require.Equal(t, expectedRefCount, binary.LittleEndian.Uint32(expectedBytes[len(expectedBytes)-4:]))
} else {
require.True(t, errors.Is(err, storage.ErrKeyNotFound))
}
}
t.Run("parent is Extension", func(t *testing.T) {
t.Run("restore Branch", func(t *testing.T) {
b := NewBranchNode()
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xCD}))
b.Children[5] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xDE}))
path := toNibbles([]byte{0xAC})
e := NewExtensionNode(path, NewHashNode(b.Hash()))
tr := NewBillet(e.Hash(), true, newTestStore())
tr.root = e
// OK
n := new(NodeObject)
n.DecodeBinary(io.NewBinReaderFromBuf(b.Bytes()))
require.NoError(t, tr.RestoreHashNode(path, n.Node))
expected := NewExtensionNode(path, n.Node)
check(t, tr, expected, n.Node, 1)
// One more time (already restored) => panic expected, no refcount changes
require.Panics(t, func() {
_ = tr.RestoreHashNode(path, n.Node)
})
check(t, tr, expected, n.Node, 1)
// Same path, but wrong hash => error expected, no refcount changes
require.True(t, errors.Is(tr.RestoreHashNode(path, NewBranchNode()), ErrRestoreFailed))
check(t, tr, expected, n.Node, 1)
// New path (changes in the MPT structure are not allowed) => error expected, no refcount changes
require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), n.Node), ErrRestoreFailed))
check(t, tr, expected, n.Node, 1)
})
t.Run("restore Leaf", func(t *testing.T) {
l := NewLeafNode([]byte{0xAB, 0xCD})
path := toNibbles([]byte{0xAC})
e := NewExtensionNode(path, NewHashNode(l.Hash()))
tr := NewBillet(e.Hash(), true, newTestStore())
tr.root = e
// OK
require.NoError(t, tr.RestoreHashNode(path, l))
expected := NewHashNode(e.Hash()) // leaf should be collapsed immediately => extension should also be collapsed
expected.Collapsed = true
check(t, tr, expected, l, 1)
// One more time (already restored and collapsed) => error expected, no refcount changes
require.Error(t, tr.RestoreHashNode(path, l))
check(t, tr, expected, l, 1)
// Same path, but wrong hash => error expected, no refcount changes
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed))
check(t, tr, expected, l, 1)
// New path (changes in the MPT structure are not allowed) => error expected, no refcount changes
require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), l), ErrRestoreFailed))
check(t, tr, expected, l, 1)
})
t.Run("restore Hash", func(t *testing.T) {
h := NewHashNode(util.Uint256{1, 2, 3})
path := toNibbles([]byte{0xAC})
e := NewExtensionNode(path, h)
tr := NewBillet(e.Hash(), true, newTestStore())
tr.root = e
// no-op
require.True(t, errors.Is(tr.RestoreHashNode(path, h), ErrRestoreFailed))
check(t, tr, e, h, 0)
})
})
t.Run("parent is Leaf", func(t *testing.T) {
l := NewLeafNode([]byte{0xAB, 0xCD})
path := []byte{}
tr := NewBillet(l.Hash(), true, newTestStore())
tr.root = l
// Already restored => panic expected
require.Panics(t, func() {
_ = tr.RestoreHashNode(path, l)
})
// Same path, but wrong hash => error expected, no refcount changes
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed))
// Non-nil path, but MPT structure can't be changed => error expected, no refcount changes
require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAC}), NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed))
})
t.Run("parent is Branch", func(t *testing.T) {
t.Run("middle child", func(t *testing.T) {
l1 := NewLeafNode([]byte{0xAB, 0xCD})
l2 := NewLeafNode([]byte{0xAB, 0xDE})
b := NewBranchNode()
b.Children[5] = NewHashNode(l1.Hash())
b.Children[lastChild] = NewHashNode(l2.Hash())
tr := NewBillet(b.Hash(), true, newTestStore())
tr.root = b
// OK
path := []byte{0x05}
require.NoError(t, tr.RestoreHashNode(path, l1))
check(t, tr, b, l1, 1)
// One more time (already restored) => panic expected.
// It's an MPT pool duty to avoid such situations during real restore process.
require.Panics(t, func() {
_ = tr.RestoreHashNode(path, l1)
})
// No refcount changes expected.
check(t, tr, b, l1, 1)
// Same path, but wrong hash => error expected, no refcount changes
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed))
check(t, tr, b, l1, 1)
// New path pointing to the empty HashNode (changes in the MPT structure are not allowed) => error expected, no refcount changes
require.True(t, errors.Is(tr.RestoreHashNode([]byte{0x01}, l1), ErrRestoreFailed))
check(t, tr, b, l1, 1)
})
t.Run("last child", func(t *testing.T) {
l1 := NewLeafNode([]byte{0xAB, 0xCD})
l2 := NewLeafNode([]byte{0xAB, 0xDE})
b := NewBranchNode()
b.Children[5] = NewHashNode(l1.Hash())
b.Children[lastChild] = NewHashNode(l2.Hash())
tr := NewBillet(b.Hash(), true, newTestStore())
tr.root = b
// OK
path := []byte{}
require.NoError(t, tr.RestoreHashNode(path, l2))
check(t, tr, b, l2, 1)
// One more time (already restored) => panic expected.
// It's an MPT pool duty to avoid such situations during real restore process.
require.Panics(t, func() {
_ = tr.RestoreHashNode(path, l2)
})
// No refcount changes expected.
check(t, tr, b, l2, 1)
// Same path, but wrong hash => error expected, no refcount changes
require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed))
check(t, tr, b, l2, 1)
})
t.Run("two children with same hash", func(t *testing.T) {
l := NewLeafNode([]byte{0xAB, 0xCD})
b := NewBranchNode()
// two same hashnodes => leaf's refcount expected to be 2 in the end.
b.Children[3] = NewHashNode(l.Hash())
b.Children[4] = NewHashNode(l.Hash())
tr := NewBillet(b.Hash(), true, newTestStore())
tr.root = b
// OK
require.NoError(t, tr.RestoreHashNode([]byte{0x03}, l))
expected := b
expected.Children[3].(*HashNode).Collapsed = true
check(t, tr, b, l, 1)
// Restore another node with the same hash => no error expected, refcount should be incremented.
// Branch node should be collapsed.
require.NoError(t, tr.RestoreHashNode([]byte{0x04}, l))
res := NewHashNode(b.Hash())
res.Collapsed = true
check(t, tr, res, l, 2)
})
})
t.Run("parent is Hash", func(t *testing.T) {
l := NewLeafNode([]byte{0xAB, 0xCD})
b := NewBranchNode()
b.Children[3] = NewHashNode(l.Hash())
b.Children[4] = NewHashNode(l.Hash())
tr := NewBillet(b.Hash(), true, newTestStore())
// Should fail, because if it's a hash node with non-empty path, then the node
// has already been collapsed.
require.Error(t, tr.RestoreHashNode([]byte{0x03}, l))
})
}

View file

@ -89,6 +89,12 @@ func (b *BranchNode) UnmarshalJSON(data []byte) error {
return errors.New("expected branch node")
}
// Clone implements Node interface.
func (b *BranchNode) Clone() Node {
res := *b
return &res
}
// splitPath splits path for a branch node.
func splitPath(path []byte) (byte, []byte) {
if len(path) != 0 {

View file

@ -54,3 +54,6 @@ func (e EmptyNode) Type() NodeType {
func (e EmptyNode) Bytes() []byte {
return nil
}
// Clone implements Node interface.
func (EmptyNode) Clone() Node { return EmptyNode{} }

View file

@ -98,3 +98,9 @@ func (e *ExtensionNode) UnmarshalJSON(data []byte) error {
}
return errors.New("expected extension node")
}
// Clone implements Node interface.
func (e *ExtensionNode) Clone() Node {
res := *e
return &res
}

View file

@ -10,6 +10,7 @@ import (
// HashNode represents MPT's hash node.
type HashNode struct {
BaseNode
Collapsed bool
}
var _ Node = (*HashNode)(nil)
@ -76,3 +77,10 @@ func (h *HashNode) UnmarshalJSON(data []byte) error {
}
return errors.New("expected hash node")
}
// Clone implements Node interface.
func (h *HashNode) Clone() Node {
res := *h
res.Collapsed = false
return &res
}

View file

@ -1,5 +1,7 @@
package mpt
import "github.com/nspcc-dev/neo-go/pkg/util"
// lcp returns longest common prefix of a and b.
// Note: it does no allocations.
func lcp(a, b []byte) []byte {
@ -40,3 +42,45 @@ func toNibbles(path []byte) []byte {
}
return result
}
// fromNibbles performs operation opposite to toNibbles and does no path validity checks.
func fromNibbles(path []byte) []byte {
result := make([]byte, len(path)/2)
for i := range result {
result[i] = path[2*i]<<4 + path[2*i+1]
}
return result
}
// GetChildrenPaths returns a set of paths to node's children who are non-empty HashNodes
// based on the node's path.
func GetChildrenPaths(path []byte, node Node) map[util.Uint256][][]byte {
res := make(map[util.Uint256][][]byte)
switch n := node.(type) {
case *LeafNode, *HashNode, EmptyNode:
return nil
case *BranchNode:
for i, child := range n.Children {
if child.Type() == HashT {
cPath := make([]byte, len(path), len(path)+1)
copy(cPath, path)
if i != lastChild {
cPath = append(cPath, byte(i))
}
paths := res[child.Hash()]
paths = append(paths, cPath)
res[child.Hash()] = paths
}
}
case *ExtensionNode:
if n.next.Type() == HashT {
cPath := make([]byte, len(path)+len(n.key))
copy(cPath, path)
copy(cPath[len(path):], n.key)
res[n.next.Hash()] = [][]byte{cPath}
}
default:
panic("unknown Node type")
}
return res
}

View file

@ -0,0 +1,67 @@
package mpt
import (
"testing"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestToNibblesFromNibbles(t *testing.T) {
check := func(t *testing.T, expected []byte) {
actual := fromNibbles(toNibbles(expected))
require.Equal(t, expected, actual)
}
t.Run("empty path", func(t *testing.T) {
check(t, []byte{})
})
t.Run("non-empty path", func(t *testing.T) {
check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF})
})
}
func TestGetChildrenPaths(t *testing.T) {
h1 := NewHashNode(util.Uint256{1, 2, 3})
h2 := NewHashNode(util.Uint256{4, 5, 6})
h3 := NewHashNode(util.Uint256{7, 8, 9})
l := NewLeafNode([]byte{1, 2, 3})
ext1 := NewExtensionNode([]byte{8, 9}, h1)
ext2 := NewExtensionNode([]byte{7, 6}, l)
branch := NewBranchNode()
branch.Children[3] = h1
branch.Children[5] = l
branch.Children[6] = h1 // 3-th and 6-th children have the same hash
branch.Children[7] = h3
branch.Children[lastChild] = h2
testCases := map[string]struct {
node Node
expected map[util.Uint256][][]byte
}{
"Hash": {h1, nil},
"Leaf": {l, nil},
"Extension with next Hash": {ext1, map[util.Uint256][][]byte{h1.Hash(): {ext1.key}}},
"Extension with next non-Hash": {ext2, map[util.Uint256][][]byte{}},
"Branch": {branch, map[util.Uint256][][]byte{
h1.Hash(): {{0x03}, {0x06}},
h2.Hash(): {{}},
h3.Hash(): {{0x07}},
}},
}
parentPath := []byte{4, 5, 6}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
require.Equal(t, testCase.expected, GetChildrenPaths([]byte{}, testCase.node))
if testCase.expected != nil {
expectedWithPrefix := make(map[util.Uint256][][]byte, len(testCase.expected))
for h, paths := range testCase.expected {
var res [][]byte
for _, path := range paths {
res = append(res, append(parentPath, path...))
}
expectedWithPrefix[h] = res
}
require.Equal(t, expectedWithPrefix, GetChildrenPaths(parentPath, testCase.node))
}
})
}
}

View file

@ -77,3 +77,9 @@ func (n *LeafNode) UnmarshalJSON(data []byte) error {
}
return errors.New("expected leaf node")
}
// Clone implements Node interface.
func (n *LeafNode) Clone() Node {
res := *n
return &res
}

View file

@ -34,6 +34,7 @@ type Node interface {
json.Marshaler
json.Unmarshaler
Size() int
Clone() Node
BaseNodeIface
}

View file

@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require"
)
func newProofTrie(t *testing.T) *Trie {
func newProofTrie(t *testing.T, missingHashNode bool) *Trie {
l := NewLeafNode([]byte("somevalue"))
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
l2 := NewLeafNode([]byte("invalid"))
@ -20,11 +20,14 @@ func newProofTrie(t *testing.T) *Trie {
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
tr.putToStore(l)
tr.putToStore(e)
if !missingHashNode {
tr.putToStore(l2)
}
return tr
}
func TestTrie_GetProof(t *testing.T) {
tr := newProofTrie(t)
tr := newProofTrie(t, true)
t.Run("MissingKey", func(t *testing.T) {
_, err := tr.GetProof([]byte{0x12})
@ -43,7 +46,7 @@ func TestTrie_GetProof(t *testing.T) {
}
func TestVerifyProof(t *testing.T) {
tr := newProofTrie(t)
tr := newProofTrie(t, true)
t.Run("Simple", func(t *testing.T) {
proof, err := tr.GetProof([]byte{0x12, 0x32})

View file

@ -353,3 +353,8 @@ func (s *Designate) getRole(item stackitem.Item) (noderoles.Role, bool) {
u := bi.Uint64()
return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u))
}
// InitializeCache invalidates native Designate cache.
func (s *Designate) InitializeCache() {
s.rolesChangedFlag.Store(true)
}

View file

@ -13,6 +13,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
"go.uber.org/atomic"
"go.uber.org/zap"
)
@ -114,6 +115,66 @@ func (s *Module) Init(height uint32, enableRefCount bool) error {
return nil
}
// CleanStorage removes all MPT-related data from the storage (MPT nodes, validated stateroots)
// except local stateroot for the current height and GC flag. This method is aimed to clean
// outdated MPT data before state sync process can be started.
// Note: this method is aimed to be called for genesis block only, an error is returned otherwice.
func (s *Module) CleanStorage() error {
if s.localHeight.Load() != 0 {
return fmt.Errorf("can't clean MPT data for non-genesis block: expected local stateroot height 0, got %d", s.localHeight.Load())
}
gcKey := []byte{byte(storage.DataMPT), prefixGC}
gcVal, err := s.Store.Get(gcKey)
if err != nil {
return fmt.Errorf("failed to get GC flag: %w", err)
}
//
b := s.Store.Batch()
s.Store.Seek([]byte{byte(storage.DataMPT)}, func(k, _ []byte) {
// Must copy here, #1468.
key := slice.Copy(k)
b.Delete(key)
})
err = s.Store.PutBatch(b)
if err != nil {
return fmt.Errorf("failed to remove outdated MPT-reated items: %w", err)
}
err = s.Store.Put(gcKey, gcVal)
if err != nil {
return fmt.Errorf("failed to store GC flag: %w", err)
}
currentLocal := s.currentLocal.Load().(util.Uint256)
if !currentLocal.Equals(util.Uint256{}) {
err := s.addLocalStateRoot(s.Store, &state.MPTRoot{
Index: s.localHeight.Load(),
Root: currentLocal,
})
if err != nil {
return fmt.Errorf("failed to store current local stateroot: %w", err)
}
}
return nil
}
// JumpToState performs jump to the state specified by given stateroot index.
func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error {
if err := s.addLocalStateRoot(s.Store, sr); err != nil {
return fmt.Errorf("failed to store local state root: %w", err)
}
data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, sr.Index)
if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixValidated}, data); err != nil {
return fmt.Errorf("failed to store validated height: %w", err)
}
s.validatedHeight.Store(sr.Index)
s.currentLocal.Store(sr.Root)
s.localHeight.Store(sr.Index)
s.mpt = mpt.NewTrie(mpt.NewHashNode(sr.Root), enableRefCount, s.Store)
return nil
}
// AddMPTBatch updates using provided batch.
func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) {
mpt := *s.mpt

View file

@ -0,0 +1,479 @@
/*
Package statesync implements module for the P2P state synchronisation process. The
module manages state synchronisation for non-archival nodes which are joining the
network and don't have the ability to resync from the genesis block.
Given the currently available state synchronisation point P, sate sync process
includes the following stages:
1. Fetching headers starting from height 0 up to P+1.
2. Fetching MPT nodes for height P stating from the corresponding state root.
3. Fetching blocks starting from height P-MaxTraceableBlocks (or 0) up to P.
Steps 2 and 3 are being performed in parallel. Once all the data are collected
and stored in the db, an atomic state jump is occurred to the state sync point P.
Further node operation process is performed using standard sync mechanism until
the node reaches synchronised state.
*/
package statesync
import (
"encoding/hex"
"errors"
"fmt"
"sync"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
"go.uber.org/zap"
)
// stateSyncStage is a type of state synchronisation stage.
type stateSyncStage uint8
const (
// inactive means that state exchange is disabled by the protocol configuration.
// Can't be combined with other states.
inactive stateSyncStage = 1 << iota
// none means that state exchange is enabled in the configuration, but
// initialisation of the state sync module wasn't yet performed, i.e.
// (*Module).Init wasn't called. Can't be combined with other states.
none
// initialized means that (*Module).Init was called, but other sync stages
// are not yet reached (i.e. that headers are requested, but not yet fetched).
// Can't be combined with other states.
initialized
// headersSynced means that headers for the current state sync point are fetched.
// May be combined with mptSynced and/or blocksSynced.
headersSynced
// mptSynced means that MPT nodes for the current state sync point are fetched.
// Always combined with headersSynced; may be combined with blocksSynced.
mptSynced
// blocksSynced means that blocks up to the current state sync point are stored.
// Always combined with headersSynced; may be combined with mptSynced.
blocksSynced
)
// Module represents state sync module and aimed to gather state-related data to
// perform an atomic state jump.
type Module struct {
lock sync.RWMutex
log *zap.Logger
// syncPoint is the state synchronisation point P we're currently working against.
syncPoint uint32
// syncStage is the stage of the sync process.
syncStage stateSyncStage
// syncInterval is the delta between two adjacent state sync points.
syncInterval uint32
// blockHeight is the index of the latest stored block.
blockHeight uint32
dao *dao.Simple
bc blockchainer.Blockchainer
mptpool *Pool
billet *mpt.Billet
jumpCallback func(p uint32) error
}
// NewModule returns new instance of statesync module.
func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple, jumpCallback func(p uint32) error) *Module {
if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) {
return &Module{
dao: s,
bc: bc,
syncStage: inactive,
}
}
return &Module{
dao: s,
bc: bc,
log: log,
syncInterval: uint32(bc.GetConfig().StateSyncInterval),
mptpool: NewPool(),
syncStage: none,
jumpCallback: jumpCallback,
}
}
// Init initializes state sync module for the current chain's height with given
// callback for MPT nodes requests.
func (s *Module) Init(currChainHeight uint32) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.syncStage != none {
return errors.New("already initialized or inactive")
}
p := (currChainHeight / s.syncInterval) * s.syncInterval
if p < 2*s.syncInterval {
// chain is too low to start state exchange process, use the standard sync mechanism
s.syncStage = inactive
return nil
}
pOld, err := s.dao.GetStateSyncPoint()
if err == nil && pOld >= p-s.syncInterval {
// old point is still valid, so try to resync states for this point.
p = pOld
} else {
if s.bc.BlockHeight() > p-2*s.syncInterval {
// chain has already been synchronised up to old state sync point and regular blocks processing was started.
// Current block height is enough to start regular blocks processing.
s.syncStage = inactive
return nil
}
if err == nil {
// pOld was found, it is outdated, and chain wasn't completely synchronised for pOld. Need to drop the db.
return fmt.Errorf("state sync point %d is found in the storage, "+
"but sync process wasn't completed and point is outdated. Please, drop the database manually and restart the node to run state sync process", pOld)
}
if s.bc.BlockHeight() != 0 {
// pOld wasn't found, but blocks processing was started in a regular manner and latest stored block is too outdated
// to start regular blocks processing again. Need to drop the db.
return fmt.Errorf("current chain's height is too low to start regular blocks processing from the oldest sync point %d. "+
"Please, drop the database manually and restart the node to run state sync process", p-s.syncInterval)
}
// We've reached this point, so chain has genesis block only. As far as we can't ruin
// current chain's state until new state is completely fetched, outdated state-related data
// will be removed from storage during (*Blockchain).jumpToState(...) execution.
// All we need to do right now is to remove genesis-related MPT nodes.
err = s.bc.GetStateModule().CleanStorage()
if err != nil {
return fmt.Errorf("failed to remove outdated MPT data from storage: %w", err)
}
}
s.syncPoint = p
err = s.dao.PutStateSyncPoint(p)
if err != nil {
return fmt.Errorf("failed to store state synchronisation point %d: %w", p, err)
}
s.syncStage = initialized
s.log.Info("try to sync state for the latest state synchronisation point",
zap.Uint32("point", p),
zap.Uint32("evaluated chain's blockHeight", currChainHeight))
return s.defineSyncStage()
}
// defineSyncStage sequentially checks and sets sync state process stage after Module
// initialization. It also performs initialization of MPT Billet if necessary.
func (s *Module) defineSyncStage() error {
// check headers sync stage first
ltstHeaderHeight := s.bc.HeaderHeight()
if ltstHeaderHeight > s.syncPoint {
s.syncStage = headersSynced
s.log.Info("headers are in sync",
zap.Uint32("headerHeight", s.bc.HeaderHeight()))
}
// check blocks sync stage
s.blockHeight = s.getLatestSavedBlock(s.syncPoint)
if s.blockHeight >= s.syncPoint {
s.syncStage |= blocksSynced
s.log.Info("blocks are in sync",
zap.Uint32("blockHeight", s.blockHeight))
}
// check MPT sync stage
if s.blockHeight > s.syncPoint {
s.syncStage |= mptSynced
s.log.Info("MPT is in sync",
zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight()))
} else if s.syncStage&headersSynced != 0 {
header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint + 1)))
if err != nil {
return fmt.Errorf("failed to get header to initialize MPT billet: %w", err)
}
s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store)
s.log.Info("MPT billet initialized",
zap.Uint32("height", s.syncPoint),
zap.String("state root", header.PrevStateRoot.StringBE()))
pool := NewPool()
pool.Add(header.PrevStateRoot, []byte{})
err = s.billet.Traverse(func(n mpt.Node, _ []byte) bool {
nPaths, ok := pool.TryGet(n.Hash())
if !ok {
// if this situation occurs, then it's a bug in MPT pool or Traverse.
panic("failed to get MPT node from the pool")
}
pool.Remove(n.Hash())
childrenPaths := make(map[util.Uint256][][]byte)
for _, path := range nPaths {
nChildrenPaths := mpt.GetChildrenPaths(path, n)
for hash, paths := range nChildrenPaths {
childrenPaths[hash] = append(childrenPaths[hash], paths...) // it's OK to have duplicates, they'll be handled by mempool
}
}
pool.Update(nil, childrenPaths)
return false
}, true)
if err != nil {
return fmt.Errorf("failed to traverse MPT during initialization: %w", err)
}
s.mptpool.Update(nil, pool.GetAll())
if s.mptpool.Count() == 0 {
s.syncStage |= mptSynced
s.log.Info("MPT is in sync",
zap.Uint32("stateroot height", s.syncPoint))
}
}
if s.syncStage == headersSynced|blocksSynced|mptSynced {
s.log.Info("state is in sync, starting regular blocks processing")
s.syncStage = inactive
}
return nil
}
// getLatestSavedBlock returns either current block index (if it's still relevant
// to continue state sync process) or H-1 where H is the index of the earliest
// block that should be saved next.
func (s *Module) getLatestSavedBlock(p uint32) uint32 {
var result uint32
mtb := s.bc.GetConfig().MaxTraceableBlocks
if p > mtb {
result = p - mtb
}
storedH, err := s.dao.GetStateSyncCurrentBlockHeight()
if err == nil && storedH > result {
result = storedH
}
actualH := s.bc.BlockHeight()
if actualH > result {
result = actualH
}
return result
}
// AddHeaders validates and adds specified headers to the chain.
func (s *Module) AddHeaders(hdrs ...*block.Header) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.syncStage != initialized {
return errors.New("headers were not requested")
}
hdrsErr := s.bc.AddHeaders(hdrs...)
if s.bc.HeaderHeight() > s.syncPoint {
err := s.defineSyncStage()
if err != nil {
return fmt.Errorf("failed to define current sync stage: %w", err)
}
}
return hdrsErr
}
// AddBlock verifies and saves block skipping executable scripts.
func (s *Module) AddBlock(block *block.Block) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.syncStage&headersSynced == 0 || s.syncStage&blocksSynced != 0 {
return nil
}
if s.blockHeight == s.syncPoint {
return nil
}
expectedHeight := s.blockHeight + 1
if expectedHeight != block.Index {
return fmt.Errorf("expected %d, got %d: invalid block index", expectedHeight, block.Index)
}
if s.bc.GetConfig().StateRootInHeader != block.StateRootEnabled {
return fmt.Errorf("stateroot setting mismatch: %v != %v", s.bc.GetConfig().StateRootInHeader, block.StateRootEnabled)
}
if s.bc.GetConfig().VerifyBlocks {
merkle := block.ComputeMerkleRoot()
if !block.MerkleRoot.Equals(merkle) {
return errors.New("invalid block: MerkleRoot mismatch")
}
}
cache := s.dao.GetWrapped()
writeBuf := io.NewBufBinWriter()
if err := cache.StoreAsBlock(block, writeBuf); err != nil {
return err
}
writeBuf.Reset()
err := cache.PutStateSyncCurrentBlockHeight(block.Index)
if err != nil {
return fmt.Errorf("failed to store current block height: %w", err)
}
for _, tx := range block.Transactions {
if err := cache.StoreAsTransaction(tx, block.Index, writeBuf); err != nil {
return err
}
writeBuf.Reset()
}
_, err = cache.Persist()
if err != nil {
return fmt.Errorf("failed to persist results: %w", err)
}
s.blockHeight = block.Index
if s.blockHeight == s.syncPoint {
s.syncStage |= blocksSynced
s.log.Info("blocks are in sync",
zap.Uint32("blockHeight", s.blockHeight))
s.checkSyncIsCompleted()
}
return nil
}
// AddMPTNodes tries to add provided set of MPT nodes to the MPT billet if they are
// not yet collected.
func (s *Module) AddMPTNodes(nodes [][]byte) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.syncStage&headersSynced == 0 || s.syncStage&mptSynced != 0 {
return errors.New("MPT nodes were not requested")
}
for _, nBytes := range nodes {
var n mpt.NodeObject
r := io.NewBinReaderFromBuf(nBytes)
n.DecodeBinary(r)
if r.Err != nil {
return fmt.Errorf("failed to decode MPT node: %w", r.Err)
}
err := s.restoreNode(n.Node)
if err != nil {
return err
}
}
if s.mptpool.Count() == 0 {
s.syncStage |= mptSynced
s.log.Info("MPT is in sync",
zap.Uint32("height", s.syncPoint))
s.checkSyncIsCompleted()
}
return nil
}
func (s *Module) restoreNode(n mpt.Node) error {
nPaths, ok := s.mptpool.TryGet(n.Hash())
if !ok {
// it can easily happen after receiving the same data from different peers.
return nil
}
var childrenPaths = make(map[util.Uint256][][]byte)
for _, path := range nPaths {
// Must clone here in order to avoid future collapse collisions. If the node's refcount>1 then MPT pool
// will manage all paths for this node and call RestoreHashNode separately for each of the paths.
err := s.billet.RestoreHashNode(path, n.Clone())
if err != nil {
return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err)
}
for h, paths := range mpt.GetChildrenPaths(path, n) {
childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool
}
}
s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths)
for h := range childrenPaths {
if child, err := s.billet.GetFromStore(h); err == nil {
// child is already in the storage, so we don't need to request it one more time.
err = s.restoreNode(child)
if err != nil {
return fmt.Errorf("unable to restore saved children: %w", err)
}
}
}
return nil
}
// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1
// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored.
// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller
// should take care of it.
func (s *Module) checkSyncIsCompleted() {
if s.syncStage != headersSynced|mptSynced|blocksSynced {
return
}
s.log.Info("state is in sync",
zap.Uint32("state sync point", s.syncPoint))
err := s.jumpCallback(s.syncPoint)
if err != nil {
s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err))
}
s.syncStage = inactive
s.dispose()
}
func (s *Module) dispose() {
s.billet = nil
}
// BlockHeight returns index of the last stored block.
func (s *Module) BlockHeight() uint32 {
s.lock.RLock()
defer s.lock.RUnlock()
return s.blockHeight
}
// IsActive tells whether state sync module is on and still gathering state
// synchronisation data (headers, blocks or MPT nodes).
func (s *Module) IsActive() bool {
s.lock.RLock()
defer s.lock.RUnlock()
return !(s.syncStage == inactive || (s.syncStage == headersSynced|mptSynced|blocksSynced))
}
// IsInitialized tells whether state sync module does not require initialization.
// If `false` is returned then Init can be safely called.
func (s *Module) IsInitialized() bool {
s.lock.RLock()
defer s.lock.RUnlock()
return s.syncStage != none
}
// NeedHeaders tells whether the module hasn't completed headers synchronisation.
func (s *Module) NeedHeaders() bool {
s.lock.RLock()
defer s.lock.RUnlock()
return s.syncStage == initialized
}
// NeedMPTNodes returns whether the module hasn't completed MPT synchronisation.
func (s *Module) NeedMPTNodes() bool {
s.lock.RLock()
defer s.lock.RUnlock()
return s.syncStage&headersSynced != 0 && s.syncStage&mptSynced == 0
}
// Traverse traverses local MPT nodes starting from the specified root down to its
// children calling `process` for each serialised node until stop condition is satisfied.
func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
s.lock.RLock()
defer s.lock.RUnlock()
b := mpt.NewBillet(root, s.bc.GetConfig().KeepOnlyLatestState, storage.NewMemCachedStore(s.dao.Store))
return b.Traverse(process, false)
}
// GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max).
func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 {
s.lock.RLock()
defer s.lock.RUnlock()
return s.mptpool.GetBatch(limit)
}

View file

@ -0,0 +1,106 @@
package statesync
import (
"fmt"
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
)
func TestModule_PR2019_discussion_r689629704(t *testing.T) {
expectedStorage := storage.NewMemCachedStore(storage.NewMemoryStore())
tr := mpt.NewTrie(nil, true, expectedStorage)
require.NoError(t, tr.Put([]byte{0x03}, []byte("leaf1")))
require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x02}, []byte("leaf2")))
require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x04}, []byte("leaf3")))
require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x02}, []byte("leaf2"))) // <-- the same `leaf2` and `leaf3` values are put in the storage,
require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x04}, []byte("leaf3"))) // <-- but the path should differ.
require.NoError(t, tr.Put([]byte{0x06, 0x03}, []byte("leaf4")))
sr := tr.StateRoot()
tr.Flush()
// Keep MPT nodes in a map in order not to repeat them. We'll use `nodes` map to ask
// state sync module to restore the nodes.
var (
nodes = make(map[util.Uint256][]byte)
expectedItems []storage.KeyValue
)
expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) {
key := slice.Copy(k)
value := slice.Copy(v)
expectedItems = append(expectedItems, storage.KeyValue{
Key: key,
Value: value,
})
hash, err := util.Uint256DecodeBytesBE(key[1:])
require.NoError(t, err)
nodeBytes := value[:len(value)-4]
nodes[hash] = nodeBytes
})
actualStorage := storage.NewMemCachedStore(storage.NewMemoryStore())
// These actions are done in module.Init(), but it's not the point of the test.
// Here we want to test only MPT restoring process.
stateSync := &Module{
log: zaptest.NewLogger(t),
syncPoint: 1000500,
syncStage: headersSynced,
syncInterval: 100500,
dao: dao.NewSimple(actualStorage, true, false),
mptpool: NewPool(),
}
stateSync.billet = mpt.NewBillet(sr, true, actualStorage)
stateSync.mptpool.Add(sr, []byte{})
// The test itself: we'll ask state sync module to restore each node exactly once.
// After that storage content (including storage items and refcounts) must
// match exactly the one got from real MPT trie. MPT pool must be empty.
// State sync module must have mptSynced state in the end.
// MPT Billet root must become a collapsed hashnode (it was checked manually).
requested := make(map[util.Uint256]struct{})
for {
unknownHashes := stateSync.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one
if len(unknownHashes) == 0 {
break
}
h := unknownHashes[0]
node, ok := nodes[h]
if !ok {
if _, ok = requested[h]; ok {
t.Fatal("node was requested twice")
}
t.Fatal("unknown node was requested")
}
require.NotPanics(t, func() {
err := stateSync.AddMPTNodes([][]byte{node})
require.NoError(t, err)
}, fmt.Errorf("hash=%s, value=%s", h.StringBE(), string(node)))
requested[h] = struct{}{}
delete(nodes, h)
if len(nodes) == 0 {
break
}
}
require.Equal(t, headersSynced|mptSynced, stateSync.syncStage, "all nodes were sent exactly ones, but MPT wasn't restored")
require.Equal(t, 0, len(nodes), "not all nodes were requested by state sync module")
require.Equal(t, 0, stateSync.mptpool.Count(), "MPT was restored, but MPT pool still contains items")
// Compare resulting storage items and refcounts.
var actualItems []storage.KeyValue
expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) {
key := slice.Copy(k)
value := slice.Copy(v)
actualItems = append(actualItems, storage.KeyValue{
Key: key,
Value: value,
})
})
require.ElementsMatch(t, expectedItems, actualItems)
}

View file

@ -0,0 +1,142 @@
package statesync
import (
"bytes"
"sort"
"sync"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// Pool stores unknown MPT nodes along with the corresponding paths (single node is
// allowed to have multiple MPT paths).
type Pool struct {
lock sync.RWMutex
hashes map[util.Uint256][][]byte
}
// NewPool returns new MPT node hashes pool.
func NewPool() *Pool {
return &Pool{
hashes: make(map[util.Uint256][][]byte),
}
}
// ContainsKey checks if MPT node with the specified hash is in the Pool.
func (mp *Pool) ContainsKey(hash util.Uint256) bool {
mp.lock.RLock()
defer mp.lock.RUnlock()
_, ok := mp.hashes[hash]
return ok
}
// TryGet returns a set of MPT paths for the specified HashNode.
func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) {
mp.lock.RLock()
defer mp.lock.RUnlock()
paths, ok := mp.hashes[hash]
// need to copy here, because we can modify existing array of paths inside the pool.
res := make([][]byte, len(paths))
copy(res, paths)
return res, ok
}
// GetAll returns all MPT nodes with the corresponding paths from the pool.
func (mp *Pool) GetAll() map[util.Uint256][][]byte {
mp.lock.RLock()
defer mp.lock.RUnlock()
return mp.hashes
}
// GetBatch returns set of unknown MPT nodes hashes (`limit` at max).
func (mp *Pool) GetBatch(limit int) []util.Uint256 {
mp.lock.RLock()
defer mp.lock.RUnlock()
count := len(mp.hashes)
if count > limit {
count = limit
}
result := make([]util.Uint256, 0, limit)
for h := range mp.hashes {
if count == 0 {
break
}
result = append(result, h)
count--
}
return result
}
// Remove removes MPT node from the pool by the specified hash.
func (mp *Pool) Remove(hash util.Uint256) {
mp.lock.Lock()
defer mp.lock.Unlock()
delete(mp.hashes, hash)
}
// Add adds path to the set of paths for the specified node.
func (mp *Pool) Add(hash util.Uint256, path []byte) {
mp.lock.Lock()
defer mp.lock.Unlock()
mp.addPaths(hash, [][]byte{path})
}
// Update is an atomic operation and removes/adds specified nodes from/to the pool.
func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) {
mp.lock.Lock()
defer mp.lock.Unlock()
for h, paths := range remove {
old := mp.hashes[h]
for _, path := range paths {
i := sort.Search(len(old), func(i int) bool {
return bytes.Compare(old[i], path) >= 0
})
if i < len(old) && bytes.Equal(old[i], path) {
old = append(old[:i], old[i+1:]...)
}
}
if len(old) == 0 {
delete(mp.hashes, h)
} else {
mp.hashes[h] = old
}
}
for h, paths := range add {
mp.addPaths(h, paths)
}
}
// addPaths adds set of the specified node paths to the pool.
func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) {
old := mp.hashes[nodeHash]
for _, path := range paths {
i := sort.Search(len(old), func(i int) bool {
return bytes.Compare(old[i], path) >= 0
})
if i < len(old) && bytes.Equal(old[i], path) {
// then path is already added
continue
}
old = append(old, path)
if i != len(old)-1 {
copy(old[i+1:], old[i:])
old[i] = path
}
}
mp.hashes[nodeHash] = old
}
// Count returns the number of nodes in the pool.
func (mp *Pool) Count() int {
mp.lock.RLock()
defer mp.lock.RUnlock()
return len(mp.hashes)
}

View file

@ -0,0 +1,123 @@
package statesync
import (
"encoding/hex"
"testing"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestPool_AddRemoveUpdate(t *testing.T) {
mp := NewPool()
i1 := []byte{1, 2, 3}
i1h := util.Uint256{1, 2, 3}
i2 := []byte{2, 3, 4}
i2h := util.Uint256{2, 3, 4}
i3 := []byte{4, 5, 6}
i3h := util.Uint256{3, 4, 5}
i4 := []byte{3, 4, 5} // has the same hash as i3
i5 := []byte{6, 7, 8} // has the same hash as i3
mapAll := map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}, i3h: {i4, i3}}
// No items
_, ok := mp.TryGet(i1h)
require.False(t, ok)
require.False(t, mp.ContainsKey(i1h))
require.Equal(t, 0, mp.Count())
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
// Add i1, i2, check OK
mp.Add(i1h, i1)
mp.Add(i2h, i2)
itm, ok := mp.TryGet(i1h)
require.True(t, ok)
require.Equal(t, [][]byte{i1}, itm)
require.True(t, mp.ContainsKey(i1h))
require.True(t, mp.ContainsKey(i2h))
require.Equal(t, map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Remove i1 and unexisting item
mp.Remove(i3h)
mp.Remove(i1h)
require.False(t, mp.ContainsKey(i1h))
require.True(t, mp.ContainsKey(i2h))
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}}, mp.GetAll())
require.Equal(t, 1, mp.Count())
// Update: remove nothing, add all
mp.Update(nil, mapAll)
require.Equal(t, mapAll, mp.GetAll())
require.Equal(t, 3, mp.Count())
// Update: remove all, add all
mp.Update(mapAll, mapAll)
require.Equal(t, mapAll, mp.GetAll()) // deletion first, addition after that
require.Equal(t, 3, mp.Count())
// Update: remove all, add nothing
mp.Update(mapAll, nil)
require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll())
require.Equal(t, 0, mp.Count())
// Update: remove several, add several
mp.Update(map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}})
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Update: remove nothing, add several with same hashes
mp.Update(nil, map[util.Uint256][][]byte{i3h: {i5, i4}}) // should be sorted by the pool
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i3, i5}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Update: remove several with same hashes, add nothing
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i4}}, nil)
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
// Update: remove several with same hashes, add several with same hashes
mp.Update(map[util.Uint256][][]byte{i3h: {i5, i3}}, map[util.Uint256][][]byte{i3h: {i5, i4}})
require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i5}}, mp.GetAll())
require.Equal(t, 2, mp.Count())
}
func TestPool_GetBatch(t *testing.T) {
check := func(t *testing.T, limit int, itemsCount int) {
mp := NewPool()
for i := 0; i < itemsCount; i++ {
mp.Add(random.Uint256(), []byte{0x01})
}
batch := mp.GetBatch(limit)
if limit < itemsCount {
require.Equal(t, limit, len(batch))
} else {
require.Equal(t, itemsCount, len(batch))
}
}
t.Run("limit less than items count", func(t *testing.T) {
check(t, 5, 6)
})
t.Run("limit more than items count", func(t *testing.T) {
check(t, 6, 5)
})
t.Run("items count limit", func(t *testing.T) {
check(t, 5, 5)
})
}
func TestPool_UpdateUsingSliceFromPool(t *testing.T) {
mp := NewPool()
p1, _ := hex.DecodeString("0f0a0f0f0f0f0f0f0104020b02080c0a06050e070b050404060206060d07080602030b04040b050e040406030f0708060c05")
p2, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040a0b000f04000b03090b02090b0e040f0d0b060d070e0b0b090b0906080602060c0d0f0e0d04070e")
p3, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040b010d01080f050f000a0d0e08060c040b050800050904060f050807080a080c07040d0107080007")
h, _ := util.Uint256DecodeStringBE("57e197679ef031bf2f0b466b20afe3f67ac04dcff80a1dc4d12dd98dd21a2511")
mp.Add(h, p1)
mp.Add(h, p2)
mp.Add(h, p3)
toBeRemoved, ok := mp.TryGet(h)
require.True(t, ok)
mp.Update(map[util.Uint256][][]byte{h: toBeRemoved}, nil)
// test that all items were successfully removed.
require.Equal(t, 0, len(mp.GetAll()))
}

435
pkg/core/statesync_test.go Normal file
View file

@ -0,0 +1,435 @@
package core
import (
"testing"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
"github.com/stretchr/testify/require"
)
func TestStateSyncModule_Init(t *testing.T) {
var (
stateSyncInterval = 2
maxTraceable uint32 = 3
)
spoutCfg := func(c *config.Config) {
c.ProtocolConfiguration.StateRootInHeader = true
c.ProtocolConfiguration.P2PStateExchangeExtensions = true
c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval
c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable
}
bcSpout := newTestChainWithCustomCfg(t, spoutCfg)
for i := 0; i <= 2*stateSyncInterval+int(maxTraceable)+1; i++ {
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
}
boltCfg := func(c *config.Config) {
spoutCfg(c)
c.ProtocolConfiguration.KeepOnlyLatestState = true
c.ProtocolConfiguration.RemoveUntraceableBlocks = true
}
t.Run("error: module disabled by config", func(t *testing.T) {
bcBolt := newTestChainWithCustomCfg(t, func(c *config.Config) {
boltCfg(c)
c.ProtocolConfiguration.RemoveUntraceableBlocks = false
})
module := bcBolt.GetStateSyncModule()
require.Error(t, module.Init(bcSpout.BlockHeight())) // module inactive (non-archival node)
})
t.Run("inactive: spout chain is too low to start state sync process", func(t *testing.T) {
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
module := bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(uint32(2*stateSyncInterval-1)))
require.False(t, module.IsActive())
})
t.Run("inactive: bolt chain height is close enough to spout chain height", func(t *testing.T) {
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
for i := 1; i < int(bcSpout.BlockHeight())-stateSyncInterval; i++ {
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
require.NoError(t, err)
require.NoError(t, bcBolt.AddBlock(b))
}
module := bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.False(t, module.IsActive())
})
t.Run("error: bolt chain is too low to start state sync process", func(t *testing.T) {
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock()))
module := bcBolt.GetStateSyncModule()
require.Error(t, module.Init(uint32(3*stateSyncInterval)))
})
t.Run("initialized: no previous state sync point", func(t *testing.T) {
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
module := bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.True(t, module.NeedHeaders())
require.False(t, module.NeedMPTNodes())
})
t.Run("error: outdated state sync point in the storage", func(t *testing.T) {
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
module := bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
module = bcBolt.GetStateSyncModule()
require.Error(t, module.Init(bcSpout.BlockHeight()+2*uint32(stateSyncInterval)))
})
t.Run("initialized: valid previous state sync point in the storage", func(t *testing.T) {
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
module := bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
module = bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.True(t, module.NeedHeaders())
require.False(t, module.NeedMPTNodes())
})
t.Run("initialization from headers/blocks/mpt synced stages", func(t *testing.T) {
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
module := bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
// firstly, fetch all headers to create proper DB state (where headers are in sync)
stateSyncPoint := (int(bcSpout.BlockHeight()) / stateSyncInterval) * stateSyncInterval
var expectedHeader *block.Header
for i := 1; i <= int(bcSpout.HeaderHeight()); i++ {
header, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(i))
require.NoError(t, err)
require.NoError(t, module.AddHeaders(header))
if i == stateSyncPoint+1 {
expectedHeader = header
}
}
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.False(t, module.NeedHeaders())
require.True(t, module.NeedMPTNodes())
// then create new statesync module with the same DB and check that state is proper
// (headers are in sync)
module = bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.False(t, module.NeedHeaders())
require.True(t, module.NeedMPTNodes())
unknownNodes := module.GetUnknownMPTNodesBatch(2)
require.Equal(t, 1, len(unknownNodes))
require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0])
// add several blocks to create DB state where blocks are not in sync yet, but it's not a genesis.
for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint-stateSyncInterval-1; i++ {
block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
require.NoError(t, err)
require.NoError(t, module.AddBlock(block))
}
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.False(t, module.NeedHeaders())
require.True(t, module.NeedMPTNodes())
require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight())
// then create new statesync module with the same DB and check that state is proper
// (blocks are not in sync yet)
module = bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.False(t, module.NeedHeaders())
require.True(t, module.NeedMPTNodes())
unknownNodes = module.GetUnknownMPTNodesBatch(2)
require.Equal(t, 1, len(unknownNodes))
require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0])
require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight())
// add rest of blocks to create DB state where blocks are in sync
for i := stateSyncPoint - stateSyncInterval; i <= stateSyncPoint; i++ {
block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
require.NoError(t, err)
require.NoError(t, module.AddBlock(block))
}
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.False(t, module.NeedHeaders())
require.True(t, module.NeedMPTNodes())
lastBlock, err := bcBolt.GetBlock(expectedHeader.PrevHash)
require.NoError(t, err)
require.Equal(t, uint32(stateSyncPoint), lastBlock.Index)
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
// then create new statesync module with the same DB and check that state is proper
// (headers and blocks are in sync)
module = bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.False(t, module.NeedHeaders())
require.True(t, module.NeedMPTNodes())
unknownNodes = module.GetUnknownMPTNodesBatch(2)
require.Equal(t, 1, len(unknownNodes))
require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0])
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
// add a few MPT nodes to create DB state where some of MPT nodes are missing
count := 5
for {
unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one
if len(unknownHashes) == 0 {
break
}
err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool {
require.NoError(t, module.AddMPTNodes([][]byte{nodeBytes}))
return true // add nodes one-by-one
})
require.NoError(t, err)
count--
if count < 0 {
break
}
}
// then create new statesync module with the same DB and check that state is proper
// (headers and blocks are in sync, mpt is not yet synced)
module = bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.False(t, module.NeedHeaders())
require.True(t, module.NeedMPTNodes())
unknownNodes = module.GetUnknownMPTNodesBatch(100)
require.True(t, len(unknownNodes) > 0)
require.NotContains(t, unknownNodes, expectedHeader.PrevStateRoot)
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
// add the rest of MPT nodes and jump to state
for {
unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one
if len(unknownHashes) == 0 {
break
}
err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool {
require.NoError(t, module.AddMPTNodes([][]byte{slice.Copy(nodeBytes)}))
return true // add nodes one-by-one
})
require.NoError(t, err)
}
// check that module is inactive and statejump is completed
require.False(t, module.IsActive())
require.False(t, module.NeedHeaders())
require.False(t, module.NeedMPTNodes())
unknownNodes = module.GetUnknownMPTNodesBatch(1)
require.True(t, len(unknownNodes) == 0)
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight())
// create new module from completed state: the module should recognise that state sync is completed
module = bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.False(t, module.IsActive())
require.False(t, module.NeedHeaders())
require.False(t, module.NeedMPTNodes())
unknownNodes = module.GetUnknownMPTNodesBatch(1)
require.True(t, len(unknownNodes) == 0)
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight())
// add one more block to the restored chain and start new module: the module should recognise state sync is completed
// and regular blocks processing was started
require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock()))
module = bcBolt.GetStateSyncModule()
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.False(t, module.IsActive())
require.False(t, module.NeedHeaders())
require.False(t, module.NeedMPTNodes())
unknownNodes = module.GetUnknownMPTNodesBatch(1)
require.True(t, len(unknownNodes) == 0)
require.Equal(t, uint32(stateSyncPoint)+1, module.BlockHeight())
require.Equal(t, uint32(stateSyncPoint)+1, bcBolt.BlockHeight())
})
}
func TestStateSyncModule_RestoreBasicChain(t *testing.T) {
var (
stateSyncInterval = 4
maxTraceable uint32 = 6
stateSyncPoint = 16
)
spoutCfg := func(c *config.Config) {
c.ProtocolConfiguration.StateRootInHeader = true
c.ProtocolConfiguration.P2PStateExchangeExtensions = true
c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval
c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable
}
bcSpout := newTestChainWithCustomCfg(t, spoutCfg)
initBasicChain(t, bcSpout)
// make spout chain higher that latest state sync point
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock()))
require.Equal(t, uint32(stateSyncPoint+2), bcSpout.BlockHeight())
boltCfg := func(c *config.Config) {
spoutCfg(c)
c.ProtocolConfiguration.KeepOnlyLatestState = true
c.ProtocolConfiguration.RemoveUntraceableBlocks = true
}
bcBolt := newTestChainWithCustomCfg(t, boltCfg)
module := bcBolt.GetStateSyncModule()
t.Run("error: add headers before initialisation", func(t *testing.T) {
h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(1))
require.NoError(t, err)
require.Error(t, module.AddHeaders(h))
})
t.Run("no error: add blocks before initialisation", func(t *testing.T) {
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(1))
require.NoError(t, err)
require.NoError(t, module.AddBlock(b))
})
t.Run("error: add MPT nodes without initialisation", func(t *testing.T) {
require.Error(t, module.AddMPTNodes([][]byte{}))
})
require.NoError(t, module.Init(bcSpout.BlockHeight()))
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.True(t, module.NeedHeaders())
require.False(t, module.NeedMPTNodes())
// add headers to module
headers := make([]*block.Header, 0, bcSpout.HeaderHeight())
for i := uint32(1); i <= bcSpout.HeaderHeight(); i++ {
h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(int(i)))
require.NoError(t, err)
headers = append(headers, h)
}
require.NoError(t, module.AddHeaders(headers...))
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.False(t, module.NeedHeaders())
require.True(t, module.NeedMPTNodes())
require.Equal(t, bcSpout.HeaderHeight(), bcBolt.HeaderHeight())
// add blocks
t.Run("error: unexpected block index", func(t *testing.T) {
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(stateSyncPoint - int(maxTraceable)))
require.NoError(t, err)
require.Error(t, module.AddBlock(b))
})
t.Run("error: missing state root in block header", func(t *testing.T) {
b := &block.Block{
Header: block.Header{
Index: uint32(stateSyncPoint) - maxTraceable + 1,
StateRootEnabled: false,
},
}
require.Error(t, module.AddBlock(b))
})
t.Run("error: invalid block merkle root", func(t *testing.T) {
b := &block.Block{
Header: block.Header{
Index: uint32(stateSyncPoint) - maxTraceable + 1,
StateRootEnabled: true,
MerkleRoot: util.Uint256{1, 2, 3},
},
}
require.Error(t, module.AddBlock(b))
})
for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint; i++ {
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
require.NoError(t, err)
require.NoError(t, module.AddBlock(b))
}
require.True(t, module.IsActive())
require.True(t, module.IsInitialized())
require.False(t, module.NeedHeaders())
require.True(t, module.NeedMPTNodes())
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
// add MPT nodes in batches
h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(stateSyncPoint + 1))
require.NoError(t, err)
unknownHashes := module.GetUnknownMPTNodesBatch(100)
require.Equal(t, 1, len(unknownHashes))
require.Equal(t, h.PrevStateRoot, unknownHashes[0])
nodesMap := make(map[util.Uint256][]byte)
err = bcSpout.GetStateSyncModule().Traverse(h.PrevStateRoot, func(n mpt.Node, nodeBytes []byte) bool {
nodesMap[n.Hash()] = nodeBytes
return false
})
require.NoError(t, err)
for {
need := module.GetUnknownMPTNodesBatch(10)
if len(need) == 0 {
break
}
add := make([][]byte, len(need))
for i, h := range need {
nodeBytes, ok := nodesMap[h]
if !ok {
t.Fatal("unknown or restored node requested")
}
add[i] = nodeBytes
delete(nodesMap, h)
}
require.NoError(t, module.AddMPTNodes(add))
}
require.False(t, module.IsActive())
require.False(t, module.NeedHeaders())
require.False(t, module.NeedMPTNodes())
unknownNodes := module.GetUnknownMPTNodesBatch(1)
require.True(t, len(unknownNodes) == 0)
require.Equal(t, uint32(stateSyncPoint), module.BlockHeight())
require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight())
// add missing blocks to bcBolt: should be ok, because state is synced
for i := stateSyncPoint + 1; i <= int(bcSpout.BlockHeight()); i++ {
b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i))
require.NoError(t, err)
require.NoError(t, bcBolt.AddBlock(b))
}
require.Equal(t, bcSpout.BlockHeight(), bcBolt.BlockHeight())
// compare storage states
fetchStorage := func(bc *Blockchain) []storage.KeyValue {
var kv []storage.KeyValue
bc.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) {
key := slice.Copy(k)
value := slice.Copy(v)
kv = append(kv, storage.KeyValue{
Key: key,
Value: value,
})
})
return kv
}
expected := fetchStorage(bcSpout)
actual := fetchStorage(bcBolt)
require.ElementsMatch(t, expected, actual)
// no temp items should be left
bcBolt.dao.Store.Seek(storage.STTempStorage.Bytes(), func(k, v []byte) {
t.Fatal("temp storage items are found")
})
}

View file

@ -8,13 +8,18 @@ import (
// KeyPrefix constants.
const (
DataBlock KeyPrefix = 0x01
DataTransaction KeyPrefix = 0x02
DataMPT KeyPrefix = 0x03
STAccount KeyPrefix = 0x40
STNotification KeyPrefix = 0x4d
STContractID KeyPrefix = 0x51
STStorage KeyPrefix = 0x70
DataBlock KeyPrefix = 0x01
DataTransaction KeyPrefix = 0x02
DataMPT KeyPrefix = 0x03
STAccount KeyPrefix = 0x40
STNotification KeyPrefix = 0x4d
STContractID KeyPrefix = 0x51
STStorage KeyPrefix = 0x70
// STTempStorage is used to store contract storage items during state sync process
// in order not to mess up the previous state which has its own items stored by
// STStorage prefix. Once state exchange process is completed, all items with
// STStorage prefix will be replaced with STTempStorage-prefixed ones.
STTempStorage KeyPrefix = 0x71
STNEP17Transfers KeyPrefix = 0x72
STNEP17TransferInfo KeyPrefix = 0x73
IXHeaderHashList KeyPrefix = 0x80
@ -22,6 +27,7 @@ const (
SYSCurrentHeader KeyPrefix = 0xc1
SYSStateSyncCurrentBlockHeight KeyPrefix = 0xc2
SYSStateSyncPoint KeyPrefix = 0xc3
SYSStateJumpStage KeyPrefix = 0xc4
SYSVersion KeyPrefix = 0xf0
)

View file

@ -4,6 +4,7 @@ import (
"github.com/Workiva/go-datastructures/queue"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"go.uber.org/atomic"
"go.uber.org/zap"
)
@ -13,6 +14,7 @@ type blockQueue struct {
checkBlocks chan struct{}
chain blockchainer.Blockqueuer
relayF func(*block.Block)
discarded *atomic.Bool
}
const (
@ -32,6 +34,7 @@ func newBlockQueue(capacity int, bc blockchainer.Blockqueuer, log *zap.Logger, r
checkBlocks: make(chan struct{}, 1),
chain: bc,
relayF: relayer,
discarded: atomic.NewBool(false),
}
}
@ -91,8 +94,10 @@ func (bq *blockQueue) putBlock(block *block.Block) error {
}
func (bq *blockQueue) discard() {
close(bq.checkBlocks)
bq.queue.Dispose()
if bq.discarded.CAS(false, true) {
close(bq.checkBlocks)
bq.queue.Dispose()
}
}
func (bq *blockQueue) length() int {

View file

@ -71,6 +71,8 @@ const (
CMDBlock = CommandType(payload.BlockType)
CMDExtensible = CommandType(payload.ExtensibleType)
CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType)
CMDGetMPTData CommandType = 0x51 // 0x5.. commands are used for extensions (P2PNotary, state exchange cmds)
CMDMPTData CommandType = 0x52
CMDReject CommandType = 0x2f
// SPV protocol.
@ -136,6 +138,10 @@ func (m *Message) decodePayload() error {
p = &payload.Version{}
case CMDInv, CMDGetData:
p = &payload.Inventory{}
case CMDGetMPTData:
p = &payload.MPTInventory{}
case CMDMPTData:
p = &payload.MPTData{}
case CMDAddr:
p = &payload.AddressList{}
case CMDBlock:
@ -221,7 +227,7 @@ func (m *Message) tryCompressPayload() error {
if m.Flags&Compressed == 0 {
switch m.Payload.(type) {
case *payload.Headers, *payload.MerkleBlock, payload.NullPayload,
*payload.Inventory:
*payload.Inventory, *payload.MPTInventory:
break
default:
size := len(compressedPayload)

View file

@ -26,6 +26,8 @@ func _() {
_ = x[CMDBlock-44]
_ = x[CMDExtensible-46]
_ = x[CMDP2PNotaryRequest-80]
_ = x[CMDGetMPTData-81]
_ = x[CMDMPTData-82]
_ = x[CMDReject-47]
_ = x[CMDFilterLoad-48]
_ = x[CMDFilterAdd-49]
@ -44,7 +46,7 @@ const (
_CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
_CommandType_name_7 = "CMDMerkleBlock"
_CommandType_name_8 = "CMDAlert"
_CommandType_name_9 = "CMDP2PNotaryRequest"
_CommandType_name_9 = "CMDP2PNotaryRequestCMDGetMPTDataCMDMPTData"
)
var (
@ -55,6 +57,7 @@ var (
_CommandType_index_4 = [...]uint8{0, 12, 22}
_CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58}
_CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61}
_CommandType_index_9 = [...]uint8{0, 19, 32, 42}
)
func (i CommandType) String() string {
@ -83,8 +86,9 @@ func (i CommandType) String() string {
return _CommandType_name_7
case i == 64:
return _CommandType_name_8
case i == 80:
return _CommandType_name_9
case 80 <= i && i <= 82:
i -= 80
return _CommandType_name_9[_CommandType_index_9[i]:_CommandType_index_9[i+1]]
default:
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
}

View file

@ -258,6 +258,21 @@ func TestEncodeDecodeNotFound(t *testing.T) {
})
}
func TestEncodeDecodeGetMPTData(t *testing.T) {
testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{
Hashes: []util.Uint256{
{1, 2, 3},
{4, 5, 6},
},
})
}
func TestEncodeDecodeMPTData(t *testing.T) {
testEncodeDecode(t, CMDMPTData, &payload.MPTData{
Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}},
})
}
func TestInvalidMessages(t *testing.T) {
t.Run("CMDBlock, empty payload", func(t *testing.T) {
testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})

View file

@ -0,0 +1,35 @@
package payload
import (
"errors"
"github.com/nspcc-dev/neo-go/pkg/io"
)
// MPTData represents the set of serialized MPT nodes.
type MPTData struct {
Nodes [][]byte
}
// EncodeBinary implements io.Serializable.
func (d *MPTData) EncodeBinary(w *io.BinWriter) {
w.WriteVarUint(uint64(len(d.Nodes)))
for _, n := range d.Nodes {
w.WriteVarBytes(n)
}
}
// DecodeBinary implements io.Serializable.
func (d *MPTData) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint()
if sz == 0 {
r.Err = errors.New("empty MPT nodes list")
return
}
for i := uint64(0); i < sz; i++ {
d.Nodes = append(d.Nodes, r.ReadVarBytes())
if r.Err != nil {
return
}
}
}

View file

@ -0,0 +1,24 @@
package payload
import (
"testing"
"github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/stretchr/testify/require"
)
func TestMPTData_EncodeDecodeBinary(t *testing.T) {
t.Run("empty", func(t *testing.T) {
d := new(MPTData)
bytes, err := testserdes.EncodeBinary(d)
require.NoError(t, err)
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTData)))
})
t.Run("good", func(t *testing.T) {
d := &MPTData{
Nodes: [][]byte{{}, {1}, {1, 2, 3}},
}
testserdes.EncodeDecodeBinary(t, d, new(MPTData))
})
}

View file

@ -0,0 +1,32 @@
package payload
import (
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// MaxMPTHashesCount is the maximum number of requested MPT nodes hashes.
const MaxMPTHashesCount = 32
// MPTInventory payload.
type MPTInventory struct {
// A list of requested MPT nodes hashes.
Hashes []util.Uint256
}
// NewMPTInventory return a pointer to an MPTInventory.
func NewMPTInventory(hashes []util.Uint256) *MPTInventory {
return &MPTInventory{
Hashes: hashes,
}
}
// DecodeBinary implements Serializable interface.
func (p *MPTInventory) DecodeBinary(br *io.BinReader) {
br.ReadArray(&p.Hashes, MaxMPTHashesCount)
}
// EncodeBinary implements Serializable interface.
func (p *MPTInventory) EncodeBinary(bw *io.BinWriter) {
bw.WriteArray(p.Hashes)
}

View file

@ -0,0 +1,38 @@
package payload
import (
"testing"
"github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestMPTInventory_EncodeDecodeBinary(t *testing.T) {
t.Run("empty", func(t *testing.T) {
testserdes.EncodeDecodeBinary(t, NewMPTInventory([]util.Uint256{}), new(MPTInventory))
})
t.Run("good", func(t *testing.T) {
inv := NewMPTInventory([]util.Uint256{{1, 2, 3}, {2, 3, 4}})
testserdes.EncodeDecodeBinary(t, inv, new(MPTInventory))
})
t.Run("too large", func(t *testing.T) {
check := func(t *testing.T, count int, fail bool) {
h := make([]util.Uint256, count)
for i := range h {
h[i] = util.Uint256{1, 2, 3}
}
if fail {
bytes, err := testserdes.EncodeBinary(NewMPTInventory(h))
require.NoError(t, err)
require.Error(t, testserdes.DecodeBinary(bytes, new(MPTInventory)))
} else {
testserdes.EncodeDecodeBinary(t, NewMPTInventory(h), new(MPTInventory))
}
}
check(t, MaxMPTHashesCount, false)
check(t, MaxMPTHashesCount+1, true)
})
}

View file

@ -7,6 +7,7 @@ import (
"fmt"
mrand "math/rand"
"net"
"sort"
"strconv"
"sync"
"time"
@ -17,7 +18,9 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/mempoolevent"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/extpool"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
@ -67,6 +70,7 @@ type (
discovery Discoverer
chain blockchainer.Blockchainer
bQueue *blockQueue
bSyncQueue *blockQueue
consensus consensus.Service
mempool *mempool.Pool
notaryRequestPool *mempool.Pool
@ -93,6 +97,7 @@ type (
oracle *oracle.Oracle
stateRoot stateroot.Service
stateSync blockchainer.StateSync
log *zap.Logger
}
@ -191,6 +196,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
}
s.stateRoot = sr
sSync := chain.GetStateSyncModule()
s.stateSync = sSync
s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil)
if config.OracleCfg.Enabled {
orcCfg := oracle.Config{
Log: log,
@ -277,6 +286,7 @@ func (s *Server) Start(errChan chan error) {
go s.broadcastTxLoop()
go s.relayBlocksLoop()
go s.bQueue.run()
go s.bSyncQueue.run()
go s.transport.Accept()
setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10))
s.run()
@ -292,6 +302,7 @@ func (s *Server) Shutdown() {
p.Disconnect(errServerShutdown)
}
s.bQueue.discard()
s.bSyncQueue.discard()
if s.StateRootCfg.Enabled {
s.stateRoot.Shutdown()
}
@ -573,6 +584,10 @@ func (s *Server) IsInSync() bool {
var peersNumber int
var notHigher int
if s.stateSync.IsActive() {
return false
}
if s.MinPeers == 0 {
return true
}
@ -630,6 +645,9 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error {
// handleBlockCmd processes the received block received from its peer.
func (s *Server) handleBlockCmd(p Peer, block *block.Block) error {
if s.stateSync.IsActive() {
return s.bSyncQueue.putBlock(block)
}
return s.bQueue.putBlock(block)
}
@ -639,25 +657,57 @@ func (s *Server) handlePing(p Peer, ping *payload.Ping) error {
if err != nil {
return err
}
if s.chain.BlockHeight() < ping.LastBlockIndex {
err = s.requestBlocks(p)
if err != nil {
return err
}
err = s.requestBlocksOrHeaders(p)
if err != nil {
return err
}
return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id)))
}
func (s *Server) requestBlocksOrHeaders(p Peer) error {
if s.stateSync.NeedHeaders() {
if s.chain.HeaderHeight() < p.LastBlockIndex() {
return s.requestHeaders(p)
}
return nil
}
var (
bq blockchainer.Blockqueuer = s.chain
requestMPTNodes bool
)
if s.stateSync.IsActive() {
bq = s.stateSync
requestMPTNodes = s.stateSync.NeedMPTNodes()
}
if bq.BlockHeight() >= p.LastBlockIndex() {
return nil
}
err := s.requestBlocks(bq, p)
if err != nil {
return err
}
if requestMPTNodes {
return s.requestMPTNodes(p, s.stateSync.GetUnknownMPTNodesBatch(payload.MaxMPTHashesCount))
}
return nil
}
// requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers.
func (s *Server) requestHeaders(p Peer) error {
// TODO: optimize
currHeight := s.chain.HeaderHeight()
needHeight := currHeight + 1
payload := payload.NewGetBlockByIndex(needHeight, -1)
return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload))
}
// handlePing processes pong request.
func (s *Server) handlePong(p Peer, pong *payload.Ping) error {
err := p.HandlePong(pong)
if err != nil {
return err
}
if s.chain.BlockHeight() < pong.LastBlockIndex {
return s.requestBlocks(p)
}
return nil
return s.requestBlocksOrHeaders(p)
}
// handleInvCmd processes the received inventory.
@ -766,6 +816,69 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
return nil
}
// handleGetMPTDataCmd processes the received MPT inventory.
func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error {
if !s.chain.GetConfig().P2PStateExchangeExtensions {
return errors.New("GetMPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
}
if s.chain.GetConfig().KeepOnlyLatestState {
// TODO: implement keeping MPT states for P1 and P2 height (#2095, #2152 related)
return errors.New("GetMPTDataCMD was received, but only latest MPT state is supported")
}
resp := payload.MPTData{}
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
added := make(map[util.Uint256]struct{})
for _, h := range inv.Hashes {
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
break
}
err := s.stateSync.Traverse(h,
func(n mpt.Node, node []byte) bool {
if _, ok := added[n.Hash()]; ok {
return false
}
l := len(node)
size := l + io.GetVarSize(l)
if size > capLeft {
return true
}
resp.Nodes = append(resp.Nodes, node)
added[n.Hash()] = struct{}{}
capLeft -= size
return false
})
if err != nil {
return fmt.Errorf("failed to traverse MPT starting from %s: %w", h.StringBE(), err)
}
}
if len(resp.Nodes) > 0 {
msg := NewMessage(CMDMPTData, &resp)
return p.EnqueueP2PMessage(msg)
}
return nil
}
func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error {
if !s.chain.GetConfig().P2PStateExchangeExtensions {
return errors.New("MPTDataCMD was received, but P2PStateExchangeExtensions are disabled")
}
return s.stateSync.AddMPTNodes(data.Nodes)
}
// requestMPTNodes requests specified MPT nodes from the peer or broadcasts
// request if peer is not specified.
func (s *Server) requestMPTNodes(p Peer, itms []util.Uint256) error {
if len(itms) == 0 {
return nil
}
if len(itms) > payload.MaxMPTHashesCount {
itms = itms[:payload.MaxMPTHashesCount]
}
pl := payload.NewMPTInventory(itms)
msg := NewMessage(CMDGetMPTData, pl)
return p.EnqueueP2PMessage(msg)
}
// handleGetBlocksCmd processes the getblocks request.
func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
count := gb.Count
@ -845,6 +958,11 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error
return p.EnqueueP2PMessage(msg)
}
// handleHeadersCmd processes headers payload.
func (s *Server) handleHeadersCmd(p Peer, h *payload.Headers) error {
return s.stateSync.AddHeaders(h.Hdrs...)
}
// handleExtensibleCmd processes received extensible payload.
func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
if !s.syncReached.Load() {
@ -993,8 +1111,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error {
// 1. Block range is divided into chunks of payload.MaxHashesCount.
// 2. Send requests for chunk in increasing order.
// 3. After all requests were sent, request random height.
func (s *Server) requestBlocks(p Peer) error {
var currHeight = s.chain.BlockHeight()
func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error {
var currHeight = bq.BlockHeight()
var peerHeight = p.LastBlockIndex()
var needHeight uint32
// lastRequestedHeight can only be increased.
@ -1051,9 +1169,18 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
case CMDGetData:
inv := msg.Payload.(*payload.Inventory)
return s.handleGetDataCmd(peer, inv)
case CMDGetMPTData:
inv := msg.Payload.(*payload.MPTInventory)
return s.handleGetMPTDataCmd(peer, inv)
case CMDMPTData:
inv := msg.Payload.(*payload.MPTData)
return s.handleMPTDataCmd(peer, inv)
case CMDGetHeaders:
gh := msg.Payload.(*payload.GetBlockByIndex)
return s.handleGetHeadersCmd(peer, gh)
case CMDHeaders:
h := msg.Payload.(*payload.Headers)
return s.handleHeadersCmd(peer, h)
case CMDInv:
inventory := msg.Payload.(*payload.Inventory)
return s.handleInvCmd(peer, inventory)
@ -1093,6 +1220,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
}
go peer.StartProtocol()
s.tryInitStateSync()
s.tryStartServices()
default:
return fmt.Errorf("received '%s' during handshake", msg.Command.String())
@ -1101,6 +1229,52 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return nil
}
func (s *Server) tryInitStateSync() {
if !s.stateSync.IsActive() {
s.bSyncQueue.discard()
return
}
if s.stateSync.IsInitialized() {
return
}
var peersNumber int
s.lock.RLock()
heights := make([]uint32, 0)
for p := range s.peers {
if p.Handshaked() {
peersNumber++
peerLastBlock := p.LastBlockIndex()
i := sort.Search(len(heights), func(i int) bool {
return heights[i] >= peerLastBlock
})
heights = append(heights, peerLastBlock)
if i != len(heights)-1 {
copy(heights[i+1:], heights[i:])
heights[i] = peerLastBlock
}
}
}
s.lock.RUnlock()
if peersNumber >= s.MinPeers && len(heights) > 0 {
// choose the height of the median peer as current chain's height
h := heights[len(heights)/2]
err := s.stateSync.Init(h)
if err != nil {
s.log.Fatal("failed to init state sync module",
zap.Uint32("evaluated chain's blockHeight", h),
zap.Uint32("blockHeight", s.chain.BlockHeight()),
zap.Uint32("headerHeight", s.chain.HeaderHeight()),
zap.Error(err))
}
// module can be inactive after init (i.e. full state is collected and ordinary block processing is needed)
if !s.stateSync.IsActive() {
s.bSyncQueue.discard()
}
}
}
func (s *Server) handleNewPayload(p *payload.Extensible) {
_, err := s.extensiblePool.Add(p)
if err != nil {

View file

@ -2,6 +2,7 @@ package network
import (
"errors"
"fmt"
"math/big"
"net"
"strconv"
@ -16,6 +17,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
@ -46,7 +48,10 @@ func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs =
func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
func TestNewServer(t *testing.T) {
bc := &fakechain.FakeChain{}
bc := &fakechain.FakeChain{ProtocolConfiguration: config.ProtocolConfiguration{
P2PStateExchangeExtensions: true,
StateRootInHeader: true,
}}
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
require.Error(t, err)
@ -733,6 +738,139 @@ func TestInv(t *testing.T) {
})
}
func TestHandleGetMPTData(t *testing.T) {
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
s := startTestServer(t)
p := newLocalPeer(t, s)
p.handshaked = true
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
Hashes: []util.Uint256{{1, 2, 3}},
})
require.Error(t, s.handleMessage(p, msg))
})
t.Run("KeepOnlyLatestState on", func(t *testing.T) {
s := startTestServer(t)
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true
p := newLocalPeer(t, s)
p.handshaked = true
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
Hashes: []util.Uint256{{1, 2, 3}},
})
require.Error(t, s.handleMessage(p, msg))
})
t.Run("good", func(t *testing.T) {
s := startTestServer(t)
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
var recvResponse atomic.Bool
r1 := random.Uint256()
r2 := random.Uint256()
r3 := random.Uint256()
node := []byte{1, 2, 3}
s.stateSync.(*fakechain.FakeStateSync).TraverseFunc = func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
if !(root.Equals(r1) || root.Equals(r2)) {
t.Fatal("unexpected root")
}
require.False(t, process(mpt.NewHashNode(r3), node))
return nil
}
found := &payload.MPTData{
Nodes: [][]byte{node}, // no duplicates expected
}
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
switch msg.Command {
case CMDMPTData:
require.Equal(t, found, msg.Payload)
recvResponse.Store(true)
}
}
hs := []util.Uint256{r1, r2}
s.testHandleMessage(t, p, CMDGetMPTData, payload.NewMPTInventory(hs))
require.Eventually(t, recvResponse.Load, time.Second, time.Millisecond)
})
}
func TestHandleMPTData(t *testing.T) {
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
s := startTestServer(t)
p := newLocalPeer(t, s)
p.handshaked = true
msg := NewMessage(CMDMPTData, &payload.MPTData{
Nodes: [][]byte{{1, 2, 3}},
})
require.Error(t, s.handleMessage(p, msg))
})
t.Run("good", func(t *testing.T) {
s := startTestServer(t)
expected := [][]byte{{1, 2, 3}, {2, 3, 4}}
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
s.stateSync = &fakechain.FakeStateSync{
AddMPTNodesFunc: func(nodes [][]byte) error {
require.Equal(t, expected, nodes)
return nil
},
}
p := newLocalPeer(t, s)
p.handshaked = true
msg := NewMessage(CMDMPTData, &payload.MPTData{
Nodes: expected,
})
require.NoError(t, s.handleMessage(p, msg))
})
}
func TestRequestMPTNodes(t *testing.T) {
s := startTestServer(t)
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDGetMPTData {
actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...)
}
}
s.register <- p
s.register <- p // ensure previous send was handled
t.Run("no hashes, no message", func(t *testing.T) {
actual = nil
require.NoError(t, s.requestMPTNodes(p, nil))
require.Nil(t, actual)
})
t.Run("good, small", func(t *testing.T) {
actual = nil
expected := []util.Uint256{random.Uint256(), random.Uint256()}
require.NoError(t, s.requestMPTNodes(p, expected))
require.Equal(t, expected, actual)
})
t.Run("good, exactly one chunk", func(t *testing.T) {
actual = nil
expected := make([]util.Uint256, payload.MaxMPTHashesCount)
for i := range expected {
expected[i] = random.Uint256()
}
require.NoError(t, s.requestMPTNodes(p, expected))
require.Equal(t, expected, actual)
})
t.Run("good, too large chunk", func(t *testing.T) {
actual = nil
expected := make([]util.Uint256, payload.MaxMPTHashesCount+1)
for i := range expected {
expected[i] = random.Uint256()
}
require.NoError(t, s.requestMPTNodes(p, expected))
require.Equal(t, expected[:payload.MaxMPTHashesCount], actual)
})
}
func TestRequestTx(t *testing.T) {
s := startTestServer(t)
@ -899,3 +1037,44 @@ func TestVerifyNotaryRequest(t *testing.T) {
require.NoError(t, verifyNotaryRequest(bc, nil, r))
})
}
func TestTryInitStateSync(t *testing.T) {
t.Run("module inactive", func(t *testing.T) {
s := startTestServer(t)
s.tryInitStateSync()
})
t.Run("module already initialized", func(t *testing.T) {
s := startTestServer(t)
ss := &fakechain.FakeStateSync{}
ss.IsActiveFlag.Store(true)
ss.IsInitializedFlag.Store(true)
s.stateSync = ss
s.tryInitStateSync()
})
t.Run("good", func(t *testing.T) {
s := startTestServer(t)
for _, h := range []uint32{10, 8, 7, 4, 11, 4} {
p := newLocalPeer(t, s)
p.handshaked = true
p.lastBlockIndex = h
s.peers[p] = true
}
p := newLocalPeer(t, s)
p.handshaked = false // one disconnected peer to check it won't be taken into attention
p.lastBlockIndex = 5
s.peers[p] = true
var expectedH uint32 = 8 // median peer
ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error {
if h != expectedH {
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
}
return nil
}}
ss.IsActiveFlag.Store(true)
s.stateSync = ss
s.tryInitStateSync()
})
}

View file

@ -267,12 +267,10 @@ func (p *TCPPeer) StartProtocol() {
zap.Uint32("id", p.Version().Nonce))
p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities)
if p.server.chain.BlockHeight() < p.LastBlockIndex() {
err = p.server.requestBlocks(p)
if err != nil {
p.Disconnect(err)
return
}
err = p.server.requestBlocksOrHeaders(p)
if err != nil {
p.Disconnect(err)
return
}
timer := time.NewTimer(p.server.ProtoTickInterval)
@ -281,10 +279,8 @@ func (p *TCPPeer) StartProtocol() {
case <-p.done:
return
case <-timer.C:
// Try to sync in headers and block with the peer if his block height is higher then ours.
if p.LastBlockIndex() > p.server.chain.BlockHeight() {
err = p.server.requestBlocks(p)
}
// Try to sync in headers and block with the peer if his block height is higher than ours.
err = p.server.requestBlocksOrHeaders(p)
if err == nil {
timer.Reset(p.server.ProtoTickInterval)
}

View file

@ -679,6 +679,7 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err
if err != nil {
return nil, response.NewRPCError("Failed to get NEP17 last updated block", err.Error(), err)
}
stateSyncPoint := lastUpdated[math.MinInt32]
bw := io.NewBufBinWriter()
for _, h := range s.chain.GetNEP17Contracts() {
balance, err := s.getNEP17Balance(h, u, bw)
@ -692,10 +693,18 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err
if cs == nil {
continue
}
lub, ok := lastUpdated[cs.ID]
if !ok {
cfg := s.chain.GetConfig()
if !cfg.P2PStateExchangeExtensions && cfg.RemoveUntraceableBlocks {
return nil, response.NewInternalServerError(fmt.Sprintf("failed to get LastUpdatedBlock for balance of %s token", cs.Hash.StringLE()), nil)
}
lub = stateSyncPoint
}
bs.Balances = append(bs.Balances, result.NEP17Balance{
Asset: h,
Amount: balance.String(),
LastUpdated: lastUpdated[cs.ID],
LastUpdated: lub,
})
}
return bs, nil