From bfbd096feda04abe5656530c7de696422a1b7cf0 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Fri, 15 Jan 2021 15:40:15 +0300 Subject: [PATCH] core: introduce mempool notifications --- pkg/consensus/consensus.go | 2 +- pkg/core/blockchain.go | 27 ++----- pkg/core/blockchain_test.go | 4 +- pkg/core/blockchainer/blockchainer.go | 1 - pkg/core/mempool/feer.go | 1 - pkg/core/mempool/mem_pool.go | 85 ++++++++++++++-------- pkg/core/mempool/mem_pool_test.go | 22 +++--- pkg/core/mempool/subscriptions.go | 86 ++++++++++++++++++++++ pkg/core/mempool/subscriptions_test.go | 98 ++++++++++++++++++++++++++ pkg/network/helper_test.go | 9 +-- pkg/network/notary_feer.go | 5 -- pkg/network/server.go | 24 +++---- pkg/network/server_test.go | 1 - pkg/rpc/server/server_helper_test.go | 4 -- pkg/services/notary/notary.go | 36 +++++++++- 15 files changed, 300 insertions(+), 105 deletions(-) create mode 100644 pkg/core/mempool/subscriptions.go create mode 100644 pkg/core/mempool/subscriptions_test.go diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index a35230f02..3e31d6ecf 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -451,7 +451,7 @@ func (s *service) verifyBlock(b block.Block) bool { } var fee int64 - var pool = mempool.New(len(coreb.Transactions), 0) + var pool = mempool.New(len(coreb.Transactions), 0, false) var mainPool = s.Chain.GetMemPool() for _, tx := range coreb.Transactions { var err error diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 2cd8974fc..20481bd02 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -122,8 +122,6 @@ type Blockchain struct { // postBlock is a set of callback methods which should be run under the Blockchain lock after new block is persisted. // Block's transactions are passed via mempool. postBlock []func(blockchainer.Blockchainer, *mempool.Pool, *block.Block) - // poolTxWithDataCallbacks is a set of callback methods which should be run nuder the Blockchain lock after successful PoolTxWithData invocation. - poolTxWithDataCallbacks []func(t *transaction.Transaction, data interface{}) sbCommittee keys.PublicKeys @@ -179,7 +177,7 @@ func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration, log *zap.L dao: dao.NewSimple(s, cfg.Magic, cfg.StateRootInHeader), stopCh: make(chan struct{}), runToExitCh: make(chan struct{}), - memPool: mempool.New(cfg.MemPoolSize, 0), + memPool: mempool.New(cfg.MemPoolSize, 0, false), sbCommittee: committee, log: log, events: make(chan bcEvent), @@ -485,7 +483,7 @@ func (bc *Blockchain) AddBlock(block *block.Block) error { if !block.MerkleRoot.Equals(merkle) { return errors.New("invalid block: MerkleRoot mismatch") } - mp = mempool.New(len(block.Transactions), 0) + mp = mempool.New(len(block.Transactions), 0, false) for _, tx := range block.Transactions { var err error // Transactions are verified before adding them @@ -1645,7 +1643,7 @@ func (bc *Blockchain) verifyStateRootWitness(r *state.MPTRoot) error { // current blockchain state. Note that this verification is completely isolated // from the main node's mempool. func (bc *Blockchain) VerifyTx(t *transaction.Transaction) error { - var mp = mempool.New(1, 0) + var mp = mempool.New(1, 0, false) bc.lock.RLock() defer bc.lock.RUnlock() return bc.verifyAndPoolTx(t, mp, bc) @@ -1679,19 +1677,7 @@ func (bc *Blockchain) PoolTxWithData(t *transaction.Transaction, data interface{ return err } } - if err := bc.verifyAndPoolTx(t, mp, feer, data); err != nil { - return err - } - for _, f := range bc.poolTxWithDataCallbacks { - f(t, data) - } - return nil -} - -// RegisterPoolTxWithDataCallback registers new callback function which is called -// under the Blockchain lock after successful PoolTxWithData invocation. -func (bc *Blockchain) RegisterPoolTxWithDataCallback(f func(t *transaction.Transaction, data interface{})) { - bc.poolTxWithDataCallbacks = append(bc.poolTxWithDataCallbacks, f) + return bc.verifyAndPoolTx(t, mp, feer, data) } //GetStandByValidators returns validators from the configuration. @@ -1912,11 +1898,6 @@ func (bc *Blockchain) newInteropContext(trigger trigger.Type, d dao.DAO, block * return ic } -// P2PNotaryModuleEnabled defines whether P2P notary module is enabled. -func (bc *Blockchain) P2PNotaryModuleEnabled() bool { - return bc.config.P2PNotary.Enabled -} - // P2PSigExtensionsEnabled defines whether P2P signature extensions are enabled. func (bc *Blockchain) P2PSigExtensionsEnabled() bool { return bc.config.P2PSigExtensions diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 4b4741b91..624eb5d05 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -449,7 +449,7 @@ func TestVerifyTx(t *testing.T) { require.True(t, errors.Is(err, ErrAlreadyExists)) }) t.Run("MemPoolOOM", func(t *testing.T) { - bc.memPool = mempool.New(1, 0) + bc.memPool = mempool.New(1, 0, false) tx1 := bc.newTestTx(h, testScript) tx1.NetworkFee += 10000 // Give it more priority. require.NoError(t, accs[0].SignTx(tx1)) @@ -988,7 +988,7 @@ func TestVerifyTx(t *testing.T) { return tx } - mp := mempool.New(10, 1) + mp := mempool.New(10, 1, false) verificationF := func(bc blockchainer.Blockchainer, tx *transaction.Transaction, data interface{}) error { if data.(int) > 5 { return errors.New("bad data") diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index b303cf978..9dad9966c 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -64,7 +64,6 @@ type Blockchainer interface { ManagementContractHash() util.Uint160 PoolTx(t *transaction.Transaction, pools ...*mempool.Pool) error PoolTxWithData(t *transaction.Transaction, data interface{}, mp *mempool.Pool, feer mempool.Feer, verificationFunction func(bc Blockchainer, t *transaction.Transaction, data interface{}) error) error - RegisterPoolTxWithDataCallback(f func(t *transaction.Transaction, data interface{})) RegisterPostBlock(f func(Blockchainer, *mempool.Pool, *block.Block)) SetNotary(mod services.Notary) SubscribeForBlocks(ch chan<- *block.Block) diff --git a/pkg/core/mempool/feer.go b/pkg/core/mempool/feer.go index 1f9e865ed..40b5d0743 100644 --- a/pkg/core/mempool/feer.go +++ b/pkg/core/mempool/feer.go @@ -12,5 +12,4 @@ type Feer interface { GetUtilityTokenBalance(util.Uint160) *big.Int BlockHeight() uint32 P2PSigExtensionsEnabled() bool - P2PNotaryModuleEnabled() bool } diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 306457267..1215b9f52 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -10,6 +10,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/util" + "go.uber.org/atomic" ) var ( @@ -69,8 +70,14 @@ type Pool struct { resendThreshold uint32 resendFunc func(*transaction.Transaction, interface{}) - // removeStaleCallback is a callback method which is called after item is removed from the mempool. - removeStaleCallback func(*transaction.Transaction, interface{}) + + // subscriptions for mempool events + subscriptionsEnabled bool + subscriptionsOn atomic.Bool + stopCh chan struct{} + events chan Event + subCh chan chan<- Event // there are no other events in mempool except Event, so no need in generic subscribers type + unsubCh chan chan<- Event } func (p items) Len() int { return len(p) } @@ -251,6 +258,13 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer, data ...interface{}) e delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) } mp.verifiedTxes[len(mp.verifiedTxes)-1] = pItem + if mp.subscriptionsOn.Load() { + mp.events <- Event{ + Type: TransactionRemoved, + Tx: unlucky.txn, + Data: unlucky.data, + } + } } else { mp.verifiedTxes = append(mp.verifiedTxes, pItem) } @@ -271,6 +285,14 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer, data ...interface{}) e updateMempoolMetrics(len(mp.verifiedTxes)) mp.lock.Unlock() + + if mp.subscriptionsOn.Load() { + mp.events <- Event{ + Type: TransactionAdded, + Tx: pItem.txn, + Data: pItem.data, + } + } return nil } @@ -309,6 +331,13 @@ func (mp *Pool) removeInternal(hash util.Uint256, feer Feer) { if attrs := tx.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 { delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) } + if mp.subscriptionsOn.Load() { + mp.events <- Event{ + Type: TransactionRemoved, + Tx: itm.txn, + Data: itm.data, + } + } } updateMempoolMetrics(len(mp.verifiedTxes)) } @@ -328,8 +357,7 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer) } height := feer.BlockHeight() var ( - staleItems []item - removedItems []item + staleItems []item ) for _, itm := range mp.verifiedTxes { if isOK(itm.txn) && mp.checkPolicy(itm.txn, policyChanged) && mp.tryAddSendersFee(itm.txn, feer, true) { @@ -353,17 +381,18 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer) if attrs := itm.txn.GetAttributes(transaction.OracleResponseT); len(attrs) != 0 { delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) } - if feer.P2PSigExtensionsEnabled() && feer.P2PNotaryModuleEnabled() && mp.removeStaleCallback != nil { - removedItems = append(removedItems, itm) + if mp.subscriptionsOn.Load() { + mp.events <- Event{ + Type: TransactionRemoved, + Tx: itm.txn, + Data: itm.data, + } } } } if len(staleItems) != 0 { go mp.resendStaleItems(staleItems) } - if len(removedItems) != 0 { - go mp.postRemoveStale(removedItems) - } mp.verifiedTxes = newVerifiedTxes mp.lock.Unlock() } @@ -388,16 +417,23 @@ func (mp *Pool) checkPolicy(tx *transaction.Transaction, policyChanged bool) boo } // New returns a new Pool struct. -func New(capacity int, payerIndex int) *Pool { - return &Pool{ - verifiedMap: make(map[util.Uint256]*transaction.Transaction), - verifiedTxes: make([]item, 0, capacity), - capacity: capacity, - payerIndex: payerIndex, - fees: make(map[util.Uint160]utilityBalanceAndFees), - conflicts: make(map[util.Uint256][]util.Uint256), - oracleResp: make(map[uint64]util.Uint256), +func New(capacity int, payerIndex int, enableSubscriptions bool) *Pool { + mp := &Pool{ + verifiedMap: make(map[util.Uint256]*transaction.Transaction), + verifiedTxes: make([]item, 0, capacity), + capacity: capacity, + payerIndex: payerIndex, + fees: make(map[util.Uint160]utilityBalanceAndFees), + conflicts: make(map[util.Uint256][]util.Uint256), + oracleResp: make(map[uint64]util.Uint256), + subscriptionsEnabled: enableSubscriptions, + stopCh: make(chan struct{}), + events: make(chan Event), + subCh: make(chan chan<- Event), + unsubCh: make(chan chan<- Event), } + mp.subscriptionsOn.Store(false) + return mp } // SetResendThreshold sets threshold after which transaction will be considered stale @@ -409,25 +445,12 @@ func (mp *Pool) SetResendThreshold(h uint32, f func(*transaction.Transaction, in mp.resendFunc = f } -// SetRemoveStaleCallback registers new callback method which should be called after mempool item is kicked off. -func (mp *Pool) SetRemoveStaleCallback(f func(t *transaction.Transaction, data interface{})) { - mp.lock.Lock() - defer mp.lock.Unlock() - mp.removeStaleCallback = f -} - func (mp *Pool) resendStaleItems(items []item) { for i := range items { mp.resendFunc(items[i].txn, items[i].data) } } -func (mp *Pool) postRemoveStale(items []item) { - for i := range items { - mp.removeStaleCallback(items[i].txn, items[i].data) - } -} - // TryGetValue returns a transaction and its fee if it exists in the memory pool. func (mp *Pool) TryGetValue(hash util.Uint256) (*transaction.Transaction, bool) { mp.lock.RLock() diff --git a/pkg/core/mempool/mem_pool_test.go b/pkg/core/mempool/mem_pool_test.go index 8885ebe7d..c7b79ef80 100644 --- a/pkg/core/mempool/mem_pool_test.go +++ b/pkg/core/mempool/mem_pool_test.go @@ -44,12 +44,8 @@ func (fs *FeerStub) P2PSigExtensionsEnabled() bool { return fs.p2pSigExt } -func (fs *FeerStub) P2PNotaryModuleEnabled() bool { - return false -} - func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { - mp := New(10, 0) + mp := New(10, 0, false) tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) tx.Nonce = 0 tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} @@ -70,7 +66,7 @@ func testMemPoolAddRemoveWithFeer(t *testing.T, fs Feer) { } func TestMemPoolRemoveStale(t *testing.T) { - mp := New(5, 0) + mp := New(5, 0, false) txs := make([]*transaction.Transaction, 5) for i := range txs { txs[i] = transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) @@ -121,7 +117,7 @@ func TestMemPoolAddRemove(t *testing.T) { func TestOverCapacity(t *testing.T) { var fs = &FeerStub{balance: 10000000} const mempoolSize = 10 - mp := New(mempoolSize, 0) + mp := New(mempoolSize, 0, false) for i := 0; i < mempoolSize; i++ { tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) @@ -197,7 +193,7 @@ func TestOverCapacity(t *testing.T) { func TestGetVerified(t *testing.T) { var fs = &FeerStub{} const mempoolSize = 10 - mp := New(mempoolSize, 0) + mp := New(mempoolSize, 0, false) txes := make([]*transaction.Transaction, 0, mempoolSize) for i := 0; i < mempoolSize; i++ { @@ -221,7 +217,7 @@ func TestGetVerified(t *testing.T) { func TestRemoveStale(t *testing.T) { var fs = &FeerStub{} const mempoolSize = 10 - mp := New(mempoolSize, 0) + mp := New(mempoolSize, 0, false) txes1 := make([]*transaction.Transaction, 0, mempoolSize/2) txes2 := make([]*transaction.Transaction, 0, mempoolSize/2) @@ -254,7 +250,7 @@ func TestRemoveStale(t *testing.T) { } func TestMemPoolFees(t *testing.T) { - mp := New(10, 0) + mp := New(10, 0, false) fs := &FeerStub{balance: 10000000} sender0 := util.Uint160{1, 2, 3} tx0 := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) @@ -365,7 +361,7 @@ func TestMempoolItemsOrder(t *testing.T) { } func TestMempoolAddRemoveOracleResponse(t *testing.T) { - mp := New(3, 0) + mp := New(3, 0, false) nonce := uint32(0) fs := &FeerStub{balance: 10000} newTx := func(netFee int64, id uint64) *transaction.Transaction { @@ -435,7 +431,7 @@ func TestMempoolAddRemoveOracleResponse(t *testing.T) { func TestMempoolAddRemoveConflicts(t *testing.T) { capacity := 6 - mp := New(capacity, 0) + mp := New(capacity, 0, false) var ( fs = &FeerStub{p2pSigExt: true, balance: 100000} nonce uint32 = 1 @@ -565,7 +561,7 @@ func TestMempoolAddWithDataGetData(t *testing.T) { blockHeight: 5, balance: 100, } - mp := New(10, 1) + mp := New(10, 1, false) newTx := func(t *testing.T, netFee int64) *transaction.Transaction { tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.RET)}, 0) tx.Signers = []transaction.Signer{{}, {}} diff --git a/pkg/core/mempool/subscriptions.go b/pkg/core/mempool/subscriptions.go new file mode 100644 index 000000000..9e1f667a9 --- /dev/null +++ b/pkg/core/mempool/subscriptions.go @@ -0,0 +1,86 @@ +package mempool + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/transaction" +) + +// EventType represents mempool event type. +type EventType byte + +const ( + // TransactionAdded marks transaction addition mempool event. + TransactionAdded EventType = 0x01 + // TransactionRemoved marks transaction removal mempool event. + TransactionRemoved EventType = 0x02 +) + +// Event represents one of mempool events: transaction was added or removed from mempool. +type Event struct { + Type EventType + Tx *transaction.Transaction + Data interface{} +} + +// RunSubscriptions runs subscriptions goroutine if mempool subscriptions are enabled. +// You should manually free the resources by calling StopSubscriptions on mempool shutdown. +func (mp *Pool) RunSubscriptions() { + if !mp.subscriptionsEnabled { + panic("subscriptions are disabled") + } + if !mp.subscriptionsOn.Load() { + mp.subscriptionsOn.Store(true) + go mp.notificationDispatcher() + } +} + +// StopSubscriptions stops mempool events loop. +func (mp *Pool) StopSubscriptions() { + if !mp.subscriptionsEnabled { + panic("subscriptions are disabled") + } + if mp.subscriptionsOn.Load() { + mp.subscriptionsOn.Store(false) + close(mp.stopCh) + } +} + +// SubscribeForTransactions adds given channel to new mempool event broadcasting, so when +// there is a new transactions added to mempool or an existing transaction removed from +// mempool you'll receive it via this channel. +func (mp *Pool) SubscribeForTransactions(ch chan<- Event) { + if mp.subscriptionsOn.Load() { + mp.subCh <- ch + } +} + +// UnsubscribeFromTransactions unsubscribes given channel from new mempool notifications, +// you can close it afterwards. Passing non-subscribed channel is a no-op. +func (mp *Pool) UnsubscribeFromTransactions(ch chan<- Event) { + if mp.subscriptionsOn.Load() { + mp.unsubCh <- ch + } +} + +// notificationDispatcher manages subscription to events and broadcasts new events. +func (mp *Pool) notificationDispatcher() { + var ( + // These are just sets of subscribers, though modelled as maps + // for ease of management (not a lot of subscriptions is really + // expected, but maps are convenient for adding/deleting elements). + txFeed = make(map[chan<- Event]bool) + ) + for { + select { + case <-mp.stopCh: + return + case sub := <-mp.subCh: + txFeed[sub] = true + case unsub := <-mp.unsubCh: + delete(txFeed, unsub) + case event := <-mp.events: + for ch := range txFeed { + ch <- event + } + } + } +} diff --git a/pkg/core/mempool/subscriptions_test.go b/pkg/core/mempool/subscriptions_test.go new file mode 100644 index 000000000..5c123ec0e --- /dev/null +++ b/pkg/core/mempool/subscriptions_test.go @@ -0,0 +1,98 @@ +package mempool + +import ( + "testing" + "time" + + "github.com/nspcc-dev/neo-go/pkg/config/netmode" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/stretchr/testify/require" +) + +func TestSubscriptions(t *testing.T) { + t.Run("disabled subscriptions", func(t *testing.T) { + mp := New(5, 0, false) + require.Panics(t, func() { + mp.RunSubscriptions() + }) + require.Panics(t, func() { + mp.StopSubscriptions() + }) + }) + + t.Run("enabled subscriptions", func(t *testing.T) { + fs := &FeerStub{balance: 100} + mp := New(2, 0, true) + mp.RunSubscriptions() + subChan1 := make(chan Event, 3) + subChan2 := make(chan Event, 3) + mp.SubscribeForTransactions(subChan1) + defer mp.StopSubscriptions() + + txs := make([]*transaction.Transaction, 4) + for i := range txs { + txs[i] = transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 0) + txs[i].Nonce = uint32(i) + txs[i].Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3}}} + txs[i].NetworkFee = int64(i) + } + + // add tx + require.NoError(t, mp.Add(txs[0], fs)) + require.Eventually(t, func() bool { return len(subChan1) == 1 }, time.Second, time.Millisecond*100) + event := <-subChan1 + require.Equal(t, Event{Type: TransactionAdded, Tx: txs[0]}, event) + + // severak subscribers + mp.SubscribeForTransactions(subChan2) + require.NoError(t, mp.Add(txs[1], fs)) + require.Eventually(t, func() bool { return len(subChan1) == 1 && len(subChan2) == 1 }, time.Second, time.Millisecond*100) + event1 := <-subChan1 + event2 := <-subChan2 + require.Equal(t, Event{Type: TransactionAdded, Tx: txs[1]}, event1) + require.Equal(t, Event{Type: TransactionAdded, Tx: txs[1]}, event2) + + // reach capacity + require.NoError(t, mp.Add(txs[2], &FeerStub{})) + require.Eventually(t, func() bool { return len(subChan1) == 2 && len(subChan2) == 2 }, time.Second, time.Millisecond*100) + event1 = <-subChan1 + event2 = <-subChan2 + require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[0]}, event1) + require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[0]}, event2) + event1 = <-subChan1 + event2 = <-subChan2 + require.Equal(t, Event{Type: TransactionAdded, Tx: txs[2]}, event1) + require.Equal(t, Event{Type: TransactionAdded, Tx: txs[2]}, event2) + + // remove tx + mp.Remove(txs[1].Hash(), fs) + require.Eventually(t, func() bool { return len(subChan1) == 1 && len(subChan2) == 1 }, time.Second, time.Millisecond*100) + event1 = <-subChan1 + event2 = <-subChan2 + require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[1]}, event1) + require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[1]}, event2) + + // remove stale + mp.RemoveStale(func(tx *transaction.Transaction) bool { + if tx.Hash().Equals(txs[2].Hash()) { + return false + } + return true + }, fs) + require.Eventually(t, func() bool { return len(subChan1) == 1 && len(subChan2) == 1 }, time.Second, time.Millisecond*100) + event1 = <-subChan1 + event2 = <-subChan2 + require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[2]}, event1) + require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[2]}, event2) + + // unsubscribe + mp.UnsubscribeFromTransactions(subChan1) + require.NoError(t, mp.Add(txs[3], fs)) + require.Eventually(t, func() bool { return len(subChan2) == 1 }, time.Second, time.Millisecond*100) + event2 = <-subChan2 + require.Equal(t, 0, len(subChan1)) + require.Equal(t, Event{Type: TransactionAdded, Tx: txs[3]}, event2) + }) +} diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index e14868772..eb4ab9509 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -52,7 +52,7 @@ type testChain struct { func newTestChain() *testChain { return &testChain{ - Pool: mempool.New(10, 0), + Pool: mempool.New(10, 0, false), poolTx: func(*transaction.Transaction) error { return nil }, poolTxWithData: func(*transaction.Transaction, interface{}, *mempool.Pool) error { return nil }, blocks: make(map[util.Uint256]*block.Block), @@ -143,10 +143,6 @@ func (chain *testChain) P2PSigExtensionsEnabled() bool { return true } -func (chain *testChain) P2PNotaryModuleEnabled() bool { - return false -} - func (chain *testChain) GetMaxBlockSystemFee() int64 { panic("TODO") } @@ -289,9 +285,6 @@ func (chain *testChain) PoolTx(tx *transaction.Transaction, _ ...*mempool.Pool) func (chain testChain) SetOracle(services.Oracle) { panic("TODO") } -func (chain *testChain) RegisterPoolTxWithDataCallback(f func(t *transaction.Transaction, data interface{})) { - panic("TODO") -} func (chain *testChain) SetNotary(notary services.Notary) { panic("TODO") } diff --git a/pkg/network/notary_feer.go b/pkg/network/notary_feer.go index 2d7f78969..97e179234 100644 --- a/pkg/network/notary_feer.go +++ b/pkg/network/notary_feer.go @@ -32,11 +32,6 @@ func (f NotaryFeer) P2PSigExtensionsEnabled() bool { return f.bc.P2PSigExtensionsEnabled() } -// P2PNotaryModuleEnabled implements mempool.Feer interface. -func (f NotaryFeer) P2PNotaryModuleEnabled() bool { - return f.bc.P2PNotaryModuleEnabled() -} - // NewNotaryFeer returns new NotaryFeer instance. func NewNotaryFeer(bc blockchainer.Blockchainer) NotaryFeer { return NotaryFeer{ diff --git a/pkg/network/server.go b/pkg/network/server.go index ab6698a6f..e2be7b021 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -137,14 +137,14 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai } if chain.P2PSigExtensionsEnabled() { s.notaryFeer = NewNotaryFeer(chain) - s.notaryRequestPool = mempool.New(chain.GetConfig().P2PNotaryRequestPayloadPoolSize, 1) + s.notaryRequestPool = mempool.New(chain.GetConfig().P2PNotaryRequestPayloadPoolSize, 1, chain.GetConfig().P2PNotary.Enabled) chain.RegisterPostBlock(func(bc blockchainer.Blockchainer, txpool *mempool.Pool, _ *block.Block) { s.notaryRequestPool.RemoveStale(func(t *transaction.Transaction) bool { return bc.IsTxStillRelevant(t, txpool, true) }, s.notaryFeer) }) if chain.GetConfig().P2PNotary.Enabled { - n, err := notary.NewNotary(chain, s.log, func(tx *transaction.Transaction) error { + n, err := notary.NewNotary(chain, s.notaryRequestPool, s.log, func(tx *transaction.Transaction) error { r := s.RelayTxn(tx) if r != RelaySucceed { return fmt.Errorf("can't pool notary tx: hash %s, reason: %d", tx.Hash().StringLE(), byte(r)) @@ -156,11 +156,6 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai } s.notaryModule = n chain.SetNotary(n) - chain.RegisterPoolTxWithDataCallback(func(_ *transaction.Transaction, data interface{}) { - if notaryRequest, ok := data.(*payload.P2PNotaryRequest); ok { - s.notaryModule.OnNewRequest(notaryRequest) - } - }) chain.RegisterPostBlock(func(bc blockchainer.Blockchainer, pool *mempool.Pool, b *block.Block) { s.notaryModule.PostPersist(bc, pool, b) }) @@ -261,6 +256,10 @@ func (s *Server) Start(errChan chan error) { if s.oracle != nil { go s.oracle.Run() } + if s.notaryModule != nil { + s.notaryRequestPool.RunSubscriptions() + go s.notaryModule.Run() + } go s.relayBlocksLoop() go s.bQueue.run() go s.transport.Accept() @@ -283,6 +282,10 @@ func (s *Server) Shutdown() { if s.oracle != nil { s.oracle.Shutdown() } + if s.notaryModule != nil { + s.notaryModule.Stop() + s.notaryRequestPool.StopSubscriptions() + } close(s.quit) } @@ -1195,13 +1198,6 @@ func (s *Server) initStaleMemPools() { mp.SetResendThreshold(uint32(threshold), s.broadcastTX) if s.chain.P2PSigExtensionsEnabled() { s.notaryRequestPool.SetResendThreshold(uint32(threshold), s.broadcastP2PNotaryRequestPayload) - if s.chain.GetConfig().P2PNotary.Enabled { - s.notaryRequestPool.SetRemoveStaleCallback(func(_ *transaction.Transaction, data interface{}) { - if notaryRequest, ok := data.(*payload.P2PNotaryRequest); ok { - s.notaryModule.OnRequestRemoval(notaryRequest) - } - }) - } } } diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 0ec9ec7b8..04cf91ca7 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -850,7 +850,6 @@ func (f feerStub) FeePerByte() int64 { return 1 } func (f feerStub) GetUtilityTokenBalance(util.Uint160) *big.Int { return big.NewInt(100000000) } func (f feerStub) BlockHeight() uint32 { return f.blockHeight } func (f feerStub) P2PSigExtensionsEnabled() bool { return false } -func (f feerStub) P2PNotaryModuleEnabled() bool { return false } func (f feerStub) GetBaseExecFee() int64 { return interop.DefaultBaseExecFee } func TestMemPool(t *testing.T) { diff --git a/pkg/rpc/server/server_helper_test.go b/pkg/rpc/server/server_helper_test.go index 2079aa575..3511ac5c2 100644 --- a/pkg/rpc/server/server_helper_test.go +++ b/pkg/rpc/server/server_helper_test.go @@ -126,10 +126,6 @@ func (fs FeerStub) P2PSigExtensionsEnabled() bool { return false } -func (fs FeerStub) P2PNotaryModuleEnabled() bool { - return false -} - func (fs FeerStub) GetBaseExecFee() int64 { return interop.DefaultBaseExecFee } diff --git a/pkg/services/notary/notary.go b/pkg/services/notary/notary.go index 5bbffe5b0..1c16ff9ba 100644 --- a/pkg/services/notary/notary.go +++ b/pkg/services/notary/notary.go @@ -41,6 +41,11 @@ type ( accMtx sync.RWMutex currAccount *wallet.Account wallet *wallet.Wallet + + mp *mempool.Pool + // requests channel + reqCh chan mempool.Event + stopCh chan struct{} } // Config represents external configuration for Notary module. @@ -74,7 +79,7 @@ type request struct { } // NewNotary returns new Notary module. -func NewNotary(bc blockchainer.Blockchainer, log *zap.Logger, onTransaction func(tx *transaction.Transaction) error) (*Notary, error) { +func NewNotary(bc blockchainer.Blockchainer, mp *mempool.Pool, log *zap.Logger, onTransaction func(tx *transaction.Transaction) error) (*Notary, error) { cfg := bc.GetConfig().P2PNotary w := cfg.UnlockWallet wallet, err := wallet.NewWalletFromFile(w.Path) @@ -102,9 +107,38 @@ func NewNotary(bc blockchainer.Blockchainer, log *zap.Logger, onTransaction func }, wallet: wallet, onTransaction: onTransaction, + mp: mp, + reqCh: make(chan mempool.Event), + stopCh: make(chan struct{}), }, nil } +// Run runs Notary module and should be called in a separate goroutine. +func (n *Notary) Run() { + n.mp.SubscribeForTransactions(n.reqCh) + for { + select { + case <-n.stopCh: + n.mp.UnsubscribeFromTransactions(n.reqCh) + return + case event := <-n.reqCh: + if req, ok := event.Data.(*payload.P2PNotaryRequest); ok { + switch event.Type { + case mempool.TransactionAdded: + n.OnNewRequest(req) + case mempool.TransactionRemoved: + n.OnRequestRemoval(req) + } + } + } + } +} + +// Stop shutdowns Notary module. +func (n *Notary) Stop() { + close(n.stopCh) +} + // OnNewRequest is a callback method which is called after new notary request is added to the notary request pool. func (n *Notary) OnNewRequest(payload *payload.P2PNotaryRequest) { if n.getAccount() == nil {