Merge pull request #990 from nspcc-dev/feature/mpt

Initial MPT implementation (2.x)
This commit is contained in:
Roman Khimov 2020-06-01 18:58:35 +03:00 committed by GitHub
commit 2f90a06db3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1784 additions and 29 deletions

View file

@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/util"
)
@ -33,35 +34,7 @@ func toNeoStorageKey(key []byte) []byte {
if len(key) < util.Uint160Size {
panic("invalid key in storage")
}
var nkey []byte
for i := util.Uint160Size - 1; i >= 0; i-- {
nkey = append(nkey, key[i])
}
key = key[util.Uint160Size:]
index := 0
remain := len(key)
for remain >= 16 {
nkey = append(nkey, key[index:index+16]...)
nkey = append(nkey, 0)
index += 16
remain -= 16
}
if remain > 0 {
nkey = append(nkey, key[index:]...)
}
padding := 16 - remain
for i := 0; i < padding; i++ {
nkey = append(nkey, 0)
}
nkey = append(nkey, byte(padding))
return nkey
return mpt.ToNeoStorageKey(key)
}
// batchToMap converts batch to a map so that JSON is compatible

98
pkg/core/mpt/branch.go Normal file
View file

@ -0,0 +1,98 @@
package mpt
import (
"encoding/json"
"errors"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
const (
// childrenCount represents a number of children of a branch node.
childrenCount = 17
// lastChild is the index of the last child.
lastChild = childrenCount - 1
)
// BranchNode represents MPT's branch node.
type BranchNode struct {
hash util.Uint256
valid bool
Children [childrenCount]Node
}
var _ Node = (*BranchNode)(nil)
// NewBranchNode returns new branch node.
func NewBranchNode() *BranchNode {
b := new(BranchNode)
for i := 0; i < childrenCount; i++ {
b.Children[i] = new(HashNode)
}
return b
}
// Type implements Node interface.
func (b *BranchNode) Type() NodeType { return BranchT }
// Hash implements Node interface.
func (b *BranchNode) Hash() util.Uint256 {
if !b.valid {
b.hash = hash.DoubleSha256(toBytes(b))
b.valid = true
}
return b.hash
}
// invalidateHash invalidates node hash.
func (b *BranchNode) invalidateHash() {
b.valid = false
}
// EncodeBinary implements io.Serializable.
func (b *BranchNode) EncodeBinary(w *io.BinWriter) {
for i := 0; i < childrenCount; i++ {
if hn, ok := b.Children[i].(*HashNode); ok {
hn.EncodeBinary(w)
continue
}
n := NewHashNode(b.Children[i].Hash())
n.EncodeBinary(w)
}
}
// DecodeBinary implements io.Serializable.
func (b *BranchNode) DecodeBinary(r *io.BinReader) {
for i := 0; i < childrenCount; i++ {
b.Children[i] = new(HashNode)
b.Children[i].DecodeBinary(r)
}
}
// MarshalJSON implements json.Marshaler.
func (b *BranchNode) MarshalJSON() ([]byte, error) {
return json.Marshal(b.Children)
}
// UnmarshalJSON implements json.Unmarshaler.
func (b *BranchNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*BranchNode); ok {
*b = *u
return nil
}
return errors.New("expected branch node")
}
// splitPath splits path for a branch node.
func splitPath(path []byte) (byte, []byte) {
if len(path) != 0 {
return path[0], path[1:]
}
return lastChild, path
}

45
pkg/core/mpt/doc.go Normal file
View file

@ -0,0 +1,45 @@
/*
Package mpt implements MPT (Merkle-Patricia Tree).
MPT stores key-value pairs and is a trie over 16-symbol alphabet. https://en.wikipedia.org/wiki/Trie
Trie is a tree where values are stored in leafs and keys are paths from root to the leaf node.
MPT consists of 4 type of nodes:
- Leaf node contains only value.
- Extension node contains both key and value.
- Branch node contains 2 or more children.
- Hash node is a compressed node and contains only actual node's hash.
The actual node must be retrieved from storage or over the network.
As an example here is a trie containing 3 pairs:
- 0x1201 -> val1
- 0x1203 -> val2
- 0x1224 -> val3
- 0x12 -> val4
ExtensionNode(0x0102), Next
_______________________|
|
BranchNode [0, 1, 2, ...], Last -> Leaf(val4)
| |
| ExtensionNode [0x04], Next -> Leaf(val3)
|
BranchNode [0, 1, 2, 3, ...], Last -> HashNode(nil)
| |
| Leaf(val2)
|
Leaf(val1)
There are 3 invariants that this implementation has:
- Branch node cannot have <= 1 children
- Extension node cannot have zero-length key
- Extension node cannot have another Extension node in it's next field
Thank to these restrictions, there is a single root hash for every set of key-value pairs
irregardless of the order they were added/removed with.
The actual trie structure can vary because of node -> HashNode compressing.
There is also one optimization which cost us almost nothing in terms of complexity but is very beneficial:
When we perform get/put/delete on a speficic path, every Hash node which was retreived from storage is
replaced by its uncompressed form, so that subsequent hits of this not don't use storage.
*/
package mpt

94
pkg/core/mpt/extension.go Normal file
View file

@ -0,0 +1,94 @@
package mpt
import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// MaxKeyLength is the max length of the extension node key.
const MaxKeyLength = 1125
// ExtensionNode represents MPT's extension node.
type ExtensionNode struct {
hash util.Uint256
valid bool
key []byte
next Node
}
var _ Node = (*ExtensionNode)(nil)
// NewExtensionNode returns hash node with the specified key and next node.
// Note: because it is a part of Trie, key must be mangled, i.e. must contain only bytes with high half = 0.
func NewExtensionNode(key []byte, next Node) *ExtensionNode {
return &ExtensionNode{
key: key,
next: next,
}
}
// Type implements Node interface.
func (e ExtensionNode) Type() NodeType { return ExtensionT }
// Hash implements Node interface.
func (e *ExtensionNode) Hash() util.Uint256 {
if !e.valid {
e.hash = hash.DoubleSha256(toBytes(e))
e.valid = true
}
return e.hash
}
// invalidateHash invalidates node hash.
func (e *ExtensionNode) invalidateHash() {
e.valid = false
}
// DecodeBinary implements io.Serializable.
func (e *ExtensionNode) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint()
if sz > MaxKeyLength {
r.Err = fmt.Errorf("extension node key is too big: %d", sz)
return
}
e.valid = false
e.key = make([]byte, sz)
r.ReadBytes(e.key)
e.next = new(HashNode)
e.next.DecodeBinary(r)
}
// EncodeBinary implements io.Serializable.
func (e ExtensionNode) EncodeBinary(w *io.BinWriter) {
w.WriteVarBytes(e.key)
n := NewHashNode(e.next.Hash())
n.EncodeBinary(w)
}
// MarshalJSON implements json.Marshaler.
func (e *ExtensionNode) MarshalJSON() ([]byte, error) {
m := map[string]interface{}{
"key": hex.EncodeToString(e.key),
"next": e.next,
}
return json.Marshal(m)
}
// UnmarshalJSON implements json.Unmarshaler.
func (e *ExtensionNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*ExtensionNode); ok {
*e = *u
return nil
}
return errors.New("expected extension node")
}

