diff --git a/pkg/core/mpt/base.go b/pkg/core/mpt/base.go index 9f10cc333..49c7b2d59 100644 --- a/pkg/core/mpt/base.go +++ b/pkg/core/mpt/base.go @@ -1,6 +1,8 @@ package mpt import ( + "fmt" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -82,3 +84,23 @@ func encodeNodeWithType(n Node, w *io.BinWriter) { w.WriteB(byte(n.Type())) n.EncodeBinary(w) } + +// DecodeNodeWithType decodes node together with it's type. +func DecodeNodeWithType(r *io.BinReader) Node { + var n Node + switch typ := NodeType(r.ReadB()); typ { + case BranchT: + n = new(BranchNode) + case ExtensionT: + n = new(ExtensionNode) + case HashT: + n = new(HashNode) + case LeafT: + n = new(LeafNode) + default: + r.Err = fmt.Errorf("invalid node type: %x", typ) + return nil + } + n.DecodeBinary(r) + return n +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index 86e675a01..04b085948 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "encoding/json" "errors" - "fmt" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -43,21 +42,7 @@ func (n NodeObject) EncodeBinary(w *io.BinWriter) { // DecodeBinary implements io.Serializable. func (n *NodeObject) DecodeBinary(r *io.BinReader) { - typ := NodeType(r.ReadB()) - switch typ { - case BranchT: - n.Node = new(BranchNode) - case ExtensionT: - n.Node = new(ExtensionNode) - case HashT: - n.Node = new(HashNode) - case LeafT: - n.Node = new(LeafNode) - default: - r.Err = fmt.Errorf("invalid node type: %x", typ) - return - } - n.Node.DecodeBinary(r) + n.Node = DecodeNodeWithType(r) } // UnmarshalJSON implements json.Unmarshaler.