diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index fd47b9f91..c69116048 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math/big" + "sort" "sync" "sync/atomic" "time" @@ -607,7 +608,8 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } } - for _, tx := range block.Transactions { + var txHashes = make([]util.Uint256, len(block.Transactions)) + for i, tx := range block.Transactions { if err := cache.StoreAsTransaction(tx, block.Index); err != nil { return err } @@ -624,8 +626,8 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { if err != nil { return fmt.Errorf("failed to persist invocation results: %w", err) } - for i := range systemInterop.Notifications { - bc.handleNotification(&systemInterop.Notifications[i], cache, block, tx.Hash()) + for j := range systemInterop.Notifications { + bc.handleNotification(&systemInterop.Notifications[j], cache, block, tx.Hash()) } } else { bc.log.Warn("contract invocation failed", @@ -646,7 +648,11 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { if err != nil { return fmt.Errorf("failed to store tx exec result: %w", err) } + txHashes[i] = tx.Hash() } + sort.Slice(txHashes, func(i, j int) bool { + return txHashes[i].CompareTo(txHashes[j]) < 0 + }) root := bc.dao.MPT.StateRoot() var prevHash util.Uint256 @@ -688,7 +694,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) - bc.memPool.RemoveStale(bc.isTxStillRelevant, bc) + bc.memPool.RemoveStale(func(tx *transaction.Transaction) bool { return bc.isTxStillRelevant(tx, txHashes) }, bc) bc.lock.Unlock() updateBlockHeightMetric(block.Index) @@ -1251,16 +1257,18 @@ func (bc *Blockchain) verifyTx(t *transaction.Transaction, block *block.Block) e } // isTxStillRelevant is a callback for mempool transaction filtering after the -// new block addition. It returns false for transactions already present in the -// chain (added by the new block), transactions using some inputs that are -// already used (double spends) and does witness reverification for non-standard +// new block addition. It returns false for transactions added by the new block +// (passed via txHashes) and does witness reverification for non-standard // contracts. It operates under the assumption that full transaction verification // was already done so we don't need to check basic things like size, input/output -// correctness, etc. -func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool { +// correctness, presence in blocks before the new one, etc. +func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction, txHashes []util.Uint256) bool { var recheckWitness bool - if bc.dao.HasTransaction(t.Hash()) { + index := sort.Search(len(txHashes), func(i int) bool { + return txHashes[i].CompareTo(t.Hash()) >= 0 + }) + if index < len(txHashes) && txHashes[index].Equals(t.Hash()) { return false } for i := range t.Scripts { diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 58aa342cc..dbb589b4a 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -402,6 +402,34 @@ func TestVerifyHashAgainstScript(t *testing.T) { }) } +func TestMemPoolRemoval(t *testing.T) { + const added = 16 + const notAdded = 32 + bc := newTestChain(t) + defer bc.Close() + addedTxes := make([]*transaction.Transaction, added) + notAddedTxes := make([]*transaction.Transaction, notAdded) + for i := range addedTxes { + addedTxes[i] = bc.newTestTx(testchain.MultisigScriptHash(), []byte{byte(opcode.PUSH1)}) + require.NoError(t, signTx(bc, addedTxes[i])) + require.NoError(t, bc.PoolTx(addedTxes[i])) + } + for i := range notAddedTxes { + notAddedTxes[i] = bc.newTestTx(testchain.MultisigScriptHash(), []byte{byte(opcode.PUSH1)}) + require.NoError(t, signTx(bc, notAddedTxes[i])) + require.NoError(t, bc.PoolTx(notAddedTxes[i])) + } + b := bc.newBlock(addedTxes...) + require.NoError(t, bc.AddBlock(b)) + mempool := bc.GetMemPool() + for _, tx := range addedTxes { + require.False(t, mempool.ContainsKey(tx.Hash())) + } + for _, tx := range notAddedTxes { + require.True(t, mempool.ContainsKey(tx.Hash())) + } +} + func TestHasBlock(t *testing.T) { bc := newTestChain(t) blocks, err := bc.genBlocks(50)