diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 38326766e..6aebda543 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -55,6 +55,7 @@ type service struct { // everything in single thread. messages chan Payload transactions chan *transaction.Transaction + lastProposal []util.Uint256 } // Config is a configuration for consensus services. @@ -205,7 +206,9 @@ func (s *service) OnPayload(cp *Payload) { // we use switch here because other payloads could be possibly added in future switch cp.Type() { case payload.PrepareRequestType: - s.txx.Add(&cp.GetPrepareRequest().(*prepareRequest).minerTx) + req := cp.GetPrepareRequest().(*prepareRequest) + s.txx.Add(&req.minerTx) + s.lastProposal = req.transactionHashes } s.messages <- *cp @@ -328,7 +331,23 @@ func (s *service) getBlock(h util.Uint256) block.Block { func (s *service) getVerifiedTx(count int) []block.Transaction { pool := s.Config.Chain.GetMemPool() - txx := pool.GetVerifiedTransactions() + + var txx []*transaction.Transaction + + if s.dbft.ViewNumber > 0 { + txx = make([]*transaction.Transaction, 0, len(s.lastProposal)) + for i := range s.lastProposal { + if tx, ok := pool.TryGetValue(s.lastProposal[i]); ok { + txx = append(txx, tx) + } + } + + if len(txx) < len(s.lastProposal)/2 { + txx = pool.GetVerifiedTransactions() + } + } else { + txx = pool.GetVerifiedTransactions() + } res := make([]block.Transaction, len(txx)+1) for i := range txx { diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 66909fcf2..1cd450f77 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -9,6 +9,7 @@ import ( "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/util" "github.com/nspcc-dev/dbft/block" + "github.com/nspcc-dev/dbft/payload" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) @@ -28,6 +29,53 @@ func TestNewService(t *testing.T) { require.Equal(t, tx, txx[1]) } +func TestService_GetVerified(t *testing.T) { + srv := newTestService(t) + txs := []*transaction.Transaction{ + newMinerTx(1), + newMinerTx(2), + newMinerTx(3), + newMinerTx(4), + } + pool := srv.Chain.GetMemPool() + item := core.NewPoolItem(txs[3], new(feer)) + + require.True(t, pool.TryAdd(txs[3].Hash(), item)) + + hashes := []util.Uint256{txs[0].Hash(), txs[1].Hash(), txs[2].Hash()} + + p := new(Payload) + p.SetType(payload.PrepareRequestType) + p.SetPayload(&prepareRequest{transactionHashes: hashes}) + p.SetValidatorIndex(1) + + priv, _ := getTestValidator(1) + require.NoError(t, p.Sign(priv)) + + srv.OnPayload(p) + require.Equal(t, hashes, srv.lastProposal) + + srv.dbft.ViewNumber = 1 + + t.Run("new transactions will be proposed in case of failure", func(t *testing.T) { + txx := srv.getVerifiedTx(10) + require.Equal(t, 2, len(txx), "there is only 1 tx in mempool") + require.Equal(t, txs[3], txx[1]) + }) + + t.Run("more than half of the last proposal will be reused", func(t *testing.T) { + for _, tx := range txs[:2] { + item := core.NewPoolItem(tx, new(feer)) + require.True(t, pool.TryAdd(tx.Hash(), item)) + } + + txx := srv.getVerifiedTx(10) + require.Contains(t, txx, txs[0]) + require.Contains(t, txx, txs[1]) + require.NotContains(t, txx, txs[2]) + }) +} + func TestService_ValidatePayload(t *testing.T) { srv := newTestService(t) priv, _ := getTestValidator(1)