From ec63d5c456a15899abdba11be2d61b3e7874021b Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Thu, 15 Oct 2020 14:45:29 +0300 Subject: [PATCH] core: add conflicts attribute Close #1491 --- pkg/core/blockchain.go | 52 +++++++- pkg/core/blockchain_test.go | 59 +++++++++ pkg/core/dao/dao.go | 47 ++++++-- pkg/core/dao/dao_test.go | 4 +- pkg/core/mempool/feer.go | 1 + pkg/core/mempool/mem_pool.go | 151 ++++++++++++++++++++++-- pkg/core/mempool/mem_pool_test.go | 130 +++++++++++++++++++- pkg/core/transaction/attribute.go | 7 +- pkg/core/transaction/attribute_test.go | 19 +++ pkg/core/transaction/attrtype.go | 12 +- pkg/core/transaction/attrtype_string.go | 12 +- pkg/core/transaction/conflicts.go | 30 +++++ pkg/core/transaction/transaction.go | 4 +- pkg/network/helper_test.go | 4 + pkg/rpc/server/server_helper_test.go | 4 + 15 files changed, 499 insertions(+), 37 deletions(-) create mode 100644 pkg/core/transaction/conflicts.go diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 4994642eb..f47627bc7 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -58,6 +58,10 @@ var ( // ErrInvalidBlockIndex is returned when trying to add block with index // other than expected height of the blockchain. ErrInvalidBlockIndex error = errors.New("invalid block index") + // ErrHasConflicts is returned when trying to add some transaction which + // conflicts with other transaction in the chain or pool according to + // Conflicts attribute. + ErrHasConflicts = errors.New("has conflicts") ) var ( persistInterval = 1 * time.Second @@ -617,6 +621,18 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error return fmt.Errorf("failed to store tx exec result: %w", err) } writeBuf.Reset() + + if bc.config.P2PSigExtensions { + for _, attr := range tx.GetAttributes(transaction.ConflictsT) { + hash := attr.Value.(*transaction.Conflicts).Hash + dummyTx := transaction.NewTrimmedTX(hash) + dummyTx.Version = transaction.DummyVersion + if err = cache.StoreAsTransaction(dummyTx, block.Index, writeBuf); err != nil { + return fmt.Errorf("failed to store conflicting transaction %s for transaction %s: %w", hash.StringLE(), tx.Hash().StringLE(), err) + } + writeBuf.Reset() + } + } } aer, err := bc.runPersist(bc.contracts.GetPostPersistScript(), block, cache) @@ -984,7 +1000,10 @@ func (bc *Blockchain) GetHeader(hash util.Uint256) (*block.Header, error) { // HasTransaction returns true if the blockchain contains he given // transaction hash. func (bc *Blockchain) HasTransaction(hash util.Uint256) bool { - return bc.memPool.ContainsKey(hash) || bc.dao.HasTransaction(hash) + if bc.memPool.ContainsKey(hash) { + return true + } + return bc.dao.HasTransaction(hash) == dao.ErrAlreadyExists } // HasBlock returns true if the blockchain contains the given @@ -1227,8 +1246,16 @@ func (bc *Blockchain) verifyAndPoolTx(t *transaction.Transaction, pool *mempool. if netFee < 0 { return fmt.Errorf("%w: net fee is %v, need %v", ErrTxSmallNetworkFee, t.NetworkFee, needNetworkFee) } - if bc.dao.HasTransaction(t.Hash()) { - return fmt.Errorf("blockchain: %w", ErrAlreadyExists) + // check that current tx wasn't included in the conflicts attributes of some other transaction which is already in the chain + if err := bc.dao.HasTransaction(t.Hash()); err != nil { + switch { + case errors.Is(err, dao.ErrAlreadyExists): + return fmt.Errorf("blockchain: %w", ErrAlreadyExists) + case errors.Is(err, dao.ErrHasConflicts): + return fmt.Errorf("blockchain: %w", ErrHasConflicts) + default: + return err + } } err := bc.verifyTxWitnesses(t, nil) if err != nil { @@ -1248,6 +1275,8 @@ func (bc *Blockchain) verifyAndPoolTx(t *transaction.Transaction, pool *mempool. return ErrInsufficientFunds case errors.Is(err, mempool.ErrOOM): return ErrOOM + case errors.Is(err, mempool.ErrConflictsAttribute): + return fmt.Errorf("mempool: %w: %s", ErrHasConflicts, err) default: return err } @@ -1303,6 +1332,14 @@ func (bc *Blockchain) verifyTxAttributes(tx *transaction.Transaction) error { if height := bc.BlockHeight(); height < nvb.Height { return fmt.Errorf("%w: NotValidBefore = %d, current height = %d", ErrTxNotYetValid, nvb.Height, height) } + case transaction.ConflictsT: + if !bc.config.P2PSigExtensions { + return errors.New("Conflicts attribute was found, but P2PSigExtensions are disabled") + } + conflicts := tx.Attributes[i].Value.(*transaction.Conflicts) + if err := bc.dao.HasTransaction(conflicts.Hash); errors.Is(err, dao.ErrAlreadyExists) { + return fmt.Errorf("conflicting transaction %s is already on chain", conflicts.Hash.StringLE()) + } default: if !bc.config.ReservedAttributes && attrType >= transaction.ReservedLowerBound && attrType <= transaction.ReservedUpperBound { return errors.New("attribute of reserved type was found, but ReservedAttributes are disabled") @@ -1326,10 +1363,10 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction, txpool *memp return false } if txpool == nil { - if bc.dao.HasTransaction(t.Hash()) { + if bc.dao.HasTransaction(t.Hash()) != nil { return false } - } else if txpool.ContainsKey(t.Hash()) { + } else if txpool.HasConflicts(t, bc) { return false } if err := bc.verifyTxAttributes(t); err != nil { @@ -1650,3 +1687,8 @@ func (bc *Blockchain) newInteropContext(trigger trigger.Type, d dao.DAO, block * } return ic } + +// P2PSigExtensionsEnabled defines whether P2P signature extensions are enabled. +func (bc *Blockchain) P2PSigExtensionsEnabled() bool { + return bc.config.P2PSigExtensions +} diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 873957da0..a5c552e3c 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -18,6 +18,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/testchain" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" @@ -629,6 +630,64 @@ func TestVerifyTx(t *testing.T) { require.NoError(t, bc.VerifyTx(tx)) }) }) + t.Run("Conflicts", func(t *testing.T) { + getConflictsTx := func(hashes ...util.Uint256) *transaction.Transaction { + tx := bc.newTestTx(h, testScript) + tx.Attributes = make([]transaction.Attribute, len(hashes)) + for i, h := range hashes { + tx.Attributes[i] = transaction.Attribute{ + Type: transaction.ConflictsT, + Value: &transaction.Conflicts{ + Hash: h, + }, + } + } + tx.NetworkFee += 4_000_000 // multisig check + tx.Signers = []transaction.Signer{{ + Account: testchain.CommitteeScriptHash(), + Scopes: transaction.None, + }} + rawScript := testchain.CommitteeVerificationScript() + require.NoError(t, err) + size := io.GetVarSize(tx) + netFee, sizeDelta := fee.Calculate(rawScript) + tx.NetworkFee += netFee + tx.NetworkFee += int64(size+sizeDelta) * bc.FeePerByte() + data := tx.GetSignedPart() + tx.Scripts = []transaction.Witness{{ + InvocationScript: testchain.SignCommittee(data), + VerificationScript: rawScript, + }} + return tx + } + t.Run("disabled", func(t *testing.T) { + bc.config.P2PSigExtensions = false + tx := getConflictsTx(util.Uint256{1, 2, 3}) + require.Error(t, bc.VerifyTx(tx)) + }) + t.Run("enabled", func(t *testing.T) { + bc.config.P2PSigExtensions = true + t.Run("dummy on-chain conflict", func(t *testing.T) { + tx := bc.newTestTx(h, testScript) + require.NoError(t, accs[0].SignTx(tx)) + dummyTx := transaction.NewTrimmedTX(tx.Hash()) + dummyTx.Version = transaction.DummyVersion + require.NoError(t, bc.dao.StoreAsTransaction(dummyTx, bc.blockHeight, nil)) + require.True(t, errors.Is(bc.VerifyTx(tx), ErrHasConflicts)) + }) + t.Run("attribute on-chain conflict", func(t *testing.T) { + b, err := bc.GetBlock(bc.GetHeaderHash(0)) + require.NoError(t, err) + conflictsHash := b.Transactions[0].Hash() + tx := getConflictsTx(conflictsHash) + require.Error(t, bc.VerifyTx(tx)) + }) + t.Run("positive", func(t *testing.T) { + tx := getConflictsTx(random.Uint256()) + require.NoError(t, bc.VerifyTx(tx)) + }) + }) + }) }) } diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index 373473320..a508bf2fc 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -3,6 +3,7 @@ package dao import ( "bytes" "encoding/binary" + "errors" "fmt" "sort" @@ -16,6 +17,15 @@ import ( "github.com/nspcc-dev/neo-go/pkg/util" ) +// HasTransaction errors +var ( + // ErrAlreadyExists is returned when transaction exists in dao. + ErrAlreadyExists = errors.New("transaction already exists") + // ErrHasConflicts is returned when transaction is in the list of conflicting + // transactions which are already in dao. + ErrHasConflicts = errors.New("transaction has conflicts") +) + // DAO is a data access object. type DAO interface { AppendNEP5Transfer(acc util.Uint160, index uint32, tr *state.NEP5Transfer) (bool, error) @@ -42,7 +52,7 @@ type DAO interface { GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) GetVersion() (string, error) GetWrapped() DAO - HasTransaction(hash util.Uint256) bool + HasTransaction(hash util.Uint256) error Persist() (int, error) PutAppExecResult(aer *state.AppExecResult, buf *io.BufBinWriter) error PutContractState(cs *state.Contract) error @@ -515,13 +525,19 @@ func (dao *Simple) GetHeaderHashes() ([]util.Uint256, error) { } // GetTransaction returns Transaction and its height by the given hash -// if it exists in the store. +// if it exists in the store. It does not return dummy transactions. func (dao *Simple) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) { key := storage.AppendPrefix(storage.DataTransaction, hash.BytesLE()) b, err := dao.Store.Get(key) if err != nil { return nil, 0, err } + if len(b) < 5 { + return nil, 0, errors.New("bad transaction bytes") + } + if b[4] == transaction.DummyVersion { + return nil, 0, storage.ErrKeyNotFound + } r := io.NewBinReaderFromBuf(b) var height = r.ReadU32LE() @@ -558,14 +574,23 @@ func read2000Uint256Hashes(b []byte) ([]util.Uint256, error) { return hashes, nil } -// HasTransaction returns true if the given store contains the given -// Transaction hash. -func (dao *Simple) HasTransaction(hash util.Uint256) bool { +// HasTransaction returns nil if the given store does not contain the given +// Transaction hash. It returns an error in case if transaction is in chain +// or in the list of conflicting transactions. +func (dao *Simple) HasTransaction(hash util.Uint256) error { key := storage.AppendPrefix(storage.DataTransaction, hash.BytesLE()) - if _, err := dao.Store.Get(key); err == nil { - return true + bytes, err := dao.Store.Get(key) + if err != nil { + return nil } - return false + + if len(bytes) < 5 { + return nil + } + if bytes[4] == transaction.DummyVersion { + return ErrHasConflicts + } + return ErrAlreadyExists } // StoreAsBlock stores given block as DataBlock. It can reuse given buffer for @@ -609,7 +634,11 @@ func (dao *Simple) StoreAsTransaction(tx *transaction.Transaction, index uint32, buf = io.NewBufBinWriter() } buf.WriteU32LE(index) - tx.EncodeBinary(buf.BinWriter) + if tx.Version == transaction.DummyVersion { + buf.BinWriter.WriteB(tx.Version) + } else { + tx.EncodeBinary(buf.BinWriter) + } if buf.Err != nil { return buf.Err } diff --git a/pkg/core/dao/dao_test.go b/pkg/core/dao/dao_test.go index 9fcbd0676..4550da3c6 100644 --- a/pkg/core/dao/dao_test.go +++ b/pkg/core/dao/dao_test.go @@ -189,8 +189,8 @@ func TestStoreAsTransaction(t *testing.T) { hash := tx.Hash() err := dao.StoreAsTransaction(tx, 0, nil) require.NoError(t, err) - hasTransaction := dao.HasTransaction(hash) - require.True(t, hasTransaction) + err = dao.HasTransaction(hash) + require.NotNil(t, err) } func TestMakeStorageItemKey(t *testing.T) { diff --git a/pkg/core/mempool/feer.go b/pkg/core/mempool/feer.go index cc86de243..40b5d0743 100644 --- a/pkg/core/mempool/feer.go +++ b/pkg/core/mempool/feer.go @@ -11,4 +11,5 @@ type Feer interface { FeePerByte() int64 GetUtilityTokenBalance(util.Uint160) *big.Int BlockHeight() uint32 + P2PSigExtensionsEnabled() bool } diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index a0800ccf3..74f17bf37 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -2,6 +2,7 @@ package mempool import ( "errors" + "fmt" "math/big" "sort" "sync" @@ -25,6 +26,9 @@ var ( // ErrOOM is returned when transaction just doesn't fit in the memory // pool because of its capacity constraints. ErrOOM = errors.New("out of memory") + // ErrConflictsAttribute is returned when transaction conflicts with other transactions + // due to its (or theirs) Conflicts attributes. + ErrConflictsAttribute = errors.New("conflicts with memory pool due to Conflicts attribute") ) // item represents a transaction in the the Memory pool. @@ -49,6 +53,8 @@ type Pool struct { verifiedMap map[util.Uint256]*transaction.Transaction verifiedTxes items fees map[util.Uint160]utilityBalanceAndFees + // conflicts is a map of hashes of transactions which are conflicting with the mempooled ones. + conflicts map[util.Uint256][]util.Uint256 capacity int feePerByte int64 @@ -108,6 +114,29 @@ func (mp *Pool) containsKey(hash util.Uint256) bool { return false } +// HasConflicts returns true if transaction is already in pool or in the Conflicts attributes +// of pooled transactions or has Conflicts attributes for pooled transactions. +func (mp *Pool) HasConflicts(t *transaction.Transaction, fee Feer) bool { + mp.lock.RLock() + defer mp.lock.RUnlock() + + if mp.containsKey(t.Hash()) { + return true + } + if fee.P2PSigExtensionsEnabled() { + // do not check sender's signature and fee + if _, ok := mp.conflicts[t.Hash()]; ok { + return true + } + for _, attr := range t.GetAttributes(transaction.ConflictsT) { + if mp.containsKey(attr.Value.(*transaction.Conflicts).Hash) { + return true + } + } + } + return false +} + // tryAddSendersFee tries to add system fee and network fee to the total sender`s fee in mempool // and returns false if both balance check is required and sender has not enough GAS to pay func (mp *Pool) tryAddSendersFee(tx *transaction.Transaction, feer Feer, needCheck bool) bool { @@ -154,13 +183,19 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { mp.lock.Unlock() return ErrDup } - err := mp.checkTxConflicts(t, fee) + conflictsToBeRemoved, err := mp.checkTxConflicts(t, fee) if err != nil { mp.lock.Unlock() return err } mp.verifiedMap[t.Hash()] = t + if fee.P2PSigExtensionsEnabled() { + // Remove conflicting transactions. + for _, conflictingTx := range conflictsToBeRemoved { + mp.removeInternal(conflictingTx.Hash(), fee) + } + } // Insert into sorted array (from max to min, that could also be done // using sort.Sort(sort.Reverse()), but it incurs more overhead. Notice // also that we're searching for position that is strictly more @@ -181,6 +216,9 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { // Ditch the last one. unlucky := mp.verifiedTxes[len(mp.verifiedTxes)-1] delete(mp.verifiedMap, unlucky.txn.Hash()) + if fee.P2PSigExtensionsEnabled() { + mp.removeConflictsOf(unlucky.txn) + } mp.verifiedTxes[len(mp.verifiedTxes)-1] = pItem } else { mp.verifiedTxes = append(mp.verifiedTxes, pItem) @@ -189,6 +227,13 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { copy(mp.verifiedTxes[n+1:], mp.verifiedTxes[n:]) mp.verifiedTxes[n] = pItem } + if fee.P2PSigExtensionsEnabled() { + // Add conflicting hashes to the mp.conflicts list. + for _, attr := range t.GetAttributes(transaction.ConflictsT) { + hash := attr.Value.(*transaction.Conflicts).Hash + mp.conflicts[hash] = append(mp.conflicts[hash], t.Hash()) + } + } // we already checked balance in checkTxConflicts, so don't need to check again mp.tryAddSendersFee(pItem.txn, fee, false) @@ -199,8 +244,14 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer) error { // Remove removes an item from the mempool, if it exists there (and does // nothing if it doesn't). -func (mp *Pool) Remove(hash util.Uint256) { +func (mp *Pool) Remove(hash util.Uint256, feer Feer) { mp.lock.Lock() + mp.removeInternal(hash, feer) + mp.lock.Unlock() +} + +// removeInternal is an internal unlocked representation of Remove +func (mp *Pool) removeInternal(hash util.Uint256, feer Feer) { if tx, ok := mp.verifiedMap[hash]; ok { var num int delete(mp.verifiedMap, hash) @@ -217,9 +268,12 @@ func (mp *Pool) Remove(hash util.Uint256) { senderFee := mp.fees[tx.Sender()] senderFee.feeSum.Sub(senderFee.feeSum, big.NewInt(tx.SystemFee+tx.NetworkFee)) mp.fees[tx.Sender()] = senderFee + if feer.P2PSigExtensionsEnabled() { + // remove all conflicting hashes from mp.conflicts list + mp.removeConflictsOf(tx) + } } updateMempoolMetrics(len(mp.verifiedTxes)) - mp.lock.Unlock() } // RemoveStale filters verified transactions through the given function keeping @@ -232,9 +286,18 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer) // because items are iterated one-by-one in increasing order. newVerifiedTxes := mp.verifiedTxes[:0] mp.fees = make(map[util.Uint160]utilityBalanceAndFees) // it'd be nice to reuse existing map, but we can't easily clear it + if feer.P2PSigExtensionsEnabled() { + mp.conflicts = make(map[util.Uint256][]util.Uint256) + } for _, itm := range mp.verifiedTxes { if isOK(itm.txn) && mp.checkPolicy(itm.txn, policyChanged) && mp.tryAddSendersFee(itm.txn, feer, true) { newVerifiedTxes = append(newVerifiedTxes, itm) + if feer.P2PSigExtensionsEnabled() { + for _, attr := range itm.txn.GetAttributes(transaction.ConflictsT) { + hash := attr.Value.(*transaction.Conflicts).Hash + mp.conflicts[hash] = append(mp.conflicts[hash], itm.txn.Hash()) + } + } } else { delete(mp.verifiedMap, itm.txn.Hash()) } @@ -269,6 +332,7 @@ func New(capacity int) *Pool { verifiedTxes: make([]item, 0, capacity), capacity: capacity, fees: make(map[util.Uint160]utilityBalanceAndFees), + conflicts: make(map[util.Uint256][]util.Uint256), } } @@ -297,15 +361,59 @@ func (mp *Pool) GetVerifiedTransactions() []*transaction.Transaction { return t } -// checkTxConflicts is an internal unprotected version of Verify. -func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) error { - senderFee, ok := mp.fees[tx.Sender()] +// checkTxConflicts is an internal unprotected version of Verify. It takes into +// consideration conflicting transactions which are about to be removed from mempool. +func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) ([]*transaction.Transaction, error) { + actualSenderFee, ok := mp.fees[tx.Sender()] if !ok { - senderFee.balance = fee.GetUtilityTokenBalance(tx.Sender()) - senderFee.feeSum = big.NewInt(0) + actualSenderFee.balance = fee.GetUtilityTokenBalance(tx.Sender()) + actualSenderFee.feeSum = big.NewInt(0) } - _, err := checkBalance(tx, senderFee) - return err + + var expectedSenderFee utilityBalanceAndFees + // Check Conflicts attributes. + var conflictsToBeRemoved []*transaction.Transaction + if fee.P2PSigExtensionsEnabled() { + // Step 1: check if `tx` was in attributes of mempooled transactions. + if conflictingHashes, ok := mp.conflicts[tx.Hash()]; ok { + for _, hash := range conflictingHashes { + existingTx := mp.verifiedMap[hash] + if existingTx.HasSigner(tx.Sender()) && existingTx.NetworkFee > tx.NetworkFee { + return nil, fmt.Errorf("%w: conflicting transaction %s has bigger network fee", ErrConflictsAttribute, existingTx.Hash().StringBE()) + } + conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx) + } + } + // Step 2: check if mempooled transactions were in `tx`'s attributes. + for _, attr := range tx.GetAttributes(transaction.ConflictsT) { + hash := attr.Value.(*transaction.Conflicts).Hash + existingTx, ok := mp.verifiedMap[hash] + if !ok { + continue + } + if !tx.HasSigner(existingTx.Sender()) { + return nil, fmt.Errorf("%w: not signed by the sender of conflicting transaction %s", ErrConflictsAttribute, existingTx.Hash().StringBE()) + } + if existingTx.NetworkFee >= tx.NetworkFee { + return nil, fmt.Errorf("%w: conflicting transaction %s has bigger or equal network fee", ErrConflictsAttribute, existingTx.Hash().StringBE()) + } + conflictsToBeRemoved = append(conflictsToBeRemoved, existingTx) + } + // Step 3: take into account sender's conflicting transactions before balance check. + expectedSenderFee = utilityBalanceAndFees{ + balance: new(big.Int).Set(actualSenderFee.balance), + feeSum: new(big.Int).Set(actualSenderFee.feeSum), + } + for _, conflictingTx := range conflictsToBeRemoved { + if conflictingTx.Sender().Equals(tx.Sender()) { + expectedSenderFee.feeSum.Sub(expectedSenderFee.feeSum, big.NewInt(conflictingTx.SystemFee+conflictingTx.NetworkFee)) + } + } + } else { + expectedSenderFee = actualSenderFee + } + _, err := checkBalance(tx, expectedSenderFee) + return conflictsToBeRemoved, err } // Verify checks if a Sender of tx is able to pay for it (and all the other @@ -315,5 +423,26 @@ func (mp *Pool) checkTxConflicts(tx *transaction.Transaction, fee Feer) error { func (mp *Pool) Verify(tx *transaction.Transaction, feer Feer) bool { mp.lock.RLock() defer mp.lock.RUnlock() - return mp.checkTxConflicts(tx, feer) == nil + _, err := mp.checkTxConflicts(tx, feer) + return err == nil +} + +// removeConflictsOf removes hash of the given transaction from the conflicts list +// for each Conflicts attribute. +func (mp *Pool) removeConflictsOf(tx *transaction.Transaction) { + // remove all conflicting hashes from mp.conflicts list + for _, attr := range tx.GetAttributes(transaction.ConflictsT) { + conflictsHash := attr.Value.(*transaction.Conflicts).Hash + if len(mp.conflicts[conflictsHash]) == 1 { + delete(mp.conflicts, conflictsHash) + continue + } + for i, existingHash := range mp.conflicts[conflictsHash] { + if existingHash == tx.Hash() { + // tx.Hash can occur in the conflicting hashes array only once, because we can't add the same transaction to the mempol twice + mp.conflicts[conflictsHash] = append(mp.conflicts[conflictsHash][:i], mp.conflicts[conflictsHash][i+1:]...) + break + } + } + } } diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index e2a1f670c..16ecbf9f3 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -1,12 +1,14 @@ package mempool import ( + "errors" "math/big" "sort" "testing" "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/stretchr/testify/assert" @@ -15,6 +17,7 @@ import ( type FeerStub struct { feePerByte int64 + p2pSigExt bool } var balance = big.NewInt(10000000) @@ -31,6 +34,10 @@ func (fs *FeerStub) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int { return balance } +func (fs *FeerStub) P2PSigExtensionsEnabled() bool { + return fs.p2pSigExt +} + func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { mp := New(10) tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) @@ -44,7 +51,7 @@ func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { tx2, ok := mp.TryGetValue(tx.Hash()) require.Equal(t, true, ok) require.Equal(t, tx, tx2) - mp.Remove(tx.Hash()) + mp.Remove(tx.Hash(), fs) _, ok = mp.TryGetValue(tx.Hash()) require.Equal(t, false, ok) // Make sure nothing left in the mempool after removal. @@ -148,7 +155,7 @@ func TestGetVerified(t *testing.T) { require.Equal(t, mempoolSize, len(verTxes)) require.ElementsMatch(t, txes, verTxes) for _, tx := range txes { - mp.Remove(tx.Hash()) + mp.Remove(tx.Hash(), fs) } verTxes = mp.GetVerifiedTransactions() require.Equal(t, 0, len(verTxes)) @@ -297,3 +304,122 @@ func TestMempoolItemsOrder(t *testing.T) { require.True(t, item3.CompareTo(item4) > 0) require.True(t, item4.CompareTo(item3) < 0) } + +func TestMempoolAddRemoveConflicts(t *testing.T) { + capacity := 6 + mp := New(capacity) + var fs = &FeerStub{ + p2pSigExt: true, + } + getConflictsTx := func(netFee int64, hashes ...util.Uint256) *transaction.Transaction { + tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) + tx.NetworkFee = netFee + tx.Nonce = uint32(random.Int(0, 1e4)) + tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} + tx.Attributes = make([]transaction.Attribute, len(hashes)) + for i, h := range hashes { + tx.Attributes[i] = transaction.Attribute{ + Type: transaction.ConflictsT, + Value: &transaction.Conflicts{ + Hash: h, + }, + } + } + _, ok := mp.TryGetValue(tx.Hash()) + require.Equal(t, false, ok) + return tx + } + + // tx1 in mempool and does not conflicts with anyone + smallNetFee := int64(3) + tx1 := getConflictsTx(smallNetFee) + require.NoError(t, mp.Add(tx1, fs)) + + // tx2 conflicts with tx1 and has smaller netfee (Step 2, negative) + tx2 := getConflictsTx(smallNetFee-1, tx1.Hash()) + require.True(t, errors.Is(mp.Add(tx2, fs), ErrConflictsAttribute)) + + // tx3 conflicts with mempooled tx1 and has larger netfee => tx1 should be replaced by tx3 (Step 2, positive) + tx3 := getConflictsTx(smallNetFee+1, tx1.Hash()) + require.NoError(t, mp.Add(tx3, fs)) + require.Equal(t, 1, mp.Count()) + require.Equal(t, 1, len(mp.conflicts)) + require.Equal(t, []util.Uint256{tx3.Hash()}, mp.conflicts[tx1.Hash()]) + + // tx1 still does not conflicts with anyone, but tx3 is mempooled, conflicts with tx1 + // and has larger netfee => tx1 shouldn't be added again (Step 1, negative) + require.True(t, errors.Is(mp.Add(tx1, fs), ErrConflictsAttribute)) + + // tx2 can now safely be added because conflicting tx1 is not in mempool => we + // cannot check that tx2 is signed by tx1.Sender + require.NoError(t, mp.Add(tx2, fs)) + require.Equal(t, 1, len(mp.conflicts)) + require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()]) + + // mempooled tx4 conflicts with tx5, but tx4 has smaller netfee => tx4 should be replaced by tx5 (Step 1, positive) + tx5 := getConflictsTx(smallNetFee + 1) + tx4 := getConflictsTx(smallNetFee, tx5.Hash()) + require.NoError(t, mp.Add(tx4, fs)) // unverified + require.Equal(t, 2, len(mp.conflicts)) + require.Equal(t, []util.Uint256{tx4.Hash()}, mp.conflicts[tx5.Hash()]) + require.NoError(t, mp.Add(tx5, fs)) + // tx5 does not conflict with anyone + require.Equal(t, 1, len(mp.conflicts)) + + // multiple conflicts in attributes of single transaction + tx6 := getConflictsTx(smallNetFee) + tx7 := getConflictsTx(smallNetFee) + tx8 := getConflictsTx(smallNetFee) + // need small network fee later + tx9 := getConflictsTx(smallNetFee-2, tx6.Hash(), tx7.Hash(), tx8.Hash()) + require.NoError(t, mp.Add(tx9, fs)) + require.Equal(t, 4, len(mp.conflicts)) + require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx6.Hash()]) + require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx7.Hash()]) + require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx8.Hash()]) + require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()]) + + // multiple conflicts in attributes of multiple transactions + tx10 := getConflictsTx(smallNetFee, tx6.Hash()) + tx11 := getConflictsTx(smallNetFee, tx6.Hash()) + require.NoError(t, mp.Add(tx10, fs)) // unverified, because tx6 is not in the pool + require.NoError(t, mp.Add(tx11, fs)) // unverified, because tx6 is not in the pool + require.Equal(t, 4, len(mp.conflicts)) + require.Equal(t, []util.Uint256{tx9.Hash(), tx10.Hash(), tx11.Hash()}, mp.conflicts[tx6.Hash()]) + require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx7.Hash()]) + require.Equal(t, []util.Uint256{tx9.Hash()}, mp.conflicts[tx8.Hash()]) + require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()]) + + // reach capacity, remove less prioritised tx9 with its multiple conflicts + require.Equal(t, capacity, len(mp.verifiedTxes)) + tx12 := getConflictsTx(smallNetFee + 2) + require.NoError(t, mp.Add(tx12, fs)) + require.Equal(t, 2, len(mp.conflicts)) + require.Equal(t, []util.Uint256{tx10.Hash(), tx11.Hash()}, mp.conflicts[tx6.Hash()]) + require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()]) + + // manually remove tx11 with its single conflict + mp.Remove(tx11.Hash(), fs) + require.Equal(t, 2, len(mp.conflicts)) + require.Equal(t, []util.Uint256{tx10.Hash()}, mp.conflicts[tx6.Hash()]) + + // manually remove last tx which conflicts with tx6 => mp.conflicts[tx6] should also be deleted + mp.Remove(tx10.Hash(), fs) + require.Equal(t, 1, len(mp.conflicts)) + require.Equal(t, []util.Uint256{tx3.Hash(), tx2.Hash()}, mp.conflicts[tx1.Hash()]) + + // tx13 conflicts with tx2, but is not signed by tx2.Sender + tx13 := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) + tx13.NetworkFee = smallNetFee + tx13.Nonce = uint32(random.Int(0, 1e4)) + tx13.Signers = []transaction.Signer{{Account: util.Uint160{3, 2, 1}}} + tx13.Attributes = []transaction.Attribute{{ + Type: transaction.ConflictsT, + Value: &transaction.Conflicts{ + Hash: tx2.Hash(), + }, + }} + _, ok := mp.TryGetValue(tx13.Hash()) + require.Equal(t, false, ok) + require.True(t, errors.Is(mp.Add(tx13, fs), ErrConflictsAttribute)) +} diff --git a/pkg/core/transaction/attribute.go b/pkg/core/transaction/attribute.go index ca3e1011e..5ef14dd17 100644 --- a/pkg/core/transaction/attribute.go +++ b/pkg/core/transaction/attribute.go @@ -37,6 +37,8 @@ func (attr *Attribute) DecodeBinary(br *io.BinReader) { attr.Value = new(OracleResponse) case NotValidBeforeT: attr.Value = new(NotValidBefore) + case ConflictsT: + attr.Value = new(Conflicts) default: if t >= ReservedLowerBound && t <= ReservedUpperBound { attr.Value = new(Reserved) @@ -53,7 +55,7 @@ func (attr *Attribute) EncodeBinary(bw *io.BinWriter) { bw.WriteB(byte(attr.Type)) switch t := attr.Type; t { case HighPriority: - case OracleResponseT, NotValidBeforeT: + case OracleResponseT, NotValidBeforeT, ConflictsT: attr.Value.EncodeBinary(bw) default: if t >= ReservedLowerBound && t <= ReservedUpperBound { @@ -92,6 +94,9 @@ func (attr *Attribute) UnmarshalJSON(data []byte) error { case NotValidBeforeT.String(): attr.Type = NotValidBeforeT attr.Value = new(NotValidBefore) + case ConflictsT.String(): + attr.Type = ConflictsT + attr.Value = new(Conflicts) default: return errors.New("wrong Type") } diff --git a/pkg/core/transaction/attribute_test.go b/pkg/core/transaction/attribute_test.go index c72b1d278..b769fe88b 100644 --- a/pkg/core/transaction/attribute_test.go +++ b/pkg/core/transaction/attribute_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "testing" + "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/stretchr/testify/require" ) @@ -59,6 +60,15 @@ func TestAttribute_EncodeBinary(t *testing.T) { require.Error(t, err) }) }) + t.Run("Conflicts", func(t *testing.T) { + attr := &Attribute{ + Type: ConflictsT, + Value: &Conflicts{ + Hash: random.Uint256(), + }, + } + testserdes.EncodeDecodeBinary(t, attr, new(Attribute)) + }) } func TestAttribute_MarshalJSON(t *testing.T) { @@ -104,4 +114,13 @@ func TestAttribute_MarshalJSON(t *testing.T) { } testserdes.MarshalUnmarshalJSON(t, attr, new(Attribute)) }) + t.Run("Conflicts", func(t *testing.T) { + attr := &Attribute{ + Type: ConflictsT, + Value: &Conflicts{ + Hash: random.Uint256(), + }, + } + testserdes.MarshalUnmarshalJSON(t, attr, new(Attribute)) + }) } diff --git a/pkg/core/transaction/attrtype.go b/pkg/core/transaction/attrtype.go index 090e28d28..fa68ae5cc 100644 --- a/pkg/core/transaction/attrtype.go +++ b/pkg/core/transaction/attrtype.go @@ -15,10 +15,16 @@ const ( // List of valid attribute types. const ( HighPriority AttrType = 1 - OracleResponseT AttrType = 0x11 // OracleResponse - NotValidBeforeT AttrType = ReservedLowerBound // NotValidBefore + OracleResponseT AttrType = 0x11 // OracleResponse + NotValidBeforeT AttrType = ReservedLowerBound // NotValidBefore + ConflictsT AttrType = ReservedLowerBound + 1 // Conflicts ) func (a AttrType) allowMultiple() bool { - return false + switch a { + case ConflictsT: + return true + default: + return false + } } diff --git a/pkg/core/transaction/attrtype_string.go b/pkg/core/transaction/attrtype_string.go index e45529acb..70db86618 100644 --- a/pkg/core/transaction/attrtype_string.go +++ b/pkg/core/transaction/attrtype_string.go @@ -11,12 +11,17 @@ func _() { _ = x[HighPriority-1] _ = x[OracleResponseT-17] _ = x[NotValidBeforeT-224] + _ = x[ConflictsT-225] } const ( _AttrType_name_0 = "HighPriority" _AttrType_name_1 = "OracleResponse" - _AttrType_name_2 = "NotValidBefore" + _AttrType_name_2 = "NotValidBeforeConflicts" +) + +var ( + _AttrType_index_2 = [...]uint8{0, 14, 23} ) func (i AttrType) String() string { @@ -25,8 +30,9 @@ func (i AttrType) String() string { return _AttrType_name_0 case i == 17: return _AttrType_name_1 - case i == 224: - return _AttrType_name_2 + case 224 <= i && i <= 225: + i -= 224 + return _AttrType_name_2[_AttrType_index_2[i]:_AttrType_index_2[i+1]] default: return "AttrType(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/core/transaction/conflicts.go b/pkg/core/transaction/conflicts.go new file mode 100644 index 000000000..235ed3341 --- /dev/null +++ b/pkg/core/transaction/conflicts.go @@ -0,0 +1,30 @@ +package transaction + +import ( + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// Conflicts represents attribute for conflicting transactions. +type Conflicts struct { + Hash util.Uint256 `json:"hash"` +} + +// DecodeBinary implements io.Serializable interface. +func (c *Conflicts) DecodeBinary(br *io.BinReader) { + hash, err := util.Uint256DecodeBytesBE(br.ReadVarBytes(util.Uint256Size)) + if err != nil { + br.Err = err + return + } + c.Hash = hash +} + +// EncodeBinary implements io.Serializable interface. +func (c *Conflicts) EncodeBinary(w *io.BinWriter) { + w.WriteVarBytes(c.Hash.BytesBE()) +} + +func (c *Conflicts) toJSONMap(m map[string]interface{}) { + m["hash"] = c.Hash +} diff --git a/pkg/core/transaction/transaction.go b/pkg/core/transaction/transaction.go index d9c23b906..dc1d9abe5 100644 --- a/pkg/core/transaction/transaction.go +++ b/pkg/core/transaction/transaction.go @@ -26,6 +26,8 @@ const ( // MaxAttributes is maximum number of attributes including signers that can be contained // within a transaction. It is set to be 16. MaxAttributes = 16 + // DummyVersion represents reserved transaction version for trimmed transactions. + DummyVersion = 255 ) // Transaction is a process recorded in the NEO blockchain. @@ -370,7 +372,7 @@ var ( // isValid checks whether decoded/unmarshalled transaction has all fields valid. func (t *Transaction) isValid() error { - if t.Version > 0 { + if t.Version > 0 && t.Version != DummyVersion { return ErrInvalidVersion } if t.SystemFee < 0 { diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 38428daea..5f74e24b7 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -41,6 +41,10 @@ func (chain testChain) FeePerByte() int64 { panic("TODO") } +func (chain testChain) P2PSigExtensionsEnabled() bool { + return false +} + func (chain testChain) GetMaxBlockSystemFee() int64 { panic("TODO") } diff --git a/pkg/rpc/server/server_helper_test.go b/pkg/rpc/server/server_helper_test.go index 1ea4ab6c4..ee41e0be4 100644 --- a/pkg/rpc/server/server_helper_test.go +++ b/pkg/rpc/server/server_helper_test.go @@ -98,3 +98,7 @@ func (fs *FeerStub) BlockHeight() uint32 { func (fs *FeerStub) GetUtilityTokenBalance(acc util.Uint160) *big.Int { return big.NewInt(1000000 * native.GASFactor) } + +func (fs FeerStub) P2PSigExtensionsEnabled() bool { + return false +}