From 0cf525d62ecbc7cf6a4cf3d3643738eaaa679525 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Thu, 7 Apr 2022 18:11:05 +0300 Subject: [PATCH] core: add ability to traverse backwards for Billet --- pkg/core/mpt/billet.go | 30 ++++++++++++++++++++++-------- pkg/core/mpt/trie.go | 2 +- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/pkg/core/mpt/billet.go b/pkg/core/mpt/billet.go index af7edf2be..40ee35de0 100644 --- a/pkg/core/mpt/billet.go +++ b/pkg/core/mpt/billet.go @@ -205,7 +205,7 @@ func (b *Billet) incrementRefAndStore(h util.Uint256, bs []byte) { // returned from `process` function. It also replaces all HashNodes to their // "unhashed" counterparts until the stop condition is satisfied. func (b *Billet) Traverse(process func(pathToNode []byte, node Node, nodeBytes []byte) bool, ignoreStorageErr bool) error { - r, err := b.traverse(b.root, []byte{}, []byte{}, process, ignoreStorageErr) + r, err := b.traverse(b.root, []byte{}, []byte{}, process, ignoreStorageErr, false) if err != nil && !errors.Is(err, errStop) { return err } @@ -213,7 +213,7 @@ func (b *Billet) Traverse(process func(pathToNode []byte, node Node, nodeBytes [ return nil } -func (b *Billet) traverse(curr Node, path, from []byte, process func(pathToNode []byte, node Node, nodeBytes []byte) bool, ignoreStorageErr bool) (Node, error) { +func (b *Billet) traverse(curr Node, path, from []byte, process func(pathToNode []byte, node Node, nodeBytes []byte) bool, ignoreStorageErr bool, backwards bool) (Node, error) { if _, ok := curr.(EmptyNode); ok { // We're not interested in EmptyNodes, and they do not affect the // traversal process, thus remain them untouched. @@ -227,7 +227,7 @@ func (b *Billet) traverse(curr Node, path, from []byte, process func(pathToNode } return nil, err } - return b.traverse(r, path, from, process, ignoreStorageErr) + return b.traverse(r, path, from, process, ignoreStorageErr, backwards) } if len(from) == 0 { bytes := slice.Copy(curr.Bytes()) @@ -242,22 +242,36 @@ func (b *Billet) traverse(curr Node, path, from []byte, process func(pathToNode var ( startIndex byte endIndex byte = childrenCount + cmp = func(i int) bool { + return i < int(endIndex) + } + step = 1 ) + if backwards { + startIndex, endIndex = lastChild, startIndex + cmp = func(i int) bool { + return i >= int(endIndex) + } + step = -1 + } if len(from) != 0 { endIndex = lastChild + if backwards { + endIndex = 0 + } startIndex, from = splitPath(from) } - for i := startIndex; i < endIndex; i++ { + for i := int(startIndex); cmp(i); i += step { var newPath []byte if i == lastChild { newPath = path } else { - newPath = append(path, i) + newPath = append(path, byte(i)) } - if i != startIndex { + if byte(i) != startIndex { from = []byte{} } - r, err := b.traverse(n.Children[i], newPath, from, process, ignoreStorageErr) + r, err := b.traverse(n.Children[i], newPath, from, process, ignoreStorageErr, backwards) if err != nil { if !errors.Is(err, errStop) { return nil, err @@ -276,7 +290,7 @@ func (b *Billet) traverse(curr Node, path, from []byte, process func(pathToNode } else { return b.tryCollapseExtension(n), nil } - r, err := b.traverse(n.next, append(path, n.key...), from, process, ignoreStorageErr) + r, err := b.traverse(n.next, append(path, n.key...), from, process, ignoreStorageErr, backwards) if err != nil && !errors.Is(err, errStop) { return nil, err } diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go index 608a39cfb..90352af3d 100644 --- a/pkg/core/mpt/trie.go +++ b/pkg/core/mpt/trie.go @@ -625,7 +625,7 @@ func (t *Trie) Find(prefix, from []byte, max int) ([]storage.KeyValue, error) { } return count >= max } - _, err = b.traverse(start, path, fromP, process, false) + _, err = b.traverse(start, path, fromP, process, false, false) if err != nil && !errors.Is(err, errStop) { return nil, err }