Merge pull request #1507 from nspcc-dev/conflicts_attr

core: implement Conflicts transaction attribute
This commit is contained in:
Roman Khimov 2020-10-29 16:54:58 +03:00 committed by GitHub
commit d4da811d12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 521 additions and 37 deletions

View file

@ -58,6 +58,10 @@ var (
// ErrInvalidBlockIndex is returned when trying to add block with index // ErrInvalidBlockIndex is returned when trying to add block with index
// other than expected height of the blockchain. // other than expected height of the blockchain.
ErrInvalidBlockIndex error = errors.New("invalid block index") ErrInvalidBlockIndex error = errors.New("invalid block index")
// ErrHasConflicts is returned when trying to add some transaction which
// conflicts with other transaction in the chain or pool according to
// Conflicts attribute.
ErrHasConflicts = errors.New("has conflicts")
) )
var ( var (
persistInterval = 1 * time.Second persistInterval = 1 * time.Second
@ -617,6 +621,18 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error
return fmt.Errorf("failed to store tx exec result: %w", err) return fmt.Errorf("failed to store tx exec result: %w", err)
} }
writeBuf.Reset() writeBuf.Reset()
if bc.config.P2PSigExtensions {
for _, attr := range tx.GetAttributes(transaction.ConflictsT) {
hash := attr.Value.(*transaction.Conflicts).Hash
dummyTx := transaction.NewTrimmedTX(hash)
dummyTx.Version = transaction.DummyVersion
if err = cache.StoreAsTransaction(dummyTx, block.Index, writeBuf); err != nil {
return fmt.Errorf("failed to store conflicting transaction %s for transaction %s: %w", hash.StringLE(), tx.Hash().StringLE(), err)
}
writeBuf.Reset()
}
}
} }
aer, err := bc.runPersist(bc.contracts.GetPostPersistScript(), block, cache) aer, err := bc.runPersist(bc.contracts.GetPostPersistScript(), block, cache)
@ -984,7 +1000,10 @@ func (bc *Blockchain) GetHeader(hash util.Uint256) (*block.Header, error) {
// HasTransaction returns true if the blockchain contains he given // HasTransaction returns true if the blockchain contains he given
// transaction hash. // transaction hash.
func (bc *Blockchain) HasTransaction(hash util.Uint256) bool { func (bc *Blockchain) HasTransaction(hash util.Uint256) bool {
return bc.memPool.ContainsKey(hash) || bc.dao.HasTransaction(hash) if bc.memPool.ContainsKey(hash) {
return true
}
return bc.dao.HasTransaction(hash) == dao.ErrAlreadyExists
} }
// HasBlock returns true if the blockchain contains the given // HasBlock returns true if the blockchain contains the given
@ -1227,8 +1246,16 @@ func (bc *Blockchain) verifyAndPoolTx(t *transaction.Transaction, pool *mempool.
if netFee < 0 { if netFee < 0 {
return fmt.Errorf("%w: net fee is %v, need %v", ErrTxSmallNetworkFee, t.NetworkFee, needNetworkFee) return fmt.Errorf("%w: net fee is %v, need %v", ErrTxSmallNetworkFee, t.NetworkFee, needNetworkFee)
} }
if bc.dao.HasTransaction(t.Hash()) { // check that current tx wasn't included in the conflicts attributes of some other transaction which is already in the chain
if err := bc.dao.HasTransaction(t.Hash()); err != nil {
switch {
case errors.Is(err, dao.ErrAlreadyExists):
return fmt.Errorf("blockchain: %w", ErrAlreadyExists) return fmt.Errorf("blockchain: %w", ErrAlreadyExists)
case errors.Is(err, dao.ErrHasConflicts):
return fmt.Errorf("blockchain: %w", ErrHasConflicts)
default:
return err
}
} }
err := bc.verifyTxWitnesses(t, nil) err := bc.verifyTxWitnesses(t, nil)
if err != nil { if err != nil {
@ -1248,6 +1275,8 @@ func (bc *Blockchain) verifyAndPoolTx(t *transaction.Transaction, pool *mempool.
return ErrInsufficientFunds return ErrInsufficientFunds
case errors.Is(err, mempool.ErrOOM): case errors.Is(err, mempool.ErrOOM):
return ErrOOM return ErrOOM
case errors.Is(err, mempool.ErrConflictsAttribute):
return fmt.Errorf("mempool: %w: %s", ErrHasConflicts, err)
default: default:
return err return err
} }
@ -1303,6 +1332,14 @@ func (bc *Blockchain) verifyTxAttributes(tx *transaction.Transaction) error {
if height := bc.BlockHeight(); height < nvb.Height { if height := bc.BlockHeight(); height < nvb.Height {
return fmt.Errorf("%w: NotValidBefore = %d, current height = %d", ErrTxNotYetValid, nvb.Height, height) return fmt.Errorf("%w: NotValidBefore = %d, current height = %d", ErrTxNotYetValid, nvb.Height, height)
} }
case transaction.ConflictsT:
if !bc.config.P2PSigExtensions {
return errors.New("Conflicts attribute was found, but P2PSigExtensions are disabled")
}
conflicts := tx.Attributes[i].Value.(*transaction.Conflicts)
if err := bc.dao.HasTransaction(conflicts.Hash); errors.Is(err, dao.ErrAlreadyExists) {
return fmt.Errorf("conflicting transaction %s is already on chain", conflicts.Hash.StringLE())
}
default: default:
if !bc.config.ReservedAttributes && attrType >= transaction.ReservedLowerBound && attrType <= transaction.ReservedUpperBound { if !bc.config.ReservedAttributes && attrType >= transaction.ReservedLowerBound && attrType <= transaction.ReservedUpperBound {
return errors.New("attribute of reserved type was found, but ReservedAttributes are disabled") return errors.New("attribute of reserved type was found, but ReservedAttributes are disabled")
@ -1326,10 +1363,10 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction, txpool *memp
return false return false
} }
if txpool == nil { if txpool == nil {
if bc.dao.HasTransaction(t.Hash()) { if bc.dao.HasTransaction(t.Hash()) != nil {
return false return false
} }
} else if txpool.ContainsKey(t.Hash()) { } else if txpool.HasConflicts(t, bc) {
return false return false
} }
if err := bc.verifyTxAttributes(t); err != nil { if err := bc.verifyTxAttributes(t); err != nil {
@ -1650,3 +1687,8 @@ func (bc *Blockchain) newInteropContext(trigger trigger.Type, d dao.DAO, block *
} }
return ic return ic
} }
// P2PSigExtensionsEnabled defines whether P2P signature extensions are enabled.
func (bc *Blockchain) P2PSigExtensionsEnabled() bool {
return bc.config.P2PSigExtensions
}

View file

@ -18,6 +18,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/internal/testchain" "github.com/nspcc-dev/neo-go/pkg/internal/testchain"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
@ -629,6 +630,64 @@ func TestVerifyTx(t *testing.T) {
require.NoError(t, bc.VerifyTx(tx)) require.NoError(t, bc.VerifyTx(tx))
}) })
}) })
t.Run("Conflicts", func(t *testing.T) {
getConflictsTx := func(hashes ...util.Uint256) *transaction.Transaction {
tx := bc.newTestTx(h, testScript)
tx.Attributes = make([]transaction.Attribute, len(hashes))
for i, h := range hashes {
tx.Attributes[i] = transaction.Attribute{
Type: transaction.ConflictsT,
Value: &transaction.Conflicts{
Hash: h,
},
}
}
tx.NetworkFee += 4_000_000 // multisig check
tx.Signers = []transaction.Signer{{
Account: testchain.CommitteeScriptHash(),
Scopes: transaction.None,
}}
rawScript := testchain.CommitteeVerificationScript()
require.NoError(t, err)
size := io.GetVarSize(tx)
netFee, sizeDelta := fee.Calculate(rawScript)
tx.NetworkFee += netFee
tx.NetworkFee += int64(size+sizeDelta) * bc.FeePerByte()
data := tx.GetSignedPart()
tx.Scripts = []transaction.Witness{{
InvocationScript: testchain.SignCommittee(data),
VerificationScript: rawScript,
}}
return tx
}
t.Run("disabled", func(t *testing.T) {
bc.config.P2PSigExtensions = false
tx := getConflictsTx(util.Uint256{1, 2, 3})
require.Error(t, bc.VerifyTx(tx))
})
t.Run("enabled", func(t *testing.T) {
bc.config.P2PSigExtensions = true
t.Run("dummy on-chain conflict", func(t *testing.T) {
tx := bc.newTestTx(h, testScript)
require.NoError(t, accs[0].SignTx(tx))
dummyTx := transaction.NewTrimmedTX(tx.Hash())
dummyTx.Version = transaction.DummyVersion
require.NoError(t, bc.dao.StoreAsTransaction(dummyTx, bc.blockHeight, nil))
require.True(t, errors.Is(bc.VerifyTx(tx), ErrHasConflicts))
})
t.Run("attribute on-chain conflict", func(t *testing.T) {
b, err := bc.GetBlock(bc.GetHeaderHash(0))
require.NoError(t, err)
conflictsHash := b.Transactions[0].Hash()
tx := getConflictsTx(conflictsHash)
require.Error(t, bc.VerifyTx(tx))
})
t.Run("positive", func(t *testing.T) {
tx := getConflictsTx(random.Uint256())
require.NoError(t, bc.VerifyTx(tx))
})
})
})
}) })
} }

