From dc6741bce77aaf85115d96331bb9653db1f0afe4 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 22 May 2020 10:37:07 +0300 Subject: [PATCH 01/13] mpt: implement MPT trie MPT is a trie with a branching factor = 16, i.e. it consists of sequences in 16-element alphabet. Signed-off-by: Evgenii Stratonikov --- pkg/core/mpt/branch.go | 78 ++++++++ pkg/core/mpt/doc.go | 45 +++++ pkg/core/mpt/extension.go | 70 +++++++ pkg/core/mpt/hash.go | 61 +++++++ pkg/core/mpt/helpers.go | 35 ++++ pkg/core/mpt/leaf.go | 56 ++++++ pkg/core/mpt/node.go | 70 +++++++ pkg/core/mpt/node_test.go | 94 ++++++++++ pkg/core/mpt/trie.go | 357 ++++++++++++++++++++++++++++++++++++ pkg/core/mpt/trie_test.go | 373 ++++++++++++++++++++++++++++++++++++++ pkg/core/storage/store.go | 1 + 11 files changed, 1240 insertions(+) create mode 100644 pkg/core/mpt/branch.go create mode 100644 pkg/core/mpt/doc.go create mode 100644 pkg/core/mpt/extension.go create mode 100644 pkg/core/mpt/hash.go create mode 100644 pkg/core/mpt/helpers.go create mode 100644 pkg/core/mpt/leaf.go create mode 100644 pkg/core/mpt/node.go create mode 100644 pkg/core/mpt/node_test.go create mode 100644 pkg/core/mpt/trie.go create mode 100644 pkg/core/mpt/trie_test.go diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go new file mode 100644 index 000000000..ac3f2400f --- /dev/null +++ b/pkg/core/mpt/branch.go @@ -0,0 +1,78 @@ +package mpt + +import ( + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +const ( + // childrenCount represents a number of children of a branch node. + childrenCount = 17 + // lastChild is the index of the last child. + lastChild = childrenCount - 1 +) + +// BranchNode represents MPT's branch node. +type BranchNode struct { + hash util.Uint256 + valid bool + + Children [childrenCount]Node +} + +var _ Node = (*BranchNode)(nil) + +// NewBranchNode returns new branch node. +func NewBranchNode() *BranchNode { + b := new(BranchNode) + for i := 0; i < childrenCount; i++ { + b.Children[i] = new(HashNode) + } + return b +} + +// Type implements Node interface. +func (b *BranchNode) Type() NodeType { return BranchT } + +// Hash implements Node interface. +func (b *BranchNode) Hash() util.Uint256 { + if !b.valid { + b.hash = hash.DoubleSha256(toBytes(b)) + b.valid = true + } + return b.hash +} + +// invalidateHash invalidates node hash. +func (b *BranchNode) invalidateHash() { + b.valid = false +} + +// EncodeBinary implements io.Serializable. +func (b *BranchNode) EncodeBinary(w *io.BinWriter) { + for i := 0; i < childrenCount; i++ { + if hn, ok := b.Children[i].(*HashNode); ok { + hn.EncodeBinary(w) + continue + } + n := NewHashNode(b.Children[i].Hash()) + n.EncodeBinary(w) + } +} + +// DecodeBinary implements io.Serializable. +func (b *BranchNode) DecodeBinary(r *io.BinReader) { + for i := 0; i < childrenCount; i++ { + b.Children[i] = new(HashNode) + b.Children[i].DecodeBinary(r) + } +} + +// splitPath splits path for a branch node. +func splitPath(path []byte) (byte, []byte) { + if len(path) != 0 { + return path[0], path[1:] + } + return lastChild, path +} diff --git a/pkg/core/mpt/doc.go b/pkg/core/mpt/doc.go new file mode 100644 index 000000000..c307665b3 --- /dev/null +++ b/pkg/core/mpt/doc.go @@ -0,0 +1,45 @@ +/* +Package mpt implements MPT (Merkle-Patricia Tree). + +MPT stores key-value pairs and is a trie over 16-symbol alphabet. https://en.wikipedia.org/wiki/Trie +Trie is a tree where values are stored in leafs and keys are paths from root to the leaf node. +MPT consists of 4 type of nodes: +- Leaf node contains only value. +- Extension node contains both key and value. +- Branch node contains 2 or more children. +- Hash node is a compressed node and contains only actual node's hash. + The actual node must be retrieved from storage or over the network. + +As an example here is a trie containing 3 pairs: +- 0x1201 -> val1 +- 0x1203 -> val2 +- 0x1224 -> val3 +- 0x12 -> val4 + +ExtensionNode(0x0102), Next + _______________________| + | +BranchNode [0, 1, 2, ...], Last -> Leaf(val4) + | | + | ExtensionNode [0x04], Next -> Leaf(val3) + | + BranchNode [0, 1, 2, 3, ...], Last -> HashNode(nil) + | | + | Leaf(val2) + | + Leaf(val1) + +There are 3 invariants that this implementation has: +- Branch node cannot have <= 1 children +- Extension node cannot have zero-length key +- Extension node cannot have another Extension node in it's next field + +Thank to these restrictions, there is a single root hash for every set of key-value pairs +irregardless of the order they were added/removed with. +The actual trie structure can vary because of node -> HashNode compressing. + +There is also one optimization which cost us almost nothing in terms of complexity but is very beneficial: +When we perform get/put/delete on a speficic path, every Hash node which was retreived from storage is +replaced by its uncompressed form, so that subsequent hits of this not don't use storage. +*/ +package mpt diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go new file mode 100644 index 000000000..775078827 --- /dev/null +++ b/pkg/core/mpt/extension.go @@ -0,0 +1,70 @@ +package mpt + +import ( + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxKeyLength is the max length of the extension node key. +const MaxKeyLength = 1125 + +// ExtensionNode represents MPT's extension node. +type ExtensionNode struct { + hash util.Uint256 + valid bool + + key []byte + next Node +} + +var _ Node = (*ExtensionNode)(nil) + +// NewExtensionNode returns hash node with the specified key and next node. +// Note: because it is a part of Trie, key must be mangled, i.e. must contain only bytes with high half = 0. +func NewExtensionNode(key []byte, next Node) *ExtensionNode { + return &ExtensionNode{ + key: key, + next: next, + } +} + +// Type implements Node interface. +func (e ExtensionNode) Type() NodeType { return ExtensionT } + +// Hash implements Node interface. +func (e *ExtensionNode) Hash() util.Uint256 { + if !e.valid { + e.hash = hash.DoubleSha256(toBytes(e)) + e.valid = true + } + return e.hash +} + +// invalidateHash invalidates node hash. +func (e *ExtensionNode) invalidateHash() { + e.valid = false +} + +// DecodeBinary implements io.Serializable. +func (e *ExtensionNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz > MaxKeyLength { + r.Err = fmt.Errorf("extension node key is too big: %d", sz) + return + } + e.valid = false + e.key = make([]byte, sz) + r.ReadBytes(e.key) + e.next = new(HashNode) + e.next.DecodeBinary(r) +} + +// EncodeBinary implements io.Serializable. +func (e ExtensionNode) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(e.key) + n := NewHashNode(e.next.Hash()) + n.EncodeBinary(w) +} diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go new file mode 100644 index 000000000..a14dec879 --- /dev/null +++ b/pkg/core/mpt/hash.go @@ -0,0 +1,61 @@ +package mpt + +import ( + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// HashNode represents MPT's hash node. +type HashNode struct { + hash util.Uint256 + valid bool +} + +var _ Node = (*HashNode)(nil) + +// NewHashNode returns hash node with the specified hash. +func NewHashNode(h util.Uint256) *HashNode { + return &HashNode{ + hash: h, + valid: true, + } +} + +// Type implements Node interface. +func (h *HashNode) Type() NodeType { return HashT } + +// Hash implements Node interface. +func (h *HashNode) Hash() util.Uint256 { + if !h.valid { + panic("can't get hash of an empty HashNode") + } + return h.hash +} + +// IsEmpty returns true iff h is an empty node i.e. contains no hash. +func (h *HashNode) IsEmpty() bool { return !h.valid } + +// DecodeBinary implements io.Serializable. +func (h *HashNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + switch sz { + case 0: + h.valid = false + case util.Uint256Size: + h.valid = true + r.ReadBytes(h.hash[:]) + default: + r.Err = fmt.Errorf("invalid hash node size: %d", sz) + } +} + +// EncodeBinary implements io.Serializable. +func (h HashNode) EncodeBinary(w *io.BinWriter) { + if !h.valid { + w.WriteVarUint(0) + return + } + w.WriteVarBytes(h.hash[:]) +} diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go new file mode 100644 index 000000000..1c67c6c59 --- /dev/null +++ b/pkg/core/mpt/helpers.go @@ -0,0 +1,35 @@ +package mpt + +// lcp returns longest common prefix of a and b. +// Note: it does no allocations. +func lcp(a, b []byte) []byte { + if len(a) < len(b) { + return lcp(b, a) + } + + var i int + for i = 0; i < len(b); i++ { + if a[i] != b[i] { + break + } + } + + return a[:i] +} + +// copySlice is a helper for copying slice if needed. +func copySlice(a []byte) []byte { + b := make([]byte, len(a)) + copy(b, a) + return b +} + +// toNibbles mangles path by splitting every byte into 2 containing low- and high- 4-byte part. +func toNibbles(path []byte) []byte { + result := make([]byte, len(path)*2) + for i := range path { + result[i*2] = path[i] >> 4 + result[i*2+1] = path[i] & 0x0F + } + return result +} diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go new file mode 100644 index 000000000..455ae3feb --- /dev/null +++ b/pkg/core/mpt/leaf.go @@ -0,0 +1,56 @@ +package mpt + +import ( + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MaxValueLength is a max length of a leaf node value. +const MaxValueLength = 1024 * 1024 + +// LeafNode represents MPT's leaf node. +type LeafNode struct { + hash util.Uint256 + valid bool + + value []byte +} + +var _ Node = (*LeafNode)(nil) + +// NewLeafNode returns hash node with the specified value. +func NewLeafNode(value []byte) *LeafNode { + return &LeafNode{value: value} +} + +// Type implements Node interface. +func (n LeafNode) Type() NodeType { return LeafT } + +// Hash implements Node interface. +func (n *LeafNode) Hash() util.Uint256 { + if !n.valid { + n.hash = hash.DoubleSha256(toBytes(n)) + n.valid = true + } + return n.hash +} + +// DecodeBinary implements io.Serializable. +func (n *LeafNode) DecodeBinary(r *io.BinReader) { + sz := r.ReadVarUint() + if sz > MaxValueLength { + r.Err = fmt.Errorf("leaf node value is too big: %d", sz) + return + } + n.valid = false + n.value = make([]byte, sz) + r.ReadBytes(n.value) +} + +// EncodeBinary implements io.Serializable. +func (n LeafNode) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(n.value) +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go new file mode 100644 index 000000000..83abb2e60 --- /dev/null +++ b/pkg/core/mpt/node.go @@ -0,0 +1,70 @@ +package mpt + +import ( + "fmt" + + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// NodeType represents node type.. +type NodeType byte + +// Node types definitions. +const ( + BranchT NodeType = 0x00 + ExtensionT NodeType = 0x01 + HashT NodeType = 0x02 + LeafT NodeType = 0x03 +) + +// NodeObject represents Node together with it's type. +// It is used for serialization/deserialization where type info +// is also expected. +type NodeObject struct { + Node +} + +// Node represents common interface of all MPT nodes. +type Node interface { + io.Serializable + Hash() util.Uint256 + Type() NodeType +} + +// EncodeBinary implements io.Serializable. +func (n NodeObject) EncodeBinary(w *io.BinWriter) { + encodeNodeWithType(n.Node, w) +} + +// DecodeBinary implements io.Serializable. +func (n *NodeObject) DecodeBinary(r *io.BinReader) { + typ := NodeType(r.ReadB()) + switch typ { + case BranchT: + n.Node = new(BranchNode) + case ExtensionT: + n.Node = new(ExtensionNode) + case HashT: + n.Node = new(HashNode) + case LeafT: + n.Node = new(LeafNode) + default: + r.Err = fmt.Errorf("invalid node type: %x", typ) + return + } + n.Node.DecodeBinary(r) +} + +// encodeNodeWithType encodes node together with it's type. +func encodeNodeWithType(n Node, w *io.BinWriter) { + w.WriteB(byte(n.Type())) + n.EncodeBinary(w) +} + +// toBytes is a helper for serializing node. +func toBytes(n Node) []byte { + buf := io.NewBufBinWriter() + encodeNodeWithType(n, buf.BinWriter) + return buf.Bytes() +} diff --git a/pkg/core/mpt/node_test.go b/pkg/core/mpt/node_test.go new file mode 100644 index 000000000..0e2c17c96 --- /dev/null +++ b/pkg/core/mpt/node_test.go @@ -0,0 +1,94 @@ +package mpt + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/stretchr/testify/require" +) + +func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) { + return func(t *testing.T) { + bs, err := testserdes.EncodeBinary(expected) + require.NoError(t, err) + err = testserdes.DecodeBinary(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + } +} + +func TestNode_Serializable(t *testing.T) { + t.Run("Leaf", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + l := NewLeafNode(random.Bytes(123)) + t.Run("Raw", getTestFuncEncode(true, l, new(LeafNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{l}, new(NodeObject))) + }) + t.Run("BigValue", getTestFuncEncode(false, + NewLeafNode(random.Bytes(MaxValueLength+1)), new(LeafNode))) + }) + + t.Run("Extension", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + e := NewExtensionNode(random.Bytes(42), NewLeafNode(random.Bytes(10))) + t.Run("Raw", getTestFuncEncode(true, e, new(ExtensionNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{e}, new(NodeObject))) + }) + t.Run("BigKey", getTestFuncEncode(false, + NewExtensionNode(random.Bytes(MaxKeyLength+1), NewLeafNode(random.Bytes(10))), new(ExtensionNode))) + }) + + t.Run("Branch", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewLeafNode(random.Bytes(10)) + b.Children[lastChild] = NewHashNode(random.Uint256()) + t.Run("Raw", getTestFuncEncode(true, b, new(BranchNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{b}, new(NodeObject))) + }) + + t.Run("Hash", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + h := NewHashNode(random.Uint256()) + t.Run("Raw", getTestFuncEncode(true, h, new(HashNode))) + t.Run("WithType", getTestFuncEncode(true, &NodeObject{h}, new(NodeObject))) + }) + t.Run("Empty", func(t *testing.T) { // compare nodes, not hashes + testserdes.EncodeDecodeBinary(t, new(HashNode), new(HashNode)) + }) + t.Run("InvalidSize", func(t *testing.T) { + buf := io.NewBufBinWriter() + buf.BinWriter.WriteVarBytes(make([]byte, 13)) + require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode))) + }) + }) + + t.Run("Invalid", func(t *testing.T) { + require.Error(t, testserdes.DecodeBinary([]byte{0xFF}, new(NodeObject))) + }) +} + +// C# interoperability test +// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135 +func TestRootHash(t *testing.T) { + b := NewBranchNode() + r := NewExtensionNode([]byte{0x0A, 0x0C}, b) + + v1 := NewLeafNode([]byte{0xAB, 0xCD}) + l1 := NewExtensionNode([]byte{0x01}, v1) + b.Children[0] = l1 + + v2 := NewLeafNode([]byte{0x22, 0x22}) + l2 := NewExtensionNode([]byte{0x09}, v2) + b.Children[9] = l2 + + r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1) + require.Equal(t, "dea3ab46e9461e885ed7091c1e533e0a8030b248d39cbc638962394eaca0fbb3", r1.Hash().StringLE()) + require.Equal(t, "93e8e1ffe2f83dd92fca67330e273bcc811bf64b8f8d9d1b25d5e7366b47d60d", r.Hash().StringLE()) +} diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go new file mode 100644 index 000000000..f9589fde3 --- /dev/null +++ b/pkg/core/mpt/trie.go @@ -0,0 +1,357 @@ +package mpt + +import ( + "bytes" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// Trie is an MPT trie storing all key-value pairs. +type Trie struct { + Store *storage.MemCachedStore + + root Node +} + +// ErrNotFound is returned when requested trie item is missing. +var ErrNotFound = errors.New("item not found") + +// NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors +// so that all storage errors are processed during `store.Persist()` at the caller. +// This also has the benefit, that every `Put` can be considered an atomic operation. +func NewTrie(root Node, store *storage.MemCachedStore) *Trie { + if root == nil { + root = new(HashNode) + } + + return &Trie{ + Store: store, + root: root, + } +} + +// Get returns value for the provided key in t. +func (t *Trie) Get(key []byte) ([]byte, error) { + path := toNibbles(key) + r, bs, err := t.getWithPath(t.root, path) + if err != nil { + return nil, err + } + t.root = r + return bs, nil +} + +// getWithPath returns value the provided path in a subtrie rooting in curr. +// It also returns a current node with all hash nodes along the path +// replaced to their "unhashed" counterparts. +func (t *Trie) getWithPath(curr Node, path []byte) (Node, []byte, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + return curr, copySlice(n.value), nil + } + case *BranchNode: + i, path := splitPath(path) + r, bs, err := t.getWithPath(n.Children[i], path) + if err != nil { + return nil, nil, err + } + n.Children[i] = r + return n, bs, nil + case *HashNode: + if !n.IsEmpty() { + if r, err := t.getFromStore(n.hash); err == nil { + return t.getWithPath(r, path) + } + } + case *ExtensionNode: + if bytes.HasPrefix(path, n.key) { + r, bs, err := t.getWithPath(n.next, path[len(n.key):]) + if err != nil { + return nil, nil, err + } + n.next = r + return curr, bs, err + } + default: + panic("invalid MPT node type") + } + return curr, nil, ErrNotFound +} + +// Put puts key-value pair in t. +func (t *Trie) Put(key, value []byte) error { + if len(key) > MaxKeyLength { + return errors.New("key is too big") + } else if len(value) > MaxValueLength { + return errors.New("value is too big") + } + if len(value) == 0 { + return t.Delete(key) + } + path := toNibbles(key) + n := NewLeafNode(value) + r, err := t.putIntoNode(t.root, path, n) + if err != nil { + return err + } + t.root = r + return nil +} + +// putIntoLeaf puts val to trie if current node is a Leaf. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { + v := val.(*LeafNode) + if len(path) == 0 { + return v, nil + } + + b := NewBranchNode() + b.Children[path[0]] = newSubTrie(path[1:], v) + b.Children[lastChild] = curr + return b, nil +} + +// putIntoBranch puts val to trie if current node is a Branch. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { + i, path := splitPath(path) + r, err := t.putIntoNode(curr.Children[i], path, val) + if err != nil { + return nil, err + } + curr.Children[i] = r + curr.invalidateHash() + return curr, nil +} + +// putIntoExtension puts val to trie if current node is an Extension. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { + if bytes.HasPrefix(path, curr.key) { + r, err := t.putIntoNode(curr.next, path[len(curr.key):], val) + if err != nil { + return nil, err + } + curr.next = r + curr.invalidateHash() + return curr, nil + } + + pref := lcp(curr.key, path) + lp := len(pref) + keyTail := curr.key[lp:] + pathTail := path[lp:] + + s1 := newSubTrie(keyTail[1:], curr.next) + b := NewBranchNode() + b.Children[keyTail[0]] = s1 + + i, pathTail := splitPath(pathTail) + s2 := newSubTrie(pathTail, val) + b.Children[i] = s2 + + if lp > 0 { + return NewExtensionNode(copySlice(pref), b), nil + } + return b, nil +} + +// putIntoHash puts val to trie if current node is a HashNode. +// It returns Node if curr needs to be replaced and error if any. +func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { + if curr.IsEmpty() { + return newSubTrie(path, val), nil + } + + result, err := t.getFromStore(curr.hash) + if err != nil { + return nil, err + } + return t.putIntoNode(result, path, val) +} + +// newSubTrie create new trie containing node at provided path. +func newSubTrie(path []byte, val Node) Node { + if len(path) == 0 { + return val + } + return NewExtensionNode(path, val) +} + +func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + return t.putIntoLeaf(n, path, val) + case *BranchNode: + return t.putIntoBranch(n, path, val) + case *ExtensionNode: + return t.putIntoExtension(n, path, val) + case *HashNode: + return t.putIntoHash(n, path, val) + default: + panic("invalid MPT node type") + } +} + +// Delete removes key from trie. +// It returns no error on missing key. +func (t *Trie) Delete(key []byte) error { + path := toNibbles(key) + r, err := t.deleteFromNode(t.root, path) + if err != nil { + return err + } + t.root = r + return nil +} + +func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { + i, path := splitPath(path) + r, err := t.deleteFromNode(b.Children[i], path) + if err != nil { + return nil, err + } + b.Children[i] = r + b.invalidateHash() + var count, index int + for i := range b.Children { + h, ok := b.Children[i].(*HashNode) + if !ok || !h.IsEmpty() { + index = i + count++ + } + } + // count is >= 1 because branch node had at least 2 children before deletion. + if count > 1 { + return b, nil + } + c := b.Children[index] + if index == lastChild { + return c, nil + } + if h, ok := c.(*HashNode); ok { + c, err = t.getFromStore(h.Hash()) + if err != nil { + return nil, err + } + } + if e, ok := c.(*ExtensionNode); ok { + e.key = append([]byte{byte(index)}, e.key...) + e.invalidateHash() + return e, nil + } + + return NewExtensionNode([]byte{byte(index)}, c), nil +} + +func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) { + if !bytes.HasPrefix(path, n.key) { + return nil, ErrNotFound + } + r, err := t.deleteFromNode(n.next, path[len(n.key):]) + if err != nil { + return nil, err + } + switch nxt := r.(type) { + case *ExtensionNode: + n.key = append(n.key, nxt.key...) + n.next = nxt.next + n.invalidateHash() + case *HashNode: + if nxt.IsEmpty() { + return nxt, nil + } + default: + n.next = r + } + return n, nil +} + +func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + return new(HashNode), nil + } + return nil, ErrNotFound + case *BranchNode: + return t.deleteFromBranch(n, path) + case *ExtensionNode: + return t.deleteFromExtension(n, path) + case *HashNode: + if n.IsEmpty() { + return nil, ErrNotFound + } + newNode, err := t.getFromStore(n.Hash()) + if err != nil { + return nil, err + } + return t.deleteFromNode(newNode, path) + default: + panic("invalid MPT node type") + } +} + +// StateRoot returns root hash of t. +func (t *Trie) StateRoot() util.Uint256 { + if hn, ok := t.root.(*HashNode); ok && hn.IsEmpty() { + return util.Uint256{} + } + return t.root.Hash() +} + +func makeStorageKey(mptKey []byte) []byte { + return append([]byte{byte(storage.DataMPT)}, mptKey...) +} + +// Flush puts every node in the trie except Hash ones to the storage. +// Because we care only about block-level changes, there is no need to put every +// new node to storage. Normally, flush should be called with every StateRoot persist, i.e. +// after every block. +func (t *Trie) Flush() { + t.flush(t.root) +} + +func (t *Trie) flush(node Node) { + switch n := node.(type) { + case *BranchNode: + for i := range n.Children { + t.flush(n.Children[i]) + } + case *ExtensionNode: + t.flush(n.next) + case *HashNode: + return + } + t.putToStore(node) +} + +func (t *Trie) putToStore(n Node) { + if n.Type() == HashT { + panic("can't put hash node in trie") + } + bs := toBytes(n) + h := hash.DoubleSha256(bs) + _ = t.Store.Put(makeStorageKey(h.BytesBE()), bs) // put in MemCached returns no errors +} + +func (t *Trie) getFromStore(h util.Uint256) (Node, error) { + data, err := t.Store.Get(makeStorageKey(h.BytesBE())) + if err != nil { + return nil, err + } + + var n NodeObject + r := io.NewBinReaderFromBuf(data) + n.DecodeBinary(r) + if r.Err != nil { + return nil, r.Err + } + return n.Node, nil +} diff --git a/pkg/core/mpt/trie_test.go b/pkg/core/mpt/trie_test.go new file mode 100644 index 000000000..470e0c8e5 --- /dev/null +++ b/pkg/core/mpt/trie_test.go @@ -0,0 +1,373 @@ +package mpt + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/stretchr/testify/require" +) + +func newTestStore() *storage.MemCachedStore { + return storage.NewMemCachedStore(storage.NewMemoryStore()) +} + +func newTestTrie(t *testing.T) *Trie { + b := NewBranchNode() + + l1 := NewLeafNode([]byte{0xAB, 0xCD}) + b.Children[0] = NewExtensionNode([]byte{0x01}, l1) + + l2 := NewLeafNode([]byte{0x22, 0x22}) + b.Children[9] = NewExtensionNode([]byte{0x09}, l2) + + v := NewLeafNode([]byte("hello")) + h := NewHashNode(v.Hash()) + b.Children[10] = NewExtensionNode([]byte{0x0e}, h) + + e := NewExtensionNode(toNibbles([]byte{0xAC}), b) + tr := NewTrie(e, newTestStore()) + + tr.putToStore(e) + tr.putToStore(b) + tr.putToStore(l1) + tr.putToStore(l2) + tr.putToStore(v) + tr.putToStore(b.Children[0]) + tr.putToStore(b.Children[9]) + tr.putToStore(b.Children[10]) + + return tr +} + +func TestTrie_PutIntoBranchNode(t *testing.T) { + b := NewBranchNode() + l := NewLeafNode([]byte{0x8}) + b.Children[0x7] = NewHashNode(l.Hash()) + b.Children[0x8] = NewHashNode(random.Uint256()) + tr := NewTrie(b, newTestStore()) + + // next + require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34})) + tr.testHas(t, []byte{}, []byte{0x12, 0x34}) + + // empty hash node child + require.NoError(t, tr.Put([]byte{0x66}, []byte{0x56})) + tr.testHas(t, []byte{0x66}, []byte{0x56}) + require.True(t, isValid(tr.root)) + + // missing hash + require.Error(t, tr.Put([]byte{0x70}, []byte{0x42})) + require.True(t, isValid(tr.root)) + + // hash is in store + tr.putToStore(l) + require.NoError(t, tr.Put([]byte{0x70}, []byte{0x42})) + require.True(t, isValid(tr.root)) +} + +func TestTrie_PutIntoExtensionNode(t *testing.T) { + l := NewLeafNode([]byte{0x11}) + key := []byte{0x12} + e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash())) + tr := NewTrie(e, newTestStore()) + + // missing hash + require.Error(t, tr.Put(key, []byte{0x42})) + + tr.putToStore(l) + require.NoError(t, tr.Put(key, []byte{0x42})) + tr.testHas(t, key, []byte{0x42}) + require.True(t, isValid(tr.root)) +} + +func TestTrie_PutIntoHashNode(t *testing.T) { + b := NewBranchNode() + l := NewLeafNode(random.Bytes(5)) + e := NewExtensionNode([]byte{0x02}, l) + b.Children[1] = NewHashNode(e.Hash()) + b.Children[9] = NewHashNode(random.Uint256()) + tr := NewTrie(b, newTestStore()) + + tr.putToStore(e) + + t.Run("MissingLeafHash", func(t *testing.T) { + _, err := tr.Get([]byte{0x12}) + require.Error(t, err) + }) + + tr.putToStore(l) + + val := random.Bytes(3) + require.NoError(t, tr.Put([]byte{0x12, 0x34}, val)) + tr.testHas(t, []byte{0x12, 0x34}, val) + tr.testHas(t, []byte{0x12}, l.value) + require.True(t, isValid(tr.root)) +} + +func TestTrie_Put(t *testing.T) { + trExp := newTestTrie(t) + + trAct := NewTrie(nil, newTestStore()) + require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD})) + require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22})) + require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello"))) + + // Note: the exact tries differ because of ("acae":"hello") node is stored as Hash node in test trie. + require.Equal(t, trExp.root.Hash(), trAct.root.Hash()) + require.True(t, isValid(trAct.root)) +} + +func TestTrie_PutInvalid(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + key, value := []byte("key"), []byte("value") + + // big key + require.Error(t, tr.Put(make([]byte, MaxKeyLength+1), value)) + + // big value + require.Error(t, tr.Put(key, make([]byte, MaxValueLength+1))) + + // this is ok though + require.NoError(t, tr.Put(key, value)) + tr.testHas(t, key, value) +} + +func TestTrie_BigPut(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + items := []struct{ k, v string }{ + {"item with long key", "value1"}, + {"item with matching prefix", "value2"}, + {"another prefix", "value3"}, + {"another prefix 2", "value4"}, + {"another ", "value5"}, + } + + for i := range items { + require.NoError(t, tr.Put([]byte(items[i].k), []byte(items[i].v))) + } + + for i := range items { + tr.testHas(t, []byte(items[i].k), []byte(items[i].v)) + } + + t.Run("Rewrite", func(t *testing.T) { + k, v := []byte(items[0].k), []byte{0x01, 0x23} + require.NoError(t, tr.Put(k, v)) + tr.testHas(t, k, v) + }) + + t.Run("Remove", func(t *testing.T) { + k := []byte(items[1].k) + require.NoError(t, tr.Put(k, []byte{})) + tr.testHas(t, k, nil) + }) +} + +func (tr *Trie) testHas(t *testing.T, key, value []byte) { + v, err := tr.Get(key) + if value == nil { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, value, v) +} + +// isValid checks for 3 invariants: +// - BranchNode contains > 1 children +// - ExtensionNode do not contain another extension node +// - ExtensionNode do not have nil key +// It is used only during testing to catch possible bugs. +func isValid(curr Node) bool { + switch n := curr.(type) { + case *BranchNode: + var count int + for i := range n.Children { + if !isValid(n.Children[i]) { + return false + } + hn, ok := n.Children[i].(*HashNode) + if !ok || !hn.IsEmpty() { + count++ + } + } + return count > 1 + case *ExtensionNode: + _, ok := n.next.(*ExtensionNode) + return len(n.key) != 0 && !ok + default: + return true + } +} + +func TestTrie_Get(t *testing.T) { + t.Run("HashNode", func(t *testing.T) { + tr := newTestTrie(t) + tr.testHas(t, []byte{0xAC, 0xAE}, []byte("hello")) + }) + t.Run("UnfoldRoot", func(t *testing.T) { + tr := newTestTrie(t) + single := NewTrie(NewHashNode(tr.root.Hash()), tr.Store) + single.testHas(t, []byte{0xAC}, nil) + single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD}) + single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22}) + single.testHas(t, []byte{0xAC, 0xAE}, []byte("hello")) + }) +} + +func TestTrie_Flush(t *testing.T) { + pairs := map[string][]byte{ + "": []byte("value0"), + "key1": []byte("value1"), + "key2": []byte("value2"), + } + + tr := NewTrie(nil, newTestStore()) + for k, v := range pairs { + require.NoError(t, tr.Put([]byte(k), v)) + } + + tr.Flush() + tr = NewTrie(NewHashNode(tr.StateRoot()), tr.Store) + for k, v := range pairs { + actual, err := tr.Get([]byte(k)) + require.NoError(t, err) + require.Equal(t, v, actual) + } +} + +func TestTrie_Delete(t *testing.T) { + t.Run("Hash", func(t *testing.T) { + t.Run("FromStore", func(t *testing.T) { + l := NewLeafNode([]byte{0x12}) + tr := NewTrie(NewHashNode(l.Hash()), newTestStore()) + t.Run("NotInStore", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{})) + }) + + tr.putToStore(l) + tr.testHas(t, []byte{}, []byte{0x12}) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + }) + + t.Run("Empty", func(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + require.Error(t, tr.Delete([]byte{})) + }) + }) + + t.Run("Leaf", func(t *testing.T) { + l := NewLeafNode([]byte{0x12, 0x34}) + tr := NewTrie(l, newTestStore()) + t.Run("NonExistentKey", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{0x12})) + tr.testHas(t, []byte{}, []byte{0x12, 0x34}) + }) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + }) + + t.Run("Extension", func(t *testing.T) { + t.Run("SingleKey", func(t *testing.T) { + l := NewLeafNode([]byte{0x12, 0x34}) + e := NewExtensionNode([]byte{0x0A, 0x0B}, l) + tr := NewTrie(e, newTestStore()) + + t.Run("NonExistentKey", func(t *testing.T) { + require.Error(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{0xAB}, []byte{0x12, 0x34}) + }) + + require.NoError(t, tr.Delete([]byte{0xAB})) + require.True(t, tr.root.(*HashNode).IsEmpty()) + }) + + t.Run("MultipleKeys", func(t *testing.T) { + b := NewBranchNode() + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34})) + b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78})) + e := NewExtensionNode([]byte{0x01, 0x02}, b) + tr := NewTrie(e, newTestStore()) + + h := e.Hash() + require.NoError(t, tr.Delete([]byte{0x12, 0x01})) + tr.testHas(t, []byte{0x12, 0x01}, nil) + tr.testHas(t, []byte{0x12, 0x67}, []byte{0x56, 0x78}) + + require.NotEqual(t, h, tr.root.Hash()) + require.Equal(t, toNibbles([]byte{0x12, 0x67}), e.key) + require.IsType(t, (*LeafNode)(nil), e.next) + }) + }) + + t.Run("Branch", func(t *testing.T) { + t.Run("3 Children", func(t *testing.T) { + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34})) + b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56})) + tr := NewTrie(b, newTestStore()) + require.NoError(t, tr.Delete([]byte{0x16})) + tr.testHas(t, []byte{}, []byte{0x12}) + tr.testHas(t, []byte{0x01}, []byte{0x34}) + tr.testHas(t, []byte{0x16}, nil) + }) + t.Run("2 Children", func(t *testing.T) { + newt := func(t *testing.T) *Trie { + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + l := NewLeafNode([]byte{0x34}) + e := NewExtensionNode([]byte{0x06}, l) + b.Children[5] = NewHashNode(e.Hash()) + tr := NewTrie(b, newTestStore()) + tr.putToStore(l) + tr.putToStore(e) + return tr + } + + t.Run("DeleteLast", func(t *testing.T) { + t.Run("MergeExtension", func(t *testing.T) { + tr := newt(t) + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + tr.testHas(t, []byte{0x56}, []byte{0x34}) + require.IsType(t, (*ExtensionNode)(nil), tr.root) + }) + + t.Run("LeaveLeaf", func(t *testing.T) { + c := NewBranchNode() + c.Children[5] = NewLeafNode([]byte{0x05}) + c.Children[6] = NewLeafNode([]byte{0x06}) + + b := NewBranchNode() + b.Children[lastChild] = NewLeafNode([]byte{0x12}) + b.Children[5] = c + tr := NewTrie(b, newTestStore()) + + require.NoError(t, tr.Delete([]byte{})) + tr.testHas(t, []byte{}, nil) + tr.testHas(t, []byte{0x55}, []byte{0x05}) + tr.testHas(t, []byte{0x56}, []byte{0x06}) + require.IsType(t, (*ExtensionNode)(nil), tr.root) + }) + }) + + t.Run("DeleteMiddle", func(t *testing.T) { + tr := newt(t) + require.NoError(t, tr.Delete([]byte{0x56})) + tr.testHas(t, []byte{}, []byte{0x12}) + tr.testHas(t, []byte{0x56}, nil) + require.IsType(t, (*LeafNode)(nil), tr.root) + }) + }) + }) +} + +func TestTrie_PanicInvalidRoot(t *testing.T) { + tr := &Trie{Store: newTestStore()} + require.Panics(t, func() { _ = tr.Put([]byte{1}, []byte{2}) }) + require.Panics(t, func() { _, _ = tr.Get([]byte{1}) }) + require.Panics(t, func() { _ = tr.Delete([]byte{1}) }) +} diff --git a/pkg/core/storage/store.go b/pkg/core/storage/store.go index 5e70334ed..575c42ba3 100644 --- a/pkg/core/storage/store.go +++ b/pkg/core/storage/store.go @@ -9,6 +9,7 @@ import ( const ( DataBlock KeyPrefix = 0x01 DataTransaction KeyPrefix = 0x02 + DataMPT KeyPrefix = 0x03 STAccount KeyPrefix = 0x40 STNotification KeyPrefix = 0x4d STContract KeyPrefix = 0x50 From 9b328240dd9f4041e43d7ff6de99f88b77821fe8 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Sun, 24 May 2020 14:23:29 +0300 Subject: [PATCH 02/13] mpt: implement MPT proof Get and Verify Signed-off-by: Evgenii Stratonikov --- pkg/core/mpt/proof.go | 74 ++++++++++++++++++++++++++++++++++++++ pkg/core/mpt/proof_test.go | 73 +++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 pkg/core/mpt/proof.go create mode 100644 pkg/core/mpt/proof_test.go diff --git a/pkg/core/mpt/proof.go b/pkg/core/mpt/proof.go new file mode 100644 index 000000000..f785bd9d4 --- /dev/null +++ b/pkg/core/mpt/proof.go @@ -0,0 +1,74 @@ +package mpt + +import ( + "bytes" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// GetProof returns a proof that key belongs to t. +// Proof consist of serialized nodes occuring on path from the root to the leaf of key. +func (t *Trie) GetProof(key []byte) ([][]byte, error) { + var proof [][]byte + path := toNibbles(key) + r, err := t.getProof(t.root, path, &proof) + if err != nil { + return proof, err + } + t.root = r + return proof, nil +} + +func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + *proofs = append(*proofs, toBytes(n)) + return n, nil + } + case *BranchNode: + *proofs = append(*proofs, toBytes(n)) + i, path := splitPath(path) + r, err := t.getProof(n.Children[i], path, proofs) + if err != nil { + return nil, err + } + n.Children[i] = r + return n, nil + case *ExtensionNode: + if bytes.HasPrefix(path, n.key) { + *proofs = append(*proofs, toBytes(n)) + r, err := t.getProof(n.next, path[len(n.key):], proofs) + if err != nil { + return nil, err + } + n.next = r + return n, nil + } + case *HashNode: + if !n.IsEmpty() { + r, err := t.getFromStore(n.Hash()) + if err != nil { + return nil, err + } + return t.getProof(r, path, proofs) + } + } + return nil, ErrNotFound +} + +// VerifyProof verifies that path indeed belongs to a MPT with the specified root hash. +// It also returns value for the key. +func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) { + path := toNibbles(key) + tr := NewTrie(NewHashNode(rh), storage.NewMemCachedStore(storage.NewMemoryStore())) + for i := range proofs { + h := hash.DoubleSha256(proofs[i]) + // no errors in Put to memory store + _ = tr.Store.Put(makeStorageKey(h[:]), proofs[i]) + } + _, bs, err := tr.getWithPath(tr.root, path) + return bs, err == nil +} diff --git a/pkg/core/mpt/proof_test.go b/pkg/core/mpt/proof_test.go new file mode 100644 index 000000000..17301af15 --- /dev/null +++ b/pkg/core/mpt/proof_test.go @@ -0,0 +1,73 @@ +package mpt + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func newProofTrie(t *testing.T) *Trie { + l := NewLeafNode([]byte("somevalue")) + e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l) + l2 := NewLeafNode([]byte("invalid")) + e2 := NewExtensionNode([]byte{0x05}, NewHashNode(l2.Hash())) + b := NewBranchNode() + b.Children[4] = NewHashNode(e.Hash()) + b.Children[5] = e2 + + tr := NewTrie(b, newTestStore()) + require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1"))) + require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) + tr.putToStore(l) + tr.putToStore(e) + return tr +} + +func TestTrie_GetProof(t *testing.T) { + tr := newProofTrie(t) + + t.Run("MissingKey", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x12}) + require.Error(t, err) + }) + + t.Run("Valid", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x12, 0x31}) + require.NoError(t, err) + }) + + t.Run("MissingHashNode", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x55}) + require.Error(t, err) + }) +} + +func TestVerifyProof(t *testing.T) { + tr := newProofTrie(t) + + t.Run("Simple", func(t *testing.T) { + proof, err := tr.GetProof([]byte{0x12, 0x32}) + require.NoError(t, err) + + t.Run("Good", func(t *testing.T) { + v, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x32}, proof) + require.True(t, ok) + require.Equal(t, []byte("value2"), v) + }) + + t.Run("Bad", func(t *testing.T) { + _, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x31}, proof) + require.False(t, ok) + }) + }) + + t.Run("InsideHash", func(t *testing.T) { + key := []byte{0x45, 0x67} + proof, err := tr.GetProof(key) + require.NoError(t, err) + + v, ok := VerifyProof(tr.root.Hash(), key, proof) + require.True(t, ok) + require.Equal(t, []byte("somevalue"), v) + }) +} From f0b85f8af7914050d865f5ee27b75d95293bd2e8 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 28 May 2020 08:55:12 +0300 Subject: [PATCH 03/13] mpt: implement JSON marshaling/unmarshaling Because there is no distinct type field in JSONized nodes, distinction is made via payload itself, thus all unmarshaling is done via NodeObject. Signed-off-by: Evgenii Stratonikov --- pkg/core/mpt/branch.go | 20 ++++++++++ pkg/core/mpt/extension.go | 24 ++++++++++++ pkg/core/mpt/hash.go | 21 ++++++++++ pkg/core/mpt/leaf.go | 19 +++++++++ pkg/core/mpt/node.go | 78 +++++++++++++++++++++++++++++++++++++ pkg/core/mpt/node_test.go | 82 ++++++++++++++++++++++++++++++++++----- 6 files changed, 234 insertions(+), 10 deletions(-) diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go index ac3f2400f..c4a383075 100644 --- a/pkg/core/mpt/branch.go +++ b/pkg/core/mpt/branch.go @@ -1,6 +1,9 @@ package mpt import ( + "encoding/json" + "errors" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -69,6 +72,23 @@ func (b *BranchNode) DecodeBinary(r *io.BinReader) { } } +// MarshalJSON implements json.Marshaler. +func (b *BranchNode) MarshalJSON() ([]byte, error) { + return json.Marshal(b.Children) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *BranchNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*BranchNode); ok { + *b = *u + return nil + } + return errors.New("expected branch node") +} + // splitPath splits path for a branch node. func splitPath(path []byte) (byte, []byte) { if len(path) != 0 { diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go index 775078827..a337c4de2 100644 --- a/pkg/core/mpt/extension.go +++ b/pkg/core/mpt/extension.go @@ -1,6 +1,9 @@ package mpt import ( + "encoding/hex" + "encoding/json" + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -68,3 +71,24 @@ func (e ExtensionNode) EncodeBinary(w *io.BinWriter) { n := NewHashNode(e.next.Hash()) n.EncodeBinary(w) } + +// MarshalJSON implements json.Marshaler. +func (e *ExtensionNode) MarshalJSON() ([]byte, error) { + m := map[string]interface{}{ + "key": hex.EncodeToString(e.key), + "next": e.next, + } + return json.Marshal(m) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (e *ExtensionNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*ExtensionNode); ok { + *e = *u + return nil + } + return errors.New("expected extension node") +} diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go index a14dec879..51c6095fd 100644 --- a/pkg/core/mpt/hash.go +++ b/pkg/core/mpt/hash.go @@ -1,6 +1,7 @@ package mpt import ( + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/io" @@ -59,3 +60,23 @@ func (h HashNode) EncodeBinary(w *io.BinWriter) { } w.WriteVarBytes(h.hash[:]) } + +// MarshalJSON implements json.Marshaler. +func (h *HashNode) MarshalJSON() ([]byte, error) { + if !h.valid { + return []byte(`{}`), nil + } + return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (h *HashNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*HashNode); ok { + *h = *u + return nil + } + return errors.New("expected hash node") +} diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go index 455ae3feb..4ae509a1c 100644 --- a/pkg/core/mpt/leaf.go +++ b/pkg/core/mpt/leaf.go @@ -1,6 +1,8 @@ package mpt import ( + "encoding/hex" + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -54,3 +56,20 @@ func (n *LeafNode) DecodeBinary(r *io.BinReader) { func (n LeafNode) EncodeBinary(w *io.BinWriter) { w.WriteVarBytes(n.value) } + +// MarshalJSON implements json.Marshaler. +func (n *LeafNode) MarshalJSON() ([]byte, error) { + return []byte(`{"value":"` + hex.EncodeToString(n.value) + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (n *LeafNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*LeafNode); ok { + *n = *u + return nil + } + return errors.New("expected leaf node") +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index 83abb2e60..53a2fdec1 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -1,6 +1,9 @@ package mpt import ( + "encoding/hex" + "encoding/json" + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/io" @@ -28,6 +31,8 @@ type NodeObject struct { // Node represents common interface of all MPT nodes. type Node interface { io.Serializable + json.Marshaler + json.Unmarshaler Hash() util.Uint256 Type() NodeType } @@ -68,3 +73,76 @@ func toBytes(n Node) []byte { encodeNodeWithType(n, buf.BinWriter) return buf.Bytes() } + +// UnmarshalJSON implements json.Unmarshaler. +func (n *NodeObject) UnmarshalJSON(data []byte) error { + var m map[string]json.RawMessage + err := json.Unmarshal(data, &m) + if err != nil { // it can be a branch node + var nodes []NodeObject + if err := json.Unmarshal(data, &nodes); err != nil { + return err + } else if len(nodes) != childrenCount { + return errors.New("invalid length of branch node") + } + + b := NewBranchNode() + for i := range b.Children { + b.Children[i] = nodes[i].Node + } + n.Node = b + return nil + } + + switch len(m) { + case 0: + n.Node = new(HashNode) + case 1: + if v, ok := m["hash"]; ok { + var h util.Uint256 + if err := json.Unmarshal(v, &h); err != nil { + return err + } + n.Node = NewHashNode(h) + } else if v, ok = m["value"]; ok { + b, err := unmarshalHex(v) + if err != nil { + return err + } else if len(b) > MaxValueLength { + return errors.New("leaf value is too big") + } + n.Node = NewLeafNode(b) + } else { + return errors.New("invalid field") + } + case 2: + keyRaw, ok1 := m["key"] + nextRaw, ok2 := m["next"] + if !ok1 || !ok2 { + return errors.New("invalid field") + } + key, err := unmarshalHex(keyRaw) + if err != nil { + return err + } else if len(key) > MaxKeyLength { + return errors.New("extension key is too big") + } + + var next NodeObject + if err := json.Unmarshal(nextRaw, &next); err != nil { + return err + } + n.Node = NewExtensionNode(key, next.Node) + default: + return errors.New("0, 1 or 2 fields expected") + } + return nil +} + +func unmarshalHex(data json.RawMessage) ([]byte, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return hex.DecodeString(s) +} diff --git a/pkg/core/mpt/node_test.go b/pkg/core/mpt/node_test.go index 0e2c17c96..e3aab54d6 100644 --- a/pkg/core/mpt/node_test.go +++ b/pkg/core/mpt/node_test.go @@ -1,26 +1,42 @@ package mpt import ( + "encoding/json" "testing" "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) { return func(t *testing.T) { - bs, err := testserdes.EncodeBinary(expected) - require.NoError(t, err) - err = testserdes.DecodeBinary(bs, actual) - if !ok { - require.Error(t, err) - return - } - require.NoError(t, err) - require.Equal(t, expected.Type(), actual.Type()) - require.Equal(t, expected.Hash(), actual.Hash()) + t.Run("IO", func(t *testing.T) { + bs, err := testserdes.EncodeBinary(expected) + require.NoError(t, err) + err = testserdes.DecodeBinary(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + }) + t.Run("JSON", func(t *testing.T) { + bs, err := json.Marshal(expected) + require.NoError(t, err) + err = json.Unmarshal(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + }) } } @@ -74,6 +90,52 @@ func TestNode_Serializable(t *testing.T) { }) } +// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198 +func TestJSONSharp(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11})) + require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22})) + require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac})) + require.NoError(t, tr.Delete([]byte{0xac, 0x11})) + require.NoError(t, tr.Delete([]byte{0xac, 0x22})) + + js, err := tr.root.MarshalJSON() + require.NoError(t, err) + require.JSONEq(t, `{"key":"0a0c", "next":{"value":"ac"}}`, string(js)) +} + +func TestInvalidJSON(t *testing.T) { + t.Run("InvalidChildrenCount", func(t *testing.T) { + var cs [childrenCount + 1]Node + for i := range cs { + cs[i] = new(HashNode) + } + data, err := json.Marshal(cs) + require.NoError(t, err) + + var n NodeObject + require.Error(t, json.Unmarshal(data, &n)) + }) + + testCases := []struct { + name string + data []byte + }{ + {"WrongFieldCount", []byte(`{"key":"0102", "next": {}, "field": {}}`)}, + {"InvalidField1", []byte(`{"next":{}}`)}, + {"InvalidField2", []byte(`{"key":"0102", "hash":{}}`)}, + {"InvalidKey", []byte(`{"key":"xy", "next":{}}`)}, + {"InvalidNext", []byte(`{"key":"01", "next":[]}`)}, + {"InvalidHash", []byte(`{"hash":"01"}`)}, + {"InvalidValue", []byte(`{"value":1}`)}, + {"InvalidBranch", []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]`)}, + } + for _, tc := range testCases { + var n NodeObject + assert.Errorf(t, json.Unmarshal(tc.data, &n), "no error in "+tc.name) + } +} + // C# interoperability test // https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135 func TestRootHash(t *testing.T) { From 6ca22027d5991fd8a5119030b7b64b5ccf02bde1 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 28 May 2020 11:53:19 +0300 Subject: [PATCH 04/13] mpt: implement (*Trie).Collapse() Because trie size is rather big, it can't be stored in memory. Thus some form of caching should also be implemented. To avoid marshaling/unmarshaling of items which are close to root and are used very frequenly we can save them across the persists. This commit implements pruning items at the specified depth, replacing them by hash nodes. Signed-off-by: Evgenii Stratonikov --- pkg/core/mpt/trie.go | 32 +++++++++++++++++ pkg/core/mpt/trie_test.go | 73 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go index f9589fde3..c5093d614 100644 --- a/pkg/core/mpt/trie.go +++ b/pkg/core/mpt/trie.go @@ -355,3 +355,35 @@ func (t *Trie) getFromStore(h util.Uint256) (Node, error) { } return n.Node, nil } + +// Collapse compresses all nodes at depth n to the hash nodes. +// Note: this function does not perform any kind of storage flushing so +// `Flush()` should be called explicitly before invoking function. +func (t *Trie) Collapse(depth int) { + if depth < 0 { + panic("negative depth") + } + t.root = collapse(depth, t.root) +} + +func collapse(depth int, node Node) Node { + if _, ok := node.(*HashNode); ok { + return node + } else if depth == 0 { + return NewHashNode(node.Hash()) + } + + switch n := node.(type) { + case *BranchNode: + for i := range n.Children { + n.Children[i] = collapse(depth-1, n.Children[i]) + } + case *ExtensionNode: + n.next = collapse(depth-1, n.next) + case *LeafNode: + case *HashNode: + default: + panic("invalid MPT node type") + } + return node +} diff --git a/pkg/core/mpt/trie_test.go b/pkg/core/mpt/trie_test.go index 470e0c8e5..d06e08168 100644 --- a/pkg/core/mpt/trie_test.go +++ b/pkg/core/mpt/trie_test.go @@ -371,3 +371,76 @@ func TestTrie_PanicInvalidRoot(t *testing.T) { require.Panics(t, func() { _, _ = tr.Get([]byte{1}) }) require.Panics(t, func() { _ = tr.Delete([]byte{1}) }) } + +func TestTrie_Collapse(t *testing.T) { + t.Run("PanicNegative", func(t *testing.T) { + tr := newTestTrie(t) + require.Panics(t, func() { tr.Collapse(-1) }) + }) + t.Run("Depth=0", func(t *testing.T) { + tr := newTestTrie(t) + h := tr.root.Hash() + + _, ok := tr.root.(*HashNode) + require.False(t, ok) + + tr.Collapse(0) + _, ok = tr.root.(*HashNode) + require.True(t, ok) + require.Equal(t, h, tr.root.Hash()) + }) + t.Run("Branch,Depth=1", func(t *testing.T) { + b := NewBranchNode() + e := NewExtensionNode([]byte{0x01}, NewLeafNode([]byte("value1"))) + he := e.Hash() + b.Children[0] = e + hb := b.Hash() + + tr := NewTrie(b, newTestStore()) + tr.Collapse(1) + + newb, ok := tr.root.(*BranchNode) + require.True(t, ok) + require.Equal(t, hb, newb.Hash()) + require.IsType(t, (*HashNode)(nil), b.Children[0]) + require.Equal(t, he, b.Children[0].Hash()) + }) + t.Run("Extension,Depth=1", func(t *testing.T) { + l := NewLeafNode([]byte("value")) + hl := l.Hash() + e := NewExtensionNode([]byte{0x01}, l) + h := e.Hash() + tr := NewTrie(e, newTestStore()) + tr.Collapse(1) + + newe, ok := tr.root.(*ExtensionNode) + require.True(t, ok) + require.Equal(t, h, newe.Hash()) + require.IsType(t, (*HashNode)(nil), newe.next) + require.Equal(t, hl, newe.next.Hash()) + }) + t.Run("Leaf", func(t *testing.T) { + l := NewLeafNode([]byte("value")) + tr := NewTrie(l, newTestStore()) + tr.Collapse(10) + require.Equal(t, NewLeafNode([]byte("value")), tr.root) + }) + t.Run("Hash", func(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + tr := NewTrie(new(HashNode), newTestStore()) + require.NotPanics(t, func() { tr.Collapse(1) }) + hn, ok := tr.root.(*HashNode) + require.True(t, ok) + require.True(t, hn.IsEmpty()) + }) + + h := random.Uint256() + hn := NewHashNode(h) + tr := NewTrie(hn, newTestStore()) + tr.Collapse(10) + + newRoot, ok := tr.root.(*HashNode) + require.True(t, ok) + require.Equal(t, NewHashNode(h), newRoot) + }) +} From 475bf2445a697d50d60e2c78b495b4a2fc477c92 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Tue, 2 Jun 2020 23:49:46 +0300 Subject: [PATCH 05/13] mpt: restructure nodes a bit, implement serialization and hash cache It drastically reduces the number of allocations and hash calculations. Signed-off-by: Evgenii Stratonikov --- pkg/core/mpt/base.go | 69 +++++++++++++++++++++++++++++++++++++++ pkg/core/mpt/branch.go | 19 ++++------- pkg/core/mpt/extension.go | 21 ++++-------- pkg/core/mpt/hash.go | 26 +++++++++------ pkg/core/mpt/leaf.go | 20 +++++------- pkg/core/mpt/node.go | 16 +-------- pkg/core/mpt/proof.go | 6 ++-- pkg/core/mpt/trie.go | 15 ++++----- 8 files changed, 117 insertions(+), 75 deletions(-) create mode 100644 pkg/core/mpt/base.go diff --git a/pkg/core/mpt/base.go b/pkg/core/mpt/base.go new file mode 100644 index 000000000..f8bce0b34 --- /dev/null +++ b/pkg/core/mpt/base.go @@ -0,0 +1,69 @@ +package mpt + +import ( + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// BaseNode implements basic things every node needs like caching hash and +// serialized representation. It's a basic node building block intended to be +// included into all node types. +type BaseNode struct { + hash util.Uint256 + bytes []byte + hashValid bool + bytesValid bool +} + +// BaseNodeIface abstracts away basic Node functions. +type BaseNodeIface interface { + Hash() util.Uint256 + Type() NodeType + Bytes() []byte +} + +// getHash returns a hash of this BaseNode. +func (b *BaseNode) getHash(n Node) util.Uint256 { + if !b.hashValid { + b.updateHash(n) + } + return b.hash +} + +// getBytes returns a slice of bytes representing this node. +func (b *BaseNode) getBytes(n Node) []byte { + if !b.bytesValid { + b.updateBytes(n) + } + return b.bytes +} + +// updateHash updates hash field for this BaseNode. +func (b *BaseNode) updateHash(n Node) { + if n.Type() == HashT { + panic("can't update hash for hash node") + } + b.hash = hash.DoubleSha256(b.getBytes(n)) + b.hashValid = true +} + +// updateCache updates hash and bytes fields for this BaseNode. +func (b *BaseNode) updateBytes(n Node) { + buf := io.NewBufBinWriter() + encodeNodeWithType(n, buf.BinWriter) + b.bytes = buf.Bytes() + b.bytesValid = true +} + +// invalidateCache sets all cache fields to invalid state. +func (b *BaseNode) invalidateCache() { + b.bytesValid = false + b.hashValid = false +} + +// encodeNodeWithType encodes node together with it's type. +func encodeNodeWithType(n Node, w *io.BinWriter) { + w.WriteB(byte(n.Type())) + n.EncodeBinary(w) +} diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go index c4a383075..fbad5d29e 100644 --- a/pkg/core/mpt/branch.go +++ b/pkg/core/mpt/branch.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -18,9 +17,7 @@ const ( // BranchNode represents MPT's branch node. type BranchNode struct { - hash util.Uint256 - valid bool - + BaseNode Children [childrenCount]Node } @@ -38,18 +35,14 @@ func NewBranchNode() *BranchNode { // Type implements Node interface. func (b *BranchNode) Type() NodeType { return BranchT } -// Hash implements Node interface. +// Hash implements BaseNode interface. func (b *BranchNode) Hash() util.Uint256 { - if !b.valid { - b.hash = hash.DoubleSha256(toBytes(b)) - b.valid = true - } - return b.hash + return b.getHash(b) } -// invalidateHash invalidates node hash. -func (b *BranchNode) invalidateHash() { - b.valid = false +// Bytes implements BaseNode interface. +func (b *BranchNode) Bytes() []byte { + return b.getBytes(b) } // EncodeBinary implements io.Serializable. diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go index a337c4de2..8bcc11c24 100644 --- a/pkg/core/mpt/extension.go +++ b/pkg/core/mpt/extension.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -16,9 +15,7 @@ const MaxKeyLength = 1125 // ExtensionNode represents MPT's extension node. type ExtensionNode struct { - hash util.Uint256 - valid bool - + BaseNode key []byte next Node } @@ -37,18 +34,14 @@ func NewExtensionNode(key []byte, next Node) *ExtensionNode { // Type implements Node interface. func (e ExtensionNode) Type() NodeType { return ExtensionT } -// Hash implements Node interface. +// Hash implements BaseNode interface. func (e *ExtensionNode) Hash() util.Uint256 { - if !e.valid { - e.hash = hash.DoubleSha256(toBytes(e)) - e.valid = true - } - return e.hash + return e.getHash(e) } -// invalidateHash invalidates node hash. -func (e *ExtensionNode) invalidateHash() { - e.valid = false +// Bytes implements BaseNode interface. +func (e *ExtensionNode) Bytes() []byte { + return e.getBytes(e) } // DecodeBinary implements io.Serializable. @@ -58,11 +51,11 @@ func (e *ExtensionNode) DecodeBinary(r *io.BinReader) { r.Err = fmt.Errorf("extension node key is too big: %d", sz) return } - e.valid = false e.key = make([]byte, sz) r.ReadBytes(e.key) e.next = new(HashNode) e.next.DecodeBinary(r) + e.invalidateCache() } // EncodeBinary implements io.Serializable. diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go index 51c6095fd..42519a1ac 100644 --- a/pkg/core/mpt/hash.go +++ b/pkg/core/mpt/hash.go @@ -10,8 +10,7 @@ import ( // HashNode represents MPT's hash node. type HashNode struct { - hash util.Uint256 - valid bool + BaseNode } var _ Node = (*HashNode)(nil) @@ -19,8 +18,10 @@ var _ Node = (*HashNode)(nil) // NewHashNode returns hash node with the specified hash. func NewHashNode(h util.Uint256) *HashNode { return &HashNode{ - hash: h, - valid: true, + BaseNode: BaseNode{ + hash: h, + hashValid: true, + }, } } @@ -29,23 +30,28 @@ func (h *HashNode) Type() NodeType { return HashT } // Hash implements Node interface. func (h *HashNode) Hash() util.Uint256 { - if !h.valid { + if !h.hashValid { panic("can't get hash of an empty HashNode") } return h.hash } // IsEmpty returns true iff h is an empty node i.e. contains no hash. -func (h *HashNode) IsEmpty() bool { return !h.valid } +func (h *HashNode) IsEmpty() bool { return !h.hashValid } + +// Bytes returns serialized HashNode. +func (h *HashNode) Bytes() []byte { + return h.getBytes(h) +} // DecodeBinary implements io.Serializable. func (h *HashNode) DecodeBinary(r *io.BinReader) { sz := r.ReadVarUint() switch sz { case 0: - h.valid = false + h.hashValid = false case util.Uint256Size: - h.valid = true + h.hashValid = true r.ReadBytes(h.hash[:]) default: r.Err = fmt.Errorf("invalid hash node size: %d", sz) @@ -54,7 +60,7 @@ func (h *HashNode) DecodeBinary(r *io.BinReader) { // EncodeBinary implements io.Serializable. func (h HashNode) EncodeBinary(w *io.BinWriter) { - if !h.valid { + if !h.hashValid { w.WriteVarUint(0) return } @@ -63,7 +69,7 @@ func (h HashNode) EncodeBinary(w *io.BinWriter) { // MarshalJSON implements json.Marshaler. func (h *HashNode) MarshalJSON() ([]byte, error) { - if !h.valid { + if !h.hashValid { return []byte(`{}`), nil } return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go index 4ae509a1c..82dd8eef6 100644 --- a/pkg/core/mpt/leaf.go +++ b/pkg/core/mpt/leaf.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -15,9 +14,7 @@ const MaxValueLength = 1024 * 1024 // LeafNode represents MPT's leaf node. type LeafNode struct { - hash util.Uint256 - valid bool - + BaseNode value []byte } @@ -31,13 +28,14 @@ func NewLeafNode(value []byte) *LeafNode { // Type implements Node interface. func (n LeafNode) Type() NodeType { return LeafT } -// Hash implements Node interface. +// Hash implements BaseNode interface. func (n *LeafNode) Hash() util.Uint256 { - if !n.valid { - n.hash = hash.DoubleSha256(toBytes(n)) - n.valid = true - } - return n.hash + return n.getHash(n) +} + +// Bytes implements BaseNode interface. +func (n *LeafNode) Bytes() []byte { + return n.getBytes(n) } // DecodeBinary implements io.Serializable. @@ -47,9 +45,9 @@ func (n *LeafNode) DecodeBinary(r *io.BinReader) { r.Err = fmt.Errorf("leaf node value is too big: %d", sz) return } - n.valid = false n.value = make([]byte, sz) r.ReadBytes(n.value) + n.invalidateCache() } // EncodeBinary implements io.Serializable. diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index 53a2fdec1..86e675a01 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -33,8 +33,7 @@ type Node interface { io.Serializable json.Marshaler json.Unmarshaler - Hash() util.Uint256 - Type() NodeType + BaseNodeIface } // EncodeBinary implements io.Serializable. @@ -61,19 +60,6 @@ func (n *NodeObject) DecodeBinary(r *io.BinReader) { n.Node.DecodeBinary(r) } -// encodeNodeWithType encodes node together with it's type. -func encodeNodeWithType(n Node, w *io.BinWriter) { - w.WriteB(byte(n.Type())) - n.EncodeBinary(w) -} - -// toBytes is a helper for serializing node. -func toBytes(n Node) []byte { - buf := io.NewBufBinWriter() - encodeNodeWithType(n, buf.BinWriter) - return buf.Bytes() -} - // UnmarshalJSON implements json.Unmarshaler. func (n *NodeObject) UnmarshalJSON(data []byte) error { var m map[string]json.RawMessage diff --git a/pkg/core/mpt/proof.go b/pkg/core/mpt/proof.go index f785bd9d4..5f8fcdc84 100644 --- a/pkg/core/mpt/proof.go +++ b/pkg/core/mpt/proof.go @@ -25,11 +25,11 @@ func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) switch n := curr.(type) { case *LeafNode: if len(path) == 0 { - *proofs = append(*proofs, toBytes(n)) + *proofs = append(*proofs, copySlice(n.Bytes())) return n, nil } case *BranchNode: - *proofs = append(*proofs, toBytes(n)) + *proofs = append(*proofs, copySlice(n.Bytes())) i, path := splitPath(path) r, err := t.getProof(n.Children[i], path, proofs) if err != nil { @@ -39,7 +39,7 @@ func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) return n, nil case *ExtensionNode: if bytes.HasPrefix(path, n.key) { - *proofs = append(*proofs, toBytes(n)) + *proofs = append(*proofs, copySlice(n.Bytes())) r, err := t.getProof(n.next, path[len(n.key):], proofs) if err != nil { return nil, err diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go index c5093d614..d9b6f7b0d 100644 --- a/pkg/core/mpt/trie.go +++ b/pkg/core/mpt/trie.go @@ -5,7 +5,6 @@ import ( "errors" "github.com/nspcc-dev/neo-go/pkg/core/storage" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -126,7 +125,7 @@ func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, err return nil, err } curr.Children[i] = r - curr.invalidateHash() + curr.invalidateCache() return curr, nil } @@ -139,7 +138,7 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod return nil, err } curr.next = r - curr.invalidateHash() + curr.invalidateCache() return curr, nil } @@ -218,7 +217,7 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { return nil, err } b.Children[i] = r - b.invalidateHash() + b.invalidateCache() var count, index int for i := range b.Children { h, ok := b.Children[i].(*HashNode) @@ -243,7 +242,7 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { } if e, ok := c.(*ExtensionNode); ok { e.key = append([]byte{byte(index)}, e.key...) - e.invalidateHash() + e.invalidateCache() return e, nil } @@ -262,7 +261,7 @@ func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) case *ExtensionNode: n.key = append(n.key, nxt.key...) n.next = nxt.next - n.invalidateHash() + n.invalidateCache() case *HashNode: if nxt.IsEmpty() { return nxt, nil @@ -336,9 +335,7 @@ func (t *Trie) putToStore(n Node) { if n.Type() == HashT { panic("can't put hash node in trie") } - bs := toBytes(n) - h := hash.DoubleSha256(bs) - _ = t.Store.Put(makeStorageKey(h.BytesBE()), bs) // put in MemCached returns no errors + _ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors } func (t *Trie) getFromStore(h util.Uint256) (Node, error) { From 2b53877dff71f592d3f7ee01125403b14bd45565 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 3 Jun 2020 00:02:45 +0300 Subject: [PATCH 06/13] mpt: don't flush nodes already present in the DB It's just a waste of time. Signed-off-by: Evgenii Stratonikov --- pkg/core/mpt/base.go | 15 +++++++++++++++ pkg/core/mpt/trie.go | 4 ++++ 2 files changed, 19 insertions(+) diff --git a/pkg/core/mpt/base.go b/pkg/core/mpt/base.go index f8bce0b34..9f10cc333 100644 --- a/pkg/core/mpt/base.go +++ b/pkg/core/mpt/base.go @@ -14,6 +14,8 @@ type BaseNode struct { bytes []byte hashValid bool bytesValid bool + + isFlushed bool } // BaseNodeIface abstracts away basic Node functions. @@ -21,6 +23,8 @@ type BaseNodeIface interface { Hash() util.Uint256 Type() NodeType Bytes() []byte + IsFlushed() bool + SetFlushed() } // getHash returns a hash of this BaseNode. @@ -60,6 +64,17 @@ func (b *BaseNode) updateBytes(n Node) { func (b *BaseNode) invalidateCache() { b.bytesValid = false b.hashValid = false + b.isFlushed = false +} + +// IsFlushed checks for node flush status. +func (b *BaseNode) IsFlushed() bool { + return b.isFlushed +} + +// SetFlushed sets 'flushed' flag to true for this node. +func (b *BaseNode) SetFlushed() { + b.isFlushed = true } // encodeNodeWithType encodes node together with it's type. diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go index d9b6f7b0d..3c38424c0 100644 --- a/pkg/core/mpt/trie.go +++ b/pkg/core/mpt/trie.go @@ -318,6 +318,9 @@ func (t *Trie) Flush() { } func (t *Trie) flush(node Node) { + if node.IsFlushed() { + return + } switch n := node.(type) { case *BranchNode: for i := range n.Children { @@ -336,6 +339,7 @@ func (t *Trie) putToStore(n Node) { panic("can't put hash node in trie") } _ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors + n.SetFlushed() } func (t *Trie) getFromStore(h util.Uint256) (Node, error) { From 0e29382035514dd75ee8b29b831f2ff8c1fa6452 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 29 May 2020 17:20:00 +0300 Subject: [PATCH 07/13] core: update MPT during block processing Signed-off-by: Evgenii Stratonikov --- pkg/core/blockchain.go | 92 +++++++++++++++++++++- pkg/core/blockchainer/blockchainer.go | 2 + pkg/core/dao/dao.go | 65 +++++++++++++++- pkg/core/state/mpt_root.go | 105 ++++++++++++++++++++++++++ pkg/core/state/mpt_root_test.go | 61 +++++++++++++++ pkg/network/helper_test.go | 6 ++ 6 files changed, 326 insertions(+), 5 deletions(-) create mode 100644 pkg/core/state/mpt_root.go create mode 100644 pkg/core/state/mpt_root_test.go diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 7ac605bed..d592eb085 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -221,6 +221,9 @@ func (bc *Blockchain) init() error { } bc.blockHeight = bHeight bc.persistedHeight = bHeight + if err = bc.dao.InitMPT(bHeight); err != nil { + return errors.Wrapf(err, "can't init MPT at height %d", bHeight) + } hashes, err := bc.dao.GetHeaderHashes() if err != nil { @@ -550,6 +553,11 @@ func (bc *Blockchain) processHeader(h *block.Header, batch storage.Batch, header return nil } +// GetStateRoot returns state root for a given height. +func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + return bc.dao.GetStateRoot(height) +} + // storeBlock performs chain update using the block given, it executes all // transactions with all appropriate side-effects and updates Blockchain state. // This is the only way to change Blockchain state. @@ -633,17 +641,38 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } } + root := bc.dao.MPT.StateRoot() + var prevHash util.Uint256 + if block.Index > 0 { + prev, err := bc.dao.GetStateRoot(block.Index - 1) + if err != nil { + return errors.WithMessagef(err, "can't get previous state root") + } + prevHash = prev.Root + } + err := bc.AddStateRoot(&state.MPTRoot{ + MPTRootBase: state.MPTRootBase{ + Index: block.Index, + PrevHash: prevHash, + Root: root, + }, + }) + if err != nil { + return err + } + if bc.config.SaveStorageBatch { bc.lastBatch = cache.DAO.GetBatch() } bc.lock.Lock() - _, err := cache.Persist() + _, err = cache.Persist() if err != nil { bc.lock.Unlock() return err } bc.contracts.Policy.OnPersistEnd(bc.dao) + bc.dao.MPT.Flush() bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) bc.memPool.RemoveStale(bc.isTxStillRelevant, bc) @@ -1194,6 +1223,67 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool { } +// AddStateRoot add new (possibly unverified) state root to the blockchain. +func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { + our, err := bc.GetStateRoot(r.Index) + if err == nil { + if our.Flag == state.Verified { + return nil + } else if r.Witness == nil && our.Witness != nil { + r.Witness = our.Witness + } + } + if err := bc.verifyStateRoot(r); err != nil { + return errors.WithMessage(err, "invalid state root") + } + if r.Index > bc.BlockHeight() { // just put it into the store for future checks + return bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: state.Unverified, + }) + } + + flag := state.Unverified + if r.Witness != nil { + if err := bc.verifyStateRootWitness(r); err != nil { + return errors.WithMessage(err, "can't verify signature") + } + flag = state.Verified + } + return bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: flag, + }) +} + +// verifyStateRoot checks if state root is valid. +func (bc *Blockchain) verifyStateRoot(r *state.MPTRoot) error { + if r.Index == 0 { + return nil + } + prev, err := bc.GetStateRoot(r.Index - 1) + if err != nil { + return errors.New("can't get previous state root") + } else if !prev.Root.Equals(r.PrevHash) { + return errors.New("previous hash mismatch") + } else if prev.Version != r.Version { + return errors.New("version mismatch") + } + return nil +} + +// verifyStateRootWitness verifies that state root signature is correct. +func (bc *Blockchain) verifyStateRootWitness(r *state.MPTRoot) error { + b, err := bc.GetBlock(bc.GetHeaderHash(int(r.Index))) + if err != nil { + return err + } + interopCtx := bc.newInteropContext(trigger.Verification, bc.dao, nil, nil) + interopCtx.Container = r + return bc.verifyHashAgainstScript(b.NextConsensus, r.Witness, interopCtx, true, + bc.contracts.Policy.GetMaxVerificationGas(interopCtx.DAO)) +} + // VerifyTx verifies whether a transaction is bonafide or not. Block parameter // is used for easy interop access and can be omitted for transactions that are // not yet added into any block. diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index 9dcac9e33..1086c6bee 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -20,6 +20,7 @@ type Blockchainer interface { GetConfig() config.ProtocolConfiguration AddHeaders(...*block.Header) error AddBlock(*block.Block) error + AddStateRoot(r *state.MPTRoot) error BlockHeight() uint32 CalculateClaimable(value *big.Int, startHeight, endHeight uint32) *big.Int Close() @@ -42,6 +43,7 @@ type Blockchainer interface { GetValidators() ([]*keys.PublicKey, error) GetStandByValidators() keys.PublicKeys GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) GetStorageItem(id int32, key []byte) *state.StorageItem GetStorageItems(id int32) (map[string]*state.StorageItem, error) GetTestVM(tx *transaction.Transaction) *vm.VM diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index f24c7267c..f55fd68e1 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -8,6 +8,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -33,6 +34,8 @@ type DAO interface { GetNEP5Balances(acc util.Uint160) (*state.NEP5Balances, error) GetNEP5TransferLog(acc util.Uint160, index uint32) (*state.NEP5TransferLog, error) GetAndUpdateNextContractID() (int32, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) + PutStateRoot(root *state.MPTRootState) error GetStorageItem(id int32, key []byte) *state.StorageItem GetStorageItems(id int32) (map[string]*state.StorageItem, error) GetStorageItemsWithPrefix(id int32, prefix []byte) (map[string]*state.StorageItem, error) @@ -58,13 +61,15 @@ type DAO interface { // Simple is memCached wrapper around DB, simple DAO implementation. type Simple struct { + MPT *mpt.Trie Store *storage.MemCachedStore network netmode.Magic } // NewSimple creates new simple dao using provided backend store. func NewSimple(backend storage.Store, network netmode.Magic) *Simple { - return &Simple{Store: storage.NewMemCachedStore(backend), network: network} + st := storage.NewMemCachedStore(backend) + return &Simple{Store: st, network: network, MPT: mpt.NewTrie(nil, st)} } // GetBatch returns currently accumulated DB changeset. @@ -75,7 +80,9 @@ func (dao *Simple) GetBatch() *storage.MemBatch { // GetWrapped returns new DAO instance with another layer of wrapped // MemCachedStore around the current DAO Store. func (dao *Simple) GetWrapped() DAO { - return NewSimple(dao.Store, dao.network) + d := NewSimple(dao.Store, dao.network) + d.MPT = dao.MPT + return d } // GetAndDecode performs get operation and decoding with serializable structures. @@ -288,6 +295,42 @@ func (dao *Simple) PutAppExecResult(aer *state.AppExecResult) error { // -- start storage item. +func makeStateRootKey(height uint32) []byte { + key := make([]byte, 5) + key[0] = byte(storage.DataMPT) + binary.LittleEndian.PutUint32(key[1:], height) + return key +} + +// InitMPT initializes MPT at the given height. +func (dao *Simple) InitMPT(height uint32) error { + if height == 0 { + dao.MPT = mpt.NewTrie(nil, dao.Store) + return nil + } + r, err := dao.GetStateRoot(height) + if err != nil { + return err + } + dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), dao.Store) + return nil +} + +// GetStateRoot returns state root of a given height. +func (dao *Simple) GetStateRoot(height uint32) (*state.MPTRootState, error) { + r := new(state.MPTRootState) + err := dao.GetAndDecode(r, makeStateRootKey(height)) + if err != nil { + return nil, err + } + return r, nil +} + +// PutStateRoot puts state root of a given height into the store. +func (dao *Simple) PutStateRoot(r *state.MPTRootState) error { + return dao.Put(r, makeStateRootKey(r.Index)) +} + // GetStorageItem returns StorageItem if it exists in the given store. func (dao *Simple) GetStorageItem(id int32, key []byte) *state.StorageItem { b, err := dao.Store.Get(makeStorageItemKey(id, key)) @@ -308,13 +351,27 @@ func (dao *Simple) GetStorageItem(id int32, key []byte) *state.StorageItem { // PutStorageItem puts given StorageItem for given id with given // key into the given store. func (dao *Simple) PutStorageItem(id int32, key []byte, si *state.StorageItem) error { - return dao.Put(si, makeStorageItemKey(id, key)) + stKey := makeStorageItemKey(id, key) + buf := io.NewBufBinWriter() + si.EncodeBinary(buf.BinWriter) + if buf.Err != nil { + return buf.Err + } + v := buf.Bytes() + if err := dao.MPT.Put(stKey[1:], v); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Store.Put(stKey, v) } // DeleteStorageItem drops storage item for the given id with the // given key from the store. func (dao *Simple) DeleteStorageItem(id int32, key []byte) error { - return dao.Store.Delete(makeStorageItemKey(id, key)) + stKey := makeStorageItemKey(id, key) + if err := dao.MPT.Delete(stKey[1:]); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Store.Delete(stKey) } // GetStorageItems returns all storage items for a given id. diff --git a/pkg/core/state/mpt_root.go b/pkg/core/state/mpt_root.go new file mode 100644 index 000000000..facf3da45 --- /dev/null +++ b/pkg/core/state/mpt_root.go @@ -0,0 +1,105 @@ +package state + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MPTRootBase represents storage state root. +type MPTRootBase struct { + Version byte + Index uint32 + PrevHash util.Uint256 + Root util.Uint256 +} + +// MPTRoot represents storage state root together with sign info. +type MPTRoot struct { + MPTRootBase + Witness *transaction.Witness +} + +// MPTRootStateFlag represents verification state of the state root. +type MPTRootStateFlag byte + +// Possible verification states of MPTRoot. +const ( + Unverified MPTRootStateFlag = 0x00 + Verified MPTRootStateFlag = 0x01 + Invalid MPTRootStateFlag = 0x03 +) + +// MPTRootState represents state root together with its verification state. +type MPTRootState struct { + MPTRoot + Flag MPTRootStateFlag +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootState) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(s.Flag)) + s.MPTRoot.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootState) DecodeBinary(r *io.BinReader) { + s.Flag = MPTRootStateFlag(r.ReadB()) + s.MPTRoot.DecodeBinary(r) +} + +// GetSignedPart returns part of MPTRootBase which needs to be signed. +func (s *MPTRootBase) GetSignedPart() []byte { + buf := io.NewBufBinWriter() + s.EncodeBinary(buf.BinWriter) + return buf.Bytes() +} + +// Equals checks if s == other. +func (s *MPTRootBase) Equals(other *MPTRootBase) bool { + return s.Version == other.Version && s.Index == other.Index && + s.PrevHash.Equals(other.PrevHash) && s.Root.Equals(other.Root) +} + +// Hash returns hash of s. +func (s *MPTRootBase) Hash() util.Uint256 { + return hash.DoubleSha256(s.GetSignedPart()) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootBase) DecodeBinary(r *io.BinReader) { + s.Version = r.ReadB() + s.Index = r.ReadU32LE() + s.PrevHash.DecodeBinary(r) + s.Root.DecodeBinary(r) +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootBase) EncodeBinary(w *io.BinWriter) { + w.WriteB(s.Version) + w.WriteU32LE(s.Index) + s.PrevHash.EncodeBinary(w) + s.Root.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRoot) DecodeBinary(r *io.BinReader) { + s.MPTRootBase.DecodeBinary(r) + + var ws []transaction.Witness + r.ReadArray(&ws, 1) + if len(ws) == 1 { + s.Witness = &ws[0] + } +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRoot) EncodeBinary(w *io.BinWriter) { + s.MPTRootBase.EncodeBinary(w) + if s.Witness == nil { + w.WriteVarUint(0) + } else { + w.WriteArray([]*transaction.Witness{s.Witness}) + } +} diff --git a/pkg/core/state/mpt_root_test.go b/pkg/core/state/mpt_root_test.go new file mode 100644 index 000000000..15a3ca043 --- /dev/null +++ b/pkg/core/state/mpt_root_test.go @@ -0,0 +1,61 @@ +package state + +import ( + "math/rand" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/stretchr/testify/require" +) + +func testStateRoot() *MPTRoot { + return &MPTRoot{ + MPTRootBase: MPTRootBase{ + Version: byte(rand.Uint32()), + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + } +} + +func TestStateRoot_Serializable(t *testing.T) { + r := testStateRoot() + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + + t.Run("WithWitness", func(t *testing.T) { + r.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + }) +} + +func TestStateRootEquals(t *testing.T) { + r1 := testStateRoot() + r2 := *r1 + require.True(t, r1.Equals(&r2.MPTRootBase)) + + r2.MPTRootBase.Index++ + require.False(t, r1.Equals(&r2.MPTRootBase)) +} + +func TestMPTRootState_Serializable(t *testing.T) { + rs := &MPTRootState{ + MPTRoot: *testStateRoot(), + Flag: 0x04, + } + rs.MPTRoot.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, rs, new(MPTRootState)) +} + +func TestMPTRootStateUnverifiedByDefault(t *testing.T) { + var r MPTRootState + require.Equal(t, Unverified, r.Flag) +} diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 6004c44a4..61cf1939e 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -49,6 +49,9 @@ func (chain *testChain) AddBlock(block *block.Block) error { } return nil } +func (chain *testChain) AddStateRoot(r *state.MPTRoot) error { + panic("TODO") +} func (chain *testChain) BlockHeight() uint32 { return atomic.LoadUint32(&chain.blockheight) } @@ -98,6 +101,9 @@ func (chain testChain) GetEnrollments() ([]state.Validator, error) { func (chain testChain) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) { panic("TODO") } +func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + panic("TODO") +} func (chain testChain) GetStorageItem(id int32, key []byte) *state.StorageItem { panic("TODO") } From caea6d6ca871102be532c80a0fdf4c71e88a1781 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Thu, 4 Jun 2020 17:16:32 +0300 Subject: [PATCH 08/13] mpt: fix extension node cache invalidation It should always be invalidated if something changes in the `next` (below the extension node). Signed-off-by: Evgenii Stratonikov --- pkg/core/mpt/trie.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go index 3c38424c0..08d128d88 100644 --- a/pkg/core/mpt/trie.go +++ b/pkg/core/mpt/trie.go @@ -261,7 +261,6 @@ func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) case *ExtensionNode: n.key = append(n.key, nxt.key...) n.next = nxt.next - n.invalidateCache() case *HashNode: if nxt.IsEmpty() { return nxt, nil @@ -269,6 +268,7 @@ func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) default: n.next = r } + n.invalidateCache() return n, nil } From 58b7e16e0e6fdc0ae58e40723126fdf3525829d8 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Thu, 4 Jun 2020 17:19:30 +0300 Subject: [PATCH 09/13] core: fix PrevHash calculation for MPTRoot This was differing from C# notion of PrevHash. It's not a previous root, but rather a hash of the previous serialized MPTRoot structure (that is to be signed by CNs). Signed-off-by: Evgenii Stratonikov --- pkg/core/blockchain.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index d592eb085..7b2b45064 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -17,6 +17,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/io" @@ -648,7 +649,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { if err != nil { return errors.WithMessagef(err, "can't get previous state root") } - prevHash = prev.Root + prevHash = hash.DoubleSha256(prev.GetSignedPart()) } err := bc.AddStateRoot(&state.MPTRoot{ MPTRootBase: state.MPTRootBase{ @@ -1264,7 +1265,7 @@ func (bc *Blockchain) verifyStateRoot(r *state.MPTRoot) error { prev, err := bc.GetStateRoot(r.Index - 1) if err != nil { return errors.New("can't get previous state root") - } else if !prev.Root.Equals(r.PrevHash) { + } else if !r.PrevHash.Equals(hash.DoubleSha256(prev.GetSignedPart())) { return errors.New("previous hash mismatch") } else if prev.Version != r.Version { return errors.New("version mismatch") From 236438d799eb7de497ae8c7a2f509b7212b2c770 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Thu, 4 Jun 2020 17:25:57 +0300 Subject: [PATCH 10/13] core: do MPT compaction every once in a while We need to compact our in-memory MPT from time to time, otherwise it quickly fills up all available memory. This raises two obvious quesions --- when to do that and to what level do that. As for 'when', I think it's quite easy to use our regular persistence interval as an anchor (and it also frees up some memory), but we can't do that in the persistence routine itself because of synchronization issues (adding some synchronization primitives would add some cost that I'd also like to avoid), so do it indirectly by comparing persisted and current height in `storeBlock`. Choosing proper level is another problem, but if we're to roughly estimate one full branch node to use 1K of memory (usually it's way less than that) then we can easily store 1K of these nodes and that gives us a depth of 10 for our trie. Signed-off-by: Evgenii Stratonikov --- pkg/core/blockchain.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 7b2b45064..c140ec561 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -674,6 +674,12 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } bc.contracts.Policy.OnPersistEnd(bc.dao) bc.dao.MPT.Flush() + // Every persist cycle we also compact our in-memory MPT. + persistedHeight := atomic.LoadUint32(&bc.persistedHeight) + if persistedHeight == block.Index-1 { + // 10 is good and roughly estimated to fit remaining trie into 1M of memory. + bc.dao.MPT.Collapse(10) + } bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) bc.memPool.RemoveStale(bc.isTxStillRelevant, bc) From ab802cdd5f75fe410f2061353f893cfec8287758 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Wed, 3 Jun 2020 18:09:28 +0300 Subject: [PATCH 11/13] state: implement JSON marshaling for MPT* items Signed-off-by: Evgenii Stratonikov --- pkg/core/state/mpt_root.go | 55 ++++++++++++++++++++++++++++----- pkg/core/state/mpt_root_test.go | 39 +++++++++++++++++++++++ 2 files changed, 87 insertions(+), 7 deletions(-) diff --git a/pkg/core/state/mpt_root.go b/pkg/core/state/mpt_root.go index facf3da45..dea3f62fa 100644 --- a/pkg/core/state/mpt_root.go +++ b/pkg/core/state/mpt_root.go @@ -1,6 +1,9 @@ package state import ( + "encoding/json" + "errors" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" @@ -9,16 +12,16 @@ import ( // MPTRootBase represents storage state root. type MPTRootBase struct { - Version byte - Index uint32 - PrevHash util.Uint256 - Root util.Uint256 + Version byte `json:"version"` + Index uint32 `json:"index"` + PrevHash util.Uint256 `json:"prehash"` + Root util.Uint256 `json:"stateroot"` } // MPTRoot represents storage state root together with sign info. type MPTRoot struct { MPTRootBase - Witness *transaction.Witness + Witness *transaction.Witness `json:"witness,omitempty"` } // MPTRootStateFlag represents verification state of the state root. @@ -33,8 +36,8 @@ const ( // MPTRootState represents state root together with its verification state. type MPTRootState struct { - MPTRoot - Flag MPTRootStateFlag + MPTRoot `json:"stateroot"` + Flag MPTRootStateFlag `json:"flag"` } // EncodeBinary implements io.Serializable. @@ -103,3 +106,41 @@ func (s *MPTRoot) EncodeBinary(w *io.BinWriter) { w.WriteArray([]*transaction.Witness{s.Witness}) } } + +// String implements fmt.Stringer. +func (f MPTRootStateFlag) String() string { + switch f { + case Unverified: + return "Unverified" + case Verified: + return "Verified" + case Invalid: + return "Invalid" + default: + return "" + } +} + +// MarshalJSON implements json.Marshaler. +func (f MPTRootStateFlag) MarshalJSON() ([]byte, error) { + return []byte(`"` + f.String() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (f *MPTRootStateFlag) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + switch s { + case "Unverified": + *f = Unverified + case "Verified": + *f = Verified + case "Invalid": + *f = Invalid + default: + return errors.New("unknown flag") + } + return nil +} diff --git a/pkg/core/state/mpt_root_test.go b/pkg/core/state/mpt_root_test.go index 15a3ca043..f1c0b5c61 100644 --- a/pkg/core/state/mpt_root_test.go +++ b/pkg/core/state/mpt_root_test.go @@ -1,12 +1,14 @@ package state import ( + "encoding/json" "math/rand" "testing" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" ) @@ -59,3 +61,40 @@ func TestMPTRootStateUnverifiedByDefault(t *testing.T) { var r MPTRootState require.Equal(t, Unverified, r.Flag) } + +func TestMPTRoot_MarshalJSON(t *testing.T) { + t.Run("Good", func(t *testing.T) { + r := testStateRoot() + rs := &MPTRootState{ + MPTRoot: *r, + Flag: Verified, + } + testserdes.MarshalUnmarshalJSON(t, rs, new(MPTRootState)) + }) + + t.Run("Compatibility", func(t *testing.T) { + js := []byte(`{ + "flag": "Unverified", + "stateroot": { + "version": 1, + "index": 3000000, + "prehash": "0x4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90", + "stateroot": "0xb2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3" + }}`) + + rs := new(MPTRootState) + require.NoError(t, json.Unmarshal(js, &rs)) + + require.EqualValues(t, 1, rs.Version) + require.EqualValues(t, 3000000, rs.Index) + require.Nil(t, rs.Witness) + + u, err := util.Uint256DecodeStringLE("4f30f43af8dd2262fc331c45bfcd9066ebbacda204e6e81371cbd884fe7d6c90") + require.NoError(t, err) + require.Equal(t, u, rs.PrevHash) + + u, err = util.Uint256DecodeStringLE("b2fd7e368a848ef70d27cf44940a35237333ed05f1d971c9408f0eb285e0b6f3") + require.NoError(t, err) + require.Equal(t, u, rs.Root) + }) +} From 5ee3ecf3810ea3d60c67b74604b06c3eb93682fc Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Mon, 22 Jun 2020 10:42:46 +0300 Subject: [PATCH 12/13] core: update verified state root height Signed-off-by: Evgenii Stratonikov --- pkg/core/blockchain.go | 18 ++++++++++++++++-- pkg/core/dao/dao.go | 22 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index c140ec561..08e959bfb 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1235,7 +1235,7 @@ func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { our, err := bc.GetStateRoot(r.Index) if err == nil { if our.Flag == state.Verified { - return nil + return bc.updateStateHeight(r.Index) } else if r.Witness == nil && our.Witness != nil { r.Witness = our.Witness } @@ -1257,10 +1257,24 @@ func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { } flag = state.Verified } - return bc.dao.PutStateRoot(&state.MPTRootState{ + err = bc.dao.PutStateRoot(&state.MPTRootState{ MPTRoot: *r, Flag: flag, }) + if err != nil { + return err + } + return bc.updateStateHeight(r.Index) +} + +func (bc *Blockchain) updateStateHeight(newHeight uint32) error { + h, err := bc.dao.GetCurrentStateRootHeight() + if err != nil { + return errors.WithMessage(err, "can't get current state root height") + } else if newHeight == h+1 { + return bc.dao.PutCurrentStateRootHeight(h + 1) + } + return nil } // verifyStateRoot checks if state root is valid. diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index f55fd68e1..006b1ecdb 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -30,6 +30,7 @@ type DAO interface { GetContractState(hash util.Uint160) (*state.Contract, error) GetCurrentBlockHeight() (uint32, error) GetCurrentHeaderHeight() (i uint32, h util.Uint256, err error) + GetCurrentStateRootHeight() (uint32, error) GetHeaderHashes() ([]util.Uint256, error) GetNEP5Balances(acc util.Uint160) (*state.NEP5Balances, error) GetNEP5TransferLog(acc util.Uint160, index uint32) (*state.NEP5TransferLog, error) @@ -316,6 +317,27 @@ func (dao *Simple) InitMPT(height uint32) error { return nil } +// GetCurrentStateRootHeight returns current state root height. +func (dao *Simple) GetCurrentStateRootHeight() (uint32, error) { + key := []byte{byte(storage.DataMPT)} + val, err := dao.Store.Get(key) + if err != nil { + if err == storage.ErrKeyNotFound { + err = nil + } + return 0, err + } + return binary.LittleEndian.Uint32(val), nil +} + +// PutCurrentStateRootHeight updates current state root height. +func (dao *Simple) PutCurrentStateRootHeight(height uint32) error { + key := []byte{byte(storage.DataMPT)} + val := make([]byte, 4) + binary.LittleEndian.PutUint32(val, height) + return dao.Store.Put(key, val) +} + // GetStateRoot returns state root of a given height. func (dao *Simple) GetStateRoot(height uint32) (*state.MPTRootState, error) { r := new(state.MPTRootState) From e21c65c59fd2941ed9c2d469a3e9a00251fbd295 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 24 Jun 2020 14:47:08 +0300 Subject: [PATCH 13/13] core: add state height to prometheus metrics Signed-off-by: Evgenii Stratonikov --- pkg/core/blockchain.go | 1 + pkg/core/prometheus.go | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 08e959bfb..055008e86 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1272,6 +1272,7 @@ func (bc *Blockchain) updateStateHeight(newHeight uint32) error { if err != nil { return errors.WithMessage(err, "can't get current state root height") } else if newHeight == h+1 { + updateStateHeightMetric(newHeight) return bc.dao.PutCurrentStateRootHeight(h + 1) } return nil diff --git a/pkg/core/prometheus.go b/pkg/core/prometheus.go index b81fb847d..c849e3459 100644 --- a/pkg/core/prometheus.go +++ b/pkg/core/prometheus.go @@ -30,6 +30,14 @@ var ( Namespace: "neogo", }, ) + //stateHeight prometheus metric. + stateHeight = prometheus.NewGauge( + prometheus.GaugeOpts{ + Help: "Current verified state height", + Name: "current_state_height", + Namespace: "neogo", + }, + ) ) func init() { @@ -51,3 +59,7 @@ func updateHeaderHeightMetric(hHeight int) { func updateBlockHeightMetric(bHeight uint32) { blockHeight.Set(float64(bHeight)) } + +func updateStateHeightMetric(sHeight uint32) { + stateHeight.Set(float64(sHeight)) +}