mpt: implement reference counting

Also postpone MPT initialization until `storeBlock`
because we need to read-and-check or save info about refcounting
depending on starting height.
This commit is contained in:
Evgenii Stratonikov 2020-10-21 16:58:41 +03:00
parent 1c559634aa
commit 85f927d892
9 changed files with 260 additions and 79 deletions

View file

@ -9,6 +9,10 @@ type (
ProtocolConfiguration struct { ProtocolConfiguration struct {
Magic netmode.Magic `yaml:"Magic"` Magic netmode.Magic `yaml:"Magic"`
MemPoolSize int `yaml:"MemPoolSize"` MemPoolSize int `yaml:"MemPoolSize"`
// KeepOnlyLatestState specifies if MPT should only store latest state.
// If true, DB size will be smaller, but older roots won't be accessible.
// This value should remain the same for the same database.
KeepOnlyLatestState bool `yaml:"KeepOnlyLatestState"`
// MaxTraceableBlocks is the length of the chain accessible to smart contracts. // MaxTraceableBlocks is the length of the chain accessible to smart contracts.
MaxTraceableBlocks uint32 `yaml:"MaxTraceableBlocks"` MaxTraceableBlocks uint32 `yaml:"MaxTraceableBlocks"`
// P2PSigExtensions enables additional signature-related transaction attributes // P2PSigExtensions enables additional signature-related transaction attributes

View file

@ -197,6 +197,9 @@ func (bc *Blockchain) init() error {
if err != nil { if err != nil {
return err return err
} }
if err := bc.dao.InitMPT(0, bc.config.KeepOnlyLatestState); err != nil {
return fmt.Errorf("can't init MPT: %w", err)
}
return bc.storeBlock(genesisBlock, nil) return bc.storeBlock(genesisBlock, nil)
} }
if ver != version { if ver != version {
@ -214,7 +217,7 @@ func (bc *Blockchain) init() error {
} }
bc.blockHeight = bHeight bc.blockHeight = bHeight
bc.persistedHeight = bHeight bc.persistedHeight = bHeight
if err = bc.dao.InitMPT(bHeight); err != nil { if err = bc.dao.InitMPT(bHeight, bc.config.KeepOnlyLatestState); err != nil {
return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err) return fmt.Errorf("can't init MPT at height %d: %w", bHeight, err)
} }

View file

@ -81,7 +81,7 @@ type Simple struct {
// NewSimple creates new simple dao using provided backend store. // NewSimple creates new simple dao using provided backend store.
func NewSimple(backend storage.Store, network netmode.Magic) *Simple { func NewSimple(backend storage.Store, network netmode.Magic) *Simple {
st := storage.NewMemCachedStore(backend) st := storage.NewMemCachedStore(backend)
return &Simple{Store: st, network: network, MPT: mpt.NewTrie(nil, st)} return &Simple{Store: st, network: network}
} }
// GetBatch returns currently accumulated DB changeset. // GetBatch returns currently accumulated DB changeset.
@ -340,16 +340,28 @@ func makeStateRootKey(height uint32) []byte {
} }
// InitMPT initializes MPT at the given height. // InitMPT initializes MPT at the given height.
func (dao *Simple) InitMPT(height uint32) error { func (dao *Simple) InitMPT(height uint32, enableRefCount bool) error {
var gcKey = []byte{byte(storage.DataMPT), 1}
if height == 0 { if height == 0 {
dao.MPT = mpt.NewTrie(nil, dao.Store) dao.MPT = mpt.NewTrie(nil, enableRefCount, dao.Store)
return nil var val byte
if enableRefCount {
val = 1
}
return dao.Store.Put(gcKey, []byte{val})
}
var hasRefCount bool
if v, err := dao.Store.Get(gcKey); err == nil {
hasRefCount = v[0] != 0
}
if hasRefCount != enableRefCount {
return fmt.Errorf("KeepOnlyLatestState setting mismatch: old=%v, new=%v", hasRefCount, enableRefCount)
} }
r, err := dao.GetStateRoot(height) r, err := dao.GetStateRoot(height)
if err != nil { if err != nil {
return err return err
} }
dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), dao.Store) dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), enableRefCount, dao.Store)
return nil return nil
} }

View file