View file

@ -3,6 +3,7 @@ package dao
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"sort" "sort"
@ -16,6 +17,15 @@ import (
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
) )
// HasTransaction errors
var (
// ErrAlreadyExists is returned when transaction exists in dao.
ErrAlreadyExists = errors.New("transaction already exists")
// ErrHasConflicts is returned when transaction is in the list of conflicting
// transactions which are already in dao.
ErrHasConflicts = errors.New("transaction has conflicts")
)
// DAO is a data access object. // DAO is a data access object.
type DAO interface { type DAO interface {
AppendNEP5Transfer(acc util.Uint160, index uint32, tr *state.NEP5Transfer) (bool, error) AppendNEP5Transfer(acc util.Uint160, index uint32, tr *state.NEP5Transfer) (bool, error)
@ -42,7 +52,7 @@ type DAO interface {
GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error)
GetVersion() (string, error) GetVersion() (string, error)
GetWrapped() DAO GetWrapped() DAO
HasTransaction(hash util.Uint256) bool HasTransaction(hash util.Uint256) error
Persist() (int, error) Persist() (int, error)
PutAppExecResult(aer *state.AppExecResult, buf *io.BufBinWriter) error PutAppExecResult(aer *state.AppExecResult, buf *io.BufBinWriter) error
PutContractState(cs *state.Contract) error PutContractState(cs *state.Contract) error
@ -515,13 +525,19 @@ func (dao *Simple) GetHeaderHashes() ([]util.Uint256, error) {
} }
// GetTransaction returns Transaction and its height by the given hash // GetTransaction returns Transaction and its height by the given hash
// if it exists in the store. // if it exists in the store. It does not return dummy transactions.
func (dao *Simple) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) { func (dao *Simple) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) {
key := storage.AppendPrefix(storage.DataTransaction, hash.BytesLE()) key := storage.AppendPrefix(storage.DataTransaction, hash.BytesLE())
b, err := dao.Store.Get(key) b, err := dao.Store.Get(key)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
if len(b) < 5 {
return nil, 0, errors.New("bad transaction bytes")
}
if b[4] == transaction.DummyVersion {
return nil, 0, storage.ErrKeyNotFound
}
r := io.NewBinReaderFromBuf(b) r := io.NewBinReaderFromBuf(b)
var height = r.ReadU32LE() var height = r.ReadU32LE()
@ -558,14 +574,23 @@ func read2000Uint256Hashes(b []byte) ([]util.Uint256, error) {
return hashes, nil return hashes, nil
} }
// HasTransaction returns true if the given store contains the given // HasTransaction returns nil if the given store does not contain the given
// Transaction hash. // Transaction hash. It returns an error in case if transaction is in chain
func (dao *Simple) HasTransaction(hash util.Uint256) bool { // or in the list of conflicting transactions.
func (dao *Simple) HasTransaction(hash util.Uint256) error {
key := storage.AppendPrefix(storage.DataTransaction, hash.BytesLE()) key := storage.AppendPrefix(storage.DataTransaction, hash.BytesLE())
if _, err := dao.Store.Get(key); err == nil { bytes, err := dao.Store.Get(key)
return true if err != nil {
return nil
} }
return false
if len(bytes) < 5 {
return nil
}
if bytes[4] == transaction.DummyVersion {
return ErrHasConflicts
}
return ErrAlreadyExists
} }
// StoreAsBlock stores given block as DataBlock. It can reuse given buffer for // StoreAsBlock stores given block as DataBlock. It can reuse given buffer for
@ -609,7 +634,11 @@ func (dao *Simple) StoreAsTransaction(tx *transaction.Transaction, index uint32,
buf = io.NewBufBinWriter() buf = io.NewBufBinWriter()
} }
buf.WriteU32LE(index) buf.WriteU32LE(index)
if tx.Version == transaction.DummyVersion {
buf.BinWriter.WriteB(tx.Version)
} else {
tx.EncodeBinary(buf.BinWriter) tx.EncodeBinary(buf.BinWriter)
}
if buf.Err != nil { if buf.Err != nil {
return buf.Err return buf.Err
} }

