diff --git a/pkg/network/server.go b/pkg/network/server.go index 432f2c7f3..8d367424e 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -13,6 +13,7 @@ import ( "sort" "strconv" "sync" + satomic "sync/atomic" "time" "github.com/nspcc-dev/neo-go/pkg/config" @@ -109,7 +110,7 @@ type ( services map[string]Service extensHandlers map[string]func(*payload.Extensible) error txCallback func(*transaction.Transaction) - txCbEnabled atomic.Bool + txCbList satomic.Value txInLock sync.RWMutex txin chan *transaction.Transaction @@ -1106,8 +1107,17 @@ txloop: s.serviceLock.RLock() txCallback := s.txCallback s.serviceLock.RUnlock() - if txCallback != nil && s.txCbEnabled.Load() { - txCallback(tx) + if txCallback != nil { + var cbList = s.txCbList.Load() + if cbList != nil { + var list = cbList.([]util.Uint256) + var i = sort.Search(len(list), func(i int) bool { + return list[i].CompareTo(tx.Hash()) >= 0 + }) + if i < len(list) && list[i].Equals(tx.Hash()) { + txCallback(tx) + } + } } if s.verifyAndPoolTX(tx) == nil { s.broadcastTX(tx, nil) @@ -1420,7 +1430,13 @@ func (s *Server) RequestTx(hashes ...util.Uint256) { return } - s.txCbEnabled.Store(true) + var sorted = make([]util.Uint256, len(hashes)) + copy(sorted, hashes) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].CompareTo(sorted[j]) < 0 + }) + + s.txCbList.Store(sorted) for i := 0; i <= len(hashes)/payload.MaxHashesCount; i++ { start := i * payload.MaxHashesCount @@ -1440,7 +1456,8 @@ func (s *Server) RequestTx(hashes ...util.Uint256) { // StopTxFlow makes the server not call previously specified consensus transaction callback. func (s *Server) StopTxFlow() { - s.txCbEnabled.Store(false) + var hashes []util.Uint256 + s.txCbList.Store(hashes) } // iteratePeersWithSendMsg sends the given message to all peers using two functions diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index b6684b1b5..0f6b1e344 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -467,10 +467,10 @@ func TestTransaction(t *testing.T) { cons := new(fakeConsensus) s.AddConsensusService(cons, cons.OnPayload, cons.OnTransaction) startWithCleanup(t, s) - s.RequestTx(util.Uint256{1}) t.Run("good", func(t *testing.T) { tx := newDummyTx() + s.RequestTx(tx.Hash()) p := newLocalPeer(t, s) p.isFullNode = true p.messageHandler = func(t *testing.T, msg *Message) { @@ -497,6 +497,7 @@ func TestTransaction(t *testing.T) { }) t.Run("bad", func(t *testing.T) { tx := newDummyTx() + s.RequestTx(tx.Hash()) s.chain.(*fakechain.FakeChain).PoolTxF = func(*transaction.Transaction) error { return core.ErrInsufficientFunds } s.testHandleMessage(t, nil, CMDTX, tx) require.Eventually(t, func() bool {