diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 9a99a1774..e0371fe2a 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -46,6 +46,8 @@ type Pool struct { lock sync.RWMutex verifiedMap map[util.Uint256]*item verifiedTxes items + inputs []*transaction.Input + claims []*transaction.Input capacity int } @@ -127,6 +129,35 @@ func (mp *Pool) containsKey(hash util.Uint256) bool { return false } +// findIndexForInput finds an index in a sorted Input pointers slice that is +// appropriate to place this input into (or which contains an identical Input). +func findIndexForInput(slice []*transaction.Input, input *transaction.Input) int { + return sort.Search(len(slice), func(n int) bool { + return input.Cmp(slice[n]) <= 0 + }) +} + +// pushInputToSortedSlice pushes new Input into the given slice. +func pushInputToSortedSlice(slice *[]*transaction.Input, input *transaction.Input) { + n := findIndexForInput(*slice, input) + *slice = append(*slice, input) + if n != len(*slice)-1 { + copy((*slice)[n+1:], (*slice)[n:]) + (*slice)[n] = input + } +} + +// dropInputFromSortedSlice removes given input from the given slice. +func dropInputFromSortedSlice(slice *[]*transaction.Input, input *transaction.Input) { + n := findIndexForInput(*slice, input) + if n == len(*slice) || *input != *(*slice)[n] { + // Not present. + return + } + copy((*slice)[n:], (*slice)[n+1:]) + *slice = (*slice)[:len(*slice)-1] +} + // Add tries to add given transaction to the Pool. func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { var pItem = &item{ @@ -175,6 +206,19 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { copy(mp.verifiedTxes[n+1:], mp.verifiedTxes[n:]) mp.verifiedTxes[n] = pItem } + + // For lots of inputs it might be easier to push them all and sort + // afterwards, but that requires benchmarking. + for i := range t.Inputs { + pushInputToSortedSlice(&mp.inputs, &t.Inputs[i]) + } + if t.Type == transaction.ClaimType { + claim := t.Data.(*transaction.ClaimTX) + for i := range claim.Claims { + pushInputToSortedSlice(&mp.claims, &claim.Claims[i]) + } + } + updateMempoolMetrics(len(mp.verifiedTxes)) mp.lock.Unlock() return nil @@ -184,10 +228,10 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { // nothing if it doesn't). func (mp *Pool) Remove(hash util.Uint256) { mp.lock.Lock() - if _, ok := mp.verifiedMap[hash]; ok { + if it, ok := mp.verifiedMap[hash]; ok { var num int delete(mp.verifiedMap, hash) - for num := range mp.verifiedTxes { + for num = range mp.verifiedTxes { if hash.Equals(mp.verifiedTxes[num].txn.Hash()) { break } @@ -197,6 +241,15 @@ func (mp *Pool) Remove(hash util.Uint256) { } else if num == len(mp.verifiedTxes)-1 { mp.verifiedTxes = mp.verifiedTxes[:num] } + for i := range it.txn.Inputs { + dropInputFromSortedSlice(&mp.inputs, &it.txn.Inputs[i]) + } + if it.txn.Type == transaction.ClaimType { + claim := it.txn.Data.(*transaction.ClaimTX) + for i := range claim.Claims { + dropInputFromSortedSlice(&mp.claims, &claim.Claims[i]) + } + } } updateMempoolMetrics(len(mp.verifiedTxes)) mp.lock.Unlock() @@ -210,14 +263,33 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool) { // We expect a lot of changes, so it's easier to allocate a new slice // rather than move things in an old one. newVerifiedTxes := make([]*item, 0, mp.capacity) + newInputs := mp.inputs[:0] + newClaims := mp.claims[:0] for _, itm := range mp.verifiedTxes { if isOK(itm.txn) { newVerifiedTxes = append(newVerifiedTxes, itm) + for i := range itm.txn.Inputs { + newInputs = append(newInputs, &itm.txn.Inputs[i]) + } + if itm.txn.Type == transaction.ClaimType { + claim := itm.txn.Data.(*transaction.ClaimTX) + for i := range claim.Claims { + newClaims = append(newClaims, &claim.Claims[i]) + } + } } else { delete(mp.verifiedMap, itm.txn.Hash()) } } + sort.Slice(newInputs, func(i, j int) bool { + return newInputs[i].Cmp(newInputs[j]) < 0 + }) + sort.Slice(newClaims, func(i, j int) bool { + return newClaims[i].Cmp(newClaims[j]) < 0 + }) mp.verifiedTxes = newVerifiedTxes + mp.inputs = newInputs + mp.claims = newClaims mp.lock.Unlock() } @@ -259,16 +331,18 @@ func (mp *Pool) GetVerifiedTransactions() []TxWithFee { // verifyInputs is an internal unprotected version of Verify. func (mp *Pool) verifyInputs(tx *transaction.Transaction) bool { - if len(tx.Inputs) == 0 { - return true + for i := range tx.Inputs { + n := findIndexForInput(mp.inputs, &tx.Inputs[i]) + if n < len(mp.inputs) && *mp.inputs[n] == tx.Inputs[i] { + return false + } } - for num := range mp.verifiedTxes { - txn := mp.verifiedTxes[num].txn - for i := range txn.Inputs { - for j := 0; j < len(tx.Inputs); j++ { - if txn.Inputs[i] == tx.Inputs[j] { - return false - } + if tx.Type == transaction.ClaimType { + claim := tx.Data.(*transaction.ClaimTX) + for i := range claim.Claims { + n := findIndexForInput(mp.claims, &claim.Claims[i]) + if n < len(mp.claims) && *mp.claims[n] == claim.Claims[i] { + return false } } } diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index b6a9401be..6c1e1fe79 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -60,7 +60,73 @@ func TestMemPoolAddRemove(t *testing.T) { t.Run("high priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) }) } -func TestMemPoolVerify(t *testing.T) { +func TestMemPoolAddRemoveWithInputsAndClaims(t *testing.T) { + mp := NewMemPool(50) + hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25") + require.NoError(t, err) + hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9") + require.NoError(t, err) + mpLessInputs := func(i, j int) bool { + return mp.inputs[i].Cmp(mp.inputs[j]) < 0 + } + mpLessClaims := func(i, j int) bool { + return mp.claims[i].Cmp(mp.claims[j]) < 0 + } + + txm1 := newMinerTX(1) + txc1, claim1 := newClaimTX() + for i := 0; i < 5; i++ { + txm1.Inputs = append(txm1.Inputs, transaction.Input{PrevHash: hash1, PrevIndex: uint16(100 - i)}) + claim1.Claims = append(claim1.Claims, transaction.Input{PrevHash: hash1, PrevIndex: uint16(100 - i)}) + } + require.NoError(t, mp.Add(txm1, &FeerStub{})) + require.NoError(t, mp.Add(txc1, &FeerStub{})) + // Look inside. + assert.Equal(t, len(txm1.Inputs), len(mp.inputs)) + assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs)) + assert.Equal(t, len(claim1.Claims), len(mp.claims)) + assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims)) + + txm2 := newMinerTX(1) + txc2, claim2 := newClaimTX() + for i := 0; i < 10; i++ { + txm2.Inputs = append(txm2.Inputs, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i)}) + claim2.Claims = append(claim2.Claims, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i)}) + } + require.NoError(t, mp.Add(txm2, &FeerStub{})) + require.NoError(t, mp.Add(txc2, &FeerStub{})) + assert.Equal(t, len(txm1.Inputs)+len(txm2.Inputs), len(mp.inputs)) + assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs)) + assert.Equal(t, len(claim1.Claims)+len(claim2.Claims), len(mp.claims)) + assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims)) + + mp.Remove(txm1.Hash()) + mp.Remove(txc2.Hash()) + assert.Equal(t, len(txm2.Inputs), len(mp.inputs)) + assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs)) + assert.Equal(t, len(claim1.Claims), len(mp.claims)) + assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims)) + + require.NoError(t, mp.Add(txm1, &FeerStub{})) + require.NoError(t, mp.Add(txc2, &FeerStub{})) + assert.Equal(t, len(txm1.Inputs)+len(txm2.Inputs), len(mp.inputs)) + assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs)) + assert.Equal(t, len(claim1.Claims)+len(claim2.Claims), len(mp.claims)) + assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims)) + + mp.RemoveStale(func(t *transaction.Transaction) bool { + if t.Hash() == txc1.Hash() || t.Hash() == txm2.Hash() { + return false + } + return true + }) + assert.Equal(t, len(txm1.Inputs), len(mp.inputs)) + assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs)) + assert.Equal(t, len(claim2.Claims), len(mp.claims)) + assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims)) +} + +func TestMemPoolVerifyInputs(t *testing.T) { mp := NewMemPool(10) tx := newMinerTX(1) inhash1 := random.Uint256() @@ -84,6 +150,33 @@ func TestMemPoolVerify(t *testing.T) { require.Error(t, mp.Add(tx3, &FeerStub{})) } +func TestMemPoolVerifyClaims(t *testing.T) { + mp := NewMemPool(50) + tx1, claim1 := newClaimTX() + hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25") + require.NoError(t, err) + hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9") + require.NoError(t, err) + for i := 0; i < 10; i++ { + claim1.Claims = append(claim1.Claims, transaction.Input{PrevHash: hash1, PrevIndex: uint16(i)}) + claim1.Claims = append(claim1.Claims, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i)}) + } + require.Equal(t, true, mp.Verify(tx1)) + require.NoError(t, mp.Add(tx1, &FeerStub{})) + + tx2, claim2 := newClaimTX() + for i := 0; i < 10; i++ { + claim2.Claims = append(claim2.Claims, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i + 10)}) + } + require.Equal(t, true, mp.Verify(tx2)) + require.NoError(t, mp.Add(tx2, &FeerStub{})) + + tx3, claim3 := newClaimTX() + claim3.Claims = append(claim3.Claims, transaction.Input{PrevHash: hash1, PrevIndex: 0}) + require.Equal(t, false, mp.Verify(tx3)) + require.Error(t, mp.Add(tx3, &FeerStub{})) +} + func newMinerTX(i uint32) *transaction.Transaction { return &transaction.Transaction{ Type: transaction.MinerType, @@ -91,6 +184,14 @@ func newMinerTX(i uint32) *transaction.Transaction { } } +func newClaimTX() (*transaction.Transaction, *transaction.ClaimTX) { + cl := &transaction.ClaimTX{} + return &transaction.Transaction{ + Type: transaction.ClaimType, + Data: cl, + }, cl +} + func TestOverCapacity(t *testing.T) { var fs = &FeerStub{lowPriority: true} const mempoolSize = 10 diff --git a/pkg/core/transaction/input.go b/pkg/core/transaction/input.go index 3b7758ba3..c5b150c25 100644 --- a/pkg/core/transaction/input.go +++ b/pkg/core/transaction/input.go @@ -28,6 +28,16 @@ func (in *Input) EncodeBinary(bw *io.BinWriter) { bw.WriteU16LE(in.PrevIndex) } +// Cmp compares two Inputs by their hash and index allowing to make a set of +// transactions ordered. +func (in *Input) Cmp(other *Input) int { + hashcmp := in.PrevHash.CompareTo(other.PrevHash) + if hashcmp == 0 { + return int(in.PrevIndex) - int(other.PrevIndex) + } + return hashcmp +} + // MapInputsToSorted maps given slice of inputs into a new slice of pointers // to inputs sorted by their PrevHash and PrevIndex. func MapInputsToSorted(ins []Input) []*Input { @@ -36,11 +46,7 @@ func MapInputsToSorted(ins []Input) []*Input { ptrs[i] = &ins[i] } sort.Slice(ptrs, func(i, j int) bool { - hashcmp := ptrs[i].PrevHash.CompareTo(ptrs[j].PrevHash) - if hashcmp == 0 { - return ptrs[i].PrevIndex < ptrs[j].PrevIndex - } - return hashcmp < 0 + return ptrs[i].Cmp(ptrs[j]) < 0 }) return ptrs }