From fb88d4f3a0562663bcd348c94ab3ab1b1a975ee6 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Sat, 26 Dec 2020 13:27:59 +0300 Subject: [PATCH] mpt: support put in batches --- pkg/core/blockchain.go | 3 + pkg/core/dao/dao.go | 11 +- pkg/core/mpt/batch.go | 253 ++++++++++++++++++++++++++++++ pkg/core/mpt/batch_test.go | 305 +++++++++++++++++++++++++++++++++++++ pkg/core/mpt/helpers.go | 14 ++ 5 files changed, 578 insertions(+), 8 deletions(-) create mode 100644 pkg/core/mpt/batch.go create mode 100644 pkg/core/mpt/batch_test.go diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 8ee28a6f4..59de9c928 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -689,6 +689,9 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error d := cache.DAO.(*dao.Simple) if err := d.UpdateMPT(); err != nil { + // Here MPT can be left in a half-applied state. + // However if this error occurs, this is a bug somewhere in code + // because changes applied are the ones from HALTed transactions. return fmt.Errorf("error while trying to apply MPT changes: %w", err) } diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index 524204ed8..a0c9480c9 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -694,15 +694,10 @@ func (dao *Simple) Persist() (int, error) { // UpdateMPT updates MPT using storage items from the underlying memcached store. func (dao *Simple) UpdateMPT() error { - var err error + var b mpt.Batch dao.Store.MemoryStore.SeekAll([]byte{byte(storage.STStorage)}, func(k, v []byte) { - if err != nil { - return - } else if v != nil { - err = dao.MPT.Put(k[1:], v) - } else { - err = dao.MPT.Delete(k[1:]) - } + b.Add(k[1:], v) }) + _, err := dao.MPT.PutBatch(b) return err } diff --git a/pkg/core/mpt/batch.go b/pkg/core/mpt/batch.go new file mode 100644 index 000000000..6a25dae89 --- /dev/null +++ b/pkg/core/mpt/batch.go @@ -0,0 +1,253 @@ +package mpt + +import ( + "bytes" + "sort" +) + +// Batch is batch of storage changes. +// It stores key-value pairs in a sorted state. +type Batch struct { + kv []keyValue +} + +type keyValue struct { + key []byte + value []byte +} + +// Add adds key-value pair to batch. +// If there is an item with the specified key, it is replaced. +func (b *Batch) Add(key []byte, value []byte) { + path := toNibbles(key) + i := sort.Search(len(b.kv), func(i int) bool { + return bytes.Compare(path, b.kv[i].key) <= 0 + }) + if i == len(b.kv) { + b.kv = append(b.kv, keyValue{path, value}) + } else if bytes.Equal(b.kv[i].key, path) { + b.kv[i].value = value + } else { + b.kv = append(b.kv, keyValue{}) + copy(b.kv[i+1:], b.kv[i:]) + b.kv[i].key = path + b.kv[i].value = value + } +} + +// PutBatch puts batch to trie. +// It is not atomic (and probably cannot be without substantial slow-down) +// and returns number of elements processed. +// However each element is being put atomically, so Trie is always in a valid state. +// It is used mostly after the block processing to update MPT and error is not expected. +func (t *Trie) PutBatch(b Batch) (int, error) { + r, n, err := t.putBatch(b.kv) + t.root = r + return n, err +} + +func (t *Trie) putBatch(kv []keyValue) (Node, int, error) { + return t.putBatchIntoNode(t.root, kv) +} + +func (t *Trie) putBatchIntoNode(curr Node, kv []keyValue) (Node, int, error) { + switch n := curr.(type) { + case *LeafNode: + return t.putBatchIntoLeaf(n, kv) + case *BranchNode: + return t.putBatchIntoBranch(n, kv) + case *ExtensionNode: + return t.putBatchIntoExtension(n, kv) + case *HashNode: + return t.putBatchIntoHash(n, kv) + default: + panic("invalid MPT node type") + } +} + +func (t *Trie) putBatchIntoLeaf(curr *LeafNode, kv []keyValue) (Node, int, error) { + t.removeRef(curr.Hash(), curr.Bytes()) + return t.newSubTrieMany(nil, kv, curr.value) +} + +func (t *Trie) putBatchIntoBranch(curr *BranchNode, kv []keyValue) (Node, int, error) { + return t.addToBranch(curr, kv, true) +} + +func (t *Trie) mergeExtension(prefix []byte, sub Node) Node { + switch sn := sub.(type) { + case *ExtensionNode: + t.removeRef(sn.Hash(), sn.bytes) + sn.key = append(prefix, sn.key...) + sn.invalidateCache() + t.addRef(sn.Hash(), sn.bytes) + return sn + case *HashNode: + return sn + default: + if len(prefix) != 0 { + e := NewExtensionNode(prefix, sub) + t.addRef(e.Hash(), e.bytes) + return e + } + return sub + } +} + +func (t *Trie) putBatchIntoExtension(curr *ExtensionNode, kv []keyValue) (Node, int, error) { + t.removeRef(curr.Hash(), curr.bytes) + + common := lcpMany(kv) + pref := lcp(common, curr.key) + if len(pref) == len(curr.key) { + // Extension must be split into new nodes. + stripPrefix(len(curr.key), kv) + sub, n, err := t.putBatchIntoNode(curr.next, kv) + return t.mergeExtension(pref, sub), n, err + } + + if len(pref) != 0 { + stripPrefix(len(pref), kv) + sub, n, err := t.putBatchIntoExtensionNoPrefix(curr.key[len(pref):], curr.next, kv) + return t.mergeExtension(pref, sub), n, err + } + return t.putBatchIntoExtensionNoPrefix(curr.key, curr.next, kv) +} + +func (t *Trie) putBatchIntoExtensionNoPrefix(key []byte, next Node, kv []keyValue) (Node, int, error) { + b := NewBranchNode() + if len(key) > 1 { + b.Children[key[0]] = t.newSubTrie(key[1:], next, false) + } else { + b.Children[key[0]] = next + } + return t.addToBranch(b, kv, false) +} + +func isEmpty(n Node) bool { + hn, ok := n.(*HashNode) + return ok && hn.IsEmpty() +} + +// addToBranch puts items into the branch node assuming b is not yet in trie. +func (t *Trie) addToBranch(b *BranchNode, kv []keyValue, inTrie bool) (Node, int, error) { + if inTrie { + t.removeRef(b.Hash(), b.bytes) + } + n, err := t.iterateBatch(kv, func(c byte, kv []keyValue) (int, error) { + child, n, err := t.putBatchIntoNode(b.Children[c], kv) + b.Children[c] = child + return n, err + }) + if inTrie && n != 0 { + b.invalidateCache() + } + return t.stripBranch(b), n, err +} + +// stripsBranch strips branch node after incomplete batch put. +// It assumes there is no reference to b in trie. +func (t *Trie) stripBranch(b *BranchNode) Node { + var n int + var lastIndex byte + for i := range b.Children { + if !isEmpty(b.Children[i]) { + n++ + lastIndex = byte(i) + } + } + switch { + case n == 0: + return new(HashNode) + case n == 1: + return t.mergeExtension([]byte{lastIndex}, b.Children[lastIndex]) + default: + t.addRef(b.Hash(), b.bytes) + return b + } +} + +func (t *Trie) iterateBatch(kv []keyValue, f func(c byte, kv []keyValue) (int, error)) (int, error) { + var n int + for len(kv) != 0 { + c, i := getLastIndex(kv) + if c != lastChild { + stripPrefix(1, kv[:i]) + } + sub, err := f(c, kv[:i]) + n += sub + if err != nil { + return n, err + } + kv = kv[i:] + } + return n, nil +} + +func (t *Trie) putBatchIntoHash(curr *HashNode, kv []keyValue) (Node, int, error) { + if curr.IsEmpty() { + common := lcpMany(kv) + stripPrefix(len(common), kv) + return t.newSubTrieMany(common, kv, nil) + } + result, err := t.getFromStore(curr.hash) + if err != nil { + return curr, 0, err + } + return t.putBatchIntoNode(result, kv) +} + +// Creates new subtrie from provided key-value pairs. +// Items in kv must have no common prefix. +// If there are any deletions in kv, return error. +// kv is not empty. +// kv is sorted by key. +// value is current value stored by prefix. +func (t *Trie) newSubTrieMany(prefix []byte, kv []keyValue, value []byte) (Node, int, error) { + if len(kv[0].key) == 0 { + if len(kv[0].value) == 0 { + if len(kv) == 1 { + if len(value) != 0 { + return new(HashNode), 1, nil + } + return new(HashNode), 0, ErrNotFound + } + node, n, err := t.newSubTrieMany(prefix, kv[1:], nil) + return node, n + 1, err + } + if len(kv) == 1 { + return t.newSubTrie(prefix, NewLeafNode(kv[0].value), true), 1, nil + } + value = kv[0].value + } + + // Prefix is empty and we have at least 2 children. + b := NewBranchNode() + if len(value) != 0 { + // Empty key is always first. + leaf := NewLeafNode(value) + t.addRef(leaf.Hash(), leaf.bytes) + b.Children[lastChild] = leaf + } + nd, n, err := t.addToBranch(b, kv, false) + return t.mergeExtension(prefix, nd), n, err +} + +func stripPrefix(n int, kv []keyValue) { + for i := range kv { + kv[i].key = kv[i].key[n:] + } +} + +func getLastIndex(kv []keyValue) (byte, int) { + if len(kv[0].key) == 0 { + return lastChild, 1 + } + c := kv[0].key[0] + for i := range kv[1:] { + if kv[i+1].key[0] != c { + return c, i + 1 + } + } + return c, len(kv) +} diff --git a/pkg/core/mpt/batch_test.go b/pkg/core/mpt/batch_test.go new file mode 100644 index 000000000..fcf612ae2 --- /dev/null +++ b/pkg/core/mpt/batch_test.go @@ -0,0 +1,305 @@ +package mpt + +import ( + "encoding/hex" + "fmt" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/stretchr/testify/require" +) + +func TestBatchAdd(t *testing.T) { + b := new(Batch) + b.Add([]byte{1}, []byte{2}) + b.Add([]byte{2, 16}, []byte{3}) + b.Add([]byte{2, 0}, []byte{4}) + b.Add([]byte{0, 1}, []byte{5}) + b.Add([]byte{2, 0}, []byte{6}) + expected := []keyValue{ + {[]byte{0, 0, 0, 1}, []byte{5}}, + {[]byte{0, 1}, []byte{2}}, + {[]byte{0, 2, 0, 0}, []byte{6}}, + {[]byte{0, 2, 1, 0}, []byte{3}}, + } + require.Equal(t, expected, b.kv) +} + +type pairs = [][2][]byte + +func testIncompletePut(t *testing.T, ps pairs, n int, tr1, tr2 *Trie) { + var b Batch + for i, p := range ps { + if i < n { + require.NoError(t, tr1.Put(p[0], p[1]), "item %d", i) + } else if i == n { + require.Error(t, tr1.Put(p[0], p[1]), "item %d", i) + } + b.Add(p[0], p[1]) + } + + num, err := tr2.PutBatch(b) + if n == len(ps) { + require.NoError(t, err) + } else { + require.Error(t, err) + } + require.Equal(t, n, num) + require.Equal(t, tr1.StateRoot(), tr2.StateRoot()) + + t.Run("test restore", func(t *testing.T) { + tr2.Flush() + tr3 := NewTrie(NewHashNode(tr2.StateRoot()), false, storage.NewMemCachedStore(tr2.Store)) + for _, p := range ps[:n] { + val, err := tr3.Get(p[0]) + if p[1] == nil { + require.Error(t, err) + continue + } + require.NoError(t, err, "key: %s", hex.EncodeToString(p[0])) + require.Equal(t, p[1], val) + } + }) +} + +func testPut(t *testing.T, ps pairs, tr1, tr2 *Trie) { + testIncompletePut(t, ps, len(ps), tr1, tr2) +} + +func TestTrie_PutBatchLeaf(t *testing.T) { + prepareLeaf := func(t *testing.T) (*Trie, *Trie) { + tr1 := NewTrie(new(HashNode), false, newTestStore()) + tr2 := NewTrie(new(HashNode), false, newTestStore()) + require.NoError(t, tr1.Put([]byte{}, []byte("value"))) + require.NoError(t, tr2.Put([]byte{}, []byte("value"))) + return tr1, tr2 + } + + t.Run("remove", func(t *testing.T) { + tr1, tr2 := prepareLeaf(t) + var ps = pairs{{[]byte{}, nil}} + testPut(t, ps, tr1, tr2) + }) + t.Run("replace", func(t *testing.T) { + tr1, tr2 := prepareLeaf(t) + var ps = pairs{{[]byte{}, []byte("replace")}} + testPut(t, ps, tr1, tr2) + }) + t.Run("remove and replace", func(t *testing.T) { + tr1, tr2 := prepareLeaf(t) + var ps = pairs{ + {[]byte{}, nil}, + {[]byte{2}, []byte("replace2")}, + } + testPut(t, ps, tr1, tr2) + }) +} + +func TestTrie_PutBatchExtension(t *testing.T) { + prepareExtension := func(t *testing.T) (*Trie, *Trie) { + tr1 := NewTrie(new(HashNode), false, newTestStore()) + tr2 := NewTrie(new(HashNode), false, newTestStore()) + require.NoError(t, tr1.Put([]byte{1, 2}, []byte("value1"))) + require.NoError(t, tr2.Put([]byte{1, 2}, []byte("value1"))) + return tr1, tr2 + } + + t.Run("split, key len > 1", func(t *testing.T) { + tr1, tr2 := prepareExtension(t) + var ps = pairs{{[]byte{2, 3}, []byte("value2")}} + testPut(t, ps, tr1, tr2) + }) + t.Run("split, key len = 1", func(t *testing.T) { + tr1, tr2 := prepareExtension(t) + var ps = pairs{{[]byte{1, 3}, []byte("value2")}} + testPut(t, ps, tr1, tr2) + }) + t.Run("add to next", func(t *testing.T) { + tr1, tr2 := prepareExtension(t) + var ps = pairs{{[]byte{1, 2, 3}, []byte("value2")}} + testPut(t, ps, tr1, tr2) + }) + t.Run("add to next with leaf", func(t *testing.T) { + tr1, tr2 := prepareExtension(t) + var ps = pairs{ + {[]byte{}, []byte("value3")}, + {[]byte{1, 2, 3}, []byte("value2")}, + } + testPut(t, ps, tr1, tr2) + }) + t.Run("remove value", func(t *testing.T) { + tr1, tr2 := prepareExtension(t) + var ps = pairs{{[]byte{1, 2}, nil}} + testPut(t, ps, tr1, tr2) + }) + t.Run("add to next, merge extension", func(t *testing.T) { + tr1, tr2 := prepareExtension(t) + var ps = pairs{ + {[]byte{1, 2}, nil}, + {[]byte{1, 2, 3}, []byte("value2")}, + } + testPut(t, ps, tr1, tr2) + }) +} + +func TestTrie_PutBatchBranch(t *testing.T) { + prepareBranch := func(t *testing.T) (*Trie, *Trie) { + tr1 := NewTrie(new(HashNode), false, newTestStore()) + tr2 := NewTrie(new(HashNode), false, newTestStore()) + require.NoError(t, tr1.Put([]byte{0x00, 2}, []byte("value1"))) + require.NoError(t, tr2.Put([]byte{0x00, 2}, []byte("value1"))) + require.NoError(t, tr1.Put([]byte{0x10, 3}, []byte("value2"))) + require.NoError(t, tr2.Put([]byte{0x10, 3}, []byte("value2"))) + return tr1, tr2 + } + + t.Run("simple add", func(t *testing.T) { + tr1, tr2 := prepareBranch(t) + var ps = pairs{{[]byte{0x20, 4}, []byte("value3")}} + testPut(t, ps, tr1, tr2) + }) + t.Run("remove 1, transform to extension", func(t *testing.T) { + tr1, tr2 := prepareBranch(t) + var ps = pairs{{[]byte{0x00, 2}, nil}} + testPut(t, ps, tr1, tr2) + }) + t.Run("incomplete put, transform to extension", func(t *testing.T) { + tr1, tr2 := prepareBranch(t) + var ps = pairs{ + {[]byte{0x00, 2}, nil}, + {[]byte{0x20, 2}, nil}, + {[]byte{0x30, 3}, []byte("won't be put")}, + } + testIncompletePut(t, ps, 1, tr1, tr2) + }) + t.Run("incomplete put, transform to empty", func(t *testing.T) { + tr1, tr2 := prepareBranch(t) + var ps = pairs{ + {[]byte{0x00, 2}, nil}, + {[]byte{0x10, 3}, nil}, + {[]byte{0x20, 2}, nil}, + {[]byte{0x30, 3}, []byte("won't be put")}, + } + testIncompletePut(t, ps, 2, tr1, tr2) + }) + t.Run("remove 2, become empty", func(t *testing.T) { + tr1, tr2 := prepareBranch(t) + var ps = pairs{ + {[]byte{0x00, 2}, nil}, + {[]byte{0x10, 3}, nil}, + } + testPut(t, ps, tr1, tr2) + }) +} + +func TestTrie_PutBatchHash(t *testing.T) { + prepareHash := func(t *testing.T) (*Trie, *Trie) { + tr1 := NewTrie(new(HashNode), false, newTestStore()) + tr2 := NewTrie(new(HashNode), false, newTestStore()) + require.NoError(t, tr1.Put([]byte{0x10}, []byte("value1"))) + require.NoError(t, tr2.Put([]byte{0x10}, []byte("value1"))) + require.NoError(t, tr1.Put([]byte{0x20}, []byte("value2"))) + require.NoError(t, tr2.Put([]byte{0x20}, []byte("value2"))) + tr1.Flush() + tr2.Flush() + return tr1, tr2 + } + + t.Run("good", func(t *testing.T) { + tr1, tr2 := prepareHash(t) + var ps = pairs{{[]byte{2}, []byte("value2")}} + tr1.Collapse(0) + tr1.Collapse(0) + testPut(t, ps, tr1, tr2) + }) + t.Run("incomplete, second hash not found", func(t *testing.T) { + tr1, tr2 := prepareHash(t) + var ps = pairs{ + {[]byte{0x10}, []byte("replace1")}, + {[]byte{0x20}, []byte("replace2")}, + } + tr1.Collapse(1) + tr2.Collapse(1) + key := makeStorageKey(tr1.root.(*BranchNode).Children[2].Hash().BytesBE()) + require.NoError(t, tr1.Store.Delete(key)) + require.NoError(t, tr2.Store.Delete(key)) + testIncompletePut(t, ps, 1, tr1, tr2) + }) +} + +func TestTrie_PutBatchEmpty(t *testing.T) { + t.Run("good", func(t *testing.T) { + tr1 := NewTrie(new(HashNode), false, newTestStore()) + tr2 := NewTrie(new(HashNode), false, newTestStore()) + var ps = pairs{ + {[]byte{}, []byte("value0")}, + {[]byte{1}, []byte("value1")}, + {[]byte{3}, []byte("value3")}, + } + testPut(t, ps, tr1, tr2) + }) + t.Run("incomplete", func(t *testing.T) { + var ps = pairs{ + {[]byte{}, []byte("replace0")}, + {[]byte{1}, []byte("replace1")}, + {[]byte{2}, nil}, + {[]byte{3}, []byte("replace3")}, + } + tr1 := NewTrie(new(HashNode), false, newTestStore()) + tr2 := NewTrie(new(HashNode), false, newTestStore()) + testIncompletePut(t, ps, 2, tr1, tr2) + }) +} + +// For the sake of coverage. +func TestTrie_InvalidNodeType(t *testing.T) { + tr := NewTrie(new(HashNode), false, newTestStore()) + var b Batch + b.Add([]byte{1}, []byte("value")) + tr.root = Node(nil) + require.Panics(t, func() { _, _ = tr.PutBatch(b) }) +} + +func TestTrie_PutBatch(t *testing.T) { + tr1 := NewTrie(new(HashNode), false, newTestStore()) + tr2 := NewTrie(new(HashNode), false, newTestStore()) + var ps = pairs{ + {[]byte{1}, []byte{1}}, + {[]byte{2}, []byte{3}}, + {[]byte{4}, []byte{5}}, + } + testPut(t, ps, tr1, tr2) + + ps = pairs{[2][]byte{{4}, {6}}} + testPut(t, ps, tr1, tr2) + + ps = pairs{[2][]byte{{4}, nil}} + testPut(t, ps, tr1, tr2) +} + +// This function is unused, but is helpful for debugging +// as it provides more readable Trie representation compared to +// `spew.Dump()` +func printNode(prefix string, n Node) { + switch tn := n.(type) { + case *HashNode: + if tn.IsEmpty() { + fmt.Printf("%s empty\n", prefix) + return + } + fmt.Printf("%s %s\n", prefix, tn.Hash().StringLE()) + case *BranchNode: + for i, c := range tn.Children { + if isEmpty(c) { + continue + } + fmt.Printf("%s [%2d] ->\n", prefix, i) + printNode(prefix+" ", c) + } + case *ExtensionNode: + fmt.Printf("%s extension-> %s\n", prefix, hex.EncodeToString(tn.key)) + printNode(prefix+" ", tn.next) + case *LeafNode: + fmt.Printf("%s leaf-> %s\n", prefix, hex.EncodeToString(tn.value)) + } +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go index 1c67c6c59..03b4e3337 100644 --- a/pkg/core/mpt/helpers.go +++ b/pkg/core/mpt/helpers.go @@ -17,6 +17,20 @@ func lcp(a, b []byte) []byte { return a[:i] } +func lcpMany(kv []keyValue) []byte { + if len(kv) == 1 { + return kv[0].key + } + p := lcp(kv[0].key, kv[1].key) + if len(p) == 0 { + return p + } + for i := range kv[2:] { + p = lcp(p, kv[2+i].key) + } + return p +} + // copySlice is a helper for copying slice if needed. func copySlice(a []byte) []byte { b := make([]byte, len(a))