From a22b1caa3e4fddfe2e0ddad18f48edfbc4bb577d Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Thu, 29 Jul 2021 18:00:07 +0300 Subject: [PATCH 01/15] core: implement MPT Billet structure for MPT restore MPT restore process is much simpler then regular MPT maintaining: trie has a fixed structure, we don't need to remove or rebuild MPT nodes. The only thing we should do is to replace Hash nodes to their unhashed counterparts and increment refcount. It's better not to touch the regular MPT code and create a separate structure for this. --- pkg/core/mpt/billet.go | 261 +++++++++++++++++++++++++++++++++++ pkg/core/mpt/helpers.go | 9 ++ pkg/core/mpt/helpers_test.go | 20 +++ 3 files changed, 290 insertions(+) create mode 100644 pkg/core/mpt/billet.go create mode 100644 pkg/core/mpt/helpers_test.go diff --git a/pkg/core/mpt/billet.go b/pkg/core/mpt/billet.go new file mode 100644 index 000000000..b47117e4a --- /dev/null +++ b/pkg/core/mpt/billet.go @@ -0,0 +1,261 @@ +package mpt + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" +) + +var ( + // ErrRestoreFailed is returned when replacing HashNode by its "unhashed" + // candidate fails. + ErrRestoreFailed = errors.New("failed to restore MPT node") + errStop = errors.New("stop condition is met") +) + +// Billet is a part of MPT trie with missing hash nodes that need to be restored. +// Billet is based on the following assumptions: +// 1. Refcount can only be incremented (we don't change MPT structure during restore, +// thus don't need to decrease refcount). +// 2. TODO: Each time the part of Billet is completely restored, it is collapsed into HashNode. +type Billet struct { + Store *storage.MemCachedStore + + root Node + refcountEnabled bool +} + +// NewBillet returns new billet for MPT trie restoring. It accepts a MemCachedStore +// to decouple storage errors from logic errors so that all storage errors are +// processed during `store.Persist()` at the caller. This also has the benefit, +// that every `Put` can be considered an atomic operation. +func NewBillet(rootHash util.Uint256, enableRefCount bool, store *storage.MemCachedStore) *Billet { + return &Billet{ + Store: store, + root: NewHashNode(rootHash), + refcountEnabled: enableRefCount, + } +} + +// RestoreHashNode replaces HashNode located at the provided path by the specified Node +// and stores it. +// TODO: It also maintains MPT as small as possible by collapsing those parts +// of MPT that have been completely restored. +func (b *Billet) RestoreHashNode(path []byte, node Node) error { + if _, ok := node.(*HashNode); ok { + return fmt.Errorf("%w: unable to restore node into HashNode", ErrRestoreFailed) + } + if _, ok := node.(EmptyNode); ok { + return fmt.Errorf("%w: unable to restore node into EmptyNode", ErrRestoreFailed) + } + r, err := b.putIntoNode(b.root, path, node) + if err != nil { + return err + } + b.root = r + + // If it's a leaf, then put into contract storage. + if leaf, ok := node.(*LeafNode); ok { + k := append([]byte{byte(storage.STStorage)}, fromNibbles(path)...) + _ = b.Store.Put(k, leaf.value) + } + return nil +} + +// putIntoNode puts val with provided path inside curr and returns updated node. +// Reference counters are updated for both curr and returned value. +func (b *Billet) putIntoNode(curr Node, path []byte, val Node) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + return b.putIntoLeaf(n, path, val) + case *BranchNode: + return b.putIntoBranch(n, path, val) + case *ExtensionNode: + return b.putIntoExtension(n, path, val) + case *HashNode: + return b.putIntoHash(n, path, val) + case EmptyNode: + return nil, fmt.Errorf("%w: can't modify EmptyNode during restore", ErrRestoreFailed) + default: + panic("invalid MPT node type") + } +} + +func (b *Billet) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { + if len(path) != 0 { + return nil, fmt.Errorf("%w: can't modify LeafNode during restore", ErrRestoreFailed) + } + if curr.Hash() != val.Hash() { + return nil, fmt.Errorf("%w: bad Leaf node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) + } + // this node has already been restored, no refcount changes required + return curr, nil +} + +func (b *Billet) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { + if len(path) == 0 && curr.Hash().Equals(val.Hash()) { + // this node has already been restored, no refcount changes required + return curr, nil + } + i, path := splitPath(path) + r, err := b.putIntoNode(curr.Children[i], path, val) + if err != nil { + return nil, err + } + curr.Children[i] = r + return curr, nil +} + +func (b *Billet) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { + if len(path) == 0 { + if curr.Hash() != val.Hash() { + return nil, fmt.Errorf("%w: bad Extension node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) + } + // this node has already been restored, no refcount changes required + return curr, nil + } + if !bytes.HasPrefix(path, curr.key) { + return nil, fmt.Errorf("%w: can't modify ExtensionNode during restore", ErrRestoreFailed) + } + + r, err := b.putIntoNode(curr.next, path[len(curr.key):], val) + if err != nil { + return nil, err + } + curr.next = r + return curr, nil +} + +func (b *Billet) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { + // Once the part of MPT Billet is completely restored, it will be collapsed forever, so + // it's an MPT pool duty to avoid duplicating restore requests. + if len(path) != 0 { + return nil, fmt.Errorf("%w: node has already been collapsed", ErrRestoreFailed) + } + + // `curr` hash node can be either of + // 1) saved in storage (i.g. if we've already restored node with the same hash from the + // other part of MPT), so just add it to local in-memory MPT. + // 2) missing from the storage. It's OK because we're syncing MPT state, and the purpose + // is to store missing hash nodes. + // both cases are OK, but we still need to validate `val` against `curr`. + if val.Hash() != curr.Hash() { + return nil, fmt.Errorf("%w: can't restore HashNode: expected and actual hashes mismatch (%s vs %s)", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) + } + // We also need to increment refcount in both cases. That's the only place where refcount + // is changed during restore process. Also flush right now, because sync process can be + // interrupted at any time. + b.incrementRefAndStore(val.Hash(), val.Bytes()) + return val, nil +} + +func (b *Billet) incrementRefAndStore(h util.Uint256, bs []byte) { + key := makeStorageKey(h.BytesBE()) + if b.refcountEnabled { + var ( + err error + data []byte + cnt int32 + ) + // An item may already be in store. + data, err = b.Store.Get(key) + if err == nil { + cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:])) + } + cnt++ + if len(data) == 0 { + data = append(bs, 0, 0, 0, 0) + } + binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt)) + _ = b.Store.Put(key, data) + } else { + _ = b.Store.Put(key, bs) + } +} + +// Traverse traverses MPT nodes (pre-order) starting from the billet root down +// to its children calling `process` for each serialised node until true is +// returned from `process` function. It also replaces all HashNodes to their +// "unhashed" counterparts until the stop condition is satisfied. +func (b *Billet) Traverse(process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) error { + r, err := b.traverse(b.root, process, ignoreStorageErr) + if err != nil && !errors.Is(err, errStop) { + return err + } + b.root = r + return nil +} + +func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) (Node, error) { + if _, ok := curr.(EmptyNode); ok { + // We're not interested in EmptyNodes, and they do not affect the + // traversal process, thus remain them untouched. + return curr, nil + } + if hn, ok := curr.(*HashNode); ok { + r, err := b.getFromStore(hn.Hash()) + if err != nil { + if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) { + return hn, nil + } + return nil, err + } + return b.traverse(r, process, ignoreStorageErr) + } + bytes := slice.Copy(curr.Bytes()) + if process(curr, bytes) { + return curr, errStop + } + switch n := curr.(type) { + case *LeafNode: + return n, nil + case *BranchNode: + for i := range n.Children { + r, err := b.traverse(n.Children[i], process, ignoreStorageErr) + if err != nil { + if !errors.Is(err, errStop) { + return nil, err + } + n.Children[i] = r + return n, err + } + n.Children[i] = r + } + return n, nil + case *ExtensionNode: + r, err := b.traverse(n.next, process, ignoreStorageErr) + if err != nil && !errors.Is(err, errStop) { + return nil, err + } + n.next = r + return n, err + default: + return nil, ErrNotFound + } +} + +func (b *Billet) getFromStore(h util.Uint256) (Node, error) { + data, err := b.Store.Get(makeStorageKey(h.BytesBE())) + if err != nil { + return nil, err + } + + var n NodeObject + r := io.NewBinReaderFromBuf(data) + n.DecodeBinary(r) + if r.Err != nil { + return nil, r.Err + } + + if b.refcountEnabled { + data = data[:len(data)-4] + } + n.Node.(flushedNode).setCache(data, h) + return n.Node, nil +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go index a7399d37d..ceabb31f7 100644 --- a/pkg/core/mpt/helpers.go +++ b/pkg/core/mpt/helpers.go @@ -40,3 +40,12 @@ func toNibbles(path []byte) []byte { } return result } + +// fromNibbles performs operation opposite to toNibbles and does no path validity checks. +func fromNibbles(path []byte) []byte { + result := make([]byte, len(path)/2) + for i := range result { + result[i] = path[2*i]<<4 + path[2*i+1] + } + return result +} diff --git a/pkg/core/mpt/helpers_test.go b/pkg/core/mpt/helpers_test.go new file mode 100644 index 000000000..ee5634bfd --- /dev/null +++ b/pkg/core/mpt/helpers_test.go @@ -0,0 +1,20 @@ +package mpt + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToNibblesFromNibbles(t *testing.T) { + check := func(t *testing.T, expected []byte) { + actual := fromNibbles(toNibbles(expected)) + require.Equal(t, expected, actual) + } + t.Run("empty path", func(t *testing.T) { + check(t, []byte{}) + }) + t.Run("non-empty path", func(t *testing.T) { + check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF}) + }) +} From d67ff30704521bfda25905965f10099c5cdd8db0 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Fri, 30 Jul 2021 16:57:42 +0300 Subject: [PATCH 02/15] core: implement statesync module And support GetMPTData and MPTData P2P commands. --- internal/fakechain/fakechain.go | 72 ++++ pkg/core/blockchain.go | 64 ++++ pkg/core/blockchainer/blockchainer.go | 3 +- pkg/core/blockchainer/blockqueuer.go | 1 + pkg/core/blockchainer/state_root.go | 1 + pkg/core/blockchainer/state_sync.go | 19 + pkg/core/mpt/helpers.go | 35 ++ pkg/core/mpt/helpers_test.go | 47 +++ pkg/core/mpt/proof_test.go | 9 +- pkg/core/native/designate.go | 5 + pkg/core/stateroot/module.go | 19 + pkg/core/statesync/module.go | 440 +++++++++++++++++++++++ pkg/core/statesync/mptpool.go | 119 ++++++ pkg/network/blockqueue.go | 9 +- pkg/network/message.go | 8 +- pkg/network/message_string.go | 10 +- pkg/network/message_test.go | 15 + pkg/network/payload/mptdata.go | 35 ++ pkg/network/payload/mptdata_test.go | 24 ++ pkg/network/payload/mptinventory.go | 32 ++ pkg/network/payload/mptinventory_test.go | 38 ++ pkg/network/server.go | 166 ++++++++- pkg/network/server_test.go | 42 ++- pkg/network/tcp_peer.go | 16 +- 24 files changed, 1197 insertions(+), 32 deletions(-) create mode 100644 pkg/core/blockchainer/state_sync.go create mode 100644 pkg/core/statesync/module.go create mode 100644 pkg/core/statesync/mptpool.go create mode 100644 pkg/network/payload/mptdata.go create mode 100644 pkg/network/payload/mptdata_test.go create mode 100644 pkg/network/payload/mptinventory.go create mode 100644 pkg/network/payload/mptinventory_test.go diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index ef8145858..ca1f0bc6d 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/blockchainer/services" "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/mempool" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/native" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -42,6 +43,13 @@ type FakeChain struct { UtilityTokenBalance *big.Int } +// FakeStateSync implements StateSync interface. +type FakeStateSync struct { + IsActiveFlag bool + IsInitializedFlag bool + InitFunc func(h uint32) error +} + // NewFakeChain returns new FakeChain structure. func NewFakeChain() *FakeChain { return &FakeChain{ @@ -294,6 +302,16 @@ func (chain *FakeChain) GetStateModule() blockchainer.StateRoot { return nil } +// GetStateSyncModule implements Blockchainer interface. +func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync { + return &FakeStateSync{} +} + +// JumpToState implements Blockchainer interface. +func (chain *FakeChain) JumpToState(module blockchainer.StateSync) error { + panic("TODO") +} + // GetStorageItem implements Blockchainer interface. func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem { panic("TODO") @@ -436,3 +454,57 @@ func (chain *FakeChain) UnsubscribeFromNotifications(ch chan<- *state.Notificati func (chain *FakeChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) { panic("TODO") } + +// AddBlock implements StateSync interface. +func (s *FakeStateSync) AddBlock(block *block.Block) error { + panic("TODO") +} + +// AddHeaders implements StateSync interface. +func (s *FakeStateSync) AddHeaders(...*block.Header) error { + panic("TODO") +} + +// AddMPTNodes implements StateSync interface. +func (s *FakeStateSync) AddMPTNodes([][]byte) error { + panic("TODO") +} + +// BlockHeight implements StateSync interface. +func (s *FakeStateSync) BlockHeight() uint32 { + panic("TODO") +} + +// IsActive implements StateSync interface. +func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag } + +// IsInitialized implements StateSync interface. +func (s *FakeStateSync) IsInitialized() bool { + return s.IsInitializedFlag +} + +// Init implements StateSync interface. +func (s *FakeStateSync) Init(currChainHeight uint32) error { + if s.InitFunc != nil { + return s.InitFunc(currChainHeight) + } + panic("TODO") +} + +// NeedHeaders implements StateSync interface. +func (s *FakeStateSync) NeedHeaders() bool { return false } + +// NeedMPTNodes implements StateSync interface. +func (s *FakeStateSync) NeedMPTNodes() bool { + panic("TODO") +} + +// Traverse implements StateSync interface. +func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + panic("TODO") +} + +// GetJumpHeight implements StateSync interface. +func (s *FakeStateSync) GetJumpHeight() (uint32, error) { + panic("TODO") +} diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 1a20308f2..18fdf786f 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -23,6 +23,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/native/noderoles" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/stateroot" + "github.com/nspcc-dev/neo-go/pkg/core/statesync" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -409,6 +410,64 @@ func (bc *Blockchain) init() error { return bc.updateExtensibleWhitelist(bHeight) } +// JumpToState is an atomic operation that changes Blockchain state to the one +// specified by the state sync point p. All the data needed for the jump must be +// collected by the state sync module. +func (bc *Blockchain) JumpToState(module blockchainer.StateSync) error { + bc.lock.Lock() + defer bc.lock.Unlock() + + p, err := module.GetJumpHeight() + if err != nil { + return fmt.Errorf("failed to get jump height: %w", err) + } + if p+1 >= uint32(len(bc.headerHashes)) { + return fmt.Errorf("invalid state sync point") + } + + bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p)) + + block, err := bc.dao.GetBlock(bc.headerHashes[p]) + if err != nil { + return fmt.Errorf("failed to get current block: %w", err) + } + err = bc.dao.StoreAsCurrentBlock(block, nil) + if err != nil { + return fmt.Errorf("failed to store current block: %w", err) + } + bc.topBlock.Store(block) + atomic.StoreUint32(&bc.blockHeight, p) + atomic.StoreUint32(&bc.persistedHeight, p) + + block, err = bc.dao.GetBlock(bc.headerHashes[p+1]) + if err != nil { + return fmt.Errorf("failed to get block to init MPT: %w", err) + } + if err = bc.stateRoot.JumpToState(&state.MPTRoot{ + Index: p, + Root: block.PrevStateRoot, + }, bc.config.KeepOnlyLatestState); err != nil { + return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err) + } + + err = bc.contracts.NEO.InitializeCache(bc, bc.dao) + if err != nil { + return fmt.Errorf("can't init cache for NEO native contract: %w", err) + } + err = bc.contracts.Management.InitializeCache(bc.dao) + if err != nil { + return fmt.Errorf("can't init cache for Management native contract: %w", err) + } + bc.contracts.Designate.InitializeCache() + + if err := bc.updateExtensibleWhitelist(p); err != nil { + return fmt.Errorf("failed to update extensible whitelist: %w", err) + } + + updateBlockHeightMetric(p) + return nil +} + // Run runs chain loop, it needs to be run as goroutine and executing it is // critical for correct Blockchain operation. func (bc *Blockchain) Run() { @@ -696,6 +755,11 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot { return bc.stateRoot } +// GetStateSyncModule returns new state sync service instance. +func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync { + return statesync.NewModule(bc, bc.log, bc.dao) +} + // storeBlock performs chain update using the block given, it executes all // transactions with all appropriate side-effects and updates Blockchain state. // This is the only way to change Blockchain state. diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index a0d5af6f4..9f9ae4e03 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -21,7 +21,6 @@ import ( type Blockchainer interface { ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction GetConfig() config.ProtocolConfiguration - AddHeaders(...*block.Header) error Blockqueuer // Blockqueuer interface CalculateClaimable(h util.Uint160, endHeight uint32) (*big.Int, error) Close() @@ -56,10 +55,12 @@ type Blockchainer interface { GetStandByCommittee() keys.PublicKeys GetStandByValidators() keys.PublicKeys GetStateModule() StateRoot + GetStateSyncModule() StateSync GetStorageItem(id int32, key []byte) state.StorageItem GetStorageItems(id int32) (map[string]state.StorageItem, error) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) + JumpToState(module StateSync) error SetOracle(service services.Oracle) mempool.Feer // fee interface ManagementContractHash() util.Uint160 diff --git a/pkg/core/blockchainer/blockqueuer.go b/pkg/core/blockchainer/blockqueuer.go index 384ccd489..bae45281f 100644 --- a/pkg/core/blockchainer/blockqueuer.go +++ b/pkg/core/blockchainer/blockqueuer.go @@ -5,5 +5,6 @@ import "github.com/nspcc-dev/neo-go/pkg/core/block" // Blockqueuer is an interface for blockqueue. type Blockqueuer interface { AddBlock(block *block.Block) error + AddHeaders(...*block.Header) error BlockHeight() uint32 } diff --git a/pkg/core/blockchainer/state_root.go b/pkg/core/blockchainer/state_root.go index 979e15963..64f7d8b7c 100644 --- a/pkg/core/blockchainer/state_root.go +++ b/pkg/core/blockchainer/state_root.go @@ -9,6 +9,7 @@ import ( // StateRoot represents local state root module. type StateRoot interface { AddStateRoot(root *state.MPTRoot) error + CurrentLocalHeight() uint32 CurrentLocalStateRoot() util.Uint256 CurrentValidatedHeight() uint32 GetStateProof(root util.Uint256, key []byte) ([][]byte, error) diff --git a/pkg/core/blockchainer/state_sync.go b/pkg/core/blockchainer/state_sync.go new file mode 100644 index 000000000..66643020c --- /dev/null +++ b/pkg/core/blockchainer/state_sync.go @@ -0,0 +1,19 @@ +package blockchainer + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// StateSync represents state sync module. +type StateSync interface { + AddMPTNodes([][]byte) error + Blockqueuer // Blockqueuer interface + Init(currChainHeight uint32) error + IsActive() bool + IsInitialized() bool + GetJumpHeight() (uint32, error) + NeedHeaders() bool + NeedMPTNodes() bool + Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go index ceabb31f7..63c02c089 100644 --- a/pkg/core/mpt/helpers.go +++ b/pkg/core/mpt/helpers.go @@ -1,5 +1,7 @@ package mpt +import "github.com/nspcc-dev/neo-go/pkg/util" + // lcp returns longest common prefix of a and b. // Note: it does no allocations. func lcp(a, b []byte) []byte { @@ -49,3 +51,36 @@ func fromNibbles(path []byte) []byte { } return result } + +// GetChildrenPaths returns a set of paths to node's children who are non-empty HashNodes +// based on the node's path. +func GetChildrenPaths(path []byte, node Node) map[util.Uint256][][]byte { + res := make(map[util.Uint256][][]byte) + switch n := node.(type) { + case *LeafNode, *HashNode, EmptyNode: + return nil + case *BranchNode: + for i, child := range n.Children { + if child.Type() == HashT { + cPath := make([]byte, len(path), len(path)+1) + copy(cPath, path) + if i != lastChild { + cPath = append(cPath, byte(i)) + } + paths := res[child.Hash()] + paths = append(paths, cPath) + res[child.Hash()] = paths + } + } + case *ExtensionNode: + if n.next.Type() == HashT { + cPath := make([]byte, len(path)+len(n.key)) + copy(cPath, path) + copy(cPath[len(path):], n.key) + res[n.next.Hash()] = [][]byte{cPath} + } + default: + panic("unknown Node type") + } + return res +} diff --git a/pkg/core/mpt/helpers_test.go b/pkg/core/mpt/helpers_test.go index ee5634bfd..28181dcc2 100644 --- a/pkg/core/mpt/helpers_test.go +++ b/pkg/core/mpt/helpers_test.go @@ -3,6 +3,7 @@ package mpt import ( "testing" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" ) @@ -18,3 +19,49 @@ func TestToNibblesFromNibbles(t *testing.T) { check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF}) }) } + +func TestGetChildrenPaths(t *testing.T) { + h1 := NewHashNode(util.Uint256{1, 2, 3}) + h2 := NewHashNode(util.Uint256{4, 5, 6}) + h3 := NewHashNode(util.Uint256{7, 8, 9}) + l := NewLeafNode([]byte{1, 2, 3}) + ext1 := NewExtensionNode([]byte{8, 9}, h1) + ext2 := NewExtensionNode([]byte{7, 6}, l) + branch := NewBranchNode() + branch.Children[3] = h1 + branch.Children[5] = l + branch.Children[6] = h1 // 3-th and 6-th children have the same hash + branch.Children[7] = h3 + branch.Children[lastChild] = h2 + testCases := map[string]struct { + node Node + expected map[util.Uint256][][]byte + }{ + "Hash": {h1, nil}, + "Leaf": {l, nil}, + "Extension with next Hash": {ext1, map[util.Uint256][][]byte{h1.Hash(): {ext1.key}}}, + "Extension with next non-Hash": {ext2, map[util.Uint256][][]byte{}}, + "Branch": {branch, map[util.Uint256][][]byte{ + h1.Hash(): {{0x03}, {0x06}}, + h2.Hash(): {{}}, + h3.Hash(): {{0x07}}, + }}, + } + parentPath := []byte{4, 5, 6} + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + require.Equal(t, testCase.expected, GetChildrenPaths([]byte{}, testCase.node)) + if testCase.expected != nil { + expectedWithPrefix := make(map[util.Uint256][][]byte, len(testCase.expected)) + for h, paths := range testCase.expected { + var res [][]byte + for _, path := range paths { + res = append(res, append(parentPath, path...)) + } + expectedWithPrefix[h] = res + } + require.Equal(t, expectedWithPrefix, GetChildrenPaths(parentPath, testCase.node)) + } + }) + } +} diff --git a/pkg/core/mpt/proof_test.go b/pkg/core/mpt/proof_test.go index 75a76408f..d733fb6d0 100644 --- a/pkg/core/mpt/proof_test.go +++ b/pkg/core/mpt/proof_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" ) -func newProofTrie(t *testing.T) *Trie { +func newProofTrie(t *testing.T, missingHashNode bool) *Trie { l := NewLeafNode([]byte("somevalue")) e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l) l2 := NewLeafNode([]byte("invalid")) @@ -20,11 +20,14 @@ func newProofTrie(t *testing.T) *Trie { require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) tr.putToStore(l) tr.putToStore(e) + if !missingHashNode { + tr.putToStore(l2) + } return tr } func TestTrie_GetProof(t *testing.T) { - tr := newProofTrie(t) + tr := newProofTrie(t, true) t.Run("MissingKey", func(t *testing.T) { _, err := tr.GetProof([]byte{0x12}) @@ -43,7 +46,7 @@ func TestTrie_GetProof(t *testing.T) { } func TestVerifyProof(t *testing.T) { - tr := newProofTrie(t) + tr := newProofTrie(t, true) t.Run("Simple", func(t *testing.T) { proof, err := tr.GetProof([]byte{0x12, 0x32}) diff --git a/pkg/core/native/designate.go b/pkg/core/native/designate.go index 8c75eebc8..844690a3d 100644 --- a/pkg/core/native/designate.go +++ b/pkg/core/native/designate.go @@ -353,3 +353,8 @@ func (s *Designate) getRole(item stackitem.Item) (noderoles.Role, bool) { u := bi.Uint64() return noderoles.Role(u), u <= math.MaxUint8 && s.isValidRole(noderoles.Role(u)) } + +// InitializeCache invalidates native Designate cache. +func (s *Designate) InitializeCache() { + s.rolesChangedFlag.Store(true) +} diff --git a/pkg/core/stateroot/module.go b/pkg/core/stateroot/module.go index 530dbf583..20a3396e9 100644 --- a/pkg/core/stateroot/module.go +++ b/pkg/core/stateroot/module.go @@ -114,6 +114,25 @@ func (s *Module) Init(height uint32, enableRefCount bool) error { return nil } +// JumpToState performs jump to the state specified by given stateroot index. +func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error { + if err := s.addLocalStateRoot(s.Store, sr); err != nil { + return fmt.Errorf("failed to store local state root: %w", err) + } + + data := make([]byte, 4) + binary.LittleEndian.PutUint32(data, sr.Index) + if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixValidated}, data); err != nil { + return fmt.Errorf("failed to store validated height: %w", err) + } + s.validatedHeight.Store(sr.Index) + + s.currentLocal.Store(sr.Root) + s.localHeight.Store(sr.Index) + s.mpt = mpt.NewTrie(mpt.NewHashNode(sr.Root), enableRefCount, s.Store) + return nil +} + // AddMPTBatch updates using provided batch. func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) { mpt := *s.mpt diff --git a/pkg/core/statesync/module.go b/pkg/core/statesync/module.go new file mode 100644 index 000000000..24f38f14e --- /dev/null +++ b/pkg/core/statesync/module.go @@ -0,0 +1,440 @@ +/* +Package statesync implements module for the P2P state synchronisation process. The +module manages state synchronisation for non-archival nodes which are joining the +network and don't have the ability to resync from the genesis block. + +Given the currently available state synchronisation point P, sate sync process +includes the following stages: + +1. Fetching headers starting from height 0 up to P+1. +2. Fetching MPT nodes for height P stating from the corresponding state root. +3. Fetching blocks starting from height P-MaxTraceableBlocks (or 0) up to P. + +Steps 2 and 3 are being performed in parallel. Once all the data are collected +and stored in the db, an atomic state jump is occurred to the state sync point P. +Further node operation process is performed using standard sync mechanism until +the node reaches synchronised state. +*/ +package statesync + +import ( + "encoding/hex" + "errors" + "fmt" + "sync" + + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" + "github.com/nspcc-dev/neo-go/pkg/core/dao" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "go.uber.org/zap" +) + +// stateSyncStage is a type of state synchronisation stage. +type stateSyncStage uint8 + +const ( + // inactive means that state exchange is disabled by the protocol configuration. + // Can't be combined with other states. + inactive stateSyncStage = 1 << iota + // none means that state exchange is enabled in the configuration, but + // initialisation of the state sync module wasn't yet performed, i.e. + // (*Module).Init wasn't called. Can't be combined with other states. + none + // initialized means that (*Module).Init was called, but other sync stages + // are not yet reached (i.e. that headers are requested, but not yet fetched). + // Can't be combined with other states. + initialized + // headersSynced means that headers for the current state sync point are fetched. + // May be combined with mptSynced and/or blocksSynced. + headersSynced + // mptSynced means that MPT nodes for the current state sync point are fetched. + // Always combined with headersSynced; may be combined with blocksSynced. + mptSynced + // blocksSynced means that blocks up to the current state sync point are stored. + // Always combined with headersSynced; may be combined with mptSynced. + blocksSynced +) + +// Module represents state sync module and aimed to gather state-related data to +// perform an atomic state jump. +type Module struct { + lock sync.RWMutex + log *zap.Logger + + // syncPoint is the state synchronisation point P we're currently working against. + syncPoint uint32 + // syncStage is the stage of the sync process. + syncStage stateSyncStage + // syncInterval is the delta between two adjacent state sync points. + syncInterval uint32 + // blockHeight is the index of the latest stored block. + blockHeight uint32 + + dao *dao.Simple + bc blockchainer.Blockchainer + mptpool *Pool + + billet *mpt.Billet +} + +// NewModule returns new instance of statesync module. +func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple) *Module { + if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) { + return &Module{ + dao: s, + bc: bc, + syncStage: inactive, + } + } + return &Module{ + dao: s, + bc: bc, + log: log, + syncInterval: uint32(bc.GetConfig().StateSyncInterval), + mptpool: NewPool(), + syncStage: none, + } +} + +// Init initializes state sync module for the current chain's height with given +// callback for MPT nodes requests. +func (s *Module) Init(currChainHeight uint32) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage != none { + return errors.New("already initialized or inactive") + } + + p := (currChainHeight / s.syncInterval) * s.syncInterval + if p < 2*s.syncInterval { + // chain is too low to start state exchange process, use the standard sync mechanism + s.syncStage = inactive + return nil + } + pOld, err := s.dao.GetStateSyncPoint() + if err == nil && pOld >= p-s.syncInterval { + // old point is still valid, so try to resync states for this point. + p = pOld + } else if s.bc.BlockHeight() > p-2*s.syncInterval { + // chain has already been synchronised up to old state sync point and regular blocks processing was started + s.syncStage = inactive + return nil + } + + s.syncPoint = p + err = s.dao.PutStateSyncPoint(p) + if err != nil { + return fmt.Errorf("failed to store state synchronisation point %d: %w", p, err) + } + s.syncStage = initialized + s.log.Info("try to sync state for the latest state synchronisation point", + zap.Uint32("point", p), + zap.Uint32("evaluated chain's blockHeight", currChainHeight)) + + // check headers sync state first + ltstHeaderHeight := s.bc.HeaderHeight() + if ltstHeaderHeight > p { + s.syncStage = headersSynced + s.log.Info("headers are in sync", + zap.Uint32("headerHeight", s.bc.HeaderHeight())) + } + + // check blocks sync state + s.blockHeight = s.getLatestSavedBlock(p) + if s.blockHeight >= p { + s.syncStage |= blocksSynced + s.log.Info("blocks are in sync", + zap.Uint32("blockHeight", s.blockHeight)) + } + + // check MPT sync state + if s.blockHeight > p { + s.syncStage |= mptSynced + s.log.Info("MPT is in sync", + zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight())) + } else if s.syncStage&headersSynced != 0 { + header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(p + 1))) + if err != nil { + return fmt.Errorf("failed to get header to initialize MPT billet: %w", err) + } + s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store) + s.log.Info("MPT billet initialized", + zap.Uint32("height", s.syncPoint), + zap.String("state root", header.PrevStateRoot.StringBE())) + pool := NewPool() + pool.Add(header.PrevStateRoot, []byte{}) + err = s.billet.Traverse(func(n mpt.Node, _ []byte) bool { + nPaths, ok := pool.TryGet(n.Hash()) + if !ok { + // if this situation occurs, then it's a bug in MPT pool or Traverse. + panic("failed to get MPT node from the pool") + } + pool.Remove(n.Hash()) + childrenPaths := make(map[util.Uint256][][]byte) + for _, path := range nPaths { + nChildrenPaths := mpt.GetChildrenPaths(path, n) + for hash, paths := range nChildrenPaths { + childrenPaths[hash] = append(childrenPaths[hash], paths...) // it's OK to have duplicates, they'll be handled by mempool + } + } + pool.Update(nil, childrenPaths) + return false + }, true) + if err != nil { + return fmt.Errorf("failed to traverse MPT while initialization: %w", err) + } + s.mptpool.Update(nil, pool.GetAll()) + if s.mptpool.Count() == 0 { + s.syncStage |= mptSynced + s.log.Info("MPT is in sync", + zap.Uint32("stateroot height", p)) + } + } + + if s.syncStage == headersSynced|blocksSynced|mptSynced { + s.log.Info("state is in sync, starting regular blocks processing") + s.syncStage = inactive + } + return nil +} + +// getLatestSavedBlock returns either current block index (if it's still relevant +// to continue state sync process) or H-1 where H is the index of the earliest +// block that should be saved next. +func (s *Module) getLatestSavedBlock(p uint32) uint32 { + var result uint32 + mtb := s.bc.GetConfig().MaxTraceableBlocks + if p > mtb { + result = p - mtb + } + storedH, err := s.dao.GetStateSyncCurrentBlockHeight() + if err == nil && storedH > result { + result = storedH + } + actualH := s.bc.BlockHeight() + if actualH > result { + result = actualH + } + return result +} + +// AddHeaders validates and adds specified headers to the chain. +func (s *Module) AddHeaders(hdrs ...*block.Header) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage != initialized { + return errors.New("headers were not requested") + } + + hdrsErr := s.bc.AddHeaders(hdrs...) + if s.bc.HeaderHeight() > s.syncPoint { + s.syncStage = headersSynced + s.log.Info("headers for state sync are fetched", + zap.Uint32("header height", s.bc.HeaderHeight())) + + header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint) + 1)) + if err != nil { + s.log.Fatal("failed to get header to initialize MPT billet", + zap.Uint32("height", s.syncPoint+1), + zap.Error(err)) + } + s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store) + s.mptpool.Add(header.PrevStateRoot, []byte{}) + s.log.Info("MPT billet initialized", + zap.Uint32("height", s.syncPoint), + zap.String("state root", header.PrevStateRoot.StringBE())) + } + return hdrsErr +} + +// AddBlock verifies and saves block skipping executable scripts. +func (s *Module) AddBlock(block *block.Block) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage&headersSynced == 0 || s.syncStage&blocksSynced != 0 { + return nil + } + + if s.blockHeight == s.syncPoint { + return nil + } + expectedHeight := s.blockHeight + 1 + if expectedHeight != block.Index { + return fmt.Errorf("expected %d, got %d: invalid block index", expectedHeight, block.Index) + } + if s.bc.GetConfig().StateRootInHeader != block.StateRootEnabled { + return fmt.Errorf("stateroot setting mismatch: %v != %v", s.bc.GetConfig().StateRootInHeader, block.StateRootEnabled) + } + if s.bc.GetConfig().VerifyBlocks { + merkle := block.ComputeMerkleRoot() + if !block.MerkleRoot.Equals(merkle) { + return errors.New("invalid block: MerkleRoot mismatch") + } + } + cache := s.dao.GetWrapped() + writeBuf := io.NewBufBinWriter() + if err := cache.StoreAsBlock(block, writeBuf); err != nil { + return err + } + writeBuf.Reset() + + err := cache.PutStateSyncCurrentBlockHeight(block.Index) + if err != nil { + return fmt.Errorf("failed to store current block height: %w", err) + } + + for _, tx := range block.Transactions { + if err := cache.StoreAsTransaction(tx, block.Index, writeBuf); err != nil { + return err + } + writeBuf.Reset() + } + + _, err = cache.Persist() + if err != nil { + return fmt.Errorf("failed to persist results: %w", err) + } + s.blockHeight = block.Index + if s.blockHeight == s.syncPoint { + s.syncStage |= blocksSynced + s.log.Info("blocks are in sync", + zap.Uint32("blockHeight", s.blockHeight)) + s.checkSyncIsCompleted() + } + return nil +} + +// AddMPTNodes tries to add provided set of MPT nodes to the MPT billet if they are +// not yet collected. +func (s *Module) AddMPTNodes(nodes [][]byte) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.syncStage&headersSynced == 0 || s.syncStage&mptSynced != 0 { + return errors.New("MPT nodes were not requested") + } + + for _, nBytes := range nodes { + var n mpt.NodeObject + r := io.NewBinReaderFromBuf(nBytes) + n.DecodeBinary(r) + if r.Err != nil { + return fmt.Errorf("failed to decode MPT node: %w", r.Err) + } + nPaths, ok := s.mptpool.TryGet(n.Hash()) + if !ok { + // it can easily happen after receiving the same data from different peers. + return nil + } + + var childrenPaths = make(map[util.Uint256][][]byte) + for _, path := range nPaths { + err := s.billet.RestoreHashNode(path, n.Node) + if err != nil { + return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err) + } + for h, paths := range mpt.GetChildrenPaths(path, n.Node) { + childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool + } + } + + s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths) + } + if s.mptpool.Count() == 0 { + s.syncStage |= mptSynced + s.log.Info("MPT is in sync", + zap.Uint32("height", s.syncPoint)) + s.checkSyncIsCompleted() + } + return nil +} + +// checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1 +// height are fetched, blocks up to P height are stored and MPT nodes for P height are stored. +// If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller +// should take care of it. +func (s *Module) checkSyncIsCompleted() { + if s.syncStage != headersSynced|mptSynced|blocksSynced { + return + } + s.log.Info("state is in sync", + zap.Uint32("state sync point", s.syncPoint)) + err := s.bc.JumpToState(s) + if err != nil { + s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err)) + } + s.syncStage = inactive + s.dispose() +} + +func (s *Module) dispose() { + s.billet = nil +} + +// BlockHeight returns index of the last stored block. +func (s *Module) BlockHeight() uint32 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.blockHeight +} + +// IsActive tells whether state sync module is on and still gathering state +// synchronisation data (headers, blocks or MPT nodes). +func (s *Module) IsActive() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return !(s.syncStage == inactive || (s.syncStage == headersSynced|mptSynced|blocksSynced)) +} + +// IsInitialized tells whether state sync module does not require initialization. +// If `false` is returned then Init can be safely called. +func (s *Module) IsInitialized() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.syncStage != none +} + +// NeedHeaders tells whether the module hasn't completed headers synchronisation. +func (s *Module) NeedHeaders() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.syncStage == initialized +} + +// NeedMPTNodes returns whether the module hasn't completed MPT synchronisation. +func (s *Module) NeedMPTNodes() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.syncStage&headersSynced != 0 && s.syncStage&mptSynced == 0 +} + +// Traverse traverses local MPT nodes starting from the specified root down to its +// children calling `process` for each serialised node until stop condition is satisfied. +func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + s.lock.RLock() + defer s.lock.RUnlock() + + b := mpt.NewBillet(root, s.bc.GetConfig().KeepOnlyLatestState, storage.NewMemCachedStore(s.dao.Store)) + return b.Traverse(process, false) +} + +// GetJumpHeight returns state sync point to jump to. It is not protected by mutex and should be called +// under the module lock. +func (s *Module) GetJumpHeight() (uint32, error) { + if s.syncStage != headersSynced|mptSynced|blocksSynced { + return 0, errors.New("state sync module has wong state to perform state jump") + } + return s.syncPoint, nil +} diff --git a/pkg/core/statesync/mptpool.go b/pkg/core/statesync/mptpool.go new file mode 100644 index 000000000..23610d97a --- /dev/null +++ b/pkg/core/statesync/mptpool.go @@ -0,0 +1,119 @@ +package statesync + +import ( + "bytes" + "sort" + "sync" + + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// Pool stores unknown MPT nodes along with the corresponding paths (single node is +// allowed to have multiple MPT paths). +type Pool struct { + lock sync.RWMutex + hashes map[util.Uint256][][]byte +} + +// NewPool returns new MPT node hashes pool. +func NewPool() *Pool { + return &Pool{ + hashes: make(map[util.Uint256][][]byte), + } +} + +// ContainsKey checks if MPT node with the specified hash is in the Pool. +func (mp *Pool) ContainsKey(hash util.Uint256) bool { + mp.lock.RLock() + defer mp.lock.RUnlock() + + _, ok := mp.hashes[hash] + return ok +} + +// TryGet returns a set of MPT paths for the specified HashNode. +func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) { + mp.lock.RLock() + defer mp.lock.RUnlock() + + paths, ok := mp.hashes[hash] + return paths, ok +} + +// GetAll returns all MPT nodes with the corresponding paths from the pool. +func (mp *Pool) GetAll() map[util.Uint256][][]byte { + mp.lock.RLock() + defer mp.lock.RUnlock() + + return mp.hashes +} + +// Remove removes MPT node from the pool by the specified hash. +func (mp *Pool) Remove(hash util.Uint256) { + mp.lock.Lock() + defer mp.lock.Unlock() + + delete(mp.hashes, hash) +} + +// Add adds path to the set of paths for the specified node. +func (mp *Pool) Add(hash util.Uint256, path []byte) { + mp.lock.Lock() + defer mp.lock.Unlock() + + mp.addPaths(hash, [][]byte{path}) +} + +// Update is an atomic operation and removes/adds specified nodes from/to the pool. +func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) { + mp.lock.Lock() + defer mp.lock.Unlock() + + for h, paths := range remove { + old := mp.hashes[h] + for _, path := range paths { + i := sort.Search(len(old), func(i int) bool { + return bytes.Compare(old[i], path) >= 0 + }) + if i < len(old) && bytes.Equal(old[i], path) { + old = append(old[:i], old[i+1:]...) + } + } + if len(old) == 0 { + delete(mp.hashes, h) + } else { + mp.hashes[h] = old + } + } + for h, paths := range add { + mp.addPaths(h, paths) + } +} + +// addPaths adds set of the specified node paths to the pool. +func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) { + old := mp.hashes[nodeHash] + for _, path := range paths { + i := sort.Search(len(old), func(i int) bool { + return bytes.Compare(old[i], path) >= 0 + }) + if i < len(old) && bytes.Equal(old[i], path) { + // then path is already added + continue + } + old = append(old, path) + if i != len(old)-1 { + copy(old[i+1:], old[i:]) + old[i] = path + } + } + mp.hashes[nodeHash] = old +} + +// Count returns the number of nodes in the pool. +func (mp *Pool) Count() int { + mp.lock.RLock() + defer mp.lock.RUnlock() + + return len(mp.hashes) +} diff --git a/pkg/network/blockqueue.go b/pkg/network/blockqueue.go index 8a21cab1b..7cc796886 100644 --- a/pkg/network/blockqueue.go +++ b/pkg/network/blockqueue.go @@ -4,6 +4,7 @@ import ( "github.com/Workiva/go-datastructures/queue" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -13,6 +14,7 @@ type blockQueue struct { checkBlocks chan struct{} chain blockchainer.Blockqueuer relayF func(*block.Block) + discarded *atomic.Bool } const ( @@ -32,6 +34,7 @@ func newBlockQueue(capacity int, bc blockchainer.Blockqueuer, log *zap.Logger, r checkBlocks: make(chan struct{}, 1), chain: bc, relayF: relayer, + discarded: atomic.NewBool(false), } } @@ -91,8 +94,10 @@ func (bq *blockQueue) putBlock(block *block.Block) error { } func (bq *blockQueue) discard() { - close(bq.checkBlocks) - bq.queue.Dispose() + if bq.discarded.CAS(false, true) { + close(bq.checkBlocks) + bq.queue.Dispose() + } } func (bq *blockQueue) length() int { diff --git a/pkg/network/message.go b/pkg/network/message.go index cbdaa554f..fb605c8a7 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -71,6 +71,8 @@ const ( CMDBlock = CommandType(payload.BlockType) CMDExtensible = CommandType(payload.ExtensibleType) CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType) + CMDGetMPTData CommandType = 0x51 // 0x5.. commands are used for extensions (P2PNotary, state exchange cmds) + CMDMPTData CommandType = 0x52 CMDReject CommandType = 0x2f // SPV protocol. @@ -136,6 +138,10 @@ func (m *Message) decodePayload() error { p = &payload.Version{} case CMDInv, CMDGetData: p = &payload.Inventory{} + case CMDGetMPTData: + p = &payload.MPTInventory{} + case CMDMPTData: + p = &payload.MPTData{} case CMDAddr: p = &payload.AddressList{} case CMDBlock: @@ -221,7 +227,7 @@ func (m *Message) tryCompressPayload() error { if m.Flags&Compressed == 0 { switch m.Payload.(type) { case *payload.Headers, *payload.MerkleBlock, payload.NullPayload, - *payload.Inventory: + *payload.Inventory, *payload.MPTInventory: break default: size := len(compressedPayload) diff --git a/pkg/network/message_string.go b/pkg/network/message_string.go index 7da007079..2ebdacd9b 100644 --- a/pkg/network/message_string.go +++ b/pkg/network/message_string.go @@ -26,6 +26,8 @@ func _() { _ = x[CMDBlock-44] _ = x[CMDExtensible-46] _ = x[CMDP2PNotaryRequest-80] + _ = x[CMDGetMPTData-81] + _ = x[CMDMPTData-82] _ = x[CMDReject-47] _ = x[CMDFilterLoad-48] _ = x[CMDFilterAdd-49] @@ -44,7 +46,7 @@ const ( _CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear" _CommandType_name_7 = "CMDMerkleBlock" _CommandType_name_8 = "CMDAlert" - _CommandType_name_9 = "CMDP2PNotaryRequest" + _CommandType_name_9 = "CMDP2PNotaryRequestCMDGetMPTDataCMDMPTData" ) var ( @@ -55,6 +57,7 @@ var ( _CommandType_index_4 = [...]uint8{0, 12, 22} _CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58} _CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61} + _CommandType_index_9 = [...]uint8{0, 19, 32, 42} ) func (i CommandType) String() string { @@ -83,8 +86,9 @@ func (i CommandType) String() string { return _CommandType_name_7 case i == 64: return _CommandType_name_8 - case i == 80: - return _CommandType_name_9 + case 80 <= i && i <= 82: + i -= 80 + return _CommandType_name_9[_CommandType_index_9[i]:_CommandType_index_9[i+1]] default: return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/network/message_test.go b/pkg/network/message_test.go index c4976a702..38b620189 100644 --- a/pkg/network/message_test.go +++ b/pkg/network/message_test.go @@ -258,6 +258,21 @@ func TestEncodeDecodeNotFound(t *testing.T) { }) } +func TestEncodeDecodeGetMPTData(t *testing.T) { + testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{ + Hashes: []util.Uint256{ + {1, 2, 3}, + {4, 5, 6}, + }, + }) +} + +func TestEncodeDecodeMPTData(t *testing.T) { + testEncodeDecode(t, CMDMPTData, &payload.MPTData{ + Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}}, + }) +} + func TestInvalidMessages(t *testing.T) { t.Run("CMDBlock, empty payload", func(t *testing.T) { testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{}) diff --git a/pkg/network/payload/mptdata.go b/pkg/network/payload/mptdata.go new file mode 100644 index 000000000..ba607a86b --- /dev/null +++ b/pkg/network/payload/mptdata.go @@ -0,0 +1,35 @@ +package payload + +import ( + "errors" + + "github.com/nspcc-dev/neo-go/pkg/io" +) + +// MPTData represents the set of serialized MPT nodes. +type MPTData struct { + Nodes [][]byte +} + +// EncodeBinary implements io.Serializable. +func (d *MPTData) EncodeBinary(w *io.BinWriter) { + w.WriteVarUint(uint64(len(d.Nodes))) + for _, n := range d.Nodes { + w.WriteVarBytes(n) + } +} + +// DecodeBinary implements io.Serializable. +func (d *MPTData) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz == 0 { + r.Err = errors.New("empty MPT nodes list") + return + } + for i := uint64(0); i < sz; i++ { + d.Nodes = append(d.Nodes, r.ReadVarBytes()) + if r.Err != nil { + return + } + } +} diff --git a/pkg/network/payload/mptdata_test.go b/pkg/network/payload/mptdata_test.go new file mode 100644 index 000000000..d9db7bff6 --- /dev/null +++ b/pkg/network/payload/mptdata_test.go @@ -0,0 +1,24 @@ +package payload + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/stretchr/testify/require" +) + +func TestMPTData_EncodeDecodeBinary(t *testing.T) { + t.Run("empty", func(t *testing.T) { + d := new(MPTData) + bytes, err := testserdes.EncodeBinary(d) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(bytes, new(MPTData))) + }) + + t.Run("good", func(t *testing.T) { + d := &MPTData{ + Nodes: [][]byte{{}, {1}, {1, 2, 3}}, + } + testserdes.EncodeDecodeBinary(t, d, new(MPTData)) + }) +} diff --git a/pkg/network/payload/mptinventory.go b/pkg/network/payload/mptinventory.go new file mode 100644 index 000000000..66aaf5c02 --- /dev/null +++ b/pkg/network/payload/mptinventory.go @@ -0,0 +1,32 @@ +package payload + +import ( + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxMPTHashesCount is the maximum number of requested MPT nodes hashes. +const MaxMPTHashesCount = 32 + +// MPTInventory payload. +type MPTInventory struct { + // A list of requested MPT nodes hashes. + Hashes []util.Uint256 +} + +// NewMPTInventory return a pointer to an MPTInventory. +func NewMPTInventory(hashes []util.Uint256) *MPTInventory { + return &MPTInventory{ + Hashes: hashes, + } +} + +// DecodeBinary implements Serializable interface. +func (p *MPTInventory) DecodeBinary(br *io.BinReader) { + br.ReadArray(&p.Hashes, MaxMPTHashesCount) +} + +// EncodeBinary implements Serializable interface. +func (p *MPTInventory) EncodeBinary(bw *io.BinWriter) { + bw.WriteArray(p.Hashes) +} diff --git a/pkg/network/payload/mptinventory_test.go b/pkg/network/payload/mptinventory_test.go new file mode 100644 index 000000000..a0c052a67 --- /dev/null +++ b/pkg/network/payload/mptinventory_test.go @@ -0,0 +1,38 @@ +package payload + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestMPTInventory_EncodeDecodeBinary(t *testing.T) { + t.Run("empty", func(t *testing.T) { + testserdes.EncodeDecodeBinary(t, NewMPTInventory([]util.Uint256{}), new(MPTInventory)) + }) + + t.Run("good", func(t *testing.T) { + inv := NewMPTInventory([]util.Uint256{{1, 2, 3}, {2, 3, 4}}) + testserdes.EncodeDecodeBinary(t, inv, new(MPTInventory)) + }) + + t.Run("too large", func(t *testing.T) { + check := func(t *testing.T, count int, fail bool) { + h := make([]util.Uint256, count) + for i := range h { + h[i] = util.Uint256{1, 2, 3} + } + if fail { + bytes, err := testserdes.EncodeBinary(NewMPTInventory(h)) + require.NoError(t, err) + require.Error(t, testserdes.DecodeBinary(bytes, new(MPTInventory))) + } else { + testserdes.EncodeDecodeBinary(t, NewMPTInventory(h), new(MPTInventory)) + } + } + check(t, MaxMPTHashesCount, false) + check(t, MaxMPTHashesCount+1, true) + }) +} diff --git a/pkg/network/server.go b/pkg/network/server.go index 457add0b0..b9f53ed0e 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -7,6 +7,7 @@ import ( "fmt" mrand "math/rand" "net" + "sort" "strconv" "sync" "time" @@ -17,7 +18,9 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/extpool" "github.com/nspcc-dev/neo-go/pkg/network/payload" @@ -67,6 +70,7 @@ type ( discovery Discoverer chain blockchainer.Blockchainer bQueue *blockQueue + bSyncQueue *blockQueue consensus consensus.Service mempool *mempool.Pool notaryRequestPool *mempool.Pool @@ -93,6 +97,7 @@ type ( oracle *oracle.Oracle stateRoot stateroot.Service + stateSync blockchainer.StateSync log *zap.Logger } @@ -191,6 +196,10 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai } s.stateRoot = sr + sSync := chain.GetStateSyncModule() + s.stateSync = sSync + s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil) + if config.OracleCfg.Enabled { orcCfg := oracle.Config{ Log: log, @@ -277,6 +286,7 @@ func (s *Server) Start(errChan chan error) { go s.broadcastTxLoop() go s.relayBlocksLoop() go s.bQueue.run() + go s.bSyncQueue.run() go s.transport.Accept() setServerAndNodeVersions(s.UserAgent, strconv.FormatUint(uint64(s.id), 10)) s.run() @@ -292,6 +302,7 @@ func (s *Server) Shutdown() { p.Disconnect(errServerShutdown) } s.bQueue.discard() + s.bSyncQueue.discard() if s.StateRootCfg.Enabled { s.stateRoot.Shutdown() } @@ -573,6 +584,10 @@ func (s *Server) IsInSync() bool { var peersNumber int var notHigher int + if s.stateSync.IsActive() { + return false + } + if s.MinPeers == 0 { return true } @@ -630,6 +645,9 @@ func (s *Server) handleVersionCmd(p Peer, version *payload.Version) error { // handleBlockCmd processes the received block received from its peer. func (s *Server) handleBlockCmd(p Peer, block *block.Block) error { + if s.stateSync.IsActive() { + return s.bSyncQueue.putBlock(block) + } return s.bQueue.putBlock(block) } @@ -639,25 +657,46 @@ func (s *Server) handlePing(p Peer, ping *payload.Ping) error { if err != nil { return err } - if s.chain.BlockHeight() < ping.LastBlockIndex { - err = s.requestBlocks(p) - if err != nil { - return err - } + err = s.requestBlocksOrHeaders(p) + if err != nil { + return err } return p.EnqueueP2PMessage(NewMessage(CMDPong, payload.NewPing(s.chain.BlockHeight(), s.id))) } +func (s *Server) requestBlocksOrHeaders(p Peer) error { + if s.stateSync.NeedHeaders() { + if s.chain.HeaderHeight() < p.LastBlockIndex() { + return s.requestHeaders(p) + } + return nil + } + var bq blockchainer.Blockqueuer = s.chain + if s.stateSync.IsActive() { + bq = s.stateSync + } + if bq.BlockHeight() < p.LastBlockIndex() { + return s.requestBlocks(bq, p) + } + return nil +} + +// requestHeaders sends a CMDGetHeaders message to the peer to sync up in headers. +func (s *Server) requestHeaders(p Peer) error { + // TODO: optimize + currHeight := s.chain.HeaderHeight() + needHeight := currHeight + 1 + payload := payload.NewGetBlockByIndex(needHeight, -1) + return p.EnqueueP2PMessage(NewMessage(CMDGetHeaders, payload)) +} + // handlePing processes pong request. func (s *Server) handlePong(p Peer, pong *payload.Ping) error { err := p.HandlePong(pong) if err != nil { return err } - if s.chain.BlockHeight() < pong.LastBlockIndex { - return s.requestBlocks(p) - } - return nil + return s.requestBlocksOrHeaders(p) } // handleInvCmd processes the received inventory. @@ -766,6 +805,50 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { return nil } +// handleGetMPTDataCmd processes the received MPT inventory. +func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error { + if !s.chain.GetConfig().P2PStateExchangeExtensions { + return errors.New("GetMPTDataCMD was received, but P2PStateExchangeExtensions are disabled") + } + if s.chain.GetConfig().KeepOnlyLatestState { + // TODO: implement keeping MPT states for P1 and P2 height (#2095, #2152 related) + return errors.New("GetMPTDataCMD was received, but only latest MPT state is supported") + } + resp := payload.MPTData{} + capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes))) + for _, h := range inv.Hashes { + if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type + break + } + err := s.stateSync.Traverse(h, + func(_ mpt.Node, node []byte) bool { + l := len(node) + size := l + io.GetVarSize(l) + if size > capLeft { + return true + } + resp.Nodes = append(resp.Nodes, node) + capLeft -= size + return false + }) + if err != nil { + return fmt.Errorf("failed to traverse MPT starting from %s: %w", h.StringBE(), err) + } + } + if len(resp.Nodes) > 0 { + msg := NewMessage(CMDMPTData, &resp) + return p.EnqueueP2PMessage(msg) + } + return nil +} + +func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error { + if !s.chain.GetConfig().P2PStateExchangeExtensions { + return errors.New("MPTDataCMD was received, but P2PStateExchangeExtensions are disabled") + } + return s.stateSync.AddMPTNodes(data.Nodes) +} + // handleGetBlocksCmd processes the getblocks request. func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { count := gb.Count @@ -845,6 +928,11 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error return p.EnqueueP2PMessage(msg) } +// handleHeadersCmd processes headers payload. +func (s *Server) handleHeadersCmd(p Peer, h *payload.Headers) error { + return s.stateSync.AddHeaders(h.Hdrs...) +} + // handleExtensibleCmd processes received extensible payload. func (s *Server) handleExtensibleCmd(e *payload.Extensible) error { if !s.syncReached.Load() { @@ -993,8 +1081,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error { // 1. Block range is divided into chunks of payload.MaxHashesCount. // 2. Send requests for chunk in increasing order. // 3. After all requests were sent, request random height. -func (s *Server) requestBlocks(p Peer) error { - var currHeight = s.chain.BlockHeight() +func (s *Server) requestBlocks(bq blockchainer.Blockqueuer, p Peer) error { + var currHeight = bq.BlockHeight() var peerHeight = p.LastBlockIndex() var needHeight uint32 // lastRequestedHeight can only be increased. @@ -1051,9 +1139,18 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDGetData: inv := msg.Payload.(*payload.Inventory) return s.handleGetDataCmd(peer, inv) + case CMDGetMPTData: + inv := msg.Payload.(*payload.MPTInventory) + return s.handleGetMPTDataCmd(peer, inv) + case CMDMPTData: + inv := msg.Payload.(*payload.MPTData) + return s.handleMPTDataCmd(peer, inv) case CMDGetHeaders: gh := msg.Payload.(*payload.GetBlockByIndex) return s.handleGetHeadersCmd(peer, gh) + case CMDHeaders: + h := msg.Payload.(*payload.Headers) + return s.handleHeadersCmd(peer, h) case CMDInv: inventory := msg.Payload.(*payload.Inventory) return s.handleInvCmd(peer, inventory) @@ -1093,6 +1190,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { } go peer.StartProtocol() + s.tryInitStateSync() s.tryStartServices() default: return fmt.Errorf("received '%s' during handshake", msg.Command.String()) @@ -1101,6 +1199,52 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return nil } +func (s *Server) tryInitStateSync() { + if !s.stateSync.IsActive() { + s.bSyncQueue.discard() + return + } + + if s.stateSync.IsInitialized() { + return + } + + var peersNumber int + s.lock.RLock() + heights := make([]uint32, 0) + for p := range s.peers { + if p.Handshaked() { + peersNumber++ + peerLastBlock := p.LastBlockIndex() + i := sort.Search(len(heights), func(i int) bool { + return heights[i] >= peerLastBlock + }) + heights = append(heights, peerLastBlock) + if i != len(heights)-1 { + copy(heights[i+1:], heights[i:]) + heights[i] = peerLastBlock + } + } + } + s.lock.RUnlock() + if peersNumber >= s.MinPeers && len(heights) > 0 { + // choose the height of the median peer as current chain's height + h := heights[len(heights)/2] + err := s.stateSync.Init(h) + if err != nil { + s.log.Fatal("failed to init state sync module", + zap.Uint32("evaluated chain's blockHeight", h), + zap.Uint32("blockHeight", s.chain.BlockHeight()), + zap.Uint32("headerHeight", s.chain.HeaderHeight()), + zap.Error(err)) + } + + // module can be inactive after init (i.e. full state is collected and ordinary block processing is needed) + if !s.stateSync.IsActive() { + s.bSyncQueue.discard() + } + } +} func (s *Server) handleNewPayload(p *payload.Extensible) { _, err := s.extensiblePool.Add(p) if err != nil { diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 0c9c9a358..4b1de5b64 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -2,6 +2,7 @@ package network import ( "errors" + "fmt" "math/big" "net" "strconv" @@ -46,7 +47,10 @@ func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") } func TestNewServer(t *testing.T) { - bc := &fakechain.FakeChain{} + bc := &fakechain.FakeChain{ProtocolConfiguration: config.ProtocolConfiguration{ + P2PStateExchangeExtensions: true, + StateRootInHeader: true, + }} s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery) require.Error(t, err) @@ -899,3 +903,39 @@ func TestVerifyNotaryRequest(t *testing.T) { require.NoError(t, verifyNotaryRequest(bc, nil, r)) }) } + +func TestTryInitStateSync(t *testing.T) { + t.Run("module inactive", func(t *testing.T) { + s := startTestServer(t) + s.tryInitStateSync() + }) + + t.Run("module already initialized", func(t *testing.T) { + s := startTestServer(t) + s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true} + s.tryInitStateSync() + }) + + t.Run("good", func(t *testing.T) { + s := startTestServer(t) + for _, h := range []uint32{10, 8, 7, 4, 11, 4} { + p := newLocalPeer(t, s) + p.handshaked = true + p.lastBlockIndex = h + s.peers[p] = true + } + p := newLocalPeer(t, s) + p.handshaked = false // one disconnected peer to check it won't be taken into attention + p.lastBlockIndex = 5 + s.peers[p] = true + var expectedH uint32 = 8 // median peer + + s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error { + if h != expectedH { + return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h) + } + return nil + }} + s.tryInitStateSync() + }) +} diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 8ff47a18c..5b159fda1 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -267,12 +267,10 @@ func (p *TCPPeer) StartProtocol() { zap.Uint32("id", p.Version().Nonce)) p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities) - if p.server.chain.BlockHeight() < p.LastBlockIndex() { - err = p.server.requestBlocks(p) - if err != nil { - p.Disconnect(err) - return - } + err = p.server.requestBlocksOrHeaders(p) + if err != nil { + p.Disconnect(err) + return } timer := time.NewTimer(p.server.ProtoTickInterval) @@ -281,10 +279,8 @@ func (p *TCPPeer) StartProtocol() { case <-p.done: return case <-timer.C: - // Try to sync in headers and block with the peer if his block height is higher then ours. - if p.LastBlockIndex() > p.server.chain.BlockHeight() { - err = p.server.requestBlocks(p) - } + // Try to sync in headers and block with the peer if his block height is higher than ours. + err = p.server.requestBlocksOrHeaders(p) if err == nil { timer.Reset(p.server.ProtoTickInterval) } From 74f1848d192d0f54c70d863690ee12d24eb71531 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Wed, 11 Aug 2021 14:29:03 +0300 Subject: [PATCH 03/15] core: adjust LastUpdatedBlock calculation for NEP17 balances ...wrt P2PStateExchange setting. --- docs/rpc.md | 6 +++++- pkg/core/blockchain.go | 15 ++++++++++++++- pkg/rpc/server/server.go | 11 ++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/docs/rpc.md b/docs/rpc.md index f12a7abca..1f8763f59 100644 --- a/docs/rpc.md +++ b/docs/rpc.md @@ -114,7 +114,11 @@ balance won't be shown in the list of NEP17 balances returned by the neo-go node (unlike the C# node behavior). However, transfer logs of such token are still available via `getnep17transfers` RPC call. -The behaviour of the `LastUpdatedBlock` tracking matches the C# node's one. +The behaviour of the `LastUpdatedBlock` tracking for archival nodes as far as for +governing token balances matches the C# node's one. For non-archival nodes and +other NEP17-compliant tokens if transfer's `LastUpdatedBlock` is lower than the +latest state synchronization point P the node working against, then +`LastUpdatedBlock` equals P. ### Unsupported methods diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 18fdf786f..1b3a99370 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1223,12 +1223,25 @@ func (bc *Blockchain) GetNEP17Contracts() []util.Uint160 { } // GetNEP17LastUpdated returns a set of contract ids with the corresponding last updated -// block indexes. +// block indexes. In case of an empty account, latest stored state synchronisation point +// is returned under Math.MinInt32 key. func (bc *Blockchain) GetNEP17LastUpdated(acc util.Uint160) (map[int32]uint32, error) { info, err := bc.dao.GetNEP17TransferInfo(acc) if err != nil { return nil, err } + if bc.config.P2PStateExchangeExtensions && bc.config.RemoveUntraceableBlocks { + if _, ok := info.LastUpdated[bc.contracts.NEO.ID]; !ok { + nBalance, lub := bc.contracts.NEO.BalanceOf(bc.dao, acc) + if nBalance.Sign() != 0 { + info.LastUpdated[bc.contracts.NEO.ID] = lub + } + } + } + stateSyncPoint, err := bc.dao.GetStateSyncPoint() + if err == nil { + info.LastUpdated[math.MinInt32] = stateSyncPoint + } return info.LastUpdated, nil } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 3fa82c659..8bd2e26a3 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -679,6 +679,7 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err if err != nil { return nil, response.NewRPCError("Failed to get NEP17 last updated block", err.Error(), err) } + stateSyncPoint := lastUpdated[math.MinInt32] bw := io.NewBufBinWriter() for _, h := range s.chain.GetNEP17Contracts() { balance, err := s.getNEP17Balance(h, u, bw) @@ -692,10 +693,18 @@ func (s *Server) getNEP17Balances(ps request.Params) (interface{}, *response.Err if cs == nil { continue } + lub, ok := lastUpdated[cs.ID] + if !ok { + cfg := s.chain.GetConfig() + if !cfg.P2PStateExchangeExtensions && cfg.RemoveUntraceableBlocks { + return nil, response.NewInternalServerError(fmt.Sprintf("failed to get LastUpdatedBlock for balance of %s token", cs.Hash.StringLE()), nil) + } + lub = stateSyncPoint + } bs.Balances = append(bs.Balances, result.NEP17Balance{ Asset: h, Amount: balance.String(), - LastUpdated: lastUpdated[cs.ID], + LastUpdated: lub, }) } return bs, nil From 6a04880b49561de380929d5f36606cf721def480 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Wed, 11 Aug 2021 16:09:50 +0300 Subject: [PATCH 04/15] core: collapse completed parts of Billet Some kind of marker is needed to check whether node has been collapsed or not. So introduce (HashNode).Collapsed --- pkg/core/mpt/billet.go | 82 +++++++++++--- pkg/core/mpt/billet_test.go | 211 ++++++++++++++++++++++++++++++++++++ pkg/core/mpt/hash.go | 1 + 3 files changed, 279 insertions(+), 15 deletions(-) create mode 100644 pkg/core/mpt/billet_test.go diff --git a/pkg/core/mpt/billet.go b/pkg/core/mpt/billet.go index b47117e4a..f66415036 100644 --- a/pkg/core/mpt/billet.go +++ b/pkg/core/mpt/billet.go @@ -23,7 +23,10 @@ var ( // Billet is based on the following assumptions: // 1. Refcount can only be incremented (we don't change MPT structure during restore, // thus don't need to decrease refcount). -// 2. TODO: Each time the part of Billet is completely restored, it is collapsed into HashNode. +// 2. Each time the part of Billet is completely restored, it is collapsed into +// HashNode. +// 3. Pair (node, path) must be restored only once. It's a duty of MPT pool to manage +// MPT paths in order to provide this assumption. type Billet struct { Store *storage.MemCachedStore @@ -44,8 +47,7 @@ func NewBillet(rootHash util.Uint256, enableRefCount bool, store *storage.MemCac } // RestoreHashNode replaces HashNode located at the provided path by the specified Node -// and stores it. -// TODO: It also maintains MPT as small as possible by collapsing those parts +// and stores it. It also maintains MPT as small as possible by collapsing those parts // of MPT that have been completely restored. func (b *Billet) RestoreHashNode(path []byte, node Node) error { if _, ok := node.(*HashNode); ok { @@ -94,14 +96,16 @@ func (b *Billet) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error if curr.Hash() != val.Hash() { return nil, fmt.Errorf("%w: bad Leaf node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) } - // this node has already been restored, no refcount changes required - return curr, nil + // Once Leaf node is restored, it will be collapsed into HashNode forever, so + // there shouldn't be such situation when we try to restore Leaf node. + panic("bug: can't restore LeafNode") } func (b *Billet) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { if len(path) == 0 && curr.Hash().Equals(val.Hash()) { - // this node has already been restored, no refcount changes required - return curr, nil + // This node has already been restored, so it's an MPT pool duty to avoid + // duplicating restore requests. + panic("bug: can't perform restoring of BranchNode twice") } i, path := splitPath(path) r, err := b.putIntoNode(curr.Children[i], path, val) @@ -109,7 +113,7 @@ func (b *Billet) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, e return nil, err } curr.Children[i] = r - return curr, nil + return b.tryCollapseBranch(curr), nil } func (b *Billet) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { @@ -117,8 +121,9 @@ func (b *Billet) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (N if curr.Hash() != val.Hash() { return nil, fmt.Errorf("%w: bad Extension node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) } - // this node has already been restored, no refcount changes required - return curr, nil + // This node has already been restored, so it's an MPT pool duty to avoid + // duplicating restore requests. + panic("bug: can't perform restoring of ExtensionNode twice") } if !bytes.HasPrefix(path, curr.key) { return nil, fmt.Errorf("%w: can't modify ExtensionNode during restore", ErrRestoreFailed) @@ -129,11 +134,11 @@ func (b *Billet) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (N return nil, err } curr.next = r - return curr, nil + return b.tryCollapseExtension(curr), nil } func (b *Billet) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { - // Once the part of MPT Billet is completely restored, it will be collapsed forever, so + // Once a part of MPT Billet is completely restored, it will be collapsed forever, so // it's an MPT pool duty to avoid duplicating restore requests. if len(path) != 0 { return nil, fmt.Errorf("%w: node has already been collapsed", ErrRestoreFailed) @@ -148,10 +153,21 @@ func (b *Billet) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error if val.Hash() != curr.Hash() { return nil, fmt.Errorf("%w: can't restore HashNode: expected and actual hashes mismatch (%s vs %s)", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) } + + if curr.Collapsed { + // This node has already been restored and collapsed, so it's an MPT pool duty to avoid + // duplicating restore requests. + panic("bug: can't perform restoring of collapsed node") + } + // We also need to increment refcount in both cases. That's the only place where refcount // is changed during restore process. Also flush right now, because sync process can be // interrupted at any time. b.incrementRefAndStore(val.Hash(), val.Bytes()) + + if val.Type() == LeafT { + return b.tryCollapseLeaf(val.(*LeafNode)), nil + } return val, nil } @@ -214,7 +230,7 @@ func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) b } switch n := curr.(type) { case *LeafNode: - return n, nil + return b.tryCollapseLeaf(n), nil case *BranchNode: for i := range n.Children { r, err := b.traverse(n.Children[i], process, ignoreStorageErr) @@ -227,19 +243,55 @@ func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) b } n.Children[i] = r } - return n, nil + return b.tryCollapseBranch(n), nil case *ExtensionNode: r, err := b.traverse(n.next, process, ignoreStorageErr) if err != nil && !errors.Is(err, errStop) { return nil, err } n.next = r - return n, err + return b.tryCollapseExtension(n), err default: return nil, ErrNotFound } } +func (b *Billet) tryCollapseLeaf(curr *LeafNode) Node { + // Leaf can always be collapsed. + res := NewHashNode(curr.Hash()) + res.Collapsed = true + return res +} + +func (b *Billet) tryCollapseExtension(curr *ExtensionNode) Node { + if !(curr.next.Type() == HashT && curr.next.(*HashNode).Collapsed) { + return curr + } + res := NewHashNode(curr.Hash()) + res.Collapsed = true + return res +} + +func (b *Billet) tryCollapseBranch(curr *BranchNode) Node { + canCollapse := true + for i := 0; i < childrenCount; i++ { + if curr.Children[i].Type() == EmptyT { + continue + } + if curr.Children[i].Type() == HashT && curr.Children[i].(*HashNode).Collapsed { + continue + } + canCollapse = false + break + } + if !canCollapse { + return curr + } + res := NewHashNode(curr.Hash()) + res.Collapsed = true + return res +} + func (b *Billet) getFromStore(h util.Uint256) (Node, error) { data, err := b.Store.Get(makeStorageKey(h.BytesBE())) if err != nil { diff --git a/pkg/core/mpt/billet_test.go b/pkg/core/mpt/billet_test.go new file mode 100644 index 000000000..7850b129e --- /dev/null +++ b/pkg/core/mpt/billet_test.go @@ -0,0 +1,211 @@ +package mpt + +import ( + "encoding/binary" + "errors" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestBillet_RestoreHashNode(t *testing.T) { + check := func(t *testing.T, tr *Billet, expectedRoot Node, expectedNode Node, expectedRefCount uint32) { + _ = expectedRoot.Hash() + _ = tr.root.Hash() + require.Equal(t, expectedRoot, tr.root) + expectedBytes, err := tr.Store.Get(makeStorageKey(expectedNode.Hash().BytesBE())) + if expectedRefCount != 0 { + require.NoError(t, err) + require.Equal(t, expectedRefCount, binary.LittleEndian.Uint32(expectedBytes[len(expectedBytes)-4:])) + } else { + require.True(t, errors.Is(err, storage.ErrKeyNotFound)) + } + } + + t.Run("parent is Extension", func(t *testing.T) { + t.Run("restore Branch", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xCD})) + b.Children[5] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0xAB, 0xDE})) + path := toNibbles([]byte{0xAC}) + e := NewExtensionNode(path, NewHashNode(b.Hash())) + tr := NewBillet(e.Hash(), true, newTestStore()) + tr.root = e + + // OK + n := new(NodeObject) + n.DecodeBinary(io.NewBinReaderFromBuf(b.Bytes())) + require.NoError(t, tr.RestoreHashNode(path, n.Node)) + expected := NewExtensionNode(path, n.Node) + check(t, tr, expected, n.Node, 1) + + // One more time (already restored) => panic expected, no refcount changes + require.Panics(t, func() { + _ = tr.RestoreHashNode(path, n.Node) + }) + check(t, tr, expected, n.Node, 1) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewBranchNode()), ErrRestoreFailed)) + check(t, tr, expected, n.Node, 1) + + // New path (changes in the MPT structure are not allowed) => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), n.Node), ErrRestoreFailed)) + check(t, tr, expected, n.Node, 1) + }) + + t.Run("restore Leaf", func(t *testing.T) { + l := NewLeafNode([]byte{0xAB, 0xCD}) + path := toNibbles([]byte{0xAC}) + e := NewExtensionNode(path, NewHashNode(l.Hash())) + tr := NewBillet(e.Hash(), true, newTestStore()) + tr.root = e + + // OK + require.NoError(t, tr.RestoreHashNode(path, l)) + expected := NewHashNode(e.Hash()) // leaf should be collapsed immediately => extension should also be collapsed + expected.Collapsed = true + check(t, tr, expected, l, 1) + + // One more time (already restored and collapsed) => error expected, no refcount changes + require.Error(t, tr.RestoreHashNode(path, l)) + check(t, tr, expected, l, 1) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed)) + check(t, tr, expected, l, 1) + + // New path (changes in the MPT structure are not allowed) => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAB}), l), ErrRestoreFailed)) + check(t, tr, expected, l, 1) + }) + + t.Run("restore Hash", func(t *testing.T) { + h := NewHashNode(util.Uint256{1, 2, 3}) + path := toNibbles([]byte{0xAC}) + e := NewExtensionNode(path, h) + tr := NewBillet(e.Hash(), true, newTestStore()) + tr.root = e + + // no-op + require.True(t, errors.Is(tr.RestoreHashNode(path, h), ErrRestoreFailed)) + check(t, tr, e, h, 0) + }) + }) + + t.Run("parent is Leaf", func(t *testing.T) { + l := NewLeafNode([]byte{0xAB, 0xCD}) + path := []byte{} + tr := NewBillet(l.Hash(), true, newTestStore()) + tr.root = l + + // Already restored => panic expected + require.Panics(t, func() { + _ = tr.RestoreHashNode(path, l) + }) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed)) + + // Non-nil path, but MPT structure can't be changed => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(toNibbles([]byte{0xAC}), NewLeafNode([]byte{0xAB, 0xEF})), ErrRestoreFailed)) + }) + + t.Run("parent is Branch", func(t *testing.T) { + t.Run("middle child", func(t *testing.T) { + l1 := NewLeafNode([]byte{0xAB, 0xCD}) + l2 := NewLeafNode([]byte{0xAB, 0xDE}) + b := NewBranchNode() + b.Children[5] = NewHashNode(l1.Hash()) + b.Children[lastChild] = NewHashNode(l2.Hash()) + tr := NewBillet(b.Hash(), true, newTestStore()) + tr.root = b + + // OK + path := []byte{0x05} + require.NoError(t, tr.RestoreHashNode(path, l1)) + check(t, tr, b, l1, 1) + + // One more time (already restored) => panic expected. + // It's an MPT pool duty to avoid such situations during real restore process. + require.Panics(t, func() { + _ = tr.RestoreHashNode(path, l1) + }) + // No refcount changes expected. + check(t, tr, b, l1, 1) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed)) + check(t, tr, b, l1, 1) + + // New path pointing to the empty HashNode (changes in the MPT structure are not allowed) => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode([]byte{0x01}, l1), ErrRestoreFailed)) + check(t, tr, b, l1, 1) + }) + + t.Run("last child", func(t *testing.T) { + l1 := NewLeafNode([]byte{0xAB, 0xCD}) + l2 := NewLeafNode([]byte{0xAB, 0xDE}) + b := NewBranchNode() + b.Children[5] = NewHashNode(l1.Hash()) + b.Children[lastChild] = NewHashNode(l2.Hash()) + tr := NewBillet(b.Hash(), true, newTestStore()) + tr.root = b + + // OK + path := []byte{} + require.NoError(t, tr.RestoreHashNode(path, l2)) + check(t, tr, b, l2, 1) + + // One more time (already restored) => panic expected. + // It's an MPT pool duty to avoid such situations during real restore process. + require.Panics(t, func() { + _ = tr.RestoreHashNode(path, l2) + }) + // No refcount changes expected. + check(t, tr, b, l2, 1) + + // Same path, but wrong hash => error expected, no refcount changes + require.True(t, errors.Is(tr.RestoreHashNode(path, NewLeafNode([]byte{0xAD})), ErrRestoreFailed)) + check(t, tr, b, l2, 1) + }) + + t.Run("two children with same hash", func(t *testing.T) { + l := NewLeafNode([]byte{0xAB, 0xCD}) + b := NewBranchNode() + // two same hashnodes => leaf's refcount expected to be 2 in the end. + b.Children[3] = NewHashNode(l.Hash()) + b.Children[4] = NewHashNode(l.Hash()) + tr := NewBillet(b.Hash(), true, newTestStore()) + tr.root = b + + // OK + require.NoError(t, tr.RestoreHashNode([]byte{0x03}, l)) + expected := b + expected.Children[3].(*HashNode).Collapsed = true + check(t, tr, b, l, 1) + + // Restore another node with the same hash => no error expected, refcount should be incremented. + // Branch node should be collapsed. + require.NoError(t, tr.RestoreHashNode([]byte{0x04}, l)) + res := NewHashNode(b.Hash()) + res.Collapsed = true + check(t, tr, res, l, 2) + }) + }) + + t.Run("parent is Hash", func(t *testing.T) { + l := NewLeafNode([]byte{0xAB, 0xCD}) + b := NewBranchNode() + b.Children[3] = NewHashNode(l.Hash()) + b.Children[4] = NewHashNode(l.Hash()) + tr := NewBillet(b.Hash(), true, newTestStore()) + + // Should fail, because if it's a hash node with non-empty path, then the node + // has already been collapsed. + require.Error(t, tr.RestoreHashNode([]byte{0x03}, l)) + }) +} diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go index 05ddbe5f3..6ad66924d 100644 --- a/pkg/core/mpt/hash.go +++ b/pkg/core/mpt/hash.go @@ -10,6 +10,7 @@ import ( // HashNode represents MPT's hash node. type HashNode struct { BaseNode + Collapsed bool } var _ Node = (*HashNode)(nil) From 3b7807e897346cfbf322ce24e73c7fde3e1deb9b Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Fri, 13 Aug 2021 12:46:23 +0300 Subject: [PATCH 05/15] network: request unknown MPT nodes In this commit: 1. Request unknown MPT nodes from peers. Note, that StateSync module itself shouldn't be responsible for nodes requests, that's a server duty. 2. Do not request the same node twice, check if it is in storage already. If so, then the only thing remaining is to update refcounter. --- internal/fakechain/fakechain.go | 5 ++ pkg/core/blockchainer/state_sync.go | 1 + pkg/core/mpt/billet.go | 5 +- pkg/core/statesync/module.go | 59 +++++++++++----- pkg/core/statesync/mptpool.go | 20 ++++++ pkg/core/statesync/mptpool_test.go | 104 ++++++++++++++++++++++++++++ pkg/network/server.go | 31 ++++++++- 7 files changed, 203 insertions(+), 22 deletions(-) create mode 100644 pkg/core/statesync/mptpool_test.go diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index ca1f0bc6d..2d7297349 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -508,3 +508,8 @@ func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, func (s *FakeStateSync) GetJumpHeight() (uint32, error) { panic("TODO") } + +// GetUnknownMPTNodesBatch implements StateSync interface. +func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 { + panic("TODO") +} diff --git a/pkg/core/blockchainer/state_sync.go b/pkg/core/blockchainer/state_sync.go index 66643020c..9c2161853 100644 --- a/pkg/core/blockchainer/state_sync.go +++ b/pkg/core/blockchainer/state_sync.go @@ -13,6 +13,7 @@ type StateSync interface { IsActive() bool IsInitialized() bool GetJumpHeight() (uint32, error) + GetUnknownMPTNodesBatch(limit int) []util.Uint256 NeedHeaders() bool NeedMPTNodes() bool Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error diff --git a/pkg/core/mpt/billet.go b/pkg/core/mpt/billet.go index f66415036..843a746ee 100644 --- a/pkg/core/mpt/billet.go +++ b/pkg/core/mpt/billet.go @@ -215,7 +215,7 @@ func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) b return curr, nil } if hn, ok := curr.(*HashNode); ok { - r, err := b.getFromStore(hn.Hash()) + r, err := b.GetFromStore(hn.Hash()) if err != nil { if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) { return hn, nil @@ -292,7 +292,8 @@ func (b *Billet) tryCollapseBranch(curr *BranchNode) Node { return res } -func (b *Billet) getFromStore(h util.Uint256) (Node, error) { +// GetFromStore returns MPT node from the storage. +func (b *Billet) GetFromStore(h util.Uint256) (Node, error) { data, err := b.Store.Get(makeStorageKey(h.BytesBE())) if err != nil { return nil, err diff --git a/pkg/core/statesync/module.go b/pkg/core/statesync/module.go index 24f38f14e..e7874bd39 100644 --- a/pkg/core/statesync/module.go +++ b/pkg/core/statesync/module.go @@ -328,24 +328,10 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error { if r.Err != nil { return fmt.Errorf("failed to decode MPT node: %w", r.Err) } - nPaths, ok := s.mptpool.TryGet(n.Hash()) - if !ok { - // it can easily happen after receiving the same data from different peers. - return nil + err := s.restoreNode(n.Node) + if err != nil { + return err } - - var childrenPaths = make(map[util.Uint256][][]byte) - for _, path := range nPaths { - err := s.billet.RestoreHashNode(path, n.Node) - if err != nil { - return fmt.Errorf("failed to add MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err) - } - for h, paths := range mpt.GetChildrenPaths(path, n.Node) { - childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool - } - } - - s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths) } if s.mptpool.Count() == 0 { s.syncStage |= mptSynced @@ -356,6 +342,37 @@ func (s *Module) AddMPTNodes(nodes [][]byte) error { return nil } +func (s *Module) restoreNode(n mpt.Node) error { + nPaths, ok := s.mptpool.TryGet(n.Hash()) + if !ok { + // it can easily happen after receiving the same data from different peers. + return nil + } + var childrenPaths = make(map[util.Uint256][][]byte) + for _, path := range nPaths { + err := s.billet.RestoreHashNode(path, n) + if err != nil { + return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err) + } + for h, paths := range mpt.GetChildrenPaths(path, n) { + childrenPaths[h] = append(childrenPaths[h], paths...) // it's OK to have duplicates, they'll be handled by mempool + } + } + + s.mptpool.Update(map[util.Uint256][][]byte{n.Hash(): nPaths}, childrenPaths) + + for h := range childrenPaths { + if child, err := s.billet.GetFromStore(h); err == nil { + // child is already in the storage, so we don't need to request it one more time. + err = s.restoreNode(child) + if err != nil { + return fmt.Errorf("unable to restore saved children: %w", err) + } + } + } + return nil +} + // checkSyncIsCompleted checks whether state sync process is completed, i.e. headers up to P+1 // height are fetched, blocks up to P height are stored and MPT nodes for P height are stored. // If so, then jumping to P state sync point occurs. It is not protected by lock, thus caller @@ -438,3 +455,11 @@ func (s *Module) GetJumpHeight() (uint32, error) { } return s.syncPoint, nil } + +// GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max). +func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.mptpool.GetBatch(limit) +} diff --git a/pkg/core/statesync/mptpool.go b/pkg/core/statesync/mptpool.go index 23610d97a..93bbb41a4 100644 --- a/pkg/core/statesync/mptpool.go +++ b/pkg/core/statesync/mptpool.go @@ -48,6 +48,26 @@ func (mp *Pool) GetAll() map[util.Uint256][][]byte { return mp.hashes } +// GetBatch returns set of unknown MPT nodes hashes (`limit` at max). +func (mp *Pool) GetBatch(limit int) []util.Uint256 { + mp.lock.RLock() + defer mp.lock.RUnlock() + + count := len(mp.hashes) + if count > limit { + count = limit + } + result := make([]util.Uint256, 0, limit) + for h := range mp.hashes { + if count == 0 { + break + } + result = append(result, h) + count-- + } + return result +} + // Remove removes MPT node from the pool by the specified hash. func (mp *Pool) Remove(hash util.Uint256) { mp.lock.Lock() diff --git a/pkg/core/statesync/mptpool_test.go b/pkg/core/statesync/mptpool_test.go new file mode 100644 index 000000000..bab32364b --- /dev/null +++ b/pkg/core/statesync/mptpool_test.go @@ -0,0 +1,104 @@ +package statesync + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/internal/random" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestPool_AddRemoveUpdate(t *testing.T) { + mp := NewPool() + + i1 := []byte{1, 2, 3} + i1h := util.Uint256{1, 2, 3} + i2 := []byte{2, 3, 4} + i2h := util.Uint256{2, 3, 4} + i3 := []byte{4, 5, 6} + i3h := util.Uint256{3, 4, 5} + i4 := []byte{3, 4, 5} // has the same hash as i3 + i5 := []byte{6, 7, 8} // has the same hash as i3 + mapAll := map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}, i3h: {i4, i3}} + + // No items + _, ok := mp.TryGet(i1h) + require.False(t, ok) + require.False(t, mp.ContainsKey(i1h)) + require.Equal(t, 0, mp.Count()) + require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll()) + + // Add i1, i2, check OK + mp.Add(i1h, i1) + mp.Add(i2h, i2) + itm, ok := mp.TryGet(i1h) + require.True(t, ok) + require.Equal(t, [][]byte{i1}, itm) + require.True(t, mp.ContainsKey(i1h)) + require.True(t, mp.ContainsKey(i2h)) + require.Equal(t, map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + + // Remove i1 and unexisting item + mp.Remove(i3h) + mp.Remove(i1h) + require.False(t, mp.ContainsKey(i1h)) + require.True(t, mp.ContainsKey(i2h)) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}}, mp.GetAll()) + require.Equal(t, 1, mp.Count()) + + // Update: remove nothing, add all + mp.Update(nil, mapAll) + require.Equal(t, mapAll, mp.GetAll()) + require.Equal(t, 3, mp.Count()) + // Update: remove all, add all + mp.Update(mapAll, mapAll) + require.Equal(t, mapAll, mp.GetAll()) // deletion first, addition after that + require.Equal(t, 3, mp.Count()) + // Update: remove all, add nothing + mp.Update(mapAll, nil) + require.Equal(t, map[util.Uint256][][]byte{}, mp.GetAll()) + require.Equal(t, 0, mp.Count()) + // Update: remove several, add several + mp.Update(map[util.Uint256][][]byte{i1h: {i1}, i2h: {i2}}, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + + // Update: remove nothing, add several with same hashes + mp.Update(nil, map[util.Uint256][][]byte{i3h: {i5, i4}}) // should be sorted by the pool + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i3, i5}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + // Update: remove several with same hashes, add nothing + mp.Update(map[util.Uint256][][]byte{i3h: {i5, i4}}, nil) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i3}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) + // Update: remove several with same hashes, add several with same hashes + mp.Update(map[util.Uint256][][]byte{i3h: {i5, i3}}, map[util.Uint256][][]byte{i3h: {i5, i4}}) + require.Equal(t, map[util.Uint256][][]byte{i2h: {i2}, i3h: {i4, i5}}, mp.GetAll()) + require.Equal(t, 2, mp.Count()) +} + +func TestPool_GetBatch(t *testing.T) { + check := func(t *testing.T, limit int, itemsCount int) { + mp := NewPool() + for i := 0; i < itemsCount; i++ { + mp.Add(random.Uint256(), []byte{0x01}) + } + batch := mp.GetBatch(limit) + if limit < itemsCount { + require.Equal(t, limit, len(batch)) + } else { + require.Equal(t, itemsCount, len(batch)) + } + } + + t.Run("limit less than items count", func(t *testing.T) { + check(t, 5, 6) + }) + t.Run("limit more than items count", func(t *testing.T) { + check(t, 6, 5) + }) + t.Run("items count limit", func(t *testing.T) { + check(t, 5, 5) + }) +} diff --git a/pkg/network/server.go b/pkg/network/server.go index b9f53ed0e..bf57ade5f 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -671,12 +671,23 @@ func (s *Server) requestBlocksOrHeaders(p Peer) error { } return nil } - var bq blockchainer.Blockqueuer = s.chain + var ( + bq blockchainer.Blockqueuer = s.chain + requestMPTNodes bool + ) if s.stateSync.IsActive() { bq = s.stateSync + requestMPTNodes = s.stateSync.NeedMPTNodes() } - if bq.BlockHeight() < p.LastBlockIndex() { - return s.requestBlocks(bq, p) + if bq.BlockHeight() >= p.LastBlockIndex() { + return nil + } + err := s.requestBlocks(bq, p) + if err != nil { + return err + } + if requestMPTNodes { + return s.requestMPTNodes(p, s.stateSync.GetUnknownMPTNodesBatch(payload.MaxMPTHashesCount)) } return nil } @@ -849,6 +860,20 @@ func (s *Server) handleMPTDataCmd(p Peer, data *payload.MPTData) error { return s.stateSync.AddMPTNodes(data.Nodes) } +// requestMPTNodes requests specified MPT nodes from the peer or broadcasts +// request if peer is not specified. +func (s *Server) requestMPTNodes(p Peer, itms []util.Uint256) error { + if len(itms) == 0 { + return nil + } + if len(itms) > payload.MaxMPTHashesCount { + itms = itms[:payload.MaxMPTHashesCount] + } + pl := payload.NewMPTInventory(itms) + msg := NewMessage(CMDGetMPTData, pl) + return p.EnqueueP2PMessage(msg) +} + // handleGetBlocksCmd processes the getblocks request. func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error { count := gb.Count From a276a85b72e9ea5a64b80f50d95631497cad9553 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Tue, 17 Aug 2021 15:35:20 +0300 Subject: [PATCH 06/15] core: unify code of state sync module initialization --- pkg/core/statesync/module.go | 41 ++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/pkg/core/statesync/module.go b/pkg/core/statesync/module.go index e7874bd39..507fbc373 100644 --- a/pkg/core/statesync/module.go +++ b/pkg/core/statesync/module.go @@ -136,29 +136,35 @@ func (s *Module) Init(currChainHeight uint32) error { zap.Uint32("point", p), zap.Uint32("evaluated chain's blockHeight", currChainHeight)) - // check headers sync state first + return s.defineSyncStage() +} + +// defineSyncStage sequentially checks and sets sync state process stage after Module +// initialization. It also performs initialization of MPT Billet if necessary. +func (s *Module) defineSyncStage() error { + // check headers sync stage first ltstHeaderHeight := s.bc.HeaderHeight() - if ltstHeaderHeight > p { + if ltstHeaderHeight > s.syncPoint { s.syncStage = headersSynced s.log.Info("headers are in sync", zap.Uint32("headerHeight", s.bc.HeaderHeight())) } - // check blocks sync state - s.blockHeight = s.getLatestSavedBlock(p) - if s.blockHeight >= p { + // check blocks sync stage + s.blockHeight = s.getLatestSavedBlock(s.syncPoint) + if s.blockHeight >= s.syncPoint { s.syncStage |= blocksSynced s.log.Info("blocks are in sync", zap.Uint32("blockHeight", s.blockHeight)) } - // check MPT sync state - if s.blockHeight > p { + // check MPT sync stage + if s.blockHeight > s.syncPoint { s.syncStage |= mptSynced s.log.Info("MPT is in sync", zap.Uint32("stateroot height", s.bc.GetStateModule().CurrentLocalHeight())) } else if s.syncStage&headersSynced != 0 { - header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(p + 1))) + header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint + 1))) if err != nil { return fmt.Errorf("failed to get header to initialize MPT billet: %w", err) } @@ -186,13 +192,13 @@ func (s *Module) Init(currChainHeight uint32) error { return false }, true) if err != nil { - return fmt.Errorf("failed to traverse MPT while initialization: %w", err) + return fmt.Errorf("failed to traverse MPT during initialization: %w", err) } s.mptpool.Update(nil, pool.GetAll()) if s.mptpool.Count() == 0 { s.syncStage |= mptSynced s.log.Info("MPT is in sync", - zap.Uint32("stateroot height", p)) + zap.Uint32("stateroot height", s.syncPoint)) } } @@ -234,21 +240,10 @@ func (s *Module) AddHeaders(hdrs ...*block.Header) error { hdrsErr := s.bc.AddHeaders(hdrs...) if s.bc.HeaderHeight() > s.syncPoint { - s.syncStage = headersSynced - s.log.Info("headers for state sync are fetched", - zap.Uint32("header height", s.bc.HeaderHeight())) - - header, err := s.bc.GetHeader(s.bc.GetHeaderHash(int(s.syncPoint) + 1)) + err := s.defineSyncStage() if err != nil { - s.log.Fatal("failed to get header to initialize MPT billet", - zap.Uint32("height", s.syncPoint+1), - zap.Error(err)) + return fmt.Errorf("failed to define current sync stage: %w", err) } - s.billet = mpt.NewBillet(header.PrevStateRoot, s.bc.GetConfig().KeepOnlyLatestState, s.dao.Store) - s.mptpool.Add(header.PrevStateRoot, []byte{}) - s.log.Info("MPT billet initialized", - zap.Uint32("height", s.syncPoint), - zap.String("state root", header.PrevStateRoot.StringBE())) } return hdrsErr } From 51f405471ebb6566bbf33efce8d8e0f0a116a050 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Tue, 17 Aug 2021 18:16:10 +0300 Subject: [PATCH 07/15] core: remove outdated blocks/txs/AERs/MPT nodes during state sync Before state sync process can be started, outdated MPT nodes should be removed from storage. After state sync is completed, outdated blocks/transactions/AERs should also be removed. --- pkg/core/blockchain.go | 15 ++++++++++- pkg/core/blockchainer/state_root.go | 1 + pkg/core/stateroot/module.go | 42 +++++++++++++++++++++++++++++ pkg/core/statesync/module.go | 31 ++++++++++++++++++--- 4 files changed, 84 insertions(+), 5 deletions(-) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 1b3a99370..3aa37943d 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -431,7 +431,8 @@ func (bc *Blockchain) JumpToState(module blockchainer.StateSync) error { if err != nil { return fmt.Errorf("failed to get current block: %w", err) } - err = bc.dao.StoreAsCurrentBlock(block, nil) + writeBuf := io.NewBufBinWriter() + err = bc.dao.StoreAsCurrentBlock(block, writeBuf) if err != nil { return fmt.Errorf("failed to store current block: %w", err) } @@ -464,6 +465,18 @@ func (bc *Blockchain) JumpToState(module blockchainer.StateSync) error { return fmt.Errorf("failed to update extensible whitelist: %w", err) } + // After current state is updated, we need to remove outdated state-related data if so. + // The only outdated data we might have is genesis-related data, so check it. + if p-bc.config.MaxTraceableBlocks > 0 { + cache := bc.dao.GetWrapped() + writeBuf.Reset() + err = cache.DeleteBlock(bc.headerHashes[0], writeBuf) + if err != nil { + return fmt.Errorf("failed to remove outdated state data for the genesis block: %w", err) + } + // TODO: remove NEP17 transfers and NEP17 transfer info for genesis block, #2096 related. + } + updateBlockHeightMetric(p) return nil } diff --git a/pkg/core/blockchainer/state_root.go b/pkg/core/blockchainer/state_root.go index 64f7d8b7c..9a540bda8 100644 --- a/pkg/core/blockchainer/state_root.go +++ b/pkg/core/blockchainer/state_root.go @@ -9,6 +9,7 @@ import ( // StateRoot represents local state root module. type StateRoot interface { AddStateRoot(root *state.MPTRoot) error + CleanStorage() error CurrentLocalHeight() uint32 CurrentLocalStateRoot() util.Uint256 CurrentValidatedHeight() uint32 diff --git a/pkg/core/stateroot/module.go b/pkg/core/stateroot/module.go index 20a3396e9..838e4717a 100644 --- a/pkg/core/stateroot/module.go +++ b/pkg/core/stateroot/module.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" "go.uber.org/atomic" "go.uber.org/zap" ) @@ -114,6 +115,47 @@ func (s *Module) Init(height uint32, enableRefCount bool) error { return nil } +// CleanStorage removes all MPT-related data from the storage (MPT nodes, validated stateroots) +// except local stateroot for the current height and GC flag. This method is aimed to clean +// outdated MPT data before state sync process can be started. +// Note: this method is aimed to be called for genesis block only, an error is returned otherwice. +func (s *Module) CleanStorage() error { + if s.localHeight.Load() != 0 { + return fmt.Errorf("can't clean MPT data for non-genesis block: expected local stateroot height 0, got %d", s.localHeight.Load()) + } + gcKey := []byte{byte(storage.DataMPT), prefixGC} + gcVal, err := s.Store.Get(gcKey) + if err != nil { + return fmt.Errorf("failed to get GC flag: %w", err) + } + // + b := s.Store.Batch() + s.Store.Seek([]byte{byte(storage.DataMPT)}, func(k, _ []byte) { + // Must copy here, #1468. + key := slice.Copy(k) + b.Delete(key) + }) + err = s.Store.PutBatch(b) + if err != nil { + return fmt.Errorf("failed to remove outdated MPT-reated items: %w", err) + } + err = s.Store.Put(gcKey, gcVal) + if err != nil { + return fmt.Errorf("failed to store GC flag: %w", err) + } + currentLocal := s.currentLocal.Load().(util.Uint256) + if !currentLocal.Equals(util.Uint256{}) { + err := s.addLocalStateRoot(s.Store, &state.MPTRoot{ + Index: s.localHeight.Load(), + Root: currentLocal, + }) + if err != nil { + return fmt.Errorf("failed to store current local stateroot: %w", err) + } + } + return nil +} + // JumpToState performs jump to the state specified by given stateroot index. func (s *Module) JumpToState(sr *state.MPTRoot, enableRefCount bool) error { if err := s.addLocalStateRoot(s.Store, sr); err != nil { diff --git a/pkg/core/statesync/module.go b/pkg/core/statesync/module.go index 507fbc373..80939de1c 100644 --- a/pkg/core/statesync/module.go +++ b/pkg/core/statesync/module.go @@ -120,10 +120,33 @@ func (s *Module) Init(currChainHeight uint32) error { if err == nil && pOld >= p-s.syncInterval { // old point is still valid, so try to resync states for this point. p = pOld - } else if s.bc.BlockHeight() > p-2*s.syncInterval { - // chain has already been synchronised up to old state sync point and regular blocks processing was started - s.syncStage = inactive - return nil + } else { + if s.bc.BlockHeight() > p-2*s.syncInterval { + // chain has already been synchronised up to old state sync point and regular blocks processing was started. + // Current block height is enough to start regular blocks processing. + s.syncStage = inactive + return nil + } + if err == nil { + // pOld was found, it is outdated, and chain wasn't completely synchronised for pOld. Need to drop the db. + return fmt.Errorf("state sync point %d is found in the storage, "+ + "but sync process wasn't completed and point is outdated. Please, drop the database manually and restart the node to run state sync process", pOld) + } + if s.bc.BlockHeight() != 0 { + // pOld wasn't found, but blocks processing was started in a regular manner and latest stored block is too outdated + // to start regular blocks processing again. Need to drop the db. + return fmt.Errorf("current chain's height is too low to start regular blocks processing from the oldest sync point %d. "+ + "Please, drop the database manually and restart the node to run state sync process", p-s.syncInterval) + } + + // We've reached this point, so chain has genesis block only. As far as we can't ruin + // current chain's state until new state is completely fetched, outdated state-related data + // will be removed from storage during (*Blockchain).JumpToState(...) execution. + // All we need to do right now is to remove genesis-related MPT nodes. + err = s.bc.GetStateModule().CleanStorage() + if err != nil { + return fmt.Errorf("failed to remove outdated MPT data from storage: %w", err) + } } s.syncPoint = p From 6381173293738185631fa4ec3bb8b08ce28e7164 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Mon, 23 Aug 2021 12:02:17 +0300 Subject: [PATCH 08/15] core: store statesync-related storage items under temp prefix State jump should be an atomic operation, we can't modify contract storage items state on-the-fly. Thus, store fresh items under temp prefix and replase the outdated ones after state sync is completed. Related https://github.com/nspcc-dev/neo-go/pull/2019#discussion_r693350460. --- pkg/core/blockchain.go | 24 +++++++++++++++++++++++- pkg/core/mpt/billet.go | 4 ++-- pkg/core/storage/store.go | 19 ++++++++++++------- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 3aa37943d..2e844950c 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -35,6 +35,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "go.uber.org/zap" @@ -43,7 +44,7 @@ import ( // Tuning parameters. const ( headerBatchCount = 2000 - version = "0.1.2" + version = "0.1.3" defaultInitialGAS = 52000000_00000000 defaultMemPoolSize = 50000 @@ -451,6 +452,27 @@ func (bc *Blockchain) JumpToState(module blockchainer.StateSync) error { return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err) } + b := bc.dao.Store.Batch() + bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) { + // Must copy here, #1468. + key := slice.Copy(k) + b.Delete(key) + }) + bc.dao.Store.Seek([]byte{byte(storage.STTempStorage)}, func(k, v []byte) { + // Must copy here, #1468. + oldKey := slice.Copy(k) + b.Delete(oldKey) + key := make([]byte, len(k)) + key[0] = byte(storage.STStorage) + copy(key[1:], k[1:]) + value := slice.Copy(v) + b.Put(key, value) + }) + err = bc.dao.Store.PutBatch(b) + if err != nil { + return fmt.Errorf("failed to replace outdated contract storage items with the fresh ones: %w", err) + } + err = bc.contracts.NEO.InitializeCache(bc, bc.dao) if err != nil { return fmt.Errorf("can't init cache for NEO native contract: %w", err) diff --git a/pkg/core/mpt/billet.go b/pkg/core/mpt/billet.go index 843a746ee..b2e19c3c8 100644 --- a/pkg/core/mpt/billet.go +++ b/pkg/core/mpt/billet.go @@ -62,9 +62,9 @@ func (b *Billet) RestoreHashNode(path []byte, node Node) error { } b.root = r - // If it's a leaf, then put into contract storage. + // If it's a leaf, then put into temporary contract storage. if leaf, ok := node.(*LeafNode); ok { - k := append([]byte{byte(storage.STStorage)}, fromNibbles(path)...) + k := append([]byte{byte(storage.STTempStorage)}, fromNibbles(path)...) _ = b.Store.Put(k, leaf.value) } return nil diff --git a/pkg/core/storage/store.go b/pkg/core/storage/store.go index dd2376c63..26f5f5133 100644 --- a/pkg/core/storage/store.go +++ b/pkg/core/storage/store.go @@ -8,13 +8,18 @@ import ( // KeyPrefix constants. const ( - DataBlock KeyPrefix = 0x01 - DataTransaction KeyPrefix = 0x02 - DataMPT KeyPrefix = 0x03 - STAccount KeyPrefix = 0x40 - STNotification KeyPrefix = 0x4d - STContractID KeyPrefix = 0x51 - STStorage KeyPrefix = 0x70 + DataBlock KeyPrefix = 0x01 + DataTransaction KeyPrefix = 0x02 + DataMPT KeyPrefix = 0x03 + STAccount KeyPrefix = 0x40 + STNotification KeyPrefix = 0x4d + STContractID KeyPrefix = 0x51 + STStorage KeyPrefix = 0x70 + // STTempStorage is used to store contract storage items during state sync process + // in order not to mess up the previous state which has its own items stored by + // STStorage prefix. Once state exchange process is completed, all items with + // STStorage prefix will be replaced with STTempStorage-prefixed ones. + STTempStorage KeyPrefix = 0x71 STNEP17Transfers KeyPrefix = 0x72 STNEP17TransferInfo KeyPrefix = 0x73 IXHeaderHashList KeyPrefix = 0x80 From 0e0b55350a72ab7053344d5dca34cb1499f4d72c Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Wed, 25 Aug 2021 16:24:20 +0300 Subject: [PATCH 09/15] core: convert (*Blockchain).JumpToState to a callback We don't need this method to be exposed, the only its user is the StateSync module. At the same time StateSync module manages its state by itself which guarantees that (*Blockchain).jumpToState will be called with proper StateSync stage. --- internal/fakechain/fakechain.go | 10 ---------- pkg/core/blockchain.go | 10 +++------- pkg/core/blockchainer/blockchainer.go | 1 - pkg/core/blockchainer/state_sync.go | 1 - pkg/core/statesync/module.go | 18 ++++++------------ 5 files changed, 9 insertions(+), 31 deletions(-) diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index 2d7297349..49b876c06 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -307,11 +307,6 @@ func (chain *FakeChain) GetStateSyncModule() blockchainer.StateSync { return &FakeStateSync{} } -// JumpToState implements Blockchainer interface. -func (chain *FakeChain) JumpToState(module blockchainer.StateSync) error { - panic("TODO") -} - // GetStorageItem implements Blockchainer interface. func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem { panic("TODO") @@ -504,11 +499,6 @@ func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, panic("TODO") } -// GetJumpHeight implements StateSync interface. -func (s *FakeStateSync) GetJumpHeight() (uint32, error) { - panic("TODO") -} - // GetUnknownMPTNodesBatch implements StateSync interface. func (s *FakeStateSync) GetUnknownMPTNodesBatch(limit int) []util.Uint256 { panic("TODO") diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 2e844950c..2e1374e65 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -411,17 +411,13 @@ func (bc *Blockchain) init() error { return bc.updateExtensibleWhitelist(bHeight) } -// JumpToState is an atomic operation that changes Blockchain state to the one +// jumpToState is an atomic operation that changes Blockchain state to the one // specified by the state sync point p. All the data needed for the jump must be // collected by the state sync module. -func (bc *Blockchain) JumpToState(module blockchainer.StateSync) error { +func (bc *Blockchain) jumpToState(p uint32) error { bc.lock.Lock() defer bc.lock.Unlock() - p, err := module.GetJumpHeight() - if err != nil { - return fmt.Errorf("failed to get jump height: %w", err) - } if p+1 >= uint32(len(bc.headerHashes)) { return fmt.Errorf("invalid state sync point") } @@ -792,7 +788,7 @@ func (bc *Blockchain) GetStateModule() blockchainer.StateRoot { // GetStateSyncModule returns new state sync service instance. func (bc *Blockchain) GetStateSyncModule() blockchainer.StateSync { - return statesync.NewModule(bc, bc.log, bc.dao) + return statesync.NewModule(bc, bc.log, bc.dao, bc.jumpToState) } // storeBlock performs chain update using the block given, it executes all diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index 9f9ae4e03..998b8a846 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -60,7 +60,6 @@ type Blockchainer interface { GetStorageItems(id int32) (map[string]state.StorageItem, error) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) - JumpToState(module StateSync) error SetOracle(service services.Oracle) mempool.Feer // fee interface ManagementContractHash() util.Uint160 diff --git a/pkg/core/blockchainer/state_sync.go b/pkg/core/blockchainer/state_sync.go index 9c2161853..a8ff919d9 100644 --- a/pkg/core/blockchainer/state_sync.go +++ b/pkg/core/blockchainer/state_sync.go @@ -12,7 +12,6 @@ type StateSync interface { Init(currChainHeight uint32) error IsActive() bool IsInitialized() bool - GetJumpHeight() (uint32, error) GetUnknownMPTNodesBatch(limit int) []util.Uint256 NeedHeaders() bool NeedMPTNodes() bool diff --git a/pkg/core/statesync/module.go b/pkg/core/statesync/module.go index 80939de1c..fb775c388 100644 --- a/pkg/core/statesync/module.go +++ b/pkg/core/statesync/module.go @@ -79,10 +79,12 @@ type Module struct { mptpool *Pool billet *mpt.Billet + + jumpCallback func(p uint32) error } // NewModule returns new instance of statesync module. -func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple) *Module { +func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple, jumpCallback func(p uint32) error) *Module { if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) { return &Module{ dao: s, @@ -97,6 +99,7 @@ func NewModule(bc blockchainer.Blockchainer, log *zap.Logger, s *dao.Simple) *Mo syncInterval: uint32(bc.GetConfig().StateSyncInterval), mptpool: NewPool(), syncStage: none, + jumpCallback: jumpCallback, } } @@ -141,7 +144,7 @@ func (s *Module) Init(currChainHeight uint32) error { // We've reached this point, so chain has genesis block only. As far as we can't ruin // current chain's state until new state is completely fetched, outdated state-related data - // will be removed from storage during (*Blockchain).JumpToState(...) execution. + // will be removed from storage during (*Blockchain).jumpToState(...) execution. // All we need to do right now is to remove genesis-related MPT nodes. err = s.bc.GetStateModule().CleanStorage() if err != nil { @@ -401,7 +404,7 @@ func (s *Module) checkSyncIsCompleted() { } s.log.Info("state is in sync", zap.Uint32("state sync point", s.syncPoint)) - err := s.bc.JumpToState(s) + err := s.jumpCallback(s.syncPoint) if err != nil { s.log.Fatal("failed to jump to the latest state sync point", zap.Error(err)) } @@ -465,15 +468,6 @@ func (s *Module) Traverse(root util.Uint256, process func(node mpt.Node, nodeByt return b.Traverse(process, false) } -// GetJumpHeight returns state sync point to jump to. It is not protected by mutex and should be called -// under the module lock. -func (s *Module) GetJumpHeight() (uint32, error) { - if s.syncStage != headersSynced|mptSynced|blocksSynced { - return 0, errors.New("state sync module has wong state to perform state jump") - } - return s.syncPoint, nil -} - // GetUnknownMPTNodesBatch returns set of currently unknown MPT nodes (`limit` at max). func (s *Module) GetUnknownMPTNodesBatch(limit int) []util.Uint256 { s.lock.RLock() From 5cda24b3afdc3bdc6c4914e050ba84ca974b31a7 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Thu, 26 Aug 2021 15:52:36 +0300 Subject: [PATCH 10/15] core: initialize headers before current block --- pkg/core/blockchain.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 2e1374e65..9b284c143 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -310,16 +310,6 @@ func (bc *Blockchain) init() error { // and the genesis block as first block. bc.log.Info("restoring blockchain", zap.String("version", version)) - bHeight, err := bc.dao.GetCurrentBlockHeight() - if err != nil { - return err - } - bc.blockHeight = bHeight - bc.persistedHeight = bHeight - if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil { - return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err) - } - bc.headerHashes, err = bc.dao.GetHeaderHashes() if err != nil { return err @@ -367,6 +357,16 @@ func (bc *Blockchain) init() error { } } + bHeight, err := bc.dao.GetCurrentBlockHeight() + if err != nil { + return err + } + bc.blockHeight = bHeight + bc.persistedHeight = bHeight + if err = bc.stateRoot.Init(bHeight, bc.config.KeepOnlyLatestState); err != nil { + return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err) + } + err = bc.contracts.NEO.InitializeCache(bc, bc.dao) if err != nil { return fmt.Errorf("can't init cache for NEO native contract: %w", err) From 5cd78c31af46755a61537b80a5744e7dc210c510 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Thu, 26 Aug 2021 17:34:52 +0300 Subject: [PATCH 11/15] core: allow to recover after state jump interruption We need several stages to manage state jump process in order not to mess up old and new contract storage items and to be sure about genesis state data are properly removed from the storage. Other operations do not require separate stage and can be performed each time `jumpToStateInternal` is called. --- pkg/core/blockchain.go | 179 ++++++++++++++++++++++++++++++-------- pkg/core/storage/store.go | 1 + 2 files changed, 144 insertions(+), 36 deletions(-) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 9b284c143..f2e7e14a1 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -44,7 +44,7 @@ import ( // Tuning parameters. const ( headerBatchCount = 2000 - version = "0.1.3" + version = "0.1.4" defaultInitialGAS = 52000000_00000000 defaultMemPoolSize = 50000 @@ -56,6 +56,31 @@ const ( // HeaderVerificationGasLimit is the maximum amount of GAS for block header verification. HeaderVerificationGasLimit = 3_00000000 // 3 GAS defaultStateSyncInterval = 40000 + + // maxStorageBatchSize is the number of elements in storage batch expected to fit into the + // storage without delays and problems. Estimated size of batch in case of given number of + // elements does not exceed 1Mb. + maxStorageBatchSize = 10000 +) + +// stateJumpStage denotes the stage of state jump process. +type stateJumpStage byte + +const ( + // none means that no state jump process was initiated yet. + none stateJumpStage = 1 << iota + // stateJumpStarted means that state jump was just initiated, but outdated storage items + // were not yet removed. + stateJumpStarted + // oldStorageItemsRemoved means that outdated contract storage items were removed, but + // new storage items were not yet saved. + oldStorageItemsRemoved + // newStorageItemsAdded means that contract storage items are up-to-date with the current + // state. + newStorageItemsAdded + // genesisStateRemoved means that state corresponding to the genesis block was removed + // from the storage. + genesisStateRemoved ) var ( @@ -357,6 +382,24 @@ func (bc *Blockchain) init() error { } } + // Check whether StateJump stage is in the storage and continue interrupted state jump if so. + jumpStage, err := bc.dao.Store.Get(storage.SYSStateJumpStage.Bytes()) + if err == nil { + if !(bc.GetConfig().P2PStateExchangeExtensions && bc.GetConfig().RemoveUntraceableBlocks) { + return errors.New("state jump was not completed, but P2PStateExchangeExtensions are disabled or archival node capability is on. " + + "To start an archival node drop the database manually and restart the node") + } + if len(jumpStage) != 1 { + return fmt.Errorf("invalid state jump stage format") + } + // State jump wasn't finished yet, thus continue it. + stateSyncPoint, err := bc.dao.GetStateSyncPoint() + if err != nil { + return fmt.Errorf("failed to get state sync point from the storage") + } + return bc.jumpToStateInternal(stateSyncPoint, stateJumpStage(jumpStage[0])) + } + bHeight, err := bc.dao.GetCurrentBlockHeight() if err != nil { return err @@ -418,17 +461,109 @@ func (bc *Blockchain) jumpToState(p uint32) error { bc.lock.Lock() defer bc.lock.Unlock() + return bc.jumpToStateInternal(p, none) +} + +// jumpToStateInternal is an internal representation of jumpToState callback that +// changes Blockchain state to the one specified by state sync point p and state +// jump stage. All the data needed for the jump must be in the DB, otherwise an +// error is returned. It is not protected by mutex. +func (bc *Blockchain) jumpToStateInternal(p uint32, stage stateJumpStage) error { if p+1 >= uint32(len(bc.headerHashes)) { - return fmt.Errorf("invalid state sync point") + return fmt.Errorf("invalid state sync point %d: headerHeignt is %d", p, len(bc.headerHashes)) } bc.log.Info("jumping to state sync point", zap.Uint32("state sync point", p)) + writeBuf := io.NewBufBinWriter() + jumpStageKey := storage.SYSStateJumpStage.Bytes() + switch stage { + case none: + err := bc.dao.Store.Put(jumpStageKey, []byte{byte(stateJumpStarted)}) + if err != nil { + return fmt.Errorf("failed to store state jump stage: %w", err) + } + fallthrough + case stateJumpStarted: + // Replace old storage items by new ones, it should be done step-by step. + // Firstly, remove all old genesis-related items. + b := bc.dao.Store.Batch() + bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) { + // Must copy here, #1468. + key := slice.Copy(k) + b.Delete(key) + }) + b.Put(jumpStageKey, []byte{byte(oldStorageItemsRemoved)}) + err := bc.dao.Store.PutBatch(b) + if err != nil { + return fmt.Errorf("failed to store state jump stage: %w", err) + } + fallthrough + case oldStorageItemsRemoved: + // Then change STTempStorage prefix to STStorage. Each replace operation is atomic. + for { + count := 0 + b := bc.dao.Store.Batch() + bc.dao.Store.Seek([]byte{byte(storage.STTempStorage)}, func(k, v []byte) { + if count >= maxStorageBatchSize { + return + } + // Must copy here, #1468. + oldKey := slice.Copy(k) + b.Delete(oldKey) + key := make([]byte, len(k)) + key[0] = byte(storage.STStorage) + copy(key[1:], k[1:]) + value := slice.Copy(v) + b.Put(key, value) + count += 2 + }) + if count > 0 { + err := bc.dao.Store.PutBatch(b) + if err != nil { + return fmt.Errorf("failed to replace outdated contract storage items with the fresh ones: %w", err) + } + } else { + break + } + } + err := bc.dao.Store.Put(jumpStageKey, []byte{byte(newStorageItemsAdded)}) + if err != nil { + return fmt.Errorf("failed to store state jump stage: %w", err) + } + fallthrough + case newStorageItemsAdded: + // After current state is updated, we need to remove outdated state-related data if so. + // The only outdated data we might have is genesis-related data, so check it. + if p-bc.config.MaxTraceableBlocks > 0 { + cache := bc.dao.GetWrapped() + writeBuf.Reset() + err := cache.DeleteBlock(bc.headerHashes[0], writeBuf) + if err != nil { + return fmt.Errorf("failed to remove outdated state data for the genesis block: %w", err) + } + // TODO: remove NEP17 transfers and NEP17 transfer info for genesis block, #2096 related. + _, err = cache.Persist() + if err != nil { + return fmt.Errorf("failed to drop genesis block state: %w", err) + } + } + err := bc.dao.Store.Put(jumpStageKey, []byte{byte(genesisStateRemoved)}) + if err != nil { + return fmt.Errorf("failed to store state jump stage: %w", err) + } + case genesisStateRemoved: + // there's nothing to do after that, so just continue with common operations + // and remove state jump stage in the end. + default: + return errors.New("unknown state jump stage") + } + block, err := bc.dao.GetBlock(bc.headerHashes[p]) if err != nil { return fmt.Errorf("failed to get current block: %w", err) } - writeBuf := io.NewBufBinWriter() + writeBuf.Reset() err = bc.dao.StoreAsCurrentBlock(block, writeBuf) if err != nil { return fmt.Errorf("failed to store current block: %w", err) @@ -448,27 +583,6 @@ func (bc *Blockchain) jumpToState(p uint32) error { return fmt.Errorf("can't perform MPT jump to height %d: %w", p, err) } - b := bc.dao.Store.Batch() - bc.dao.Store.Seek([]byte{byte(storage.STStorage)}, func(k, _ []byte) { - // Must copy here, #1468. - key := slice.Copy(k) - b.Delete(key) - }) - bc.dao.Store.Seek([]byte{byte(storage.STTempStorage)}, func(k, v []byte) { - // Must copy here, #1468. - oldKey := slice.Copy(k) - b.Delete(oldKey) - key := make([]byte, len(k)) - key[0] = byte(storage.STStorage) - copy(key[1:], k[1:]) - value := slice.Copy(v) - b.Put(key, value) - }) - err = bc.dao.Store.PutBatch(b) - if err != nil { - return fmt.Errorf("failed to replace outdated contract storage items with the fresh ones: %w", err) - } - err = bc.contracts.NEO.InitializeCache(bc, bc.dao) if err != nil { return fmt.Errorf("can't init cache for NEO native contract: %w", err) @@ -483,19 +597,12 @@ func (bc *Blockchain) jumpToState(p uint32) error { return fmt.Errorf("failed to update extensible whitelist: %w", err) } - // After current state is updated, we need to remove outdated state-related data if so. - // The only outdated data we might have is genesis-related data, so check it. - if p-bc.config.MaxTraceableBlocks > 0 { - cache := bc.dao.GetWrapped() - writeBuf.Reset() - err = cache.DeleteBlock(bc.headerHashes[0], writeBuf) - if err != nil { - return fmt.Errorf("failed to remove outdated state data for the genesis block: %w", err) - } - // TODO: remove NEP17 transfers and NEP17 transfer info for genesis block, #2096 related. - } - updateBlockHeightMetric(p) + + err = bc.dao.Store.Delete(jumpStageKey) + if err != nil { + return fmt.Errorf("failed to remove outdated state jump stage: %w", err) + } return nil } diff --git a/pkg/core/storage/store.go b/pkg/core/storage/store.go index 26f5f5133..bd62f6001 100644 --- a/pkg/core/storage/store.go +++ b/pkg/core/storage/store.go @@ -27,6 +27,7 @@ const ( SYSCurrentHeader KeyPrefix = 0xc1 SYSStateSyncCurrentBlockHeight KeyPrefix = 0xc2 SYSStateSyncPoint KeyPrefix = 0xc3 + SYSStateJumpStage KeyPrefix = 0xc4 SYSVersion KeyPrefix = 0xf0 ) From 36808b89049cbbcbcd5afb3c6f395ca7fe2c2adc Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Fri, 27 Aug 2021 16:58:27 +0300 Subject: [PATCH 12/15] core: clone MPT node while restoring it multiple times We need this to avoid collapse collisions. Example of such collapse described in https://github.com/nspcc-dev/neo-go/pull/2019#discussion_r689629704. --- pkg/core/mpt/branch.go | 6 ++ pkg/core/mpt/empty.go | 3 + pkg/core/mpt/extension.go | 6 ++ pkg/core/mpt/hash.go | 7 ++ pkg/core/mpt/leaf.go | 6 ++ pkg/core/mpt/node.go | 1 + pkg/core/statesync/module.go | 4 +- pkg/core/statesync/module_test.go | 106 ++++++++++++++++++++++++++++++ 8 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 pkg/core/statesync/module_test.go diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go index 0338ff4a7..d2fc84dfc 100644 --- a/pkg/core/mpt/branch.go +++ b/pkg/core/mpt/branch.go @@ -89,6 +89,12 @@ func (b *BranchNode) UnmarshalJSON(data []byte) error { return errors.New("expected branch node") } +// Clone implements Node interface. +func (b *BranchNode) Clone() Node { + res := *b + return &res +} + // splitPath splits path for a branch node. func splitPath(path []byte) (byte, []byte) { if len(path) != 0 { diff --git a/pkg/core/mpt/empty.go b/pkg/core/mpt/empty.go index 6669ef8c1..bc4f4914a 100644 --- a/pkg/core/mpt/empty.go +++ b/pkg/core/mpt/empty.go @@ -54,3 +54,6 @@ func (e EmptyNode) Type() NodeType { func (e EmptyNode) Bytes() []byte { return nil } + +// Clone implements Node interface. +func (EmptyNode) Clone() Node { return EmptyNode{} } diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go index 2dcbcb66b..1266c6acb 100644 --- a/pkg/core/mpt/extension.go +++ b/pkg/core/mpt/extension.go @@ -98,3 +98,9 @@ func (e *ExtensionNode) UnmarshalJSON(data []byte) error { } return errors.New("expected extension node") } + +// Clone implements Node interface. +func (e *ExtensionNode) Clone() Node { + res := *e + return &res +} diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go index 6ad66924d..03dc47a36 100644 --- a/pkg/core/mpt/hash.go +++ b/pkg/core/mpt/hash.go @@ -77,3 +77,10 @@ func (h *HashNode) UnmarshalJSON(data []byte) error { } return errors.New("expected hash node") } + +// Clone implements Node interface. +func (h *HashNode) Clone() Node { + res := *h + res.Collapsed = false + return &res +} diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go index 0f3072b85..0efa45e70 100644 --- a/pkg/core/mpt/leaf.go +++ b/pkg/core/mpt/leaf.go @@ -77,3 +77,9 @@ func (n *LeafNode) UnmarshalJSON(data []byte) error { } return errors.New("expected leaf node") } + +// Clone implements Node interface. +func (n *LeafNode) Clone() Node { + res := *n + return &res +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index af35286c1..a5f6cc814 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -34,6 +34,7 @@ type Node interface { json.Marshaler json.Unmarshaler Size() int + Clone() Node BaseNodeIface } diff --git a/pkg/core/statesync/module.go b/pkg/core/statesync/module.go index fb775c388..801b51a97 100644 --- a/pkg/core/statesync/module.go +++ b/pkg/core/statesync/module.go @@ -371,7 +371,9 @@ func (s *Module) restoreNode(n mpt.Node) error { } var childrenPaths = make(map[util.Uint256][][]byte) for _, path := range nPaths { - err := s.billet.RestoreHashNode(path, n) + // Must clone here in order to avoid future collapse collisions. If the node's refcount>1 then MPT pool + // will manage all paths for this node and call RestoreHashNode separately for each of the paths. + err := s.billet.RestoreHashNode(path, n.Clone()) if err != nil { return fmt.Errorf("failed to restore MPT node with hash %s and path %s: %w", n.Hash().StringBE(), hex.EncodeToString(path), err) } diff --git a/pkg/core/statesync/module_test.go b/pkg/core/statesync/module_test.go new file mode 100644 index 000000000..c1bf6b117 --- /dev/null +++ b/pkg/core/statesync/module_test.go @@ -0,0 +1,106 @@ +package statesync + +import ( + "fmt" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/dao" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestModule_PR2019_discussion_r689629704(t *testing.T) { + expectedStorage := storage.NewMemCachedStore(storage.NewMemoryStore()) + tr := mpt.NewTrie(nil, true, expectedStorage) + require.NoError(t, tr.Put([]byte{0x03}, []byte("leaf1"))) + require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x02}, []byte("leaf2"))) + require.NoError(t, tr.Put([]byte{0x01, 0xab, 0x04}, []byte("leaf3"))) + require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x02}, []byte("leaf2"))) // <-- the same `leaf2` and `leaf3` values are put in the storage, + require.NoError(t, tr.Put([]byte{0x06, 0x01, 0xde, 0x04}, []byte("leaf3"))) // <-- but the path should differ. + require.NoError(t, tr.Put([]byte{0x06, 0x03}, []byte("leaf4"))) + + sr := tr.StateRoot() + tr.Flush() + + // Keep MPT nodes in a map in order not to repeat them. We'll use `nodes` map to ask + // state sync module to restore the nodes. + var ( + nodes = make(map[util.Uint256][]byte) + expectedItems []storage.KeyValue + ) + expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) { + key := slice.Copy(k) + value := slice.Copy(v) + expectedItems = append(expectedItems, storage.KeyValue{ + Key: key, + Value: value, + }) + hash, err := util.Uint256DecodeBytesBE(key[1:]) + require.NoError(t, err) + nodeBytes := value[:len(value)-4] + nodes[hash] = nodeBytes + }) + + actualStorage := storage.NewMemCachedStore(storage.NewMemoryStore()) + // These actions are done in module.Init(), but it's not the point of the test. + // Here we want to test only MPT restoring process. + stateSync := &Module{ + log: zaptest.NewLogger(t), + syncPoint: 1000500, + syncStage: headersSynced, + syncInterval: 100500, + dao: dao.NewSimple(actualStorage, true, false), + mptpool: NewPool(), + } + stateSync.billet = mpt.NewBillet(sr, true, actualStorage) + stateSync.mptpool.Add(sr, []byte{}) + + // The test itself: we'll ask state sync module to restore each node exactly once. + // After that storage content (including storage items and refcounts) must + // match exactly the one got from real MPT trie. MPT pool must be empty. + // State sync module must have mptSynced state in the end. + // MPT Billet root must become a collapsed hashnode (it was checked manually). + requested := make(map[util.Uint256]struct{}) + for { + unknownHashes := stateSync.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one + if len(unknownHashes) == 0 { + break + } + h := unknownHashes[0] + node, ok := nodes[h] + if !ok { + if _, ok = requested[h]; ok { + t.Fatal("node was requested twice") + } + t.Fatal("unknown node was requested") + } + require.NotPanics(t, func() { + err := stateSync.AddMPTNodes([][]byte{node}) + require.NoError(t, err) + }, fmt.Errorf("hash=%s, value=%s", h.StringBE(), string(node))) + requested[h] = struct{}{} + delete(nodes, h) + if len(nodes) == 0 { + break + } + } + require.Equal(t, headersSynced|mptSynced, stateSync.syncStage, "all nodes were sent exactly ones, but MPT wasn't restored") + require.Equal(t, 0, len(nodes), "not all nodes were requested by state sync module") + require.Equal(t, 0, stateSync.mptpool.Count(), "MPT was restored, but MPT pool still contains items") + + // Compare resulting storage items and refcounts. + var actualItems []storage.KeyValue + expectedStorage.Seek(storage.DataMPT.Bytes(), func(k, v []byte) { + key := slice.Copy(k) + value := slice.Copy(v) + actualItems = append(actualItems, storage.KeyValue{ + Key: key, + Value: value, + }) + }) + require.ElementsMatch(t, expectedItems, actualItems) +} From 0aedfd0038c1600815e62f40959698dd9596e207 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Wed, 1 Sep 2021 14:05:57 +0300 Subject: [PATCH 13/15] core: fix bug in MPT pool during Update We need to copy the result of `TryGet` method, otherwice the slice can be modified inside `Add` or `Update` methods, which leads to inconsistent MPT pool state. --- pkg/core/statesync/mptpool.go | 5 ++++- pkg/core/statesync/mptpool_test.go | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/pkg/core/statesync/mptpool.go b/pkg/core/statesync/mptpool.go index 93bbb41a4..819188246 100644 --- a/pkg/core/statesync/mptpool.go +++ b/pkg/core/statesync/mptpool.go @@ -37,7 +37,10 @@ func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) { defer mp.lock.RUnlock() paths, ok := mp.hashes[hash] - return paths, ok + // need to copy here, because we can modify existing array of paths inside the pool. + res := make([][]byte, len(paths)) + copy(res, paths) + return res, ok } // GetAll returns all MPT nodes with the corresponding paths from the pool. diff --git a/pkg/core/statesync/mptpool_test.go b/pkg/core/statesync/mptpool_test.go index bab32364b..2d094025c 100644 --- a/pkg/core/statesync/mptpool_test.go +++ b/pkg/core/statesync/mptpool_test.go @@ -1,6 +1,7 @@ package statesync import ( + "encoding/hex" "testing" "github.com/nspcc-dev/neo-go/internal/random" @@ -102,3 +103,21 @@ func TestPool_GetBatch(t *testing.T) { check(t, 5, 5) }) } + +func TestPool_UpdateUsingSliceFromPool(t *testing.T) { + mp := NewPool() + p1, _ := hex.DecodeString("0f0a0f0f0f0f0f0f0104020b02080c0a06050e070b050404060206060d07080602030b04040b050e040406030f0708060c05") + p2, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040a0b000f04000b03090b02090b0e040f0d0b060d070e0b0b090b0906080602060c0d0f0e0d04070e") + p3, _ := hex.DecodeString("0f0a0f0f0f0f0f0f01040b010d01080f050f000a0d0e08060c040b050800050904060f050807080a080c07040d0107080007") + h, _ := util.Uint256DecodeStringBE("57e197679ef031bf2f0b466b20afe3f67ac04dcff80a1dc4d12dd98dd21a2511") + mp.Add(h, p1) + mp.Add(h, p2) + mp.Add(h, p3) + + toBeRemoved, ok := mp.TryGet(h) + require.True(t, ok) + + mp.Update(map[util.Uint256][][]byte{h: toBeRemoved}, nil) + // test that all items were successfully removed. + require.Equal(t, 0, len(mp.GetAll())) +} From 51c8c0d82b07fad3a91c9a6de769fd96bab947c2 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Tue, 31 Aug 2021 18:39:19 +0300 Subject: [PATCH 14/15] core: add tests for StateSync module --- pkg/core/blockchain_test.go | 88 ++++++++ pkg/core/statesync_test.go | 435 ++++++++++++++++++++++++++++++++++++ 2 files changed, 523 insertions(+) create mode 100644 pkg/core/statesync_test.go diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 2b0ed4a0b..285eeb379 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -1,6 +1,7 @@ package core import ( + "encoding/binary" "errors" "fmt" "math/big" @@ -34,6 +35,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -41,6 +43,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" ) func TestVerifyHeader(t *testing.T) { @@ -1734,3 +1737,88 @@ func TestConfigNativeUpdateHistory(t *testing.T) { check(t, tc) } } + +func TestBlockchain_InitWithIncompleteStateJump(t *testing.T) { + var ( + stateSyncInterval = 4 + maxTraceable uint32 = 6 + ) + spountCfg := func(c *config.Config) { + c.ProtocolConfiguration.RemoveUntraceableBlocks = true + c.ProtocolConfiguration.StateRootInHeader = true + c.ProtocolConfiguration.P2PStateExchangeExtensions = true + c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval + c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable + } + bcSpout := newTestChainWithCustomCfg(t, spountCfg) + initBasicChain(t, bcSpout) + + // reach next to the latest state sync point and pretend that we've just restored + stateSyncPoint := (int(bcSpout.BlockHeight())/stateSyncInterval + 1) * stateSyncInterval + for i := bcSpout.BlockHeight() + 1; i <= uint32(stateSyncPoint); i++ { + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + } + require.Equal(t, uint32(stateSyncPoint), bcSpout.BlockHeight()) + b := bcSpout.newBlock() + require.NoError(t, bcSpout.AddHeaders(&b.Header)) + + // put storage items with STTemp prefix + batch := bcSpout.dao.Store.Batch() + bcSpout.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) { + key := slice.Copy(k) + key[0] = storage.STTempStorage.Bytes()[0] + value := slice.Copy(v) + batch.Put(key, value) + }) + require.NoError(t, bcSpout.dao.Store.PutBatch(batch)) + + checkNewBlockchainErr := func(t *testing.T, cfg func(c *config.Config), store storage.Store, shouldFail bool) { + unitTestNetCfg, err := config.Load("../../config", testchain.Network()) + require.NoError(t, err) + cfg(&unitTestNetCfg) + log := zaptest.NewLogger(t) + _, err = NewBlockchain(store, unitTestNetCfg.ProtocolConfiguration, log) + if shouldFail { + require.Error(t, err) + } else { + require.NoError(t, err) + } + } + boltCfg := func(c *config.Config) { + spountCfg(c) + c.ProtocolConfiguration.KeepOnlyLatestState = true + } + // manually store statejump stage to check statejump recover process + t.Run("invalid RemoveUntraceableBlocks setting", func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)})) + checkNewBlockchainErr(t, func(c *config.Config) { + boltCfg(c) + c.ProtocolConfiguration.RemoveUntraceableBlocks = false + }, bcSpout.dao.Store, true) + }) + t.Run("invalid state jump stage format", func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{0x01, 0x02})) + checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true) + }) + t.Run("missing state sync point", func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)})) + checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true) + }) + t.Run("invalid state sync point", func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stateJumpStarted)})) + point := make([]byte, 4) + binary.LittleEndian.PutUint32(point, uint32(len(bcSpout.headerHashes))) + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point)) + checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, true) + }) + for _, stage := range []stateJumpStage{stateJumpStarted, oldStorageItemsRemoved, newStorageItemsAdded, genesisStateRemoved, 0x03} { + t.Run(fmt.Sprintf("state jump stage %d", stage), func(t *testing.T) { + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateJumpStage.Bytes(), []byte{byte(stage)})) + point := make([]byte, 4) + binary.LittleEndian.PutUint32(point, uint32(stateSyncPoint)) + require.NoError(t, bcSpout.dao.Store.Put(storage.SYSStateSyncPoint.Bytes(), point)) + shouldFail := stage == 0x03 // unknown stage + checkNewBlockchainErr(t, boltCfg, bcSpout.dao.Store, shouldFail) + }) + } +} diff --git a/pkg/core/statesync_test.go b/pkg/core/statesync_test.go new file mode 100644 index 000000000..0751ab65e --- /dev/null +++ b/pkg/core/statesync_test.go @@ -0,0 +1,435 @@ +package core + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/config" + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" + "github.com/stretchr/testify/require" +) + +func TestStateSyncModule_Init(t *testing.T) { + var ( + stateSyncInterval = 2 + maxTraceable uint32 = 3 + ) + spoutCfg := func(c *config.Config) { + c.ProtocolConfiguration.StateRootInHeader = true + c.ProtocolConfiguration.P2PStateExchangeExtensions = true + c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval + c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable + } + bcSpout := newTestChainWithCustomCfg(t, spoutCfg) + for i := 0; i <= 2*stateSyncInterval+int(maxTraceable)+1; i++ { + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + } + + boltCfg := func(c *config.Config) { + spoutCfg(c) + c.ProtocolConfiguration.KeepOnlyLatestState = true + c.ProtocolConfiguration.RemoveUntraceableBlocks = true + } + t.Run("error: module disabled by config", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, func(c *config.Config) { + boltCfg(c) + c.ProtocolConfiguration.RemoveUntraceableBlocks = false + }) + module := bcBolt.GetStateSyncModule() + require.Error(t, module.Init(bcSpout.BlockHeight())) // module inactive (non-archival node) + }) + + t.Run("inactive: spout chain is too low to start state sync process", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(uint32(2*stateSyncInterval-1))) + require.False(t, module.IsActive()) + }) + + t.Run("inactive: bolt chain height is close enough to spout chain height", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + for i := 1; i < int(bcSpout.BlockHeight())-stateSyncInterval; i++ { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, bcBolt.AddBlock(b)) + } + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.False(t, module.IsActive()) + }) + + t.Run("error: bolt chain is too low to start state sync process", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock())) + + module := bcBolt.GetStateSyncModule() + require.Error(t, module.Init(uint32(3*stateSyncInterval))) + }) + + t.Run("initialized: no previous state sync point", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.True(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + }) + + t.Run("error: outdated state sync point in the storage", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + + module = bcBolt.GetStateSyncModule() + require.Error(t, module.Init(bcSpout.BlockHeight()+2*uint32(stateSyncInterval))) + }) + + t.Run("initialized: valid previous state sync point in the storage", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.True(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + }) + + t.Run("initialization from headers/blocks/mpt synced stages", func(t *testing.T) { + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + + // firstly, fetch all headers to create proper DB state (where headers are in sync) + stateSyncPoint := (int(bcSpout.BlockHeight()) / stateSyncInterval) * stateSyncInterval + var expectedHeader *block.Header + for i := 1; i <= int(bcSpout.HeaderHeight()); i++ { + header, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, module.AddHeaders(header)) + if i == stateSyncPoint+1 { + expectedHeader = header + } + } + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + + // then create new statesync module with the same DB and check that state is proper + // (headers are in sync) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + unknownNodes := module.GetUnknownMPTNodesBatch(2) + require.Equal(t, 1, len(unknownNodes)) + require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0]) + + // add several blocks to create DB state where blocks are not in sync yet, but it's not a genesis. + for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint-stateSyncInterval-1; i++ { + block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, module.AddBlock(block)) + } + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight()) + + // then create new statesync module with the same DB and check that state is proper + // (blocks are not in sync yet) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(2) + require.Equal(t, 1, len(unknownNodes)) + require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0]) + require.Equal(t, uint32(stateSyncPoint-stateSyncInterval-1), module.BlockHeight()) + + // add rest of blocks to create DB state where blocks are in sync + for i := stateSyncPoint - stateSyncInterval; i <= stateSyncPoint; i++ { + block, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, module.AddBlock(block)) + } + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + lastBlock, err := bcBolt.GetBlock(expectedHeader.PrevHash) + require.NoError(t, err) + require.Equal(t, uint32(stateSyncPoint), lastBlock.Index) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + + // then create new statesync module with the same DB and check that state is proper + // (headers and blocks are in sync) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(2) + require.Equal(t, 1, len(unknownNodes)) + require.Equal(t, expectedHeader.PrevStateRoot, unknownNodes[0]) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + + // add a few MPT nodes to create DB state where some of MPT nodes are missing + count := 5 + for { + unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one + if len(unknownHashes) == 0 { + break + } + err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool { + require.NoError(t, module.AddMPTNodes([][]byte{nodeBytes})) + return true // add nodes one-by-one + }) + require.NoError(t, err) + count-- + if count < 0 { + break + } + } + + // then create new statesync module with the same DB and check that state is proper + // (headers and blocks are in sync, mpt is not yet synced) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(100) + require.True(t, len(unknownNodes) > 0) + require.NotContains(t, unknownNodes, expectedHeader.PrevStateRoot) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + + // add the rest of MPT nodes and jump to state + for { + unknownHashes := module.GetUnknownMPTNodesBatch(1) // restore nodes one-by-one + if len(unknownHashes) == 0 { + break + } + err := bcSpout.GetStateSyncModule().Traverse(unknownHashes[0], func(node mpt.Node, nodeBytes []byte) bool { + require.NoError(t, module.AddMPTNodes([][]byte{slice.Copy(nodeBytes)})) + return true // add nodes one-by-one + }) + require.NoError(t, err) + } + + // check that module is inactive and statejump is completed + require.False(t, module.IsActive()) + require.False(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(1) + require.True(t, len(unknownNodes) == 0) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight()) + + // create new module from completed state: the module should recognise that state sync is completed + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.False(t, module.IsActive()) + require.False(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(1) + require.True(t, len(unknownNodes) == 0) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight()) + + // add one more block to the restored chain and start new module: the module should recognise state sync is completed + // and regular blocks processing was started + require.NoError(t, bcBolt.AddBlock(bcBolt.newBlock())) + module = bcBolt.GetStateSyncModule() + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.False(t, module.IsActive()) + require.False(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + unknownNodes = module.GetUnknownMPTNodesBatch(1) + require.True(t, len(unknownNodes) == 0) + require.Equal(t, uint32(stateSyncPoint)+1, module.BlockHeight()) + require.Equal(t, uint32(stateSyncPoint)+1, bcBolt.BlockHeight()) + }) +} + +func TestStateSyncModule_RestoreBasicChain(t *testing.T) { + var ( + stateSyncInterval = 4 + maxTraceable uint32 = 6 + stateSyncPoint = 16 + ) + spoutCfg := func(c *config.Config) { + c.ProtocolConfiguration.StateRootInHeader = true + c.ProtocolConfiguration.P2PStateExchangeExtensions = true + c.ProtocolConfiguration.StateSyncInterval = stateSyncInterval + c.ProtocolConfiguration.MaxTraceableBlocks = maxTraceable + } + bcSpout := newTestChainWithCustomCfg(t, spoutCfg) + initBasicChain(t, bcSpout) + + // make spout chain higher that latest state sync point + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + require.NoError(t, bcSpout.AddBlock(bcSpout.newBlock())) + require.Equal(t, uint32(stateSyncPoint+2), bcSpout.BlockHeight()) + + boltCfg := func(c *config.Config) { + spoutCfg(c) + c.ProtocolConfiguration.KeepOnlyLatestState = true + c.ProtocolConfiguration.RemoveUntraceableBlocks = true + } + bcBolt := newTestChainWithCustomCfg(t, boltCfg) + module := bcBolt.GetStateSyncModule() + + t.Run("error: add headers before initialisation", func(t *testing.T) { + h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(1)) + require.NoError(t, err) + require.Error(t, module.AddHeaders(h)) + }) + t.Run("no error: add blocks before initialisation", func(t *testing.T) { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(1)) + require.NoError(t, err) + require.NoError(t, module.AddBlock(b)) + }) + t.Run("error: add MPT nodes without initialisation", func(t *testing.T) { + require.Error(t, module.AddMPTNodes([][]byte{})) + }) + + require.NoError(t, module.Init(bcSpout.BlockHeight())) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.True(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + + // add headers to module + headers := make([]*block.Header, 0, bcSpout.HeaderHeight()) + for i := uint32(1); i <= bcSpout.HeaderHeight(); i++ { + h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(int(i))) + require.NoError(t, err) + headers = append(headers, h) + } + require.NoError(t, module.AddHeaders(headers...)) + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + require.Equal(t, bcSpout.HeaderHeight(), bcBolt.HeaderHeight()) + + // add blocks + t.Run("error: unexpected block index", func(t *testing.T) { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(stateSyncPoint - int(maxTraceable))) + require.NoError(t, err) + require.Error(t, module.AddBlock(b)) + }) + t.Run("error: missing state root in block header", func(t *testing.T) { + b := &block.Block{ + Header: block.Header{ + Index: uint32(stateSyncPoint) - maxTraceable + 1, + StateRootEnabled: false, + }, + } + require.Error(t, module.AddBlock(b)) + }) + t.Run("error: invalid block merkle root", func(t *testing.T) { + b := &block.Block{ + Header: block.Header{ + Index: uint32(stateSyncPoint) - maxTraceable + 1, + StateRootEnabled: true, + MerkleRoot: util.Uint256{1, 2, 3}, + }, + } + require.Error(t, module.AddBlock(b)) + }) + + for i := stateSyncPoint - int(maxTraceable) + 1; i <= stateSyncPoint; i++ { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, module.AddBlock(b)) + } + require.True(t, module.IsActive()) + require.True(t, module.IsInitialized()) + require.False(t, module.NeedHeaders()) + require.True(t, module.NeedMPTNodes()) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + + // add MPT nodes in batches + h, err := bcSpout.GetHeader(bcSpout.GetHeaderHash(stateSyncPoint + 1)) + require.NoError(t, err) + unknownHashes := module.GetUnknownMPTNodesBatch(100) + require.Equal(t, 1, len(unknownHashes)) + require.Equal(t, h.PrevStateRoot, unknownHashes[0]) + nodesMap := make(map[util.Uint256][]byte) + err = bcSpout.GetStateSyncModule().Traverse(h.PrevStateRoot, func(n mpt.Node, nodeBytes []byte) bool { + nodesMap[n.Hash()] = nodeBytes + return false + }) + require.NoError(t, err) + for { + need := module.GetUnknownMPTNodesBatch(10) + if len(need) == 0 { + break + } + add := make([][]byte, len(need)) + for i, h := range need { + nodeBytes, ok := nodesMap[h] + if !ok { + t.Fatal("unknown or restored node requested") + } + add[i] = nodeBytes + delete(nodesMap, h) + } + require.NoError(t, module.AddMPTNodes(add)) + } + require.False(t, module.IsActive()) + require.False(t, module.NeedHeaders()) + require.False(t, module.NeedMPTNodes()) + unknownNodes := module.GetUnknownMPTNodesBatch(1) + require.True(t, len(unknownNodes) == 0) + require.Equal(t, uint32(stateSyncPoint), module.BlockHeight()) + require.Equal(t, uint32(stateSyncPoint), bcBolt.BlockHeight()) + + // add missing blocks to bcBolt: should be ok, because state is synced + for i := stateSyncPoint + 1; i <= int(bcSpout.BlockHeight()); i++ { + b, err := bcSpout.GetBlock(bcSpout.GetHeaderHash(i)) + require.NoError(t, err) + require.NoError(t, bcBolt.AddBlock(b)) + } + require.Equal(t, bcSpout.BlockHeight(), bcBolt.BlockHeight()) + + // compare storage states + fetchStorage := func(bc *Blockchain) []storage.KeyValue { + var kv []storage.KeyValue + bc.dao.Store.Seek(storage.STStorage.Bytes(), func(k, v []byte) { + key := slice.Copy(k) + value := slice.Copy(v) + kv = append(kv, storage.KeyValue{ + Key: key, + Value: value, + }) + }) + return kv + } + expected := fetchStorage(bcSpout) + actual := fetchStorage(bcBolt) + require.ElementsMatch(t, expected, actual) + + // no temp items should be left + bcBolt.dao.Store.Seek(storage.STTempStorage.Bytes(), func(k, v []byte) { + t.Fatal("temp storage items are found") + }) +} From 0fa48691f70524bd69333af769c3ce60637d2dba Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Mon, 6 Sep 2021 15:16:47 +0300 Subject: [PATCH 15/15] network: do not duplicate MPT nodes in GetMPTNodes response Also tests are added. --- internal/fakechain/fakechain.go | 19 +++-- pkg/network/server.go | 7 +- pkg/network/server_test.go | 143 +++++++++++++++++++++++++++++++- 3 files changed, 161 insertions(+), 8 deletions(-) diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index 49b876c06..426972c6b 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -22,6 +22,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" + uatomic "go.uber.org/atomic" ) // FakeChain implements Blockchainer interface, but does not provide real functionality. @@ -45,9 +46,11 @@ type FakeChain struct { // FakeStateSync implements StateSync interface. type FakeStateSync struct { - IsActiveFlag bool - IsInitializedFlag bool + IsActiveFlag uatomic.Bool + IsInitializedFlag uatomic.Bool InitFunc func(h uint32) error + TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error + AddMPTNodesFunc func(nodes [][]byte) error } // NewFakeChain returns new FakeChain structure. @@ -461,7 +464,10 @@ func (s *FakeStateSync) AddHeaders(...*block.Header) error { } // AddMPTNodes implements StateSync interface. -func (s *FakeStateSync) AddMPTNodes([][]byte) error { +func (s *FakeStateSync) AddMPTNodes(nodes [][]byte) error { + if s.AddMPTNodesFunc != nil { + return s.AddMPTNodesFunc(nodes) + } panic("TODO") } @@ -471,11 +477,11 @@ func (s *FakeStateSync) BlockHeight() uint32 { } // IsActive implements StateSync interface. -func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag } +func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag.Load() } // IsInitialized implements StateSync interface. func (s *FakeStateSync) IsInitialized() bool { - return s.IsInitializedFlag + return s.IsInitializedFlag.Load() } // Init implements StateSync interface. @@ -496,6 +502,9 @@ func (s *FakeStateSync) NeedMPTNodes() bool { // Traverse implements StateSync interface. func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + if s.TraverseFunc != nil { + return s.TraverseFunc(root, process) + } panic("TODO") } diff --git a/pkg/network/server.go b/pkg/network/server.go index bf57ade5f..16a6a46f5 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -827,18 +827,23 @@ func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error { } resp := payload.MPTData{} capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes))) + added := make(map[util.Uint256]struct{}) for _, h := range inv.Hashes { if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type break } err := s.stateSync.Traverse(h, - func(_ mpt.Node, node []byte) bool { + func(n mpt.Node, node []byte) bool { + if _, ok := added[n.Hash()]; ok { + return false + } l := len(node) size := l + io.GetVarSize(l) if size > capLeft { return true } resp.Nodes = append(resp.Nodes, node) + added[n.Hash()] = struct{}{} capLeft -= size return false }) diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 4b1de5b64..69dc0d86b 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -17,6 +17,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" @@ -737,6 +738,139 @@ func TestInv(t *testing.T) { }) } +func TestHandleGetMPTData(t *testing.T) { + t.Run("P2PStateExchange extensions off", func(t *testing.T) { + s := startTestServer(t) + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{ + Hashes: []util.Uint256{{1, 2, 3}}, + }) + require.Error(t, s.handleMessage(p, msg)) + }) + + t.Run("KeepOnlyLatestState on", func(t *testing.T) { + s := startTestServer(t) + s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{ + Hashes: []util.Uint256{{1, 2, 3}}, + }) + require.Error(t, s.handleMessage(p, msg)) + }) + + t.Run("good", func(t *testing.T) { + s := startTestServer(t) + s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + var recvResponse atomic.Bool + r1 := random.Uint256() + r2 := random.Uint256() + r3 := random.Uint256() + node := []byte{1, 2, 3} + s.stateSync.(*fakechain.FakeStateSync).TraverseFunc = func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + if !(root.Equals(r1) || root.Equals(r2)) { + t.Fatal("unexpected root") + } + require.False(t, process(mpt.NewHashNode(r3), node)) + return nil + } + found := &payload.MPTData{ + Nodes: [][]byte{node}, // no duplicates expected + } + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + switch msg.Command { + case CMDMPTData: + require.Equal(t, found, msg.Payload) + recvResponse.Store(true) + } + } + hs := []util.Uint256{r1, r2} + s.testHandleMessage(t, p, CMDGetMPTData, payload.NewMPTInventory(hs)) + + require.Eventually(t, recvResponse.Load, time.Second, time.Millisecond) + }) +} + +func TestHandleMPTData(t *testing.T) { + t.Run("P2PStateExchange extensions off", func(t *testing.T) { + s := startTestServer(t) + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDMPTData, &payload.MPTData{ + Nodes: [][]byte{{1, 2, 3}}, + }) + require.Error(t, s.handleMessage(p, msg)) + }) + + t.Run("good", func(t *testing.T) { + s := startTestServer(t) + expected := [][]byte{{1, 2, 3}, {2, 3, 4}} + s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + s.stateSync = &fakechain.FakeStateSync{ + AddMPTNodesFunc: func(nodes [][]byte) error { + require.Equal(t, expected, nodes) + return nil + }, + } + + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDMPTData, &payload.MPTData{ + Nodes: expected, + }) + require.NoError(t, s.handleMessage(p, msg)) + }) +} + +func TestRequestMPTNodes(t *testing.T) { + s := startTestServer(t) + + var actual []util.Uint256 + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDGetMPTData { + actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...) + } + } + s.register <- p + s.register <- p // ensure previous send was handled + + t.Run("no hashes, no message", func(t *testing.T) { + actual = nil + require.NoError(t, s.requestMPTNodes(p, nil)) + require.Nil(t, actual) + }) + t.Run("good, small", func(t *testing.T) { + actual = nil + expected := []util.Uint256{random.Uint256(), random.Uint256()} + require.NoError(t, s.requestMPTNodes(p, expected)) + require.Equal(t, expected, actual) + }) + t.Run("good, exactly one chunk", func(t *testing.T) { + actual = nil + expected := make([]util.Uint256, payload.MaxMPTHashesCount) + for i := range expected { + expected[i] = random.Uint256() + } + require.NoError(t, s.requestMPTNodes(p, expected)) + require.Equal(t, expected, actual) + }) + t.Run("good, too large chunk", func(t *testing.T) { + actual = nil + expected := make([]util.Uint256, payload.MaxMPTHashesCount+1) + for i := range expected { + expected[i] = random.Uint256() + } + require.NoError(t, s.requestMPTNodes(p, expected)) + require.Equal(t, expected[:payload.MaxMPTHashesCount], actual) + }) +} + func TestRequestTx(t *testing.T) { s := startTestServer(t) @@ -912,7 +1046,10 @@ func TestTryInitStateSync(t *testing.T) { t.Run("module already initialized", func(t *testing.T) { s := startTestServer(t) - s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true} + ss := &fakechain.FakeStateSync{} + ss.IsActiveFlag.Store(true) + ss.IsInitializedFlag.Store(true) + s.stateSync = ss s.tryInitStateSync() }) @@ -930,12 +1067,14 @@ func TestTryInitStateSync(t *testing.T) { s.peers[p] = true var expectedH uint32 = 8 // median peer - s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error { + ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error { if h != expectedH { return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h) } return nil }} + ss.IsActiveFlag.Store(true) + s.stateSync = ss s.tryInitStateSync() }) }