mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-12-05 03:58:23 +00:00
mpt: implement JSON marshaling/unmarshaling
Because there is no distinct type field in JSONized nodes, distinction is made via payload itself, thus all unmarshaling is done via NodeObject.
This commit is contained in:
parent
31d9aeddd2
commit
9c478378e1
6 changed files with 234 additions and 10 deletions
|
@ -1,6 +1,9 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
|
@ -69,6 +72,23 @@ func (b *BranchNode) DecodeBinary(r *io.BinReader) {
|
|||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (b *BranchNode) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(b.Children)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (b *BranchNode) UnmarshalJSON(data []byte) error {
|
||||
var obj NodeObject
|
||||
if err := obj.UnmarshalJSON(data); err != nil {
|
||||
return err
|
||||
} else if u, ok := obj.Node.(*BranchNode); ok {
|
||||
*b = *u
|
||||
return nil
|
||||
}
|
||||
return errors.New("expected branch node")
|
||||
}
|
||||
|
||||
// splitPath splits path for a branch node.
|
||||
func splitPath(path []byte) (byte, []byte) {
|
||||
if len(path) != 0 {
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
|
@ -68,3 +71,24 @@ func (e ExtensionNode) EncodeBinary(w *io.BinWriter) {
|
|||
n := NewHashNode(e.next.Hash())
|
||||
n.EncodeBinary(w)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (e *ExtensionNode) MarshalJSON() ([]byte, error) {
|
||||
m := map[string]interface{}{
|
||||
"key": hex.EncodeToString(e.key),
|
||||
"next": e.next,
|
||||
}
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (e *ExtensionNode) UnmarshalJSON(data []byte) error {
|
||||
var obj NodeObject
|
||||
if err := obj.UnmarshalJSON(data); err != nil {
|
||||
return err
|
||||
} else if u, ok := obj.Node.(*ExtensionNode); ok {
|
||||
*e = *u
|
||||
return nil
|
||||
}
|
||||
return errors.New("expected extension node")
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
|
@ -59,3 +60,23 @@ func (h HashNode) EncodeBinary(w *io.BinWriter) {
|
|||
}
|
||||
w.WriteVarBytes(h.hash[:])
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (h *HashNode) MarshalJSON() ([]byte, error) {
|
||||
if !h.valid {
|
||||
return []byte(`{}`), nil
|
||||
}
|
||||
return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (h *HashNode) UnmarshalJSON(data []byte) error {
|
||||
var obj NodeObject
|
||||
if err := obj.UnmarshalJSON(data); err != nil {
|
||||
return err
|
||||
} else if u, ok := obj.Node.(*HashNode); ok {
|
||||
*h = *u
|
||||
return nil
|
||||
}
|
||||
return errors.New("expected hash node")
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
|
@ -54,3 +56,20 @@ func (n *LeafNode) DecodeBinary(r *io.BinReader) {
|
|||
func (n LeafNode) EncodeBinary(w *io.BinWriter) {
|
||||
w.WriteVarBytes(n.value)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (n *LeafNode) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`{"value":"` + hex.EncodeToString(n.value) + `"}`), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (n *LeafNode) UnmarshalJSON(data []byte) error {
|
||||
var obj NodeObject
|
||||
if err := obj.UnmarshalJSON(data); err != nil {
|
||||
return err
|
||||
} else if u, ok := obj.Node.(*LeafNode); ok {
|
||||
*n = *u
|
||||
return nil
|
||||
}
|
||||
return errors.New("expected leaf node")
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
|
@ -28,6 +31,8 @@ type NodeObject struct {
|
|||
// Node represents common interface of all MPT nodes.
|
||||
type Node interface {
|
||||
io.Serializable
|
||||
json.Marshaler
|
||||
json.Unmarshaler
|
||||
Hash() util.Uint256
|
||||
Type() NodeType
|
||||
}
|
||||
|
@ -68,3 +73,76 @@ func toBytes(n Node) []byte {
|
|||
encodeNodeWithType(n, buf.BinWriter)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (n *NodeObject) UnmarshalJSON(data []byte) error {
|
||||
var m map[string]json.RawMessage
|
||||
err := json.Unmarshal(data, &m)
|
||||
if err != nil { // it can be a branch node
|
||||
var nodes []NodeObject
|
||||
if err := json.Unmarshal(data, &nodes); err != nil {
|
||||
return err
|
||||
} else if len(nodes) != childrenCount {
|
||||
return errors.New("invalid length of branch node")
|
||||
}
|
||||
|
||||
b := NewBranchNode()
|
||||
for i := range b.Children {
|
||||
b.Children[i] = nodes[i].Node
|
||||
}
|
||||
n.Node = b
|
||||
return nil
|
||||
}
|
||||
|
||||
switch len(m) {
|
||||
case 0:
|
||||
n.Node = new(HashNode)
|
||||
case 1:
|
||||
if v, ok := m["hash"]; ok {
|
||||
var h util.Uint256
|
||||
if err := json.Unmarshal(v, &h); err != nil {
|
||||
return err
|
||||
}
|
||||
n.Node = NewHashNode(h)
|
||||
} else if v, ok = m["value"]; ok {
|
||||
b, err := unmarshalHex(v)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(b) > MaxValueLength {
|
||||
return errors.New("leaf value is too big")
|
||||
}
|
||||
n.Node = NewLeafNode(b)
|
||||
} else {
|
||||
return errors.New("invalid field")
|
||||
}
|
||||
case 2:
|
||||
keyRaw, ok1 := m["key"]
|
||||
nextRaw, ok2 := m["next"]
|
||||
if !ok1 || !ok2 {
|
||||
return errors.New("invalid field")
|
||||
}
|
||||
key, err := unmarshalHex(keyRaw)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(key) > MaxKeyLength {
|
||||
return errors.New("extension key is too big")
|
||||
}
|
||||
|
||||
var next NodeObject
|
||||
if err := json.Unmarshal(nextRaw, &next); err != nil {
|
||||
return err
|
||||
}
|
||||
n.Node = NewExtensionNode(key, next.Node)
|
||||
default:
|
||||
return errors.New("0, 1 or 2 fields expected")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalHex(data json.RawMessage) ([]byte, error) {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hex.DecodeString(s)
|
||||
}
|
||||
|
|
|
@ -1,26 +1,42 @@
|
|||
package mpt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
bs, err := testserdes.EncodeBinary(expected)
|
||||
require.NoError(t, err)
|
||||
err = testserdes.DecodeBinary(bs, actual)
|
||||
if !ok {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected.Type(), actual.Type())
|
||||
require.Equal(t, expected.Hash(), actual.Hash())
|
||||
t.Run("IO", func(t *testing.T) {
|
||||
bs, err := testserdes.EncodeBinary(expected)
|
||||
require.NoError(t, err)
|
||||
err = testserdes.DecodeBinary(bs, actual)
|
||||
if !ok {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected.Type(), actual.Type())
|
||||
require.Equal(t, expected.Hash(), actual.Hash())
|
||||
})
|
||||
t.Run("JSON", func(t *testing.T) {
|
||||
bs, err := json.Marshal(expected)
|
||||
require.NoError(t, err)
|
||||
err = json.Unmarshal(bs, actual)
|
||||
if !ok {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected.Type(), actual.Type())
|
||||
require.Equal(t, expected.Hash(), actual.Hash())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,6 +90,52 @@ func TestNode_Serializable(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198
|
||||
func TestJSONSharp(t *testing.T) {
|
||||
tr := NewTrie(nil, newTestStore())
|
||||
require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11}))
|
||||
require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22}))
|
||||
require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac}))
|
||||
require.NoError(t, tr.Delete([]byte{0xac, 0x11}))
|
||||
require.NoError(t, tr.Delete([]byte{0xac, 0x22}))
|
||||
|
||||
js, err := tr.root.MarshalJSON()
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"key":"0a0c", "next":{"value":"ac"}}`, string(js))
|
||||
}
|
||||
|
||||
func TestInvalidJSON(t *testing.T) {
|
||||
t.Run("InvalidChildrenCount", func(t *testing.T) {
|
||||
var cs [childrenCount + 1]Node
|
||||
for i := range cs {
|
||||
cs[i] = new(HashNode)
|
||||
}
|
||||
data, err := json.Marshal(cs)
|
||||
require.NoError(t, err)
|
||||
|
||||
var n NodeObject
|
||||
require.Error(t, json.Unmarshal(data, &n))
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{"WrongFieldCount", []byte(`{"key":"0102", "next": {}, "field": {}}`)},
|
||||
{"InvalidField1", []byte(`{"next":{}}`)},
|
||||
{"InvalidField2", []byte(`{"key":"0102", "hash":{}}`)},
|
||||
{"InvalidKey", []byte(`{"key":"xy", "next":{}}`)},
|
||||
{"InvalidNext", []byte(`{"key":"01", "next":[]}`)},
|
||||
{"InvalidHash", []byte(`{"hash":"01"}`)},
|
||||
{"InvalidValue", []byte(`{"value":1}`)},
|
||||
{"InvalidBranch", []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]`)},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
var n NodeObject
|
||||
assert.Errorf(t, json.Unmarshal(tc.data, &n), "no error in "+tc.name)
|
||||
}
|
||||
}
|
||||
|
||||
// C# interoperability test
|
||||
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135
|
||||
func TestRootHash(t *testing.T) {
|
||||
|
|
Loading…
Reference in a new issue