network: retransmit stale transactions

This commit is contained in:
Evgenii Stratonikov 2020-11-11 15:49:51 +03:00
parent 06f3c34981
commit d93ddfda10
4 changed files with 93 additions and 5 deletions

View file

@ -872,7 +872,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
} }
bc.topBlock.Store(block) bc.topBlock.Store(block)
atomic.StoreUint32(&bc.blockHeight, block.Index) atomic.StoreUint32(&bc.blockHeight, block.Index)
bc.memPool.RemoveStale(bc.isTxStillRelevant) bc.memPool.RemoveStale(bc.isTxStillRelevant, block.Index)
bc.lock.Unlock() bc.lock.Unlock()
updateBlockHeightMetric(block.Index) updateBlockHeightMetric(block.Index)

View file

@ -2,6 +2,7 @@ package mempool
import ( import (
"errors" "errors"
"math/bits"
"sort" "sort"
"sync" "sync"
@ -48,6 +49,9 @@ type Pool struct {
inputs []*transaction.Input inputs []*transaction.Input
claims []*transaction.Input claims []*transaction.Input
resendThreshold uint32
resendFunc func(*transaction.Transaction)
capacity int capacity int
} }
@ -257,13 +261,14 @@ func (mp *Pool) Remove(hash util.Uint256) {
// RemoveStale filters verified transactions through the given function keeping // RemoveStale filters verified transactions through the given function keeping
// only the transactions for which it returns a true result. It's used to quickly // only the transactions for which it returns a true result. It's used to quickly
// drop part of the mempool that is now invalid after the block acceptance. // drop part of the mempool that is now invalid after the block acceptance.
func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool) { func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, height uint32) {
mp.lock.Lock() mp.lock.Lock()
// We can reuse already allocated slice // We can reuse already allocated slice
// because items are iterated one-by-one in increasing order. // because items are iterated one-by-one in increasing order.
newVerifiedTxes := mp.verifiedTxes[:0] newVerifiedTxes := mp.verifiedTxes[:0]
newInputs := mp.inputs[:0] newInputs := mp.inputs[:0]
newClaims := mp.claims[:0] newClaims := mp.claims[:0]
var staleTxs []*transaction.Transaction
for _, itm := range mp.verifiedTxes { for _, itm := range mp.verifiedTxes {
if isOK(itm.txn) { if isOK(itm.txn) {
newVerifiedTxes = append(newVerifiedTxes, itm) newVerifiedTxes = append(newVerifiedTxes, itm)
@ -276,10 +281,21 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool) {
newClaims = append(newClaims, &claim.Claims[i]) newClaims = append(newClaims, &claim.Claims[i])
} }
} }
if mp.resendThreshold != 0 {
// tx is resend at resendThreshold, 2*resendThreshold, 4*resendThreshold ...
// so quotient must be a power of two.
diff := (height - itm.blockStamp)
if diff%mp.resendThreshold == 0 && bits.OnesCount32(diff/mp.resendThreshold) == 1 {
staleTxs = append(staleTxs, itm.txn)
}
}
} else { } else {
delete(mp.verifiedMap, itm.txn.Hash()) delete(mp.verifiedMap, itm.txn.Hash())
} }
} }
if len(staleTxs) != 0 {
go mp.resendStaleTxs(staleTxs)
}
sort.Slice(newInputs, func(i, j int) bool { sort.Slice(newInputs, func(i, j int) bool {
return newInputs[i].Cmp(newInputs[j]) < 0 return newInputs[i].Cmp(newInputs[j]) < 0
}) })
@ -301,6 +317,21 @@ func NewMemPool(capacity int) Pool {
} }
} }
// SetResendThreshold sets threshold after which transaction will be considered stale
// and returned for retransmission by `GetStaleTransactions`.
func (mp *Pool) SetResendThreshold(h uint32, f func(*transaction.Transaction)) {
mp.lock.Lock()
defer mp.lock.Unlock()
mp.resendThreshold = h
mp.resendFunc = f
}
func (mp *Pool) resendStaleTxs(txs []*transaction.Transaction) {
for i := range txs {
mp.resendFunc(txs[i])
}
}
// TryGetValue returns a transaction and its fee if it exists in the memory pool. // TryGetValue returns a transaction and its fee if it exists in the memory pool.
func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, util.Fixed8, bool) { func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, util.Fixed8, bool) {
mp.lock.RLock() mp.lock.RLock()

View file

@ -3,6 +3,7 @@ package mempool
import ( import (
"sort" "sort"
"testing" "testing"
"time"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/random"
@ -12,6 +13,7 @@ import (
) )
type FeerStub struct { type FeerStub struct {
blockHeight uint32
lowPriority bool lowPriority bool
sysFee util.Fixed8 sysFee util.Fixed8
netFee util.Fixed8 netFee util.Fixed8
@ -19,7 +21,7 @@ type FeerStub struct {
} }
func (fs *FeerStub) BlockHeight() uint32 { func (fs *FeerStub) BlockHeight() uint32 {
return 0 return fs.blockHeight
} }
func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 { func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 {
@ -57,6 +59,48 @@ func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) {
assert.Equal(t, 0, len(mp.verifiedTxes)) assert.Equal(t, 0, len(mp.verifiedTxes))
} }
func TestMemPoolRemoveStale(t *testing.T) {
mp := NewMemPool(5)
txs := make([]*transaction.Transaction, 5)
for i := range txs {
txs[i] = newMinerTX(uint32(i))
require.NoError(t, mp.Add(txs[i], &FeerStub{blockHeight: uint32(i)}))
}
staleTxs := make(chan *transaction.Transaction, 5)
f := func(tx *transaction.Transaction) {
staleTxs <- tx
}
mp.SetResendThreshold(5, f)
isValid := func(tx *transaction.Transaction) bool {
return tx.Data.(*transaction.MinerTX).Nonce%2 == 0
}
mp.RemoveStale(isValid, 5) // 0 + 5
require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100)
require.Equal(t, txs[0], <-staleTxs)
mp.RemoveStale(isValid, 7) // 2 + 5
require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100)
require.Equal(t, txs[2], <-staleTxs)
mp.RemoveStale(isValid, 10) // 0 + 2 * 5
require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100)
require.Equal(t, txs[0], <-staleTxs)
mp.RemoveStale(isValid, 15) // 0 + 3 * 5
// tx[2] should appear, so it is also checked that tx[0] wasn't sent on height 15.
mp.RemoveStale(isValid, 22) // 2 + 4 * 5
require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100)
require.Equal(t, txs[2], <-staleTxs)
// panic if something is sent after this.
close(staleTxs)
require.Len(t, staleTxs, 0)
}
func TestMemPoolAddRemove(t *testing.T) { func TestMemPoolAddRemove(t *testing.T) {
var fs = &FeerStub{lowPriority: false} var fs = &FeerStub{lowPriority: false}
t.Run("low priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) }) t.Run("low priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) })
@ -123,7 +167,7 @@ func TestMemPoolAddRemoveWithInputsAndClaims(t *testing.T) {
return false return false
} }
return true return true
}) }, 0)
assert.Equal(t, len(txm1.Inputs), len(mp.inputs)) assert.Equal(t, len(txm1.Inputs), len(mp.inputs))
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs)) assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
assert.Equal(t, len(claim2.Claims), len(mp.claims)) assert.Equal(t, len(claim2.Claims), len(mp.claims))
@ -337,7 +381,7 @@ func TestRemoveStale(t *testing.T) {
} }
} }
return false return false
}) }, 0)
require.Equal(t, mempoolSize/2, mp.Count()) require.Equal(t, mempoolSize/2, mp.Count())
verTxes := mp.GetVerifiedTransactions() verTxes := mp.GetVerifiedTransactions()
for _, txf := range verTxes { for _, txf := range verTxes {

View file

@ -176,6 +176,7 @@ func (s *Server) Start(errChan chan error) {
zap.Uint32("headerHeight", s.chain.HeaderHeight())) zap.Uint32("headerHeight", s.chain.HeaderHeight()))
s.tryStartConsensus() s.tryStartConsensus()
s.initStaleTxMemPool()
go s.broadcastTxLoop() go s.broadcastTxLoop()
go s.relayBlocksLoop() go s.relayBlocksLoop()
@ -968,6 +969,18 @@ func (s *Server) broadcastTxHashes(hs []util.Uint256) {
}) })
} }
// initStaleTxMemPool initializes mempool for stale tx processing.
func (s *Server) initStaleTxMemPool() {
cfg := s.chain.GetConfig()
threshold := 5
if l := len(cfg.StandbyValidators); l*2 > threshold {
threshold = l * 2
}
mp := s.chain.GetMemPool()
mp.SetResendThreshold(uint32(threshold), s.broadcastTX)
}
// broadcastTxLoop is a loop for batching and sending // broadcastTxLoop is a loop for batching and sending
// transactions hashes in an INV payload. // transactions hashes in an INV payload.
func (s *Server) broadcastTxLoop() { func (s *Server) broadcastTxLoop() {