network: retransmit stale transactions

This commit is contained in:
Evgenii Stratonikov 2020-11-11 15:49:51 +03:00
parent e700fb2c96
commit 3e5b84348d
3 changed files with 94 additions and 3 deletions

View file

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"math/bits"
"sort" "sort"
"sync" "sync"
@ -58,6 +59,9 @@ type Pool struct {
capacity int capacity int
feePerByte int64 feePerByte int64
resendThreshold uint32
resendFunc func(*transaction.Transaction)
} }
func (p items) Len() int { return len(p) } func (p items) Len() int { return len(p) }
@ -289,6 +293,8 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer)
if feer.P2PSigExtensionsEnabled() { if feer.P2PSigExtensionsEnabled() {
mp.conflicts = make(map[util.Uint256][]util.Uint256) mp.conflicts = make(map[util.Uint256][]util.Uint256)
} }
height := feer.BlockHeight()
var staleTxs []*transaction.Transaction
for _, itm := range mp.verifiedTxes { for _, itm := range mp.verifiedTxes {
if isOK(itm.txn) && mp.checkPolicy(itm.txn, policyChanged) && mp.tryAddSendersFee(itm.txn, feer, true) { if isOK(itm.txn) && mp.checkPolicy(itm.txn, policyChanged) && mp.tryAddSendersFee(itm.txn, feer, true) {
newVerifiedTxes = append(newVerifiedTxes, itm) newVerifiedTxes = append(newVerifiedTxes, itm)
@ -298,10 +304,21 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer)
mp.conflicts[hash] = append(mp.conflicts[hash], itm.txn.Hash()) mp.conflicts[hash] = append(mp.conflicts[hash], itm.txn.Hash())
} }
} }
if mp.resendThreshold != 0 {
// tx is resend at resendThreshold, 2*resendThreshold, 4*resendThreshold ...
// so quotient must be a power of two.
diff := (height - itm.blockStamp)
if diff%mp.resendThreshold == 0 && bits.OnesCount32(diff/mp.resendThreshold) == 1 {
staleTxs = append(staleTxs, itm.txn)
}
}
} else { } else {
delete(mp.verifiedMap, itm.txn.Hash()) delete(mp.verifiedMap, itm.txn.Hash())
} }
} }
if len(staleTxs) != 0 {
go mp.resendStaleTxs(staleTxs)
}
mp.verifiedTxes = newVerifiedTxes mp.verifiedTxes = newVerifiedTxes
mp.lock.Unlock() mp.lock.Unlock()
} }
@ -336,6 +353,21 @@ func New(capacity int) *Pool {
} }
} }
// SetResendThreshold sets threshold after which transaction will be considered stale
// and returned for retransmission by `GetStaleTransactions`.
func (mp *Pool) SetResendThreshold(h uint32, f func(*transaction.Transaction)) {
mp.lock.Lock()
defer mp.lock.Unlock()
mp.resendThreshold = h
mp.resendFunc = f
}
func (mp *Pool) resendStaleTxs(txs []*transaction.Transaction) {
for i := range txs {
mp.resendFunc(txs[i])
}
}
// TryGetValue returns a transaction and its fee if it exists in the memory pool. // TryGetValue returns a transaction and its fee if it exists in the memory pool.
func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, bool) { func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, bool) {
mp.lock.RLock() mp.lock.RLock()

View file

@ -5,6 +5,7 @@ import (
"math/big" "math/big"
"sort" "sort"
"testing" "testing"
"time"
"github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
@ -16,8 +17,9 @@ import (
) )
type FeerStub struct { type FeerStub struct {
feePerByte int64 feePerByte int64
p2pSigExt bool p2pSigExt bool
blockHeight uint32
} }
var balance = big.NewInt(10000000) var balance = big.NewInt(10000000)
@ -27,7 +29,7 @@ func (fs *FeerStub) FeePerByte() int64 {
} }
func (fs *FeerStub) BlockHeight() uint32 { func (fs *FeerStub) BlockHeight() uint32 {
return 0 return fs.blockHeight
} }
func (fs *FeerStub) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int { func (fs *FeerStub) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int {
@ -59,6 +61,50 @@ func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) {
assert.Equal(t, 0, len(mp.verifiedTxes)) assert.Equal(t, 0, len(mp.verifiedTxes))
} }
func TestMemPoolRemoveStale(t *testing.T) {
mp := New(5)
txs := make([]*transaction.Transaction, 5)
for i := range txs {
txs[i] = transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0)
txs[i].Nonce = uint32(i)
txs[i].Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}}
require.NoError(t, mp.Add(txs[i], &FeerStub{blockHeight: uint32(i)}))
}
staleTxs := make(chan *transaction.Transaction, 5)
f := func(tx *transaction.Transaction) {
staleTxs <- tx
}
mp.SetResendThreshold(5, f)
isValid := func(tx *transaction.Transaction) bool {
return tx.Nonce%2 == 0
}
mp.RemoveStale(isValid, &FeerStub{blockHeight: 5}) // 0 + 5
require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100)
require.Equal(t, txs[0], <-staleTxs)
mp.RemoveStale(isValid, &FeerStub{blockHeight: 7}) // 2 + 5
require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100)
require.Equal(t, txs[2], <-staleTxs)
mp.RemoveStale(isValid, &FeerStub{blockHeight: 10}) // 0 + 2 * 5
require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100)
require.Equal(t, txs[0], <-staleTxs)
mp.RemoveStale(isValid, &FeerStub{blockHeight: 15}) // 0 + 3 * 5
// tx[2] should appear, so it is also checked that tx[0] wasn't sent on height 15.
mp.RemoveStale(isValid, &FeerStub{blockHeight: 22}) // 2 + 4 * 5
require.Eventually(t, func() bool { return len(staleTxs) == 1 }, time.Second, time.Millisecond*100)
require.Equal(t, txs[2], <-staleTxs)
// panic if something is sent after this.
close(staleTxs)
require.Len(t, staleTxs, 0)
}
func TestMemPoolAddRemove(t *testing.T) { func TestMemPoolAddRemove(t *testing.T) {
var fs = &FeerStub{} var fs = &FeerStub{}
testMemPoolAddRemoveWithFeer(t, fs) testMemPoolAddRemoveWithFeer(t, fs)

View file

@ -171,6 +171,7 @@ func (s *Server) Start(errChan chan error) {
zap.Uint32("headerHeight", s.chain.HeaderHeight())) zap.Uint32("headerHeight", s.chain.HeaderHeight()))
s.tryStartConsensus() s.tryStartConsensus()
s.initStaleTxMemPool()
go s.broadcastTxLoop() go s.broadcastTxLoop()
go s.relayBlocksLoop() go s.relayBlocksLoop()
@ -905,6 +906,18 @@ func (s *Server) broadcastTxHashes(hs []util.Uint256) {
s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, Peer.IsFullNode) s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, Peer.IsFullNode)
} }
// initStaleTxMemPool initializes mempool for stale tx processing.
func (s *Server) initStaleTxMemPool() {
cfg := s.chain.GetConfig()
threshold := 5
if cfg.ValidatorsCount*2 > threshold {
threshold = cfg.ValidatorsCount * 2
}
mp := s.chain.GetMemPool()
mp.SetResendThreshold(uint32(threshold), s.broadcastTX)
}
// broadcastTxLoop is a loop for batching and sending // broadcastTxLoop is a loop for batching and sending
// transactions hashes in an INV payload. // transactions hashes in an INV payload.
func (s *Server) broadcastTxLoop() { func (s *Server) broadcastTxLoop() {