82
pkg/core/mpt/hash.go Normal file
View file

@ -0,0 +1,82 @@
package mpt
import (
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// HashNode represents MPT's hash node.
type HashNode struct {
hash util.Uint256
valid bool
}
var _ Node = (*HashNode)(nil)
// NewHashNode returns hash node with the specified hash.
func NewHashNode(h util.Uint256) *HashNode {
return &HashNode{
hash: h,
valid: true,
}
}
// Type implements Node interface.
func (h *HashNode) Type() NodeType { return HashT }
// Hash implements Node interface.
func (h *HashNode) Hash() util.Uint256 {
if !h.valid {
panic("can't get hash of an empty HashNode")
}
return h.hash
}
// IsEmpty returns true iff h is an empty node i.e. contains no hash.
func (h *HashNode) IsEmpty() bool { return !h.valid }
// DecodeBinary implements io.Serializable.
func (h *HashNode) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint()
switch sz {
case 0:
h.valid = false
case util.Uint256Size:
h.valid = true
r.ReadBytes(h.hash[:])
default:
r.Err = fmt.Errorf("invalid hash node size: %d", sz)
}
}
// EncodeBinary implements io.Serializable.
func (h HashNode) EncodeBinary(w *io.BinWriter) {
if !h.valid {
w.WriteVarUint(0)
return
}
w.WriteVarBytes(h.hash[:])
}
// MarshalJSON implements json.Marshaler.
func (h *HashNode) MarshalJSON() ([]byte, error) {
if !h.valid {
return []byte(`{}`), nil
}
return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (h *HashNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*HashNode); ok {
*h = *u
return nil
}
return errors.New("expected hash node")
}

71
pkg/core/mpt/helpers.go Normal file
View file

@ -0,0 +1,71 @@
package mpt
import "github.com/nspcc-dev/neo-go/pkg/util"
// lcp returns longest common prefix of a and b.
// Note: it does no allocations.
func lcp(a, b []byte) []byte {
if len(a) < len(b) {
return lcp(b, a)
}
var i int
for i = 0; i < len(b); i++ {
if a[i] != b[i] {
break
}
}
return a[:i]
}
// copySlice is a helper for copying slice if needed.
func copySlice(a []byte) []byte {
b := make([]byte, len(a))
copy(b, a)
return b
}
// toNibbles mangles path by splitting every byte into 2 containing low- and high- 4-byte part.
func toNibbles(path []byte) []byte {
result := make([]byte, len(path)*2)
for i := range path {
result[i*2] = path[i] >> 4
result[i*2+1] = path[i] & 0x0F
}
return result
}
// ToNeoStorageKey converts storage key to C# neo node's format.
// Key is expected to be at least 20 bytes in length.
// our format: script hash in BE + key
// neo format: script hash in LE + key with 0 between every 16 bytes, padded to len 16.
func ToNeoStorageKey(key []byte) []byte {
const groupSize = 16
var nkey []byte
for i := util.Uint160Size - 1; i >= 0; i-- {
nkey = append(nkey, key[i])
}
key = key[util.Uint160Size:]
index := 0
remain := len(key)
for remain >= groupSize {
nkey = append(nkey, key[index:index+groupSize]...)
nkey = append(nkey, 0)
index += groupSize
remain -= groupSize
}
if remain > 0 {
nkey = append(nkey, key[index:]...)
}
padding := groupSize - remain
for i := 0; i < padding; i++ {
nkey = append(nkey, 0)
}
return append(nkey, byte(padding))
}

View file

@ -0,0 +1,30 @@
package mpt
import (
"encoding/hex"
"testing"
"github.com/stretchr/testify/require"
)
func TestToNeoStorageKey(t *testing.T) {
testCases := []struct{ key, res string }{
{
"0102030405060708091011121314151617181920",
"20191817161514131211100908070605040302010000000000000000000000000000000010",
},
{
"010203040506070809101112131415161718192021222324",
"2019181716151413121110090807060504030201212223240000000000000000000000000c",
},
{
"0102030405060708091011121314151617181920212223242526272829303132333435363738",
"20191817161514131211100908070605040302012122232425262728293031323334353600373800000000000000000000000000000e",
},
}
for _, tc := range testCases {
key, _ := hex.DecodeString(tc.key)
res, _ := hex.DecodeString(tc.res)
require.Equal(t, res, ToNeoStorageKey(key))
}
}

75
pkg/core/mpt/leaf.go Normal file
View file

@ -0,0 +1,75 @@
package mpt
import (
"encoding/hex"
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// MaxValueLength is a max length of a leaf node value.
const MaxValueLength = 1024 * 1024
// LeafNode represents MPT's leaf node.
type LeafNode struct {
hash util.Uint256
valid bool
value []byte
}
var _ Node = (*LeafNode)(nil)
// NewLeafNode returns hash node with the specified value.
func NewLeafNode(value []byte) *LeafNode {
return &LeafNode{value: value}
}
// Type implements Node interface.
func (n LeafNode) Type() NodeType { return LeafT }
// Hash implements Node interface.
func (n *LeafNode) Hash() util.Uint256 {
if !n.valid {
n.hash = hash.DoubleSha256(toBytes(n))
n.valid = true
}
return n.hash
}
// DecodeBinary implements io.Serializable.
func (n *LeafNode) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint()
if sz > MaxValueLength {
r.Err = fmt.Errorf("leaf node value is too big: %d", sz)
return
}
n.valid = false
n.value = make([]byte, sz)
r.ReadBytes(n.value)
}
// EncodeBinary implements io.Serializable.
func (n LeafNode) EncodeBinary(w *io.BinWriter) {
w.WriteVarBytes(n.value)
}
// MarshalJSON implements json.Marshaler.
func (n *LeafNode) MarshalJSON() ([]byte, error) {
return []byte(`{"value":"` + hex.EncodeToString(n.value) + `"}`), nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *LeafNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*LeafNode); ok {
*n = *u
return nil
}
return errors.New("expected leaf node")
}