@ -16,8 +16,6 @@ type BaseNode struct {
bytes []byte bytes []byte
hashValid bool hashValid bool
bytesValid bool bytesValid bool
isFlushed bool
} }
// BaseNodeIface abstracts away basic Node functions. // BaseNodeIface abstracts away basic Node functions.
@ -25,8 +23,6 @@ type BaseNodeIface interface {
Hash() util.Uint256 Hash() util.Uint256
Type() NodeType Type() NodeType
Bytes() []byte Bytes() []byte
IsFlushed() bool
SetFlushed()
} }
type flushedNode interface { type flushedNode interface {
@ -38,7 +34,6 @@ func (b *BaseNode) setCache(bs []byte, h util.Uint256) {
b.hash = h b.hash = h
b.bytesValid = true b.bytesValid = true
b.hashValid = true b.hashValid = true
b.isFlushed = true
} }
// getHash returns a hash of this BaseNode. // getHash returns a hash of this BaseNode.
@ -78,17 +73,6 @@ func (b *BaseNode) updateBytes(n Node) {
func (b *BaseNode) invalidateCache() { func (b *BaseNode) invalidateCache() {
b.bytesValid = false b.bytesValid = false
b.hashValid = false b.hashValid = false
b.isFlushed = false
}
// IsFlushed checks for node flush status.
func (b *BaseNode) IsFlushed() bool {
return b.isFlushed
}
// SetFlushed sets 'flushed' flag to true for this node.
func (b *BaseNode) SetFlushed() {
b.isFlushed = true
} }
// encodeNodeWithType encodes node together with it's type. // encodeNodeWithType encodes node together with it's type.
@ -99,6 +83,9 @@ func encodeNodeWithType(n Node, w *io.BinWriter) {
// DecodeNodeWithType decodes node together with it's type. // DecodeNodeWithType decodes node together with it's type.
func DecodeNodeWithType(r *io.BinReader) Node { func DecodeNodeWithType(r *io.BinReader) Node {
if r.Err != nil {
return nil
}
var n Node var n Node
switch typ := NodeType(r.ReadB()); typ { switch typ := NodeType(r.ReadB()); typ {
case BranchT: case BranchT:

View file

@ -92,7 +92,7 @@ func TestNode_Serializable(t *testing.T) {
// https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198 // https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198
func TestJSONSharp(t *testing.T) { func TestJSONSharp(t *testing.T) {
tr := NewTrie(nil, newTestStore()) tr := NewTrie(nil, false, newTestStore())
require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11})) require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11}))
require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22})) require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22}))
require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac})) require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac}))

View file

