diff --git a/pkg/core/mpt/batch_test.go b/pkg/core/mpt/batch_test.go index 9438b8e2b..e8233ee82 100644 --- a/pkg/core/mpt/batch_test.go +++ b/pkg/core/mpt/batch_test.go @@ -70,26 +70,26 @@ func TestTrie_PutBatchLeaf(t *testing.T) { prepareLeaf := func(t *testing.T) (*Trie, *Trie) { tr1 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(new(HashNode), false, newTestStore()) - require.NoError(t, tr1.Put([]byte{}, []byte("value"))) - require.NoError(t, tr2.Put([]byte{}, []byte("value"))) + require.NoError(t, tr1.Put([]byte{0}, []byte("value"))) + require.NoError(t, tr2.Put([]byte{0}, []byte("value"))) return tr1, tr2 } t.Run("remove", func(t *testing.T) { tr1, tr2 := prepareLeaf(t) - var ps = pairs{{[]byte{}, nil}} + var ps = pairs{{[]byte{0}, nil}} testPut(t, ps, tr1, tr2) }) t.Run("replace", func(t *testing.T) { tr1, tr2 := prepareLeaf(t) - var ps = pairs{{[]byte{}, []byte("replace")}} + var ps = pairs{{[]byte{0}, []byte("replace")}} testPut(t, ps, tr1, tr2) }) t.Run("remove and replace", func(t *testing.T) { tr1, tr2 := prepareLeaf(t) var ps = pairs{ - {[]byte{}, nil}, - {[]byte{2}, []byte("replace2")}, + {[]byte{0}, nil}, + {[]byte{0, 2}, []byte("replace2")}, } testPut(t, ps, tr1, tr2) }) @@ -122,7 +122,7 @@ func TestTrie_PutBatchExtension(t *testing.T) { t.Run("add to next with leaf", func(t *testing.T) { tr1, tr2 := prepareExtension(t) var ps = pairs{ - {[]byte{}, []byte("value3")}, + {[]byte{0}, []byte("value3")}, {[]byte{1, 2, 3}, []byte("value2")}, } testPut(t, ps, tr1, tr2) @@ -232,7 +232,7 @@ func TestTrie_PutBatchEmpty(t *testing.T) { tr1 := NewTrie(new(HashNode), false, newTestStore()) tr2 := NewTrie(new(HashNode), false, newTestStore()) var ps = pairs{ - {[]byte{}, []byte("value0")}, + {[]byte{0}, []byte("value0")}, {[]byte{1}, []byte("value1")}, {[]byte{3}, []byte("value3")}, } @@ -240,7 +240,7 @@ func TestTrie_PutBatchEmpty(t *testing.T) { }) t.Run("incomplete", func(t *testing.T) { var ps = pairs{ - {[]byte{}, []byte("replace0")}, + {[]byte{0}, []byte("replace0")}, {[]byte{1}, []byte("replace1")}, {[]byte{2}, nil}, {[]byte{3}, []byte("replace3")}, diff --git a/pkg/core/mpt/compat_test.go b/pkg/core/mpt/compat_test.go new file mode 100644 index 000000000..1581e64e9 --- /dev/null +++ b/pkg/core/mpt/compat_test.go @@ -0,0 +1,360 @@ +package mpt + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func prepareMPTCompat() *Trie { + b := NewBranchNode() + r := NewExtensionNode([]byte{0x0a, 0x0c}, b) + v1 := NewLeafNode([]byte{0xab, 0xcd}) //key=ac01 + v2 := NewLeafNode([]byte{0x22, 0x22}) //key=ac + v3 := NewLeafNode([]byte("existing")) //key=acae + v4 := NewLeafNode([]byte("missing")) + h3 := NewHashNode(v3.Hash()) + e1 := NewExtensionNode([]byte{0x01}, v1) + e3 := NewExtensionNode([]byte{0x0e}, h3) + e4 := NewExtensionNode([]byte{0x01}, v4) + b.Children[0] = e1 + b.Children[10] = e3 + b.Children[16] = v2 + b.Children[15] = NewHashNode(e4.Hash()) + + tr := NewTrie(r, true, newTestStore()) + tr.putToStore(r) + tr.putToStore(b) + tr.putToStore(e1) + tr.putToStore(e3) + tr.putToStore(v1) + tr.putToStore(v2) + tr.putToStore(v3) + + return tr +} + +// TestCompatibility contains tests present in C# implementation. +// https://github.com/neo-project/neo-modules/blob/master/tests/Neo.Plugins.StateService.Tests/MPT/UT_MPTTrie.cs +// There are some differences, though: +// 1. In our implementation delete is silent, i.e. we do not return an error is the key is missing. +// However, we do return error when contents of hash node are missing from the store +// (corresponds to exception in C# implementation). +// 2. If `GetProof` key is missing from the trie, we return error, while C# node just returns empty proof +// with no exception. +func TestCompatibility(t *testing.T) { + mainTrie := prepareMPTCompat() + + t.Run("TryGet", func(t *testing.T) { + tr := copyTrie(mainTrie) + tr.testHas(t, []byte{0xac, 0x01}, []byte{0xab, 0xcd}) + tr.testHas(t, []byte{0xac}, []byte{0x22, 0x22}) + tr.testHas(t, []byte{0xab, 0x99}, nil) + tr.testHas(t, []byte{0xac, 0x39}, nil) + tr.testHas(t, []byte{0xac, 0x02}, nil) + tr.testHas(t, []byte{0xac, 0x01, 0x00}, nil) + tr.testHas(t, []byte{0xac, 0x99, 0x10}, nil) + tr.testHas(t, []byte{0xac, 0xf1}, nil) + }) + + t.Run("TryGetResolve", func(t *testing.T) { + tr := copyTrie(mainTrie) + tr.testHas(t, []byte{0xac, 0xae}, []byte("existing")) + }) + + t.Run("TryPut", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xac, 0x01}, []byte{0xab, 0xcd}, + []byte{0xac}, []byte{0x22, 0x22}, + []byte{0xac, 0xae}, []byte("existing"), + []byte{0xac, 0xf1}, []byte("missing")) + + require.Equal(t, mainTrie.root.Hash(), tr.root.Hash()) + require.Error(t, tr.Put(nil, []byte{0x01})) + require.NoError(t, tr.Put([]byte{0x01}, nil)) + require.Error(t, tr.Put(make([]byte, MaxKeyLength+1), nil)) + require.Error(t, tr.Put([]byte{0x01}, make([]byte, MaxValueLength+1))) + require.Equal(t, mainTrie.root.Hash(), tr.root.Hash()) + require.NoError(t, tr.Put([]byte{0xac, 0x01}, []byte{0xab})) + }) + + t.Run("PutCantResolve", func(t *testing.T) { + tr := copyTrie(mainTrie) + require.Error(t, tr.Put([]byte{0xac, 0xf1, 0x11}, []byte{1})) + }) + + t.Run("TryDelete", func(t *testing.T) { + tr := copyTrie(mainTrie) + tr.testHas(t, []byte{0xac}, []byte{0x22, 0x22}) + require.NoError(t, tr.Delete([]byte{0x0c, 0x99})) + require.NoError(t, tr.Delete(nil)) + require.NoError(t, tr.Delete([]byte{0xac, 0x20})) + + require.Error(t, tr.Delete([]byte{0xac, 0xf1})) // error for can't resolve + + // In our implementation missing keys are ignored. + require.NoError(t, tr.Delete([]byte{0xac})) + require.NoError(t, tr.Delete([]byte{0xac, 0xae, 0x01})) + require.NoError(t, tr.Delete([]byte{0xac, 0xae})) + + require.Equal(t, "cb06925428b7c727375c7fdd943a302fe2c818cf2e2eaf63a7932e3fd6cb3408", + tr.root.Hash().StringLE()) + }) + + t.Run("DeleteRemainCanResolve", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xac, 0x00}, []byte{0xab, 0xcd}, + []byte{0xac, 0x10}, []byte{0xab, 0xcd}) + tr.Flush() + + tr2 := copyTrie(tr) + require.NoError(t, tr2.Delete([]byte{0xac, 0x00})) + + tr2.Flush() + require.NoError(t, tr2.Delete([]byte{0xac, 0x10})) + }) + + t.Run("DeleteRemainCantResolve", func(t *testing.T) { + b := NewBranchNode() + r := NewExtensionNode([]byte{0x0a, 0x0c}, b) + v1 := NewLeafNode([]byte{0xab, 0xcd}) + v4 := NewLeafNode([]byte("missing")) + e1 := NewExtensionNode([]byte{0x01}, v1) + e4 := NewExtensionNode([]byte{0x01}, v4) + b.Children[0] = e1 + b.Children[15] = NewHashNode(e4.Hash()) + + tr := NewTrie(NewHashNode(r.Hash()), false, newTestStore()) + tr.putToStore(r) + tr.putToStore(b) + tr.putToStore(e1) + tr.putToStore(v1) + + require.Error(t, tr.Delete([]byte{0xac, 0x01})) + }) + + t.Run("DeleteSameValue", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xac, 0x01}, []byte{0xab, 0xcd}, + []byte{0xac, 0x02}, []byte{0xab, 0xcd}) + tr.testHas(t, []byte{0xac, 0x01}, []byte{0xab, 0xcd}) + tr.testHas(t, []byte{0xac, 0x02}, []byte{0xab, 0xcd}) + + require.NoError(t, tr.Delete([]byte{0xac, 0x01})) + tr.testHas(t, []byte{0xac, 0x02}, []byte{0xab, 0xcd}) + tr.Flush() + + tr2 := NewTrie(NewHashNode(tr.root.Hash()), false, tr.Store) + tr2.testHas(t, []byte{0xac, 0x02}, []byte{0xab, 0xcd}) + }) + + t.Run("BranchNodeRemainValue", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xac, 0x11}, []byte{0xac, 0x11}, + []byte{0xac, 0x22}, []byte{0xac, 0x22}, + []byte{0xac}, []byte{0xac}) + tr.Flush() + checkBatchSize(t, tr, 7) + + require.NoError(t, tr.Delete([]byte{0xac, 0x11})) + tr.Flush() + checkBatchSize(t, tr, 5) + + require.NoError(t, tr.Delete([]byte{0xac, 0x22})) + tr.Flush() + checkBatchSize(t, tr, 2) + }) + + t.Run("GetProof", func(t *testing.T) { + b := NewBranchNode() + r := NewExtensionNode([]byte{0x0a, 0x0c}, b) + v1 := NewLeafNode([]byte{0xab, 0xcd}) //key=ac01 + v2 := NewLeafNode([]byte{0x22, 0x22}) //key=ac + v3 := NewLeafNode([]byte("existing")) //key=acae + v4 := NewLeafNode([]byte("missing")) + h3 := NewHashNode(v3.Hash()) + e1 := NewExtensionNode([]byte{0x01}, v1) + e3 := NewExtensionNode([]byte{0x0e}, h3) + e4 := NewExtensionNode([]byte{0x01}, v4) + b.Children[0] = e1 + b.Children[10] = e3 + b.Children[16] = v2 + b.Children[15] = NewHashNode(e4.Hash()) + + tr := NewTrie(NewHashNode(r.Hash()), true, mainTrie.Store) + require.Equal(t, r.Hash(), tr.root.Hash()) + + // Tail bytes contain reference counter thus check for prefix. + proof := testGetProof(t, tr, []byte{0xac, 0x01}, 4) + require.True(t, bytes.HasPrefix(r.Bytes(), proof[0])) + require.True(t, bytes.HasPrefix(b.Bytes(), proof[1])) + require.True(t, bytes.HasPrefix(e1.Bytes(), proof[2])) + require.True(t, bytes.HasPrefix(v1.Bytes(), proof[3])) + + testGetProof(t, tr, []byte{0xac}, 3) + testGetProof(t, tr, []byte{0xac, 0x10}, 0) + testGetProof(t, tr, []byte{0xac, 0xae}, 4) + testGetProof(t, tr, nil, 0) + testGetProof(t, tr, []byte{0xac, 0x01, 0x00}, 0) + testGetProof(t, tr, []byte{0xac, 0xf1}, 0) + }) + + t.Run("VerifyProof", func(t *testing.T) { + tr := copyTrie(mainTrie) + proof := testGetProof(t, tr, []byte{0xac, 0x01}, 4) + value, ok := VerifyProof(tr.root.Hash(), []byte{0xac, 0x01}, proof) + require.True(t, ok) + require.Equal(t, []byte{0xab, 0xcd}, value) + }) + + t.Run("AddLongerKey", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xab}, []byte{0x01}, + []byte{0xab, 0xcd}, []byte{0x02}) + tr.testHas(t, []byte{0xab}, []byte{0x01}) + }) + + t.Run("SplitKey", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xab, 0xcd}, []byte{0x01}, + []byte{0xab}, []byte{0x02}) + testGetProof(t, tr, []byte{0xab, 0xcd}, 4) + + tr2 := newFilledTrie(t, + []byte{0xab}, []byte{0x02}, + []byte{0xab, 0xcd}, []byte{0x01}) + testGetProof(t, tr, []byte{0xab, 0xcd}, 4) + + require.Equal(t, tr.root.Hash(), tr2.root.Hash()) + }) + + t.Run("Reference", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xa1, 0x01}, []byte{0x01}, + []byte{0xa2, 0x01}, []byte{0x01}, + []byte{0xa3, 0x01}, []byte{0x01}) + tr.Flush() + + tr2 := copyTrie(tr) + require.NoError(t, tr2.Delete([]byte{0xa3, 0x01})) + tr2.Flush() + + tr3 := copyTrie(tr2) + require.NoError(t, tr3.Delete([]byte{0xa2, 0x01})) + tr3.testHas(t, []byte{0xa1, 0x01}, []byte{0x01}) + }) + + t.Run("Reference2", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xa1, 0x01}, []byte{0x01}, + []byte{0xa2, 0x01}, []byte{0x01}, + []byte{0xa3, 0x01}, []byte{0x01}) + tr.Flush() + checkBatchSize(t, tr, 4) + + require.NoError(t, tr.Delete([]byte{0xa3, 0x01})) + tr.Flush() + checkBatchSize(t, tr, 4) + + require.NoError(t, tr.Delete([]byte{0xa2, 0x01})) + tr.Flush() + checkBatchSize(t, tr, 2) + tr.testHas(t, []byte{0xa1, 0x01}, []byte{0x01}) + }) + + t.Run("ExtensionDeleteDirty", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xa1}, []byte{0x01}, + []byte{0xa2}, []byte{0x02}) + tr.Flush() + checkBatchSize(t, tr, 4) + + tr1 := copyTrie(tr) + require.NoError(t, tr1.Delete([]byte{0xa1})) + tr1.Flush() + require.Equal(t, 2, len(tr1.Store.GetBatch().Put)) + + tr2 := copyTrie(tr1) + require.NoError(t, tr2.Delete([]byte{0xa2})) + tr2.Flush() + require.Equal(t, 0, len(tr2.Store.GetBatch().Put)) + }) + + t.Run("BranchDeleteDirty", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0x10}, []byte{0x01}, + []byte{0x20}, []byte{0x02}, + []byte{0x30}, []byte{0x03}) + tr.Flush() + checkBatchSize(t, tr, 7) + + tr1 := copyTrie(tr) + require.NoError(t, tr1.Delete([]byte{0x10})) + tr1.Flush() + + tr2 := copyTrie(tr1) + require.NoError(t, tr2.Delete([]byte{0x20})) + tr2.Flush() + require.Equal(t, 2, len(tr2.Store.GetBatch().Put)) + + tr3 := copyTrie(tr2) + require.NoError(t, tr3.Delete([]byte{0x30})) + tr3.Flush() + require.Equal(t, 0, len(tr3.Store.GetBatch().Put)) + }) + + t.Run("ExtensionPutDirty", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0xa1}, []byte{0x01}, + []byte{0xa2}, []byte{0x02}) + tr.Flush() + checkBatchSize(t, tr, 4) + + tr1 := copyTrie(tr) + require.NoError(t, tr1.Put([]byte{0xa3}, []byte{0x03})) + tr1.Flush() + require.Equal(t, 5, len(tr1.Store.GetBatch().Put)) + }) + + t.Run("BranchPutDirty", func(t *testing.T) { + tr := newFilledTrie(t, + []byte{0x10}, []byte{0x01}, + []byte{0x20}, []byte{0x02}) + tr.Flush() + checkBatchSize(t, tr, 5) + + tr1 := copyTrie(tr) + require.NoError(t, tr1.Put([]byte{0x30}, []byte{0x03})) + tr1.Flush() + checkBatchSize(t, tr1, 7) + }) +} + +func copyTrie(t *Trie) *Trie { + return NewTrie(NewHashNode(t.root.Hash()), t.refcountEnabled, t.Store) +} + +func checkBatchSize(t *testing.T, tr *Trie, n int) { + require.Equal(t, n, len(tr.Store.GetBatch().Put)) +} + +func testGetProof(t *testing.T, tr *Trie, key []byte, size int) [][]byte { + proof, err := tr.GetProof(key) + if size == 0 { + require.Error(t, err) + return proof + } + + require.NoError(t, err) + require.Equal(t, size, len(proof)) + return proof +} + +func newFilledTrie(t *testing.T, args ...[]byte) *Trie { + tr := NewTrie(nil, true, newTestStore()) + for i := 0; i < len(args); i += 2 { + require.NoError(t, tr.Put(args[i], args[i+1])) + } + return tr +} diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go index 026201655..7cb0bb7f0 100644 --- a/pkg/core/mpt/extension.go +++ b/pkg/core/mpt/extension.go @@ -11,8 +11,14 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" ) -// MaxKeyLength is the max length of the extension node key. -const MaxKeyLength = (storage.MaxStorageKeyLen + 4) * 2 +const ( + // maxPathLength is the max length of the extension node key. + maxPathLength = (storage.MaxStorageKeyLen + 4) * 2 + + // MaxKeyLength is the max length of the key to put in trie + // before transforming to nibbles. + MaxKeyLength = maxPathLength / 2 +) // ExtensionNode represents MPT's extension node. type ExtensionNode struct { @@ -48,7 +54,7 @@ func (e *ExtensionNode) Bytes() []byte { // DecodeBinary implements io.Serializable. func (e *ExtensionNode) DecodeBinary(r *io.BinReader) { sz := r.ReadVarUint() - if sz > MaxKeyLength { + if sz > maxPathLength { r.Err = fmt.Errorf("extension node key is too big: %d", sz) return } diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index 2d2c42807..5f8aed265 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -96,7 +96,7 @@ func (n *NodeObject) UnmarshalJSON(data []byte) error { key, err := unmarshalHex(keyRaw) if err != nil { return err - } else if len(key) > MaxKeyLength { + } else if len(key) > maxPathLength { return errors.New("extension key is too big") } diff --git a/pkg/core/mpt/node_test.go b/pkg/core/mpt/node_test.go index 74c0247a2..0af0509a9 100644 --- a/pkg/core/mpt/node_test.go +++ b/pkg/core/mpt/node_test.go @@ -61,7 +61,7 @@ func TestNode_Serializable(t *testing.T) { t.Run("WithType", getTestFuncEncode(true, &NodeObject{e}, new(NodeObject))) }) t.Run("BigKey", getTestFuncEncode(false, - NewExtensionNode(random.Bytes(MaxKeyLength+1), NewLeafNode(random.Bytes(10))), new(ExtensionNode))) + NewExtensionNode(random.Bytes(maxPathLength+1), NewLeafNode(random.Bytes(10))), new(ExtensionNode))) }) t.Run("Branch", func(t *testing.T) { diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go index 6a0335273..6f4b452a4 100644 --- a/pkg/core/mpt/trie.go +++ b/pkg/core/mpt/trie.go @@ -97,7 +97,9 @@ func (t *Trie) getWithPath(curr Node, path []byte) (Node, []byte, error) { // Put puts key-value pair in t. func (t *Trie) Put(key, value []byte) error { - if len(key) > MaxKeyLength { + if len(key) == 0 { + return errors.New("key is empty") + } else if len(key) > MaxKeyLength { return errors.New("key is too big") } else if len(value) > MaxValueLength { return errors.New("value is too big") diff --git a/pkg/core/mpt/trie_test.go b/pkg/core/mpt/trie_test.go index bbb3b9828..18d0d43b6 100644 --- a/pkg/core/mpt/trie_test.go +++ b/pkg/core/mpt/trie_test.go @@ -85,10 +85,6 @@ func TestTrie_PutIntoBranchNode(t *testing.T) { b.Children[0x8] = NewHashNode(random.Uint256()) tr := NewTrie(b, false, newTestStore()) - // next - require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34})) - tr.testHas(t, []byte{}, []byte{0x12, 0x34}) - // empty hash node child require.NoError(t, tr.Put([]byte{0x66}, []byte{0x56})) tr.testHas(t, []byte{0x66}, []byte{0x56}) @@ -160,8 +156,11 @@ func TestTrie_PutInvalid(t *testing.T) { tr := NewTrie(nil, false, newTestStore()) key, value := []byte("key"), []byte("value") + // empty key + require.Error(t, tr.Put(nil, value)) + // big key - require.Error(t, tr.Put(make([]byte, MaxKeyLength+1), value)) + require.Error(t, tr.Put(make([]byte, maxPathLength+1), value)) // big value require.Error(t, tr.Put(key, make([]byte, MaxValueLength+1))) @@ -271,7 +270,7 @@ func TestTrie_Get(t *testing.T) { func TestTrie_Flush(t *testing.T) { pairs := map[string][]byte{ - "": []byte("value0"), + "x": []byte("value0"), "key1": []byte("value1"), "key2": []byte("value2"), }