View file

@ -189,8 +189,8 @@ func TestStoreAsTransaction(t *testing.T) {
hash := tx.Hash() hash := tx.Hash()
err := dao.StoreAsTransaction(tx, 0, nil) err := dao.StoreAsTransaction(tx, 0, nil)
require.NoError(t, err) require.NoError(t, err)
hasTransaction := dao.HasTransaction(hash) err = dao.HasTransaction(hash)
require.True(t, hasTransaction) require.NotNil(t, err)
} }
func TestMakeStorageItemKey(t *testing.T) { func TestMakeStorageItemKey(t *testing.T) {

View file

@ -11,4 +11,5 @@ type Feer interface {
FeePerByte() int64 FeePerByte() int64
GetUtilityTokenBalance(util.Uint160) *big.Int GetUtilityTokenBalance(util.Uint160) *big.Int
BlockHeight() uint32 BlockHeight() uint32
P2PSigExtensionsEnabled() bool
} }

View file

@ -2,6 +2,7 @@ package mempool
import ( import (
"errors" "errors"
"fmt"
"math/big" "math/big"
"sort" "sort"
"sync" "sync"
@ -25,6 +26,9 @@ var (
// ErrOOM is returned when transaction just doesn't fit in the memory // ErrOOM is returned when transaction just doesn't fit in the memory
// pool because of its capacity constraints. // pool because of its capacity constraints.
ErrOOM = errors.New("out of memory") ErrOOM = errors.New("out of memory")
// ErrConflictsAttribute is returned when transaction conflicts with other transactions
// due to its (or theirs) Conflicts attributes.
ErrConflictsAttribute = errors.New("conflicts with memory pool due to Conflicts attribute")
) )
// item represents a transaction in the the Memory pool. // item represents a transaction in the the Memory pool.
@ -49,6 +53,8 @@ type Pool struct {
verifiedMap map[util.Uint256]*transaction.Transaction verifiedMap map[util.Uint256]*transaction.Transaction
verifiedTxes items verifiedTxes items
fees map[util.Uint160]utilityBalanceAndFees fees map[util.Uint160]utilityBalanceAndFees
// conflicts is a map of hashes of transactions which are conflicting with the mempooled ones.
conflicts map[util.Uint256][]util.Uint256
capacity int capacity int
feePerByte int64 feePerByte int64
@ -108,6 +114,29 @@ func (mp *Pool) containsKey(hash util.Uint256) bool {
return false return false
} }
// HasConflicts returns true if transaction is already in pool or in the Conflicts attributes
// of pooled transactions or has Conflicts attributes for pooled transactions.
func (mp *Pool) HasConflicts(t *transaction.Transaction, fee Feer) bool {
mp.lock.RLock()
defer mp.lock.RUnlock()
if mp.containsKey(t.Hash()) {
return true
}
if fee.P2PSigExtensionsEnabled() {
// do not check sender's signature and fee
if _, ok := mp.conflicts[t.Hash()]; ok {
return true
}
for _, attr := range t.GetAttributes(transaction.ConflictsT) {
if mp.containsKey(attr.Value.(*transaction.Conflicts).Hash) {
return true
}
}
}
return false
}
// tryAddSendersFee tries to add system fee and network fee to the total sender`s fee in mempool // tryAddSendersFee tries to add system fee and network fee to the total sender`s fee in mempool
// and returns false if both balance check is required and sender has not enough GAS to pay // and returns false if both balance check is required and sender has not enough GAS to pay
func (mp *Pool) tryAddSendersFee(tx *transaction.Transaction, feer Feer, needCheck bool) bool { func (mp *Pool) tryAddSendersFee(tx *transaction.Transaction, feer Feer, needCheck bool) bool {
@ -154,13 +183,19 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
mp.lock.Unlock() mp.lock.Unlock()
return ErrDup return ErrDup
} }
err := mp.checkTxConflicts(t, fee) conflictsToBeRemoved, err := mp.checkTxConflicts(t, fee)
if err != nil { if err != nil {
mp.lock.Unlock() mp.lock.Unlock()
return err return err
} }
mp.verifiedMap[t.Hash()] = t mp.verifiedMap[t.Hash()] = t
if fee.P2PSigExtensionsEnabled() {
// Remove conflicting transactions.
for _, conflictingTx := range conflictsToBeRemoved {
mp.removeInternal(conflictingTx.Hash(), fee)
}
}
// Insert into sorted array (from max to min, that could also be done // Insert into sorted array (from max to min, that could also be done
// using sort.Sort(sort.Reverse()), but it incurs more overhead. Notice // using sort.Sort(sort.Reverse()), but it incurs more overhead. Notice
// also that we're searching for position that is strictly more // also that we're searching for position that is strictly more
@ -181,6 +216,9 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
// Ditch the last one. // Ditch the last one.
unlucky := mp.verifiedTxes[len(mp.verifiedTxes)-1] unlucky := mp.verifiedTxes[len(mp.verifiedTxes)-1]
delete(mp.verifiedMap, unlucky.txn.Hash()) delete(mp.verifiedMap, unlucky.txn.Hash())
if fee.P2PSigExtensionsEnabled() {
mp.removeConflictsOf(unlucky.txn)
}
mp.verifiedTxes[len(mp.verifiedTxes)-1] = pItem mp.verifiedTxes[len(mp.verifiedTxes)-1] = pItem
} else { } else {
mp.verifiedTxes = append(mp.verifiedTxes, pItem) mp.verifiedTxes = append(mp.verifiedTxes, pItem)
@ -189,6 +227,13 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
copy(mp.verifiedTxes[n+1:], mp.verifiedTxes[n:]) copy(mp.verifiedTxes[n+1:], mp.verifiedTxes[n:])
mp.verifiedTxes[n] = pItem mp.verifiedTxes[n] = pItem
} }
if fee.P2PSigExtensionsEnabled() {
// Add conflicting hashes to the mp.conflicts list.
for _, attr := range t.GetAttributes(transaction.ConflictsT) {
hash := attr.Value.(*transaction.Conflicts).Hash
mp.conflicts[hash] = append(mp.conflicts[hash], t.Hash())
}
}
// we already checked balance in checkTxConflicts, so don't need to check again // we already checked balance in checkTxConflicts, so don't need to check again
mp.tryAddSendersFee(pItem.txn, fee, false) mp.tryAddSendersFee(pItem.txn, fee, false)
@ -199,8 +244,14 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error {
// Remove removes an item from the mempool, if it exists there (and does // Remove removes an item from the mempool, if it exists there (and does
// nothing if it doesn't). // nothing if it doesn't).
func (mp *Pool) Remove(hash util.Uint256) { func (mp *Pool) Remove(hash util.Uint256, feer Feer) {
mp.lock.Lock() mp.lock.Lock()
mp.removeInternal(hash, feer)
mp.lock.Unlock()
}
// removeInternal is an internal unlocked representation of Remove
func (mp *Pool) removeInternal(hash util.Uint256, feer Feer) {
if tx, ok := mp.verifiedMap[hash]; ok { if tx, ok := mp.verifiedMap[hash]; ok {
var num int var num int
delete(mp.verifiedMap, hash) delete(mp.verifiedMap, hash)
@ -217,9 +268,12 @@ func (mp *Pool) Remove(hash util.Uint256) {
senderFee := mp.fees[tx.Sender()] senderFee := mp.fees[tx.Sender()]
senderFee.feeSum.Sub(senderFee.feeSum, big.NewInt(tx.SystemFee+tx.NetworkFee)) senderFee.feeSum.Sub(senderFee.feeSum, big.NewInt(tx.SystemFee+tx.NetworkFee))
mp.fees[tx.Sender()] = senderFee mp.fees[tx.Sender()] = senderFee
if feer.P2PSigExtensionsEnabled() {
// remove all conflicting hashes from mp.conflicts list
mp.removeConflictsOf(tx)
}
} }
updateMempoolMetrics(len(mp.verifiedTxes)) updateMempoolMetrics(len(mp.verifiedTxes))
mp.lock.Unlock()
} }
// RemoveStale filters verified transactions through the given function keeping // RemoveStale filters verified transactions through the given function keeping
@ -232,9 +286,18 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer)
// because items are iterated one-by-one in increasing order. // because items are iterated one-by-one in increasing order.
newVerifiedTxes := mp.verifiedTxes[:0] newVerifiedTxes := mp.verifiedTxes[:0]
mp.fees = make(map[util.Uint160]utilityBalanceAndFees) // it'd be nice to reuse existing map, but we can't easily clear it mp.fees = make(map[util.Uint160]utilityBalanceAndFees) // it'd be nice to reuse existing map, but we can't easily clear it
if feer.P2PSigExtensionsEnabled() {
mp.conflicts = make(map[util.Uint256][]util.Uint256)
}
for _, itm := range mp.verifiedTxes { for _, itm := range mp.verifiedTxes {
if isOK(itm.txn) && mp.checkPolicy(itm.txn, policyChanged) && mp.tryAddSendersFee(itm.txn, feer, true) { if isOK(itm.txn) && mp.checkPolicy(itm.txn, policyChanged) && mp.tryAddSendersFee(itm.txn, feer, true) {
newVerifiedTxes = append(newVerifiedTxes, itm) newVerifiedTxes = append(newVerifiedTxes, itm)
if feer.P2PSigExtensionsEnabled() {
for _, attr := range itm.txn.GetAttributes(transaction.ConflictsT) {
hash := attr.Value.(*transaction.Conflicts).Hash
mp.conflicts[hash] = append(mp.conflicts[hash], itm.txn.Hash())
}
}
} else { } else {
delete(mp.verifiedMap, itm.txn.Hash()) delete(mp.verifiedMap, itm.txn.Hash())
} }
@ -269,6 +332,7 @@ func New(capacity int) *Pool {
verifiedTxes: make([]item, 0, capacity), verifiedTxes: make([]item, 0, capacity),
capacity: capacity, capacity: capacity,
fees: make(map[util.Uint160]utilityBalanceAndFees), fees: make(map[util.Uint160]utilityBalanceAndFees),
conflicts: make(map[util.Uint256][]util.Uint256),
} }
} }
@ -297,15 +361,59 @@ func (mp *Pool) GetVerifiedTransactions() []*transaction.Transaction {
return t return t
} }
// checkTxConflicts is an internal unprotected version of Verify. // checkTxConflicts is an internal unprotected version of Verify. It takes into
func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) error { // consideration conflicting transactions which are about to be removed from mempool.
senderFee, ok := mp.fees[tx.Sender()] func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*transaction.Transaction, error) {
actualSenderFee, ok := mp.fees[tx.Sender()]
if !ok { if !ok {
senderFee.balance = fee.GetUtilityTokenBalance(tx.Sender()) actualSenderFee.balance = fee.GetUtilityTokenBalance(tx.Sender())
senderFee.feeSum = big.NewInt(0) actualSenderFee.feeSum = big.NewInt(0)
} }
_, err := checkBalance(tx, senderFee)
return err var expectedSenderFee utilityBalanceAndFees
// Check Conflicts attributes.
var conflictsToBeRemoved []*transaction.Transaction
if fee.P2PSigExtensionsEnabled() {
// Step 1: check if `tx` was in attributes of mempooled transactions.
if conflictingHashes, ok := mp.conflicts[tx.Hash()]; ok {
for _, hash := range conflictingHashes {
existingTx := mp.verifiedMap[hash]
if existingTx.HasSigner(tx.Sender()) && existingTx.NetworkFee > tx.NetworkFee {
return nil, fmt.Errorf("%w: conflicting transaction %s has bigger network fee", ErrConflictsAttribute, existingTx.Hash().StringBE())
}
conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx)
}
}
// Step 2: check if mempooled transactions were in `tx`'s attributes.
for _, attr := range tx.GetAttributes(transaction.ConflictsT) {
hash := attr.Value.(*transaction.Conflicts).Hash
existingTx, ok := mp.verifiedMap[hash]
if !ok {
continue
}
if !tx.HasSigner(existingTx.Sender()) {
return nil, fmt.Errorf("%w: not signed by the sender of conflicting transaction %s", ErrConflictsAttribute, existingTx.Hash().StringBE())
}
if existingTx.NetworkFee >= tx.NetworkFee {
return nil, fmt.Errorf("%w: conflicting transaction %s has bigger or equal network fee", ErrConflictsAttribute, existingTx.Hash().StringBE())
}
conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx)
}
// Step 3: take into account sender's conflicting transactions before balance check.
expectedSenderFee = utilityBalanceAndFees{
balance: new(big.Int).Set(actualSenderFee.balance),
feeSum: new(big.Int).Set(actualSenderFee.feeSum),
}
for _, conflictingTx := range conflictsToBeRemoved {
if conflictingTx.Sender().Equals(tx.Sender()) {
expectedSenderFee.feeSum.Sub(expectedSenderFee.feeSum, big.NewInt(conflictingTx.SystemFee+conflictingTx.NetworkFee))
}
}
} else {
expectedSenderFee = actualSenderFee
}
_, err := checkBalance(tx, expectedSenderFee)
return conflictsToBeRemoved, err
} }
// Verify checks if a Sender of tx is able to pay for it (and all the other // Verify checks if a Sender of tx is able to pay for it (and all the other
@ -315,5 +423,26 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) error {
func (mp *Pool) Verify(tx *transaction.Transaction, feer Feer) bool { func (mp *Pool) Verify(tx *transaction.Transaction, feer Feer) bool {
mp.lock.RLock() mp.lock.RLock()
defer mp.lock.RUnlock() defer mp.lock.RUnlock()
return mp.checkTxConflicts(tx, feer) == nil _, err := mp.checkTxConflicts(tx, feer)
return err == nil
}
// removeConflictsOf removes hash of the given transaction from the conflicts list
// for each Conflicts attribute.
func (mp *Pool) removeConflictsOf(tx *transaction.Transaction) {
// remove all conflicting hashes from mp.conflicts list
for _, attr := range tx.GetAttributes(transaction.ConflictsT) {
conflictsHash := attr.Value.(*transaction.Conflicts).Hash
if len(mp.conflicts[conflictsHash]) == 1 {
delete(mp.conflicts, conflictsHash)
continue
}
for i, existingHash := range mp.conflicts[conflictsHash] {
if existingHash == tx.Hash() {
// tx.Hash can occur in the conflicting hashes array only once, because we can't add the same transaction to the mempol twice
mp.conflicts[conflictsHash] = append(mp.conflicts[conflictsHash][:i], mp.conflicts[conflictsHash][i+1:]...)
break
}
}
}
} }

View file

@ -1,12 +1,14 @@
package mempool package mempool
import ( import (
"errors"
"math/big" "math/big"
"sort" "sort"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -15,6 +17,7 @@ import (
type FeerStub struct { type FeerStub struct {
feePerByte int64 feePerByte int64
p2pSigExt bool
} }
var balance = big.NewInt(10000000) var balance = big.NewInt(10000000)
@ -31,6 +34,10 @@ func (fs *FeerStub) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int {
return balance return balance
} }
func (fs *FeerStub) P2PSigExtensionsEnabled() bool {
return fs.p2pSigExt
}
func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) {
mp := New(10) mp := New(10)
tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0)
@ -44,7 +51,7 @@ func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) {
tx2, ok := mp.TryGetValue(tx.Hash()) tx2, ok := mp.TryGetValue(tx.Hash())
require.Equal(t, true, ok) require.Equal(t, true, ok)
require.Equal(t, tx, tx2) require.Equal(t, tx, tx2)
mp.Remove(tx.Hash()) mp.Remove(tx.Hash(), fs)
_, ok = mp.TryGetValue(tx.Hash()) _, ok = mp.TryGetValue(tx.Hash())
require.Equal(t, false, ok) require.Equal(t, false, ok)
// Make sure nothing left in the mempool after removal. // Make sure nothing left in the mempool after removal.
@ -148,7 +155,7 @@ func TestGetVerified(t *testing.T) {
require.Equal(t, mempoolSize, len(verTxes)) require.Equal(t, mempoolSize, len(verTxes))
require.ElementsMatch(t, txes, verTxes) require.ElementsMatch(t, txes, verTxes)
for _, tx := range txes { for _, tx := range txes {
mp.Remove(tx.Hash()) mp.Remove(tx.Hash(), fs)
} }
verTxes = mp.GetVerifiedTransactions() verTxes = mp.GetVerifiedTransactions()
require.Equal(t, 0, len(verTxes)) require.Equal(t, 0, len(verTxes))
@ -297,3 +304,122 @@ func TestMempoolItemsOrder(t *testing.T) {
require.True(t, item3.CompareTo(item4) > 0) require.True(t, item3.CompareTo(item4) > 0)
require.True(t, item4.CompareTo(item3) < 0) require.True(t, item4.CompareTo(item3) < 0)
} }
func TestMempoolAddRemoveConflicts(t *testing.T) {
capacity := 6
mp := New(capacity)
var fs = &FeerStub{
p2pSigExt: true,
}
getConflictsTx := func(netFee int64, hashes ...util.Uint256) *transaction.Transaction {
tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0)
tx.NetworkFee = netFee
tx.Nonce = uint32(random.Int(0, 1e4))
tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}}
tx.Attributes = make([]transaction.Attribute, len(hashes))
for i, h := range hashes {
tx.Attributes[i] = transaction.Attribute{
Type: transaction.ConflictsT,
Value: &transaction.Conflicts{
Hash: h,
},
}
}
_, ok := mp.TryGetValue(tx.Hash())
require.Equal(t, false, ok)
return tx
}
// tx1 in mempool and does not conflicts with anyone
smallNetFee := int64(3)
tx1 := getConflictsTx(smallNetFee)
require.NoError(t, mp.Add(tx1, fs))
// tx2 conflicts with tx1 and has smaller netfee (Step 2, negative)
tx2 := getConflictsTx(smallNetFee-1, tx1.Hash())
require.True(t, errors.Is(mp.Add(tx2, fs), ErrConflictsAttribute))
// tx3 conflicts with mempooled tx1 and has larger netfee => tx1 should be replaced by tx3 (Step 2, positive)
tx3 := getConflictsTx(smallNetFee+1, tx1.Hash())
require.NoError(t, mp.Add(tx3, fs))
require.Equal(t, 1, mp.Count())
require.Equal(t, 1, len(mp.conflicts))
require.Equal(t, []util.Uint256{tx3.Hash()}, mp.conflicts[tx1.Hash()])
// tx1 still does not conflicts with anyone, but tx3 is mempooled, conflicts with tx1
// and has larger netfee => tx1 shouldn't be added again (Step 1, negative)
require.True(t, errors.Is(mp.Add(tx1, fs), ErrConflictsAttribute))
// tx2 can now safely be added because conflicting tx1 is not in mempool => we
// cannot check that tx2 is signed by tx1.Sender
require.NoError(t, mp.Add(tx2, fs))
require.Equal(t, 1, len(mp.conflicts))
require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()])
// mempooled tx4 conflicts with tx5, but tx4 has smaller netfee => tx4 should be replaced by tx5 (Step 1, positive)
tx5 := getConflictsTx(smallNetFee + 1)
tx4 := getConflictsTx(smallNetFee, tx5.Hash())
require.NoError(t, mp.Add(tx4, fs)) // unverified
require.Equal(t, 2, len(mp.conflicts))
require.Equal(t, []util.Uint256{tx4.Hash()}, mp.conflicts[tx5.Hash()])
require.NoError(t, mp.Add(tx5, fs))
// tx5 does not conflict with anyone
require.Equal(t, 1, len(mp.conflicts))
// multiple conflicts in attributes of single transaction
tx6 := getConflictsTx(smallNetFee)
tx7 := getConflictsTx(smallNetFee)
tx8 := getConflictsTx(smallNetFee)
// need small network fee later
tx9 := getConflictsTx(smallNetFee-2, tx6.Hash(), tx7.Hash(), tx8.Hash())
require.NoError(t, mp.Add(tx9, fs))
require.Equal(t, 4, len(mp.conflicts))
require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx6.Hash()])
require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx7.Hash()])
require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx8.Hash()])
require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()])
// multiple conflicts in attributes of multiple transactions
tx10 := getConflictsTx(smallNetFee, tx6.Hash())
tx11 := getConflictsTx(smallNetFee, tx6.Hash())
require.NoError(t, mp.Add(tx10, fs)) // unverified, because tx6 is not in the pool
require.NoError(t, mp.Add(tx11, fs)) // unverified, because tx6 is not in the pool
require.Equal(t, 4, len(mp.conflicts))
require.Equal(t, []util.Uint256{tx9.Hash(), tx10.Hash(), tx11.Hash()}, mp.conflicts[tx6.Hash()])
require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx7.Hash()])
require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx8.Hash()])
require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()])
// reach capacity, remove less prioritised tx9 with its multiple conflicts
require.Equal(t, capacity, len(mp.verifiedTxes))
tx12 := getConflictsTx(smallNetFee + 2)
require.NoError(t, mp.Add(tx12, fs))
require.Equal(t, 2, len(mp.conflicts))
require.Equal(t, []util.Uint256{tx10.Hash(), tx11.Hash()}, mp.conflicts[tx6.Hash()])
require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()])
// manually remove tx11 with its single conflict
mp.Remove(tx11.Hash(), fs)
require.Equal(t, 2, len(mp.conflicts))
require.Equal(t, []util.Uint256{tx10.Hash()}, mp.conflicts[tx6.Hash()])
// manually remove last tx which conflicts with tx6 => mp.conflicts[tx6] should also be deleted
mp.Remove(tx10.Hash(), fs)
require.Equal(t, 1, len(mp.conflicts))
require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()])
// tx13 conflicts with tx2, but is not signed by tx2.Sender
tx13 := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0)
tx13.NetworkFee = smallNetFee
tx13.Nonce = uint32(random.Int(0, 1e4))
tx13.Signers = []transaction.Signer{{Account: util.Uint160{3, 2, 1}}}
tx13.Attributes = []transaction.Attribute{{
Type: transaction.ConflictsT,
Value: &transaction.Conflicts{
Hash: tx2.Hash(),
},
}}
_, ok := mp.TryGetValue(tx13.Hash())
require.Equal(t, false, ok)
require.True(t, errors.Is(mp.Add(tx13, fs), ErrConflictsAttribute))
}

