mpt: restructure nodes a bit, implement serialization and hash cache

It drastically reduces the number of allocations and hash calculations.

Signed-off-by: Evgenii Stratonikov <evgeniy@nspcc.ru>
This commit is contained in:
Roman Khimov 2020-06-02 23:49:46 +03:00 committed by Evgenii Stratonikov
parent 6ca22027d5
commit 475bf2445a
8 changed files with 117 additions and 75 deletions

69
pkg/core/mpt/base.go Normal file
View file

@ -0,0 +1,69 @@
package mpt
import (
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// BaseNode implements basic things every node needs like caching hash and
// serialized representation. It's a basic node building block intended to be
// included into all node types.
type BaseNode struct {
hash util.Uint256
bytes []byte
hashValid bool
bytesValid bool
}
// BaseNodeIface abstracts away basic Node functions.
type BaseNodeIface interface {
Hash() util.Uint256
Type() NodeType
Bytes() []byte
}
// getHash returns a hash of this BaseNode.
func (b *BaseNode) getHash(n Node) util.Uint256 {
if !b.hashValid {
b.updateHash(n)
}
return b.hash
}
// getBytes returns a slice of bytes representing this node.
func (b *BaseNode) getBytes(n Node) []byte {
if !b.bytesValid {
b.updateBytes(n)
}
return b.bytes
}
// updateHash updates hash field for this BaseNode.
func (b *BaseNode) updateHash(n Node) {
if n.Type() == HashT {
panic("can't update hash for hash node")
}
b.hash = hash.DoubleSha256(b.getBytes(n))
b.hashValid = true
}
// updateCache updates hash and bytes fields for this BaseNode.
func (b *BaseNode) updateBytes(n Node) {
buf := io.NewBufBinWriter()
encodeNodeWithType(n, buf.BinWriter)
b.bytes = buf.Bytes()
b.bytesValid = true
}
// invalidateCache sets all cache fields to invalid state.
func (b *BaseNode) invalidateCache() {
b.bytesValid = false
b.hashValid = false
}
// encodeNodeWithType encodes node together with it's type.
func encodeNodeWithType(n Node, w *io.BinWriter) {
w.WriteB(byte(n.Type()))
n.EncodeBinary(w)
}

View file

@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
) )
@ -18,9 +17,7 @@ const (
// BranchNode represents MPT's branch node. // BranchNode represents MPT's branch node.
type BranchNode struct { type BranchNode struct {
hash util.Uint256 BaseNode
valid bool
Children [childrenCount]Node Children [childrenCount]Node
} }
@ -38,18 +35,14 @@ func NewBranchNode() *BranchNode {
// Type implements Node interface. // Type implements Node interface.
func (b *BranchNode) Type() NodeType { return BranchT } func (b *BranchNode) Type() NodeType { return BranchT }
// Hash implements Node interface. // Hash implements BaseNode interface.
func (b *BranchNode) Hash() util.Uint256 { func (b *BranchNode) Hash() util.Uint256 {
if !b.valid { return b.getHash(b)
b.hash = hash.DoubleSha256(toBytes(b))
b.valid = true
}
return b.hash
} }
// invalidateHash invalidates node hash. // Bytes implements BaseNode interface.
func (b *BranchNode) invalidateHash() { func (b *BranchNode) Bytes() []byte {
b.valid = false return b.getBytes(b)
} }
// EncodeBinary implements io.Serializable. // EncodeBinary implements io.Serializable.

View file

@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
) )
@ -16,9 +15,7 @@ const MaxKeyLength = 1125
// ExtensionNode represents MPT's extension node. // ExtensionNode represents MPT's extension node.
type ExtensionNode struct { type ExtensionNode struct {
hash util.Uint256 BaseNode
valid bool
key []byte key []byte
next Node next Node
} }
@ -37,18 +34,14 @@ func NewExtensionNode(key []byte, next Node) *ExtensionNode {
// Type implements Node interface. // Type implements Node interface.
func (e ExtensionNode) Type() NodeType { return ExtensionT } func (e ExtensionNode) Type() NodeType { return ExtensionT }
// Hash implements Node interface. // Hash implements BaseNode interface.
func (e *ExtensionNode) Hash() util.Uint256 { func (e *ExtensionNode) Hash() util.Uint256 {
if !e.valid { return e.getHash(e)
e.hash = hash.DoubleSha256(toBytes(e))
e.valid = true
}
return e.hash
} }
// invalidateHash invalidates node hash. // Bytes implements BaseNode interface.
func (e *ExtensionNode) invalidateHash() { func (e *ExtensionNode) Bytes() []byte {
e.valid = false return e.getBytes(e)
} }
// DecodeBinary implements io.Serializable. // DecodeBinary implements io.Serializable.
@ -58,11 +51,11 @@ func (e *ExtensionNode) DecodeBinary(r *io.BinReader) {
r.Err = fmt.Errorf("extension node key is too big: %d", sz) r.Err = fmt.Errorf("extension node key is too big: %d", sz)
return return
} }
e.valid = false
e.key = make([]byte, sz) e.key = make([]byte, sz)
r.ReadBytes(e.key) r.ReadBytes(e.key)
e.next = new(HashNode) e.next = new(HashNode)
e.next.DecodeBinary(r) e.next.DecodeBinary(r)
e.invalidateCache()
} }
// EncodeBinary implements io.Serializable. // EncodeBinary implements io.Serializable.

