From d93ddfda10176524ba45571f299260820809d81b Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Wed, 11 Nov 2020 15:49:51 +0300 Subject: [PATCH] network: retransmit stale transactions --- pkg/core/blockchain.go | 2 +- pkg/core/mempool/mem_pool.go | 33 +++++++++++++++++++- pkg/core/mempool/mem_pool_test.go | 50 +++++++++++++++++++++++++++++-- pkg/network/server.go | 13 ++++++++ 4 files changed, 93 insertions(+), 5 deletions(-) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 180b80da4..6e7d4e456 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -872,7 +872,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) - bc.memPool.RemoveStale(bc.isTxStillRelevant) + bc.memPool.RemoveStale(bc.isTxStillRelevant, block.Index) bc.lock.Unlock() updateBlockHeightMetric(block.Index) diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 8f1895132..e47ef11a1 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -2,6 +2,7 @@ package mempool import ( "errors" + "math/bits" "sort" "sync" @@ -48,6 +49,9 @@ type Pool struct { inputs []*transaction.Input claims []*transaction.Input + resendThreshold uint32 + resendFunc func(*transaction.Transaction) + capacity int } @@ -257,13 +261,14 @@ func (mp *Pool) Remove(hash util.Uint256) { // RemoveStale filters verified transactions through the given function keeping // only the transactions for which it returns a true result. It's used to quickly // drop part of the mempool that is now invalid after the block acceptance. -func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool) { +func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, height uint32) { mp.lock.Lock() // We can reuse already allocated slice // because items are iterated one-by-one in increasing order. newVerifiedTxes := mp.verifiedTxes[:0] newInputs := mp.inputs[:0] newClaims := mp.claims[:0] + var staleTxs []*transaction.Transaction for _, itm := range mp.verifiedTxes { if isOK(itm.txn) { newVerifiedTxes = append(newVerifiedTxes, itm) @@ -276,10 +281,21 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool) { newClaims = append(newClaims, &claim.Claims[i]) } } + if mp.resendThreshold != 0 { + // tx is resend at resendThreshold, 2*resendThreshold, 4*resendThreshold ... + // so quotient must be a power of two. + diff := (height - itm.blockStamp) + if diff%mp.resendThreshold == 0 && bits.OnesCount32(diff/mp.resendThreshold) == 1 { + staleTxs = append(staleTxs, itm.txn) + } + } } else { delete(mp.verifiedMap, itm.txn.Hash()) } } + if len(staleTxs) != 0 { + go mp.resendStaleTxs(staleTxs) + } sort.Slice(newInputs, func(i, j int) bool { return newInputs[i].Cmp(newInputs[j]) < 0 }) @@ -301,6 +317,21 @@ func NewMemPool(capacity int) Pool { } } +// SetResendThreshold sets threshold after which transaction will be considered stale +// and returned for retransmission by `GetStaleTransactions`. +func (mp *Pool) SetResendThreshold(h uint32, f func(*transaction.Transaction)) { + mp.lock.Lock() + defer mp.lock.Unlock() + mp.resendThreshold = h + mp.resendFunc = f +} + +func (mp *Pool) resendStaleTxs(txs []*transaction.Transaction) { + for i := range txs { + mp.resendFunc(txs[i]) + } +} + // TryGetValue returns a transaction and its fee if it exists in the memory pool. func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, util.Fixed8, bool) { mp.lock.RLock() diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index 2d4059281..38fce9750 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -3,6 +3,7 @@ package mempool import ( "sort" "testing" + "time" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/internal/random" @@ -12,6 +13,7 @@ import ( ) type FeerStub struct { + blockHeight uint32 lowPriority bool sysFee util.Fixed8 netFee util.Fixed8 @@ -19,7 +21,7 @@ type FeerStub struct { } func (fs *FeerStub) BlockHeight() uint32 { - return 0 + return fs.blockHeight } func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 { @@ -57,6 +59,48 @@ func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { assert.Equal(t, 0, len(mp.verifiedTxes)) } +func TestMemPoolRemoveStale(t *testing.T) { + mp := NewMemPool(5) + txs := make([]*transaction.Transaction, 5) + for i := range txs { + txs[i] = newMinerTX(uint32(i)) + require.NoError(t, mp.Add(txs[i], &FeerStub{blockHeight: uint32(i)})) + } + + staleTxs := make(chan *transaction.Transaction, 5) + f := func(tx *transaction.Transaction) { + staleTxs <- tx + } + mp.SetResendThreshold(5, f) + + isValid := func(tx *transaction.Transaction) bool { + return tx.Data.(*transaction.MinerTX).Nonce%2 == 0 + } + + mp.RemoveStale(isValid, 5) // 0 + 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[0], <-staleTxs) + + mp.RemoveStale(isValid, 7) // 2 + 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[2], <-staleTxs) + + mp.RemoveStale(isValid, 10) // 0 + 2 * 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[0], <-staleTxs) + + mp.RemoveStale(isValid, 15) // 0 + 3 * 5 + + // tx[2] should appear, so it is also checked that tx[0] wasn't sent on height 15. + mp.RemoveStale(isValid, 22) // 2 + 4 * 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[2], <-staleTxs) + + // panic if something is sent after this. + close(staleTxs) + require.Len(t, staleTxs, 0) +} + func TestMemPoolAddRemove(t *testing.T) { var fs = &FeerStub{lowPriority: false} t.Run("low priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) }) @@ -123,7 +167,7 @@ func TestMemPoolAddRemoveWithInputsAndClaims(t *testing.T) { return false } return true - }) + }, 0) assert.Equal(t, len(txm1.Inputs), len(mp.inputs)) assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs)) assert.Equal(t, len(claim2.Claims), len(mp.claims)) @@ -337,7 +381,7 @@ func TestRemoveStale(t *testing.T) { } } return false - }) + }, 0) require.Equal(t, mempoolSize/2, mp.Count()) verTxes := mp.GetVerifiedTransactions() for _, txf := range verTxes { diff --git a/pkg/network/server.go b/pkg/network/server.go index d45162418..dc616f884 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -176,6 +176,7 @@ func (s *Server) Start(errChan chan error) { zap.Uint32("headerHeight", s.chain.HeaderHeight())) s.tryStartConsensus() + s.initStaleTxMemPool() go s.broadcastTxLoop() go s.relayBlocksLoop() @@ -968,6 +969,18 @@ func (s *Server) broadcastTxHashes(hs []util.Uint256) { }) } +// initStaleTxMemPool initializes mempool for stale tx processing. +func (s *Server) initStaleTxMemPool() { + cfg := s.chain.GetConfig() + threshold := 5 + if l := len(cfg.StandbyValidators); l*2 > threshold { + threshold = l * 2 + } + + mp := s.chain.GetMemPool() + mp.SetResendThreshold(uint32(threshold), s.broadcastTX) +} + // broadcastTxLoop is a loop for batching and sending // transactions hashes in an INV payload. func (s *Server) broadcastTxLoop() {