mempool: properly remove fees when removing tx during Add

Fixes #3488.

Signed-off-by: Roman Khimov <roman@nspcc.ru>
This commit is contained in:
Roman Khimov 2024-07-30 18:11:05 +03:00
parent a11e433754
commit 5d1d7b104e
2 changed files with 31 additions and 24 deletions

View file

@ -250,19 +250,8 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer, data ...any) error {
} }
// Ditch the last one. // Ditch the last one.
unlucky := mp.verifiedTxes[len(mp.verifiedTxes)-1] unlucky := mp.verifiedTxes[len(mp.verifiedTxes)-1]
delete(mp.verifiedMap, unlucky.txn.Hash())
mp.removeConflictsOf(unlucky.txn)
if attrs := unlucky.txn.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 {
delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID)
}
mp.verifiedTxes[len(mp.verifiedTxes)-1] = pItem mp.verifiedTxes[len(mp.verifiedTxes)-1] = pItem
if mp.subscriptionsOn.Load() { mp.removeFromMapWithFeesAndAttrs(unlucky)
mp.events <- mempoolevent.Event{
Type: mempoolevent.TransactionRemoved,
Tx: unlucky.txn,
Data: unlucky.data,
}
}
} else { } else {
mp.verifiedTxes = append(mp.verifiedTxes, pItem) mp.verifiedTxes = append(mp.verifiedTxes, pItem)
} }
@ -305,14 +294,15 @@ func (mp *Pool) Remove(hash util.Uint256) {
mp.lock.Unlock() mp.lock.Unlock()
} }
// removeInternal is an internal unlocked representation of Remove. // removeInternal is an internal unlocked representation of Remove, it drops
// transaction from verifiedMap and verifiedTxs, adjusts fees and fires a
// "removed" event.
func (mp *Pool) removeInternal(hash util.Uint256) { func (mp *Pool) removeInternal(hash util.Uint256) {
tx, ok := mp.verifiedMap[hash] _, ok := mp.verifiedMap[hash]
if !ok { if !ok {
return return
} }
var num int var num int
delete(mp.verifiedMap, hash)
for num = range mp.verifiedTxes { for num = range mp.verifiedTxes {
if hash.Equals(mp.verifiedTxes[num].txn.Hash()) { if hash.Equals(mp.verifiedTxes[num].txn.Hash()) {
break break
@ -324,13 +314,23 @@ func (mp *Pool) removeInternal(hash util.Uint256) {
} else if num == len(mp.verifiedTxes)-1 { } else if num == len(mp.verifiedTxes)-1 {
mp.verifiedTxes = mp.verifiedTxes[:num] mp.verifiedTxes = mp.verifiedTxes[:num]
} }
mp.removeFromMapWithFeesAndAttrs(itm)
}
// removeFromMapWithFeesAndAttrs removes given item (with the given hash) from
// verifiedMap, adjusts fees, handles attributes and fires an event. Notice
// that it does not do anything to verifiedTxes (the presumption is that if
// you have itm already, you can handle it fine for the specific case).
// It's an internal method, locking is to be handled by the caller.
func (mp *Pool) removeFromMapWithFeesAndAttrs(itm item) {
delete(mp.verifiedMap, itm.txn.Hash())
payer := itm.txn.Signers[mp.payerIndex].Account payer := itm.txn.Signers[mp.payerIndex].Account
senderFee := mp.fees[payer] senderFee := mp.fees[payer]
senderFee.feeSum.SubUint64(&senderFee.feeSum, uint64(tx.SystemFee+tx.NetworkFee)) senderFee.feeSum.SubUint64(&senderFee.feeSum, uint64(itm.txn.SystemFee+itm.txn.NetworkFee))
mp.fees[payer] = senderFee mp.fees[payer] = senderFee
// remove all conflicting hashes from mp.conflicts list // remove all conflicting hashes from mp.conflicts list
mp.removeConflictsOf(tx) mp.removeConflictsOf(itm.txn)
if attrs := tx.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 { if attrs := itm.txn.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 {
delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID)
} }
if mp.subscriptionsOn.Load() { if mp.subscriptionsOn.Load() {

View file

@ -112,18 +112,20 @@ func TestMemPoolAddRemove(t *testing.T) {
func TestOverCapacity(t *testing.T) { func TestOverCapacity(t *testing.T) {
var fs = &FeerStub{balance: 10000000} var fs = &FeerStub{balance: 10000000}
var acc = util.Uint160{1, 2, 3}
const mempoolSize = 10 const mempoolSize = 10
mp := New(mempoolSize, 0, false, nil) mp := New(mempoolSize, 0, false, nil)
for i := 0; i < mempoolSize; i++ { for i := 0; i < mempoolSize; i++ {
tx := transaction.New([]byte{byte(opcode.PUSH1)}, 0) tx := transaction.New([]byte{byte(opcode.PUSH1)}, 0)
tx.Nonce = uint32(i) tx.Nonce = uint32(i)
tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} tx.Signers = []transaction.Signer{{Account: acc}}
require.NoError(t, mp.Add(tx, fs)) require.NoError(t, mp.Add(tx, fs))
} }
txcnt := uint32(mempoolSize) txcnt := uint32(mempoolSize)
require.Equal(t, mempoolSize, mp.Count()) require.Equal(t, mempoolSize, mp.Count())
require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes))) require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes)))
require.Equal(t, *uint256.NewInt(0), mp.fees[acc].feeSum)
bigScript := make([]byte, 64) bigScript := make([]byte, 64)
bigScript[0] = byte(opcode.PUSH1) bigScript[0] = byte(opcode.PUSH1)
@ -133,18 +135,20 @@ func TestOverCapacity(t *testing.T) {
tx := transaction.New(bigScript, 0) tx := transaction.New(bigScript, 0)
tx.NetworkFee = 10000 tx.NetworkFee = 10000
tx.Nonce = txcnt tx.Nonce = txcnt
tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} tx.Signers = []transaction.Signer{{Account: acc}}
txcnt++ txcnt++
// size is ~90, networkFee is 10000 => feePerByte is 119 // size is ~90, networkFee is 10000 => feePerByte is 119
require.NoError(t, mp.Add(tx, fs)) require.NoError(t, mp.Add(tx, fs))
require.Equal(t, mempoolSize, mp.Count()) require.Equal(t, mempoolSize, mp.Count())
require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes))) require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes)))
} }
require.Equal(t, *uint256.NewInt(10 * 10000), mp.fees[acc].feeSum)
// Less prioritized txes are not allowed anymore. // Less prioritized txes are not allowed anymore.
tx := transaction.New(bigScript, 0) tx := transaction.New(bigScript, 0)
tx.NetworkFee = 100 tx.NetworkFee = 100
tx.Nonce = txcnt tx.Nonce = txcnt
tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} tx.Signers = []transaction.Signer{{Account: acc}}
txcnt++ txcnt++
require.Error(t, mp.Add(tx, fs)) require.Error(t, mp.Add(tx, fs))
require.Equal(t, mempoolSize, mp.Count()) require.Equal(t, mempoolSize, mp.Count())
@ -152,35 +156,38 @@ func TestOverCapacity(t *testing.T) {
require.Equal(t, mempoolSize, len(mp.verifiedTxes)) require.Equal(t, mempoolSize, len(mp.verifiedTxes))
require.False(t, mp.containsKey(tx.Hash())) require.False(t, mp.containsKey(tx.Hash()))
require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes))) require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes)))
require.Equal(t, *uint256.NewInt(100000), mp.fees[acc].feeSum)
// Low net fee, but higher per-byte fee is still a better combination. // Low net fee, but higher per-byte fee is still a better combination.
tx = transaction.New([]byte{byte(opcode.PUSH1)}, 0) tx = transaction.New([]byte{byte(opcode.PUSH1)}, 0)
tx.Nonce = txcnt tx.Nonce = txcnt
tx.NetworkFee = 7000 tx.NetworkFee = 7000
tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} tx.Signers = []transaction.Signer{{Account: acc}}
txcnt++ txcnt++
// size is ~51 (small script), networkFee is 7000 (<10000) // size is ~51 (small script), networkFee is 7000 (<10000)
// => feePerByte is 137 (>119) // => feePerByte is 137 (>119)
require.NoError(t, mp.Add(tx, fs)) require.NoError(t, mp.Add(tx, fs))
require.Equal(t, mempoolSize, mp.Count()) require.Equal(t, mempoolSize, mp.Count())
require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes))) require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes)))
require.Equal(t, *uint256.NewInt(9*10000 + 7000), mp.fees[acc].feeSum)
// High priority always wins over low priority. // High priority always wins over low priority.
for i := 0; i < mempoolSize; i++ { for i := 0; i < mempoolSize; i++ {
tx := transaction.New([]byte{byte(opcode.PUSH1)}, 0) tx := transaction.New([]byte{byte(opcode.PUSH1)}, 0)
tx.NetworkFee = 8000 tx.NetworkFee = 8000
tx.Nonce = txcnt tx.Nonce = txcnt
tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} tx.Signers = []transaction.Signer{{Account: acc}}
txcnt++ txcnt++
require.NoError(t, mp.Add(tx, fs)) require.NoError(t, mp.Add(tx, fs))
require.Equal(t, mempoolSize, mp.Count()) require.Equal(t, mempoolSize, mp.Count())
require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes))) require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes)))
} }
require.Equal(t, *uint256.NewInt(10 * 8000), mp.fees[acc].feeSum)
// Good luck with low priority now. // Good luck with low priority now.
tx = transaction.New([]byte{byte(opcode.PUSH1)}, 0) tx = transaction.New([]byte{byte(opcode.PUSH1)}, 0)
tx.Nonce = txcnt tx.Nonce = txcnt
tx.NetworkFee = 7000 tx.NetworkFee = 7000
tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} tx.Signers = []transaction.Signer{{Account: acc}}
require.Error(t, mp.Add(tx, fs)) require.Error(t, mp.Add(tx, fs))
require.Equal(t, mempoolSize, mp.Count()) require.Equal(t, mempoolSize, mp.Count())
require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes))) require.Equal(t, true, sort.IsSorted(sort.Reverse(mp.verifiedTxes)))