mpt: move empty hash node in a separate type

We use them quite frequently (consider children for a new branch
node) and it is better to get rid of unneeded allocations.

Signed-off-by: Evgeniy Stratonikov <evgeniy@nspcc.ru>
This commit is contained in:
Evgeniy Stratonikov 2021-08-03 17:10:46 +03:00
parent f02d8b4ec4
commit db80ef28df
11 changed files with 124 additions and 99 deletions

View file

@ -54,8 +54,8 @@ func (b *BaseNode) getBytes(n Node) []byte {
// updateHash updates hash field for this BaseNode. // updateHash updates hash field for this BaseNode.
func (b *BaseNode) updateHash(n Node) { func (b *BaseNode) updateHash(n Node) {
if n.Type() == HashT { if n.Type() == HashT || n.Type() == EmptyT {
panic("can't update hash for hash node") panic("can't update hash for empty or hash node")
} }
b.hash = hash.DoubleSha256(b.getBytes(n)) b.hash = hash.DoubleSha256(b.getBytes(n))
b.hashValid = true b.hashValid = true
@ -86,17 +86,7 @@ func encodeBinaryAsChild(n Node, w *io.BinWriter) {
// encodeNodeWithType encodes node together with it's type. // encodeNodeWithType encodes node together with it's type.
func encodeNodeWithType(n Node, w *io.BinWriter) { func encodeNodeWithType(n Node, w *io.BinWriter) {
switch t := n.Type(); t { w.WriteB(byte(n.Type()))
case HashT:
hn := n.(*HashNode)
if !hn.hashValid {
w.WriteB(byte(EmptyT))
break
}
fallthrough
default:
w.WriteB(byte(t))
}
n.EncodeBinary(w) n.EncodeBinary(w)
} }
@ -120,11 +110,7 @@ func DecodeNodeWithType(r *io.BinReader) Node {
case LeafT: case LeafT:
n = new(LeafNode) n = new(LeafNode)
case EmptyT: case EmptyT:
n = &HashNode{ n = EmptyNode{}
BaseNode: BaseNode{
hashValid: false,
},
}
default: default:
r.Err = fmt.Errorf("invalid node type: %x", typ) r.Err = fmt.Errorf("invalid node type: %x", typ)
return nil return nil

View file

@ -62,6 +62,8 @@ func (t *Trie) putBatchIntoNode(curr Node, kv []keyValue) (Node, int, error) {
return t.putBatchIntoExtension(n, kv) return t.putBatchIntoExtension(n, kv)
case *HashNode: case *HashNode:
return t.putBatchIntoHash(n, kv) return t.putBatchIntoHash(n, kv)
case EmptyNode:
return t.putBatchIntoEmpty(kv)
default: default:
panic("invalid MPT node type") panic("invalid MPT node type")
} }
@ -84,11 +86,9 @@ func (t *Trie) mergeExtension(prefix []byte, sub Node) (Node, error) {
sn.invalidateCache() sn.invalidateCache()
t.addRef(sn.Hash(), sn.bytes) t.addRef(sn.Hash(), sn.bytes)
return sn, nil return sn, nil
case *HashNode: case EmptyNode:
if sn.IsEmpty() {
return sn, nil return sn, nil
} case *HashNode:
n, err := t.getFromStore(sn.Hash()) n, err := t.getFromStore(sn.Hash())
if err != nil { if err != nil {
return sn, err return sn, err
@ -141,8 +141,8 @@ func (t *Trie) putBatchIntoExtensionNoPrefix(key []byte, next Node, kv []keyValu
} }
func isEmpty(n Node) bool { func isEmpty(n Node) bool {
hn, ok := n.(*HashNode) _, ok := n.(EmptyNode)
return ok && hn.IsEmpty() return ok
} }
// addToBranch puts items into the branch node assuming b is not yet in trie. // addToBranch puts items into the branch node assuming b is not yet in trie.
@ -190,7 +190,7 @@ func (t *Trie) stripBranch(b *BranchNode) (Node, error) {
} }
switch { switch {
case n == 0: case n == 0:
return new(HashNode), nil return EmptyNode{}, nil
case n == 1: case n == 1:
if lastIndex != lastChild { if lastIndex != lastChild {
return t.mergeExtension([]byte{lastIndex}, b.Children[lastIndex]) return t.mergeExtension([]byte{lastIndex}, b.Children[lastIndex])
@ -219,12 +219,13 @@ func (t *Trie) iterateBatch(kv []keyValue, f func(c byte, kv []keyValue) (int, e
return n, nil return n, nil
} }
func (t *Trie) putBatchIntoHash(curr *HashNode, kv []keyValue) (Node, int, error) { func (t *Trie) putBatchIntoEmpty(kv []keyValue) (Node, int, error) {
if curr.IsEmpty() {
common := lcpMany(kv) common := lcpMany(kv)
stripPrefix(len(common), kv) stripPrefix(len(common), kv)
return t.newSubTrieMany(common, kv, nil) return t.newSubTrieMany(common, kv, nil)
} }
func (t *Trie) putBatchIntoHash(curr *HashNode, kv []keyValue) (Node, int, error) {
result, err := t.getFromStore(curr.hash) result, err := t.getFromStore(curr.hash)
if err != nil { if err != nil {
return curr, 0, err return curr, 0, err
@ -242,7 +243,7 @@ func (t *Trie) newSubTrieMany(prefix []byte, kv []keyValue, value []byte) (Node,
if len(kv[0].key) == 0 { if len(kv[0].key) == 0 {
if len(kv[0].value) == 0 { if len(kv[0].value) == 0 {
if len(kv) == 1 { if len(kv) == 1 {
return new(HashNode), 1, nil return EmptyNode{}, 1, nil
} }
node, n, err := t.newSubTrieMany(prefix, kv[1:], nil) node, n, err := t.newSubTrieMany(prefix, kv[1:], nil)
return node, n + 1, err return node, n + 1, err

View file

@ -68,8 +68,8 @@ func testPut(t *testing.T, ps pairs, tr1, tr2 *Trie) {
func TestTrie_PutBatchLeaf(t *testing.T) { func TestTrie_PutBatchLeaf(t *testing.T) {
prepareLeaf := func(t *testing.T) (*Trie, *Trie) { prepareLeaf := func(t *testing.T) (*Trie, *Trie) {
tr1 := NewTrie(new(HashNode), false, newTestStore()) tr1 := NewTrie(EmptyNode{}, false, newTestStore())
tr2 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(EmptyNode{}, false, newTestStore())
require.NoError(t, tr1.Put([]byte{0}, []byte("value"))) require.NoError(t, tr1.Put([]byte{0}, []byte("value")))
require.NoError(t, tr2.Put([]byte{0}, []byte("value"))) require.NoError(t, tr2.Put([]byte{0}, []byte("value")))
return tr1, tr2 return tr1, tr2
@ -97,8 +97,8 @@ func TestTrie_PutBatchLeaf(t *testing.T) {
func TestTrie_PutBatchExtension(t *testing.T) { func TestTrie_PutBatchExtension(t *testing.T) {
prepareExtension := func(t *testing.T) (*Trie, *Trie) { prepareExtension := func(t *testing.T) (*Trie, *Trie) {
tr1 := NewTrie(new(HashNode), false, newTestStore()) tr1 := NewTrie(EmptyNode{}, false, newTestStore())
tr2 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(EmptyNode{}, false, newTestStore())
require.NoError(t, tr1.Put([]byte{1, 2}, []byte("value1"))) require.NoError(t, tr1.Put([]byte{1, 2}, []byte("value1")))
require.NoError(t, tr2.Put([]byte{1, 2}, []byte("value1"))) require.NoError(t, tr2.Put([]byte{1, 2}, []byte("value1")))
return tr1, tr2 return tr1, tr2
@ -144,8 +144,8 @@ func TestTrie_PutBatchExtension(t *testing.T) {
func TestTrie_PutBatchBranch(t *testing.T) { func TestTrie_PutBatchBranch(t *testing.T) {
prepareBranch := func(t *testing.T) (*Trie, *Trie) { prepareBranch := func(t *testing.T) (*Trie, *Trie) {
tr1 := NewTrie(new(HashNode), false, newTestStore()) tr1 := NewTrie(EmptyNode{}, false, newTestStore())
tr2 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(EmptyNode{}, false, newTestStore())
require.NoError(t, tr1.Put([]byte{0x00, 2}, []byte("value1"))) require.NoError(t, tr1.Put([]byte{0x00, 2}, []byte("value1")))
require.NoError(t, tr2.Put([]byte{0x00, 2}, []byte("value1"))) require.NoError(t, tr2.Put([]byte{0x00, 2}, []byte("value1")))
require.NoError(t, tr1.Put([]byte{0x10, 3}, []byte("value2"))) require.NoError(t, tr1.Put([]byte{0x10, 3}, []byte("value2")))
@ -175,8 +175,8 @@ func TestTrie_PutBatchBranch(t *testing.T) {
require.IsType(t, (*ExtensionNode)(nil), tr1.root) require.IsType(t, (*ExtensionNode)(nil), tr1.root)
}) })
t.Run("non-empty child is last node", func(t *testing.T) { t.Run("non-empty child is last node", func(t *testing.T) {
tr1 := NewTrie(new(HashNode), false, newTestStore()) tr1 := NewTrie(EmptyNode{}, false, newTestStore())
tr2 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(EmptyNode{}, false, newTestStore())
require.NoError(t, tr1.Put([]byte{0x00, 2}, []byte("value1"))) require.NoError(t, tr1.Put([]byte{0x00, 2}, []byte("value1")))
require.NoError(t, tr2.Put([]byte{0x00, 2}, []byte("value1"))) require.NoError(t, tr2.Put([]byte{0x00, 2}, []byte("value1")))
require.NoError(t, tr1.Put([]byte{0x00}, []byte("value2"))) require.NoError(t, tr1.Put([]byte{0x00}, []byte("value2")))
@ -222,8 +222,8 @@ func TestTrie_PutBatchBranch(t *testing.T) {
func TestTrie_PutBatchHash(t *testing.T) { func TestTrie_PutBatchHash(t *testing.T) {
prepareHash := func(t *testing.T) (*Trie, *Trie) { prepareHash := func(t *testing.T) (*Trie, *Trie) {
tr1 := NewTrie(new(HashNode), false, newTestStore()) tr1 := NewTrie(EmptyNode{}, false, newTestStore())
tr2 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(EmptyNode{}, false, newTestStore())
require.NoError(t, tr1.Put([]byte{0x10}, []byte("value1"))) require.NoError(t, tr1.Put([]byte{0x10}, []byte("value1")))
require.NoError(t, tr2.Put([]byte{0x10}, []byte("value1"))) require.NoError(t, tr2.Put([]byte{0x10}, []byte("value1")))
require.NoError(t, tr1.Put([]byte{0x20}, []byte("value2"))) require.NoError(t, tr1.Put([]byte{0x20}, []byte("value2")))
@ -257,8 +257,8 @@ func TestTrie_PutBatchHash(t *testing.T) {
func TestTrie_PutBatchEmpty(t *testing.T) { func TestTrie_PutBatchEmpty(t *testing.T) {
t.Run("good", func(t *testing.T) { t.Run("good", func(t *testing.T) {
tr1 := NewTrie(new(HashNode), false, newTestStore()) tr1 := NewTrie(EmptyNode{}, false, newTestStore())
tr2 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(EmptyNode{}, false, newTestStore())
var ps = pairs{ var ps = pairs{
{[]byte{0}, []byte("value0")}, {[]byte{0}, []byte("value0")},
{[]byte{1}, []byte("value1")}, {[]byte{1}, []byte("value1")},
@ -273,15 +273,15 @@ func TestTrie_PutBatchEmpty(t *testing.T) {
{[]byte{2}, nil}, {[]byte{2}, nil},
{[]byte{3}, []byte("replace3")}, {[]byte{3}, []byte("replace3")},
} }
tr1 := NewTrie(new(HashNode), false, newTestStore()) tr1 := NewTrie(EmptyNode{}, false, newTestStore())
tr2 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(EmptyNode{}, false, newTestStore())
testIncompletePut(t, ps, 4, tr1, tr2) testIncompletePut(t, ps, 4, tr1, tr2)
}) })
} }
// For the sake of coverage. // For the sake of coverage.
func TestTrie_InvalidNodeType(t *testing.T) { func TestTrie_InvalidNodeType(t *testing.T) {
tr := NewTrie(new(HashNode), false, newTestStore()) tr := NewTrie(EmptyNode{}, false, newTestStore())
var b Batch var b Batch
b.Add([]byte{1}, []byte("value")) b.Add([]byte{1}, []byte("value"))
tr.root = Node(nil) tr.root = Node(nil)
@ -289,8 +289,8 @@ func TestTrie_InvalidNodeType(t *testing.T) {
} }
func TestTrie_PutBatch(t *testing.T) { func TestTrie_PutBatch(t *testing.T) {
tr1 := NewTrie(new(HashNode), false, newTestStore()) tr1 := NewTrie(EmptyNode{}, false, newTestStore())
tr2 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(EmptyNode{}, false, newTestStore())
var ps = pairs{ var ps = pairs{
{[]byte{1}, []byte{1}}, {[]byte{1}, []byte{1}},
{[]byte{2}, []byte{3}}, {[]byte{2}, []byte{3}},
@ -312,11 +312,10 @@ var _ = printNode
// `spew.Dump()`. // `spew.Dump()`.
func printNode(prefix string, n Node) { func printNode(prefix string, n Node) {
switch tn := n.(type) { switch tn := n.(type) {
case *HashNode: case EmptyNode:
if tn.IsEmpty() {
fmt.Printf("%s empty\n", prefix) fmt.Printf("%s empty\n", prefix)
return return
} case *HashNode:
fmt.Printf("%s %s\n", prefix, tn.Hash().StringLE()) fmt.Printf("%s %s\n", prefix, tn.Hash().StringLE())
case *BranchNode: case *BranchNode:
for i, c := range tn.Children { for i, c := range tn.Children {

View file

@ -27,7 +27,7 @@ var _ Node = (*BranchNode)(nil)
func NewBranchNode() *BranchNode { func NewBranchNode() *BranchNode {
b := new(BranchNode) b := new(BranchNode)
for i := 0; i < childrenCount; i++ { for i := 0; i < childrenCount; i++ {
b.Children[i] = new(HashNode) b.Children[i] = EmptyNode{}
} }
return b return b
} }

53
pkg/core/mpt/empty.go Normal file
View file

@ -0,0 +1,53 @@
package mpt
import (
"encoding/json"
"errors"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// EmptyNode represents empty node.
type EmptyNode struct{}
// DecodeBinary implements io.Serializable interface.
func (e EmptyNode) DecodeBinary(*io.BinReader) {
}
// EncodeBinary implements io.Serializable interface.
func (e EmptyNode) EncodeBinary(*io.BinWriter) {
}
// MarshalJSON implements Node interface.
func (e EmptyNode) MarshalJSON() ([]byte, error) {
return []byte(`{}`), nil
}
// UnmarshalJSON implements Node interface.
func (e EmptyNode) UnmarshalJSON(bytes []byte) error {
var m map[string]interface{}
err := json.Unmarshal(bytes, &m)
if err != nil {
return err
}
if len(m) != 0 {
return errors.New("expected empty node")
}
return nil
}
// Hash implements Node interface.
func (e EmptyNode) Hash() util.Uint256 {
panic("can't get hash of an EmptyNode")
}
// Type implements Node interface.
func (e EmptyNode) Type() NodeType {
return EmptyT
}
// Bytes implements Node interface.
func (e EmptyNode) Bytes() []byte {
return nil
}

View file

@ -35,9 +35,6 @@ func (h *HashNode) Hash() util.Uint256 {
return h.hash return h.hash
} }
// IsEmpty returns true if h is an empty node i.e. contains no hash.
func (h *HashNode) IsEmpty() bool { return !h.hashValid }
// Bytes returns serialized HashNode. // Bytes returns serialized HashNode.
func (h *HashNode) Bytes() []byte { func (h *HashNode) Bytes() []byte {
return h.getBytes(h) return h.getBytes(h)
@ -60,9 +57,6 @@ func (h HashNode) EncodeBinary(w *io.BinWriter) {
// MarshalJSON implements json.Marshaler. // MarshalJSON implements json.Marshaler.
func (h *HashNode) MarshalJSON() ([]byte, error) { func (h *HashNode) MarshalJSON() ([]byte, error) {
if !h.hashValid {
return []byte(`{}`), nil
}
return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil
} }

View file

@ -68,7 +68,7 @@ func (n *NodeObject) UnmarshalJSON(data []byte) error {
switch len(m) { switch len(m) {
case 0: case 0:
n.Node = new(HashNode) n.Node = EmptyNode{}
case 1: case 1:
if v, ok := m["hash"]; ok { if v, ok := m["hash"]; ok {
var h util.Uint256 var h util.Uint256

View file

@ -78,9 +78,6 @@ func TestNode_Serializable(t *testing.T) {
t.Run("Raw", getTestFuncEncode(true, h, new(HashNode))) t.Run("Raw", getTestFuncEncode(true, h, new(HashNode)))
t.Run("WithType", getTestFuncEncode(true, &NodeObject{h}, new(NodeObject))) t.Run("WithType", getTestFuncEncode(true, &NodeObject{h}, new(NodeObject)))
}) })
t.Run("Empty", func(t *testing.T) { // compare nodes, not hashes
testserdes.EncodeDecodeBinary(t, new(HashNode), new(HashNode))
})
t.Run("InvalidSize", func(t *testing.T) { t.Run("InvalidSize", func(t *testing.T) {
buf := io.NewBufBinWriter() buf := io.NewBufBinWriter()
buf.BinWriter.WriteBytes(make([]byte, 13)) buf.BinWriter.WriteBytes(make([]byte, 13))
@ -111,7 +108,7 @@ func TestInvalidJSON(t *testing.T) {
t.Run("InvalidChildrenCount", func(t *testing.T) { t.Run("InvalidChildrenCount", func(t *testing.T) {
var cs [childrenCount + 1]Node var cs [childrenCount + 1]Node
for i := range cs { for i := range cs {
cs[i] = new(HashNode) cs[i] = EmptyNode{}
} }
data, err := json.Marshal(cs) data, err := json.Marshal(cs)
require.NoError(t, err) require.NoError(t, err)

View file

@ -49,14 +49,12 @@ func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error)
return n, nil return n, nil
} }
case *HashNode: case *HashNode:
if !n.IsEmpty() {
r, err := t.getFromStore(n.Hash()) r, err := t.getFromStore(n.Hash())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return t.getProof(r, path, proofs) return t.getProof(r, path, proofs)
} }
}
return nil, ErrNotFound return nil, ErrNotFound
} }

View file

@ -35,7 +35,7 @@ var ErrNotFound = errors.New("item not found")
// This also has the benefit, that every `Put` can be considered an atomic operation. // This also has the benefit, that every `Put` can be considered an atomic operation.
func NewTrie(root Node, enableRefCount bool, store *storage.MemCachedStore) *Trie { func NewTrie(root Node, enableRefCount bool, store *storage.MemCachedStore) *Trie {
if root == nil { if root == nil {
root = new(HashNode) root = EmptyNode{}
} }
return &Trie{ return &Trie{
@ -75,12 +75,11 @@ func (t *Trie) getWithPath(curr Node, path []byte) (Node, []byte, error) {
} }
n.Children[i] = r n.Children[i] = r
return n, bs, nil return n, bs, nil
case EmptyNode:
case *HashNode: case *HashNode:
if !n.IsEmpty() {
if r, err := t.getFromStore(n.hash); err == nil { if r, err := t.getFromStore(n.hash); err == nil {
return t.getWithPath(r, path) return t.getWithPath(r, path)
} }
}
case *ExtensionNode: case *ExtensionNode:
if bytes.HasPrefix(path, n.key) { if bytes.HasPrefix(path, n.key) {
r, bs, err := t.getWithPath(n.next, path[len(n.key):]) r, bs, err := t.getWithPath(n.next, path[len(n.key):])
@ -187,14 +186,13 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod
return b, nil return b, nil
} }
func (t *Trie) putIntoEmpty(path []byte, val Node) (Node, error) {
return t.newSubTrie(path, val, true), nil
}
// putIntoHash puts val to trie if current node is a HashNode. // putIntoHash puts val to trie if current node is a HashNode.
// It returns Node if curr needs to be replaced and error if any. // It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) {
if curr.IsEmpty() {
hn := t.newSubTrie(path, val, true)
return hn, nil
}
result, err := t.getFromStore(curr.hash) result, err := t.getFromStore(curr.hash)
if err != nil { if err != nil {
return nil, err return nil, err
@ -227,6 +225,8 @@ func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) {
return t.putIntoExtension(n, path, val) return t.putIntoExtension(n, path, val)
case *HashNode: case *HashNode:
return t.putIntoHash(n, path, val) return t.putIntoHash(n, path, val)
case EmptyNode:
return t.putIntoEmpty(path, val)
default: default:
panic("invalid MPT node type") panic("invalid MPT node type")
} }
@ -257,8 +257,7 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
b.invalidateCache() b.invalidateCache()
var count, index int var count, index int
for i := range b.Children { for i := range b.Children {
h, ok := b.Children[i].(*HashNode) if !isEmpty(b.Children[i]) {
if !ok || !h.IsEmpty() {
index = i index = i
count++ count++
} }
@ -307,10 +306,9 @@ func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error)
t.removeRef(nxt.Hash(), nxt.bytes) t.removeRef(nxt.Hash(), nxt.bytes)
n.key = append(n.key, nxt.key...) n.key = append(n.key, nxt.key...)
n.next = nxt.next n.next = nxt.next
case *HashNode: case EmptyNode:
if nxt.IsEmpty() {
return nxt, nil return nxt, nil
} case *HashNode:
n.next = nxt n.next = nxt
default: default:
n.next = r n.next = r
@ -327,17 +325,16 @@ func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) {
case *LeafNode: case *LeafNode:
if len(path) == 0 { if len(path) == 0 {
t.removeRef(curr.Hash(), curr.Bytes()) t.removeRef(curr.Hash(), curr.Bytes())
return new(HashNode), nil return EmptyNode{}, nil
} }
return curr, nil return curr, nil
case *BranchNode: case *BranchNode:
return t.deleteFromBranch(n, path) return t.deleteFromBranch(n, path)
case *ExtensionNode: case *ExtensionNode:
return t.deleteFromExtension(n, path) return t.deleteFromExtension(n, path)
case *HashNode: case EmptyNode:
if n.IsEmpty() {
return n, nil return n, nil
} case *HashNode:
newNode, err := t.getFromStore(n.Hash()) newNode, err := t.getFromStore(n.Hash())
if err != nil { if err != nil {
return nil, err return nil, err
@ -350,7 +347,7 @@ func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) {
// StateRoot returns root hash of t. // StateRoot returns root hash of t.
func (t *Trie) StateRoot() util.Uint256 { func (t *Trie) StateRoot() util.Uint256 {
if hn, ok := t.root.(*HashNode); ok && hn.IsEmpty() { if isEmpty(t.root) {
return util.Uint256{} return util.Uint256{}
} }
return t.root.Hash() return t.root.Hash()
@ -486,9 +483,11 @@ func (t *Trie) Collapse(depth int) {
} }
func collapse(depth int, node Node) Node { func collapse(depth int, node Node) Node {
if _, ok := node.(*HashNode); ok { switch node.(type) {
case *HashNode, EmptyNode:
return node return node
} else if depth == 0 { }
if depth == 0 {
return NewHashNode(node.Hash()) return NewHashNode(node.Hash())
} }

View file

@ -239,8 +239,7 @@ func isValid(curr Node) bool {
if !isValid(n.Children[i]) { if !isValid(n.Children[i]) {
return false return false
} }
hn, ok := n.Children[i].(*HashNode) if !isEmpty(n.Children[i]) {
if !ok || !hn.IsEmpty() {
count++ count++
} }
} }
@ -342,7 +341,7 @@ func testTrieDelete(t *testing.T, enableGC bool) {
}) })
require.NoError(t, tr.Delete([]byte{0xAB})) require.NoError(t, tr.Delete([]byte{0xAB}))
require.True(t, tr.root.(*HashNode).IsEmpty()) require.IsType(t, EmptyNode{}, tr.root)
}) })
t.Run("MultipleKeys", func(t *testing.T) { t.Run("MultipleKeys", func(t *testing.T) {
@ -505,12 +504,11 @@ func TestTrie_Collapse(t *testing.T) {
require.Equal(t, NewLeafNode([]byte("value")), tr.root) require.Equal(t, NewLeafNode([]byte("value")), tr.root)
}) })
t.Run("Hash", func(t *testing.T) { t.Run("Hash", func(t *testing.T) {
t.Run("Empty", func(t *testing.T) { t.Run("EmptyNode", func(t *testing.T) {
tr := NewTrie(new(HashNode), false, newTestStore()) tr := NewTrie(EmptyNode{}, false, newTestStore())
require.NotPanics(t, func() { tr.Collapse(1) }) require.NotPanics(t, func() { tr.Collapse(1) })
hn, ok := tr.root.(*HashNode) _, ok := tr.root.(EmptyNode)
require.True(t, ok) require.True(t, ok)
require.True(t, hn.IsEmpty())
}) })
h := random.Uint256() h := random.Uint256()