diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index ae8b9d9bf..1689af8bb 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -30,6 +30,9 @@ var ( // ErrConflictsAttribute is returned when transaction conflicts with other transactions // due to its (or theirs) Conflicts attributes. ErrConflictsAttribute = errors.New("conflicts with memory pool due to Conflicts attribute") + // ErrOracleResponse is returned when mempool already contains transaction + // with the same oracle response ID and higher network fee. + ErrOracleResponse = errors.New("conflicts with memory pool due to OracleResponse attribute") ) // item represents a transaction in the the Memory pool. @@ -56,6 +59,8 @@ type Pool struct { fees map[util.Uint160]utilityBalanceAndFees // conflicts is a map of hashes of transactions which are conflicting with the mempooled ones. conflicts map[util.Uint256][]util.Uint256 + // oracleResp contains ids of oracle responses for tx in pool. + oracleResp map[uint64]util.Uint256 capacity int feePerByte int64 @@ -192,6 +197,18 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { mp.lock.Unlock() return err } + if attrs := t.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 { + id := attrs[0].Value.(*transaction.OracleResponse).ID + h, ok := mp.oracleResp[id] + if ok { + if mp.verifiedMap[h].NetworkFee >= t.NetworkFee { + mp.lock.Unlock() + return ErrOracleResponse + } + mp.removeInternal(h, fee) + } + mp.oracleResp[id] = t.Hash() + } mp.verifiedMap[t.Hash()] = t if fee.P2PSigExtensionsEnabled() { @@ -276,6 +293,9 @@ func (mp *Pool) removeInternal(hash util.Uint256, feer Feer) { // remove all conflicting hashes from mp.conflicts list mp.removeConflictsOf(tx) } + if attrs := tx.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 { + delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) + } } updateMempoolMetrics(len(mp.verifiedTxes)) } @@ -314,6 +334,9 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer) } } else { delete(mp.verifiedMap, itm.txn.Hash()) + if attrs := itm.txn.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 { + delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) + } } } if len(staleTxs) != 0 { @@ -350,6 +373,7 @@ func New(capacity int) *Pool { capacity: capacity, fees: make(map[util.Uint160]utilityBalanceAndFees), conflicts: make(map[util.Uint256][]util.Uint256), + oracleResp: make(map[uint64]util.Uint256), } } diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index 31fcf7212..e8167dc01 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -351,6 +351,59 @@ func TestMempoolItemsOrder(t *testing.T) { require.True(t, item4.CompareTo(item3) < 0) } +func TestMempoolAddRemoveOracleResponse(t *testing.T) { + mp := New(5) + nonce := uint32(0) + fs := &FeerStub{} + newTx := func(netFee int64, id uint64) *transaction.Transaction { + tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) + tx.NetworkFee = netFee + tx.Nonce = nonce + nonce++ + tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} + tx.Attributes = []transaction.Attribute{{ + Type: transaction.OracleResponseT, + Value: &transaction.OracleResponse{ID: id}, + }} + // sanity check + _, ok := mp.TryGetValue(tx.Hash()) + require.False(t, ok) + return tx + } + + tx1 := newTx(10, 1) + require.NoError(t, mp.Add(tx1, fs)) + + // smaller network fee + tx2 := newTx(5, 1) + err := mp.Add(tx2, fs) + require.True(t, errors.Is(err, ErrOracleResponse)) + + // ok if old tx is removed + mp.Remove(tx1.Hash(), fs) + require.NoError(t, mp.Add(tx2, fs)) + + // higher network fee + tx3 := newTx(6, 1) + require.NoError(t, mp.Add(tx3, fs)) + _, ok := mp.TryGetValue(tx2.Hash()) + require.False(t, ok) + _, ok = mp.TryGetValue(tx3.Hash()) + require.True(t, ok) + + // another oracle response ID + tx4 := newTx(4, 2) + require.NoError(t, mp.Add(tx4, fs)) + + mp.RemoveStale(func(tx *transaction.Transaction) bool { + return tx.Hash() != tx4.Hash() + }, fs) + + // check that oracle id was removed. + tx5 := newTx(3, 2) + require.NoError(t, mp.Add(tx5, fs)) +} + func TestMempoolAddRemoveConflicts(t *testing.T) { capacity := 6 mp := New(capacity)