mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-12-23 13:41:37 +00:00
Merge pull request #448 from nspcc-dev/tx-processing-fixes
TX relaying fixes
This commit is contained in:
commit
2e99d65554
6 changed files with 160 additions and 4 deletions
|
@ -467,6 +467,9 @@ func (bc *Blockchain) storeBlock(block *Block) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
atomic.StoreUint32(&bc.blockHeight, block.Index)
|
atomic.StoreUint32(&bc.blockHeight, block.Index)
|
||||||
|
for _, tx := range block.Transactions {
|
||||||
|
bc.memPool.Remove(tx.Hash())
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -106,6 +106,7 @@ func (mp MemPool) TryAdd(hash util.Uint256, pItem *PoolItem) bool {
|
||||||
|
|
||||||
mp.lock.RLock()
|
mp.lock.RLock()
|
||||||
if _, ok := mp.unsortedTxn[hash]; ok {
|
if _, ok := mp.unsortedTxn[hash]; ok {
|
||||||
|
mp.lock.RUnlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
mp.unsortedTxn[hash] = pItem
|
mp.unsortedTxn[hash] = pItem
|
||||||
|
@ -131,6 +132,39 @@ func (mp MemPool) TryAdd(hash util.Uint256, pItem *PoolItem) bool {
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove removes an item from the mempool, if it exists there (and does
|
||||||
|
// nothing if it doesn't).
|
||||||
|
func (mp *MemPool) Remove(hash util.Uint256) {
|
||||||
|
var mapAndPools = []struct {
|
||||||
|
unsortedMap map[util.Uint256]*PoolItem
|
||||||
|
sortedPools []*PoolItems
|
||||||
|
}{
|
||||||
|
{unsortedMap: mp.unsortedTxn, sortedPools: []*PoolItems{&mp.sortedHighPrioTxn, &mp.sortedLowPrioTxn}},
|
||||||
|
{unsortedMap: mp.unverifiedTxn, sortedPools: []*PoolItems{&mp.unverifiedSortedHighPrioTxn, &mp.unverifiedSortedLowPrioTxn}},
|
||||||
|
}
|
||||||
|
mp.lock.Lock()
|
||||||
|
for _, mapAndPool := range mapAndPools {
|
||||||
|
if _, ok := mapAndPool.unsortedMap[hash]; ok {
|
||||||
|
delete(mapAndPool.unsortedMap, hash)
|
||||||
|
for _, pool := range mapAndPool.sortedPools {
|
||||||
|
var num int
|
||||||
|
var item *PoolItem
|
||||||
|
for num, item = range *pool {
|
||||||
|
if hash.Equals(item.txn.Hash()) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if num < len(*pool)-1 {
|
||||||
|
*pool = append((*pool)[:num], (*pool)[num+1:]...)
|
||||||
|
} else if num == len(*pool)-1 {
|
||||||
|
*pool = (*pool)[:num]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mp.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// RemoveOverCapacity removes transactions with lowest fees until the total number of transactions
|
// RemoveOverCapacity removes transactions with lowest fees until the total number of transactions
|
||||||
// in the MemPool is within the capacity of the MemPool.
|
// in the MemPool is within the capacity of the MemPool.
|
||||||
func (mp *MemPool) RemoveOverCapacity() {
|
func (mp *MemPool) RemoveOverCapacity() {
|
||||||
|
|
64
pkg/core/mem_pool_test.go
Normal file
64
pkg/core/mem_pool_test.go
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/CityOfZion/neo-go/pkg/core/transaction"
|
||||||
|
"github.com/CityOfZion/neo-go/pkg/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FeerStub struct {
|
||||||
|
lowPriority bool
|
||||||
|
sysFee util.Fixed8
|
||||||
|
netFee util.Fixed8
|
||||||
|
perByteFee util.Fixed8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 {
|
||||||
|
return fs.netFee
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *FeerStub) IsLowPriority(*transaction.Transaction) bool {
|
||||||
|
return fs.lowPriority
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *FeerStub) FeePerByte(*transaction.Transaction) util.Fixed8 {
|
||||||
|
return fs.perByteFee
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *FeerStub) SystemFee(*transaction.Transaction) util.Fixed8 {
|
||||||
|
return fs.sysFee
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) {
|
||||||
|
mp := NewMemPool(10)
|
||||||
|
tx := newMinerTX()
|
||||||
|
item := NewPoolItem(tx, fs)
|
||||||
|
_, ok := mp.TryGetValue(tx.Hash())
|
||||||
|
require.Equal(t, false, ok)
|
||||||
|
require.Equal(t, true, mp.TryAdd(tx.Hash(), item))
|
||||||
|
// Re-adding should fail.
|
||||||
|
require.Equal(t, false, mp.TryAdd(tx.Hash(), item))
|
||||||
|
tx2, ok := mp.TryGetValue(tx.Hash())
|
||||||
|
require.Equal(t, true, ok)
|
||||||
|
require.Equal(t, tx, tx2)
|
||||||
|
mp.Remove(tx.Hash())
|
||||||
|
_, ok = mp.TryGetValue(tx.Hash())
|
||||||
|
require.Equal(t, false, ok)
|
||||||
|
// Make sure nothing left in the mempool after removal.
|
||||||
|
assert.Equal(t, 0, len(mp.unsortedTxn))
|
||||||
|
assert.Equal(t, 0, len(mp.unverifiedTxn))
|
||||||
|
assert.Equal(t, 0, len(mp.sortedHighPrioTxn))
|
||||||
|
assert.Equal(t, 0, len(mp.sortedLowPrioTxn))
|
||||||
|
assert.Equal(t, 0, len(mp.unverifiedSortedHighPrioTxn))
|
||||||
|
assert.Equal(t, 0, len(mp.unverifiedSortedLowPrioTxn))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemPoolAddRemove(t *testing.T) {
|
||||||
|
var fs = &FeerStub{lowPriority: false}
|
||||||
|
t.Run("low priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) })
|
||||||
|
fs.lowPriority = true
|
||||||
|
t.Run("high priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) })
|
||||||
|
}
|
|
@ -23,6 +23,7 @@ func NewInvocationTX(script []byte) *Transaction {
|
||||||
Version: 1,
|
Version: 1,
|
||||||
Data: &InvocationTX{
|
Data: &InvocationTX{
|
||||||
Script: script,
|
Script: script,
|
||||||
|
Version: 1,
|
||||||
},
|
},
|
||||||
Attributes: []*Attribute{},
|
Attributes: []*Attribute{},
|
||||||
Inputs: []*Input{},
|
Inputs: []*Input{},
|
||||||
|
|
|
@ -96,6 +96,25 @@ func TestDecodeEncodeInvocationTX(t *testing.T) {
|
||||||
assert.Equal(t, rawInvocationTX, hex.EncodeToString(buf.Bytes()))
|
assert.Equal(t, rawInvocationTX, hex.EncodeToString(buf.Bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewInvocationTX(t *testing.T) {
|
||||||
|
script := []byte{0x51}
|
||||||
|
tx := NewInvocationTX(script)
|
||||||
|
txData := tx.Data.(*InvocationTX)
|
||||||
|
assert.Equal(t, InvocationType, tx.Type)
|
||||||
|
assert.Equal(t, tx.Version, txData.Version)
|
||||||
|
assert.Equal(t, script, txData.Script)
|
||||||
|
buf := io.NewBufBinWriter()
|
||||||
|
// Update hash fields to match tx2 that is gonna autoupdate them on decode.
|
||||||
|
_ = tx.Hash()
|
||||||
|
tx.EncodeBinary(buf.BinWriter)
|
||||||
|
assert.Nil(t, buf.Err)
|
||||||
|
var tx2 Transaction
|
||||||
|
r := io.NewBinReaderFromBuf(buf.Bytes())
|
||||||
|
tx2.DecodeBinary(r)
|
||||||
|
assert.Nil(t, r.Err)
|
||||||
|
assert.Equal(t, *tx, tx2)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDecodePublishTX(t *testing.T) {
|
func TestDecodePublishTX(t *testing.T) {
|
||||||
expectedTXData := &PublishTX{}
|
expectedTXData := &PublishTX{}
|
||||||
expectedTXData.Name = "Lock"
|
expectedTXData.Name = "Lock"
|
||||||
|
|
|
@ -279,13 +279,40 @@ func (s *Server) handleBlockCmd(p Peer, block *core.Block) error {
|
||||||
|
|
||||||
// handleInvCmd processes the received inventory.
|
// handleInvCmd processes the received inventory.
|
||||||
func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
|
func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
|
||||||
if !inv.Type.Valid() || len(inv.Hashes) == 0 {
|
|
||||||
return errInvalidInvType
|
|
||||||
}
|
|
||||||
payload := payload.NewInventory(inv.Type, inv.Hashes)
|
payload := payload.NewInventory(inv.Type, inv.Hashes)
|
||||||
return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
|
return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleInvCmd processes the received inventory.
|
||||||
|
func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
|
||||||
|
switch inv.Type {
|
||||||
|
case payload.TXType:
|
||||||
|
for _, hash := range inv.Hashes {
|
||||||
|
tx, _, err := s.chain.GetTransaction(hash)
|
||||||
|
if err == nil {
|
||||||
|
err = p.WriteMsg(NewMessage(s.Net, CMDTX, tx))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case payload.BlockType:
|
||||||
|
for _, hash := range inv.Hashes {
|
||||||
|
b, err := s.chain.GetBlock(hash)
|
||||||
|
if err == nil {
|
||||||
|
err = p.WriteMsg(NewMessage(s.Net, CMDBlock, b))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case payload.ConsensusType:
|
||||||
|
// TODO (#431)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// handleAddrCmd will process received addresses.
|
// handleAddrCmd will process received addresses.
|
||||||
func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
|
func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error {
|
||||||
for _, a := range addrs.Addrs {
|
for _, a := range addrs.Addrs {
|
||||||
|
@ -350,6 +377,11 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.Handshaked() {
|
if peer.Handshaked() {
|
||||||
|
if inv, ok := msg.Payload.(*payload.Inventory); ok {
|
||||||
|
if !inv.Type.Valid() || len(inv.Hashes) == 0 {
|
||||||
|
return errInvalidInvType
|
||||||
|
}
|
||||||
|
}
|
||||||
switch msg.CommandType() {
|
switch msg.CommandType() {
|
||||||
case CMDAddr:
|
case CMDAddr:
|
||||||
addrs := msg.Payload.(*payload.AddressList)
|
addrs := msg.Payload.(*payload.AddressList)
|
||||||
|
@ -357,6 +389,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
|
||||||
case CMDGetAddr:
|
case CMDGetAddr:
|
||||||
// it has no payload
|
// it has no payload
|
||||||
return s.handleGetAddrCmd(peer)
|
return s.handleGetAddrCmd(peer)
|
||||||
|
case CMDGetData:
|
||||||
|
inv := msg.Payload.(*payload.Inventory)
|
||||||
|
return s.handleGetDataCmd(peer, inv)
|
||||||
case CMDHeaders:
|
case CMDHeaders:
|
||||||
headers := msg.Payload.(*payload.Headers)
|
headers := msg.Payload.(*payload.Headers)
|
||||||
go s.handleHeadersCmd(peer, headers)
|
go s.handleHeadersCmd(peer, headers)
|
||||||
|
|
Loading…
Reference in a new issue