View file

@ -37,6 +37,8 @@ func (attr *Attribute) DecodeBinary(br *io.BinReader) {
attr.Value = new(OracleResponse) attr.Value = new(OracleResponse)
case NotValidBeforeT: case NotValidBeforeT:
attr.Value = new(NotValidBefore) attr.Value = new(NotValidBefore)
case ConflictsT:
attr.Value = new(Conflicts)
default: default:
if t >= ReservedLowerBound && t <= ReservedUpperBound { if t >= ReservedLowerBound && t <= ReservedUpperBound {
attr.Value = new(Reserved) attr.Value = new(Reserved)
@ -53,7 +55,7 @@ func (attr *Attribute) EncodeBinary(bw *io.BinWriter) {
bw.WriteB(byte(attr.Type)) bw.WriteB(byte(attr.Type))
switch t := attr.Type; t { switch t := attr.Type; t {
case HighPriority: case HighPriority:
case OracleResponseT, NotValidBeforeT: case OracleResponseT, NotValidBeforeT, ConflictsT:
attr.Value.EncodeBinary(bw) attr.Value.EncodeBinary(bw)
default: default:
if t >= ReservedLowerBound && t <= ReservedUpperBound { if t >= ReservedLowerBound && t <= ReservedUpperBound {
@ -92,6 +94,9 @@ func (attr *Attribute) UnmarshalJSON(data []byte) error {
case NotValidBeforeT.String(): case NotValidBeforeT.String():
attr.Type = NotValidBeforeT attr.Type = NotValidBeforeT
attr.Value = new(NotValidBefore) attr.Value = new(NotValidBefore)
case ConflictsT.String():
attr.Type = ConflictsT
attr.Value = new(Conflicts)
default: default:
return errors.New("wrong Type") return errors.New("wrong Type")
} }

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -59,6 +60,15 @@ func TestAttribute_EncodeBinary(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}) })
}) })
t.Run("Conflicts", func(t *testing.T) {
attr := &Attribute{
Type: ConflictsT,
Value: &Conflicts{
Hash: random.Uint256(),
},
}
testserdes.EncodeDecodeBinary(t, attr, new(Attribute))
})
} }
func TestAttribute_MarshalJSON(t *testing.T) { func TestAttribute_MarshalJSON(t *testing.T) {
@ -104,4 +114,13 @@ func TestAttribute_MarshalJSON(t *testing.T) {
} }
testserdes.MarshalUnmarshalJSON(t, attr, new(Attribute)) testserdes.MarshalUnmarshalJSON(t, attr, new(Attribute))
}) })
t.Run("Conflicts", func(t *testing.T) {
attr := &Attribute{
Type: ConflictsT,
Value: &Conflicts{
Hash: random.Uint256(),
},
}
testserdes.MarshalUnmarshalJSON(t, attr, new(Attribute))
})
} }

