From df1792c80b2f99629b698288c8beeacfc29daf5f Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 23 Oct 2020 10:55:08 +0300 Subject: [PATCH] mpt: export func for decoding node with type `NodeObject` can contain auxilliary fields and shouldn't be used from outside. --- pkg/core/mpt/base.go | 22 ++++++++++++++++++++++ pkg/core/mpt/node.go | 17 +---------------- pkg/rpc/response/result/mpt_test.go | 5 ++--- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pkg/core/mpt/base.go b/pkg/core/mpt/base.go index 9f10cc333..49c7b2d59 100644 --- a/pkg/core/mpt/base.go +++ b/pkg/core/mpt/base.go @@ -1,6 +1,8 @@ package mpt import ( + "fmt" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -82,3 +84,23 @@ func encodeNodeWithType(n Node, w *io.BinWriter) { w.WriteB(byte(n.Type())) n.EncodeBinary(w) } + +// DecodeNodeWithType decodes node together with it's type. +func DecodeNodeWithType(r *io.BinReader) Node { + var n Node + switch typ := NodeType(r.ReadB()); typ { + case BranchT: + n = new(BranchNode) + case ExtensionT: + n = new(ExtensionNode) + case HashT: + n = new(HashNode) + case LeafT: + n = new(LeafNode) + default: + r.Err = fmt.Errorf("invalid node type: %x", typ) + return nil + } + n.DecodeBinary(r) + return n +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index 86e675a01..04b085948 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "encoding/json" "errors" - "fmt" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -43,21 +42,7 @@ func (n NodeObject) EncodeBinary(w *io.BinWriter) { // DecodeBinary implements io.Serializable. func (n *NodeObject) DecodeBinary(r *io.BinReader) { - typ := NodeType(r.ReadB()) - switch typ { - case BranchT: - n.Node = new(BranchNode) - case ExtensionT: - n.Node = new(ExtensionNode) - case HashT: - n.Node = new(HashNode) - case LeafT: - n.Node = new(LeafNode) - default: - r.Err = fmt.Errorf("invalid node type: %x", typ) - return - } - n.Node.DecodeBinary(r) + n.Node = DecodeNodeWithType(r) } // UnmarshalJSON implements json.Unmarshaler. diff --git a/pkg/rpc/response/result/mpt_test.go b/pkg/rpc/response/result/mpt_test.go index 22e0c021c..27173fb9b 100644 --- a/pkg/rpc/response/result/mpt_test.go +++ b/pkg/rpc/response/result/mpt_test.go @@ -41,10 +41,9 @@ func TestGetProof_MarshalJSON(t *testing.T) { require.Equal(t, 8, len(p.Result.Proof)) for i := range p.Result.Proof { // smoke test that every chunk is correctly encoded node r := io.NewBinReaderFromBuf(p.Result.Proof[i]) - var n mpt.NodeObject - n.DecodeBinary(r) + n := mpt.DecodeNodeWithType(r) require.NoError(t, r.Err) - require.NotNil(t, n.Node) + require.NotNil(t, n) } }) }