From cba117352ca2cabcf27caaa0b43dfffc2ae7b4a8 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Mon, 30 Nov 2020 12:48:18 +0300 Subject: [PATCH] mempool: correctly handle tx with oracle response If tx with the same oracle response ID is already in mempool, replace it if network fee of added transaction is higher and return error otherwise. --- pkg/core/mempool/mem_pool.go | 24 ++++++++++++++ pkg/core/mempool/mem_pool_test.go | 53 +++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index ae8b9d9bf..1689af8bb 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -30,6 +30,9 @@ var ( // ErrConflictsAttribute is returned when transaction conflicts with other transactions // due to its (or theirs) Conflicts attributes. ErrConflictsAttribute = errors.New("conflicts with memory pool due to Conflicts attribute") + // ErrOracleResponse is returned when mempool already contains transaction + // with the same oracle response ID and higher network fee. + ErrOracleResponse = errors.New("conflicts with memory pool due to OracleResponse attribute") ) // item represents a transaction in the the Memory pool. @@ -56,6 +59,8 @@ type Pool struct { fees map[util.Uint160]utilityBalanceAndFees // conflicts is a map of hashes of transactions which are conflicting with the mempooled ones. conflicts map[util.Uint256][]util.Uint256 + // oracleResp contains ids of oracle responses for tx in pool. + oracleResp map[uint64]util.Uint256 capacity int feePerByte int64 @@ -192,6 +197,18 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { mp.lock.Unlock() return err } + if attrs := t.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 { + id := attrs[0].Value.(*transaction.OracleResponse).ID + h, ok := mp.oracleResp[id] + if ok { + if mp.verifiedMap[h].NetworkFee >= t.NetworkFee { + mp.lock.Unlock() + return ErrOracleResponse + } + mp.removeInternal(h, fee) + } + mp.oracleResp[id] = t.Hash() + } mp.verifiedMap[t.Hash()] = t if fee.P2PSigExtensionsEnabled() { @@ -276,6 +293,9 @@ func (mp *Pool) removeInternal(hash util.Uint256, feer Feer) { // remove all conflicting hashes from mp.conflicts list mp.removeConflictsOf(tx) } + if attrs := tx.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 { + delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) + } } updateMempoolMetrics(len(mp.verifiedTxes)) } @@ -314,6 +334,9 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer) } } else { delete(mp.verifiedMap, itm.txn.Hash()) + if attrs := itm.txn.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 { + delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) + } } } if len(staleTxs) != 0 { @@ -350,6 +373,7 @@ func New(capacity int) *Pool { capacity: capacity, fees: make(map[util.Uint160]utilityBalanceAndFees), conflicts: make(map[util.Uint256][]util.Uint256), + oracleResp: make(map[uint64]util.Uint256), } } diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index 31fcf7212..e8167dc01 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -351,6 +351,59 @@ func TestMempoolItemsOrder(t *testing.T) { require.True(t, item4.CompareTo(item3) < 0) } +func TestMempoolAddRemoveOracleResponse(t *testing.T) { + mp := New(5) + nonce := uint32(0) + fs := &FeerStub{} + newTx := func(netFee int64, id uint64) *transaction.Transaction { + tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) + tx.NetworkFee = netFee + tx.Nonce = nonce + nonce++ + tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} + tx.Attributes = []transaction.Attribute{{ + Type: transaction.OracleResponseT, + Value: &transaction.OracleResponse{ID: id}, + }} + // sanity check + _, ok := mp.TryGetValue(tx.Hash()) + require.False(t, ok) + return tx + } + + tx1 := newTx(10, 1) + require.NoError(t, mp.Add(tx1, fs)) + + // smaller network fee + tx2 := newTx(5, 1) + err := mp.Add(tx2, fs) + require.True(t, errors.Is(err, ErrOracleResponse)) + + // ok if old tx is removed + mp.Remove(tx1.Hash(), fs) + require.NoError(t, mp.Add(tx2, fs)) + + // higher network fee + tx3 := newTx(6, 1) + require.NoError(t, mp.Add(tx3, fs)) + _, ok := mp.TryGetValue(tx2.Hash()) + require.False(t, ok) + _, ok = mp.TryGetValue(tx3.Hash()) + require.True(t, ok) + + // another oracle response ID + tx4 := newTx(4, 2) + require.NoError(t, mp.Add(tx4, fs)) + + mp.RemoveStale(func(tx *transaction.Transaction) bool { + return tx.Hash() != tx4.Hash() + }, fs) + + // check that oracle id was removed. + tx5 := newTx(3, 2) + require.NoError(t, mp.Add(tx5, fs)) +} + func TestMempoolAddRemoveConflicts(t *testing.T) { capacity := 6 mp := New(capacity)