Merge pull request #448 from nspcc-dev/tx-processing-fixes

TX relaying fixes
This commit is contained in:
Vsevolod 2019-10-24 13:54:42 +03:00 committed by GitHub
commit 2e99d65554
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 160 additions and 4 deletions

View file

@ -467,6 +467,9 @@ func (bc *Blockchain) storeBlock(block *Block) error {
} }
atomic.StoreUint32(&bc.blockHeight, block.Index) atomic.StoreUint32(&bc.blockHeight, block.Index)
for _, tx := range block.Transactions {
bc.memPool.Remove(tx.Hash())
}
return nil return nil
} }

View file

@ -106,6 +106,7 @@ func (mp MemPool) TryAdd(hash util.Uint256, pItem *PoolItem) bool {
mp.lock.RLock() mp.lock.RLock()
if _, ok := mp.unsortedTxn[hash]; ok { if _, ok := mp.unsortedTxn[hash]; ok {
mp.lock.RUnlock()
return false return false
} }
mp.unsortedTxn[hash] = pItem mp.unsortedTxn[hash] = pItem
@ -131,6 +132,39 @@ func (mp MemPool) TryAdd(hash util.Uint256, pItem *PoolItem) bool {
return ok return ok
} }
// Remove removes an item from the mempool, if it exists there (and does
// nothing if it doesn't).
func (mp *MemPool) Remove(hash util.Uint256) {
var mapAndPools = []struct {
unsortedMap map[util.Uint256]*PoolItem
sortedPools []*PoolItems
}{
{unsortedMap: mp.unsortedTxn, sortedPools: []*PoolItems{&mp.sortedHighPrioTxn, &mp.sortedLowPrioTxn}},
{unsortedMap: mp.unverifiedTxn, sortedPools: []*PoolItems{&mp.unverifiedSortedHighPrioTxn, &mp.unverifiedSortedLowPrioTxn}},
}
mp.lock.Lock()
for _, mapAndPool := range mapAndPools {
if _, ok := mapAndPool.unsortedMap[hash]; ok {
delete(mapAndPool.unsortedMap, hash)
for _, pool := range mapAndPool.sortedPools {
var num int
var item *PoolItem
for num, item = range *pool {
if hash.Equals(item.txn.Hash()) {
break
}
}
if num < len(*pool)-1 {
*pool = append((*pool)[:num], (*pool)[num+1:]...)
} else if num == len(*pool)-1 {
*pool = (*pool)[:num]
}
}
}
}
mp.lock.Unlock()
}
// RemoveOverCapacity removes transactions with lowest fees until the total number of transactions // RemoveOverCapacity removes transactions with lowest fees until the total number of transactions
// in the MemPool is within the capacity of the MemPool. // in the MemPool is within the capacity of the MemPool.
func (mp *MemPool) RemoveOverCapacity() { func (mp *MemPool) RemoveOverCapacity() {

64
pkg/core/mem_pool_test.go Normal file
View file

@ -0,0 +1,64 @@
package core
import (
"testing"
"github.com/CityOfZion/neo-go/pkg/core/transaction"
"github.com/CityOfZion/neo-go/pkg/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type FeerStub struct {
lowPriority bool
sysFee util.Fixed8
netFee util.Fixed8
perByteFee util.Fixed8
}
func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 {
return fs.netFee
}
func (fs *FeerStub) IsLowPriority(*transaction.Transaction) bool {
return fs.lowPriority
}
func (fs *FeerStub) FeePerByte(*transaction.Transaction) util.Fixed8 {
return fs.perByteFee
}
func (fs *FeerStub) SystemFee(*transaction.Transaction) util.Fixed8 {
return fs.sysFee
}
func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) {
mp := NewMemPool(10)
tx := newMinerTX()
item := NewPoolItem(tx, fs)
_, ok := mp.TryGetValue(tx.Hash())
require.Equal(t, false, ok)
require.Equal(t, true, mp.TryAdd(tx.Hash(), item))
// Re-adding should fail.
require.Equal(t, false, mp.TryAdd(tx.Hash(), item))
tx2, ok := mp.TryGetValue(tx.Hash())
require.Equal(t, true, ok)
require.Equal(t, tx, tx2)
mp.Remove(tx.Hash())
_, ok = mp.TryGetValue(tx.Hash())
require.Equal(t, false, ok)
// Make sure nothing left in the mempool after removal.
assert.Equal(t, 0, len(mp.unsortedTxn))
assert.Equal(t, 0, len(mp.unverifiedTxn))
assert.Equal(t, 0, len(mp.sortedHighPrioTxn))
assert.Equal(t, 0, len(mp.sortedLowPrioTxn))
assert.Equal(t, 0, len(mp.unverifiedSortedHighPrioTxn))
assert.Equal(t, 0, len(mp.unverifiedSortedLowPrioTxn))
}
func TestMemPoolAddRemove(t *testing.T) {
var fs = &FeerStub{lowPriority: false}
t.Run("low priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) })
fs.lowPriority = true
t.Run("high priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) })
}

View file

@ -22,7 +22,8 @@ func NewInvocationTX(script []byte) *Transaction {
Type: InvocationType, Type: InvocationType,
Version: 1, Version: 1,
Data: &InvocationTX{ Data: &InvocationTX{
Script: script, Script: script,
Version: 1,
}, },
Attributes: []*Attribute{}, Attributes: []*Attribute{},
Inputs: []*Input{}, Inputs: []*Input{},

View file

@ -96,6 +96,25 @@ func TestDecodeEncodeInvocationTX(t *testing.T) {
assert.Equal(t, rawInvocationTX, hex.EncodeToString(buf.Bytes())) assert.Equal(t, rawInvocationTX, hex.EncodeToString(buf.Bytes()))
} }
func TestNewInvocationTX(t *testing.T) {
script := []byte{0x51}
tx := NewInvocationTX(script)
txData := tx.Data.(*InvocationTX)
assert.Equal(t, InvocationType, tx.Type)
assert.Equal(t, tx.Version, txData.Version)
assert.Equal(t, script, txData.Script)
buf := io.NewBufBinWriter()
// Update hash fields to match tx2 that is gonna autoupdate them on decode.
_ = tx.Hash()
tx.EncodeBinary(buf.BinWriter)
assert.Nil(t, buf.Err)
var tx2 Transaction
r := io.NewBinReaderFromBuf(buf.Bytes())
tx2.DecodeBinary(r)
assert.Nil(t, r.Err)
assert.Equal(t, *tx, tx2)
}
func TestDecodePublishTX(t *testing.T) { func TestDecodePublishTX(t *testing.T) {
expectedTXData := &PublishTX{} expectedTXData := &PublishTX{}
expectedTXData.Name = "Lock" expectedTXData.Name = "Lock"

View file

@ -279,13 +279,40 @@ func (s *Server) handleBlockCmd(p Peer, block *core.Block) error {
// handleInvCmd processes the received inventory. // handleInvCmd processes the received inventory.
func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
if !inv.Type.Valid() || len(inv.Hashes) == 0 {
return errInvalidInvType
}
payload := payload.NewInventory(inv.Type, inv.Hashes) payload := payload.NewInventory(inv.Type, inv.Hashes)
return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
} }
// handleInvCmd processes the received inventory.
func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
switch inv.Type {
case payload.TXType:
for _, hash := range inv.Hashes {
tx, _, err := s.chain.GetTransaction(hash)
if err == nil {
err = p.WriteMsg(NewMessage(s.Net, CMDTX, tx))
if err != nil {
return err
}
}
}
case payload.BlockType:
for _, hash := range inv.Hashes {
b, err := s.chain.GetBlock(hash)
if err == nil {
err = p.WriteMsg(NewMessage(s.Net, CMDBlock, b))
if err != nil {
return err
}
}
}
case payload.ConsensusType:
// TODO (#431)
}
return nil
}
// handleAddrCmd will process received addresses. // handleAddrCmd will process received addresses.
func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error { func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
for _, a := range addrs.Addrs { for _, a := range addrs.Addrs {
@ -350,6 +377,11 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
} }
if peer.Handshaked() { if peer.Handshaked() {
if inv, ok := msg.Payload.(*payload.Inventory); ok {
if !inv.Type.Valid() || len(inv.Hashes) == 0 {
return errInvalidInvType
}
}
switch msg.CommandType() { switch msg.CommandType() {
case CMDAddr: case CMDAddr:
addrs := msg.Payload.(*payload.AddressList) addrs := msg.Payload.(*payload.AddressList)
@ -357,6 +389,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
case CMDGetAddr: case CMDGetAddr:
// it has no payload // it has no payload
return s.handleGetAddrCmd(peer) return s.handleGetAddrCmd(peer)
case CMDGetData:
inv := msg.Payload.(*payload.Inventory)
return s.handleGetDataCmd(peer, inv)
case CMDHeaders: case CMDHeaders:
headers := msg.Payload.(*payload.Headers) headers := msg.Payload.(*payload.Headers)
go s.handleHeadersCmd(peer, headers) go s.handleHeadersCmd(peer, headers)