diff --git a/pkg/core/mempool/mem_pool.go b/pkg/core/mempool/mem_pool.go index 53b6f60bb..47a9a8723 100644 --- a/pkg/core/mempool/mem_pool.go +++ b/pkg/core/mempool/mem_pool.go @@ -8,6 +8,7 @@ import ( "sort" "sync" + "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/util" "go.uber.org/atomic" @@ -75,9 +76,9 @@ type Pool struct { subscriptionsEnabled bool subscriptionsOn atomic.Bool stopCh chan struct{} - events chan Event - subCh chan chan<- Event // there are no other events in mempool except Event, so no need in generic subscribers type - unsubCh chan chan<- Event + events chan mempoolevent.Event + subCh chan chan<- mempoolevent.Event // there are no other events in mempool except Event, so no need in generic subscribers type + unsubCh chan chan<- mempoolevent.Event } func (p items) Len() int { return len(p) } @@ -259,8 +260,8 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer, data ...interface{}) e } mp.verifiedTxes[len(mp.verifiedTxes)-1] = pItem if mp.subscriptionsOn.Load() { - mp.events <- Event{ - Type: TransactionRemoved, + mp.events <- mempoolevent.Event{ + Type: mempoolevent.TransactionRemoved, Tx: unlucky.txn, Data: unlucky.data, } @@ -287,8 +288,8 @@ func (mp *Pool) Add(t *transaction.Transaction, fee Feer, data ...interface{}) e mp.lock.Unlock() if mp.subscriptionsOn.Load() { - mp.events <- Event{ - Type: TransactionAdded, + mp.events <- mempoolevent.Event{ + Type: mempoolevent.TransactionAdded, Tx: pItem.txn, Data: pItem.data, } @@ -332,8 +333,8 @@ func (mp *Pool) removeInternal(hash util.Uint256, feer Feer) { delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) } if mp.subscriptionsOn.Load() { - mp.events <- Event{ - Type: TransactionRemoved, + mp.events <- mempoolevent.Event{ + Type: mempoolevent.TransactionRemoved, Tx: itm.txn, Data: itm.data, } @@ -382,8 +383,8 @@ func (mp *Pool) RemoveStale(isOK func(*transaction.Transaction) bool, feer Feer) delete(mp.oracleResp, attrs[0].Value.(*transaction.OracleResponse).ID) } if mp.subscriptionsOn.Load() { - mp.events <- Event{ - Type: TransactionRemoved, + mp.events <- mempoolevent.Event{ + Type: mempoolevent.TransactionRemoved, Tx: itm.txn, Data: itm.data, } @@ -428,9 +429,9 @@ func New(capacity int, payerIndex int, enableSubscriptions bool) *Pool { oracleResp: make(map[uint64]util.Uint256), subscriptionsEnabled: enableSubscriptions, stopCh: make(chan struct{}), - events: make(chan Event), - subCh: make(chan chan<- Event), - unsubCh: make(chan chan<- Event), + events: make(chan mempoolevent.Event), + subCh: make(chan chan<- mempoolevent.Event), + unsubCh: make(chan chan<- mempoolevent.Event), } mp.subscriptionsOn.Store(false) return mp diff --git a/pkg/core/mempool/subscriptions.go b/pkg/core/mempool/subscriptions.go index 9e1f667a9..066d0651b 100644 --- a/pkg/core/mempool/subscriptions.go +++ b/pkg/core/mempool/subscriptions.go @@ -1,25 +1,6 @@ package mempool -import ( - "github.com/nspcc-dev/neo-go/pkg/core/transaction" -) - -// EventType represents mempool event type. -type EventType byte - -const ( - // TransactionAdded marks transaction addition mempool event. - TransactionAdded EventType = 0x01 - // TransactionRemoved marks transaction removal mempool event. - TransactionRemoved EventType = 0x02 -) - -// Event represents one of mempool events: transaction was added or removed from mempool. -type Event struct { - Type EventType - Tx *transaction.Transaction - Data interface{} -} +import "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" // RunSubscriptions runs subscriptions goroutine if mempool subscriptions are enabled. // You should manually free the resources by calling StopSubscriptions on mempool shutdown. @@ -47,7 +28,7 @@ func (mp *Pool) StopSubscriptions() { // SubscribeForTransactions adds given channel to new mempool event broadcasting, so when // there is a new transactions added to mempool or an existing transaction removed from // mempool you'll receive it via this channel. -func (mp *Pool) SubscribeForTransactions(ch chan<- Event) { +func (mp *Pool) SubscribeForTransactions(ch chan<- mempoolevent.Event) { if mp.subscriptionsOn.Load() { mp.subCh <- ch } @@ -55,7 +36,7 @@ func (mp *Pool) SubscribeForTransactions(ch chan<- Event) { // UnsubscribeFromTransactions unsubscribes given channel from new mempool notifications, // you can close it afterwards. Passing non-subscribed channel is a no-op. -func (mp *Pool) UnsubscribeFromTransactions(ch chan<- Event) { +func (mp *Pool) UnsubscribeFromTransactions(ch chan<- mempoolevent.Event) { if mp.subscriptionsOn.Load() { mp.unsubCh <- ch } @@ -67,7 +48,7 @@ func (mp *Pool) notificationDispatcher() { // These are just sets of subscribers, though modelled as maps // for ease of management (not a lot of subscriptions is really // expected, but maps are convenient for adding/deleting elements). - txFeed = make(map[chan<- Event]bool) + txFeed = make(map[chan<- mempoolevent.Event]bool) ) for { select { diff --git a/pkg/core/mempool/subscriptions_test.go b/pkg/core/mempool/subscriptions_test.go index a2fb3a91f..bbe1bc2a3 100644 --- a/pkg/core/mempool/subscriptions_test.go +++ b/pkg/core/mempool/subscriptions_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -25,8 +26,8 @@ func TestSubscriptions(t *testing.T) { fs := &FeerStub{balance: 100} mp := New(2, 0, true) mp.RunSubscriptions() - subChan1 := make(chan Event, 3) - subChan2 := make(chan Event, 3) + subChan1 := make(chan mempoolevent.Event, 3) + subChan2 := make(chan mempoolevent.Event, 3) mp.SubscribeForTransactions(subChan1) t.Cleanup(mp.StopSubscriptions) @@ -42,7 +43,7 @@ func TestSubscriptions(t *testing.T) { require.NoError(t, mp.Add(txs[0], fs)) require.Eventually(t, func() bool { return len(subChan1) == 1 }, time.Second, time.Millisecond*100) event := <-subChan1 - require.Equal(t, Event{Type: TransactionAdded, Tx: txs[0]}, event) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionAdded, Tx: txs[0]}, event) // severak subscribers mp.SubscribeForTransactions(subChan2) @@ -50,28 +51,28 @@ func TestSubscriptions(t *testing.T) { require.Eventually(t, func() bool { return len(subChan1) == 1 && len(subChan2) == 1 }, time.Second, time.Millisecond*100) event1 := <-subChan1 event2 := <-subChan2 - require.Equal(t, Event{Type: TransactionAdded, Tx: txs[1]}, event1) - require.Equal(t, Event{Type: TransactionAdded, Tx: txs[1]}, event2) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionAdded, Tx: txs[1]}, event1) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionAdded, Tx: txs[1]}, event2) // reach capacity require.NoError(t, mp.Add(txs[2], &FeerStub{})) require.Eventually(t, func() bool { return len(subChan1) == 2 && len(subChan2) == 2 }, time.Second, time.Millisecond*100) event1 = <-subChan1 event2 = <-subChan2 - require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[0]}, event1) - require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[0]}, event2) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionRemoved, Tx: txs[0]}, event1) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionRemoved, Tx: txs[0]}, event2) event1 = <-subChan1 event2 = <-subChan2 - require.Equal(t, Event{Type: TransactionAdded, Tx: txs[2]}, event1) - require.Equal(t, Event{Type: TransactionAdded, Tx: txs[2]}, event2) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionAdded, Tx: txs[2]}, event1) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionAdded, Tx: txs[2]}, event2) // remove tx mp.Remove(txs[1].Hash(), fs) require.Eventually(t, func() bool { return len(subChan1) == 1 && len(subChan2) == 1 }, time.Second, time.Millisecond*100) event1 = <-subChan1 event2 = <-subChan2 - require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[1]}, event1) - require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[1]}, event2) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionRemoved, Tx: txs[1]}, event1) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionRemoved, Tx: txs[1]}, event2) // remove stale mp.RemoveStale(func(tx *transaction.Transaction) bool { @@ -80,8 +81,8 @@ func TestSubscriptions(t *testing.T) { require.Eventually(t, func() bool { return len(subChan1) == 1 && len(subChan2) == 1 }, time.Second, time.Millisecond*100) event1 = <-subChan1 event2 = <-subChan2 - require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[2]}, event1) - require.Equal(t, Event{Type: TransactionRemoved, Tx: txs[2]}, event2) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionRemoved, Tx: txs[2]}, event1) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionRemoved, Tx: txs[2]}, event2) // unsubscribe mp.UnsubscribeFromTransactions(subChan1) @@ -89,6 +90,6 @@ func TestSubscriptions(t *testing.T) { require.Eventually(t, func() bool { return len(subChan2) == 1 }, time.Second, time.Millisecond*100) event2 = <-subChan2 require.Equal(t, 0, len(subChan1)) - require.Equal(t, Event{Type: TransactionAdded, Tx: txs[3]}, event2) + require.Equal(t, mempoolevent.Event{Type: mempoolevent.TransactionAdded, Tx: txs[3]}, event2) }) } diff --git a/pkg/core/mempoolevent/event.go b/pkg/core/mempoolevent/event.go new file mode 100644 index 000000000..5361c85a6 --- /dev/null +++ b/pkg/core/mempoolevent/event.go @@ -0,0 +1,70 @@ +package mempoolevent + +import ( + "encoding/json" + "errors" + + "github.com/nspcc-dev/neo-go/pkg/core/transaction" +) + +// Type represents mempool event type. +type Type byte + +const ( + // TransactionAdded marks transaction addition mempool event. + TransactionAdded Type = 0x01 + // TransactionRemoved marks transaction removal mempool event. + TransactionRemoved Type = 0x02 +) + +// Event represents one of mempool events: transaction was added or removed from mempool. +type Event struct { + Type Type + Tx *transaction.Transaction + Data interface{} +} + +// String is a Stringer implementation. +func (e Type) String() string { + switch e { + case TransactionAdded: + return "added" + case TransactionRemoved: + return "removed" + default: + return "unknown" + } +} + +// GetEventTypeFromString converts input string into an Type if it's possible. +func GetEventTypeFromString(s string) (Type, error) { + switch s { + case "added": + return TransactionAdded, nil + case "removed": + return TransactionRemoved, nil + default: + return 0, errors.New("invalid event type name") + } +} + +// MarshalJSON implements json.Marshaler interface. +func (e Type) MarshalJSON() ([]byte, error) { + return json.Marshal(e.String()) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (e *Type) UnmarshalJSON(b []byte) error { + var s string + + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + id, err := GetEventTypeFromString(s) + if err != nil { + return err + } + *e = id + return nil +} diff --git a/pkg/services/notary/notary.go b/pkg/services/notary/notary.go index a2f5d7eb5..8ba2f2486 100644 --- a/pkg/services/notary/notary.go +++ b/pkg/services/notary/notary.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/mempool" + "github.com/nspcc-dev/neo-go/pkg/core/mempoolevent" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" @@ -48,7 +49,7 @@ type ( mp *mempool.Pool // requests channel - reqCh chan mempool.Event + reqCh chan mempoolevent.Event blocksCh chan *block.Block stopCh chan struct{} } @@ -109,7 +110,7 @@ func NewNotary(cfg Config, net netmode.Magic, mp *mempool.Pool, onTransaction fu wallet: wallet, onTransaction: onTransaction, mp: mp, - reqCh: make(chan mempool.Event), + reqCh: make(chan mempoolevent.Event), blocksCh: make(chan *block.Block), stopCh: make(chan struct{}), }, nil @@ -129,9 +130,9 @@ func (n *Notary) Run() { case event := <-n.reqCh: if req, ok := event.Data.(*payload.P2PNotaryRequest); ok { switch event.Type { - case mempool.TransactionAdded: + case mempoolevent.TransactionAdded: n.OnNewRequest(req) - case mempool.TransactionRemoved: + case mempoolevent.TransactionRemoved: n.OnRequestRemoval(req) } }