diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 74f17bf37..ae8b9d9bf 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math/big" + "math/bits" "sort" "sync" @@ -58,6 +59,9 @@ type Pool struct { capacity int feePerByte int64 + + resendThreshold uint32 + resendFunc func(*transaction.Transaction) } func (p items) Len() int { return len(p) } @@ -289,6 +293,8 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer) if feer.P2PSigExtensionsEnabled() { mp.conflicts = make(map[util.Uint256][]util.Uint256) } + height := feer.BlockHeight() + var staleTxs []*transaction.Transaction for _, itm := range mp.verifiedTxes { if isOK(itm.txn) && mp.checkPolicy(itm.txn, policyChanged) && mp.tryAddSendersFee(itm.txn, feer, true) { newVerifiedTxes = append(newVerifiedTxes, itm) @@ -298,10 +304,21 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer) mp.conflicts[hash] = append(mp.conflicts[hash], itm.txn.Hash()) } } + if mp.resendThreshold != 0 { + // tx is resend at resendThreshold, 2*resendThreshold, 4*resendThreshold ... + // so quotient must be a power of two. + diff := (height - itm.blockStamp) + if diff%mp.resendThreshold == 0 && bits.OnesCount32(diff/mp.resendThreshold) == 1 { + staleTxs = append(staleTxs, itm.txn) + } + } } else { delete(mp.verifiedMap, itm.txn.Hash()) } } + if len(staleTxs) != 0 { + go mp.resendStaleTxs(staleTxs) + } mp.verifiedTxes = newVerifiedTxes mp.lock.Unlock() } @@ -336,6 +353,21 @@ func New(capacity int) *Pool { } } +// SetResendThreshold sets threshold after which transaction will be considered stale +// and returned for retransmission by `GetStaleTransactions`. +func (mp *Pool) SetResendThreshold(h uint32, f func(*transaction.Transaction)) { + mp.lock.Lock() + defer mp.lock.Unlock() + mp.resendThreshold = h + mp.resendFunc = f +} + +func (mp *Pool) resendStaleTxs(txs []*transaction.Transaction) { + for i := range txs { + mp.resendFunc(txs[i]) + } +} + // TryGetValue returns a transaction and its fee if it exists in the memory pool. func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, bool) { mp.lock.RLock() diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index 16ecbf9f3..a872ebe7f 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -5,6 +5,7 @@ import ( "math/big" "sort" "testing" + "time" "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -16,8 +17,9 @@ import ( ) type FeerStub struct { - feePerByte int64 - p2pSigExt bool + feePerByte int64 + p2pSigExt bool + blockHeight uint32 } var balance = big.NewInt(10000000) @@ -27,7 +29,7 @@ func (fs *FeerStub) FeePerByte() int64 { } func (fs *FeerStub) BlockHeight() uint32 { - return 0 + return fs.blockHeight } func (fs *FeerStub) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int { @@ -59,6 +61,50 @@ func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { assert.Equal(t, 0, len(mp.verifiedTxes)) } +func TestMemPoolRemoveStale(t *testing.T) { + mp := New(5) + txs := make([]*transaction.Transaction, 5) + for i := range txs { + txs[i] = transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) + txs[i].Nonce = uint32(i) + txs[i].Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} + require.NoError(t, mp.Add(txs[i], &FeerStub{blockHeight: uint32(i)})) + } + + staleTxs := make(chan *transaction.Transaction, 5) + f := func(tx *transaction.Transaction) { + staleTxs <- tx + } + mp.SetResendThreshold(5, f) + + isValid := func(tx *transaction.Transaction) bool { + return tx.Nonce%2 == 0 + } + + mp.RemoveStale(isValid, &FeerStub{blockHeight: 5}) // 0 + 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[0], <-staleTxs) + + mp.RemoveStale(isValid, &FeerStub{blockHeight: 7}) // 2 + 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[2], <-staleTxs) + + mp.RemoveStale(isValid, &FeerStub{blockHeight: 10}) // 0 + 2 * 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[0], <-staleTxs) + + mp.RemoveStale(isValid, &FeerStub{blockHeight: 15}) // 0 + 3 * 5 + + // tx[2] should appear, so it is also checked that tx[0] wasn't sent on height 15. + mp.RemoveStale(isValid, &FeerStub{blockHeight: 22}) // 2 + 4 * 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[2], <-staleTxs) + + // panic if something is sent after this. + close(staleTxs) + require.Len(t, staleTxs, 0) +} + func TestMemPoolAddRemove(t *testing.T) { var fs = &FeerStub{} testMemPoolAddRemoveWithFeer(t, fs) diff --git a/pkg/network/server.go b/pkg/network/server.go index fd7b64a5e..6dd91b0ba 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -171,6 +171,7 @@ func (s *Server) Start(errChan chan error) { zap.Uint32("headerHeight", s.chain.HeaderHeight())) s.tryStartConsensus() + s.initStaleTxMemPool() go s.broadcastTxLoop() go s.relayBlocksLoop() @@ -905,6 +906,18 @@ func (s *Server) broadcastTxHashes(hs []util.Uint256) { s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, Peer.IsFullNode) } +// initStaleTxMemPool initializes mempool for stale tx processing. +func (s *Server) initStaleTxMemPool() { + cfg := s.chain.GetConfig() + threshold := 5 + if cfg.ValidatorsCount*2 > threshold { + threshold = cfg.ValidatorsCount * 2 + } + + mp := s.chain.GetMemPool() + mp.SetResendThreshold(uint32(threshold), s.broadcastTX) +} + // broadcastTxLoop is a loop for batching and sending // transactions hashes in an INV payload. func (s *Server) broadcastTxLoop() {