diff --git a/pkg/core/mpt/base.go b/pkg/core/mpt/base.go new file mode 100644 index 000000000..9f10cc333 --- /dev/null +++ b/pkg/core/mpt/base.go @@ -0,0 +1,84 @@ +package mpt + +import ( + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// BaseNode implements basic things every node needs like caching hash and +// serialized representation. It's a basic node building block intended to be +// included into all node types. +type BaseNode struct { + hash util.Uint256 + bytes []byte + hashValid bool + bytesValid bool + + isFlushed bool +} + +// BaseNodeIface abstracts away basic Node functions. +type BaseNodeIface interface { + Hash() util.Uint256 + Type() NodeType + Bytes() []byte + IsFlushed() bool + SetFlushed() +} + +// getHash returns a hash of this BaseNode. +func (b *BaseNode) getHash(n Node) util.Uint256 { + if !b.hashValid { + b.updateHash(n) + } + return b.hash +} + +// getBytes returns a slice of bytes representing this node. +func (b *BaseNode) getBytes(n Node) []byte { + if !b.bytesValid { + b.updateBytes(n) + } + return b.bytes +} + +// updateHash updates hash field for this BaseNode. +func (b *BaseNode) updateHash(n Node) { + if n.Type() == HashT { + panic("can't update hash for hash node") + } + b.hash = hash.DoubleSha256(b.getBytes(n)) + b.hashValid = true +} + +// updateCache updates hash and bytes fields for this BaseNode. +func (b *BaseNode) updateBytes(n Node) { + buf := io.NewBufBinWriter() + encodeNodeWithType(n, buf.BinWriter) + b.bytes = buf.Bytes() + b.bytesValid = true +} + +// invalidateCache sets all cache fields to invalid state. +func (b *BaseNode) invalidateCache() { + b.bytesValid = false + b.hashValid = false + b.isFlushed = false +} + +// IsFlushed checks for node flush status. +func (b *BaseNode) IsFlushed() bool { + return b.isFlushed +} + +// SetFlushed sets 'flushed' flag to true for this node. +func (b *BaseNode) SetFlushed() { + b.isFlushed = true +} + +// encodeNodeWithType encodes node together with it's type. +func encodeNodeWithType(n Node, w *io.BinWriter) { + w.WriteB(byte(n.Type())) + n.EncodeBinary(w) +} diff --git a/pkg/core/mpt/branch.go b/pkg/core/mpt/branch.go index c4a383075..fbad5d29e 100644 --- a/pkg/core/mpt/branch.go +++ b/pkg/core/mpt/branch.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -18,9 +17,7 @@ const ( // BranchNode represents MPT's branch node. type BranchNode struct { - hash util.Uint256 - valid bool - + BaseNode Children [childrenCount]Node } @@ -38,18 +35,14 @@ func NewBranchNode() *BranchNode { // Type implements Node interface. func (b *BranchNode) Type() NodeType { return BranchT } -// Hash implements Node interface. +// Hash implements BaseNode interface. func (b *BranchNode) Hash() util.Uint256 { - if !b.valid { - b.hash = hash.DoubleSha256(toBytes(b)) - b.valid = true - } - return b.hash + return b.getHash(b) } -// invalidateHash invalidates node hash. -func (b *BranchNode) invalidateHash() { - b.valid = false +// Bytes implements BaseNode interface. +func (b *BranchNode) Bytes() []byte { + return b.getBytes(b) } // EncodeBinary implements io.Serializable. diff --git a/pkg/core/mpt/extension.go b/pkg/core/mpt/extension.go index a337c4de2..8bcc11c24 100644 --- a/pkg/core/mpt/extension.go +++ b/pkg/core/mpt/extension.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -16,9 +15,7 @@ const MaxKeyLength = 1125 // ExtensionNode represents MPT's extension node. type ExtensionNode struct { - hash util.Uint256 - valid bool - + BaseNode key []byte next Node } @@ -37,18 +34,14 @@ func NewExtensionNode(key []byte, next Node) *ExtensionNode { // Type implements Node interface. func (e ExtensionNode) Type() NodeType { return ExtensionT } -// Hash implements Node interface. +// Hash implements BaseNode interface. func (e *ExtensionNode) Hash() util.Uint256 { - if !e.valid { - e.hash = hash.DoubleSha256(toBytes(e)) - e.valid = true - } - return e.hash + return e.getHash(e) } -// invalidateHash invalidates node hash. -func (e *ExtensionNode) invalidateHash() { - e.valid = false +// Bytes implements BaseNode interface. +func (e *ExtensionNode) Bytes() []byte { + return e.getBytes(e) } // DecodeBinary implements io.Serializable. @@ -58,11 +51,11 @@ func (e *ExtensionNode) DecodeBinary(r *io.BinReader) { r.Err = fmt.Errorf("extension node key is too big: %d", sz) return } - e.valid = false e.key = make([]byte, sz) r.ReadBytes(e.key) e.next = new(HashNode) e.next.DecodeBinary(r) + e.invalidateCache() } // EncodeBinary implements io.Serializable. diff --git a/pkg/core/mpt/hash.go b/pkg/core/mpt/hash.go index 51c6095fd..42519a1ac 100644 --- a/pkg/core/mpt/hash.go +++ b/pkg/core/mpt/hash.go @@ -10,8 +10,7 @@ import ( // HashNode represents MPT's hash node. type HashNode struct { - hash util.Uint256 - valid bool + BaseNode } var _ Node = (*HashNode)(nil) @@ -19,8 +18,10 @@ var _ Node = (*HashNode)(nil) // NewHashNode returns hash node with the specified hash. func NewHashNode(h util.Uint256) *HashNode { return &HashNode{ - hash: h, - valid: true, + BaseNode: BaseNode{ + hash: h, + hashValid: true, + }, } } @@ -29,23 +30,28 @@ func (h *HashNode) Type() NodeType { return HashT } // Hash implements Node interface. func (h *HashNode) Hash() util.Uint256 { - if !h.valid { + if !h.hashValid { panic("can't get hash of an empty HashNode") } return h.hash } // IsEmpty returns true iff h is an empty node i.e. contains no hash. -func (h *HashNode) IsEmpty() bool { return !h.valid } +func (h *HashNode) IsEmpty() bool { return !h.hashValid } + +// Bytes returns serialized HashNode. +func (h *HashNode) Bytes() []byte { + return h.getBytes(h) +} // DecodeBinary implements io.Serializable. func (h *HashNode) DecodeBinary(r *io.BinReader) { sz := r.ReadVarUint() switch sz { case 0: - h.valid = false + h.hashValid = false case util.Uint256Size: - h.valid = true + h.hashValid = true r.ReadBytes(h.hash[:]) default: r.Err = fmt.Errorf("invalid hash node size: %d", sz) @@ -54,7 +60,7 @@ func (h *HashNode) DecodeBinary(r *io.BinReader) { // EncodeBinary implements io.Serializable. func (h HashNode) EncodeBinary(w *io.BinWriter) { - if !h.valid { + if !h.hashValid { w.WriteVarUint(0) return } @@ -63,7 +69,7 @@ func (h HashNode) EncodeBinary(w *io.BinWriter) { // MarshalJSON implements json.Marshaler. func (h *HashNode) MarshalJSON() ([]byte, error) { - if !h.valid { + if !h.hashValid { return []byte(`{}`), nil } return []byte(`{"hash":"` + h.hash.StringLE() + `"}`), nil diff --git a/pkg/core/mpt/leaf.go b/pkg/core/mpt/leaf.go index 4ae509a1c..82dd8eef6 100644 --- a/pkg/core/mpt/leaf.go +++ b/pkg/core/mpt/leaf.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -15,9 +14,7 @@ const MaxValueLength = 1024 * 1024 // LeafNode represents MPT's leaf node. type LeafNode struct { - hash util.Uint256 - valid bool - + BaseNode value []byte } @@ -31,13 +28,14 @@ func NewLeafNode(value []byte) *LeafNode { // Type implements Node interface. func (n LeafNode) Type() NodeType { return LeafT } -// Hash implements Node interface. +// Hash implements BaseNode interface. func (n *LeafNode) Hash() util.Uint256 { - if !n.valid { - n.hash = hash.DoubleSha256(toBytes(n)) - n.valid = true - } - return n.hash + return n.getHash(n) +} + +// Bytes implements BaseNode interface. +func (n *LeafNode) Bytes() []byte { + return n.getBytes(n) } // DecodeBinary implements io.Serializable. @@ -47,9 +45,9 @@ func (n *LeafNode) DecodeBinary(r *io.BinReader) { r.Err = fmt.Errorf("leaf node value is too big: %d", sz) return } - n.valid = false n.value = make([]byte, sz) r.ReadBytes(n.value) + n.invalidateCache() } // EncodeBinary implements io.Serializable. diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index 53a2fdec1..86e675a01 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -33,8 +33,7 @@ type Node interface { io.Serializable json.Marshaler json.Unmarshaler - Hash() util.Uint256 - Type() NodeType + BaseNodeIface } // EncodeBinary implements io.Serializable. @@ -61,19 +60,6 @@ func (n *NodeObject) DecodeBinary(r *io.BinReader) { n.Node.DecodeBinary(r) } -// encodeNodeWithType encodes node together with it's type. -func encodeNodeWithType(n Node, w *io.BinWriter) { - w.WriteB(byte(n.Type())) - n.EncodeBinary(w) -} - -// toBytes is a helper for serializing node. -func toBytes(n Node) []byte { - buf := io.NewBufBinWriter() - encodeNodeWithType(n, buf.BinWriter) - return buf.Bytes() -} - // UnmarshalJSON implements json.Unmarshaler. func (n *NodeObject) UnmarshalJSON(data []byte) error { var m map[string]json.RawMessage diff --git a/pkg/core/mpt/proof.go b/pkg/core/mpt/proof.go index f785bd9d4..5f8fcdc84 100644 --- a/pkg/core/mpt/proof.go +++ b/pkg/core/mpt/proof.go @@ -25,11 +25,11 @@ func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) switch n := curr.(type) { case *LeafNode: if len(path) == 0 { - *proofs = append(*proofs, toBytes(n)) + *proofs = append(*proofs, copySlice(n.Bytes())) return n, nil } case *BranchNode: - *proofs = append(*proofs, toBytes(n)) + *proofs = append(*proofs, copySlice(n.Bytes())) i, path := splitPath(path) r, err := t.getProof(n.Children[i], path, proofs) if err != nil { @@ -39,7 +39,7 @@ func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) return n, nil case *ExtensionNode: if bytes.HasPrefix(path, n.key) { - *proofs = append(*proofs, toBytes(n)) + *proofs = append(*proofs, copySlice(n.Bytes())) r, err := t.getProof(n.next, path[len(n.key):], proofs) if err != nil { return nil, err diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go index c5093d614..3c38424c0 100644 --- a/pkg/core/mpt/trie.go +++ b/pkg/core/mpt/trie.go @@ -5,7 +5,6 @@ import ( "errors" "github.com/nspcc-dev/neo-go/pkg/core/storage" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" ) @@ -126,7 +125,7 @@ func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, err return nil, err } curr.Children[i] = r - curr.invalidateHash() + curr.invalidateCache() return curr, nil } @@ -139,7 +138,7 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod return nil, err } curr.next = r - curr.invalidateHash() + curr.invalidateCache() return curr, nil } @@ -218,7 +217,7 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { return nil, err } b.Children[i] = r - b.invalidateHash() + b.invalidateCache() var count, index int for i := range b.Children { h, ok := b.Children[i].(*HashNode) @@ -243,7 +242,7 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { } if e, ok := c.(*ExtensionNode); ok { e.key = append([]byte{byte(index)}, e.key...) - e.invalidateHash() + e.invalidateCache() return e, nil } @@ -262,7 +261,7 @@ func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) case *ExtensionNode: n.key = append(n.key, nxt.key...) n.next = nxt.next - n.invalidateHash() + n.invalidateCache() case *HashNode: if nxt.IsEmpty() { return nxt, nil @@ -319,6 +318,9 @@ func (t *Trie) Flush() { } func (t *Trie) flush(node Node) { + if node.IsFlushed() { + return + } switch n := node.(type) { case *BranchNode: for i := range n.Children { @@ -336,9 +338,8 @@ func (t *Trie) putToStore(n Node) { if n.Type() == HashT { panic("can't put hash node in trie") } - bs := toBytes(n) - h := hash.DoubleSha256(bs) - _ = t.Store.Put(makeStorageKey(h.BytesBE()), bs) // put in MemCached returns no errors + _ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors + n.SetFlushed() } func (t *Trie) getFromStore(h util.Uint256) (Node, error) {