@ -63,7 +63,7 @@ func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error)
// It also returns value for the key. // It also returns value for the key.
func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) { func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) {
path := toNibbles(key) path := toNibbles(key)
tr := NewTrie(NewHashNode(rh), storage.NewMemCachedStore(storage.NewMemoryStore())) tr := NewTrie(NewHashNode(rh), false, storage.NewMemCachedStore(storage.NewMemoryStore()))
for i := range proofs { for i := range proofs {
h := hash.DoubleSha256(proofs[i]) h := hash.DoubleSha256(proofs[i])
// no errors in Put to memory store // no errors in Put to memory store

View file

@ -15,7 +15,7 @@ func newProofTrie(t *testing.T) *Trie {
b.Children[4] = NewHashNode(e.Hash()) b.Children[4] = NewHashNode(e.Hash())
b.Children[5] = e2 b.Children[5] = e2
tr := NewTrie(b, newTestStore()) tr := NewTrie(b, false, newTestStore())
require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1"))) require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1")))
require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2")))
tr.putToStore(l) tr.putToStore(l)

View file

@ -2,7 +2,9 @@ package mpt
import ( import (
"bytes" "bytes"
"encoding/binary"
"errors" "errors"
"fmt"
"github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
@ -14,6 +16,14 @@ type Trie struct {
Store *storage.MemCachedStore Store *storage.MemCachedStore
root Node root Node
refcountEnabled bool
refcount map[util.Uint256]*cachedNode
}
type cachedNode struct {
bytes []byte
initial int32
refcount int32
} }
// ErrNotFound is returned when requested trie item is missing. // ErrNotFound is returned when requested trie item is missing.
@ -22,7 +32,7 @@ var ErrNotFound = errors.New("item not found")
// NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors // NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors
// so that all storage errors are processed during `store.Persist()` at the caller. // so that all storage errors are processed during `store.Persist()` at the caller.
// This also has the benefit, that every `Put` can be considered an atomic operation. // This also has the benefit, that every `Put` can be considered an atomic operation.
func NewTrie(root Node, store *storage.MemCachedStore) *Trie { func NewTrie(root Node, enableRefCount bool, store *storage.MemCachedStore) *Trie {
if root == nil { if root == nil {
root = new(HashNode) root = new(HashNode)
} }
@ -30,6 +40,9 @@ func NewTrie(root Node, store *storage.MemCachedStore) *Trie {
return &Trie{ return &Trie{
Store: store, Store: store,
root: root, root: root,
refcountEnabled: enableRefCount,
refcount: make(map[util.Uint256]*cachedNode),
} }
} }
@ -107,12 +120,15 @@ func (t *Trie) Put(key, value []byte) error {
func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) {
v := val.(*LeafNode) v := val.(*LeafNode)
if len(path) == 0 { if len(path) == 0 {
t.removeRef(curr.Hash(), curr.bytes)
t.addRef(val.Hash(), val.Bytes())
return v, nil return v, nil
} }
b := NewBranchNode() b := NewBranchNode()
b.Children[path[0]] = newSubTrie(path[1:], v) b.Children[path[0]] = t.newSubTrie(path[1:], v, true)
b.Children[lastChild] = curr b.Children[lastChild] = curr
t.addRef(b.Hash(), b.bytes)
return b, nil return b, nil
} }
@ -120,18 +136,21 @@ func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error)
// It returns Node if curr needs to be replaced and error if any. // It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) {
i, path := splitPath(path) i, path := splitPath(path)
t.removeRef(curr.Hash(), curr.bytes)
r, err := t.putIntoNode(curr.Children[i], path, val) r, err := t.putIntoNode(curr.Children[i], path, val)
if err != nil { if err != nil {
return nil, err return nil, err
} }
curr.Children[i] = r curr.Children[i] = r
curr.invalidateCache() curr.invalidateCache()
t.addRef(curr.Hash(), curr.bytes)
return curr, nil return curr, nil
} }
// putIntoExtension puts val to trie if current node is an Extension. // putIntoExtension puts val to trie if current node is an Extension.
// It returns Node if curr needs to be replaced and error if any. // It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) {
t.removeRef(curr.Hash(), curr.bytes)
if bytes.HasPrefix(path, curr.key) { if bytes.HasPrefix(path, curr.key) {
r, err := t.putIntoNode(curr.next, path[len(curr.key):], val) r, err := t.putIntoNode(curr.next, path[len(curr.key):], val)
if err != nil { if err != nil {
@ -139,6 +158,7 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod
} }
curr.next = r curr.next = r
curr.invalidateCache() curr.invalidateCache()
t.addRef(curr.Hash(), curr.bytes)
return curr, nil return curr, nil
} }
@ -147,16 +167,19 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod
keyTail := curr.key[lp:] keyTail := curr.key[lp:]
pathTail := path[lp:] pathTail := path[lp:]
s1 := newSubTrie(keyTail[1:], curr.next) s1 := t.newSubTrie(keyTail[1:], curr.next, false)
b := NewBranchNode() b := NewBranchNode()
b.Children[keyTail[0]] = s1 b.Children[keyTail[0]] = s1
i, pathTail := splitPath(pathTail) i, pathTail := splitPath(pathTail)
s2 := newSubTrie(pathTail, val) s2 := t.newSubTrie(pathTail, val, true)
b.Children[i] = s2 b.Children[i] = s2
t.addRef(b.Hash(), b.bytes)
if lp > 0 { if lp > 0 {
return NewExtensionNode(copySlice(pref), b), nil e := NewExtensionNode(copySlice(pref), b)
t.addRef(e.Hash(), e.bytes)
return e, nil
} }
return b, nil return b, nil
} }
@ -165,7 +188,8 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod
// It returns Node if curr needs to be replaced and error if any. // It returns Node if curr needs to be replaced and error if any.
func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) {
if curr.IsEmpty() { if curr.IsEmpty() {
return newSubTrie(path, val), nil hn := t.newSubTrie(path, val, true)
return hn, nil
} }
result, err := t.getFromStore(curr.hash) result, err := t.getFromStore(curr.hash)
@ -176,13 +200,20 @@ func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error)
} }
// newSubTrie create new trie containing node at provided path. // newSubTrie create new trie containing node at provided path.
func newSubTrie(path []byte, val Node) Node { func (t *Trie) newSubTrie(path []byte, val Node, newVal bool) Node {
if newVal {
t.addRef(val.Hash(), val.Bytes())
}
if len(path) == 0 { if len(path) == 0 {
return val return val
} }
return NewExtensionNode(path, val) e := NewExtensionNode(path, val)
t.addRef(e.Hash(), e.bytes)
return e
} }
// putIntoNode puts val with provided path inside curr and returns updated node.
// Reference counters are updated for both curr and returned value.
func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) { func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) {
switch n := curr.(type) { switch n := curr.(type) {
case *LeafNode: case *LeafNode:
@ -212,10 +243,13 @@ func (t *Trie) Delete(key []byte) error {
func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
i, path := splitPath(path) i, path := splitPath(path)
h := b.Hash()
bs := b.bytes
r, err := t.deleteFromNode(b.Children[i], path) r, err := t.deleteFromNode(b.Children[i], path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.removeRef(h, bs)
b.Children[i] = r b.Children[i] = r
b.invalidateCache() b.invalidateCache()
var count, index int var count, index int
@ -228,6 +262,7 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
} }
// count is >= 1 because branch node had at least 2 children before deletion. // count is >= 1 because branch node had at least 2 children before deletion.
if count > 1 { if count > 1 {
t.addRef(b.Hash(), b.bytes)
return b, nil return b, nil
} }
c := b.Children[index] c := b.Children[index]
@ -241,24 +276,32 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
} }
} }
if e, ok := c.(*ExtensionNode); ok { if e, ok := c.(*ExtensionNode); ok {
t.removeRef(e.Hash(), e.bytes)
e.key = append([]byte{byte(index)}, e.key...) e.key = append([]byte{byte(index)}, e.key...)
e.invalidateCache() e.invalidateCache()
t.addRef(e.Hash(), e.bytes)
return e, nil return e, nil
} }
return NewExtensionNode([]byte{byte(index)}, c), nil e := NewExtensionNode([]byte{byte(index)}, c)
t.addRef(e.Hash(), e.bytes)
return e, nil
} }
func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) { func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) {
if !bytes.HasPrefix(path, n.key) { if !bytes.HasPrefix(path, n.key) {
return nil, ErrNotFound return nil, ErrNotFound
} }
h := n.Hash()
bs := n.bytes
r, err := t.deleteFromNode(n.next, path[len(n.key):]) r, err := t.deleteFromNode(n.next, path[len(n.key):])
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.removeRef(h, bs)
switch nxt := r.(type) { switch nxt := r.(type) {
case *ExtensionNode: case *ExtensionNode:
t.removeRef(nxt.Hash(), nxt.bytes)
n.key = append(n.key, nxt.key...) n.key = append(n.key, nxt.key...)
n.next = nxt.next n.next = nxt.next
case *HashNode: case *HashNode:
@ -269,13 +312,17 @@ func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error)
n.next = r n.next = r
} }
n.invalidateCache() n.invalidateCache()
t.addRef(n.Hash(), n.bytes)
return n, nil return n, nil
} }
// deleteFromNode removes value with provided path from curr and returns an updated node.
// Reference counters are updated for both curr and returned value.
func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) { func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) {
switch n := curr.(type) { switch n := curr.(type) {
case *LeafNode: case *LeafNode:
if len(path) == 0 { if len(path) == 0 {
t.removeRef(curr.Hash(), curr.Bytes())
return new(HashNode), nil return new(HashNode), nil
} }
return nil, ErrNotFound return nil, ErrNotFound
@ -314,32 +361,88 @@ func makeStorageKey(mptKey []byte) []byte {
// new node to storage. Normally, flush should be called with every StateRoot persist, i.e. // new node to storage. Normally, flush should be called with every StateRoot persist, i.e.
// after every block. // after every block.
func (t *Trie) Flush() { func (t *Trie) Flush() {
t.flush(t.root) for h, node := range t.refcount {
if node.refcount != 0 {
if node.bytes == nil {
panic("item not in trie")
}
if t.refcountEnabled {
node.initial = t.updateRefCount(h)
if node.initial == 0 {
delete(t.refcount, h)
}
} else if node.refcount > 0 {
_ = t.Store.Put(makeStorageKey(h.BytesBE()), node.bytes)
}
node.refcount = 0
} else {
delete(t.refcount, h)
}
}
} }
func (t *Trie) flush(node Node) { // updateRefCount should be called only when refcounting is enabled.
if node.IsFlushed() { func (t *Trie) updateRefCount(h util.Uint256) int32 {
return if !t.refcountEnabled {
panic("`updateRefCount` is called, but GC is disabled")
} }
switch n := node.(type) { var data []byte
case *BranchNode: key := makeStorageKey(h.BytesBE())
for i := range n.Children { node := t.refcount[h]
t.flush(n.Children[i]) cnt := node.initial
if cnt == 0 {
// A newly created item which may be in store.
var err error
data, err = t.Store.Get(key)
if err == nil {
cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:]))
} }
case *ExtensionNode:
t.flush(n.next)
case *HashNode:
return
} }
t.putToStore(node) if len(data) == 0 {
data = append(node.bytes, 0, 0, 0, 0)
}
cnt += node.refcount
switch {
case cnt < 0:
// BUG: negative reference count
panic(fmt.Sprintf("negative reference count: %s new %d, upd %d", h.StringBE(), cnt, t.refcount[h]))
case cnt == 0:
_ = t.Store.Delete(key)
default:
binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt))
_ = t.Store.Put(key, data)
}
return cnt
} }
func (t *Trie) putToStore(n Node) { func (t *Trie) addRef(h util.Uint256, bs []byte) {
if n.Type() == HashT { node := t.refcount[h]
panic("can't put hash node in trie") if node == nil {
t.refcount[h] = &cachedNode{
refcount: 1,
bytes: bs,
}
return
}
node.refcount++
if node.bytes == nil {
node.bytes = bs
}
}
func (t *Trie) removeRef(h util.Uint256, bs []byte) {
node := t.refcount[h]
if node == nil {
t.refcount[h] = &cachedNode{
refcount: -1,
bytes: bs,
}
return
}
node.refcount--
if node.bytes == nil {
node.bytes = bs
} }
_ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors
n.SetFlushed()
} }
func (t *Trie) getFromStore(h util.Uint256) (Node, error) { func (t *Trie) getFromStore(h util.Uint256) (Node, error) {
@ -354,6 +457,15 @@ func (t *Trie) getFromStore(h util.Uint256) (Node, error) {
if r.Err != nil { if r.Err != nil {
return nil, r.Err return nil, r.Err
} }
if t.refcountEnabled {
data = data[:len(data)-4]
node := t.refcount[h]
if node != nil {
node.bytes = data
node.initial = int32(r.ReadU32LE())
}
}
n.Node.(flushedNode).setCache(data, h) n.Node.(flushedNode).setCache(data, h)
return n.Node, nil return n.Node, nil
} }
@ -366,6 +478,7 @@ func (t *Trie) Collapse(depth int) {
panic("negative depth") panic("negative depth")
} }
t.root = collapse(depth, t.root) t.root = collapse(depth, t.root)
t.refcount = make(map[util.Uint256]*cachedNode)
} }
func collapse(depth int, node Node) Node { func collapse(depth int, node Node) Node {

View file

@ -26,7 +26,7 @@ func newTestTrie(t *testing.T) *Trie {
b.Children[10] = NewExtensionNode([]byte{0x0e}, h) b.Children[10] = NewExtensionNode([]byte{0x0e}, h)
e := NewExtensionNode(toNibbles([]byte{0xAC}), b) e := NewExtensionNode(toNibbles([]byte{0xAC}), b)
tr := NewTrie(e, newTestStore()) tr := NewTrie(e, false, newTestStore())
tr.putToStore(e) tr.putToStore(e)
tr.putToStore(b) tr.putToStore(b)
@ -40,12 +40,50 @@ func newTestTrie(t *testing.T) *Trie {
return tr return tr
} }
func testTrieRefcount(t *testing.T, key1, key2 []byte) {
tr := NewTrie(nil, true, storage.NewMemCachedStore(storage.NewMemoryStore()))
require.NoError(t, tr.Put(key1, []byte{1}))
tr.Flush()
require.NoError(t, tr.Put(key2, []byte{1}))
tr.Flush()
tr.testHas(t, key1, []byte{1})
tr.testHas(t, key2, []byte{1})
// remove first, keep second
require.NoError(t, tr.Delete(key1))
tr.Flush()
tr.testHas(t, key1, nil)
tr.testHas(t, key2, []byte{1})
// no-op
require.NoError(t, tr.Put(key1, []byte{1}))
require.NoError(t, tr.Delete(key1))
tr.Flush()
tr.testHas(t, key1, nil)
tr.testHas(t, key2, []byte{1})
// error on delete, refcount should not be updated
require.Error(t, tr.Delete(key1))
tr.Flush()
tr.testHas(t, key1, nil)
tr.testHas(t, key2, []byte{1})
}
func TestTrie_Refcount(t *testing.T) {
t.Run("Leaf", func(t *testing.T) {
testTrieRefcount(t, []byte{0x11}, []byte{0x12})
})
t.Run("Extension", func(t *testing.T) {
testTrieRefcount(t, []byte{0x10, 11}, []byte{0x11, 12})
})
}
func TestTrie_PutIntoBranchNode(t *testing.T) { func TestTrie_PutIntoBranchNode(t *testing.T) {
b := NewBranchNode() b := NewBranchNode()
l := NewLeafNode([]byte{0x8}) l := NewLeafNode([]byte{0x8})
b.Children[0x7] = NewHashNode(l.Hash()) b.Children[0x7] = NewHashNode(l.Hash())
b.Children[0x8] = NewHashNode(random.Uint256()) b.Children[0x8] = NewHashNode(random.Uint256())
tr := NewTrie(b, newTestStore()) tr := NewTrie(b, false, newTestStore())
// next // next
require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34})) require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34}))
@ -70,7 +108,7 @@ func TestTrie_PutIntoExtensionNode(t *testing.T) {
l := NewLeafNode([]byte{0x11}) l := NewLeafNode([]byte{0x11})
key := []byte{0x12} key := []byte{0x12}
e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash())) e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash()))
tr := NewTrie(e, newTestStore()) tr := NewTrie(e, false, newTestStore())
// missing hash // missing hash
require.Error(t, tr.Put(key, []byte{0x42})) require.Error(t, tr.Put(key, []byte{0x42}))
@ -87,7 +125,7 @@ func TestTrie_PutIntoHashNode(t *testing.T) {
e := NewExtensionNode([]byte{0x02}, l) e := NewExtensionNode([]byte{0x02}, l)
b.Children[1] = NewHashNode(e.Hash()) b.Children[1] = NewHashNode(e.Hash())
b.Children[9] = NewHashNode(random.Uint256()) b.Children[9] = NewHashNode(random.Uint256())
tr := NewTrie(b, newTestStore()) tr := NewTrie(b, false, newTestStore())
tr.putToStore(e) tr.putToStore(e)
@ -108,7 +146,7 @@ func TestTrie_PutIntoHashNode(t *testing.T) {
func TestTrie_Put(t *testing.T) { func TestTrie_Put(t *testing.T) {
trExp := newTestTrie(t) trExp := newTestTrie(t)
trAct := NewTrie(nil, newTestStore()) trAct := NewTrie(nil, false, newTestStore())
require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD})) require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD}))
require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22})) require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22}))
require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello"))) require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello")))
@ -119,7 +157,7 @@ func TestTrie_Put(t *testing.T) {
} }
func TestTrie_PutInvalid(t *testing.T) { func TestTrie_PutInvalid(t *testing.T) {
tr := NewTrie(nil, newTestStore()) tr := NewTrie(nil, false, newTestStore())
key, value := []byte("key"), []byte("value") key, value := []byte("key"), []byte("value")
// big key // big key
@ -134,7 +172,7 @@ func TestTrie_PutInvalid(t *testing.T) {
} }
func TestTrie_BigPut(t *testing.T) { func TestTrie_BigPut(t *testing.T) {
tr := NewTrie(nil, newTestStore()) tr := NewTrie(nil, false, newTestStore())
items := []struct{ k, v string }{ items := []struct{ k, v string }{
{"item with long key", "value1"}, {"item with long key", "value1"},
{"item with matching prefix", "value2"}, {"item with matching prefix", "value2"},
@ -164,6 +202,21 @@ func TestTrie_BigPut(t *testing.T) {
}) })
} }
func (tr *Trie) putToStore(n Node) {
if n.Type() == HashT {
panic("can't put hash node in trie")
}
if tr.refcountEnabled {
tr.refcount[n.Hash()] = &cachedNode{
bytes: n.Bytes(),
refcount: 1,
}
tr.updateRefCount(n.Hash())
} else {
_ = tr.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes())
}
}
func (tr *Trie) testHas(t *testing.T, key, value []byte) { func (tr *Trie) testHas(t *testing.T, key, value []byte) {
v, err := tr.Get(key) v, err := tr.Get(key)
if value == nil { if value == nil {
@ -208,7 +261,7 @@ func TestTrie_Get(t *testing.T) {
}) })
t.Run("UnfoldRoot", func(t *testing.T) { t.Run("UnfoldRoot", func(t *testing.T) {
tr := newTestTrie(t) tr := newTestTrie(t)
single := NewTrie(NewHashNode(tr.root.Hash()), tr.Store) single := NewTrie(NewHashNode(tr.root.Hash()), false, tr.Store)
single.testHas(t, []byte{0xAC}, nil) single.testHas(t, []byte{0xAC}, nil)
single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD}) single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD})
single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22}) single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22})
@ -223,13 +276,13 @@ func TestTrie_Flush(t *testing.T) {
"key2": []byte("value2"), "key2": []byte("value2"),
} }
tr := NewTrie(nil, newTestStore()) tr := NewTrie(nil, false, newTestStore())
for k, v := range pairs { for k, v := range pairs {
require.NoError(t, tr.Put([]byte(k), v)) require.NoError(t, tr.Put([]byte(k), v))
} }
tr.Flush() tr.Flush()
tr = NewTrie(NewHashNode(tr.StateRoot()), tr.Store) tr = NewTrie(NewHashNode(tr.StateRoot()), false, tr.Store)
for k, v := range pairs { for k, v := range pairs {
actual, err := tr.Get([]byte(k)) actual, err := tr.Get([]byte(k))
require.NoError(t, err) require.NoError(t, err)
@ -238,10 +291,19 @@ func TestTrie_Flush(t *testing.T) {
} }
func TestTrie_Delete(t *testing.T) { func TestTrie_Delete(t *testing.T) {
t.Run("No GC", func(t *testing.T) {
testTrieDelete(t, false)
})
t.Run("With GC", func(t *testing.T) {
testTrieDelete(t, true)
})
}
func testTrieDelete(t *testing.T, enableGC bool) {
t.Run("Hash", func(t *testing.T) { t.Run("Hash", func(t *testing.T) {
t.Run("FromStore", func(t *testing.T) { t.Run("FromStore", func(t *testing.T) {
l := NewLeafNode([]byte{0x12}) l := NewLeafNode([]byte{0x12})
tr := NewTrie(NewHashNode(l.Hash()), newTestStore()) tr := NewTrie(NewHashNode(l.Hash()), enableGC, newTestStore())
t.Run("NotInStore", func(t *testing.T) { t.Run("NotInStore", func(t *testing.T) {
require.Error(t, tr.Delete([]byte{})) require.Error(t, tr.Delete([]byte{}))
}) })
@ -253,14 +315,14 @@ func TestTrie_Delete(t *testing.T) {
}) })
t.Run("Empty", func(t *testing.T) { t.Run("Empty", func(t *testing.T) {
tr := NewTrie(nil, newTestStore()) tr := NewTrie(nil, enableGC, newTestStore())
require.Error(t, tr.Delete([]byte{})) require.Error(t, tr.Delete([]byte{}))
}) })
}) })
t.Run("Leaf", func(t *testing.T) { t.Run("Leaf", func(t *testing.T) {
l := NewLeafNode([]byte{0x12, 0x34}) l := NewLeafNode([]byte{0x12, 0x34})
tr := NewTrie(l, newTestStore()) tr := NewTrie(l, enableGC, newTestStore())
t.Run("NonExistentKey", func(t *testing.T) { t.Run("NonExistentKey", func(t *testing.T) {
require.Error(t, tr.Delete([]byte{0x12})) require.Error(t, tr.Delete([]byte{0x12}))
tr.testHas(t, []byte{}, []byte{0x12, 0x34}) tr.testHas(t, []byte{}, []byte{0x12, 0x34})
@ -273,7 +335,7 @@ func TestTrie_Delete(t *testing.T) {
t.Run("SingleKey", func(t *testing.T) { t.Run("SingleKey", func(t *testing.T) {
l := NewLeafNode([]byte{0x12, 0x34}) l := NewLeafNode([]byte{0x12, 0x34})
e := NewExtensionNode([]byte{0x0A, 0x0B}, l) e := NewExtensionNode([]byte{0x0A, 0x0B}, l)
tr := NewTrie(e, newTestStore()) tr := NewTrie(e, enableGC, newTestStore())
t.Run("NonExistentKey", func(t *testing.T) { t.Run("NonExistentKey", func(t *testing.T) {
require.Error(t, tr.Delete([]byte{})) require.Error(t, tr.Delete([]byte{}))
@ -289,7 +351,7 @@ func TestTrie_Delete(t *testing.T) {
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34})) b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34}))
b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78})) b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78}))
e := NewExtensionNode([]byte{0x01, 0x02}, b) e := NewExtensionNode([]byte{0x01, 0x02}, b)
tr := NewTrie(e, newTestStore()) tr := NewTrie(e, enableGC, newTestStore())
h := e.Hash() h := e.Hash()
require.NoError(t, tr.Delete([]byte{0x12, 0x01})) require.NoError(t, tr.Delete([]byte{0x12, 0x01}))
@ -308,7 +370,7 @@ func TestTrie_Delete(t *testing.T) {
b.Children[lastChild] = NewLeafNode([]byte{0x12}) b.Children[lastChild] = NewLeafNode([]byte{0x12})
b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34})) b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34}))
b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56})) b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56}))
tr := NewTrie(b, newTestStore()) tr := NewTrie(b, enableGC, newTestStore())
require.NoError(t, tr.Delete([]byte{0x16})) require.NoError(t, tr.Delete([]byte{0x16}))
tr.testHas(t, []byte{}, []byte{0x12}) tr.testHas(t, []byte{}, []byte{0x12})
tr.testHas(t, []byte{0x01}, []byte{0x34}) tr.testHas(t, []byte{0x01}, []byte{0x34})
@ -321,7 +383,7 @@ func TestTrie_Delete(t *testing.T) {
l := NewLeafNode([]byte{0x34}) l := NewLeafNode([]byte{0x34})
e := NewExtensionNode([]byte{0x06}, l) e := NewExtensionNode([]byte{0x06}, l)
b.Children[5] = NewHashNode(e.Hash()) b.Children[5] = NewHashNode(e.Hash())
tr := NewTrie(b, newTestStore()) tr := NewTrie(b, enableGC, newTestStore())
tr.putToStore(l) tr.putToStore(l)
tr.putToStore(e) tr.putToStore(e)
return tr return tr
@ -344,7 +406,7 @@ func TestTrie_Delete(t *testing.T) {
b := NewBranchNode() b := NewBranchNode()
b.Children[lastChild] = NewLeafNode([]byte{0x12}) b.Children[lastChild] = NewLeafNode([]byte{0x12})
b.Children[5] = c b.Children[5] = c
tr := NewTrie(b, newTestStore()) tr := NewTrie(b, enableGC, newTestStore())
require.NoError(t, tr.Delete([]byte{})) require.NoError(t, tr.Delete([]byte{}))
tr.testHas(t, []byte{}, nil) tr.testHas(t, []byte{}, nil)
@ -396,7 +458,7 @@ func TestTrie_Collapse(t *testing.T) {
b.Children[0] = e b.Children[0] = e
hb := b.Hash() hb := b.Hash()
tr := NewTrie(b, newTestStore()) tr := NewTrie(b, false, newTestStore())
tr.Collapse(1) tr.Collapse(1)
newb, ok := tr.root.(*BranchNode) newb, ok := tr.root.(*BranchNode)
@ -410,7 +472,7 @@ func TestTrie_Collapse(t *testing.T) {
hl := l.Hash() hl := l.Hash()
e := NewExtensionNode([]byte{0x01}, l) e := NewExtensionNode([]byte{0x01}, l)
h := e.Hash() h := e.Hash()
tr := NewTrie(e, newTestStore()) tr := NewTrie(e, false, newTestStore())
tr.Collapse(1) tr.Collapse(1)
newe, ok := tr.root.(*ExtensionNode) newe, ok := tr.root.(*ExtensionNode)
@ -421,13 +483,13 @@ func TestTrie_Collapse(t *testing.T) {
}) })
t.Run("Leaf", func(t *testing.T) { t.Run("Leaf", func(t *testing.T) {
l := NewLeafNode([]byte("value")) l := NewLeafNode([]byte("value"))
tr := NewTrie(l, newTestStore()) tr := NewTrie(l, false, newTestStore())
tr.Collapse(10) tr.Collapse(10)
require.Equal(t, NewLeafNode([]byte("value")), tr.root) require.Equal(t, NewLeafNode([]byte("value")), tr.root)
}) })
t.Run("Hash", func(t *testing.T) { t.Run("Hash", func(t *testing.T) {
t.Run("Empty", func(t *testing.T) { t.Run("Empty", func(t *testing.T) {
tr := NewTrie(new(HashNode), newTestStore()) tr := NewTrie(new(HashNode), false, newTestStore())
require.NotPanics(t, func() { tr.Collapse(1) }) require.NotPanics(t, func() { tr.Collapse(1) })
hn, ok := tr.root.(*HashNode) hn, ok := tr.root.(*HashNode)
require.True(t, ok) require.True(t, ok)
@ -436,7 +498,7 @@ func TestTrie_Collapse(t *testing.T) {
h := random.Uint256() h := random.Uint256()
hn := NewHashNode(h) hn := NewHashNode(h)
tr := NewTrie(hn, newTestStore()) tr := NewTrie(hn, false, newTestStore())
tr.Collapse(10) tr.Collapse(10)
newRoot, ok := tr.root.(*HashNode) newRoot, ok := tr.root.(*HashNode)