diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 901b528bb..10512f946 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -520,14 +520,17 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*tran var expectedSenderFee utilityBalanceAndFees // Check Conflicts attributes. - var conflictsToBeRemoved []*transaction.Transaction + var ( + conflictsToBeRemoved []*transaction.Transaction + conflictingFee int64 + ) if fee.P2PSigExtensionsEnabled() { // Step 1: check if `tx` was in attributes of mempooled transactions. if conflictingHashes, ok := mp.conflicts[tx.Hash()]; ok { for _, hash := range conflictingHashes { existingTx := mp.verifiedMap[hash] - if existingTx.HasSigner(payer) && existingTx.NetworkFee > tx.NetworkFee { - return nil, fmt.Errorf("%w: conflicting transaction %s has bigger network fee", ErrConflictsAttribute, existingTx.Hash().StringBE()) + if existingTx.HasSigner(payer) { + conflictingFee += existingTx.NetworkFee } conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx) } @@ -542,11 +545,12 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*tran if !tx.HasSigner(existingTx.Signers[mp.payerIndex].Account) { return nil, fmt.Errorf("%w: not signed by the sender of conflicting transaction %s", ErrConflictsAttribute, existingTx.Hash().StringBE()) } - if existingTx.NetworkFee >= tx.NetworkFee { - return nil, fmt.Errorf("%w: conflicting transaction %s has bigger or equal network fee", ErrConflictsAttribute, existingTx.Hash().StringBE()) - } + conflictingFee += existingTx.NetworkFee conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx) } + if conflictingFee != 0 && tx.NetworkFee <= conflictingFee { + return nil, fmt.Errorf("%w: conflicting transactions have bigger or equal network fee: %d vs %d", ErrConflictsAttribute, tx.NetworkFee, conflictingFee) + } // Step 3: take into account sender's conflicting transactions before balance check. expectedSenderFee = actualSenderFee for _, conflictingTx := range conflictsToBeRemoved { diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index 1d6514ee8..d1d247dc9 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -1,13 +1,14 @@ package mempool import ( + "fmt" "math/big" "sort" + "strings" "testing" "time" "github.com/holiman/uint256" - "github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/util" @@ -423,18 +424,23 @@ func TestMempoolAddRemoveOracleResponse(t *testing.T) { } func TestMempoolAddRemoveConflicts(t *testing.T) { - capacity := 6 - mp := New(capacity, 0, false, nil) + var ( + capacity = 6 + mp = New(capacity, 0, false, nil) + sender = transaction.Signer{Account: util.Uint160{1, 2, 3}} + maliciousSender = transaction.Signer{Account: util.Uint160{4, 5, 6}} + ) + var ( fs = &FeerStub{p2pSigExt: true, balance: 100000} nonce uint32 = 1 ) - getConflictsTx := func(netFee int64, hashes ...util.Uint256) *transaction.Transaction { + getTx := func(netFee int64, sender transaction.Signer, hashes ...util.Uint256) *transaction.Transaction { tx := transaction.New([]byte{byte(opcode.PUSH1)}, 0) tx.NetworkFee = netFee tx.Nonce = nonce nonce++ - tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} + tx.Signers = []transaction.Signer{sender} tx.Attributes = make([]transaction.Attribute, len(hashes)) for i, h := range hashes { tx.Attributes[i] = transaction.Attribute{ @@ -448,6 +454,12 @@ func TestMempoolAddRemoveConflicts(t *testing.T) { require.Equal(t, false, ok) return tx } + getConflictsTx := func(netFee int64, hashes ...util.Uint256) *transaction.Transaction { + return getTx(netFee, sender, hashes...) + } + getMaliciousTx := func(netFee int64, hashes ...util.Uint256) *transaction.Transaction { + return getTx(netFee, maliciousSender, hashes...) + } // tx1 in mempool and does not conflicts with anyone smallNetFee := int64(3) @@ -528,26 +540,89 @@ func TestMempoolAddRemoveConflicts(t *testing.T) { assert.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()]) // tx13 conflicts with tx2, but is not signed by tx2.Sender - tx13 := transaction.New([]byte{byte(opcode.PUSH1)}, 0) - tx13.NetworkFee = smallNetFee - tx13.Nonce = uint32(random.Int(0, 1e4)) - tx13.Signers = []transaction.Signer{{Account: util.Uint160{3, 2, 1}}} - tx13.Attributes = []transaction.Attribute{{ - Type: transaction.ConflictsT, - Value: &transaction.Conflicts{ - Hash: tx2.Hash(), - }, - }} + tx13 := getMaliciousTx(smallNetFee, tx2.Hash()) _, ok := mp.TryGetValue(tx13.Hash()) require.Equal(t, false, ok) require.ErrorIs(t, mp.Add(tx13, fs), ErrConflictsAttribute) + // tx15 conflicts with tx14, but added firstly and has the same network fee => tx14 must not be added. tx14 := getConflictsTx(smallNetFee) tx15 := getConflictsTx(smallNetFee, tx14.Hash()) require.NoError(t, mp.Add(tx15, fs)) - require.NoError(t, mp.Add(tx14, fs)) - err := mp.Add(tx15, fs) - require.ErrorIs(t, err, ErrConflictsAttribute) + err := mp.Add(tx14, fs) + require.Error(t, err) + + require.True(t, strings.Contains(err.Error(), fmt.Sprintf("conflicting transactions have bigger or equal network fee: %d vs %d", smallNetFee, smallNetFee))) + + check := func(t *testing.T, mainFee int64, fail bool) { + // Clear mempool. + mp.RemoveStale(func(t *transaction.Transaction) bool { + return false + }, fs) + + // mempooled tx17, tx18, tx19 conflict with tx16 + tx16 := getConflictsTx(mainFee) + tx17 := getConflictsTx(smallNetFee, tx16.Hash()) + tx18 := getConflictsTx(smallNetFee, tx16.Hash()) + tx19 := getMaliciousTx(smallNetFee, tx16.Hash()) // malicious, thus, doesn't take into account during fee evaluation + require.NoError(t, mp.Add(tx17, fs)) + require.NoError(t, mp.Add(tx18, fs)) + require.NoError(t, mp.Add(tx19, fs)) + if fail { + require.Error(t, mp.Add(tx16, fs)) + _, ok = mp.TryGetValue(tx17.Hash()) + require.True(t, ok) + _, ok = mp.TryGetValue(tx18.Hash()) + require.True(t, ok) + _, ok = mp.TryGetValue(tx19.Hash()) + require.True(t, ok) + } else { + require.NoError(t, mp.Add(tx16, fs)) + _, ok = mp.TryGetValue(tx17.Hash()) + require.False(t, ok) + _, ok = mp.TryGetValue(tx18.Hash()) + require.False(t, ok) + _, ok = mp.TryGetValue(tx19.Hash()) + require.False(t, ok) + } + } + check(t, smallNetFee*2, true) + check(t, smallNetFee*2+1, false) + + check = func(t *testing.T, mainFee int64, fail bool) { + // Clear mempool. + mp.RemoveStale(func(t *transaction.Transaction) bool { + return false + }, fs) + + // mempooled tx20, tx21, tx22 don't conflict with anyone, but tx23 conflicts with them + tx20 := getConflictsTx(smallNetFee) + tx21 := getConflictsTx(smallNetFee) + tx22 := getConflictsTx(smallNetFee) + tx23 := getConflictsTx(mainFee, tx20.Hash(), tx21.Hash(), tx22.Hash()) + require.NoError(t, mp.Add(tx20, fs)) + require.NoError(t, mp.Add(tx21, fs)) + require.NoError(t, mp.Add(tx22, fs)) + if fail { + require.Error(t, mp.Add(tx23, fs)) + _, ok = mp.TryGetData(tx20.Hash()) + require.True(t, ok) + _, ok = mp.TryGetData(tx21.Hash()) + require.True(t, ok) + _, ok = mp.TryGetData(tx22.Hash()) + require.True(t, ok) + } else { + require.NoError(t, mp.Add(tx23, fs)) + _, ok = mp.TryGetData(tx20.Hash()) + require.False(t, ok) + _, ok = mp.TryGetData(tx21.Hash()) + require.False(t, ok) + _, ok = mp.TryGetData(tx22.Hash()) + require.False(t, ok) + } + } + check(t, smallNetFee*3, true) + check(t, smallNetFee*3+1, false) } func TestMempoolAddWithDataGetData(t *testing.T) {