diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index b7722501d..fc9bc88c8 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -9,6 +9,7 @@ import ( "github.com/CityOfZion/neo-go/config" "github.com/CityOfZion/neo-go/pkg/core" coreb "github.com/CityOfZion/neo-go/pkg/core/block" + "github.com/CityOfZion/neo-go/pkg/core/mempool" "github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/crypto/keys" @@ -412,13 +413,13 @@ func (s *service) getBlock(h util.Uint256) block.Block { func (s *service) getVerifiedTx(count int) []block.Transaction { pool := s.Config.Chain.GetMemPool() - var txx []*transaction.Transaction + var txx []mempool.TxWithFee if s.dbft.ViewNumber > 0 { - txx = make([]*transaction.Transaction, 0, len(s.lastProposal)) + txx = make([]mempool.TxWithFee, 0, len(s.lastProposal)) for i := range s.lastProposal { - if tx, ok := pool.TryGetValue(s.lastProposal[i]); ok { - txx = append(txx, tx) + if tx, fee, ok := pool.TryGetValue(s.lastProposal[i]); ok { + txx = append(txx, mempool.TxWithFee{Tx: tx, Fee: fee}) } } @@ -432,8 +433,8 @@ func (s *service) getVerifiedTx(count int) []block.Transaction { res := make([]block.Transaction, len(txx)+1) var netFee util.Fixed8 for i := range txx { - res[i+1] = txx[i] - netFee += s.Config.Chain.NetworkFee(txx[i]) + res[i+1] = txx[i].Tx + netFee += txx[i].Fee } var txOuts []transaction.Output diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 2aa323562..5199c643d 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -814,7 +814,7 @@ func (bc *Blockchain) headerListLen() (n int) { // GetTransaction returns a TX and its height by the given hash. func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) { - if tx, ok := bc.memPool.TryGetValue(hash); ok { + if tx, _, ok := bc.memPool.TryGetValue(hash); ok { return tx, 0, nil // the height is not actually defined for memPool transaction. Not sure if zero is a good number in this case. } return bc.dao.GetTransaction(hash) diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 3facd377a..9a99a1774 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -35,6 +35,12 @@ type item struct { // items is a slice of item. type items []*item +// TxWithFee combines transaction and its precalculated network fee. +type TxWithFee struct { + Tx *transaction.Transaction + Fee util.Fixed8 +} + // Pool stores the unconfirms transactions. type Pool struct { lock sync.RWMutex @@ -224,29 +230,28 @@ func NewMemPool(capacity int) Pool { } } -// TryGetValue returns a transaction if it exists in the memory pool. -func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, bool) { +// TryGetValue returns a transaction and its fee if it exists in the memory pool. +func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, util.Fixed8, bool) { mp.lock.RLock() defer mp.lock.RUnlock() if pItem, ok := mp.verifiedMap[hash]; ok { - return pItem.txn, ok + return pItem.txn, pItem.netFee, ok } - return nil, false + return nil, 0, false } // GetVerifiedTransactions returns a slice of Input from all the transactions in the memory pool // whose hash is not included in excludedHashes. -func (mp *Pool) GetVerifiedTransactions() []*transaction.Transaction { +func (mp *Pool) GetVerifiedTransactions() []TxWithFee { mp.lock.RLock() defer mp.lock.RUnlock() - var t = make([]*transaction.Transaction, len(mp.verifiedTxes)) - var i int + var t = make([]TxWithFee, len(mp.verifiedTxes)) - for _, p := range mp.verifiedTxes { - t[i] = p.txn - i++ + for i := range mp.verifiedTxes { + t[i].Tx = mp.verifiedTxes[i].txn + t[i].Fee = mp.verifiedTxes[i].netFee } return t diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index f60829a2a..b6a9401be 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -37,16 +37,16 @@ func (fs *FeerStub) SystemFee(*transaction.Transaction) util.Fixed8 { func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { mp := NewMemPool(10) tx := newMinerTX(0) - _, ok := mp.TryGetValue(tx.Hash()) + _, _, ok := mp.TryGetValue(tx.Hash()) require.Equal(t, false, ok) require.NoError(t, mp.Add(tx, fs)) // Re-adding should fail. require.Error(t, mp.Add(tx, fs)) - tx2, ok := mp.TryGetValue(tx.Hash()) + tx2, _, ok := mp.TryGetValue(tx.Hash()) require.Equal(t, true, ok) require.Equal(t, tx, tx2) mp.Remove(tx.Hash()) - _, ok = mp.TryGetValue(tx.Hash()) + _, _, ok = mp.TryGetValue(tx.Hash()) require.Equal(t, false, ok) // Make sure nothing left in the mempool after removal. assert.Equal(t, 0, len(mp.verifiedMap)) @@ -173,8 +173,8 @@ func TestGetVerified(t *testing.T) { require.Equal(t, mempoolSize, mp.Count()) verTxes := mp.GetVerifiedTransactions() require.Equal(t, mempoolSize, len(verTxes)) - for _, tx := range verTxes { - require.Contains(t, txes, tx) + for _, txf := range verTxes { + require.Contains(t, txes, txf.Tx) } for _, tx := range txes { mp.Remove(tx.Hash()) @@ -210,8 +210,8 @@ func TestRemoveStale(t *testing.T) { }) require.Equal(t, mempoolSize/2, mp.Count()) verTxes := mp.GetVerifiedTransactions() - for _, tx := range verTxes { - require.NotContains(t, txes1, tx) - require.Contains(t, txes2, tx) + for _, txf := range verTxes { + require.NotContains(t, txes1, txf.Tx) + require.Contains(t, txes2, txf.Tx) } }