mempool: return fee along with tx when requesting tx

Users of GetVerifiedTransactions() don't want to recalculate tx fee and it's
nice to have it returned from TryGetValue() also sometimes.
This commit is contained in:
Roman Khimov 2020-02-18 18:56:41 +03:00
parent 06daeb44f3
commit 22f5667530
4 changed files with 31 additions and 25 deletions

View file

@ -9,6 +9,7 @@ import (
"github.com/CityOfZion/neo-go/config" "github.com/CityOfZion/neo-go/config"
"github.com/CityOfZion/neo-go/pkg/core" "github.com/CityOfZion/neo-go/pkg/core"
coreb "github.com/CityOfZion/neo-go/pkg/core/block" coreb "github.com/CityOfZion/neo-go/pkg/core/block"
"github.com/CityOfZion/neo-go/pkg/core/mempool"
"github.com/CityOfZion/neo-go/pkg/core/transaction" "github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/crypto/hash" "github.com/CityOfZion/neo-go/pkg/crypto/hash"
"github.com/CityOfZion/neo-go/pkg/crypto/keys" "github.com/CityOfZion/neo-go/pkg/crypto/keys"
@ -412,13 +413,13 @@ func (s *service) getBlock(h util.Uint256) block.Block {
func (s *service) getVerifiedTx(count int) []block.Transaction { func (s *service) getVerifiedTx(count int) []block.Transaction {
pool := s.Config.Chain.GetMemPool() pool := s.Config.Chain.GetMemPool()
var txx []*transaction.Transaction var txx []mempool.TxWithFee
if s.dbft.ViewNumber > 0 { if s.dbft.ViewNumber > 0 {
txx = make([]*transaction.Transaction, 0, len(s.lastProposal)) txx = make([]mempool.TxWithFee, 0, len(s.lastProposal))
for i := range s.lastProposal { for i := range s.lastProposal {
if tx, ok := pool.TryGetValue(s.lastProposal[i]); ok { if tx, fee, ok := pool.TryGetValue(s.lastProposal[i]); ok {
txx = append(txx, tx) txx = append(txx, mempool.TxWithFee{Tx: tx, Fee: fee})
} }
} }
@ -432,8 +433,8 @@ func (s *service) getVerifiedTx(count int) []block.Transaction {
res := make([]block.Transaction, len(txx)+1) res := make([]block.Transaction, len(txx)+1)
var netFee util.Fixed8 var netFee util.Fixed8
for i := range txx { for i := range txx {
res[i+1] = txx[i] res[i+1] = txx[i].Tx
netFee += s.Config.Chain.NetworkFee(txx[i]) netFee += txx[i].Fee
} }
var txOuts []transaction.Output var txOuts []transaction.Output

View file

@ -814,7 +814,7 @@ func (bc *Blockchain) headerListLen() (n int) {
// GetTransaction returns a TX and its height by the given hash. // GetTransaction returns a TX and its height by the given hash.
func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) { func (bc *Blockchain) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) {
if tx, ok := bc.memPool.TryGetValue(hash); ok { if tx, _, ok := bc.memPool.TryGetValue(hash); ok {
return tx, 0, nil // the height is not actually defined for memPool transaction. Not sure if zero is a good number in this case. return tx, 0, nil // the height is not actually defined for memPool transaction. Not sure if zero is a good number in this case.
} }
return bc.dao.GetTransaction(hash) return bc.dao.GetTransaction(hash)

View file

@ -35,6 +35,12 @@ type item struct {
// items is a slice of item. // items is a slice of item.
type items []*item type items []*item
// TxWithFee combines transaction and its precalculated network fee.
type TxWithFee struct {
Tx *transaction.Transaction
Fee util.Fixed8
}
// Pool stores the unconfirms transactions. // Pool stores the unconfirms transactions.
type Pool struct { type Pool struct {
lock sync.RWMutex lock sync.RWMutex
@ -224,29 +230,28 @@ func NewMemPool(capacity int) Pool {
} }
} }
// TryGetValue returns a transaction if it exists in the memory pool. // TryGetValue returns a transaction and its fee if it exists in the memory pool.
func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, bool) { func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, util.Fixed8, bool) {
mp.lock.RLock() mp.lock.RLock()
defer mp.lock.RUnlock() defer mp.lock.RUnlock()
if pItem, ok := mp.verifiedMap[hash]; ok { if pItem, ok := mp.verifiedMap[hash]; ok {
return pItem.txn, ok return pItem.txn, pItem.netFee, ok
} }
return nil, false return nil, 0, false
} }
// GetVerifiedTransactions returns a slice of Input from all the transactions in the memory pool // GetVerifiedTransactions returns a slice of Input from all the transactions in the memory pool
// whose hash is not included in excludedHashes. // whose hash is not included in excludedHashes.
func (mp *Pool) GetVerifiedTransactions() []*transaction.Transaction { func (mp *Pool) GetVerifiedTransactions() []TxWithFee {
mp.lock.RLock() mp.lock.RLock()
defer mp.lock.RUnlock() defer mp.lock.RUnlock()
var t = make([]*transaction.Transaction, len(mp.verifiedTxes)) var t = make([]TxWithFee, len(mp.verifiedTxes))
var i int
for _, p := range mp.verifiedTxes { for i := range mp.verifiedTxes {
t[i] = p.txn t[i].Tx = mp.verifiedTxes[i].txn
i++ t[i].Fee = mp.verifiedTxes[i].netFee
} }
return t return t

View file

@ -37,16 +37,16 @@ func (fs *FeerStub) SystemFee(*transaction.Transaction) util.Fixed8 {
func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) {
mp := NewMemPool(10) mp := NewMemPool(10)
tx := newMinerTX(0) tx := newMinerTX(0)
_, ok := mp.TryGetValue(tx.Hash()) _, _, ok := mp.TryGetValue(tx.Hash())
require.Equal(t, false, ok) require.Equal(t, false, ok)
require.NoError(t, mp.Add(tx, fs)) require.NoError(t, mp.Add(tx, fs))
// Re-adding should fail. // Re-adding should fail.
require.Error(t, mp.Add(tx, fs)) require.Error(t, mp.Add(tx, fs))
tx2, ok := mp.TryGetValue(tx.Hash()) tx2, _, ok := mp.TryGetValue(tx.Hash())
require.Equal(t, true, ok) require.Equal(t, true, ok)
require.Equal(t, tx, tx2) require.Equal(t, tx, tx2)
mp.Remove(tx.Hash()) mp.Remove(tx.Hash())
_, ok = mp.TryGetValue(tx.Hash()) _, _, ok = mp.TryGetValue(tx.Hash())
require.Equal(t, false, ok) require.Equal(t, false, ok)
// Make sure nothing left in the mempool after removal. // Make sure nothing left in the mempool after removal.
assert.Equal(t, 0, len(mp.verifiedMap)) assert.Equal(t, 0, len(mp.verifiedMap))
@ -173,8 +173,8 @@ func TestGetVerified(t *testing.T) {
require.Equal(t, mempoolSize, mp.Count()) require.Equal(t, mempoolSize, mp.Count())
verTxes := mp.GetVerifiedTransactions() verTxes := mp.GetVerifiedTransactions()
require.Equal(t, mempoolSize, len(verTxes)) require.Equal(t, mempoolSize, len(verTxes))
for _, tx := range verTxes { for _, txf := range verTxes {
require.Contains(t, txes, tx) require.Contains(t, txes, txf.Tx)
} }
for _, tx := range txes { for _, tx := range txes {
mp.Remove(tx.Hash()) mp.Remove(tx.Hash())
@ -210,8 +210,8 @@ func TestRemoveStale(t *testing.T) {
}) })
require.Equal(t, mempoolSize/2, mp.Count()) require.Equal(t, mempoolSize/2, mp.Count())
verTxes := mp.GetVerifiedTransactions() verTxes := mp.GetVerifiedTransactions()
for _, tx := range verTxes { for _, txf := range verTxes {
require.NotContains(t, txes1, tx) require.NotContains(t, txes1, txf.Tx)
require.Contains(t, txes2, tx) require.Contains(t, txes2, txf.Tx)
} }
} }