From a22b1caa3e4fddfe2e0ddad18f48edfbc4bb577d Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Thu, 29 Jul 2021 18:00:07 +0300 Subject: [PATCH] core: implement MPT Billet structure for MPT restore MPT restore process is much simpler then regular MPT maintaining: trie has a fixed structure, we don't need to remove or rebuild MPT nodes. The only thing we should do is to replace Hash nodes to their unhashed counterparts and increment refcount. It's better not to touch the regular MPT code and create a separate structure for this. --- pkg/core/mpt/billet.go | 261 +++++++++++++++++++++++++++++++++++ pkg/core/mpt/helpers.go | 9 ++ pkg/core/mpt/helpers_test.go | 20 +++ 3 files changed, 290 insertions(+) create mode 100644 pkg/core/mpt/billet.go create mode 100644 pkg/core/mpt/helpers_test.go diff --git a/pkg/core/mpt/billet.go b/pkg/core/mpt/billet.go new file mode 100644 index 000000000..b47117e4a --- /dev/null +++ b/pkg/core/mpt/billet.go @@ -0,0 +1,261 @@ +package mpt + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/slice" +) + +var ( + // ErrRestoreFailed is returned when replacing HashNode by its "unhashed" + // candidate fails. + ErrRestoreFailed = errors.New("failed to restore MPT node") + errStop = errors.New("stop condition is met") +) + +// Billet is a part of MPT trie with missing hash nodes that need to be restored. +// Billet is based on the following assumptions: +// 1. Refcount can only be incremented (we don't change MPT structure during restore, +// thus don't need to decrease refcount). +// 2. TODO: Each time the part of Billet is completely restored, it is collapsed into HashNode. +type Billet struct { + Store *storage.MemCachedStore + + root Node + refcountEnabled bool +} + +// NewBillet returns new billet for MPT trie restoring. It accepts a MemCachedStore +// to decouple storage errors from logic errors so that all storage errors are +// processed during `store.Persist()` at the caller. This also has the benefit, +// that every `Put` can be considered an atomic operation. +func NewBillet(rootHash util.Uint256, enableRefCount bool, store *storage.MemCachedStore) *Billet { + return &Billet{ + Store: store, + root: NewHashNode(rootHash), + refcountEnabled: enableRefCount, + } +} + +// RestoreHashNode replaces HashNode located at the provided path by the specified Node +// and stores it. +// TODO: It also maintains MPT as small as possible by collapsing those parts +// of MPT that have been completely restored. +func (b *Billet) RestoreHashNode(path []byte, node Node) error { + if _, ok := node.(*HashNode); ok { + return fmt.Errorf("%w: unable to restore node into HashNode", ErrRestoreFailed) + } + if _, ok := node.(EmptyNode); ok { + return fmt.Errorf("%w: unable to restore node into EmptyNode", ErrRestoreFailed) + } + r, err := b.putIntoNode(b.root, path, node) + if err != nil { + return err + } + b.root = r + + // If it's a leaf, then put into contract storage. + if leaf, ok := node.(*LeafNode); ok { + k := append([]byte{byte(storage.STStorage)}, fromNibbles(path)...) + _ = b.Store.Put(k, leaf.value) + } + return nil +} + +// putIntoNode puts val with provided path inside curr and returns updated node. +// Reference counters are updated for both curr and returned value. +func (b *Billet) putIntoNode(curr Node, path []byte, val Node) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + return b.putIntoLeaf(n, path, val) + case *BranchNode: + return b.putIntoBranch(n, path, val) + case *ExtensionNode: + return b.putIntoExtension(n, path, val) + case *HashNode: + return b.putIntoHash(n, path, val) + case EmptyNode: + return nil, fmt.Errorf("%w: can't modify EmptyNode during restore", ErrRestoreFailed) + default: + panic("invalid MPT node type") + } +} + +func (b *Billet) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { + if len(path) != 0 { + return nil, fmt.Errorf("%w: can't modify LeafNode during restore", ErrRestoreFailed) + } + if curr.Hash() != val.Hash() { + return nil, fmt.Errorf("%w: bad Leaf node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) + } + // this node has already been restored, no refcount changes required + return curr, nil +} + +func (b *Billet) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { + if len(path) == 0 && curr.Hash().Equals(val.Hash()) { + // this node has already been restored, no refcount changes required + return curr, nil + } + i, path := splitPath(path) + r, err := b.putIntoNode(curr.Children[i], path, val) + if err != nil { + return nil, err + } + curr.Children[i] = r + return curr, nil +} + +func (b *Billet) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { + if len(path) == 0 { + if curr.Hash() != val.Hash() { + return nil, fmt.Errorf("%w: bad Extension node hash: expected %s, got %s", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) + } + // this node has already been restored, no refcount changes required + return curr, nil + } + if !bytes.HasPrefix(path, curr.key) { + return nil, fmt.Errorf("%w: can't modify ExtensionNode during restore", ErrRestoreFailed) + } + + r, err := b.putIntoNode(curr.next, path[len(curr.key):], val) + if err != nil { + return nil, err + } + curr.next = r + return curr, nil +} + +func (b *Billet) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { + // Once the part of MPT Billet is completely restored, it will be collapsed forever, so + // it's an MPT pool duty to avoid duplicating restore requests. + if len(path) != 0 { + return nil, fmt.Errorf("%w: node has already been collapsed", ErrRestoreFailed) + } + + // `curr` hash node can be either of + // 1) saved in storage (i.g. if we've already restored node with the same hash from the + // other part of MPT), so just add it to local in-memory MPT. + // 2) missing from the storage. It's OK because we're syncing MPT state, and the purpose + // is to store missing hash nodes. + // both cases are OK, but we still need to validate `val` against `curr`. + if val.Hash() != curr.Hash() { + return nil, fmt.Errorf("%w: can't restore HashNode: expected and actual hashes mismatch (%s vs %s)", ErrRestoreFailed, curr.Hash().StringBE(), val.Hash().StringBE()) + } + // We also need to increment refcount in both cases. That's the only place where refcount + // is changed during restore process. Also flush right now, because sync process can be + // interrupted at any time. + b.incrementRefAndStore(val.Hash(), val.Bytes()) + return val, nil +} + +func (b *Billet) incrementRefAndStore(h util.Uint256, bs []byte) { + key := makeStorageKey(h.BytesBE()) + if b.refcountEnabled { + var ( + err error + data []byte + cnt int32 + ) + // An item may already be in store. + data, err = b.Store.Get(key) + if err == nil { + cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:])) + } + cnt++ + if len(data) == 0 { + data = append(bs, 0, 0, 0, 0) + } + binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt)) + _ = b.Store.Put(key, data) + } else { + _ = b.Store.Put(key, bs) + } +} + +// Traverse traverses MPT nodes (pre-order) starting from the billet root down +// to its children calling `process` for each serialised node until true is +// returned from `process` function. It also replaces all HashNodes to their +// "unhashed" counterparts until the stop condition is satisfied. +func (b *Billet) Traverse(process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) error { + r, err := b.traverse(b.root, process, ignoreStorageErr) + if err != nil && !errors.Is(err, errStop) { + return err + } + b.root = r + return nil +} + +func (b *Billet) traverse(curr Node, process func(node Node, nodeBytes []byte) bool, ignoreStorageErr bool) (Node, error) { + if _, ok := curr.(EmptyNode); ok { + // We're not interested in EmptyNodes, and they do not affect the + // traversal process, thus remain them untouched. + return curr, nil + } + if hn, ok := curr.(*HashNode); ok { + r, err := b.getFromStore(hn.Hash()) + if err != nil { + if ignoreStorageErr && errors.Is(err, storage.ErrKeyNotFound) { + return hn, nil + } + return nil, err + } + return b.traverse(r, process, ignoreStorageErr) + } + bytes := slice.Copy(curr.Bytes()) + if process(curr, bytes) { + return curr, errStop + } + switch n := curr.(type) { + case *LeafNode: + return n, nil + case *BranchNode: + for i := range n.Children { + r, err := b.traverse(n.Children[i], process, ignoreStorageErr) + if err != nil { + if !errors.Is(err, errStop) { + return nil, err + } + n.Children[i] = r + return n, err + } + n.Children[i] = r + } + return n, nil + case *ExtensionNode: + r, err := b.traverse(n.next, process, ignoreStorageErr) + if err != nil && !errors.Is(err, errStop) { + return nil, err + } + n.next = r + return n, err + default: + return nil, ErrNotFound + } +} + +func (b *Billet) getFromStore(h util.Uint256) (Node, error) { + data, err := b.Store.Get(makeStorageKey(h.BytesBE())) + if err != nil { + return nil, err + } + + var n NodeObject + r := io.NewBinReaderFromBuf(data) + n.DecodeBinary(r) + if r.Err != nil { + return nil, r.Err + } + + if b.refcountEnabled { + data = data[:len(data)-4] + } + n.Node.(flushedNode).setCache(data, h) + return n.Node, nil +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go index a7399d37d..ceabb31f7 100644 --- a/pkg/core/mpt/helpers.go +++ b/pkg/core/mpt/helpers.go @@ -40,3 +40,12 @@ func toNibbles(path []byte) []byte { } return result } + +// fromNibbles performs operation opposite to toNibbles and does no path validity checks. +func fromNibbles(path []byte) []byte { + result := make([]byte, len(path)/2) + for i := range result { + result[i] = path[2*i]<<4 + path[2*i+1] + } + return result +} diff --git a/pkg/core/mpt/helpers_test.go b/pkg/core/mpt/helpers_test.go new file mode 100644 index 000000000..ee5634bfd --- /dev/null +++ b/pkg/core/mpt/helpers_test.go @@ -0,0 +1,20 @@ +package mpt + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToNibblesFromNibbles(t *testing.T) { + check := func(t *testing.T, expected []byte) { + actual := fromNibbles(toNibbles(expected)) + require.Equal(t, expected, actual) + } + t.Run("empty path", func(t *testing.T) { + check(t, []byte{}) + }) + t.Run("non-empty path", func(t *testing.T) { + check(t, []byte{0x01, 0xAC, 0x8d, 0x04, 0xFF}) + }) +}