diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 180b80da4..6e7d4e456 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -872,7 +872,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) - bc.memPool.RemoveStale(bc.isTxStillRelevant) + bc.memPool.RemoveStale(bc.isTxStillRelevant, block.Index) bc.lock.Unlock() updateBlockHeightMetric(block.Index) diff --git a/pkg/core/blockchainer.go b/pkg/core/blockchainer.go index 6ead1cd03..76c032c61 100644 --- a/pkg/core/blockchainer.go +++ b/pkg/core/blockchainer.go @@ -19,7 +19,6 @@ type Blockchainer interface { AddHeaders(...*block.Header) error AddBlock(*block.Block) error AddStateRoot(r *state.MPTRoot) error - BlockHeight() uint32 CalculateClaimable(value util.Fixed8, startHeight, endHeight uint32) (util.Fixed8, util.Fixed8, error) Close() HeaderHeight() uint32 diff --git a/pkg/core/mempool/feer.go b/pkg/core/mempool/feer.go index 89b63dab5..f425ddd2a 100644 --- a/pkg/core/mempool/feer.go +++ b/pkg/core/mempool/feer.go @@ -7,6 +7,7 @@ import ( // Feer is an interface that abstract the implementation of the fee calculation. type Feer interface { + BlockHeight() uint32 NetworkFee(t *transaction.Transaction) util.Fixed8 IsLowPriority(util.Fixed8) bool FeePerByte(t *transaction.Transaction) util.Fixed8 diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 7dc38e2d4..e47ef11a1 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -2,9 +2,9 @@ package mempool import ( "errors" + "math/bits" "sort" "sync" - "time" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/util" @@ -26,7 +26,7 @@ var ( // item represents a transaction in the the Memory pool. type item struct { txn *transaction.Transaction - timeStamp time.Time + blockStamp uint32 perByteFee util.Fixed8 netFee util.Fixed8 isLowPrio bool @@ -49,6 +49,9 @@ type Pool struct { inputs []*transaction.Input claims []*transaction.Input + resendThreshold uint32 + resendFunc func(*transaction.Transaction) + capacity int } @@ -162,7 +165,7 @@ func dropInputFromSortedSlice(slice *[]*transaction.Input, input *transaction.In func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { var pItem = &item{ txn: t, - timeStamp: time.Now().UTC(), + blockStamp: fee.BlockHeight(), perByteFee: fee.FeePerByte(t), netFee: fee.NetworkFee(t), } @@ -258,13 +261,14 @@ func (mp *Pool) Remove(hash util.Uint256) { // RemoveStale filters verified transactions through the given function keeping // only the transactions for which it returns a true result. It's used to quickly // drop part of the mempool that is now invalid after the block acceptance. -func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool) { +func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, height uint32) { mp.lock.Lock() // We can reuse already allocated slice // because items are iterated one-by-one in increasing order. newVerifiedTxes := mp.verifiedTxes[:0] newInputs := mp.inputs[:0] newClaims := mp.claims[:0] + var staleTxs []*transaction.Transaction for _, itm := range mp.verifiedTxes { if isOK(itm.txn) { newVerifiedTxes = append(newVerifiedTxes, itm) @@ -277,10 +281,21 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool) { newClaims = append(newClaims, &claim.Claims[i]) } } + if mp.resendThreshold != 0 { + // tx is resend at resendThreshold, 2*resendThreshold, 4*resendThreshold ... + // so quotient must be a power of two. + diff := (height - itm.blockStamp) + if diff%mp.resendThreshold == 0 && bits.OnesCount32(diff/mp.resendThreshold) == 1 { + staleTxs = append(staleTxs, itm.txn) + } + } } else { delete(mp.verifiedMap, itm.txn.Hash()) } } + if len(staleTxs) != 0 { + go mp.resendStaleTxs(staleTxs) + } sort.Slice(newInputs, func(i, j int) bool { return newInputs[i].Cmp(newInputs[j]) < 0 }) @@ -302,6 +317,21 @@ func NewMemPool(capacity int) Pool { } } +// SetResendThreshold sets threshold after which transaction will be considered stale +// and returned for retransmission by `GetStaleTransactions`. +func (mp *Pool) SetResendThreshold(h uint32, f func(*transaction.Transaction)) { + mp.lock.Lock() + defer mp.lock.Unlock() + mp.resendThreshold = h + mp.resendFunc = f +} + +func (mp *Pool) resendStaleTxs(txs []*transaction.Transaction) { + for i := range txs { + mp.resendFunc(txs[i]) + } +} + // TryGetValue returns a transaction and its fee if it exists in the memory pool. func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, util.Fixed8, bool) { mp.lock.RLock() diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index 5cfcf7b70..38fce9750 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -3,6 +3,7 @@ package mempool import ( "sort" "testing" + "time" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/internal/random" @@ -12,12 +13,17 @@ import ( ) type FeerStub struct { + blockHeight uint32 lowPriority bool sysFee util.Fixed8 netFee util.Fixed8 perByteFee util.Fixed8 } +func (fs *FeerStub) BlockHeight() uint32 { + return fs.blockHeight +} + func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 { return fs.netFee } @@ -53,6 +59,48 @@ func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { assert.Equal(t, 0, len(mp.verifiedTxes)) } +func TestMemPoolRemoveStale(t *testing.T) { + mp := NewMemPool(5) + txs := make([]*transaction.Transaction, 5) + for i := range txs { + txs[i] = newMinerTX(uint32(i)) + require.NoError(t, mp.Add(txs[i], &FeerStub{blockHeight: uint32(i)})) + } + + staleTxs := make(chan *transaction.Transaction, 5) + f := func(tx *transaction.Transaction) { + staleTxs <- tx + } + mp.SetResendThreshold(5, f) + + isValid := func(tx *transaction.Transaction) bool { + return tx.Data.(*transaction.MinerTX).Nonce%2 == 0 + } + + mp.RemoveStale(isValid, 5) // 0 + 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[0], <-staleTxs) + + mp.RemoveStale(isValid, 7) // 2 + 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[2], <-staleTxs) + + mp.RemoveStale(isValid, 10) // 0 + 2 * 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[0], <-staleTxs) + + mp.RemoveStale(isValid, 15) // 0 + 3 * 5 + + // tx[2] should appear, so it is also checked that tx[0] wasn't sent on height 15. + mp.RemoveStale(isValid, 22) // 2 + 4 * 5 + require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100) + require.Equal(t, txs[2], <-staleTxs) + + // panic if something is sent after this. + close(staleTxs) + require.Len(t, staleTxs, 0) +} + func TestMemPoolAddRemove(t *testing.T) { var fs = &FeerStub{lowPriority: false} t.Run("low priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) }) @@ -119,7 +167,7 @@ func TestMemPoolAddRemoveWithInputsAndClaims(t *testing.T) { return false } return true - }) + }, 0) assert.Equal(t, len(txm1.Inputs), len(mp.inputs)) assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs)) assert.Equal(t, len(claim2.Claims), len(mp.claims)) @@ -333,7 +381,7 @@ func TestRemoveStale(t *testing.T) { } } return false - }) + }, 0) require.Equal(t, mempoolSize/2, mp.Count()) verTxes := mp.GetVerifiedTransactions() for _, txf := range verTxes { diff --git a/pkg/network/server.go b/pkg/network/server.go index d45162418..dc616f884 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -176,6 +176,7 @@ func (s *Server) Start(errChan chan error) { zap.Uint32("headerHeight", s.chain.HeaderHeight())) s.tryStartConsensus() + s.initStaleTxMemPool() go s.broadcastTxLoop() go s.relayBlocksLoop() @@ -968,6 +969,18 @@ func (s *Server) broadcastTxHashes(hs []util.Uint256) { }) } +// initStaleTxMemPool initializes mempool for stale tx processing. +func (s *Server) initStaleTxMemPool() { + cfg := s.chain.GetConfig() + threshold := 5 + if l := len(cfg.StandbyValidators); l*2 > threshold { + threshold = l * 2 + } + + mp := s.chain.GetMemPool() + mp.SetResendThreshold(uint32(threshold), s.broadcastTX) +} + // broadcastTxLoop is a loop for batching and sending // transactions hashes in an INV payload. func (s *Server) broadcastTxLoop() { diff --git a/pkg/rpc/server/server_helper_test.go b/pkg/rpc/server/server_helper_test.go index 61bcc2e0d..9a84e44ce 100644 --- a/pkg/rpc/server/server_helper_test.go +++ b/pkg/rpc/server/server_helper_test.go @@ -85,6 +85,10 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *Server, *http type FeerStub struct{} +func (fs *FeerStub) BlockHeight() uint32 { + return 0 +} + func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 { return 0 }