From 68b9ff1f17021fde9602516b565109ef85ab2af9 Mon Sep 17 00:00:00 2001
From: Anna Shaleva <shaleva.ann@nspcc.ru>
Date: Wed, 31 May 2023 19:22:04 +0300
Subject: [PATCH] mempool: adjust the rule of conflicting transaction ranking

Pay for all the conflicts if you'd like to went in. Close #3028.

Signed-off-by: Anna Shaleva <shaleva.ann@nspcc.ru>
---
 pkg/core/mempool/mem_pool.go      |  16 +++--
 pkg/core/mempool/mem_pool_test.go | 111 +++++++++++++++++++++++++-----
 2 files changed, 103 insertions(+), 24 deletions(-)

diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go
index 901b528bb..10512f946 100644
--- a/pkg/core/mempool/mem_pool.go
+++ b/pkg/core/mempool/mem_pool.go
@@ -520,14 +520,17 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*tran
 
 	var expectedSenderFee utilityBalanceAndFees
 	// Check Conflicts attributes.
-	var conflictsToBeRemoved []*transaction.Transaction
+	var (
+		conflictsToBeRemoved []*transaction.Transaction
+		conflictingFee       int64
+	)
 	if fee.P2PSigExtensionsEnabled() {
 		// Step 1: check if `tx` was in attributes of mempooled transactions.
 		if conflictingHashes, ok := mp.conflicts[tx.Hash()]; ok {
 			for _, hash := range conflictingHashes {
 				existingTx := mp.verifiedMap[hash]
-				if existingTx.HasSigner(payer) && existingTx.NetworkFee > tx.NetworkFee {
-					return nil, fmt.Errorf("%w: conflicting transaction %s has bigger network fee", ErrConflictsAttribute, existingTx.Hash().StringBE())
+				if existingTx.HasSigner(payer) {
+					conflictingFee += existingTx.NetworkFee
 				}
 				conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx)
 			}
@@ -542,11 +545,12 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*tran
 			if !tx.HasSigner(existingTx.Signers[mp.payerIndex].Account) {
 				return nil, fmt.Errorf("%w: not signed by the sender of conflicting transaction %s", ErrConflictsAttribute, existingTx.Hash().StringBE())
 			}
-			if existingTx.NetworkFee >= tx.NetworkFee {
-				return nil, fmt.Errorf("%w: conflicting transaction %s has bigger or equal network fee", ErrConflictsAttribute, existingTx.Hash().StringBE())
-			}
+			conflictingFee += existingTx.NetworkFee
 			conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx)
 		}
