diff --git a/pkg/core/mpt/batch.go b/pkg/core/mpt/batch.go index 1b782498b..a03a6cbc8 100644 --- a/pkg/core/mpt/batch.go +++ b/pkg/core/mpt/batch.go @@ -192,7 +192,10 @@ func (t *Trie) stripBranch(b *BranchNode) (Node, error) { case n == 0: return new(HashNode), nil case n == 1: - return t.mergeExtension([]byte{lastIndex}, b.Children[lastIndex]) + if lastIndex != lastChild { + return t.mergeExtension([]byte{lastIndex}, b.Children[lastIndex]) + } + return b.Children[lastIndex], nil default: t.addRef(b.Hash(), b.bytes) return b, nil diff --git a/pkg/core/mpt/batch_test.go b/pkg/core/mpt/batch_test.go index ee673825a..6585d5204 100644 --- a/pkg/core/mpt/batch_test.go +++ b/pkg/core/mpt/batch_test.go @@ -174,6 +174,22 @@ func TestTrie_PutBatchBranch(t *testing.T) { testPut(t, ps, tr1, tr2) require.IsType(t, (*ExtensionNode)(nil), tr1.root) }) + t.Run("non-empty child is last node", func(t *testing.T) { + tr1 := NewTrie(new(HashNode), false, newTestStore()) + tr2 := NewTrie(new(HashNode), false, newTestStore()) + require.NoError(t, tr1.Put([]byte{0x00, 2}, []byte("value1"))) + require.NoError(t, tr2.Put([]byte{0x00, 2}, []byte("value1"))) + require.NoError(t, tr1.Put([]byte{0x00}, []byte("value2"))) + require.NoError(t, tr2.Put([]byte{0x00}, []byte("value2"))) + + tr1.Flush() + tr1.Collapse(1) + tr2.Flush() + tr2.Collapse(1) + + var ps = pairs{{[]byte{0x00, 2}, nil}} + testPut(t, ps, tr1, tr2) + }) }) t.Run("incomplete put, transform to extension", func(t *testing.T) { tr1, tr2 := prepareBranch(t) diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go index 6f4b452a4..fea6141e6 100644 --- a/pkg/core/mpt/trie.go +++ b/pkg/core/mpt/trie.go @@ -310,6 +310,7 @@ func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) if nxt.IsEmpty() { return nxt, nil } + n.next = nxt default: n.next = r } diff --git a/pkg/core/mpt/trie_test.go b/pkg/core/mpt/trie_test.go index 18d0d43b6..25ddd561c 100644 --- a/pkg/core/mpt/trie_test.go +++ b/pkg/core/mpt/trie_test.go @@ -395,6 +395,24 @@ func testTrieDelete(t *testing.T, enableGC bool) { tr.testHas(t, []byte{}, nil) tr.testHas(t, []byte{0x56}, []byte{0x34}) require.IsType(t, (*ExtensionNode)(nil), tr.root) + + t.Run("WithHash, branch node replaced", func(t *testing.T) { + ch := NewLeafNode([]byte{5, 6}) + h := ch.Hash() + + b := NewBranchNode() + b.Children[3] = NewExtensionNode([]byte{4}, NewLeafNode([]byte{1, 2, 3})) + b.Children[lastChild] = NewHashNode(h) + + tr := NewTrie(NewExtensionNode([]byte{1, 2}, b), enableGC, newTestStore()) + tr.putToStore(ch) + + require.NoError(t, tr.Delete([]byte{0x12, 0x34})) + tr.testHas(t, []byte{0x12, 0x34}, nil) + tr.testHas(t, []byte{0x12}, []byte{5, 6}) + require.IsType(t, (*ExtensionNode)(nil), tr.root) + require.Equal(t, h, tr.root.(*ExtensionNode).next.Hash()) + }) }) t.Run("LeaveLeaf", func(t *testing.T) {