View file

@ -10,8 +10,7 @@ import (
// HashNode represents MPT's hash node. // HashNode represents MPT's hash node.
type HashNode struct { type HashNode struct {
hash util.Uint256 BaseNode
valid bool
} }
var _ Node = (*HashNode)(nil) var _ Node = (*HashNode)(nil)
@ -19,8 +18,10 @@ var _ Node = (*HashNode)(nil)
// NewHashNode returns hash node with the specified hash. // NewHashNode returns hash node with the specified hash.
func NewHashNode(h util.Uint256) *HashNode { func NewHashNode(h util.Uint256) *HashNode {
return &HashNode{ return &HashNode{
hash: h, BaseNode: BaseNode{
valid: true, hash: h,
hashValid: true,
},
} }
} }
@ -29,23 +30,28 @@ func (h *HashNode) Type() NodeType { return HashT }
// Hash implements Node interface. // Hash implements Node interface.
func (h *HashNode) Hash() util.Uint256 { func (h *HashNode) Hash() util.Uint256 {
if !h.valid { if !h.hashValid {
panic("can't get hash of an empty HashNode") panic("can't get hash of an empty HashNode")
} }
return h.hash return h.hash
} }
// IsEmpty returns true iff h is an empty node i.e. contains no hash. // IsEmpty returns true iff h is an empty node i.e. contains no hash.
func (h *HashNode) IsEmpty() bool { return !h.valid } func (h *HashNode) IsEmpty() bool { return !h.hashValid }
// Bytes returns serialized HashNode.
func (h *HashNode) Bytes() []byte {
return h.getBytes(h)
}
// DecodeBinary implements io.Serializable. // DecodeBinary implements io.Serializable.
func (h *HashNode) DecodeBinary(r *io.BinReader) { func (h *HashNode) DecodeBinary(r *io.BinReader) {
sz := r.ReadVarUint() sz := r.ReadVarUint()
switch sz { switch sz {
case 0: case 0:
h.valid = false h.hashValid = false
case util.Uint256Size: case util.Uint256Size:
h.valid = true h.hashValid = true
r.ReadBytes(h.hash[:]) r.ReadBytes(h.hash[:])
default: default:
r.Err = fmt.Errorf("invalid hash node size: %d", sz) r.Err = fmt.Errorf("invalid hash node size: %d", sz)
@ -54,7 +60,7 @@ func (h *HashNode) DecodeBinary(r *io.BinReader) {
// EncodeBinary implements io.Serializable. // EncodeBinary implements io.Serializable.
func (h HashNode) EncodeBinary(w *io.BinWriter) { func (h HashNode) EncodeBinary(w *io.BinWriter) {
if !h.valid { if !h.hashValid {
w.WriteVarUint(0) w.WriteVarUint(0)
return return
} }
@ -63,7 +69,7 @@ func (h HashNode) EncodeBinary(w *io.BinWriter) {
// MarshalJSON implements json.Marshaler. // MarshalJSON implements json.Marshaler.
func (h *HashNode) MarshalJSON() ([]byte, error) { func (h *HashNode) MarshalJSON() ([]byte, error) {
if !h.valid { if !h.hashValid {
return []byte(`{}`), nil return []byte(`{}`), nil
} }
return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil

View file

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
) )
@ -15,9 +14,7 @@ const MaxValueLength = 1024 * 1024
// LeafNode represents MPT's leaf node. // LeafNode represents MPT's leaf node.
type LeafNode struct { type LeafNode struct {
hash util.Uint256 BaseNode
valid bool
value []byte value []byte
} }
@ -31,13 +28,14 @@ func NewLeafNode(value []byte) *LeafNode {
// Type implements Node interface. // Type implements Node interface.
func (n LeafNode) Type() NodeType { return LeafT } func (n LeafNode) Type() NodeType { return LeafT }
// Hash implements Node interface. // Hash implements BaseNode interface.
func (n *LeafNode) Hash() util.Uint256 { func (n *LeafNode) Hash() util.Uint256 {
if !n.valid { return n.getHash(n)
n.hash = hash.DoubleSha256(toBytes(n)) }
n.valid = true
} // Bytes implements BaseNode interface.
return n.hash func (n *LeafNode) Bytes() []byte {
return n.getBytes(n)
} }
// DecodeBinary implements io.Serializable. // DecodeBinary implements io.Serializable.
@ -47,9 +45,9 @@ func (n *LeafNode) DecodeBinary(r *io.BinReader) {
r.Err = fmt.Errorf("leaf node value is too big: %d", sz) r.Err = fmt.Errorf("leaf node value is too big: %d", sz)
return return
} }
n.valid = false
n.value = make([]byte, sz) n.value = make([]byte, sz)
r.ReadBytes(n.value) r.ReadBytes(n.value)
n.invalidateCache()
} }
// EncodeBinary implements io.Serializable. // EncodeBinary implements io.Serializable.

View file

@ -33,8 +33,7 @@ type Node interface {
io.Serializable io.Serializable
json.Marshaler json.Marshaler
json.Unmarshaler json.Unmarshaler
Hash() util.Uint256 BaseNodeIface
Type() NodeType
} }
// EncodeBinary implements io.Serializable. // EncodeBinary implements io.Serializable.
@ -61,19 +60,6 @@ func (n *NodeObject) DecodeBinary(r *io.BinReader) {
n.Node.DecodeBinary(r) n.Node.DecodeBinary(r)
} }
// encodeNodeWithType encodes node together with it's type.
func encodeNodeWithType(n Node, w *io.BinWriter) {
w.WriteB(byte(n.Type()))
n.EncodeBinary(w)
}
// toBytes is a helper for serializing node.
func toBytes(n Node) []byte {
buf := io.NewBufBinWriter()
encodeNodeWithType(n, buf.BinWriter)
return buf.Bytes()
}
// UnmarshalJSON implements json.Unmarshaler. // UnmarshalJSON implements json.Unmarshaler.
func (n *NodeObject) UnmarshalJSON(data []byte) error { func (n *NodeObject) UnmarshalJSON(data []byte) error {
var m map[string]json.RawMessage var m map[string]json.RawMessage

View file

@ -25,11 +25,11 @@ func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error)
switch n := curr.(type) { switch n := curr.(type) {
case *LeafNode: case *LeafNode:
if len(path) == 0 { if len(path) == 0 {
*proofs = append(*proofs, toBytes(n)) *proofs = append(*proofs, copySlice(n.Bytes()))
return n, nil return n, nil
} }
case *BranchNode: case *BranchNode:
*proofs = append(*proofs, toBytes(n)) *proofs = append(*proofs, copySlice(n.Bytes()))
i, path := splitPath(path) i, path := splitPath(path)
r, err := t.getProof(n.Children[i], path, proofs) r, err := t.getProof(n.Children[i], path, proofs)
if err != nil { if err != nil {
@ -39,7 +39,7 @@ func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error)
return n, nil return n, nil
case *ExtensionNode: case *ExtensionNode:
if bytes.HasPrefix(path, n.key) { if bytes.HasPrefix(path, n.key) {
*proofs = append(*proofs, toBytes(n)) *proofs = append(*proofs, copySlice(n.Bytes()))
r, err := t.getProof(n.next, path[len(n.key):], proofs) r, err := t.getProof(n.next, path[len(n.key):], proofs)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -5,7 +5,6 @@ import (
"errors" "errors"
"github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
) )
@ -126,7 +125,7 @@ func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, err
return nil, err return nil, err
} }
curr.Children[i] = r curr.Children[i] = r
curr.invalidateHash() curr.invalidateCache()
return curr, nil return curr, nil
} }
@ -139,7 +138,7 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod
return nil, err return nil, err
} }
curr.next = r curr.next = r
curr.invalidateHash() curr.invalidateCache()
return curr, nil return curr, nil
} }
@ -218,7 +217,7 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
return nil, err return nil, err
} }
b.Children[i] = r b.Children[i] = r
b.invalidateHash() b.invalidateCache()
var count, index int var count, index int
for i := range b.Children { for i := range b.Children {
h, ok := b.Children[i].(*HashNode) h, ok := b.Children[i].(*HashNode)
@ -243,7 +242,7 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
} }
if e, ok := c.(*ExtensionNode); ok { if e, ok := c.(*ExtensionNode); ok {
e.key = append([]byte{byte(index)}, e.key...) e.key = append([]byte{byte(index)}, e.key...)
e.invalidateHash() e.invalidateCache()
return e, nil return e, nil
} }
@ -262,7 +261,7 @@ func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error)
case *ExtensionNode: case *ExtensionNode:
n.key = append(n.key, nxt.key...) n.key = append(n.key, nxt.key...)
n.next = nxt.next n.next = nxt.next
n.invalidateHash() n.invalidateCache()
case *HashNode: case *HashNode:
if nxt.IsEmpty() { if nxt.IsEmpty() {
return nxt, nil return nxt, nil
@ -336,9 +335,7 @@ func (t *Trie) putToStore(n Node) {
if n.Type() == HashT { if n.Type() == HashT {
panic("can't put hash node in trie") panic("can't put hash node in trie")
} }
bs := toBytes(n) _ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors
h := hash.DoubleSha256(bs)
_ = t.Store.Put(makeStorageKey(h.BytesBE()), bs) // put in MemCached returns no errors
} }
func (t *Trie) getFromStore(h util.Uint256) (Node, error) { func (t *Trie) getFromStore(h util.Uint256) (Node, error) {