neoneo-go/pkg/core/mpt/node.go

149 lines
3.2 KiB
Go
Raw Normal View History

package mpt
import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// NodeType represents node type..
type NodeType byte
// Node types definitions.
const (
BranchT NodeType = 0x00
ExtensionT NodeType = 0x01
HashT NodeType = 0x02
LeafT NodeType = 0x03
)
// NodeObject represents Node together with it's type.
// It is used for serialization/deserialization where type info
// is also expected.
type NodeObject struct {
Node
}
// Node represents common interface of all MPT nodes.
type Node interface {
io.Serializable
json.Marshaler
json.Unmarshaler
Hash() util.Uint256
Type() NodeType
}
// EncodeBinary implements io.Serializable.
func (n NodeObject) EncodeBinary(w *io.BinWriter) {
encodeNodeWithType(n.Node, w)
}
// DecodeBinary implements io.Serializable.
func (n *NodeObject) DecodeBinary(r *io.BinReader) {
typ := NodeType(r.ReadB())
switch typ {
case BranchT:
n.Node = new(BranchNode)
case ExtensionT:
n.Node = new(ExtensionNode)
case HashT:
n.Node = new(HashNode)
case LeafT:
n.Node = new(LeafNode)
default:
r.Err = fmt.Errorf("invalid node type: %x", typ)
return
}
n.Node.DecodeBinary(r)
}
// encodeNodeWithType encodes node together with it's type.
func encodeNodeWithType(n Node, w *io.BinWriter) {
w.WriteB(byte(n.Type()))
n.EncodeBinary(w)
}
// toBytes is a helper for serializing node.
func toBytes(n Node) []byte {
buf := io.NewBufBinWriter()
encodeNodeWithType(n, buf.BinWriter)
return buf.Bytes()
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *NodeObject) UnmarshalJSON(data []byte) error {
var m map[string]json.RawMessage
err := json.Unmarshal(data, &m)
if err != nil { // it can be a branch node
var nodes []NodeObject
if err := json.Unmarshal(data, &nodes); err != nil {
return err
} else if len(nodes) != childrenCount {
return errors.New("invalid length of branch node")
}
b := NewBranchNode()
for i := range b.Children {
b.Children[i] = nodes[i].Node
}
n.Node = b
return nil
}
switch len(m) {
case 0:
n.Node = new(HashNode)
case 1:
if v, ok := m["hash"]; ok {
var h util.Uint256
if err := json.Unmarshal(v, &h); err != nil {
return err
}
n.Node = NewHashNode(h)
} else if v, ok = m["value"]; ok {
b, err := unmarshalHex(v)
if err != nil {
return err
} else if len(b) > MaxValueLength {
return errors.New("leaf value is too big")
}
n.Node = NewLeafNode(b)
} else {
return errors.New("invalid field")
}
case 2:
keyRaw, ok1 := m["key"]
nextRaw, ok2 := m["next"]
if !ok1 || !ok2 {
return errors.New("invalid field")
}
key, err := unmarshalHex(keyRaw)
if err != nil {
return err
} else if len(key) > MaxKeyLength {
return errors.New("extension key is too big")
}
var next NodeObject
if err := json.Unmarshal(nextRaw, &next); err != nil {
return err
}
n.Node = NewExtensionNode(key, next.Node)
default:
return errors.New("0, 1 or 2 fields expected")
}
return nil
}
func unmarshalHex(data json.RawMessage) ([]byte, error) {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return hex.DecodeString(s)
}