+		if conflictingFee != 0 && tx.NetworkFee <= conflictingFee {
+			return nil, fmt.Errorf("%w: conflicting transactions have bigger or equal network fee: %d vs %d", ErrConflictsAttribute, tx.NetworkFee, conflictingFee)
+		}
 		// Step 3: take into account sender's conflicting transactions before balance check.
 		expectedSenderFee = actualSenderFee
 		for _, conflictingTx := range conflictsToBeRemoved {
diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go
index 1d6514ee8..d1d247dc9 100644
--- a/pkg/core/mempool/mem_pool_test.go
+++ b/pkg/core/mempool/mem_pool_test.go
@@ -1,13 +1,14 @@
 package mempool
 
 import (
+	"fmt"
 	"math/big"
 	"sort"
+	"strings"
 	"testing"
 	"time"
 
 	"github.com/holiman/uint256"
-	"github.com/nspcc-dev/neo-go/internal/random"
 	"github.com/nspcc-dev/neo-go/pkg/core/transaction"
 	"github.com/nspcc-dev/neo-go/pkg/network/payload"
 	"github.com/nspcc-dev/neo-go/pkg/util"
@@ -423,18 +424,23 @@ func TestMempoolAddRemoveOracleResponse(t *testing.T) {
 }
 
 func TestMempoolAddRemoveConflicts(t *testing.T) {
-	capacity := 6
-	mp := New(capacity, 0, false, nil)
+	var (
+		capacity        = 6
+		mp              = New(capacity, 0, false, nil)
+		sender          = transaction.Signer{Account: util.Uint160{1, 2, 3}}
+		maliciousSender = transaction.Signer{Account: util.Uint160{4, 5, 6}}
+	)
+
 	var (
 		fs           = &FeerStub{p2pSigExt: true, balance: 100000}
 		nonce uint32 = 1
 	)
-	getConflictsTx := func(netFee int64, hashes ...util.Uint256) *transaction.Transaction {
+	getTx := func(netFee int64, sender transaction.Signer, hashes ...util.Uint256) *transaction.Transaction {
 		tx := transaction.New([]byte{byte(opcode.PUSH1)}, 0)
 		tx.NetworkFee = netFee
 		tx.Nonce = nonce
 		nonce++
-		tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}}
+		tx.Signers = []transaction.Signer{sender}
 		tx.Attributes = make([]transaction.Attribute, len(hashes))
 		for i, h := range hashes {
 			tx.Attributes[i] = transaction.Attribute{
@@ -448,6 +454,12 @@ func TestMempoolAddRemoveConflicts(t *testing.T) {
 		require.Equal(t, false, ok)
 		return tx
 	}
+	getConflictsTx := func(netFee int64, hashes ...util.Uint256) *transaction.Transaction {
+		return getTx(netFee, sender, hashes...)
+	}
+	getMaliciousTx := func(netFee int64, hashes ...util.Uint256) *transaction.Transaction {
+		return getTx(netFee, maliciousSender, hashes...)
+	}
 
 	// tx1 in mempool and does not conflicts with anyone
 	smallNetFee := int64(3)
@@ -528,26 +540,89 @@ func TestMempoolAddRemoveConflicts(t *testing.T) {
 	assert.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()])
 
 	// tx13 conflicts with tx2, but is not signed by tx2.Sender
-	tx13 := transaction.New([]byte{byte(opcode.PUSH1)}, 0)
-	tx13.NetworkFee = smallNetFee
-	tx13.Nonce = uint32(random.Int(0, 1e4))
-	tx13.Signers = []transaction.Signer{{Account: util.Uint160{3, 2, 1}}}
-	tx13.Attributes = []transaction.Attribute{{
-		Type: transaction.ConflictsT,
-		Value: &transaction.Conflicts{
-			Hash: tx2.Hash(),
-		},
-	}}
+	tx13 := getMaliciousTx(smallNetFee, tx2.Hash())
 	_, ok := mp.TryGetValue(tx13.Hash())
 	require.Equal(t, false, ok)
 	require.ErrorIs(t, mp.Add(tx13, fs), ErrConflictsAttribute)
 
+	// tx15 conflicts with tx14, but added firstly and has the same network fee => tx14 must not be added.
 	tx14 := getConflictsTx(smallNetFee)
 	tx15 := getConflictsTx(smallNetFee, tx14.Hash())
 	require.NoError(t, mp.Add(tx15, fs))
-	require.NoError(t, mp.Add(tx14, fs))
-	err := mp.Add(tx15, fs)
-	require.ErrorIs(t, err, ErrConflictsAttribute)
+	err := mp.Add(tx14, fs)
+	require.Error(t, err)
+
+	require.True(t, strings.Contains(err.Error(), fmt.Sprintf("conflicting transactions have bigger or equal network fee: %d vs %d", smallNetFee, smallNetFee)))
+
+	check := func(t *testing.T, mainFee int64, fail bool) {
+		// Clear mempool.
+		mp.RemoveStale(func(t *transaction.Transaction) bool {
+			return false
+		}, fs)
+
+		// mempooled tx17, tx18, tx19 conflict with tx16
+		tx16 := getConflictsTx(mainFee)
+		tx17 := getConflictsTx(smallNetFee, tx16.Hash())
+		tx18 := getConflictsTx(smallNetFee, tx16.Hash())
+		tx19 := getMaliciousTx(smallNetFee, tx16.Hash()) // malicious, thus, doesn't take into account during fee evaluation
+		require.NoError(t, mp.Add(tx17, fs))
+		require.NoError(t, mp.Add(tx18, fs))
+		require.NoError(t, mp.Add(tx19, fs))
+		if fail {
+			require.Error(t, mp.Add(tx16, fs))
+			_, ok = mp.TryGetValue(tx17.Hash())
+			require.True(t, ok)
+			_, ok = mp.TryGetValue(tx18.Hash())
+			require.True(t, ok)
+			_, ok = mp.TryGetValue(tx19.Hash())
+			require.True(t, ok)
+		} else {
+			require.NoError(t, mp.Add(tx16, fs))
+			_, ok = mp.TryGetValue(tx17.Hash())
+			require.False(t, ok)
+			_, ok = mp.TryGetValue(tx18.Hash())
+			require.False(t, ok)
+			_, ok = mp.TryGetValue(tx19.Hash())
+			require.False(t, ok)
+		}
+	}
+	check(t, smallNetFee*2, true)
+	check(t, smallNetFee*2+1, false)
+
+	check = func(t *testing.T, mainFee int64, fail bool) {
+		// Clear mempool.
+		mp.RemoveStale(func(t *transaction.Transaction) bool {
+			return false
+		}, fs)
+
+		// mempooled tx20, tx21, tx22 don't conflict with anyone, but tx23 conflicts with them
+		tx20 := getConflictsTx(smallNetFee)
+		tx21 := getConflictsTx(smallNetFee)
+		tx22 := getConflictsTx(smallNetFee)
+		tx23 := getConflictsTx(mainFee, tx20.Hash(), tx21.Hash(), tx22.Hash())
+		require.NoError(t, mp.Add(tx20, fs))
+		require.NoError(t, mp.Add(tx21, fs))
+		require.NoError(t, mp.Add(tx22, fs))
+		if fail {
+			require.Error(t, mp.Add(tx23, fs))
+			_, ok = mp.TryGetData(tx20.Hash())
+			require.True(t, ok)
+			_, ok = mp.TryGetData(tx21.Hash())
+			require.True(t, ok)
+			_, ok = mp.TryGetData(tx22.Hash())
+			require.True(t, ok)
+		} else {
+			require.NoError(t, mp.Add(tx23, fs))
+			_, ok = mp.TryGetData(tx20.Hash())
+			require.False(t, ok)
+			_, ok = mp.TryGetData(tx21.Hash())
+			require.False(t, ok)
+			_, ok = mp.TryGetData(tx22.Hash())
+			require.False(t, ok)
+		}
+	}
+	check(t, smallNetFee*3, true)
+	check(t, smallNetFee*3+1, false)
 }
 
 func TestMempoolAddWithDataGetData(t *testing.T) {