Merge pull request #697 from nspcc-dev/claim-verification-in-mempool

Claim TX verification for mempool
This commit is contained in:
Roman Khimov 2020-02-27 15:33:53 +03:00 committed by GitHub
commit 26c4e83ddf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 198 additions and 17 deletions

View file

@ -46,6 +46,8 @@ type Pool struct {
lock sync.RWMutex lock sync.RWMutex
verifiedMap map[util.Uint256]*item verifiedMap map[util.Uint256]*item
verifiedTxes items verifiedTxes items
inputs []*transaction.Input
claims []*transaction.Input
capacity int capacity int
} }
@ -127,6 +129,35 @@ func (mp *Pool) containsKey(hash util.Uint256) bool {
return false return false
} }
// findIndexForInput finds an index in a sorted Input pointers slice that is
// appropriate to place this input into (or which contains an identical Input).
func findIndexForInput(slice []*transaction.Input, input *transaction.Input) int {
return sort.Search(len(slice), func(n int) bool {
return input.Cmp(slice[n]) <= 0
})
}
// pushInputToSortedSlice pushes new Input into the given slice.
func pushInputToSortedSlice(slice *[]*transaction.Input, input *transaction.Input) {
n := findIndexForInput(*slice, input)
*slice = append(*slice, input)
if n != len(*slice)-1 {
copy((*slice)[n+1:], (*slice)[n:])
(*slice)[n] = input
}
}
// dropInputFromSortedSlice removes given input from the given slice.
func dropInputFromSortedSlice(slice *[]*transaction.Input, input *transaction.Input) {
n := findIndexForInput(*slice, input)
if n == len(*slice) || *input != *(*slice)[n] {
// Not present.
return
}
copy((*slice)[n:], (*slice)[n+1:])
*slice = (*slice)[:len(*slice)-1]
}
// Add tries to add given transaction to the Pool. // Add tries to add given transaction to the Pool.
func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
var pItem = &item{ var pItem = &item{
@ -175,6 +206,19 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
copy(mp.verifiedTxes[n+1:], mp.verifiedTxes[n:]) copy(mp.verifiedTxes[n+1:], mp.verifiedTxes[n:])
mp.verifiedTxes[n] = pItem mp.verifiedTxes[n] = pItem
} }
// For lots of inputs it might be easier to push them all and sort
// afterwards, but that requires benchmarking.
for i := range t.Inputs {
pushInputToSortedSlice(&mp.inputs, &t.Inputs[i])
}
if t.Type == transaction.ClaimType {
claim := t.Data.(*transaction.ClaimTX)
for i := range claim.Claims {
pushInputToSortedSlice(&mp.claims, &claim.Claims[i])
}
}
updateMempoolMetrics(len(mp.verifiedTxes)) updateMempoolMetrics(len(mp.verifiedTxes))
mp.lock.Unlock() mp.lock.Unlock()
return nil return nil
@ -184,10 +228,10 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
// nothing if it doesn't). // nothing if it doesn't).
func (mp *Pool) Remove(hash util.Uint256) { func (mp *Pool) Remove(hash util.Uint256) {
mp.lock.Lock() mp.lock.Lock()
if _, ok := mp.verifiedMap[hash]; ok { if it, ok := mp.verifiedMap[hash]; ok {
var num int var num int
delete(mp.verifiedMap, hash) delete(mp.verifiedMap, hash)
for num := range mp.verifiedTxes { for num = range mp.verifiedTxes {
if hash.Equals(mp.verifiedTxes[num].txn.Hash()) { if hash.Equals(mp.verifiedTxes[num].txn.Hash()) {
break break
} }
@ -197,6 +241,15 @@ func (mp *Pool) Remove(hash util.Uint256) {
} else if num == len(mp.verifiedTxes)-1 { } else if num == len(mp.verifiedTxes)-1 {
mp.verifiedTxes = mp.verifiedTxes[:num] mp.verifiedTxes = mp.verifiedTxes[:num]
} }
for i := range it.txn.Inputs {
dropInputFromSortedSlice(&mp.inputs, &it.txn.Inputs[i])
}
if it.txn.Type == transaction.ClaimType {
claim := it.txn.Data.(*transaction.ClaimTX)
for i := range claim.Claims {
dropInputFromSortedSlice(&mp.claims, &claim.Claims[i])
}
}
} }
updateMempoolMetrics(len(mp.verifiedTxes)) updateMempoolMetrics(len(mp.verifiedTxes))
mp.lock.Unlock() mp.lock.Unlock()
@ -210,14 +263,33 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool) {
// We expect a lot of changes, so it's easier to allocate a new slice // We expect a lot of changes, so it's easier to allocate a new slice
// rather than move things in an old one. // rather than move things in an old one.
newVerifiedTxes := make([]*item, 0, mp.capacity) newVerifiedTxes := make([]*item, 0, mp.capacity)
newInputs := mp.inputs[:0]
newClaims := mp.claims[:0]
for _, itm := range mp.verifiedTxes { for _, itm := range mp.verifiedTxes {
if isOK(itm.txn) { if isOK(itm.txn) {
newVerifiedTxes = append(newVerifiedTxes, itm) newVerifiedTxes = append(newVerifiedTxes, itm)
for i := range itm.txn.Inputs {
newInputs = append(newInputs, &itm.txn.Inputs[i])
}
if itm.txn.Type == transaction.ClaimType {
claim := itm.txn.Data.(*transaction.ClaimTX)
for i := range claim.Claims {
newClaims = append(newClaims, &claim.Claims[i])
}
}
} else { } else {
delete(mp.verifiedMap, itm.txn.Hash()) delete(mp.verifiedMap, itm.txn.Hash())
} }
} }
sort.Slice(newInputs, func(i, j int) bool {
return newInputs[i].Cmp(newInputs[j]) < 0
})
sort.Slice(newClaims, func(i, j int) bool {
return newClaims[i].Cmp(newClaims[j]) < 0
})
mp.verifiedTxes = newVerifiedTxes mp.verifiedTxes = newVerifiedTxes
mp.inputs = newInputs
mp.claims = newClaims
mp.lock.Unlock() mp.lock.Unlock()
} }
@ -259,16 +331,18 @@ func (mp *Pool) GetVerifiedTransactions() []TxWithFee {
// verifyInputs is an internal unprotected version of Verify. // verifyInputs is an internal unprotected version of Verify.
func (mp *Pool) verifyInputs(tx *transaction.Transaction) bool { func (mp *Pool) verifyInputs(tx *transaction.Transaction) bool {
if len(tx.Inputs) == 0 { for i := range tx.Inputs {
return true n := findIndexForInput(mp.inputs, &tx.Inputs[i])
if n < len(mp.inputs) && *mp.inputs[n] == tx.Inputs[i] {
return false
}
} }
for num := range mp.verifiedTxes { if tx.Type == transaction.ClaimType {
txn := mp.verifiedTxes[num].txn claim := tx.Data.(*transaction.ClaimTX)
for i := range txn.Inputs { for i := range claim.Claims {
for j := 0; j < len(tx.Inputs); j++ { n := findIndexForInput(mp.claims, &claim.Claims[i])
if txn.Inputs[i] == tx.Inputs[j] { if n < len(mp.claims) && *mp.claims[n] == claim.Claims[i] {
return false return false
}
} }
} }
} }

View file

@ -60,7 +60,73 @@ func TestMemPoolAddRemove(t *testing.T) {
t.Run("high priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) }) t.Run("high priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) })
} }
func TestMemPoolVerify(t *testing.T) { func TestMemPoolAddRemoveWithInputsAndClaims(t *testing.T) {
mp := NewMemPool(50)
hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25")
require.NoError(t, err)
hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9")
require.NoError(t, err)
mpLessInputs := func(i, j int) bool {
return mp.inputs[i].Cmp(mp.inputs[j]) < 0
}
mpLessClaims := func(i, j int) bool {
return mp.claims[i].Cmp(mp.claims[j]) < 0
}
txm1 := newMinerTX(1)
txc1, claim1 := newClaimTX()
for i := 0; i < 5; i++ {
txm1.Inputs = append(txm1.Inputs, transaction.Input{PrevHash: hash1, PrevIndex: uint16(100 - i)})
claim1.Claims = append(claim1.Claims, transaction.Input{PrevHash: hash1, PrevIndex: uint16(100 - i)})
}
require.NoError(t, mp.Add(txm1, &FeerStub{}))
require.NoError(t, mp.Add(txc1, &FeerStub{}))
// Look inside.
assert.Equal(t, len(txm1.Inputs), len(mp.inputs))
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
assert.Equal(t, len(claim1.Claims), len(mp.claims))
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
txm2 := newMinerTX(1)
txc2, claim2 := newClaimTX()
for i := 0; i < 10; i++ {
txm2.Inputs = append(txm2.Inputs, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i)})
claim2.Claims = append(claim2.Claims, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i)})
}
require.NoError(t, mp.Add(txm2, &FeerStub{}))
require.NoError(t, mp.Add(txc2, &FeerStub{}))
assert.Equal(t, len(txm1.Inputs)+len(txm2.Inputs), len(mp.inputs))
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
assert.Equal(t, len(claim1.Claims)+len(claim2.Claims), len(mp.claims))
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
mp.Remove(txm1.Hash())
mp.Remove(txc2.Hash())
assert.Equal(t, len(txm2.Inputs), len(mp.inputs))
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
assert.Equal(t, len(claim1.Claims), len(mp.claims))
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
require.NoError(t, mp.Add(txm1, &FeerStub{}))
require.NoError(t, mp.Add(txc2, &FeerStub{}))
assert.Equal(t, len(txm1.Inputs)+len(txm2.Inputs), len(mp.inputs))
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
assert.Equal(t, len(claim1.Claims)+len(claim2.Claims), len(mp.claims))
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
mp.RemoveStale(func(t *transaction.Transaction) bool {
if t.Hash() == txc1.Hash() || t.Hash() == txm2.Hash() {
return false
}
return true
})
assert.Equal(t, len(txm1.Inputs), len(mp.inputs))
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
assert.Equal(t, len(claim2.Claims), len(mp.claims))
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
}
func TestMemPoolVerifyInputs(t *testing.T) {
mp := NewMemPool(10) mp := NewMemPool(10)
tx := newMinerTX(1) tx := newMinerTX(1)
inhash1 := random.Uint256() inhash1 := random.Uint256()
@ -84,6 +150,33 @@ func TestMemPoolVerify(t *testing.T) {
require.Error(t, mp.Add(tx3, &FeerStub{})) require.Error(t, mp.Add(tx3, &FeerStub{}))
} }
func TestMemPoolVerifyClaims(t *testing.T) {
mp := NewMemPool(50)
tx1, claim1 := newClaimTX()
hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25")
require.NoError(t, err)
hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9")
require.NoError(t, err)
for i := 0; i < 10; i++ {
claim1.Claims = append(claim1.Claims, transaction.Input{PrevHash: hash1, PrevIndex: uint16(i)})
claim1.Claims = append(claim1.Claims, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i)})
}
require.Equal(t, true, mp.Verify(tx1))
require.NoError(t, mp.Add(tx1, &FeerStub{}))
tx2, claim2 := newClaimTX()
for i := 0; i < 10; i++ {
claim2.Claims = append(claim2.Claims, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i + 10)})
}
require.Equal(t, true, mp.Verify(tx2))
require.NoError(t, mp.Add(tx2, &FeerStub{}))
tx3, claim3 := newClaimTX()
claim3.Claims = append(claim3.Claims, transaction.Input{PrevHash: hash1, PrevIndex: 0})
require.Equal(t, false, mp.Verify(tx3))
require.Error(t, mp.Add(tx3, &FeerStub{}))
}
func newMinerTX(i uint32) *transaction.Transaction { func newMinerTX(i uint32) *transaction.Transaction {
return &transaction.Transaction{ return &transaction.Transaction{
Type: transaction.MinerType, Type: transaction.MinerType,
@ -91,6 +184,14 @@ func newMinerTX(i uint32) *transaction.Transaction {
} }
} }
func newClaimTX() (*transaction.Transaction, *transaction.ClaimTX) {
cl := &transaction.ClaimTX{}
return &transaction.Transaction{
Type: transaction.ClaimType,
Data: cl,
}, cl
}
func TestOverCapacity(t *testing.T) { func TestOverCapacity(t *testing.T) {
var fs = &FeerStub{lowPriority: true} var fs = &FeerStub{lowPriority: true}
const mempoolSize = 10 const mempoolSize = 10

View file

@ -28,6 +28,16 @@ func (in *Input) EncodeBinary(bw *io.BinWriter) {
bw.WriteU16LE(in.PrevIndex) bw.WriteU16LE(in.PrevIndex)
} }
// Cmp compares two Inputs by their hash and index allowing to make a set of
// transactions ordered.
func (in *Input) Cmp(other *Input) int {
hashcmp := in.PrevHash.CompareTo(other.PrevHash)
if hashcmp == 0 {
return int(in.PrevIndex) - int(other.PrevIndex)
}
return hashcmp
}
// MapInputsToSorted maps given slice of inputs into a new slice of pointers // MapInputsToSorted maps given slice of inputs into a new slice of pointers
// to inputs sorted by their PrevHash and PrevIndex. // to inputs sorted by their PrevHash and PrevIndex.
func MapInputsToSorted(ins []Input) []*Input { func MapInputsToSorted(ins []Input) []*Input {
@ -36,11 +46,7 @@ func MapInputsToSorted(ins []Input) []*Input {
ptrs[i] = &ins[i] ptrs[i] = &ins[i]
} }
sort.Slice(ptrs, func(i, j int) bool { sort.Slice(ptrs, func(i, j int) bool {
hashcmp := ptrs[i].PrevHash.CompareTo(ptrs[j].PrevHash) return ptrs[i].Cmp(ptrs[j]) < 0
if hashcmp == 0 {
return ptrs[i].PrevIndex < ptrs[j].PrevIndex
}
return hashcmp < 0
}) })
return ptrs return ptrs
} }