diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 3ecd2c3b7..13159122e 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -467,6 +467,9 @@ func (bc *Blockchain) storeBlock(block *Block) error { } atomic.StoreUint32(&bc.blockHeight, block.Index) + for _, tx := range block.Transactions { + bc.memPool.Remove(tx.Hash()) + } return nil } diff --git a/pkg/core/mem_pool.go b/pkg/core/mem_pool.go index 1c8396719..3c675e5b7 100644 --- a/pkg/core/mem_pool.go +++ b/pkg/core/mem_pool.go @@ -106,6 +106,7 @@ func (mp MemPool) TryAdd(hash util.Uint256, pItem *PoolItem) bool { mp.lock.RLock() if _, ok := mp.unsortedTxn[hash]; ok { + mp.lock.RUnlock() return false } mp.unsortedTxn[hash] = pItem @@ -131,6 +132,39 @@ func (mp MemPool) TryAdd(hash util.Uint256, pItem *PoolItem) bool { return ok } +// Remove removes an item from the mempool, if it exists there (and does +// nothing if it doesn't). +func (mp *MemPool) Remove(hash util.Uint256) { + var mapAndPools = []struct { + unsortedMap map[util.Uint256]*PoolItem + sortedPools []*PoolItems + }{ + {unsortedMap: mp.unsortedTxn, sortedPools: []*PoolItems{&mp.sortedHighPrioTxn, &mp.sortedLowPrioTxn}}, + {unsortedMap: mp.unverifiedTxn, sortedPools: []*PoolItems{&mp.unverifiedSortedHighPrioTxn, &mp.unverifiedSortedLowPrioTxn}}, + } + mp.lock.Lock() + for _, mapAndPool := range mapAndPools { + if _, ok := mapAndPool.unsortedMap[hash]; ok { + delete(mapAndPool.unsortedMap, hash) + for _, pool := range mapAndPool.sortedPools { + var num int + var item *PoolItem + for num, item = range *pool { + if hash.Equals(item.txn.Hash()) { + break + } + } + if num < len(*pool)-1 { + *pool = append((*pool)[:num], (*pool)[num+1:]...) + } else if num == len(*pool)-1 { + *pool = (*pool)[:num] + } + } + } + } + mp.lock.Unlock() +} + // RemoveOverCapacity removes transactions with lowest fees until the total number of transactions // in the MemPool is within the capacity of the MemPool. func (mp *MemPool) RemoveOverCapacity() { diff --git a/pkg/core/mem_pool_test.go b/pkg/core/mem_pool_test.go new file mode 100644 index 000000000..bfd12230b --- /dev/null +++ b/pkg/core/mem_pool_test.go @@ -0,0 +1,64 @@ +package core + +import ( + "testing" + + "github.com/CityOfZion/neo-go/pkg/core/transaction" + "github.com/CityOfZion/neo-go/pkg/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type FeerStub struct { + lowPriority bool + sysFee util.Fixed8 + netFee util.Fixed8 + perByteFee util.Fixed8 +} + +func (fs *FeerStub) NetworkFee(*transaction.Transaction) util.Fixed8 { + return fs.netFee +} + +func (fs *FeerStub) IsLowPriority(*transaction.Transaction) bool { + return fs.lowPriority +} + +func (fs *FeerStub) FeePerByte(*transaction.Transaction) util.Fixed8 { + return fs.perByteFee +} + +func (fs *FeerStub) SystemFee(*transaction.Transaction) util.Fixed8 { + return fs.sysFee +} + +func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { + mp := NewMemPool(10) + tx := newMinerTX() + item := NewPoolItem(tx, fs) + _, ok := mp.TryGetValue(tx.Hash()) + require.Equal(t, false, ok) + require.Equal(t, true, mp.TryAdd(tx.Hash(), item)) + // Re-adding should fail. + require.Equal(t, false, mp.TryAdd(tx.Hash(), item)) + tx2, ok := mp.TryGetValue(tx.Hash()) + require.Equal(t, true, ok) + require.Equal(t, tx, tx2) + mp.Remove(tx.Hash()) + _, ok = mp.TryGetValue(tx.Hash()) + require.Equal(t, false, ok) + // Make sure nothing left in the mempool after removal. + assert.Equal(t, 0, len(mp.unsortedTxn)) + assert.Equal(t, 0, len(mp.unverifiedTxn)) + assert.Equal(t, 0, len(mp.sortedHighPrioTxn)) + assert.Equal(t, 0, len(mp.sortedLowPrioTxn)) + assert.Equal(t, 0, len(mp.unverifiedSortedHighPrioTxn)) + assert.Equal(t, 0, len(mp.unverifiedSortedLowPrioTxn)) +} + +func TestMemPoolAddRemove(t *testing.T) { + var fs = &FeerStub{lowPriority: false} + t.Run("low priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) }) + fs.lowPriority = true + t.Run("high priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) }) +} diff --git a/pkg/core/transaction/invocation.go b/pkg/core/transaction/invocation.go index 2ac3550e0..a0d82cb58 100644 --- a/pkg/core/transaction/invocation.go +++ b/pkg/core/transaction/invocation.go @@ -22,7 +22,8 @@ func NewInvocationTX(script []byte) *Transaction { Type: InvocationType, Version: 1, Data: &InvocationTX{ - Script: script, + Script: script, + Version: 1, }, Attributes: []*Attribute{}, Inputs: []*Input{}, diff --git a/pkg/core/transaction/transaction_test.go b/pkg/core/transaction/transaction_test.go index 2c416c570..4d5ef35ac 100644 --- a/pkg/core/transaction/transaction_test.go +++ b/pkg/core/transaction/transaction_test.go @@ -96,6 +96,25 @@ func TestDecodeEncodeInvocationTX(t *testing.T) { assert.Equal(t, rawInvocationTX, hex.EncodeToString(buf.Bytes())) } +func TestNewInvocationTX(t *testing.T) { + script := []byte{0x51} + tx := NewInvocationTX(script) + txData := tx.Data.(*InvocationTX) + assert.Equal(t, InvocationType, tx.Type) + assert.Equal(t, tx.Version, txData.Version) + assert.Equal(t, script, txData.Script) + buf := io.NewBufBinWriter() + // Update hash fields to match tx2 that is gonna autoupdate them on decode. + _ = tx.Hash() + tx.EncodeBinary(buf.BinWriter) + assert.Nil(t, buf.Err) + var tx2 Transaction + r := io.NewBinReaderFromBuf(buf.Bytes()) + tx2.DecodeBinary(r) + assert.Nil(t, r.Err) + assert.Equal(t, *tx, tx2) +} + func TestDecodePublishTX(t *testing.T) { expectedTXData := &PublishTX{} expectedTXData.Name = "Lock" diff --git a/pkg/network/server.go b/pkg/network/server.go index eec400c9a..f5efc4fac 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -279,13 +279,40 @@ func (s *Server) handleBlockCmd(p Peer, block *core.Block) error { // handleInvCmd processes the received inventory. func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { - if !inv.Type.Valid() || len(inv.Hashes) == 0 { - return errInvalidInvType - } payload := payload.NewInventory(inv.Type, inv.Hashes) return p.WriteMsg(NewMessage(s.Net, CMDGetData, payload)) } +// handleInvCmd processes the received inventory. +func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { + switch inv.Type { + case payload.TXType: + for _, hash := range inv.Hashes { + tx, _, err := s.chain.GetTransaction(hash) + if err == nil { + err = p.WriteMsg(NewMessage(s.Net, CMDTX, tx)) + if err != nil { + return err + } + + } + } + case payload.BlockType: + for _, hash := range inv.Hashes { + b, err := s.chain.GetBlock(hash) + if err == nil { + err = p.WriteMsg(NewMessage(s.Net, CMDBlock, b)) + if err != nil { + return err + } + } + } + case payload.ConsensusType: + // TODO (#431) + } + return nil +} + // handleAddrCmd will process received addresses. func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error { for _, a := range addrs.Addrs { @@ -350,6 +377,11 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { } if peer.Handshaked() { + if inv, ok := msg.Payload.(*payload.Inventory); ok { + if !inv.Type.Valid() || len(inv.Hashes) == 0 { + return errInvalidInvType + } + } switch msg.CommandType() { case CMDAddr: addrs := msg.Payload.(*payload.AddressList) @@ -357,6 +389,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDGetAddr: // it has no payload return s.handleGetAddrCmd(peer) + case CMDGetData: + inv := msg.Payload.(*payload.Inventory) + return s.handleGetDataCmd(peer, inv) case CMDHeaders: headers := msg.Payload.(*payload.Headers) go s.handleHeadersCmd(peer, headers)