View file

@ -17,8 +17,14 @@ const (
HighPriority AttrType = 1 HighPriority AttrType = 1
OracleResponseT AttrType = 0x11 // OracleResponse OracleResponseT AttrType = 0x11 // OracleResponse
NotValidBeforeT AttrType = ReservedLowerBound // NotValidBefore NotValidBeforeT AttrType = ReservedLowerBound // NotValidBefore
ConflictsT AttrType = ReservedLowerBound + 1 // Conflicts
) )
func (a AttrType) allowMultiple() bool { func (a AttrType) allowMultiple() bool {
switch a {
case ConflictsT:
return true
default:
return false return false
}
} }

View file

@ -11,12 +11,17 @@ func _() {
_ = x[HighPriority-1] _ = x[HighPriority-1]
_ = x[OracleResponseT-17] _ = x[OracleResponseT-17]
_ = x[NotValidBeforeT-224] _ = x[NotValidBeforeT-224]
_ = x[ConflictsT-225]
} }
const ( const (
_AttrType_name_0 = "HighPriority" _AttrType_name_0 = "HighPriority"
_AttrType_name_1 = "OracleResponse" _AttrType_name_1 = "OracleResponse"
_AttrType_name_2 = "NotValidBefore" _AttrType_name_2 = "NotValidBeforeConflicts"
)
var (
_AttrType_index_2 = [...]uint8{0, 14, 23}
) )
func (i AttrType) String() string { func (i AttrType) String() string {
@ -25,8 +30,9 @@ func (i AttrType) String() string {
return _AttrType_name_0 return _AttrType_name_0
case i == 17: case i == 17:
return _AttrType_name_1 return _AttrType_name_1
case i == 224: case 224 <= i && i <= 225:
return _AttrType_name_2 i -= 224
return _AttrType_name_2[_AttrType_index_2[i]:_AttrType_index_2[i+1]]
default: default:
return "AttrType(" + strconv.FormatInt(int64(i), 10) + ")" return "AttrType(" + strconv.FormatInt(int64(i), 10) + ")"
} }

View file

@ -0,0 +1,30 @@
package transaction
import (
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// Conflicts represents attribute for conflicting transactions.
type Conflicts struct {
Hash util.Uint256 `json:"hash"`
}
// DecodeBinary implements io.Serializable interface.
func (c *Conflicts) DecodeBinary(br *io.BinReader) {
hash, err := util.Uint256DecodeBytesBE(br.ReadVarBytes(util.Uint256Size))
if err != nil {
br.Err = err
return
}
c.Hash = hash
}
// EncodeBinary implements io.Serializable interface.
func (c *Conflicts) EncodeBinary(w *io.BinWriter) {
w.WriteVarBytes(c.Hash.BytesBE())
}
func (c *Conflicts) toJSONMap(m map[string]interface{}) {
m["hash"] = c.Hash
}

View file

@ -26,6 +26,8 @@ const (
// MaxAttributes is maximum number of attributes including signers that can be contained // MaxAttributes is maximum number of attributes including signers that can be contained
// within a transaction. It is set to be 16. // within a transaction. It is set to be 16.
MaxAttributes = 16 MaxAttributes = 16
// DummyVersion represents reserved transaction version for trimmed transactions.
DummyVersion = 255
) )
// Transaction is a process recorded in the NEO blockchain. // Transaction is a process recorded in the NEO blockchain.
@ -370,7 +372,7 @@ var (
// isValid checks whether decoded/unmarshalled transaction has all fields valid. // isValid checks whether decoded/unmarshalled transaction has all fields valid.
func (t *Transaction) isValid() error { func (t *Transaction) isValid() error {
if t.Version > 0 { if t.Version > 0 && t.Version != DummyVersion {
return ErrInvalidVersion return ErrInvalidVersion
} }
if t.SystemFee < 0 { if t.SystemFee < 0 {
@ -407,3 +409,13 @@ func (t *Transaction) isValid() error {
} }
return nil return nil
} }
// HasSigner returns true in case if hash is present in the list of signers.
func (t *Transaction) HasSigner(hash util.Uint160) bool {
for _, h := range t.Signers {
if h.Account.Equals(hash) {
return true
}
}
return false
}

View file

@ -12,6 +12,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/internal/random"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
) )
@ -248,3 +249,14 @@ func TestTransaction_GetAttributes(t *testing.T) {
require.Equal(t, conflictsAttrs, tx.GetAttributes(typ)) require.Equal(t, conflictsAttrs, tx.GetAttributes(typ))
}) })
} }
func TestTransaction_HasSigner(t *testing.T) {
u1, u2 := random.Uint160(), random.Uint160()
tx := Transaction{
Signers: []Signer{
{Account: u1}, {Account: u2},
},
}
require.True(t, tx.HasSigner(u1))
require.False(t, tx.HasSigner(util.Uint160{}))
}

View file

@ -41,6 +41,10 @@ func (chain testChain) FeePerByte() int64 {
panic("TODO") panic("TODO")
} }
func (chain testChain) P2PSigExtensionsEnabled() bool {
return false
}
func (chain testChain) GetMaxBlockSystemFee() int64 { func (chain testChain) GetMaxBlockSystemFee() int64 {
panic("TODO") panic("TODO")
} }

View file

@ -98,3 +98,7 @@ func (fs *FeerStub) BlockHeight() uint32 {
func (fs *FeerStub) GetUtilityTokenBalance(acc util.Uint160) *big.Int { func (fs *FeerStub) GetUtilityTokenBalance(acc util.Uint160) *big.Int {
return big.NewInt(1000000 * native.GASFactor) return big.NewInt(1000000 * native.GASFactor)
} }
func (fs FeerStub) P2PSigExtensionsEnabled() bool {
return false
}