148
pkg/core/mpt/node.go Normal file
View file

@ -0,0 +1,148 @@
package mpt
import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// NodeType represents node type..
type NodeType byte
// Node types definitions.
const (
BranchT NodeType = 0x00
ExtensionT NodeType = 0x01
HashT NodeType = 0x02
LeafT NodeType = 0x03
)
// NodeObject represents Node together with it's type.
// It is used for serialization/deserialization where type info
// is also expected.
type NodeObject struct {
Node
}
// Node represents common interface of all MPT nodes.
type Node interface {
io.Serializable
json.Marshaler
json.Unmarshaler
Hash() util.Uint256
Type() NodeType
}
// EncodeBinary implements io.Serializable.
func (n NodeObject) EncodeBinary(w *io.BinWriter) {
encodeNodeWithType(n.Node, w)
}
// DecodeBinary implements io.Serializable.
func (n *NodeObject) DecodeBinary(r *io.BinReader) {
typ := NodeType(r.ReadB())
switch typ {
case BranchT:
n.Node = new(BranchNode)
case ExtensionT:
n.Node = new(ExtensionNode)
case HashT:
n.Node = new(HashNode)
case LeafT:
n.Node = new(LeafNode)
default:
r.Err = fmt.Errorf("invalid node type: %x", typ)
return
}
n.Node.DecodeBinary(r)
}
// encodeNodeWithType encodes node together with it's type.
func encodeNodeWithType(n Node, w *io.BinWriter) {
w.WriteB(byte(n.Type()))
n.EncodeBinary(w)
}
// toBytes is a helper for serializing node.
func toBytes(n Node) []byte {
buf := io.NewBufBinWriter()
encodeNodeWithType(n, buf.BinWriter)
return buf.Bytes()
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *NodeObject) UnmarshalJSON(data []byte) error {
var m map[string]json.RawMessage
err := json.Unmarshal(data, &m)
if err != nil { // it can be a branch node
var nodes []NodeObject
if err := json.Unmarshal(data, &nodes); err != nil {
return err
} else if len(nodes) != childrenCount {
return errors.New("invalid length of branch node")
}
b := NewBranchNode()
for i := range b.Children {
b.Children[i] = nodes[i].Node
}
n.Node = b
return nil
}
switch len(m) {
case 0:
n.Node = new(HashNode)
case 1:
if v, ok := m["hash"]; ok {
var h util.Uint256
if err := json.Unmarshal(v, &h); err != nil {
return err
}
n.Node = NewHashNode(h)
} else if v, ok = m["value"]; ok {
b, err := unmarshalHex(v)
if err != nil {
return err
} else if len(b) > MaxValueLength {
return errors.New("leaf value is too big")
}
n.Node = NewLeafNode(b)
} else {
return errors.New("invalid field")
}
case 2:
keyRaw, ok1 := m["key"]
nextRaw, ok2 := m["next"]
if !ok1 || !ok2 {
return errors.New("invalid field")
}
key, err := unmarshalHex(keyRaw)
if err != nil {
return err
} else if len(key) > MaxKeyLength {
return errors.New("extension key is too big")
}
var next NodeObject
if err := json.Unmarshal(nextRaw, &next); err != nil {
return err
}
n.Node = NewExtensionNode(key, next.Node)
default:
return errors.New("0, 1 or 2 fields expected")
}
return nil
}
func unmarshalHex(data json.RawMessage) ([]byte, error) {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return hex.DecodeString(s)
}

156
pkg/core/mpt/node_test.go Normal file
View file

