From 7f038bd4650340d6da0941088a4aeaa356cfad98 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Tue, 30 Mar 2021 20:35:41 +0300 Subject: [PATCH] mpt: split HashNode in two types First type is non-empty HashNode, and the second one is an Empty node. --- pkg/core/mpt/base.go | 24 ++++++++++++++++++++++-- pkg/core/mpt/hash.go | 15 +++------------ pkg/core/mpt/node.go | 1 + pkg/core/mpt/node_test.go | 11 +++++++---- pkg/rpc/response/result/mpt_test.go | 6 ++++-- 5 files changed, 37 insertions(+), 20 deletions(-) diff --git a/pkg/core/mpt/base.go b/pkg/core/mpt/base.go index 38645ff05..8762281d6 100644 --- a/pkg/core/mpt/base.go +++ b/pkg/core/mpt/base.go @@ -78,7 +78,17 @@ func (b *BaseNode) invalidateCache() { // encodeNodeWithType encodes node together with it's type. func encodeNodeWithType(n Node, w *io.BinWriter) { - w.WriteB(byte(n.Type())) + switch t := n.Type(); t { + case HashT: + hn := n.(*HashNode) + if !hn.hashValid { + w.WriteB(byte(EmptyT)) + break + } + fallthrough + default: + w.WriteB(byte(t)) + } n.EncodeBinary(w) } @@ -94,9 +104,19 @@ func DecodeNodeWithType(r *io.BinReader) Node { case ExtensionT: n = new(ExtensionNode) case HashT: - n = new(HashNode) + n = &HashNode{ + BaseNode: BaseNode{ + hashValid: true, + }, + } case LeafT: n = new(LeafNode) + case EmptyT: + n = &HashNode{ + BaseNode: BaseNode{ + hashValid: false, + }, + } default: r.Err = fmt.Errorf("invalid node type: %x", typ) return nil diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go index 03c4fdc1d..ca24dc457 100644 --- a/pkg/core/mpt/hash.go +++ b/pkg/core/mpt/hash.go @@ -2,7 +2,6 @@ package mpt import ( "errors" - "fmt" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -46,25 +45,17 @@ func (h *HashNode) Bytes() []byte { // DecodeBinary implements io.Serializable. func (h *HashNode) DecodeBinary(r *io.BinReader) { - sz := r.ReadVarUint() - switch sz { - case 0: - h.hashValid = false - case util.Uint256Size: - h.hashValid = true - r.ReadBytes(h.hash[:]) - default: - r.Err = fmt.Errorf("invalid hash node size: %d", sz) + if h.hashValid { + h.hash.DecodeBinary(r) } } // EncodeBinary implements io.Serializable. func (h HashNode) EncodeBinary(w *io.BinWriter) { if !h.hashValid { - w.WriteVarUint(0) return } - w.WriteVarBytes(h.hash[:]) + w.WriteBytes(h.hash[:]) } // EncodeBinaryAsChild implements BaseNode interface. diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index befc216fd..2d2c42807 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -18,6 +18,7 @@ const ( ExtensionT NodeType = 0x01 LeafT NodeType = 0x02 HashT NodeType = 0x03 + EmptyT NodeType = 0x04 ) // NodeObject represents Node together with it's type. diff --git a/pkg/core/mpt/node_test.go b/pkg/core/mpt/node_test.go index 8dd65ac9d..74c0247a2 100644 --- a/pkg/core/mpt/node_test.go +++ b/pkg/core/mpt/node_test.go @@ -16,6 +16,9 @@ func getTestFuncEncode(ok bool, expected, actual Node) func(t *testing.T) { t.Run("IO", func(t *testing.T) { bs, err := testserdes.EncodeBinary(expected) require.NoError(t, err) + if hn, ok := actual.(*HashNode); ok { + hn.hashValid = true // this field is set during NodeObject decoding + } err = testserdes.DecodeBinary(bs, actual) if !ok { require.Error(t, err) @@ -80,8 +83,8 @@ func TestNode_Serializable(t *testing.T) { }) t.Run("InvalidSize", func(t *testing.T) { buf := io.NewBufBinWriter() - buf.BinWriter.WriteVarBytes(make([]byte, 13)) - require.Error(t, testserdes.DecodeBinary(buf.Bytes(), new(HashNode))) + buf.BinWriter.WriteBytes(make([]byte, 13)) + require.Error(t, testserdes.DecodeBinary(buf.Bytes(), &HashNode{BaseNode: BaseNode{hashValid: true}})) }) }) @@ -151,6 +154,6 @@ func TestRootHash(t *testing.T) { b.Children[9] = l2 r1 := NewExtensionNode([]byte{0x0A, 0x0C, 0x00, 0x01}, v1) - require.Equal(t, "a6d1385fa2e089fd9ca79e58bee47cb4c9c949140a382580138840113412931d", r1.Hash().StringLE()) - require.Equal(t, "62d14dc02b9f905ca6ec73fb499b1eef835e482d936744e3b6298cf9ad26ba03", r.Hash().StringLE()) + require.Equal(t, "cedd9897dd1559fbd5dfe5cfb223464da6de438271028afb8d647e950cbd18e0", r1.Hash().StringLE()) + require.Equal(t, "1037e779c8a0313bd0d99c4151fa70a277c43c53a549b6444079f2e67e8ffb7b", r.Hash().StringLE()) } diff --git a/pkg/rpc/response/result/mpt_test.go b/pkg/rpc/response/result/mpt_test.go index 91adeaabb..e328288f8 100644 --- a/pkg/rpc/response/result/mpt_test.go +++ b/pkg/rpc/response/result/mpt_test.go @@ -1,10 +1,13 @@ package result import ( + "encoding/json" "testing" "github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/stretchr/testify/require" ) @@ -19,7 +22,6 @@ func testProofWithKey() *ProofWithKey { } } -/* func TestGetProof_MarshalJSON(t *testing.T) { t.Run("Good", func(t *testing.T) { p := testProofWithKey() @@ -40,7 +42,7 @@ func TestGetProof_MarshalJSON(t *testing.T) { } }) } -*/ + func TestProofWithKey_EncodeString(t *testing.T) { expected := testProofWithKey() var actual ProofWithKey