diff --git a/pkg/config/protocol_config.go b/pkg/config/protocol_config.go index e0d9bf0cd..b094a1958 100644 --- a/pkg/config/protocol_config.go +++ b/pkg/config/protocol_config.go @@ -22,6 +22,10 @@ type ( AddressVersion byte `yaml:"AddressVersion"` // EnableStateRoot specifies if exchange of state roots should be enabled. EnableStateRoot bool `yaml:"EnableStateRoot"` + // KeepOnlyLatestState specifies if MPT should only store latest state. + // If true, DB size will be smaller, but older roots won't be accessible. + // This value should remain the same for the same database. + KeepOnlyLatestState bool `yaml:"KeepOnlyLatestState"` // FeePerExtraByte sets the expected per-byte fee for // transactions exceeding the MaxFreeTransactionSize. FeePerExtraByte float64 `yaml:"FeePerExtraByte"` diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 6e7d4e456..c5d88068c 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -210,7 +210,7 @@ func (bc *Blockchain) init() error { return err } if bc.config.EnableStateRoot { - if err := bc.dao.InitMPT(0); err != nil { + if err := bc.dao.InitMPT(0, bc.config.KeepOnlyLatestState); err != nil { return err } } @@ -232,7 +232,7 @@ func (bc *Blockchain) init() error { bc.blockHeight = bHeight bc.persistedHeight = bHeight if bc.config.EnableStateRoot { - if err = bc.dao.InitMPT(bHeight); err != nil { + if err = bc.dao.InitMPT(bHeight, bc.config.KeepOnlyLatestState); err != nil { return errors.Wrapf(err, "can't init MPT at height %d", bHeight) } } @@ -563,7 +563,7 @@ func (bc *Blockchain) GetStateProof(root util.Uint256, key []byte) ([][]byte, er if !bc.config.EnableStateRoot { return nil, errors.New("state root feature is not enabled") } - tr := mpt.NewTrie(mpt.NewHashNode(root), storage.NewMemCachedStore(bc.dao.Store)) + tr := mpt.NewTrie(mpt.NewHashNode(root), bc.config.KeepOnlyLatestState, storage.NewMemCachedStore(bc.dao.Store)) return tr.GetProof(key) } diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index 15b59c682..360dec0ea 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -510,16 +510,28 @@ func makeStateRootKey(height uint32) []byte { } // InitMPT initializes MPT at the given height. -func (dao *Simple) InitMPT(height uint32) error { +func (dao *Simple) InitMPT(height uint32, enableRefCount bool) error { + var gcKey = []byte{byte(storage.DataMPT), 1} if height == 0 { - dao.MPT = mpt.NewTrie(nil, dao.Store) - return nil + dao.MPT = mpt.NewTrie(nil, enableRefCount, dao.Store) + var val byte + if enableRefCount { + val = 1 + } + return dao.Store.Put(gcKey, []byte{val}) + } + var hasRefCount bool + if v, err := dao.Store.Get(gcKey); err == nil { + hasRefCount = v[0] != 0 + } + if hasRefCount != enableRefCount { + return fmt.Errorf("KeepOnlyLatestState setting mismatch: old=%v, new=%v", hasRefCount, enableRefCount) } r, err := dao.GetStateRoot(height) if err != nil { return err } - dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), dao.Store) + dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), enableRefCount, dao.Store) return nil } diff --git a/pkg/core/mpt/base.go b/pkg/core/mpt/base.go index 9f10cc333..1c4ebd8a6 100644 --- a/pkg/core/mpt/base.go +++ b/pkg/core/mpt/base.go @@ -1,6 +1,8 @@ package mpt import ( + "fmt" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -14,8 +16,6 @@ type BaseNode struct { bytes []byte hashValid bool bytesValid bool - - isFlushed bool } // BaseNodeIface abstracts away basic Node functions. @@ -23,8 +23,17 @@ type BaseNodeIface interface { Hash() util.Uint256 Type() NodeType Bytes() []byte - IsFlushed() bool - SetFlushed() +} + +type flushedNode interface { + setCache([]byte, util.Uint256) +} + +func (b *BaseNode) setCache(bs []byte, h util.Uint256) { + b.bytes = bs + b.hash = h + b.bytesValid = true + b.hashValid = true } // getHash returns a hash of this BaseNode. @@ -64,17 +73,6 @@ func (b *BaseNode) updateBytes(n Node) { func (b *BaseNode) invalidateCache() { b.bytesValid = false b.hashValid = false - b.isFlushed = false -} - -// IsFlushed checks for node flush status. -func (b *BaseNode) IsFlushed() bool { - return b.isFlushed -} - -// SetFlushed sets 'flushed' flag to true for this node. -func (b *BaseNode) SetFlushed() { - b.isFlushed = true } // encodeNodeWithType encodes node together with it's type. @@ -82,3 +80,26 @@ func encodeNodeWithType(n Node, w *io.BinWriter) { w.WriteB(byte(n.Type())) n.EncodeBinary(w) } + +// DecodeNodeWithType decodes node together with it's type. +func DecodeNodeWithType(r *io.BinReader) Node { + if r.Err != nil { + return nil + } + var n Node + switch typ := NodeType(r.ReadB()); typ { + case BranchT: + n = new(BranchNode) + case ExtensionT: + n = new(ExtensionNode) + case HashT: + n = new(HashNode) + case LeafT: + n = new(LeafNode) + default: + r.Err = fmt.Errorf("invalid node type: %x", typ) + return nil + } + n.DecodeBinary(r) + return n +} diff --git a/pkg/core/mpt/node.go b/pkg/core/mpt/node.go index 86e675a01..04b085948 100644 --- a/pkg/core/mpt/node.go +++ b/pkg/core/mpt/node.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "encoding/json" "errors" - "fmt" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -43,21 +42,7 @@ func (n NodeObject) EncodeBinary(w *io.BinWriter) { // DecodeBinary implements io.Serializable. func (n *NodeObject) DecodeBinary(r *io.BinReader) { - typ := NodeType(r.ReadB()) - switch typ { - case BranchT: - n.Node = new(BranchNode) - case ExtensionT: - n.Node = new(ExtensionNode) - case HashT: - n.Node = new(HashNode) - case LeafT: - n.Node = new(LeafNode) - default: - r.Err = fmt.Errorf("invalid node type: %x", typ) - return - } - n.Node.DecodeBinary(r) + n.Node = DecodeNodeWithType(r) } // UnmarshalJSON implements json.Unmarshaler. diff --git a/pkg/core/mpt/node_test.go b/pkg/core/mpt/node_test.go index e3aab54d6..dc2b66ec6 100644 --- a/pkg/core/mpt/node_test.go +++ b/pkg/core/mpt/node_test.go @@ -92,7 +92,7 @@ func TestNode_Serializable(t *testing.T) { // https://github.com/neo-project/neo/blob/neox-2.x/neo.UnitTests/UT_MPTTrie.cs#L198 func TestJSONSharp(t *testing.T) { - tr := NewTrie(nil, newTestStore()) + tr := NewTrie(nil, false, newTestStore()) require.NoError(t, tr.Put([]byte{0xac, 0x11}, []byte{0xac, 0x11})) require.NoError(t, tr.Put([]byte{0xac, 0x22}, []byte{0xac, 0x22})) require.NoError(t, tr.Put([]byte{0xac}, []byte{0xac})) diff --git a/pkg/core/mpt/proof.go b/pkg/core/mpt/proof.go index 5f8fcdc84..db18ead59 100644 --- a/pkg/core/mpt/proof.go +++ b/pkg/core/mpt/proof.go @@ -63,7 +63,7 @@ func (t *Trie) getProof(curr Node, path []byte, proofs *[][]byte) (Node, error) // It also returns value for the key. func VerifyProof(rh util.Uint256, key []byte, proofs [][]byte) ([]byte, bool) { path := toNibbles(key) - tr := NewTrie(NewHashNode(rh), storage.NewMemCachedStore(storage.NewMemoryStore())) + tr := NewTrie(NewHashNode(rh), false, storage.NewMemCachedStore(storage.NewMemoryStore())) for i := range proofs { h := hash.DoubleSha256(proofs[i]) // no errors in Put to memory store diff --git a/pkg/core/mpt/proof_test.go b/pkg/core/mpt/proof_test.go index 17301af15..75a76408f 100644 --- a/pkg/core/mpt/proof_test.go +++ b/pkg/core/mpt/proof_test.go @@ -15,7 +15,7 @@ func newProofTrie(t *testing.T) *Trie { b.Children[4] = NewHashNode(e.Hash()) b.Children[5] = e2 - tr := NewTrie(b, newTestStore()) + tr := NewTrie(b, false, newTestStore()) require.NoError(t, tr.Put([]byte{0x12, 0x31}, []byte("value1"))) require.NoError(t, tr.Put([]byte{0x12, 0x32}, []byte("value2"))) tr.putToStore(l) diff --git a/pkg/core/mpt/trie.go b/pkg/core/mpt/trie.go index 08d128d88..751b1b2f3 100644 --- a/pkg/core/mpt/trie.go +++ b/pkg/core/mpt/trie.go @@ -2,7 +2,9 @@ package mpt import ( "bytes" + "encoding/binary" "errors" + "fmt" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/io" @@ -13,7 +15,15 @@ import ( type Trie struct { Store *storage.MemCachedStore - root Node + root Node + refcountEnabled bool + refcount map[util.Uint256]*cachedNode +} + +type cachedNode struct { + bytes []byte + initial int32 + refcount int32 } // ErrNotFound is returned when requested trie item is missing. @@ -22,7 +32,7 @@ var ErrNotFound = errors.New("item not found") // NewTrie returns new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors // so that all storage errors are processed during `store.Persist()` at the caller. // This also has the benefit, that every `Put` can be considered an atomic operation. -func NewTrie(root Node, store *storage.MemCachedStore) *Trie { +func NewTrie(root Node, enableRefCount bool, store *storage.MemCachedStore) *Trie { if root == nil { root = new(HashNode) } @@ -30,6 +40,9 @@ func NewTrie(root Node, store *storage.MemCachedStore) *Trie { return &Trie{ Store: store, root: root, + + refcountEnabled: enableRefCount, + refcount: make(map[util.Uint256]*cachedNode), } } @@ -107,12 +120,15 @@ func (t *Trie) Put(key, value []byte) error { func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { v := val.(*LeafNode) if len(path) == 0 { + t.removeRef(curr.Hash(), curr.bytes) + t.addRef(val.Hash(), val.Bytes()) return v, nil } b := NewBranchNode() - b.Children[path[0]] = newSubTrie(path[1:], v) + b.Children[path[0]] = t.newSubTrie(path[1:], v, true) b.Children[lastChild] = curr + t.addRef(b.Hash(), b.bytes) return b, nil } @@ -120,18 +136,21 @@ func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) // It returns Node if curr needs to be replaced and error if any. func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { i, path := splitPath(path) + t.removeRef(curr.Hash(), curr.bytes) r, err := t.putIntoNode(curr.Children[i], path, val) if err != nil { return nil, err } curr.Children[i] = r curr.invalidateCache() + t.addRef(curr.Hash(), curr.bytes) return curr, nil } // putIntoExtension puts val to trie if current node is an Extension. // It returns Node if curr needs to be replaced and error if any. func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { + t.removeRef(curr.Hash(), curr.bytes) if bytes.HasPrefix(path, curr.key) { r, err := t.putIntoNode(curr.next, path[len(curr.key):], val) if err != nil { @@ -139,6 +158,7 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod } curr.next = r curr.invalidateCache() + t.addRef(curr.Hash(), curr.bytes) return curr, nil } @@ -147,16 +167,19 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod keyTail := curr.key[lp:] pathTail := path[lp:] - s1 := newSubTrie(keyTail[1:], curr.next) + s1 := t.newSubTrie(keyTail[1:], curr.next, false) b := NewBranchNode() b.Children[keyTail[0]] = s1 i, pathTail := splitPath(pathTail) - s2 := newSubTrie(pathTail, val) + s2 := t.newSubTrie(pathTail, val, true) b.Children[i] = s2 + t.addRef(b.Hash(), b.bytes) if lp > 0 { - return NewExtensionNode(copySlice(pref), b), nil + e := NewExtensionNode(copySlice(pref), b) + t.addRef(e.Hash(), e.bytes) + return e, nil } return b, nil } @@ -165,7 +188,8 @@ func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Nod // It returns Node if curr needs to be replaced and error if any. func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { if curr.IsEmpty() { - return newSubTrie(path, val), nil + hn := t.newSubTrie(path, val, true) + return hn, nil } result, err := t.getFromStore(curr.hash) @@ -176,13 +200,20 @@ func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) } // newSubTrie create new trie containing node at provided path. -func newSubTrie(path []byte, val Node) Node { +func (t *Trie) newSubTrie(path []byte, val Node, newVal bool) Node { + if newVal { + t.addRef(val.Hash(), val.Bytes()) + } if len(path) == 0 { return val } - return NewExtensionNode(path, val) + e := NewExtensionNode(path, val) + t.addRef(e.Hash(), e.bytes) + return e } +// putIntoNode puts val with provided path inside curr and returns updated node. +// Reference counters are updated for both curr and returned value. func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) { switch n := curr.(type) { case *LeafNode: @@ -212,10 +243,13 @@ func (t *Trie) Delete(key []byte) error { func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { i, path := splitPath(path) + h := b.Hash() + bs := b.bytes r, err := t.deleteFromNode(b.Children[i], path) if err != nil { return nil, err } + t.removeRef(h, bs) b.Children[i] = r b.invalidateCache() var count, index int @@ -228,6 +262,7 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { } // count is >= 1 because branch node had at least 2 children before deletion. if count > 1 { + t.addRef(b.Hash(), b.bytes) return b, nil } c := b.Children[index] @@ -241,24 +276,32 @@ func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { } } if e, ok := c.(*ExtensionNode); ok { + t.removeRef(e.Hash(), e.bytes) e.key = append([]byte{byte(index)}, e.key...) e.invalidateCache() + t.addRef(e.Hash(), e.bytes) return e, nil } - return NewExtensionNode([]byte{byte(index)}, c), nil + e := NewExtensionNode([]byte{byte(index)}, c) + t.addRef(e.Hash(), e.bytes) + return e, nil } func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) { if !bytes.HasPrefix(path, n.key) { return nil, ErrNotFound } + h := n.Hash() + bs := n.bytes r, err := t.deleteFromNode(n.next, path[len(n.key):]) if err != nil { return nil, err } + t.removeRef(h, bs) switch nxt := r.(type) { case *ExtensionNode: + t.removeRef(nxt.Hash(), nxt.bytes) n.key = append(n.key, nxt.key...) n.next = nxt.next case *HashNode: @@ -269,13 +312,17 @@ func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) n.next = r } n.invalidateCache() + t.addRef(n.Hash(), n.bytes) return n, nil } +// deleteFromNode removes value with provided path from curr and returns an updated node. +// Reference counters are updated for both curr and returned value. func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) { switch n := curr.(type) { case *LeafNode: if len(path) == 0 { + t.removeRef(curr.Hash(), curr.Bytes()) return new(HashNode), nil } return nil, ErrNotFound @@ -314,32 +361,88 @@ func makeStorageKey(mptKey []byte) []byte { // new node to storage. Normally, flush should be called with every StateRoot persist, i.e. // after every block. func (t *Trie) Flush() { - t.flush(t.root) -} - -func (t *Trie) flush(node Node) { - if node.IsFlushed() { - return - } - switch n := node.(type) { - case *BranchNode: - for i := range n.Children { - t.flush(n.Children[i]) + for h, node := range t.refcount { + if node.refcount != 0 { + if node.bytes == nil { + panic("item not in trie") + } + if t.refcountEnabled { + node.initial = t.updateRefCount(h) + if node.initial == 0 { + delete(t.refcount, h) + } + } else if node.refcount > 0 { + _ = t.Store.Put(makeStorageKey(h.BytesBE()), node.bytes) + } + node.refcount = 0 + } else { + delete(t.refcount, h) } - case *ExtensionNode: - t.flush(n.next) - case *HashNode: - return } - t.putToStore(node) } -func (t *Trie) putToStore(n Node) { - if n.Type() == HashT { - panic("can't put hash node in trie") +// updateRefCount should be called only when refcounting is enabled. +func (t *Trie) updateRefCount(h util.Uint256) int32 { + if !t.refcountEnabled { + panic("`updateRefCount` is called, but GC is disabled") + } + var data []byte + key := makeStorageKey(h.BytesBE()) + node := t.refcount[h] + cnt := node.initial + if cnt == 0 { + // A newly created item which may be in store. + var err error + data, err = t.Store.Get(key) + if err == nil { + cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:])) + } + } + if len(data) == 0 { + data = append(node.bytes, 0, 0, 0, 0) + } + cnt += node.refcount + switch { + case cnt < 0: + // BUG: negative reference count + panic(fmt.Sprintf("negative reference count: %s new %d, upd %d", h.StringBE(), cnt, t.refcount[h])) + case cnt == 0: + _ = t.Store.Delete(key) + default: + binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt)) + _ = t.Store.Put(key, data) + } + return cnt +} + +func (t *Trie) addRef(h util.Uint256, bs []byte) { + node := t.refcount[h] + if node == nil { + t.refcount[h] = &cachedNode{ + refcount: 1, + bytes: bs, + } + return + } + node.refcount++ + if node.bytes == nil { + node.bytes = bs + } +} + +func (t *Trie) removeRef(h util.Uint256, bs []byte) { + node := t.refcount[h] + if node == nil { + t.refcount[h] = &cachedNode{ + refcount: -1, + bytes: bs, + } + return + } + node.refcount-- + if node.bytes == nil { + node.bytes = bs } - _ = t.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) // put in MemCached returns no errors - n.SetFlushed() } func (t *Trie) getFromStore(h util.Uint256) (Node, error) { @@ -354,6 +457,16 @@ func (t *Trie) getFromStore(h util.Uint256) (Node, error) { if r.Err != nil { return nil, r.Err } + + if t.refcountEnabled { + data = data[:len(data)-4] + node := t.refcount[h] + if node != nil { + node.bytes = data + node.initial = int32(r.ReadU32LE()) + } + } + n.Node.(flushedNode).setCache(data, h) return n.Node, nil } @@ -365,6 +478,7 @@ func (t *Trie) Collapse(depth int) { panic("negative depth") } t.root = collapse(depth, t.root) + t.refcount = make(map[util.Uint256]*cachedNode) } func collapse(depth int, node Node) Node { diff --git a/pkg/core/mpt/trie_test.go b/pkg/core/mpt/trie_test.go index d06e08168..e8fc532e5 100644 --- a/pkg/core/mpt/trie_test.go +++ b/pkg/core/mpt/trie_test.go @@ -26,7 +26,7 @@ func newTestTrie(t *testing.T) *Trie { b.Children[10] = NewExtensionNode([]byte{0x0e}, h) e := NewExtensionNode(toNibbles([]byte{0xAC}), b) - tr := NewTrie(e, newTestStore()) + tr := NewTrie(e, false, newTestStore()) tr.putToStore(e) tr.putToStore(b) @@ -40,12 +40,50 @@ func newTestTrie(t *testing.T) *Trie { return tr } +func testTrieRefcount(t *testing.T, key1, key2 []byte) { + tr := NewTrie(nil, true, storage.NewMemCachedStore(storage.NewMemoryStore())) + require.NoError(t, tr.Put(key1, []byte{1})) + tr.Flush() + require.NoError(t, tr.Put(key2, []byte{1})) + tr.Flush() + tr.testHas(t, key1, []byte{1}) + tr.testHas(t, key2, []byte{1}) + + // remove first, keep second + require.NoError(t, tr.Delete(key1)) + tr.Flush() + tr.testHas(t, key1, nil) + tr.testHas(t, key2, []byte{1}) + + // no-op + require.NoError(t, tr.Put(key1, []byte{1})) + require.NoError(t, tr.Delete(key1)) + tr.Flush() + tr.testHas(t, key1, nil) + tr.testHas(t, key2, []byte{1}) + + // error on delete, refcount should not be updated + require.Error(t, tr.Delete(key1)) + tr.Flush() + tr.testHas(t, key1, nil) + tr.testHas(t, key2, []byte{1}) +} + +func TestTrie_Refcount(t *testing.T) { + t.Run("Leaf", func(t *testing.T) { + testTrieRefcount(t, []byte{0x11}, []byte{0x12}) + }) + t.Run("Extension", func(t *testing.T) { + testTrieRefcount(t, []byte{0x10, 11}, []byte{0x11, 12}) + }) +} + func TestTrie_PutIntoBranchNode(t *testing.T) { b := NewBranchNode() l := NewLeafNode([]byte{0x8}) b.Children[0x7] = NewHashNode(l.Hash()) b.Children[0x8] = NewHashNode(random.Uint256()) - tr := NewTrie(b, newTestStore()) + tr := NewTrie(b, false, newTestStore()) // next require.NoError(t, tr.Put([]byte{}, []byte{0x12, 0x34})) @@ -70,7 +108,7 @@ func TestTrie_PutIntoExtensionNode(t *testing.T) { l := NewLeafNode([]byte{0x11}) key := []byte{0x12} e := NewExtensionNode(toNibbles(key), NewHashNode(l.Hash())) - tr := NewTrie(e, newTestStore()) + tr := NewTrie(e, false, newTestStore()) // missing hash require.Error(t, tr.Put(key, []byte{0x42})) @@ -87,7 +125,7 @@ func TestTrie_PutIntoHashNode(t *testing.T) { e := NewExtensionNode([]byte{0x02}, l) b.Children[1] = NewHashNode(e.Hash()) b.Children[9] = NewHashNode(random.Uint256()) - tr := NewTrie(b, newTestStore()) + tr := NewTrie(b, false, newTestStore()) tr.putToStore(e) @@ -108,7 +146,7 @@ func TestTrie_PutIntoHashNode(t *testing.T) { func TestTrie_Put(t *testing.T) { trExp := newTestTrie(t) - trAct := NewTrie(nil, newTestStore()) + trAct := NewTrie(nil, false, newTestStore()) require.NoError(t, trAct.Put([]byte{0xAC, 0x01}, []byte{0xAB, 0xCD})) require.NoError(t, trAct.Put([]byte{0xAC, 0x99}, []byte{0x22, 0x22})) require.NoError(t, trAct.Put([]byte{0xAC, 0xAE}, []byte("hello"))) @@ -119,7 +157,7 @@ func TestTrie_Put(t *testing.T) { } func TestTrie_PutInvalid(t *testing.T) { - tr := NewTrie(nil, newTestStore()) + tr := NewTrie(nil, false, newTestStore()) key, value := []byte("key"), []byte("value") // big key @@ -134,7 +172,7 @@ func TestTrie_PutInvalid(t *testing.T) { } func TestTrie_BigPut(t *testing.T) { - tr := NewTrie(nil, newTestStore()) + tr := NewTrie(nil, false, newTestStore()) items := []struct{ k, v string }{ {"item with long key", "value1"}, {"item with matching prefix", "value2"}, @@ -164,6 +202,21 @@ func TestTrie_BigPut(t *testing.T) { }) } +func (tr *Trie) putToStore(n Node) { + if n.Type() == HashT { + panic("can't put hash node in trie") + } + if tr.refcountEnabled { + tr.refcount[n.Hash()] = &cachedNode{ + bytes: n.Bytes(), + refcount: 1, + } + tr.updateRefCount(n.Hash()) + } else { + _ = tr.Store.Put(makeStorageKey(n.Hash().BytesBE()), n.Bytes()) + } +} + func (tr *Trie) testHas(t *testing.T, key, value []byte) { v, err := tr.Get(key) if value == nil { @@ -208,7 +261,7 @@ func TestTrie_Get(t *testing.T) { }) t.Run("UnfoldRoot", func(t *testing.T) { tr := newTestTrie(t) - single := NewTrie(NewHashNode(tr.root.Hash()), tr.Store) + single := NewTrie(NewHashNode(tr.root.Hash()), false, tr.Store) single.testHas(t, []byte{0xAC}, nil) single.testHas(t, []byte{0xAC, 0x01}, []byte{0xAB, 0xCD}) single.testHas(t, []byte{0xAC, 0x99}, []byte{0x22, 0x22}) @@ -223,13 +276,13 @@ func TestTrie_Flush(t *testing.T) { "key2": []byte("value2"), } - tr := NewTrie(nil, newTestStore()) + tr := NewTrie(nil, false, newTestStore()) for k, v := range pairs { require.NoError(t, tr.Put([]byte(k), v)) } tr.Flush() - tr = NewTrie(NewHashNode(tr.StateRoot()), tr.Store) + tr = NewTrie(NewHashNode(tr.StateRoot()), false, tr.Store) for k, v := range pairs { actual, err := tr.Get([]byte(k)) require.NoError(t, err) @@ -238,10 +291,19 @@ func TestTrie_Flush(t *testing.T) { } func TestTrie_Delete(t *testing.T) { + t.Run("No GC", func(t *testing.T) { + testTrieDelete(t, false) + }) + t.Run("With GC", func(t *testing.T) { + testTrieDelete(t, true) + }) +} + +func testTrieDelete(t *testing.T, enableGC bool) { t.Run("Hash", func(t *testing.T) { t.Run("FromStore", func(t *testing.T) { l := NewLeafNode([]byte{0x12}) - tr := NewTrie(NewHashNode(l.Hash()), newTestStore()) + tr := NewTrie(NewHashNode(l.Hash()), enableGC, newTestStore()) t.Run("NotInStore", func(t *testing.T) { require.Error(t, tr.Delete([]byte{})) }) @@ -253,14 +315,14 @@ func TestTrie_Delete(t *testing.T) { }) t.Run("Empty", func(t *testing.T) { - tr := NewTrie(nil, newTestStore()) + tr := NewTrie(nil, enableGC, newTestStore()) require.Error(t, tr.Delete([]byte{})) }) }) t.Run("Leaf", func(t *testing.T) { l := NewLeafNode([]byte{0x12, 0x34}) - tr := NewTrie(l, newTestStore()) + tr := NewTrie(l, enableGC, newTestStore()) t.Run("NonExistentKey", func(t *testing.T) { require.Error(t, tr.Delete([]byte{0x12})) tr.testHas(t, []byte{}, []byte{0x12, 0x34}) @@ -273,7 +335,7 @@ func TestTrie_Delete(t *testing.T) { t.Run("SingleKey", func(t *testing.T) { l := NewLeafNode([]byte{0x12, 0x34}) e := NewExtensionNode([]byte{0x0A, 0x0B}, l) - tr := NewTrie(e, newTestStore()) + tr := NewTrie(e, enableGC, newTestStore()) t.Run("NonExistentKey", func(t *testing.T) { require.Error(t, tr.Delete([]byte{})) @@ -289,7 +351,7 @@ func TestTrie_Delete(t *testing.T) { b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x12, 0x34})) b.Children[6] = NewExtensionNode([]byte{0x07}, NewLeafNode([]byte{0x56, 0x78})) e := NewExtensionNode([]byte{0x01, 0x02}, b) - tr := NewTrie(e, newTestStore()) + tr := NewTrie(e, enableGC, newTestStore()) h := e.Hash() require.NoError(t, tr.Delete([]byte{0x12, 0x01})) @@ -308,7 +370,7 @@ func TestTrie_Delete(t *testing.T) { b.Children[lastChild] = NewLeafNode([]byte{0x12}) b.Children[0] = NewExtensionNode([]byte{0x01}, NewLeafNode([]byte{0x34})) b.Children[1] = NewExtensionNode([]byte{0x06}, NewLeafNode([]byte{0x56})) - tr := NewTrie(b, newTestStore()) + tr := NewTrie(b, enableGC, newTestStore()) require.NoError(t, tr.Delete([]byte{0x16})) tr.testHas(t, []byte{}, []byte{0x12}) tr.testHas(t, []byte{0x01}, []byte{0x34}) @@ -321,7 +383,7 @@ func TestTrie_Delete(t *testing.T) { l := NewLeafNode([]byte{0x34}) e := NewExtensionNode([]byte{0x06}, l) b.Children[5] = NewHashNode(e.Hash()) - tr := NewTrie(b, newTestStore()) + tr := NewTrie(b, enableGC, newTestStore()) tr.putToStore(l) tr.putToStore(e) return tr @@ -344,7 +406,7 @@ func TestTrie_Delete(t *testing.T) { b := NewBranchNode() b.Children[lastChild] = NewLeafNode([]byte{0x12}) b.Children[5] = c - tr := NewTrie(b, newTestStore()) + tr := NewTrie(b, enableGC, newTestStore()) require.NoError(t, tr.Delete([]byte{})) tr.testHas(t, []byte{}, nil) @@ -396,7 +458,7 @@ func TestTrie_Collapse(t *testing.T) { b.Children[0] = e hb := b.Hash() - tr := NewTrie(b, newTestStore()) + tr := NewTrie(b, false, newTestStore()) tr.Collapse(1) newb, ok := tr.root.(*BranchNode) @@ -410,7 +472,7 @@ func TestTrie_Collapse(t *testing.T) { hl := l.Hash() e := NewExtensionNode([]byte{0x01}, l) h := e.Hash() - tr := NewTrie(e, newTestStore()) + tr := NewTrie(e, false, newTestStore()) tr.Collapse(1) newe, ok := tr.root.(*ExtensionNode) @@ -421,13 +483,13 @@ func TestTrie_Collapse(t *testing.T) { }) t.Run("Leaf", func(t *testing.T) { l := NewLeafNode([]byte("value")) - tr := NewTrie(l, newTestStore()) + tr := NewTrie(l, false, newTestStore()) tr.Collapse(10) require.Equal(t, NewLeafNode([]byte("value")), tr.root) }) t.Run("Hash", func(t *testing.T) { t.Run("Empty", func(t *testing.T) { - tr := NewTrie(new(HashNode), newTestStore()) + tr := NewTrie(new(HashNode), false, newTestStore()) require.NotPanics(t, func() { tr.Collapse(1) }) hn, ok := tr.root.(*HashNode) require.True(t, ok) @@ -436,7 +498,7 @@ func TestTrie_Collapse(t *testing.T) { h := random.Uint256() hn := NewHashNode(h) - tr := NewTrie(hn, newTestStore()) + tr := NewTrie(hn, false, newTestStore()) tr.Collapse(10) newRoot, ok := tr.root.(*HashNode) diff --git a/pkg/rpc/response/result/mpt_test.go b/pkg/rpc/response/result/mpt_test.go index 22e0c021c..27173fb9b 100644 --- a/pkg/rpc/response/result/mpt_test.go +++ b/pkg/rpc/response/result/mpt_test.go @@ -41,10 +41,9 @@ func TestGetProof_MarshalJSON(t *testing.T) { require.Equal(t, 8, len(p.Result.Proof)) for i := range p.Result.Proof { // smoke test that every chunk is correctly encoded node r := io.NewBinReaderFromBuf(p.Result.Proof[i]) - var n mpt.NodeObject - n.DecodeBinary(r) + n := mpt.DecodeNodeWithType(r) require.NoError(t, r.Err) - require.NotNil(t, n.Node) + require.NotNil(t, n) } }) }