@ -0,0 +1,156 @@
package mpt
import (
"encoding/json"
"testing"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) {
return func(t *testing.T) {
t.Run("IO", func(t *testing.T) {
bs, err := testserdes.EncodeBinary(expected)
require.NoError(t, err)
err = testserdes.DecodeBinary(bs, actual)
if !ok {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, expected.Type(), actual.Type())
require.Equal(t, expected.Hash(), actual.Hash())
})
t.Run("JSON", func(t *testing.T) {
bs, err := json.Marshal(expected)
require.NoError(t, err)
err = json.Unmarshal(bs, actual)
if !ok {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, expected.Type(), actual.Type())
require.Equal(t, expected.Hash(), actual.Hash())
})
}
}
func TestNode_Serializable(t *testing.T) {
t.Run("Leaf", func(t *testing.T) {
t.Run("Good", func(t *testing.T) {
l := NewLeafNode(random.Bytes(123))
t.Run("Raw", getTestFuncEncode(true, l, new(LeafNode)))
t.Run("WithType", getTestFuncEncode(true, &NodeObject{l}, new(NodeObject)))
})
t.Run("BigValue", getTestFuncEncode(false,
NewLeafNode(random.Bytes(MaxValueLength+1)), new(LeafNode)))
})
t.Run("Extension", func(t *testing.T) {
t.Run("Good", func(t *testing.T) {
e := NewExtensionNode(random.Bytes(42), NewLeafNode(random.Bytes(10)))
t.Run("Raw", getTestFuncEncode(true, e, new(ExtensionNode)))
t.Run("WithType", getTestFuncEncode(true, &NodeObject{e}, new(NodeObject)))
})
t.Run("BigKey", getTestFuncEncode(false,
NewExtensionNode(random.Bytes(MaxKeyLength+1), NewLeafNode(random.Bytes(10))), new(ExtensionNode)))
})
t.Run("Branch", func(t *testing.T) {
b := NewBranchNode()
b.Children[0] = NewLeafNode(random.Bytes(10))
b.Children[lastChild] = NewHashNode(random.Uint256())
t.Run("Raw", getTestFuncEncode(true, b, new(BranchNode)))
t.Run("WithType", getTestFuncEncode(true, &NodeObject{b}, new(NodeObject)))
})
t.Run("Hash", func(t *testing.T) {
t.Run("Good", func(t *testing.T) {
h := NewHashNode(random.Uint256())
t.Run("Raw", getTestFuncEncode(true, h, new(HashNode)))
t.Run("WithType", getTestFuncEncode(true, &NodeObject{h}, new(NodeObject)))
})
t.Run("Empty", func(t *testing.T) { // compare nodes, not hashes
testserdes.EncodeDecodeBinary(t, new(HashNode), new(HashNode))
})
t.Run("InvalidSize", func(t *testing.T) {
buf := io.NewBufBinWriter()
buf.BinWriter.WriteVarBytes(make([]byte, 13))
require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode)))
})
})
t.Run("Invalid", func(t *testing.T) {
require.Error(t, testserdes.DecodeBinary([]byte{0xFF}, new(NodeObject)))
})
}
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198
func TestJSONSharp(t *testing.T) {
tr := NewTrie(nil, newTestStore())
require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11}))
require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22}))
require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac}))
require.NoError(t, tr.Delete([]byte{0xac, 0x11}))
require.NoError(t, tr.Delete([]byte{0xac, 0x22}))
js, err := tr.root.MarshalJSON()
require.NoError(t, err)
require.JSONEq(t, `{"key":"0a0c", "next":{"value":"ac"}}`, string(js))
}
func TestInvalidJSON(t *testing.T) {
t.Run("InvalidChildrenCount", func(t *testing.T) {
var cs [childrenCount + 1]Node
for i := range cs {
cs[i] = new(HashNode)
}
data, err := json.Marshal(cs)
require.NoError(t, err)
var n NodeObject
require.Error(t, json.Unmarshal(data, &n))
})
testCases := []struct {
name string
data []byte
}{
{"WrongFieldCount", []byte(`{"key":"0102", "next": {}, "field": {}}`)},
{"InvalidField1", []byte(`{"next":{}}`)},
{"InvalidField2", []byte(`{"key":"0102", "hash":{}}`)},
{"InvalidKey", []byte(`{"key":"xy", "next":{}}`)},
{"InvalidNext", []byte(`{"key":"01", "next":[]}`)},
{"InvalidHash", []byte(`{"hash":"01"}`)},
{"InvalidValue", []byte(`{"value":1}`)},
{"InvalidBranch", []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]`)},
}
for _, tc := range testCases {
var n NodeObject
assert.Errorf(t, json.Unmarshal(tc.data, &n), "no error in "+tc.name)
}
}
// C# interoperability test
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135
func TestRootHash(t *testing.T) {
b := NewBranchNode()
r := NewExtensionNode([]byte{0x0A, 0x0C}, b)
v1 := NewLeafNode([]byte{0xAB, 0xCD})
l1 := NewExtensionNode([]byte{0x01}, v1)
b.Children[0] = l1
v2 := NewLeafNode([]byte{0x22, 0x22})
l2 := NewExtensionNode([]byte{0x09}, v2)
b.Children[9] = l2
r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1)
require.Equal(t, "dea3ab46e9461e885ed7091c1e533e0a8030b248d39cbc638962394eaca0fbb3", r1.Hash().StringLE())
require.Equal(t, "93e8e1ffe2f83dd92fca67330e273bcc811bf64b8f8d9d1b25d5e7366b47d60d", r.Hash().StringLE())
}

74
pkg/core/mpt/proof.go Normal file
View file

@ -0,0 +1,74 @@
package mpt
import (
"bytes"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// GetProof returns a proof that key belongs to t.
// Proof consist of serialized nodes occuring on path from the root to the leaf of key.
func (t *Trie) GetProof(key []byte) ([][]byte, error) {
var proof [][]byte
path := toNibbles(key)
r, err := t.getProof(t.root, path, &proof)
if err != nil {
return proof, err
}
t.root = r
return proof, nil
}
func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) {
switch n := curr.(type) {
case *LeafNode:
if len(path) == 0 {
*proofs = append(*proofs, toBytes(n))
return n, nil
}
case *BranchNode:
*proofs = append(*proofs, toBytes(n))
i, path := splitPath(path)
r, err := t.getProof(n.Children[i], path, proofs)
if err != nil {
return nil, err
}
n.Children[i] = r
return n, nil
case *ExtensionNode:
if bytes.HasPrefix(path, n.key) {
*proofs = append(*proofs, toBytes(n))
r, err := t.getProof(n.next, path[len(n.key):], proofs)
if err != nil {
return nil, err
}
n.next = r
return n, nil
}
case *HashNode:
if !n.IsEmpty() {
r, err := t.getFromStore(n.Hash())
if err != nil {
return nil, err
}
return t.getProof(r, path, proofs)
}
}
return nil, ErrNotFound
}
// VerifyProof verifies that path indeed belongs to a MPT with the specified root hash.
// It also returns value for the key.
func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) {
path := toNibbles(key)
tr := NewTrie(NewHashNode(rh), storage.NewMemCachedStore(storage.NewMemoryStore()))
for i := range proofs {
h := hash.DoubleSha256(proofs[i])
// no errors in Put to memory store
_ = tr.Store.Put(makeStorageKey(h[:]), proofs[i])
}
_, bs, err := tr.getWithPath(tr.root, path)
return bs, err == nil
}

View file

@ -0,0 +1,73 @@
package mpt
import (
"testing"
"github.com/stretchr/testify/require"
)
func newProofTrie(t *testing.T) *Trie {
l := NewLeafNode([]byte("somevalue"))
e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l)
l2 := NewLeafNode([]byte("invalid"))
e2 := NewExtensionNode([]byte{0x05}, NewHashNode(l2.Hash()))
b := NewBranchNode()
b.Children[4] = NewHashNode(e.Hash())
b.Children[5] = e2
tr := NewTrie(b, newTestStore())
require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1")))
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
tr.putToStore(l)
tr.putToStore(e)
return tr
}
func TestTrie_GetProof(t *testing.T) {
tr := newProofTrie(t)
t.Run("MissingKey", func(t *testing.T) {
_, err := tr.GetProof([]byte{0x12})
require.Error(t, err)
})
t.Run("Valid", func(t *testing.T) {
_, err := tr.GetProof([]byte{0x12, 0x31})
require.NoError(t, err)
})
t.Run("MissingHashNode", func(t *testing.T) {
_, err := tr.GetProof([]byte{0x55})
require.Error(t, err)
})
}
func TestVerifyProof(t *testing.T) {
tr := newProofTrie(t)
t.Run("Simple", func(t *testing.T) {
proof, err := tr.GetProof([]byte{0x12, 0x32})
require.NoError(t, err)
t.Run("Good", func(t *testing.T) {
v, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x32}, proof)
require.True(t, ok)
require.Equal(t, []byte("value2"), v)
})
t.Run("Bad", func(t *testing.T) {
_, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x31}, proof)
require.False(t, ok)
})
})
t.Run("InsideHash", func(t *testing.T) {
key := []byte{0x45, 0x67}
proof, err := tr.GetProof(key)
require.NoError(t, err)
v, ok := VerifyProof(tr.root.Hash(), key, proof)
require.True(t, ok)
require.Equal(t, []byte("somevalue"), v)
})
}

389
pkg/core/mpt/trie.go Normal file
View file

@ -0,0 +1,389 @@
package mpt
import (
"bytes"
"errors"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// Trie is an MPT trie storing all key-value pairs.
type Trie struct {
Store *storage.MemCachedStore
root Node
}
// ErrNotFound is returned when requested trie item is missing.
var ErrNotFound = errors.New("item not found")
// NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors
// so that all storage errors are processed during `store.Persist()` at the caller.
// This also has the benefit, that every `Put` can be considered an atomic operation.
func NewTrie(root Node, store *storage.MemCachedStore) *Trie {
if root == nil {
root = new(HashNode)
}
return &Trie{
Store: store,
root: root,
}
}
// Get returns value for the provided key in t.
func (t *Trie) Get(key []byte) ([]byte, error) {
path := toNibbles(key)
r, bs, err := t.getWithPath(t.root, path)
if err != nil {
return nil, err
}
t.root = r
return bs, nil
}
// getWithPath returns value the provided path in a subtrie rooting in curr.
// It also returns a current node with all hash nodes along the path
// replaced to their "unhashed" counterparts.
func (t *Trie) getWithPath(curr Node, path []byte) (Node, []byte, error) {
switch n := curr.(type) {
case *LeafNode:
if len(path) == 0 {
return curr, copySlice(n.value), nil
}
case *BranchNode:
i, path := splitPath(path)
r, bs, err := t.getWithPath(n.Children[i], path)
if err != nil {
return nil, nil, err
}
n.Children[i] = r
return n, bs, nil
case *HashNode:
if !n.IsEmpty() {
if r, err := t.getFromStore(n.hash); err == nil {
return t.getWithPath(r, path)
}
}
case *ExtensionNode:
if bytes.HasPrefix(path, n.key) {
r, bs, err := t.getWithPath(n.next, path[len(n.key):])
if err != nil {
return nil, nil, err
}
n.next = r
return curr, bs, err
}
default:
panic("invalid MPT node type")
}
return curr, nil, ErrNotFound
}
// Put puts key-value pair in t.
func (t *Trie) Put(key, value []byte) error {
if len(key) > MaxKeyLength {
return errors.New("key is too big")
} else if len(value) > MaxValueLength {
return errors.New("value is too big")
}
if len(value) == 0 {
return t.Delete(key)
}
path := toNibbles(key)
n := NewLeafNode(value)
r, err := t.putIntoNode(t.root, path, n)
if err != nil {
return err
}
t.root = r
return nil
}
// putIntoLeaf puts val to trie if current node is a Leaf.
// It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) {
v := val.(*LeafNode)
if len(path) == 0 {
return v, nil
}
b := NewBranchNode()
b.Children[path[0]] = newSubTrie(path[1:], v)
b.Children[lastChild] = curr
return b, nil
}
// putIntoBranch puts val to trie if current node is a Branch.
// It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) {
i, path := splitPath(path)
r, err := t.putIntoNode(curr.Children[i], path, val)
if err != nil {
return nil, err
}
curr.Children[i] = r
curr.invalidateHash()
return curr, nil
}
// putIntoExtension puts val to trie if current node is an Extension.
// It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) {
if bytes.HasPrefix(path, curr.key) {
r, err := t.putIntoNode(curr.next, path[len(curr.key):], val)
if err != nil {
return nil, err
}
curr.next = r
curr.invalidateHash()
return curr, nil
}
pref := lcp(curr.key, path)
lp := len(pref)
keyTail := curr.key[lp:]
pathTail := path[lp:]
s1 := newSubTrie(keyTail[1:], curr.next)
b := NewBranchNode()
b.Children[keyTail[0]] = s1
i, pathTail := splitPath(pathTail)
s2 := newSubTrie(pathTail, val)
b.Children[i] = s2
if lp > 0 {
return NewExtensionNode(copySlice(pref), b), nil
}
return b, nil
}
// putIntoHash puts val to trie if current node is a HashNode.
// It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) {
if curr.IsEmpty() {
return newSubTrie(path, val), nil
}
result, err := t.getFromStore(curr.hash)
if err != nil {
return nil, err
}
return t.putIntoNode(result, path, val)
}
// newSubTrie create new trie containing node at provided path.
func newSubTrie(path []byte, val Node) Node {
if len(path) == 0 {
return val
}
return NewExtensionNode(path, val)
}
func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) {
switch n := curr.(type) {
case *LeafNode:
return t.putIntoLeaf(n, path, val)
case *BranchNode:
return t.putIntoBranch(n, path, val)
case *ExtensionNode:
return t.putIntoExtension(n, path, val)
case *HashNode:
return t.putIntoHash(n, path, val)
default:
panic("invalid MPT node type")
}
}
// Delete removes key from trie.
// It returns no error on missing key.
func (t *Trie) Delete(key []byte) error {
path := toNibbles(key)
r, err := t.deleteFromNode(t.root, path)
if err != nil {
return err
}
t.root = r
return nil
}
func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
i, path := splitPath(path)
r, err := t.deleteFromNode(b.Children[i], path)
if err != nil {
return nil, err
}
b.Children[i] = r
b.invalidateHash()
var count, index int
for i := range b.Children {
h, ok := b.Children[i].(*HashNode)
if !ok || !h.IsEmpty() {
index = i
count++
}
}
// count is >= 1 because branch node had at least 2 children before deletion.
if count > 1 {
return b, nil
}
c := b.Children[index]
if index == lastChild {
return c, nil
}
if h, ok := c.(*HashNode); ok {
c, err = t.getFromStore(h.Hash())
if err != nil {
return nil, err
}
}
if e, ok := c.(*ExtensionNode); ok {
e.key = append([]byte{byte(index)}, e.key...)
e.invalidateHash()
return e, nil
}
return NewExtensionNode([]byte{byte(index)}, c), nil
}
func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) {
if !bytes.HasPrefix(path, n.key) {
return nil, ErrNotFound
}
r, err := t.deleteFromNode(n.next, path[len(n.key):])
if err != nil {
return nil, err
}
switch nxt := r.(type) {
case *ExtensionNode:
n.key = append(n.key, nxt.key...)
n.next = nxt.next
n.invalidateHash()
case *HashNode:
if nxt.IsEmpty() {
return nxt, nil
}
default:
n.next = r
}
return n, nil
}
func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) {
switch n := curr.(type) {
case *LeafNode:
if len(path) == 0 {
return new(HashNode), nil
}
return nil, ErrNotFound
case *BranchNode:
return t.deleteFromBranch(n, path)
case *ExtensionNode:
return t.deleteFromExtension(n, path)
case *HashNode:
if n.IsEmpty() {
return nil, ErrNotFound
}
newNode, err := t.getFromStore(n.Hash())
if err != nil {
return nil, err
}
return t.deleteFromNode(newNode, path)
default:
panic("invalid MPT node type")
}
}
// StateRoot returns root hash of t.
func (t *Trie) StateRoot() util.Uint256 {
if hn, ok := t.root.(*HashNode); ok && hn.IsEmpty() {
return util.Uint256{}
}
return t.root.Hash()
}
func makeStorageKey(mptKey []byte) []byte {
return append([]byte{byte(storage.DataMPT)}, mptKey...)
}
// Flush puts every node in the trie except Hash ones to the storage.
// Because we care only about block-level changes, there is no need to put every
// new node to storage. Normally, flush should be called with every StateRoot persist, i.e.
// after every block.
func (t *Trie) Flush() {
t.flush(t.root)
}
func (t *Trie) flush(node Node) {
switch n := node.(type) {
case *BranchNode:
for i := range n.Children {
t.flush(n.Children[i])
}
case *ExtensionNode:
t.flush(n.next)
case *HashNode:
return
}
t.putToStore(node)
}
func (t *Trie) putToStore(n Node) {
if n.Type() == HashT {
panic("can't put hash node in trie")
}
bs := toBytes(n)
h := hash.DoubleSha256(bs)
_ = t.Store.Put(makeStorageKey(h.BytesBE()), bs) // put in MemCached returns no errors
}
func (t *Trie) getFromStore(h util.Uint256) (Node, error) {
data, err := t.Store.Get(makeStorageKey(h.BytesBE()))
if err != nil {
return nil, err
}
var n NodeObject
r := io.NewBinReaderFromBuf(data)
n.DecodeBinary(r)
if r.Err != nil {
return nil, r.Err
}
return n.Node, nil
}
// Collapse compresses all nodes at depth n to the hash nodes.
// Note: this function does not perform any kind of storage flushing so
// `Flush()` should be called explicitly before invoking function.
func (t *Trie) Collapse(depth int) {
if depth < 0 {
panic("negative depth")
}
t.root = collapse(depth, t.root)
}
func collapse(depth int, node Node) Node {
if _, ok := node.(*HashNode); ok {
return node
} else if depth == 0 {
return NewHashNode(node.Hash())
}
switch n := node.(type) {
case *BranchNode:
for i := range n.Children {
n.Children[i] = collapse(depth-1, n.Children[i])
}
case *ExtensionNode:
n.next = collapse(depth-1, n.next)
case *LeafNode:
case *HashNode:
default:
panic("invalid MPT node type")
}
return node
}

446
pkg/core/mpt/trie_test.go Normal file
View file

@ -0,0 +1,446 @@
package mpt
import (
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/stretchr/testify/require"
)
func newTestStore() *storage.MemCachedStore {
return storage.NewMemCachedStore(storage.NewMemoryStore())
}
func newTestTrie(t *testing.T) *Trie {
b := NewBranchNode()
l1 := NewLeafNode([]byte{0xAB, 0xCD})
b.Children[0] = NewExtensionNode([]byte{0x01}, l1)
l2 := NewLeafNode([]byte{0x22, 0x22})
b.Children[9] = NewExtensionNode([]byte{0x09}, l2)
v := NewLeafNode([]byte("hello"))
h := NewHashNode(v.Hash())
b.Children[10] = NewExtensionNode([]byte{0x0e}, h)
e := NewExtensionNode(toNibbles([]byte{0xAC}), b)
tr := NewTrie(e, newTestStore())
tr.putToStore(e)
tr.putToStore(b)
tr.putToStore(l1)
tr.putToStore(l2)
tr.putToStore(v)
tr.putToStore(b.Children[0])
tr.putToStore(b.Children[9])
tr.putToStore(b.Children[10])
return tr
}
func TestTrie_PutIntoBranchNode(t *testing.T) {
b := NewBranchNode()
l := NewLeafNode([]byte{0x8})
b.Children[0x7] = NewHashNode(l.Hash())
b.Children[0x8] = NewHashNode(random.Uint256())
tr := NewTrie(b, newTestStore())
// next
require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34}))
tr.testHas(t, []byte{}, []byte{0x12, 0x34})
// empty hash node child
require.NoError(t, tr.Put([]byte{0x66}, []byte{0x56}))
tr.testHas(t, []byte{0x66}, []byte{0x56})
require.True(t, isValid(tr.root))
// missing hash
require.Error(t, tr.Put([]byte{0x70}, []byte{0x42}))
require.True(t, isValid(tr.root))
// hash is in store
tr.putToStore(l)
require.NoError(t, tr.Put([]byte{0x70}, []byte{0x42}))
require.True(t, isValid(tr.root))
}
func TestTrie_PutIntoExtensionNode(t *testing.T) {
l := NewLeafNode([]byte{0x11})
key := []byte{0x12}
e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash()))
tr := NewTrie(e, newTestStore())
// missing hash
require.Error(t, tr.Put(key, []byte{0x42}))
tr.putToStore(l)
require.NoError(t, tr.Put(key, []byte{0x42}))
tr.testHas(t, key, []byte{0x42})
require.True(t, isValid(tr.root))
}
func TestTrie_PutIntoHashNode(t *testing.T) {
b := NewBranchNode()
l := NewLeafNode(random.Bytes(5))
e := NewExtensionNode([]byte{0x02}, l)
b.Children[1] = NewHashNode(e.Hash())
b.Children[9] = NewHashNode(random.Uint256())
tr := NewTrie(b, newTestStore())
tr.putToStore(e)
t.Run("MissingLeafHash", func(t *testing.T) {
_, err := tr.Get([]byte{0x12})
require.Error(t, err)
})
tr.putToStore(l)
val := random.Bytes(3)
require.NoError(t, tr.Put([]byte{0x12, 0x34}, val))
tr.testHas(t, []byte{0x12, 0x34}, val)
tr.testHas(t, []byte{0x12}, l.value)
require.True(t, isValid(tr.root))
}
func TestTrie_Put(t *testing.T) {
trExp := newTestTrie(t)
trAct := NewTrie(nil, newTestStore())
require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD}))
require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22}))
require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello")))
// Note: the exact tries differ because of ("acae":"hello") node is stored as Hash node in test trie.
require.Equal(t, trExp.root.Hash(), trAct.root.Hash())
require.True(t, isValid(trAct.root))
}
func TestTrie_PutInvalid(t *testing.T) {
tr := NewTrie(nil, newTestStore())
key, value := []byte("key"), []byte("value")
// big key
require.Error(t, tr.Put(make([]byte, MaxKeyLength+1), value))
// big value
require.Error(t, tr.Put(key, make([]byte, MaxValueLength+1)))
// this is ok though
require.NoError(t, tr.Put(key, value))
tr.testHas(t, key, value)
}
func TestTrie_BigPut(t *testing.T) {
tr := NewTrie(nil, newTestStore())
items := []struct{ k, v string }{
{"item with long key", "value1"},
{"item with matching prefix", "value2"},
{"another prefix", "value3"},
{"another prefix 2", "value4"},
{"another ", "value5"},
}
for i := range items {
require.NoError(t, tr.Put([]byte(items[i].k), []byte(items[i].v)))
}
for i := range items {
tr.testHas(t, []byte(items[i].k), []byte(items[i].v))
}
t.Run("Rewrite", func(t *testing.T) {
k, v := []byte(items[0].k), []byte{0x01, 0x23}
require.NoError(t, tr.Put(k, v))
tr.testHas(t, k, v)
})
t.Run("Remove", func(t *testing.T) {
k := []byte(items[1].k)
require.NoError(t, tr.Put(k, []byte{}))
tr.testHas(t, k, nil)
})
}
func (tr *Trie) testHas(t *testing.T, key, value []byte) {
v, err := tr.Get(key)
if value == nil {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, value, v)
}
// isValid checks for 3 invariants:
// - BranchNode contains > 1 children
// - ExtensionNode do not contain another extension node
// - ExtensionNode do not have nil key
// It is used only during testing to catch possible bugs.
func isValid(curr Node) bool {
switch n := curr.(type) {
case *BranchNode:
var count int
for i := range n.Children {
if !isValid(n.Children[i]) {
return false
}
hn, ok := n.Children[i].(*HashNode)
if !ok || !hn.IsEmpty() {
count++
}
}
return count > 1
case *ExtensionNode:
_, ok := n.next.(*ExtensionNode)
return len(n.key) != 0 && !ok
default:
return true
}
}
func TestTrie_Get(t *testing.T) {
t.Run("HashNode", func(t *testing.T) {
tr := newTestTrie(t)
tr.testHas(t, []byte{0xAC, 0xAE}, []byte("hello"))
})
t.Run("UnfoldRoot", func(t *testing.T) {
tr := newTestTrie(t)
single := NewTrie(NewHashNode(tr.root.Hash()), tr.Store)
single.testHas(t, []byte{0xAC}, nil)
single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD})
single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22})
single.testHas(t, []byte{0xAC, 0xAE}, []byte("hello"))
})
}
func TestTrie_Flush(t *testing.T) {
pairs := map[string][]byte{
"": []byte("value0"),
"key1": []byte("value1"),
"key2": []byte("value2"),
}
tr := NewTrie(nil, newTestStore())
for k, v := range pairs {
require.NoError(t, tr.Put([]byte(k), v))
}
tr.Flush()
tr = NewTrie(NewHashNode(tr.StateRoot()), tr.Store)
for k, v := range pairs {
actual, err := tr.Get([]byte(k))
require.NoError(t, err)
require.Equal(t, v, actual)
}
}
func TestTrie_Delete(t *testing.T) {
t.Run("Hash", func(t *testing.T) {
t.Run("FromStore", func(t *testing.T) {
l := NewLeafNode([]byte{0x12})
tr := NewTrie(NewHashNode(l.Hash()), newTestStore())
t.Run("NotInStore", func(t *testing.T) {
require.Error(t, tr.Delete([]byte{}))
})
tr.putToStore(l)
tr.testHas(t, []byte{}, []byte{0x12})
require.NoError(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{}, nil)
})
t.Run("Empty", func(t *testing.T) {
tr := NewTrie(nil, newTestStore())
require.Error(t, tr.Delete([]byte{}))
})
})
t.Run("Leaf", func(t *testing.T) {
l := NewLeafNode([]byte{0x12, 0x34})
tr := NewTrie(l, newTestStore())
t.Run("NonExistentKey", func(t *testing.T) {
require.Error(t, tr.Delete([]byte{0x12}))
tr.testHas(t, []byte{}, []byte{0x12, 0x34})
})
require.NoError(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{}, nil)
})
t.Run("Extension", func(t *testing.T) {
t.Run("SingleKey", func(t *testing.T) {
l := NewLeafNode([]byte{0x12, 0x34})
e := NewExtensionNode([]byte{0x0A, 0x0B}, l)
tr := NewTrie(e, newTestStore())
t.Run("NonExistentKey", func(t *testing.T) {
require.Error(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{0xAB}, []byte{0x12, 0x34})
})
require.NoError(t, tr.Delete([]byte{0xAB}))
require.True(t, tr.root.(*HashNode).IsEmpty())
})
t.Run("MultipleKeys", func(t *testing.T) {
b := NewBranchNode()
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34}))
b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78}))
e := NewExtensionNode([]byte{0x01, 0x02}, b)
tr := NewTrie(e, newTestStore())
h := e.Hash()
require.NoError(t, tr.Delete([]byte{0x12, 0x01}))
tr.testHas(t, []byte{0x12, 0x01}, nil)
tr.testHas(t, []byte{0x12, 0x67}, []byte{0x56, 0x78})
require.NotEqual(t, h, tr.root.Hash())
require.Equal(t, toNibbles([]byte{0x12, 0x67}), e.key)
require.IsType(t, (*LeafNode)(nil), e.next)
})
})
t.Run("Branch", func(t *testing.T) {
t.Run("3 Children", func(t *testing.T) {
b := NewBranchNode()
b.Children[lastChild] = NewLeafNode([]byte{0x12})
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34}))
b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56}))
tr := NewTrie(b, newTestStore())
require.NoError(t, tr.Delete([]byte{0x16}))
tr.testHas(t, []byte{}, []byte{0x12})
tr.testHas(t, []byte{0x01}, []byte{0x34})
tr.testHas(t, []byte{0x16}, nil)
})
t.Run("2 Children", func(t *testing.T) {
newt := func(t *testing.T) *Trie {
b := NewBranchNode()
b.Children[lastChild] = NewLeafNode([]byte{0x12})
l := NewLeafNode([]byte{0x34})
e := NewExtensionNode([]byte{0x06}, l)
b.Children[5] = NewHashNode(e.Hash())
tr := NewTrie(b, newTestStore())
tr.putToStore(l)
tr.putToStore(e)
return tr
}
t.Run("DeleteLast", func(t *testing.T) {
t.Run("MergeExtension", func(t *testing.T) {
tr := newt(t)
require.NoError(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{}, nil)
tr.testHas(t, []byte{0x56}, []byte{0x34})
require.IsType(t, (*ExtensionNode)(nil), tr.root)
})
t.Run("LeaveLeaf", func(t *testing.T) {
c := NewBranchNode()
c.Children[5] = NewLeafNode([]byte{0x05})
c.Children[6] = NewLeafNode([]byte{0x06})
b := NewBranchNode()
b.Children[lastChild] = NewLeafNode([]byte{0x12})
b.Children[5] = c
tr := NewTrie(b, newTestStore())
require.NoError(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{}, nil)
tr.testHas(t, []byte{0x55}, []byte{0x05})
tr.testHas(t, []byte{0x56}, []byte{0x06})
require.IsType(t, (*ExtensionNode)(nil), tr.root)
})
})
t.Run("DeleteMiddle", func(t *testing.T) {
tr := newt(t)
require.NoError(t, tr.Delete([]byte{0x56}))
tr.testHas(t, []byte{}, []byte{0x12})
tr.testHas(t, []byte{0x56}, nil)
require.IsType(t, (*LeafNode)(nil), tr.root)
})
})
})
}
func TestTrie_PanicInvalidRoot(t *testing.T) {
tr := &Trie{Store: newTestStore()}
require.Panics(t, func() { _ = tr.Put([]byte{1}, []byte{2}) })
require.Panics(t, func() { _, _ = tr.Get([]byte{1}) })
require.Panics(t, func() { _ = tr.Delete([]byte{1}) })
}
func TestTrie_Collapse(t *testing.T) {
t.Run("PanicNegative", func(t *testing.T) {
tr := newTestTrie(t)
require.Panics(t, func() { tr.Collapse(-1) })
})
t.Run("Depth=0", func(t *testing.T) {
tr := newTestTrie(t)
h := tr.root.Hash()
_, ok := tr.root.(*HashNode)
require.False(t, ok)
tr.Collapse(0)
_, ok = tr.root.(*HashNode)
require.True(t, ok)
require.Equal(t, h, tr.root.Hash())
})
t.Run("Branch,Depth=1", func(t *testing.T) {
b := NewBranchNode()
e := NewExtensionNode([]byte{0x01}, NewLeafNode([]byte("value1")))
he := e.Hash()
b.Children[0] = e
hb := b.Hash()
tr := NewTrie(b, newTestStore())
tr.Collapse(1)
newb, ok := tr.root.(*BranchNode)
require.True(t, ok)
require.Equal(t, hb, newb.Hash())
require.IsType(t, (*HashNode)(nil), b.Children[0])
require.Equal(t, he, b.Children[0].Hash())
})
t.Run("Extension,Depth=1", func(t *testing.T) {
l := NewLeafNode([]byte("value"))
hl := l.Hash()
e := NewExtensionNode([]byte{0x01}, l)
h := e.Hash()
tr := NewTrie(e, newTestStore())
tr.Collapse(1)
newe, ok := tr.root.(*ExtensionNode)
require.True(t, ok)
require.Equal(t, h, newe.Hash())
require.IsType(t, (*HashNode)(nil), newe.next)
require.Equal(t, hl, newe.next.Hash())
})
t.Run("Leaf", func(t *testing.T) {
l := NewLeafNode([]byte("value"))
tr := NewTrie(l, newTestStore())
tr.Collapse(10)
require.Equal(t, NewLeafNode([]byte("value")), tr.root)
})
t.Run("Hash", func(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
tr := NewTrie(new(HashNode), newTestStore())
require.NotPanics(t, func() { tr.Collapse(1) })
hn, ok := tr.root.(*HashNode)
require.True(t, ok)
require.True(t, hn.IsEmpty())
})
h := random.Uint256()
hn := NewHashNode(h)
tr := NewTrie(hn, newTestStore())
tr.Collapse(10)
newRoot, ok := tr.root.(*HashNode)
require.True(t, ok)
require.Equal(t, NewHashNode(h), newRoot)
})
}

View file

@ -9,6 +9,7 @@ import (
const (
DataBlock KeyPrefix = 0x01
DataTransaction KeyPrefix = 0x02
DataMPT KeyPrefix = 0x03
STAccount KeyPrefix = 0x40
STCoin KeyPrefix = 0x44
STSpentCoin KeyPrefix = 0x45