diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go index ac3f2400f..c4a383075 100644 --- a/pkg/core/mpt/branch.go +++ b/pkg/core/mpt/branch.go @@ -1,6 +1,9 @@ package mpt import ( + "encoding/json" + "errors" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -69,6 +72,23 @@ func (b *BranchNode) DecodeBinary(r *io.BinReader) { } } +// MarshalJSON implements json.Marshaler. +func (b *BranchNode) MarshalJSON() ([]byte, error) { + return json.Marshal(b.Children) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *BranchNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*BranchNode); ok { + *b = *u + return nil + } + return errors.New("expected branch node") +} + // splitPath splits path for a branch node. func splitPath(path []byte) (byte, []byte) { if len(path) != 0 { diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go index 775078827..a337c4de2 100644 --- a/pkg/core/mpt/extension.go +++ b/pkg/core/mpt/extension.go @@ -1,6 +1,9 @@ package mpt import ( + "encoding/hex" + "encoding/json" + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -68,3 +71,24 @@ func (e ExtensionNode) EncodeBinary(w *io.BinWriter) { n := NewHashNode(e.next.Hash()) n.EncodeBinary(w) } + +// MarshalJSON implements json.Marshaler. +func (e *ExtensionNode) MarshalJSON() ([]byte, error) { + m := map[string]interface{}{ + "key": hex.EncodeToString(e.key), + "next": e.next, + } + return json.Marshal(m) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (e *ExtensionNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*ExtensionNode); ok { + *e = *u + return nil + } + return errors.New("expected extension node") +} diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go index a14dec879..51c6095fd 100644 --- a/pkg/core/mpt/hash.go +++ b/pkg/core/mpt/hash.go @@ -1,6 +1,7 @@ package mpt import ( + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/io" @@ -59,3 +60,23 @@ func (h HashNode) EncodeBinary(w *io.BinWriter) { } w.WriteVarBytes(h.hash[:]) } + +// MarshalJSON implements json.Marshaler. +func (h *HashNode) MarshalJSON() ([]byte, error) { + if !h.valid { + return []byte(`{}`), nil + } + return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (h *HashNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*HashNode); ok { + *h = *u + return nil + } + return errors.New("expected hash node") +} diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go index 455ae3feb..4ae509a1c 100644 --- a/pkg/core/mpt/leaf.go +++ b/pkg/core/mpt/leaf.go @@ -1,6 +1,8 @@ package mpt import ( + "encoding/hex" + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -54,3 +56,20 @@ func (n *LeafNode) DecodeBinary(r *io.BinReader) { func (n LeafNode) EncodeBinary(w *io.BinWriter) { w.WriteVarBytes(n.value) } + +// MarshalJSON implements json.Marshaler. +func (n *LeafNode) MarshalJSON() ([]byte, error) { + return []byte(`{"value":"` + hex.EncodeToString(n.value) + `"}`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (n *LeafNode) UnmarshalJSON(data []byte) error { + var obj NodeObject + if err := obj.UnmarshalJSON(data); err != nil { + return err + } else if u, ok := obj.Node.(*LeafNode); ok { + *n = *u + return nil + } + return errors.New("expected leaf node") +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index 83abb2e60..53a2fdec1 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -1,6 +1,9 @@ package mpt import ( + "encoding/hex" + "encoding/json" + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/io" @@ -28,6 +31,8 @@ type NodeObject struct { // Node represents common interface of all MPT nodes. type Node interface { io.Serializable + json.Marshaler + json.Unmarshaler Hash() util.Uint256 Type() NodeType } @@ -68,3 +73,76 @@ func toBytes(n Node) []byte { encodeNodeWithType(n, buf.BinWriter) return buf.Bytes() } + +// UnmarshalJSON implements json.Unmarshaler. +func (n *NodeObject) UnmarshalJSON(data []byte) error { + var m map[string]json.RawMessage + err := json.Unmarshal(data, &m) + if err != nil { // it can be a branch node + var nodes []NodeObject + if err := json.Unmarshal(data, &nodes); err != nil { + return err + } else if len(nodes) != childrenCount { + return errors.New("invalid length of branch node") + } + + b := NewBranchNode() + for i := range b.Children { + b.Children[i] = nodes[i].Node + } + n.Node = b + return nil + } + + switch len(m) { + case 0: + n.Node = new(HashNode) + case 1: + if v, ok := m["hash"]; ok { + var h util.Uint256 + if err := json.Unmarshal(v, &h); err != nil { + return err + } + n.Node = NewHashNode(h) + } else if v, ok = m["value"]; ok { + b, err := unmarshalHex(v) + if err != nil { + return err + } else if len(b) > MaxValueLength { + return errors.New("leaf value is too big") + } + n.Node = NewLeafNode(b) + } else { + return errors.New("invalid field") + } + case 2: + keyRaw, ok1 := m["key"] + nextRaw, ok2 := m["next"] + if !ok1 || !ok2 { + return errors.New("invalid field") + } + key, err := unmarshalHex(keyRaw) + if err != nil { + return err + } else if len(key) > MaxKeyLength { + return errors.New("extension key is too big") + } + + var next NodeObject + if err := json.Unmarshal(nextRaw, &next); err != nil { + return err + } + n.Node = NewExtensionNode(key, next.Node) + default: + return errors.New("0, 1 or 2 fields expected") + } + return nil +} + +func unmarshalHex(data json.RawMessage) ([]byte, error) { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return hex.DecodeString(s) +} diff --git a/pkg/core/mpt/node_test.go b/pkg/core/mpt/node_test.go index 0e2c17c96..e3aab54d6 100644 --- a/pkg/core/mpt/node_test.go +++ b/pkg/core/mpt/node_test.go @@ -1,26 +1,42 @@ package mpt import ( + "encoding/json" "testing" "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) { return func(t *testing.T) { - bs, err := testserdes.EncodeBinary(expected) - require.NoError(t, err) - err = testserdes.DecodeBinary(bs, actual) - if !ok { - require.Error(t, err) - return - } - require.NoError(t, err) - require.Equal(t, expected.Type(), actual.Type()) - require.Equal(t, expected.Hash(), actual.Hash()) + t.Run("IO", func(t *testing.T) { + bs, err := testserdes.EncodeBinary(expected) + require.NoError(t, err) + err = testserdes.DecodeBinary(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + }) + t.Run("JSON", func(t *testing.T) { + bs, err := json.Marshal(expected) + require.NoError(t, err) + err = json.Unmarshal(bs, actual) + if !ok { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, expected.Type(), actual.Type()) + require.Equal(t, expected.Hash(), actual.Hash()) + }) } } @@ -74,6 +90,52 @@ func TestNode_Serializable(t *testing.T) { }) } +// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198 +func TestJSONSharp(t *testing.T) { + tr := NewTrie(nil, newTestStore()) + require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11})) + require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22})) + require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac})) + require.NoError(t, tr.Delete([]byte{0xac, 0x11})) + require.NoError(t, tr.Delete([]byte{0xac, 0x22})) + + js, err := tr.root.MarshalJSON() + require.NoError(t, err) + require.JSONEq(t, `{"key":"0a0c", "next":{"value":"ac"}}`, string(js)) +} + +func TestInvalidJSON(t *testing.T) { + t.Run("InvalidChildrenCount", func(t *testing.T) { + var cs [childrenCount + 1]Node + for i := range cs { + cs[i] = new(HashNode) + } + data, err := json.Marshal(cs) + require.NoError(t, err) + + var n NodeObject + require.Error(t, json.Unmarshal(data, &n)) + }) + + testCases := []struct { + name string + data []byte + }{ + {"WrongFieldCount", []byte(`{"key":"0102", "next": {}, "field": {}}`)}, + {"InvalidField1", []byte(`{"next":{}}`)}, + {"InvalidField2", []byte(`{"key":"0102", "hash":{}}`)}, + {"InvalidKey", []byte(`{"key":"xy", "next":{}}`)}, + {"InvalidNext", []byte(`{"key":"01", "next":[]}`)}, + {"InvalidHash", []byte(`{"hash":"01"}`)}, + {"InvalidValue", []byte(`{"value":1}`)}, + {"InvalidBranch", []byte(`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]`)}, + } + for _, tc := range testCases { + var n NodeObject + assert.Errorf(t, json.Unmarshal(tc.data, &n), "no error in "+tc.name) + } +} + // C# interoperability test // https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L135 func TestRootHash(t *testing.T) {