mpt: implement JSON marshaling/unmarshaling

Because there is no distinct type field in JSONized nodes, distinction
is made via payload itself, thus all unmarshaling is done via
NodeObject.
This commit is contained in:
Evgenii Stratonikov 2020-05-28 08:55:12 +03:00
parent 31d9aeddd2
commit 9c478378e1
6 changed files with 234 additions and 10 deletions

View file

@ -1,6 +1,9 @@
package mpt package mpt
import ( import (
"encoding/json"
"errors"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -69,6 +72,23 @@ func (b *BranchNode) DecodeBinary(r *io.BinReader) {
} }
} }
// MarshalJSON implements json.Marshaler.
func (b *BranchNode) MarshalJSON() ([]byte, error) {
return json.Marshal(b.Children)
}
// UnmarshalJSON implements json.Unmarshaler.
func (b *BranchNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*BranchNode); ok {
*b = *u
return nil
}
return errors.New("expected branch node")
}
// splitPath splits path for a branch node. // splitPath splits path for a branch node.
func splitPath(path []byte) (byte, []byte) { func splitPath(path []byte) (byte, []byte) {
if len(path) != 0 { if len(path) != 0 {

View file

@ -1,6 +1,9 @@
package mpt package mpt
import ( import (
"encoding/hex"
"encoding/json"
"errors"
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
@ -68,3 +71,24 @@ func (e ExtensionNode) EncodeBinary(w *io.BinWriter) {
n := NewHashNode(e.next.Hash()) n := NewHashNode(e.next.Hash())
n.EncodeBinary(w) n.EncodeBinary(w)
} }
// MarshalJSON implements json.Marshaler.
func (e *ExtensionNode) MarshalJSON() ([]byte, error) {
m := map[string]interface{}{
"key": hex.EncodeToString(e.key),
"next": e.next,
}
return json.Marshal(m)
}
// UnmarshalJSON implements json.Unmarshaler.
func (e *ExtensionNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*ExtensionNode); ok {
*e = *u
return nil
}
return errors.New("expected extension node")
}

View file

@ -1,6 +1,7 @@
package mpt package mpt
import ( import (
"errors"
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
@ -59,3 +60,23 @@ func (h HashNode) EncodeBinary(w *io.BinWriter) {
} }
w.WriteVarBytes(h.hash[:]) w.WriteVarBytes(h.hash[:])
} }
// MarshalJSON implements json.Marshaler.
func (h *HashNode) MarshalJSON() ([]byte, error) {
if !h.valid {
return []byte(`{}`), nil
}
return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (h *HashNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*HashNode); ok {
*h = *u
return nil
}
return errors.New("expected hash node")
}

View file

@ -1,6 +1,8 @@
package mpt package mpt
import ( import (
"encoding/hex"
"errors"
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
@ -54,3 +56,20 @@ func (n *LeafNode) DecodeBinary(r *io.BinReader) {
func (n LeafNode) EncodeBinary(w *io.BinWriter) { func (n LeafNode) EncodeBinary(w *io.BinWriter) {
w.WriteVarBytes(n.value) w.WriteVarBytes(n.value)
} }
// MarshalJSON implements json.Marshaler.
func (n *LeafNode) MarshalJSON() ([]byte, error) {
return []byte(`{"value":"` + hex.EncodeToString(n.value) + `"}`), nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *LeafNode) UnmarshalJSON(data []byte) error {
var obj NodeObject
if err := obj.UnmarshalJSON(data); err != nil {
return err
} else if u, ok := obj.Node.(*LeafNode); ok {
*n = *u
return nil
}
return errors.New("expected leaf node")
}

View file

@ -1,6 +1,9 @@
package mpt package mpt
import ( import (
"encoding/hex"
"encoding/json"
"errors"
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
@ -28,6 +31,8 @@ type NodeObject struct {
// Node represents common interface of all MPT nodes. // Node represents common interface of all MPT nodes.
type Node interface { type Node interface {
io.Serializable io.Serializable
json.Marshaler
json.Unmarshaler
Hash() util.Uint256 Hash() util.Uint256
Type() NodeType Type() NodeType
} }
@ -68,3 +73,76 @@ func toBytes(n Node) []byte {
encodeNodeWithType(n, buf.BinWriter) encodeNodeWithType(n, buf.BinWriter)
return buf.Bytes() return buf.Bytes()
} }
// UnmarshalJSON implements json.Unmarshaler.
func (n *NodeObject) UnmarshalJSON(data []byte) error {
var m map[string]json.RawMessage
err := json.Unmarshal(data, &m)
if err != nil { // it can be a branch node
var nodes []NodeObject
if err := json.Unmarshal(data, &nodes); err != nil {
return err
} else if len(nodes) != childrenCount {
return errors.New("invalid length of branch node")
}
b := NewBranchNode()
for i := range b.Children {
b.Children[i] = nodes[i].Node
}
n.Node = b
return nil
}
switch len(m) {
case 0:
n.Node = new(HashNode)
case 1:
if v, ok := m["hash"]; ok {
var h util.Uint256
if err := json.Unmarshal(v, &h); err != nil {
return err
}
n.Node = NewHashNode(h)
} else if v, ok = m["value"]; ok {
b, err := unmarshalHex(v)
if err != nil {
return err
} else if len(b) > MaxValueLength {
return errors.New("leaf value is too big")
}
n.Node = NewLeafNode(b)
} else {
return errors.New("invalid field")
}
case 2:
keyRaw, ok1 := m["key"]
nextRaw, ok2 := m["next"]
if !ok1 || !ok2 {
return errors.New("invalid field")
}
key, err := unmarshalHex(keyRaw)
if err != nil {
return err
} else if len(key) > MaxKeyLength {
return errors.New("extension key is too big")
}
var next NodeObject
if err := json.Unmarshal(nextRaw, &next); err != nil {
return err
}
n.Node = NewExtensionNode(key, next.Node)
default:
return errors.New("0, 1 or 2 fields expected")
}
return nil
}
func unmarshalHex(data json.RawMessage) ([]byte, error) {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return hex.DecodeString(s)
}

View file

@ -1,26 +1,42 @@
package mpt package mpt
import ( import (
"encoding/json"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) { func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
bs, err := testserdes.EncodeBinary(expected) t.Run("IO", func(t *testing.T) {
require.NoError(t, err) bs, err := testserdes.EncodeBinary(expected)
err = testserdes.DecodeBinary(bs, actual) require.NoError(t, err)
if !ok { err = testserdes.DecodeBinary(bs, actual)
require.Error(t, err) if !ok {
return require.Error(t, err)
} return
require.NoError(t, err) }
require.Equal(t, expected.Type(), actual.Type()) require.NoError(t, err)
require.Equal(t, expected.Hash(), actual.Hash()) require.Equal(t, expected.Type(), actual.Type())
require.Equal(t, expected.Hash(), actual.Hash())
})
t.Run("JSON", func(t *testing.T) {
bs, err := json.Marshal(expected)
require.NoError(t, err)
err = json.Unmarshal(bs, actual)
if !ok {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, expected.Type(), actual.Type())
require.Equal(t, expected.Hash(), actual.Hash())
})
} }
} }
@ -74,6 +90,52 @@ func TestNode_Serializable(t *testing.T) {
}) })
} }
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198
func TestJSONSharp(t *testing.T) {
tr := NewTrie(nil, newTestStore())
require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11}))
require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22}))
require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac}))
require.NoError(t, tr.Delete([]byte{0xac, 0x11}))
require.NoError(t, tr.Delete([]byte{0xac, 0x22}))
js, err := tr.root.MarshalJSON()
require.NoError(t, err)
require.JSONEq(t, `{"key":"0a0c", "next":{"value":"ac"}}`, string(js))
}
func TestInvalidJSON(t *testing.T) {
t.Run("InvalidChildrenCount", func(t *testing.T) {
var cs [childrenCount + 1]Node
for i := range cs {
cs[i] = new(HashNode)
}
data, err := json.Marshal(cs)
require.NoError(t, err)
var n NodeObject
require.Error(t, json.Unmarshal(data, &n))
})
testCases := []struct {
name string
data []byte
}{
{"WrongFieldCount", []byte(`{"key":"0102", "next": {}, "field": {}}`)},
{"InvalidField1", []byte(`{"next":{}}`)},
{"InvalidField2", []byte(`{"key":"0102", "hash":{}}`)},
{"InvalidKey", []byte(`{"key":"xy", "next":{}}`)},
{"InvalidNext", []byte(`{"key":"01", "next":[]}`)},
{"InvalidHash", []byte(`{"hash":"01"}`)},
{"InvalidValue", []byte(`{"value":1}`)},
{"InvalidBranch", []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]`)},
}
for _, tc := range testCases {
var n NodeObject
assert.Errorf(t, json.Unmarshal(tc.data, &n), "no error in "+tc.name)
}
}
// C# interoperability test // C# interoperability test
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135 // https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135
func TestRootHash(t *testing.T) { func TestRootHash(t *testing.T) {