diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 5f9fe82ed..7dc38e2d4 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -168,7 +168,7 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { } pItem.isLowPrio = fee.IsLowPriority(pItem.netFee) mp.lock.Lock() - if !mp.verifyInputs(t) { + if !mp.checkTxConflicts(t) { mp.lock.Unlock() return ErrConflict } @@ -329,24 +329,39 @@ func (mp *Pool) GetVerifiedTransactions() []TxWithFee { return t } -// verifyInputs is an internal unprotected version of Verify. -func (mp *Pool) verifyInputs(tx *transaction.Transaction) bool { - for i := range tx.Inputs { - n := findIndexForInput(mp.inputs, &tx.Inputs[i]) - if n < len(mp.inputs) && *mp.inputs[n] == tx.Inputs[i] { - return false +// areInputsInPool tries to find inputs in a given sorted pool and returns true +// if it finds any. +func areInputsInPool(inputs []transaction.Input, pool []*transaction.Input) bool { + for i := range inputs { + n := findIndexForInput(pool, &inputs[i]) + if n < len(pool) && *pool[n] == inputs[i] { + return true } } - if tx.Type == transaction.ClaimType { + return false +} + +// checkTxConflicts is an internal unprotected version of Verify. +func (mp *Pool) checkTxConflicts(tx *transaction.Transaction) bool { + if areInputsInPool(tx.Inputs, mp.inputs) { + return false + } + switch tx.Type { + case transaction.ClaimType: claim := tx.Data.(*transaction.ClaimTX) - for i := range claim.Claims { - n := findIndexForInput(mp.claims, &claim.Claims[i]) - if n < len(mp.claims) && *mp.claims[n] == claim.Claims[i] { + if areInputsInPool(claim.Claims, mp.claims) { + return false + } + case transaction.IssueType: + // It's a hack, because technically we could check for + // available asset amount, but these transactions are so rare + // that no one really cares about this restriction. + for i := range mp.verifiedTxes { + if mp.verifiedTxes[i].txn.Type == transaction.IssueType { return false } } } - return true } @@ -356,5 +371,5 @@ func (mp *Pool) verifyInputs(tx *transaction.Transaction) bool { func (mp *Pool) Verify(tx *transaction.Transaction) bool { mp.lock.RLock() defer mp.lock.RUnlock() - return mp.verifyInputs(tx) + return mp.checkTxConflicts(tx) } diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index 315361bb3..212a623f0 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -177,6 +177,31 @@ func TestMemPoolVerifyClaims(t *testing.T) { require.Error(t, mp.Add(tx3, &FeerStub{})) } +func TestMemPoolVerifyIssue(t *testing.T) { + mp := NewMemPool(50) + tx1 := newIssueTX() + require.Equal(t, true, mp.Verify(tx1)) + require.NoError(t, mp.Add(tx1, &FeerStub{})) + + tx2 := newIssueTX() + require.Equal(t, false, mp.Verify(tx2)) + require.Error(t, mp.Add(tx2, &FeerStub{})) +} + +func newIssueTX() *transaction.Transaction { + return &transaction.Transaction{ + Type: transaction.IssueType, + Data: &transaction.IssueTX{}, + Outputs: []transaction.Output{ + transaction.Output{ + AssetID: random.Uint256(), + Amount: util.Fixed8FromInt64(42), + ScriptHash: random.Uint160(), + }, + }, + } +} + func newMinerTX(i uint32) *transaction.Transaction { return &transaction.Transaction{ Type: transaction.MinerType,