From dc6741bce77aaf85115d96331bb9653db1f0afe4 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 22 May 2020 10:37:07 +0300 Subject: [PATCH] mpt: implement MPT trie MPT is a trie with a branching factor = 16, i.e. it consists of sequences in 16-element alphabet. Signed-off-by: Evgenii Stratonikov --- pkg/core/mpt/branch.go | 78 ++++++++ pkg/core/mpt/doc.go | 45 +++++ pkg/core/mpt/extension.go | 70 +++++++ pkg/core/mpt/hash.go | 61 +++++++ pkg/core/mpt/helpers.go | 35 ++++ pkg/core/mpt/leaf.go | 56 ++++++ pkg/core/mpt/node.go | 70 +++++++ pkg/core/mpt/node_test.go | 94 ++++++++++ pkg/core/mpt/trie.go | 357 ++++++++++++++++++++++++++++++++++++ pkg/core/mpt/trie_test.go | 373 ++++++++++++++++++++++++++++++++++++++ pkg/core/storage/store.go | 1 + 11 files changed, 1240 insertions(+) create mode 100644 pkg/core/mpt/branch.go create mode 100644 pkg/core/mpt/doc.go create mode 100644 pkg/core/mpt/extension.go create mode 100644 pkg/core/mpt/hash.go create mode 100644 pkg/core/mpt/helpers.go create mode 100644 pkg/core/mpt/leaf.go create mode 100644 pkg/core/mpt/node.go create mode 100644 pkg/core/mpt/node_test.go create mode 100644 pkg/core/mpt/trie.go create mode 100644 pkg/core/mpt/trie_test.go diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go new file mode 100644 index 000000000..ac3f2400f --- /dev/null +++ b/pkg/core/mpt/branch.go @@ -0,0 +1,78 @@ +package mpt + +import ( + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +const ( + // childrenCount represents a number of children of a branch node. + childrenCount = 17 + // lastChild is the index of the last child. + lastChild = childrenCount - 1 +) + +// BranchNode represents MPT's branch node. +type BranchNode struct { + hash util.Uint256 + valid bool + + Children [childrenCount]Node +} + +var _ Node = (*BranchNode)(nil) + +// NewBranchNode returns new branch node. +func NewBranchNode() *BranchNode { + b := new(BranchNode) + for i := 0; i < childrenCount; i++ { + b.Children[i] = new(HashNode) + } + return b +} + +// Type implements Node interface. +func (b *BranchNode) Type() NodeType { return BranchT } + +// Hash implements Node interface. +func (b *BranchNode) Hash() util.Uint256 { + if !b.valid { + b.hash = hash.DoubleSha256(toBytes(b)) + b.valid = true + } + return b.hash +} + +// invalidateHash invalidates node hash. +func (b *BranchNode) invalidateHash() { + b.valid = false +} + +// EncodeBinary implements io.Serializable. +func (b *BranchNode) EncodeBinary(w *io.BinWriter) { + for i := 0; i < childrenCount; i++ { + if hn, ok := b.Children[i].(*HashNode); ok { + hn.EncodeBinary(w) + continue + } + n := NewHashNode(b.Children[i].Hash()) + n.EncodeBinary(w) + } +} + +// DecodeBinary implements io.Serializable. +func (b *BranchNode) DecodeBinary(r *io.BinReader) { + for i := 0; i < childrenCount; i++ { + b.Children[i] = new(HashNode) + b.Children[i].DecodeBinary(r) + } +} + +// splitPath splits path for a branch node. +func splitPath(path []byte) (byte, []byte) { + if len(path) != 0 { + return path[0], path[1:] + } + return lastChild, path +} diff --git a/pkg/core/mpt/doc.go b/pkg/core/mpt/doc.go new file mode 100644 index 000000000..c307665b3 --- /dev/null +++ b/pkg/core/mpt/doc.go @@ -0,0 +1,45 @@ +/* +Package mpt implements MPT (Merkle-Patricia Tree). + +MPT stores key-value pairs and is a trie over 16-symbol alphabet. https://en.wikipedia.org/wiki/Trie +Trie is a tree where values are stored in leafs and keys are paths from root to the leaf node. +MPT consists of 4 type of nodes: +- Leaf node contains only value. +- Extension node contains both key and value. +- Branch node contains 2 or more children. +- Hash node is a compressed node and contains only actual node's hash. + The actual node must be retrieved from storage or over the network. + +As an example here is a trie containing 3 pairs: +- 0x1201 -> val1 +- 0x1203 -> val2 +- 0x1224 -> val3 +- 0x12 -> val4 + +ExtensionNode(0x0102), Next + _______________________| + | +BranchNode [0, 1, 2, ...], Last -> Leaf(val4) + | | + | ExtensionNode [0x04], Next -> Leaf(val3) + | + BranchNode [0, 1, 2, 3, ...], Last -> HashNode(nil) + | | + | Leaf(val2) + | + Leaf(val1) + +There are 3 invariants that this implementation has: +- Branch node cannot have <= 1 children +- Extension node cannot have zero-length key +- Extension node cannot have another Extension node in it's next field + +Thank to these restrictions, there is a single root hash for every set of key-value pairs +irregardless of the order they were added/removed with. +The actual trie structure can vary because of node -> HashNode compressing. + +There is also one optimization which cost us almost nothing in terms of complexity but is very beneficial: +When we perform get/put/delete on a speficic path, every Hash node which was retreived from storage is +replaced by its uncompressed form, so that subsequent hits of this not don't use storage. +*/ +package mpt diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go new file mode 100644 index 000000000..775078827 --- /dev/null +++ b/pkg/core/mpt/extension.go @@ -0,0 +1,70 @@ +package mpt + +import ( + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxKeyLength is the max length of the extension node key. +const MaxKeyLength = 1125 + +// ExtensionNode represents MPT's extension node. +type ExtensionNode struct { + hash util.Uint256 + valid bool + + key []byte + next Node +} + +var _ Node = (*ExtensionNode)(nil) + +// NewExtensionNode returns hash node with the specified key and next node. +// Note: because it is a part of Trie, key must be mangled, i.e. must contain only bytes with high half = 0. +func NewExtensionNode(key []byte, next Node) *ExtensionNode { + return &ExtensionNode{ + key: key, + next: next, + } +} + +// Type implements Node interface. +func (e ExtensionNode) Type() NodeType { return ExtensionT } + +// Hash implements Node interface. +func (e *ExtensionNode) Hash() util.Uint256 { + if !e.valid { + e.hash = hash.DoubleSha256(toBytes(e)) + e.valid = true + } + return e.hash +} + +// invalidateHash invalidates node hash. +func (e *ExtensionNode) invalidateHash() { + e.valid = false +} + +// DecodeBinary implements io.Serializable. +func (e *ExtensionNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz > MaxKeyLength { + r.Err = fmt.Errorf("extension node key is too big: %d", sz) + return + } + e.valid = false + e.key = make([]byte, sz) + r.ReadBytes(e.key) + e.next = new(HashNode) + e.next.DecodeBinary(r) +} + +// EncodeBinary implements io.Serializable. +func (e ExtensionNode) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(e.key) + n := NewHashNode(e.next.Hash()) + n.EncodeBinary(w) +} diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go new file mode 100644 index 000000000..a14dec879 --- /dev/null +++ b/pkg/core/mpt/hash.go @@ -0,0 +1,61 @@ +package mpt + +import ( + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// HashNode represents MPT's hash node. +type HashNode struct { + hash util.Uint256 + valid bool +} + +var _ Node = (*HashNode)(nil) + +// NewHashNode returns hash node with the specified hash. +func NewHashNode(h util.Uint256) *HashNode { + return &HashNode{ + hash: h, + valid: true, + } +} + +// Type implements Node interface. +func (h *HashNode) Type() NodeType { return HashT } + +// Hash implements Node interface. +func (h *HashNode) Hash() util.Uint256 { + if !h.valid { + panic("can't get hash of an empty HashNode") + } + return h.hash +} + +// IsEmpty returns true iff h is an empty node i.e. contains no hash. +func (h *HashNode) IsEmpty() bool { return !h.valid } + +// DecodeBinary implements io.Serializable. +func (h *HashNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + switch sz { + case 0: + h.valid = false + case util.Uint256Size: + h.valid = true + r.ReadBytes(h.hash[:]) + default: + r.Err = fmt.Errorf("invalid hash node size: %d", sz) + } +} + +// EncodeBinary implements io.Serializable. +func (h HashNode) EncodeBinary(w *io.BinWriter) { + if !h.valid { + w.WriteVarUint(0) + return + } + w.WriteVarBytes(h.hash[:]) +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go new file mode 100644 index 000000000..1c67c6c59 --- /dev/null +++ b/pkg/core/mpt/helpers.go @@ -0,0 +1,35 @@ +package mpt + +// lcp returns longest common prefix of a and b. +// Note: it does no allocations. +func lcp(a, b []byte) []byte { + if len(a) < len(b) { + return lcp(b, a) + } + + var i int + for i = 0; i < len(b); i++ { + if a[i] != b[i] { + break + } + } + + return a[:i] +} + +// copySlice is a helper for copying slice if needed. +func copySlice(a []byte) []byte { + b := make([]byte, len(a)) + copy(b, a) + return b +} + +// toNibbles mangles path by splitting every byte into 2 containing low- and high- 4-byte part. +func toNibbles(path []byte) []byte { + result := make([]byte, len(path)*2) + for i := range path { + result[i*2] = path[i] >> 4 + result[i*2+1] = path[i] & 0x0F + } + return result +} diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go new file mode 100644 index 000000000..455ae3feb --- /dev/null +++ b/pkg/core/mpt/leaf.go @@ -0,0 +1,56 @@ +package mpt + +import ( + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxValueLength is a max length of a leaf node value. +const MaxValueLength = 1024 * 1024 + +// LeafNode represents MPT's leaf node. +type LeafNode struct { + hash util.Uint256 + valid bool + + value []byte +} + +var _ Node = (*LeafNode)(nil) + +// NewLeafNode returns hash node with the specified value. +func NewLeafNode(value []byte) *LeafNode { + return &LeafNode{value: value} +} + +// Type implements Node interface. +func (n LeafNode) Type() NodeType { return LeafT } + +// Hash implements Node interface. +func (n *LeafNode) Hash() util.Uint256 { + if !n.valid { + n.hash = hash.DoubleSha256(toBytes(n)) + n.valid = true + } + return n.hash +} + +// DecodeBinary implements io.Serializable. +func (n *LeafNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz > MaxValueLength { + r.Err = fmt.Errorf("leaf node value is too big: %d", sz) + return + } + n.valid = false + n.value = make([]byte, sz) + r.ReadBytes(n.value) +} + +// EncodeBinary implements io.Serializable. +func (n LeafNode) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(n.value) +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go new file mode 100644 index 000000000..83abb2e60 --- /dev/null +++ b/pkg/core/mpt/node.go @@ -0,0 +1,70 @@ +package mpt + +import ( + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// NodeType represents node type.. +type NodeType byte + +// Node types definitions. +const ( + BranchT NodeType = 0x00 + ExtensionT NodeType = 0x01 + HashT NodeType = 0x02 + LeafT NodeType = 0x03 +) + +// NodeObject represents Node together with it's type. +// It is used for serialization/deserialization where type info +// is also expected. +type NodeObject struct { + Node +} + +// Node represents common interface of all MPT nodes. +type Node interface { + io.Serializable + Hash() util.Uint256 + Type() NodeType +} + +// EncodeBinary implements io.Serializable. +func (n NodeObject) EncodeBinary(w *io.BinWriter) { + encodeNodeWithType(n.Node, w) +} + +// DecodeBinary implements io.Serializable. +func (n *NodeObject) DecodeBinary(r *io.BinReader) { + typ := NodeType(r.ReadB()) + switch typ { + case BranchT: + n.Node = new(BranchNode) + case ExtensionT: + n.Node = new(ExtensionNode) + case HashT: + n.Node = new(HashNode) + case LeafT: + n.Node = new(LeafNode) + default: + r.Err = fmt.Errorf("invalid node type: %x", typ) + return + } + n.Node.DecodeBinary(r) +} + +// encodeNodeWithType encodes node together with it's type. +func encodeNodeWithType(n Node, w *io.BinWriter) { + w.WriteB(byte(n.Type())) + n.EncodeBinary(w) +} + +// toBytes is a helper for serializing node. +func toBytes(n Node) []byte { + buf := io.NewBufBinWriter() + encodeNodeWithType(n, buf.BinWriter) + return buf.Bytes() +} diff --git a/pkg/core/mpt/node_test.go b/pkg/core/mpt/node_test.go new file mode 100644 index 000000000..0e2c17c96 --- /dev/null +++ b/pkg/core/mpt/node_test.go @@ -0,0 +1,94 @@ +package mpt + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/stretchr/testify/require" +) + +func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) { + return func(t *testing.T) { + bs, err := testserdes.EncodeBinary(expected) + require.NoError(t, err) + err = testserdes.DecodeBinary(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + } +} + +func TestNode_Serializable(t *testing.T) { + t.Run("Leaf", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + l := NewLeafNode(random.Bytes(123)) + t.Run("Raw", getTestFuncEncode(true, l, new(LeafNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{l}, new(NodeObject))) + }) + t.Run("BigValue", getTestFuncEncode(false, + NewLeafNode(random.Bytes(MaxValueLength+1)), new(LeafNode))) + }) + + t.Run("Extension", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + e := NewExtensionNode(random.Bytes(42), NewLeafNode(random.Bytes(10))) + t.Run("Raw", getTestFuncEncode(true, e, new(ExtensionNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{e}, new(NodeObject))) + }) + t.Run("BigKey", getTestFuncEncode(false, + NewExtensionNode(random.Bytes(MaxKeyLength+1), NewLeafNode(random.Bytes(10))), new(ExtensionNode))) + }) + + t.Run("Branch", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewLeafNode(random.Bytes(10)) + b.Children[lastChild] = NewHashNode(random.Uint256()) + t.Run("Raw", getTestFuncEncode(true, b, new(BranchNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{b}, new(NodeObject))) + }) + + t.Run("Hash", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + h := NewHashNode(random.Uint256()) + t.Run("Raw", getTestFuncEncode(true, h, new(HashNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{h}, new(NodeObject))) + }) + t.Run("Empty", func(t *testing.T) { // compare nodes, not hashes + testserdes.EncodeDecodeBinary(t, new(HashNode), new(HashNode)) + }) + t.Run("InvalidSize", func(t *testing.T) { + buf := io.NewBufBinWriter() + buf.BinWriter.WriteVarBytes(make([]byte, 13)) + require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode))) + }) + }) + + t.Run("Invalid", func(t *testing.T) { + require.Error(t, testserdes.DecodeBinary([]byte{0xFF}, new(NodeObject))) + }) +} + +// C# interoperability test +// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135 +func TestRootHash(t *testing.T) { + b := NewBranchNode() + r := NewExtensionNode([]byte{0x0A, 0x0C}, b) + + v1 := NewLeafNode([]byte{0xAB, 0xCD}) + l1 := NewExtensionNode([]byte{0x01}, v1) + b.Children[0] = l1 + + v2 := NewLeafNode([]byte{0x22, 0x22}) + l2 := NewExtensionNode([]byte{0x09}, v2) + b.Children[9] = l2 + + r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1) + require.Equal(t, "dea3ab46e9461e885ed7091c1e533e0a8030b248d39cbc638962394eaca0fbb3", r1.Hash().StringLE()) + require.Equal(t, "93e8e1ffe2f83dd92fca67330e273bcc811bf64b8f8d9d1b25d5e7366b47d60d", r.Hash().StringLE()) +} diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go new file mode 100644 index 000000000..f9589fde3 --- /dev/null +++ b/pkg/core/mpt/trie.go @@ -0,0 +1,357 @@ +package mpt + +import ( + "bytes" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// Trie is an MPT trie storing all key-value pairs. +type Trie struct { + Store *storage.MemCachedStore + + root Node +} + +// ErrNotFound is returned when requested trie item is missing. +var ErrNotFound = errors.New("item not found") + +// NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors +// so that all storage errors are processed during `store.Persist()` at the caller. +// This also has the benefit, that every `Put` can be considered an atomic operation. +func NewTrie(root Node, store *storage.MemCachedStore) *Trie { + if root == nil { + root = new(HashNode) + } + + return &Trie{ + Store: store, + root: root, + } +} + +// Get returns value for the provided key in t. +func (t *Trie) Get(key []byte) ([]byte, error) { + path := toNibbles(key) + r, bs, err := t.getWithPath(t.root, path) + if err != nil { + return nil, err + } + t.root = r + return bs, nil +} + +// getWithPath returns value the provided path in a subtrie rooting in curr. +// It also returns a current node with all hash nodes along the path +// replaced to their "unhashed" counterparts. +func (t *Trie) getWithPath(curr Node, path []byte) (Node, []byte, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + return curr, copySlice(n.value), nil + } + case *BranchNode: + i, path := splitPath(path) + r, bs, err := t.getWithPath(n.Children[i], path) + if err != nil { + return nil, nil, err + } + n.Children[i] = r + return n, bs, nil + case *HashNode: + if !n.IsEmpty() { + if r, err := t.getFromStore(n.hash); err == nil { + return t.getWithPath(r, path) + } + } + case *ExtensionNode: + if bytes.HasPrefix(path, n.key) { + r, bs, err := t.getWithPath(n.next, path[len(n.key):]) + if err != nil { + return nil, nil, err + } + n.next = r + return curr, bs, err + } + default: + panic("invalid MPT node type") + } + return curr, nil, ErrNotFound +} + +// Put puts key-value pair in t. +func (t *Trie) Put(key, value []byte) error { + if len(key) > MaxKeyLength { + return errors.New("key is too big") + } else if len(value) > MaxValueLength { + return errors.New("value is too big") + } + if len(value) == 0 { + return t.Delete(key) + } + path := toNibbles(key) + n := NewLeafNode(value) + r, err := t.putIntoNode(t.root, path, n) + if err != nil { + return err + } + t.root = r + return nil +} + +// putIntoLeaf puts val to trie if current node is a Leaf. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { + v := val.(*LeafNode) + if len(path) == 0 { + return v, nil + } + + b := NewBranchNode() + b.Children[path[0]] = newSubTrie(path[1:], v) + b.Children[lastChild] = curr + return b, nil +} + +// putIntoBranch puts val to trie if current node is a Branch. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { + i, path := splitPath(path) + r, err := t.putIntoNode(curr.Children[i], path, val) + if err != nil { + return nil, err + } + curr.Children[i] = r + curr.invalidateHash() + return curr, nil +} + +// putIntoExtension puts val to trie if current node is an Extension. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { + if bytes.HasPrefix(path, curr.key) { + r, err := t.putIntoNode(curr.next, path[len(curr.key):], val) + if err != nil { + return nil, err + } + curr.next = r + curr.invalidateHash() + return curr, nil + } + + pref := lcp(curr.key, path) + lp := len(pref) + keyTail := curr.key[lp:] + pathTail := path[lp:] + + s1 := newSubTrie(keyTail[1:], curr.next) + b := NewBranchNode() + b.Children[keyTail[0]] = s1 + + i, pathTail := splitPath(pathTail) + s2 := newSubTrie(pathTail, val) + b.Children[i] = s2 + + if lp > 0 { + return NewExtensionNode(copySlice(pref), b), nil + } + return b, nil +} + +// putIntoHash puts val to trie if current node is a HashNode. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { + if curr.IsEmpty() { + return newSubTrie(path, val), nil + } + + result, err := t.getFromStore(curr.hash) + if err != nil { + return nil, err + } + return t.putIntoNode(result, path, val) +} + +// newSubTrie create new trie containing node at provided path. +func newSubTrie(path []byte, val Node) Node { + if len(path) == 0 { + return val + } + return NewExtensionNode(path, val) +} + +func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + return t.putIntoLeaf(n, path, val) + case *BranchNode: + return t.putIntoBranch(n, path, val) + case *ExtensionNode: + return t.putIntoExtension(n, path, val) + case *HashNode: + return t.putIntoHash(n, path, val) + default: + panic("invalid MPT node type") + } +} + +// Delete removes key from trie. +// It returns no error on missing key. +func (t *Trie) Delete(key []byte) error { + path := toNibbles(key) + r, err := t.deleteFromNode(t.root, path) + if err != nil { + return err + } + t.root = r + return nil +} + +func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { + i, path := splitPath(path) + r, err := t.deleteFromNode(b.Children[i], path) + if err != nil { + return nil, err + } + b.Children[i] = r + b.invalidateHash() + var count, index int + for i := range b.Children { + h, ok := b.Children[i].(*HashNode) + if !ok || !h.IsEmpty() { + index = i + count++ + } + } + // count is >= 1 because branch node had at least 2 children before deletion. + if count > 1 { + return b, nil + } + c := b.Children[index] + if index == lastChild { + return c, nil + } + if h, ok := c.(*HashNode); ok { + c, err = t.getFromStore(h.Hash()) + if err != nil { + return nil, err + } + } + if e, ok := c.(*ExtensionNode); ok { + e.key = append([]byte{byte(index)}, e.key...) + e.invalidateHash() + return e, nil + } + + return NewExtensionNode([]byte{byte(index)}, c), nil +} + +func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) { + if !bytes.HasPrefix(path, n.key) { + return nil, ErrNotFound + } + r, err := t.deleteFromNode(n.next, path[len(n.key):]) + if err != nil { + return nil, err + } + switch nxt := r.(type) { + case *ExtensionNode: + n.key = append(n.key, nxt.key...) + n.next = nxt.next + n.invalidateHash() + case *HashNode: + if nxt.IsEmpty() { + return nxt, nil + } + default: + n.next = r + } + return n, nil +} + +func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + return new(HashNode), nil + } + return nil, ErrNotFound + case *BranchNode: + return t.deleteFromBranch(n, path) + case *ExtensionNode: + return t.deleteFromExtension(n, path) + case *HashNode: + if n.IsEmpty() { + return nil, ErrNotFound + } + newNode, err := t.getFromStore(n.Hash()) + if err != nil { + return nil, err + } + return t.deleteFromNode(newNode, path) + default: + panic("invalid MPT node type") + } +} + +// StateRoot returns root hash of t. +func (t *Trie) StateRoot() util.Uint256 { + if hn, ok := t.root.(*HashNode); ok && hn.IsEmpty() { + return util.Uint256{} + } + return t.root.Hash() +} + +func makeStorageKey(mptKey []byte) []byte { + return append([]byte{byte(storage.DataMPT)}, mptKey...) +} + +// Flush puts every node in the trie except Hash ones to the storage. +// Because we care only about block-level changes, there is no need to put every +// new node to storage. Normally, flush should be called with every StateRoot persist, i.e. +// after every block. +func (t *Trie) Flush() { + t.flush(t.root) +} + +func (t *Trie) flush(node Node) { + switch n := node.(type) { + case *BranchNode: + for i := range n.Children { + t.flush(n.Children[i]) + } + case *ExtensionNode: + t.flush(n.next) + case *HashNode: + return + } + t.putToStore(node) +} + +func (t *Trie) putToStore(n Node) { + if n.Type() == HashT { + panic("can't put hash node in trie") + } + bs := toBytes(n) + h := hash.DoubleSha256(bs) + _ = t.Store.Put(makeStorageKey(h.BytesBE()), bs) // put in MemCached returns no errors +} + +func (t *Trie) getFromStore(h util.Uint256) (Node, error) { + data, err := t.Store.Get(makeStorageKey(h.BytesBE())) + if err != nil { + return nil, err + } + + var n NodeObject + r := io.NewBinReaderFromBuf(data) + n.DecodeBinary(r) + if r.Err != nil { + return nil, r.Err + } + return n.Node, nil +} diff --git a/pkg/core/mpt/trie_test.go b/pkg/core/mpt/trie_test.go new file mode 100644 index 000000000..470e0c8e5 --- /dev/null +++ b/pkg/core/mpt/trie_test.go @@ -0,0 +1,373 @@ +package mpt + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/stretchr/testify/require" +) + +func newTestStore() *storage.MemCachedStore { + return storage.NewMemCachedStore(storage.NewMemoryStore()) +} + +func newTestTrie(t *testing.T) *Trie { + b := NewBranchNode() + + l1 := NewLeafNode([]byte{0xAB, 0xCD}) + b.Children[0] = NewExtensionNode([]byte{0x01}, l1) + + l2 := NewLeafNode([]byte{0x22, 0x22}) + b.Children[9] = NewExtensionNode([]byte{0x09}, l2) + + v := NewLeafNode([]byte("hello")) + h := NewHashNode(v.Hash()) + b.Children[10] = NewExtensionNode([]byte{0x0e}, h) + + e := NewExtensionNode(toNibbles([]byte{0xAC}), b) + tr := NewTrie(e, newTestStore()) + + tr.putToStore(e) + tr.putToStore(b) + tr.putToStore(l1) + tr.putToStore(l2) + tr.putToStore(v) + tr.putToStore(b.Children[0]) + tr.putToStore(b.Children[9]) + tr.putToStore(b.Children[10]) + + return tr +} + +func TestTrie_PutIntoBranchNode(t *testing.T) { + b := NewBranchNode() + l := NewLeafNode([]byte{0x8}) + b.Children[0x7] = NewHashNode(l.Hash()) + b.Children[0x8] = NewHashNode(random.Uint256()) + tr := NewTrie(b, newTestStore()) + + // next + require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34})) + tr.testHas(t, []byte{}, []byte{0x12, 0x34}) + + // empty hash node child + require.NoError(t, tr.Put([]byte{0x66}, []byte{0x56})) + tr.testHas(t, []byte{0x66}, []byte{0x56}) + require.True(t, isValid(tr.root)) + + // missing hash + require.Error(t, tr.Put([]byte{0x70}, []byte{0x42})) + require.True(t, isValid(tr.root)) + + // hash is in store + tr.putToStore(l) + require.NoError(t, tr.Put([]byte{0x70}, []byte{0x42})) + require.True(t, isValid(tr.root)) +} + +func TestTrie_PutIntoExtensionNode(t *testing.T) { + l := NewLeafNode([]byte{0x11}) + key := []byte{0x12} + e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash())) + tr := NewTrie(e, newTestStore()) + + // missing hash + require.Error(t, tr.Put(key, []byte{0x42})) + + tr.putToStore(l) + require.NoError(t, tr.Put(key, []byte{0x42})) + tr.testHas(t, key, []byte{0x42}) + require.True(t, isValid(tr.root)) +} + +func TestTrie_PutIntoHashNode(t *testing.T) { + b := NewBranchNode() + l := NewLeafNode(random.Bytes(5)) + e := NewExtensionNode([]byte{0x02}, l) + b.Children[1] = NewHashNode(e.Hash()) + b.Children[9] = NewHashNode(random.Uint256()) + tr := NewTrie(b, newTestStore()) + + tr.putToStore(e) + + t.Run("MissingLeafHash", func(t *testing.T) { + _, err := tr.Get([]byte{0x12}) + require.Error(t, err) + }) + + tr.putToStore(l) + + val := random.Bytes(3) + require.NoError(t, tr.Put([]byte{0x12, 0x34}, val)) + tr.testHas(t, []byte{0x12, 0x34}, val) + tr.testHas(t, []byte{0x12}, l.value) + require.True(t, isValid(tr.root)) +} + +func TestTrie_Put(t *testing.T) { + trExp := newTestTrie(t) + + trAct := NewTrie(nil, newTestStore()) + require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD})) + require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22})) + require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello"))) + + // Note: the exact tries differ because of ("acae":"hello") node is stored as Hash node in test trie. + require.Equal(t, trExp.root.Hash(), trAct.root.Hash()) + require.True(t, isValid(trAct.root)) +} + +func TestTrie_PutInvalid(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + key, value := []byte("key"), []byte("value") + + // big key + require.Error(t, tr.Put(make([]byte, MaxKeyLength+1), value)) + + // big value + require.Error(t, tr.Put(key, make([]byte, MaxValueLength+1))) + + // this is ok though + require.NoError(t, tr.Put(key, value)) + tr.testHas(t, key, value) +} + +func TestTrie_BigPut(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + items := []struct{ k, v string }{ + {"item with long key", "value1"}, + {"item with matching prefix", "value2"}, + {"another prefix", "value3"}, + {"another prefix 2", "value4"}, + {"another ", "value5"}, + } + + for i := range items { + require.NoError(t, tr.Put([]byte(items[i].k), []byte(items[i].v))) + } + + for i := range items { + tr.testHas(t, []byte(items[i].k), []byte(items[i].v)) + } + + t.Run("Rewrite", func(t *testing.T) { + k, v := []byte(items[0].k), []byte{0x01, 0x23} + require.NoError(t, tr.Put(k, v)) + tr.testHas(t, k, v) + }) + + t.Run("Remove", func(t *testing.T) { + k := []byte(items[1].k) + require.NoError(t, tr.Put(k, []byte{})) + tr.testHas(t, k, nil) + }) +} + +func (tr *Trie) testHas(t *testing.T, key, value []byte) { + v, err := tr.Get(key) + if value == nil { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, value, v) +} + +// isValid checks for 3 invariants: +// - BranchNode contains > 1 children +// - ExtensionNode do not contain another extension node +// - ExtensionNode do not have nil key +// It is used only during testing to catch possible bugs. +func isValid(curr Node) bool { + switch n := curr.(type) { + case *BranchNode: + var count int + for i := range n.Children { + if !isValid(n.Children[i]) { + return false + } + hn, ok := n.Children[i].(*HashNode) + if !ok || !hn.IsEmpty() { + count++ + } + } + return count > 1 + case *ExtensionNode: + _, ok := n.next.(*ExtensionNode) + return len(n.key) != 0 && !ok + default: + return true + } +} + +func TestTrie_Get(t *testing.T) { + t.Run("HashNode", func(t *testing.T) { + tr := newTestTrie(t) + tr.testHas(t, []byte{0xAC, 0xAE}, []byte("hello")) + }) + t.Run("UnfoldRoot", func(t *testing.T) { + tr := newTestTrie(t) + single := NewTrie(NewHashNode(tr.root.Hash()), tr.Store) + single.testHas(t, []byte{0xAC}, nil) + single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD}) + single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22}) + single.testHas(t, []byte{0xAC, 0xAE}, []byte("hello")) + }) +} + +func TestTrie_Flush(t *testing.T) { + pairs := map[string][]byte{ + "": []byte("value0"), + "key1": []byte("value1"), + "key2": []byte("value2"), + } + + tr := NewTrie(nil, newTestStore()) + for k, v := range pairs { + require.NoError(t, tr.Put([]byte(k), v)) + } + + tr.Flush() + tr = NewTrie(NewHashNode(tr.StateRoot()), tr.Store) + for k, v := range pairs { + actual, err := tr.Get([]byte(k)) + require.NoError(t, err) + require.Equal(t, v, actual) + } +} + +func TestTrie_Delete(t *testing.T) { + t.Run("Hash", func(t *testing.T) { + t.Run("FromStore", func(t *testing.T) { + l := NewLeafNode([]byte{0x12}) + tr := NewTrie(NewHashNode(l.Hash()), newTestStore()) + t.Run("NotInStore", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{})) + }) + + tr.putToStore(l) + tr.testHas(t, []byte{}, []byte{0x12}) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + }) + + t.Run("Empty", func(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + require.Error(t, tr.Delete([]byte{})) + }) + }) + + t.Run("Leaf", func(t *testing.T) { + l := NewLeafNode([]byte{0x12, 0x34}) + tr := NewTrie(l, newTestStore()) + t.Run("NonExistentKey", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{0x12})) + tr.testHas(t, []byte{}, []byte{0x12, 0x34}) + }) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + }) + + t.Run("Extension", func(t *testing.T) { + t.Run("SingleKey", func(t *testing.T) { + l := NewLeafNode([]byte{0x12, 0x34}) + e := NewExtensionNode([]byte{0x0A, 0x0B}, l) + tr := NewTrie(e, newTestStore()) + + t.Run("NonExistentKey", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{0xAB}, []byte{0x12, 0x34}) + }) + + require.NoError(t, tr.Delete([]byte{0xAB})) + require.True(t, tr.root.(*HashNode).IsEmpty()) + }) + + t.Run("MultipleKeys", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34})) + b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78})) + e := NewExtensionNode([]byte{0x01, 0x02}, b) + tr := NewTrie(e, newTestStore()) + + h := e.Hash() + require.NoError(t, tr.Delete([]byte{0x12, 0x01})) + tr.testHas(t, []byte{0x12, 0x01}, nil) + tr.testHas(t, []byte{0x12, 0x67}, []byte{0x56, 0x78}) + + require.NotEqual(t, h, tr.root.Hash()) + require.Equal(t, toNibbles([]byte{0x12, 0x67}), e.key) + require.IsType(t, (*LeafNode)(nil), e.next) + }) + }) + + t.Run("Branch", func(t *testing.T) { + t.Run("3 Children", func(t *testing.T) { + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34})) + b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56})) + tr := NewTrie(b, newTestStore()) + require.NoError(t, tr.Delete([]byte{0x16})) + tr.testHas(t, []byte{}, []byte{0x12}) + tr.testHas(t, []byte{0x01}, []byte{0x34}) + tr.testHas(t, []byte{0x16}, nil) + }) + t.Run("2 Children", func(t *testing.T) { + newt := func(t *testing.T) *Trie { + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + l := NewLeafNode([]byte{0x34}) + e := NewExtensionNode([]byte{0x06}, l) + b.Children[5] = NewHashNode(e.Hash()) + tr := NewTrie(b, newTestStore()) + tr.putToStore(l) + tr.putToStore(e) + return tr + } + + t.Run("DeleteLast", func(t *testing.T) { + t.Run("MergeExtension", func(t *testing.T) { + tr := newt(t) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + tr.testHas(t, []byte{0x56}, []byte{0x34}) + require.IsType(t, (*ExtensionNode)(nil), tr.root) + }) + + t.Run("LeaveLeaf", func(t *testing.T) { + c := NewBranchNode() + c.Children[5] = NewLeafNode([]byte{0x05}) + c.Children[6] = NewLeafNode([]byte{0x06}) + + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + b.Children[5] = c + tr := NewTrie(b, newTestStore()) + + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + tr.testHas(t, []byte{0x55}, []byte{0x05}) + tr.testHas(t, []byte{0x56}, []byte{0x06}) + require.IsType(t, (*ExtensionNode)(nil), tr.root) + }) + }) + + t.Run("DeleteMiddle", func(t *testing.T) { + tr := newt(t) + require.NoError(t, tr.Delete([]byte{0x56})) + tr.testHas(t, []byte{}, []byte{0x12}) + tr.testHas(t, []byte{0x56}, nil) + require.IsType(t, (*LeafNode)(nil), tr.root) + }) + }) + }) +} + +func TestTrie_PanicInvalidRoot(t *testing.T) { + tr := &Trie{Store: newTestStore()} + require.Panics(t, func() { _ = tr.Put([]byte{1}, []byte{2}) }) + require.Panics(t, func() { _, _ = tr.Get([]byte{1}) }) + require.Panics(t, func() { _ = tr.Delete([]byte{1}) }) +} diff --git a/pkg/core/storage/store.go b/pkg/core/storage/store.go index 5e70334ed..575c42ba3 100644 --- a/pkg/core/storage/store.go +++ b/pkg/core/storage/store.go @@ -9,6 +9,7 @@ import ( const ( DataBlock KeyPrefix = 0x01 DataTransaction KeyPrefix = 0x02 + DataMPT KeyPrefix = 0x03 STAccount KeyPrefix = 0x40 STNotification KeyPrefix = 0x4d STContract KeyPrefix = 0x50