forked from TrueCloudLab/neoneo-go
mempool: rework inputs verification, check Claim txes
Use more efficient check for Input and also check Claims to avoid double claiming.
This commit is contained in:
parent
f329de73e8
commit
3e2b490025
3 changed files with 197 additions and 16 deletions
|
@ -46,6 +46,8 @@ type Pool struct {
|
|||
lock sync.RWMutex
|
||||
verifiedMap map[util.Uint256]*item
|
||||
verifiedTxes items
|
||||
inputs []*transaction.Input
|
||||
claims []*transaction.Input
|
||||
|
||||
capacity int
|
||||
}
|
||||
|
@ -127,6 +129,35 @@ func (mp *Pool) containsKey(hash util.Uint256) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// findIndexForInput finds an index in a sorted Input pointers slice that is
|
||||
// appropriate to place this input into (or which contains an identical Input).
|
||||
func findIndexForInput(slice []*transaction.Input, input *transaction.Input) int {
|
||||
return sort.Search(len(slice), func(n int) bool {
|
||||
return input.Cmp(slice[n]) <= 0
|
||||
})
|
||||
}
|
||||
|
||||
// pushInputToSortedSlice pushes new Input into the given slice.
|
||||
func pushInputToSortedSlice(slice *[]*transaction.Input, input *transaction.Input) {
|
||||
n := findIndexForInput(*slice, input)
|
||||
*slice = append(*slice, input)
|
||||
if n != len(*slice)-1 {
|
||||
copy((*slice)[n+1:], (*slice)[n:])
|
||||
(*slice)[n] = input
|
||||
}
|
||||
}
|
||||
|
||||
// dropInputFromSortedSlice removes given input from the given slice.
|
||||
func dropInputFromSortedSlice(slice *[]*transaction.Input, input *transaction.Input) {
|
||||
n := findIndexForInput(*slice, input)
|
||||
if n == len(*slice) || *input != *(*slice)[n] {
|
||||
// Not present.
|
||||
return
|
||||
}
|
||||
copy((*slice)[n:], (*slice)[n+1:])
|
||||
*slice = (*slice)[:len(*slice)-1]
|
||||
}
|
||||
|
||||
// Add tries to add given transaction to the Pool.
|
||||
func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
|
||||
var pItem = &item{
|
||||
|
@ -175,6 +206,19 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
|
|||
copy(mp.verifiedTxes[n+1:], mp.verifiedTxes[n:])
|
||||
mp.verifiedTxes[n] = pItem
|
||||
}
|
||||
|
||||
// For lots of inputs it might be easier to push them all and sort
|
||||
// afterwards, but that requires benchmarking.
|
||||
for i := range t.Inputs {
|
||||
pushInputToSortedSlice(&mp.inputs, &t.Inputs[i])
|
||||
}
|
||||
if t.Type == transaction.ClaimType {
|
||||
claim := t.Data.(*transaction.ClaimTX)
|
||||
for i := range claim.Claims {
|
||||
pushInputToSortedSlice(&mp.claims, &claim.Claims[i])
|
||||
}
|
||||
}
|
||||
|
||||
updateMempoolMetrics(len(mp.verifiedTxes))
|
||||
mp.lock.Unlock()
|
||||
return nil
|
||||
|
@ -184,7 +228,7 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
|
|||
// nothing if it doesn't).
|
||||
func (mp *Pool) Remove(hash util.Uint256) {
|
||||
mp.lock.Lock()
|
||||
if _, ok := mp.verifiedMap[hash]; ok {
|
||||
if it, ok := mp.verifiedMap[hash]; ok {
|
||||
var num int
|
||||
delete(mp.verifiedMap, hash)
|
||||
for num = range mp.verifiedTxes {
|
||||
|
@ -197,6 +241,15 @@ func (mp *Pool) Remove(hash util.Uint256) {
|
|||
} else if num == len(mp.verifiedTxes)-1 {
|
||||
mp.verifiedTxes = mp.verifiedTxes[:num]
|
||||
}
|
||||
for i := range it.txn.Inputs {
|
||||
dropInputFromSortedSlice(&mp.inputs, &it.txn.Inputs[i])
|
||||
}
|
||||
if it.txn.Type == transaction.ClaimType {
|
||||
claim := it.txn.Data.(*transaction.ClaimTX)
|
||||
for i := range claim.Claims {
|
||||
dropInputFromSortedSlice(&mp.claims, &claim.Claims[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
updateMempoolMetrics(len(mp.verifiedTxes))
|
||||
mp.lock.Unlock()
|
||||
|
@ -210,14 +263,33 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool) {
|
|||
// We expect a lot of changes, so it's easier to allocate a new slice
|
||||
// rather than move things in an old one.
|
||||
newVerifiedTxes := make([]*item, 0, mp.capacity)
|
||||
newInputs := mp.inputs[:0]
|
||||
newClaims := mp.claims[:0]
|
||||
for _, itm := range mp.verifiedTxes {
|
||||
if isOK(itm.txn) {
|
||||
newVerifiedTxes = append(newVerifiedTxes, itm)
|
||||
for i := range itm.txn.Inputs {
|
||||
newInputs = append(newInputs, &itm.txn.Inputs[i])
|
||||
}
|
||||
if itm.txn.Type == transaction.ClaimType {
|
||||
claim := itm.txn.Data.(*transaction.ClaimTX)
|
||||
for i := range claim.Claims {
|
||||
newClaims = append(newClaims, &claim.Claims[i])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
delete(mp.verifiedMap, itm.txn.Hash())
|
||||
}
|
||||
}
|
||||
sort.Slice(newInputs, func(i, j int) bool {
|
||||
return newInputs[i].Cmp(newInputs[j]) < 0
|
||||
})
|
||||
sort.Slice(newClaims, func(i, j int) bool {
|
||||
return newClaims[i].Cmp(newClaims[j]) < 0
|
||||
})
|
||||
mp.verifiedTxes = newVerifiedTxes
|
||||
mp.inputs = newInputs
|
||||
mp.claims = newClaims
|
||||
mp.lock.Unlock()
|
||||
}
|
||||
|
||||
|
@ -259,16 +331,18 @@ func (mp *Pool) GetVerifiedTransactions() []TxWithFee {
|
|||
|
||||
// verifyInputs is an internal unprotected version of Verify.
|
||||
func (mp *Pool) verifyInputs(tx *transaction.Transaction) bool {
|
||||
if len(tx.Inputs) == 0 {
|
||||
return true
|
||||
for i := range tx.Inputs {
|
||||
n := findIndexForInput(mp.inputs, &tx.Inputs[i])
|
||||
if n < len(mp.inputs) && *mp.inputs[n] == tx.Inputs[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for num := range mp.verifiedTxes {
|
||||
txn := mp.verifiedTxes[num].txn
|
||||
for i := range txn.Inputs {
|
||||
for j := 0; j < len(tx.Inputs); j++ {
|
||||
if txn.Inputs[i] == tx.Inputs[j] {
|
||||
return false
|
||||
}
|
||||
if tx.Type == transaction.ClaimType {
|
||||
claim := tx.Data.(*transaction.ClaimTX)
|
||||
for i := range claim.Claims {
|
||||
n := findIndexForInput(mp.claims, &claim.Claims[i])
|
||||
if n < len(mp.claims) && *mp.claims[n] == claim.Claims[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,7 +60,73 @@ func TestMemPoolAddRemove(t *testing.T) {
|
|||
t.Run("high priority", func(t *testing.T) { testMemPoolAddRemoveWithFeer(t, fs) })
|
||||
}
|
||||
|
||||
func TestMemPoolVerify(t *testing.T) {
|
||||
func TestMemPoolAddRemoveWithInputsAndClaims(t *testing.T) {
|
||||
mp := NewMemPool(50)
|
||||
hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25")
|
||||
require.NoError(t, err)
|
||||
hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9")
|
||||
require.NoError(t, err)
|
||||
mpLessInputs := func(i, j int) bool {
|
||||
return mp.inputs[i].Cmp(mp.inputs[j]) < 0
|
||||
}
|
||||
mpLessClaims := func(i, j int) bool {
|
||||
return mp.claims[i].Cmp(mp.claims[j]) < 0
|
||||
}
|
||||
|
||||
txm1 := newMinerTX(1)
|
||||
txc1, claim1 := newClaimTX()
|
||||
for i := 0; i < 5; i++ {
|
||||
txm1.Inputs = append(txm1.Inputs, transaction.Input{PrevHash: hash1, PrevIndex: uint16(100 - i)})
|
||||
claim1.Claims = append(claim1.Claims, transaction.Input{PrevHash: hash1, PrevIndex: uint16(100 - i)})
|
||||
}
|
||||
require.NoError(t, mp.Add(txm1, &FeerStub{}))
|
||||
require.NoError(t, mp.Add(txc1, &FeerStub{}))
|
||||
// Look inside.
|
||||
assert.Equal(t, len(txm1.Inputs), len(mp.inputs))
|
||||
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
|
||||
assert.Equal(t, len(claim1.Claims), len(mp.claims))
|
||||
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
|
||||
|
||||
txm2 := newMinerTX(1)
|
||||
txc2, claim2 := newClaimTX()
|
||||
for i := 0; i < 10; i++ {
|
||||
txm2.Inputs = append(txm2.Inputs, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i)})
|
||||
claim2.Claims = append(claim2.Claims, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i)})
|
||||
}
|
||||
require.NoError(t, mp.Add(txm2, &FeerStub{}))
|
||||
require.NoError(t, mp.Add(txc2, &FeerStub{}))
|
||||
assert.Equal(t, len(txm1.Inputs)+len(txm2.Inputs), len(mp.inputs))
|
||||
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
|
||||
assert.Equal(t, len(claim1.Claims)+len(claim2.Claims), len(mp.claims))
|
||||
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
|
||||
|
||||
mp.Remove(txm1.Hash())
|
||||
mp.Remove(txc2.Hash())
|
||||
assert.Equal(t, len(txm2.Inputs), len(mp.inputs))
|
||||
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
|
||||
assert.Equal(t, len(claim1.Claims), len(mp.claims))
|
||||
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
|
||||
|
||||
require.NoError(t, mp.Add(txm1, &FeerStub{}))
|
||||
require.NoError(t, mp.Add(txc2, &FeerStub{}))
|
||||
assert.Equal(t, len(txm1.Inputs)+len(txm2.Inputs), len(mp.inputs))
|
||||
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
|
||||
assert.Equal(t, len(claim1.Claims)+len(claim2.Claims), len(mp.claims))
|
||||
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
|
||||
|
||||
mp.RemoveStale(func(t *transaction.Transaction) bool {
|
||||
if t.Hash() == txc1.Hash() || t.Hash() == txm2.Hash() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, len(txm1.Inputs), len(mp.inputs))
|
||||
assert.True(t, sort.SliceIsSorted(mp.inputs, mpLessInputs))
|
||||
assert.Equal(t, len(claim2.Claims), len(mp.claims))
|
||||
assert.True(t, sort.SliceIsSorted(mp.claims, mpLessClaims))
|
||||
}
|
||||
|
||||
func TestMemPoolVerifyInputs(t *testing.T) {
|
||||
mp := NewMemPool(10)
|
||||
tx := newMinerTX(1)
|
||||
inhash1 := random.Uint256()
|
||||
|
@ -84,6 +150,33 @@ func TestMemPoolVerify(t *testing.T) {
|
|||
require.Error(t, mp.Add(tx3, &FeerStub{}))
|
||||
}
|
||||
|
||||
func TestMemPoolVerifyClaims(t *testing.T) {
|
||||
mp := NewMemPool(50)
|
||||
tx1, claim1 := newClaimTX()
|
||||
hash1, err := util.Uint256DecodeStringBE("a83ba6ede918a501558d3170a124324aedc89909e64c4ff2c6f863094f980b25")
|
||||
require.NoError(t, err)
|
||||
hash2, err := util.Uint256DecodeStringBE("629397158f852e838077bb2715b13a2e29b0a51c2157e5466321b70ed7904ce9")
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < 10; i++ {
|
||||
claim1.Claims = append(claim1.Claims, transaction.Input{PrevHash: hash1, PrevIndex: uint16(i)})
|
||||
claim1.Claims = append(claim1.Claims, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i)})
|
||||
}
|
||||
require.Equal(t, true, mp.Verify(tx1))
|
||||
require.NoError(t, mp.Add(tx1, &FeerStub{}))
|
||||
|
||||
tx2, claim2 := newClaimTX()
|
||||
for i := 0; i < 10; i++ {
|
||||
claim2.Claims = append(claim2.Claims, transaction.Input{PrevHash: hash2, PrevIndex: uint16(i + 10)})
|
||||
}
|
||||
require.Equal(t, true, mp.Verify(tx2))
|
||||
require.NoError(t, mp.Add(tx2, &FeerStub{}))
|
||||
|
||||
tx3, claim3 := newClaimTX()
|
||||
claim3.Claims = append(claim3.Claims, transaction.Input{PrevHash: hash1, PrevIndex: 0})
|
||||
require.Equal(t, false, mp.Verify(tx3))
|
||||
require.Error(t, mp.Add(tx3, &FeerStub{}))
|
||||
}
|
||||
|
||||
func newMinerTX(i uint32) *transaction.Transaction {
|
||||
return &transaction.Transaction{
|
||||
Type: transaction.MinerType,
|
||||
|
@ -91,6 +184,14 @@ func newMinerTX(i uint32) *transaction.Transaction {
|
|||
}
|
||||
}
|
||||
|
||||
func newClaimTX() (*transaction.Transaction, *transaction.ClaimTX) {
|
||||
cl := &transaction.ClaimTX{}
|
||||
return &transaction.Transaction{
|
||||
Type: transaction.ClaimType,
|
||||
Data: cl,
|
||||
}, cl
|
||||
}
|
||||
|
||||
func TestOverCapacity(t *testing.T) {
|
||||
var fs = &FeerStub{lowPriority: true}
|
||||
const mempoolSize = 10
|
||||
|
|
|
@ -28,6 +28,16 @@ func (in *Input) EncodeBinary(bw *io.BinWriter) {
|
|||
bw.WriteU16LE(in.PrevIndex)
|
||||
}
|
||||
|
||||
// Cmp compares two Inputs by their hash and index allowing to make a set of
|
||||
// transactions ordered.
|
||||
func (in *Input) Cmp(other *Input) int {
|
||||
hashcmp := in.PrevHash.CompareTo(other.PrevHash)
|
||||
if hashcmp == 0 {
|
||||
return int(in.PrevIndex) - int(other.PrevIndex)
|
||||
}
|
||||
return hashcmp
|
||||
}
|
||||
|
||||
// MapInputsToSorted maps given slice of inputs into a new slice of pointers
|
||||
// to inputs sorted by their PrevHash and PrevIndex.
|
||||
func MapInputsToSorted(ins []Input) []*Input {
|
||||
|
@ -36,11 +46,7 @@ func MapInputsToSorted(ins []Input) []*Input {
|
|||
ptrs[i] = &ins[i]
|
||||
}
|
||||
sort.Slice(ptrs, func(i, j int) bool {
|
||||
hashcmp := ptrs[i].PrevHash.CompareTo(ptrs[j].PrevHash)
|
||||
if hashcmp == 0 {
|
||||
return ptrs[i].PrevIndex < ptrs[j].PrevIndex
|
||||
}
|
||||
return hashcmp < 0
|
||||
return ptrs[i].Cmp(ptrs[j]) < 0
|
||||
})
|
||||
return ptrs
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue