diff --git a/pkg/core/mpt/proof.go b/pkg/core/mpt/proof.go new file mode 100644 index 000000000..f785bd9d4 --- /dev/null +++ b/pkg/core/mpt/proof.go @@ -0,0 +1,74 @@ +package mpt + +import ( + "bytes" + + "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// GetProof returns a proof that key belongs to t. +// Proof consist of serialized nodes occuring on path from the root to the leaf of key. +func (t *Trie) GetProof(key []byte) ([][]byte, error) { + var proof [][]byte + path := toNibbles(key) + r, err := t.getProof(t.root, path, &proof) + if err != nil { + return proof, err + } + t.root = r + return proof, nil +} + +func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) { + switch n := curr.(type) { + case *LeafNode: + if len(path) == 0 { + *proofs = append(*proofs, toBytes(n)) + return n, nil + } + case *BranchNode: + *proofs = append(*proofs, toBytes(n)) + i, path := splitPath(path) + r, err := t.getProof(n.Children[i], path, proofs) + if err != nil { + return nil, err + } + n.Children[i] = r + return n, nil + case *ExtensionNode: + if bytes.HasPrefix(path, n.key) { + *proofs = append(*proofs, toBytes(n)) + r, err := t.getProof(n.next, path[len(n.key):], proofs) + if err != nil { + return nil, err + } + n.next = r + return n, nil + } + case *HashNode: + if !n.IsEmpty() { + r, err := t.getFromStore(n.Hash()) + if err != nil { + return nil, err + } + return t.getProof(r, path, proofs) + } + } + return nil, ErrNotFound +} + +// VerifyProof verifies that path indeed belongs to a MPT with the specified root hash. +// It also returns value for the key. +func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) { + path := toNibbles(key) + tr := NewTrie(NewHashNode(rh), storage.NewMemCachedStore(storage.NewMemoryStore())) + for i := range proofs { + h := hash.DoubleSha256(proofs[i]) + // no errors in Put to memory store + _ = tr.Store.Put(makeStorageKey(h[:]), proofs[i]) + } + _, bs, err := tr.getWithPath(tr.root, path) + return bs, err == nil +} diff --git a/pkg/core/mpt/proof_test.go b/pkg/core/mpt/proof_test.go new file mode 100644 index 000000000..17301af15 --- /dev/null +++ b/pkg/core/mpt/proof_test.go @@ -0,0 +1,73 @@ +package mpt + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func newProofTrie(t *testing.T) *Trie { + l := NewLeafNode([]byte("somevalue")) + e := NewExtensionNode([]byte{0x05, 0x06, 0x07}, l) + l2 := NewLeafNode([]byte("invalid")) + e2 := NewExtensionNode([]byte{0x05}, NewHashNode(l2.Hash())) + b := NewBranchNode() + b.Children[4] = NewHashNode(e.Hash()) + b.Children[5] = e2 + + tr := NewTrie(b, newTestStore()) + require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1"))) + require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) + tr.putToStore(l) + tr.putToStore(e) + return tr +} + +func TestTrie_GetProof(t *testing.T) { + tr := newProofTrie(t) + + t.Run("MissingKey", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x12}) + require.Error(t, err) + }) + + t.Run("Valid", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x12, 0x31}) + require.NoError(t, err) + }) + + t.Run("MissingHashNode", func(t *testing.T) { + _, err := tr.GetProof([]byte{0x55}) + require.Error(t, err) + }) +} + +func TestVerifyProof(t *testing.T) { + tr := newProofTrie(t) + + t.Run("Simple", func(t *testing.T) { + proof, err := tr.GetProof([]byte{0x12, 0x32}) + require.NoError(t, err) + + t.Run("Good", func(t *testing.T) { + v, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x32}, proof) + require.True(t, ok) + require.Equal(t, []byte("value2"), v) + }) + + t.Run("Bad", func(t *testing.T) { + _, ok := VerifyProof(tr.root.Hash(), []byte{0x12, 0x31}, proof) + require.False(t, ok) + }) + }) + + t.Run("InsideHash", func(t *testing.T) { + key := []byte{0x45, 0x67} + proof, err := tr.GetProof(key) + require.NoError(t, err) + + v, ok := VerifyProof(tr.root.Hash(), key, proof) + require.True(t, ok) + require.Equal(t, []byte("somevalue"), v) + }) +}