diff --git a/pkg/neorpc/rpcevent/filter.go b/pkg/neorpc/rpcevent/filter.go new file mode 100644 index 000000000..e2ca74c1b --- /dev/null +++ b/pkg/neorpc/rpcevent/filter.go @@ -0,0 +1,83 @@ +package rpcevent + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/neorpc" + "github.com/nspcc-dev/neo-go/pkg/neorpc/result" +) + +type ( + // Comparator is an interface required from notification event filter to be able to + // filter notifications. + Comparator interface { + EventID() neorpc.EventID + Filter() interface{} + } + // Container is an interface required from notification event to be able to + // pass filter. + Container interface { + EventID() neorpc.EventID + EventPayload() interface{} + } +) + +// Matches filters our given Container against Comparator filter. +func Matches(f Comparator, r Container) bool { + expectedEvent := f.EventID() + filter := f.Filter() + if r.EventID() != expectedEvent { + return false + } + if filter == nil { + return true + } + switch f.EventID() { + case neorpc.BlockEventID: + filt := filter.(neorpc.BlockFilter) + b := r.EventPayload().(*block.Block) + return int(b.PrimaryIndex) == filt.Primary + case neorpc.TransactionEventID: + filt := filter.(neorpc.TxFilter) + tx := r.EventPayload().(*transaction.Transaction) + senderOK := filt.Sender == nil || tx.Sender().Equals(*filt.Sender) + signerOK := true + if filt.Signer != nil { + signerOK = false + for i := range tx.Signers { + if tx.Signers[i].Account.Equals(*filt.Signer) { + signerOK = true + break + } + } + } + return senderOK && signerOK + case neorpc.NotificationEventID: + filt := filter.(neorpc.NotificationFilter) + notification := r.EventPayload().(*state.ContainedNotificationEvent) + hashOk := filt.Contract == nil || notification.ScriptHash.Equals(*filt.Contract) + nameOk := filt.Name == nil || notification.Name == *filt.Name + return hashOk && nameOk + case neorpc.ExecutionEventID: + filt := filter.(neorpc.ExecutionFilter) + applog := r.EventPayload().(*state.AppExecResult) + return applog.VMState.String() == filt.State + case neorpc.NotaryRequestEventID: + filt := filter.(neorpc.TxFilter) + req := r.EventPayload().(*result.NotaryRequestEvent) + senderOk := filt.Sender == nil || req.NotaryRequest.FallbackTransaction.Signers[1].Account == *filt.Sender + signerOK := true + if filt.Signer != nil { + signerOK = false + for _, signer := range req.NotaryRequest.MainTransaction.Signers { + if signer.Account.Equals(*filt.Signer) { + signerOK = true + break + } + } + } + return senderOk && signerOK + } + return false +} diff --git a/pkg/neorpc/rpcevent/filter_test.go b/pkg/neorpc/rpcevent/filter_test.go new file mode 100644 index 000000000..d16c222e6 --- /dev/null +++ b/pkg/neorpc/rpcevent/filter_test.go @@ -0,0 +1,252 @@ +package rpcevent + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/neorpc" + "github.com/nspcc-dev/neo-go/pkg/neorpc/result" + "github.com/nspcc-dev/neo-go/pkg/network/payload" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/vmstate" + "github.com/stretchr/testify/require" +) + +type ( + testComparator struct { + id neorpc.EventID + filter interface{} + } + testContainer struct { + id neorpc.EventID + pld interface{} + } +) + +func (c testComparator) EventID() neorpc.EventID { + return c.id +} +func (c testComparator) Filter() interface{} { + return c.filter +} +func (c testContainer) EventID() neorpc.EventID { + return c.id +} +func (c testContainer) EventPayload() interface{} { + return c.pld +} + +func TestMatches(t *testing.T) { + primary := byte(1) + sender := util.Uint160{1, 2, 3} + signer := util.Uint160{4, 5, 6} + contract := util.Uint160{7, 8, 9} + badUint160 := util.Uint160{9, 9, 9} + name := "ntf name" + badName := "bad name" + bContainer := testContainer{ + id: neorpc.BlockEventID, + pld: &block.Block{ + Header: block.Header{PrimaryIndex: primary}, + }, + } + st := vmstate.Halt + badState := "FAULT" + txContainer := testContainer{ + id: neorpc.TransactionEventID, + pld: &transaction.Transaction{Signers: []transaction.Signer{{Account: sender}, {Account: signer}}}, + } + ntfContainer := testContainer{ + id: neorpc.NotificationEventID, + pld: &state.ContainedNotificationEvent{NotificationEvent: state.NotificationEvent{ScriptHash: contract, Name: name}}, + } + exContainer := testContainer{ + id: neorpc.ExecutionEventID, + pld: &state.AppExecResult{Execution: state.Execution{VMState: st}}, + } + ntrContainer := testContainer{ + id: neorpc.NotaryRequestEventID, + pld: &result.NotaryRequestEvent{ + NotaryRequest: &payload.P2PNotaryRequest{ + MainTransaction: &transaction.Transaction{Signers: []transaction.Signer{{Account: signer}}}, + FallbackTransaction: &transaction.Transaction{Signers: []transaction.Signer{{Account: util.Uint160{}}, {Account: sender}}}, + }, + }, + } + missedContainer := testContainer{ + id: neorpc.MissedEventID, + } + var testCases = []struct { + name string + comparator testComparator + container testContainer + expected bool + }{ + { + name: "ID mismatch", + comparator: testComparator{id: neorpc.TransactionEventID}, + container: bContainer, + expected: false, + }, + { + name: "missed event", + comparator: testComparator{id: neorpc.BlockEventID}, + container: missedContainer, + expected: false, + }, + { + name: "block, no filter", + comparator: testComparator{id: neorpc.BlockEventID}, + container: bContainer, + expected: true, + }, + { + name: "block, primary mismatch", + comparator: testComparator{ + id: neorpc.BlockEventID, + filter: neorpc.BlockFilter{Primary: int(primary + 1)}, + }, + container: bContainer, + expected: false, + }, + { + name: "block, filter match", + comparator: testComparator{ + id: neorpc.BlockEventID, + filter: neorpc.BlockFilter{Primary: int(primary)}, + }, + container: bContainer, + expected: true, + }, + { + name: "transaction, no filter", + comparator: testComparator{id: neorpc.TransactionEventID}, + container: txContainer, + expected: true, + }, + { + name: "transaction, sender mismatch", + comparator: testComparator{ + id: neorpc.TransactionEventID, + filter: neorpc.TxFilter{Sender: &badUint160}, + }, + container: txContainer, + expected: false, + }, + { + name: "transaction, signer mismatch", + comparator: testComparator{ + id: neorpc.TransactionEventID, + filter: neorpc.TxFilter{Signer: &badUint160}, + }, + container: txContainer, + expected: false, + }, + { + name: "transaction, filter match", + comparator: testComparator{ + id: neorpc.TransactionEventID, + filter: neorpc.TxFilter{Sender: &sender, Signer: &signer}, + }, + container: txContainer, + expected: true, + }, + { + name: "notification, no filter", + comparator: testComparator{id: neorpc.NotificationEventID}, + container: ntfContainer, + expected: true, + }, + { + name: "notification, contract mismatch", + comparator: testComparator{ + id: neorpc.NotificationEventID, + filter: neorpc.NotificationFilter{Contract: &badUint160}, + }, + container: ntfContainer, + expected: false, + }, + { + name: "notification, name mismatch", + comparator: testComparator{ + id: neorpc.NotificationEventID, + filter: neorpc.NotificationFilter{Name: &badName}, + }, + container: ntfContainer, + expected: false, + }, + { + name: "notification, filter match", + comparator: testComparator{ + id: neorpc.NotificationEventID, + filter: neorpc.NotificationFilter{Name: &name, Contract: &contract}, + }, + container: ntfContainer, + expected: true, + }, + { + name: "execution, no filter", + comparator: testComparator{id: neorpc.ExecutionEventID}, + container: exContainer, + expected: true, + }, + { + name: "execution, state mismatch", + comparator: testComparator{ + id: neorpc.ExecutionEventID, + filter: neorpc.ExecutionFilter{State: badState}, + }, + container: exContainer, + expected: false, + }, + { + name: "execution, filter mismatch", + comparator: testComparator{ + id: neorpc.ExecutionEventID, + filter: neorpc.ExecutionFilter{State: st.String()}, + }, + container: exContainer, + expected: true, + }, + { + name: "notary request, no filter", + comparator: testComparator{id: neorpc.NotaryRequestEventID}, + container: ntrContainer, + expected: true, + }, + { + name: "notary request, sender mismatch", + comparator: testComparator{ + id: neorpc.NotaryRequestEventID, + filter: neorpc.TxFilter{Sender: &badUint160}, + }, + container: ntrContainer, + expected: false, + }, + { + name: "notary request, signer mismatch", + comparator: testComparator{ + id: neorpc.NotaryRequestEventID, + filter: neorpc.TxFilter{Signer: &badUint160}, + }, + container: ntrContainer, + expected: false, + }, + { + name: "notary request, filter match", + comparator: testComparator{ + id: neorpc.NotaryRequestEventID, + filter: neorpc.TxFilter{Sender: &sender, Signer: &signer}, + }, + container: ntrContainer, + expected: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expected, Matches(tc.comparator, tc.container)) + }) + } +} diff --git a/pkg/neorpc/types.go b/pkg/neorpc/types.go index d78991380..ff01cbfa7 100644 --- a/pkg/neorpc/types.go +++ b/pkg/neorpc/types.go @@ -155,3 +155,14 @@ func (s *SignerWithWitness) UnmarshalJSON(data []byte) error { } return nil } + +// EventID implements EventContainer interface and returns notification ID. +func (n *Notification) EventID() EventID { + return n.Event +} + +// EventPayload implements EventContainer interface and returns notification +// object. +func (n *Notification) EventPayload() interface{} { + return n.Payload[0] +} diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 33de3826d..1711a1f26 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -15,6 +15,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/neorpc" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" + "github.com/nspcc-dev/neo-go/pkg/neorpc/rpcevent" "github.com/nspcc-dev/neo-go/pkg/util" "go.uber.org/atomic" ) @@ -45,12 +46,30 @@ type WSClient struct { closeErr error subscriptionsLock sync.RWMutex - subscriptions map[string]bool + subscriptions map[string]notificationReceiver respLock sync.RWMutex respChannels map[uint64]chan *neorpc.Response } +// notificationReceiver is a server events receiver. It stores desired notifications ID +// and filter and a channel used to receive matching notifications. +type notificationReceiver struct { + typ neorpc.EventID + filter interface{} + ch chan<- Notification +} + +// EventID implements neorpc.Comparator interface and returns notification ID. +func (r notificationReceiver) EventID() neorpc.EventID { + return r.typ +} + +// Filter implements neorpc.Comparator interface and returns notification filter. +func (r notificationReceiver) Filter() interface{} { + return r.filter +} + // Notification represents a server-generated notification for client subscriptions. // Value can be one of *block.Block, *state.AppExecResult, *state.ContainedNotificationEvent // *transaction.Transaction or *subscriptions.NotaryRequestEvent based on Type. @@ -59,6 +78,17 @@ type Notification struct { Value interface{} } +// EventID implements Container interface and returns notification ID. +func (n Notification) EventID() neorpc.EventID { + return n.Type +} + +// EventPayload implements Container interface and returns notification +// object. +func (n Notification) EventPayload() interface{} { + return n.Value +} + // requestResponse is a combined type for request and response since we can get // any of them here. type requestResponse struct { @@ -107,7 +137,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error closeCalled: *atomic.NewBool(false), respChannels: make(map[uint64]chan *neorpc.Response), requests: make(chan *neorpc.Request), - subscriptions: make(map[string]bool), + subscriptions: make(map[string]notificationReceiver), } err = initClient(ctx, &wsc.Client, endpoint, opts) @@ -205,7 +235,16 @@ readloop: break readloop } } - c.Notifications <- Notification{Type: event, Value: val} + ok := make(map[chan<- Notification]bool) + c.subscriptionsLock.RLock() + for _, rcvr := range c.subscriptions { + ntf := Notification{Type: event, Value: val} + if (rpcevent.Matches(rcvr, ntf) || event == neorpc.MissedEventID /*missed event must be delivered to each receiver*/) && !ok[rcvr.ch] { + ok[rcvr.ch] = true // strictly one notification per channel + rcvr.ch <- ntf // this will block other receivers + } + } + c.subscriptionsLock.RUnlock() } else if rr.ID != nil && (rr.Error != nil || rr.Result != nil) { id, err := strconv.ParseUint(string(rr.ID), 10, 64) if err != nil { @@ -317,7 +356,7 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) { } } -func (c *WSClient) performSubscription(params []interface{}) (string, error) { +func (c *WSClient) performSubscription(params []interface{}, rcvr notificationReceiver) (string, error) { var resp string if err := c.performRequest("subscribe", params, &resp); err != nil { @@ -327,7 +366,7 @@ func (c *WSClient) performSubscription(params []interface{}) (string, error) { c.subscriptionsLock.Lock() defer c.subscriptionsLock.Unlock() - c.subscriptions[resp] = true + c.subscriptions[resp] = rcvr return resp, nil } @@ -337,7 +376,7 @@ func (c *WSClient) performUnsubscription(id string) error { c.subscriptionsLock.Lock() defer c.subscriptionsLock.Unlock() - if !c.subscriptions[id] { + if _, ok := c.subscriptions[id]; !ok { return errors.New("no subscription with this ID") } if err := c.performRequest("unsubscribe", []interface{}{id}, &resp); err != nil { @@ -354,22 +393,52 @@ func (c *WSClient) performUnsubscription(id string) error { // of the client. It can be filtered by primary consensus node index, nil value doesn't // add any filters. func (c *WSClient) SubscribeForNewBlocks(primary *int) (string, error) { + return c.SubscribeForNewBlocksWithChan(primary, c.Notifications) +} + +// SubscribeForNewBlocksWithChan registers provided channel as a receiver for the +// specified new blocks notifications. The receiver channel must be properly read and +// drained after usage in order not to block other notification receivers. +// See SubscribeForNewBlocks for parameter details. +func (c *WSClient) SubscribeForNewBlocksWithChan(primary *int, rcvrCh chan<- Notification) (string, error) { params := []interface{}{"block_added"} + var flt *neorpc.BlockFilter if primary != nil { - params = append(params, neorpc.BlockFilter{Primary: *primary}) + flt = &neorpc.BlockFilter{Primary: *primary} + params = append(params, *flt) } - return c.performSubscription(params) + rcvr := notificationReceiver{ + typ: neorpc.BlockEventID, + filter: flt, + ch: rcvrCh, + } + return c.performSubscription(params, rcvr) } // SubscribeForNewTransactions adds subscription for new transaction events to // this instance of the client. It can be filtered by the sender and/or the signer, nil // value is treated as missing filter. func (c *WSClient) SubscribeForNewTransactions(sender *util.Uint160, signer *util.Uint160) (string, error) { + return c.SubscribeForNewTransactionsWithChan(sender, signer, c.Notifications) +} + +// SubscribeForNewTransactionsWithChan registers provided channel as a receiver +// for the specified new transactions notifications. The receiver channel must be +// properly read and drained after usage in order not to block other notification +// receivers. See SubscribeForNewTransactions for parameter details. +func (c *WSClient) SubscribeForNewTransactionsWithChan(sender *util.Uint160, signer *util.Uint160, rcvrCh chan<- Notification) (string, error) { params := []interface{}{"transaction_added"} + var flt *neorpc.TxFilter if sender != nil || signer != nil { - params = append(params, neorpc.TxFilter{Sender: sender, Signer: signer}) + flt = &neorpc.TxFilter{Sender: sender, Signer: signer} + params = append(params, *flt) } - return c.performSubscription(params) + rcvr := notificationReceiver{ + typ: neorpc.TransactionEventID, + filter: flt, + ch: rcvrCh, + } + return c.performSubscription(params, rcvr) } // SubscribeForExecutionNotifications adds subscription for notifications @@ -377,11 +446,26 @@ func (c *WSClient) SubscribeForNewTransactions(sender *util.Uint160, signer *uti // filtered by the contract's hash (that emits notifications), nil value puts no such // restrictions. func (c *WSClient) SubscribeForExecutionNotifications(contract *util.Uint160, name *string) (string, error) { + return c.SubscribeForExecutionNotificationsWithChan(contract, name, c.Notifications) +} + +// SubscribeForExecutionNotificationsWithChan registers provided channel as a +// receiver for the specified execution events. The receiver channel must be +// properly read and drained after usage in order not to block other notification +// receivers. See SubscribeForExecutionNotifications for parameter details. +func (c *WSClient) SubscribeForExecutionNotificationsWithChan(contract *util.Uint160, name *string, rcvrCh chan<- Notification) (string, error) { params := []interface{}{"notification_from_execution"} + var flt *neorpc.NotificationFilter if contract != nil || name != nil { - params = append(params, neorpc.NotificationFilter{Contract: contract, Name: name}) + flt = &neorpc.NotificationFilter{Contract: contract, Name: name} + params = append(params, *flt) } - return c.performSubscription(params) + rcvr := notificationReceiver{ + typ: neorpc.NotificationEventID, + filter: flt, + ch: rcvrCh, + } + return c.performSubscription(params, rcvr) } // SubscribeForTransactionExecutions adds subscription for application execution @@ -389,14 +473,29 @@ func (c *WSClient) SubscribeForExecutionNotifications(contract *util.Uint160, na // be filtered by state (HALT/FAULT) to check for successful or failing // transactions, nil value means no filtering. func (c *WSClient) SubscribeForTransactionExecutions(state *string) (string, error) { + return c.SubscribeForTransactionExecutionsWithChan(state, c.Notifications) +} + +// SubscribeForTransactionExecutionsWithChan registers provided channel as a +// receiver for the specified execution notifications. The receiver channel must be +// properly read and drained after usage in order not to block other notification +// receivers. See SubscribeForTransactionExecutions for parameter details. +func (c *WSClient) SubscribeForTransactionExecutionsWithChan(state *string, rcvrCh chan<- Notification) (string, error) { params := []interface{}{"transaction_executed"} + var flt *neorpc.ExecutionFilter if state != nil { if *state != "HALT" && *state != "FAULT" { return "", errors.New("bad state parameter") } - params = append(params, neorpc.ExecutionFilter{State: *state}) + flt = &neorpc.ExecutionFilter{State: *state} + params = append(params, *flt) } - return c.performSubscription(params) + rcvr := notificationReceiver{ + typ: neorpc.ExecutionEventID, + filter: flt, + ch: rcvrCh, + } + return c.performSubscription(params, rcvr) } // SubscribeForNotaryRequests adds subscription for notary request payloads @@ -404,11 +503,26 @@ func (c *WSClient) SubscribeForTransactionExecutions(state *string) (string, err // request sender's hash, or main tx signer's hash, nil value puts no such // restrictions. func (c *WSClient) SubscribeForNotaryRequests(sender *util.Uint160, mainSigner *util.Uint160) (string, error) { + return c.SubscribeForNotaryRequestsWithChan(sender, mainSigner, c.Notifications) +} + +// SubscribeForNotaryRequestsWithChan registers provided channel as a receiver +// for the specified notary requests notifications. The receiver channel must be +// properly read and drained after usage in order not to block other notification +// receivers. See SubscribeForNotaryRequests for parameter details. +func (c *WSClient) SubscribeForNotaryRequestsWithChan(sender *util.Uint160, mainSigner *util.Uint160, rcvrCh chan<- Notification) (string, error) { params := []interface{}{"notary_request_event"} + var flt *neorpc.TxFilter if sender != nil { - params = append(params, neorpc.TxFilter{Sender: sender, Signer: mainSigner}) + flt = &neorpc.TxFilter{Sender: sender, Signer: mainSigner} + params = append(params, *flt) } - return c.performSubscription(params) + rcvr := notificationReceiver{ + typ: neorpc.NotaryRequestEventID, + filter: flt, + ch: rcvrCh, + } + return c.performSubscription(params, rcvr) } // Unsubscribe removes subscription for the given event stream. diff --git a/pkg/rpcclient/wsclient_test.go b/pkg/rpcclient/wsclient_test.go index eb730d2c0..1bc396a21 100644 --- a/pkg/rpcclient/wsclient_test.go +++ b/pkg/rpcclient/wsclient_test.go @@ -32,19 +32,32 @@ func TestWSClientClose(t *testing.T) { } func TestWSClientSubscription(t *testing.T) { + ch := make(chan Notification) var cases = map[string]func(*WSClient) (string, error){ "blocks": func(wsc *WSClient) (string, error) { return wsc.SubscribeForNewBlocks(nil) }, + "blocks_with_custom_ch": func(wsc *WSClient) (string, error) { + return wsc.SubscribeForNewBlocksWithChan(nil, ch) + }, "transactions": func(wsc *WSClient) (string, error) { return wsc.SubscribeForNewTransactions(nil, nil) }, + "transactions_with_custom_ch": func(wsc *WSClient) (string, error) { + return wsc.SubscribeForNewTransactionsWithChan(nil, nil, ch) + }, "notifications": func(wsc *WSClient) (string, error) { return wsc.SubscribeForExecutionNotifications(nil, nil) }, + "notifications_with_custom_ch": func(wsc *WSClient) (string, error) { + return wsc.SubscribeForExecutionNotificationsWithChan(nil, nil, ch) + }, "executions": func(wsc *WSClient) (string, error) { return wsc.SubscribeForTransactionExecutions(nil) }, + "executions_with_custom_ch": func(wsc *WSClient) (string, error) { + return wsc.SubscribeForTransactionExecutionsWithChan(nil, ch) + }, } t.Run("good", func(t *testing.T) { for name, f := range cases { @@ -83,13 +96,13 @@ func TestWSClientUnsubscription(t *testing.T) { var cases = map[string]responseCheck{ "good": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { // We can't really subscribe using this stub server, so set up wsc internals. - wsc.subscriptions["0"] = true + wsc.subscriptions["0"] = notificationReceiver{} err := wsc.Unsubscribe("0") require.NoError(t, err) }}, "all": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { // We can't really subscribe using this stub server, so set up wsc internals. - wsc.subscriptions["0"] = true + wsc.subscriptions["0"] = notificationReceiver{} err := wsc.UnsubscribeAll() require.NoError(t, err) require.Equal(t, 0, len(wsc.subscriptions)) @@ -100,13 +113,13 @@ func TestWSClientUnsubscription(t *testing.T) { }}, "error returned": {`{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`, func(t *testing.T, wsc *WSClient) { // We can't really subscribe using this stub server, so set up wsc internals. - wsc.subscriptions["0"] = true + wsc.subscriptions["0"] = notificationReceiver{} err := wsc.Unsubscribe("0") require.Error(t, err) }}, "false returned": {`{"jsonrpc": "2.0", "id": 1, "result": false}`, func(t *testing.T, wsc *WSClient) { // We can't really subscribe using this stub server, so set up wsc internals. - wsc.subscriptions["0"] = true + wsc.subscriptions["0"] = notificationReceiver{} err := wsc.Unsubscribe("0") require.Error(t, err) }}, @@ -151,26 +164,104 @@ func TestWSClientEvents(t *testing.T) { } })) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) - require.NoError(t, err) - wsc.getNextRequestID = getTestRequestID - wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually. - wsc.cache.network = netmode.UnitTestNet - for range events { + t.Run("default ntf channel", func(t *testing.T) { + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID + wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually. + // Our server mock is restricted, so perform subscriptions manually with default notifications channel. + wsc.subscriptionsLock.Lock() + wsc.subscriptions["0"] = notificationReceiver{typ: neorpc.BlockEventID, ch: wsc.Notifications} + wsc.subscriptions["1"] = notificationReceiver{typ: neorpc.ExecutionEventID, ch: wsc.Notifications} + wsc.subscriptions["2"] = notificationReceiver{typ: neorpc.NotificationEventID, ch: wsc.Notifications} + // MissedEvent must be delivered without subscription. + wsc.subscriptionsLock.Unlock() + wsc.cache.network = netmode.UnitTestNet + for range events { + select { + case _, ok = <-wsc.Notifications: + case <-time.After(time.Second): + t.Fatal("timeout waiting for event") + } + require.True(t, ok) + } select { case _, ok = <-wsc.Notifications: case <-time.After(time.Second): t.Fatal("timeout waiting for event") } - require.True(t, ok) - } - select { - case _, ok = <-wsc.Notifications: - case <-time.After(time.Second): - t.Fatal("timeout waiting for event") - } - // Connection closed by server. - require.False(t, ok) + // Connection closed by server. + require.False(t, ok) + }) + t.Run("multiple ntf channels", func(t *testing.T) { + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID + wsc.cacheLock.Lock() + wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually. + wsc.cache.network = netmode.UnitTestNet + wsc.cacheLock.Unlock() + + // Our server mock is restricted, so perform subscriptions manually with default notifications channel. + ch1 := make(chan Notification) + ch2 := make(chan Notification) + ch3 := make(chan Notification) + wsc.subscriptionsLock.Lock() + wsc.subscriptions["0"] = notificationReceiver{typ: neorpc.BlockEventID, ch: wsc.Notifications} + wsc.subscriptions["1"] = notificationReceiver{typ: neorpc.ExecutionEventID, ch: wsc.Notifications} + wsc.subscriptions["2"] = notificationReceiver{typ: neorpc.NotificationEventID, ch: wsc.Notifications} + wsc.subscriptions["3"] = notificationReceiver{typ: neorpc.BlockEventID, ch: ch1} + wsc.subscriptions["4"] = notificationReceiver{typ: neorpc.NotificationEventID, ch: ch2} + wsc.subscriptions["5"] = notificationReceiver{typ: neorpc.NotificationEventID, ch: ch2} // check duplicating subscriptions + wsc.subscriptions["6"] = notificationReceiver{typ: neorpc.ExecutionEventID, filter: neorpc.ExecutionFilter{State: "HALT"}, ch: ch2} + wsc.subscriptions["7"] = notificationReceiver{typ: neorpc.ExecutionEventID, filter: neorpc.ExecutionFilter{State: "FAULT"}, ch: ch3} + // MissedEvent must be delivered without subscription. + wsc.subscriptionsLock.Unlock() + + var ( + defaultChCnt int + ch1Cnt int + ch2Cnt int + ch3Cnt int + expectedDefaultCnCount = len(events) + expectedCh1Cnt = 1 + 1 // Block event + Missed event + expectedCh2Cnt = 1 + 2 + 1 // Notification event + 2 Execution events + Missed event + expectedCh3Cnt = 1 // Missed event + ntf Notification + ) + for i := 0; i < expectedDefaultCnCount+expectedCh1Cnt+expectedCh2Cnt+expectedCh3Cnt; i++ { + select { + case ntf, ok = <-wsc.Notifications: + defaultChCnt++ + case ntf, ok = <-ch1: + require.True(t, ntf.Type == neorpc.BlockEventID || ntf.Type == neorpc.MissedEventID, ntf.Type) + ch1Cnt++ + case ntf, ok = <-ch2: + require.True(t, ntf.Type == neorpc.NotificationEventID || ntf.Type == neorpc.MissedEventID || ntf.Type == neorpc.ExecutionEventID) + ch2Cnt++ + case ntf, ok = <-ch3: + require.True(t, ntf.Type == neorpc.MissedEventID) + ch3Cnt++ + case <-time.After(time.Second): + t.Fatal("timeout waiting for event") + } + require.True(t, ok) + } + select { + case _, ok = <-wsc.Notifications: + case _, ok = <-ch1: + case _, ok = <-ch2: + case _, ok = <-ch3: + case <-time.After(time.Second): + t.Fatal("timeout waiting for event") + } + // Connection closed by server. + require.False(t, ok) + require.Equal(t, expectedDefaultCnCount, defaultChCnt) + require.Equal(t, expectedCh1Cnt, ch1Cnt) + require.Equal(t, expectedCh2Cnt, ch2Cnt) + require.Equal(t, expectedCh3Cnt, ch3Cnt) + }) } func TestWSExecutionVMStateCheck(t *testing.T) { diff --git a/pkg/services/rpcsrv/server.go b/pkg/services/rpcsrv/server.go index fbc19bd90..d5874e92d 100644 --- a/pkg/services/rpcsrv/server.go +++ b/pkg/services/rpcsrv/server.go @@ -40,6 +40,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/neorpc" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" + "github.com/nspcc-dev/neo-go/pkg/neorpc/rpcevent" "github.com/nspcc-dev/neo-go/pkg/network" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/services/oracle/broadcaster" @@ -2593,7 +2594,7 @@ chloop: continue } for i := range sub.feeds { - if sub.feeds[i].Matches(&resp) { + if rpcevent.Matches(sub.feeds[i], &resp) { if msg == nil { b, err = json.Marshal(resp) if err != nil { diff --git a/pkg/services/rpcsrv/subscription.go b/pkg/services/rpcsrv/subscription.go index 6ec1be3fb..85e9a7036 100644 --- a/pkg/services/rpcsrv/subscription.go +++ b/pkg/services/rpcsrv/subscription.go @@ -2,11 +2,7 @@ package rpcsrv import ( "github.com/gorilla/websocket" - "github.com/nspcc-dev/neo-go/pkg/core/block" - "github.com/nspcc-dev/neo-go/pkg/core/state" - "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/neorpc" - "github.com/nspcc-dev/neo-go/pkg/neorpc/result" "go.uber.org/atomic" ) @@ -22,12 +18,23 @@ type ( // that's not for long. feeds [maxFeeds]feed } + // feed stores subscriber's desired event ID with filter. feed struct { event neorpc.EventID filter interface{} } ) +// EventID implements neorpc.EventComparator interface and returns notification ID. +func (f feed) EventID() neorpc.EventID { + return f.event +} + +// Filter implements neorpc.EventComparator interface and returns notification filter. +func (f feed) Filter() interface{} { + return f.filter +} + const ( // Maximum number of subscriptions per one client. maxFeeds = 16 @@ -42,59 +49,3 @@ const ( // a lot in terms of memory used. notificationBufSize = 1024 ) - -func (f *feed) Matches(r *neorpc.Notification) bool { - if r.Event != f.event { - return false - } - if f.filter == nil { - return true - } - switch f.event { - case neorpc.BlockEventID: - filt := f.filter.(neorpc.BlockFilter) - b := r.Payload[0].(*block.Block) - return int(b.PrimaryIndex) == filt.Primary - case neorpc.TransactionEventID: - filt := f.filter.(neorpc.TxFilter) - tx := r.Payload[0].(*transaction.Transaction) - senderOK := filt.Sender == nil || tx.Sender().Equals(*filt.Sender) - signerOK := true - if filt.Signer != nil { - signerOK = false - for i := range tx.Signers { - if tx.Signers[i].Account.Equals(*filt.Signer) { - signerOK = true - break - } - } - } - return senderOK && signerOK - case neorpc.NotificationEventID: - filt := f.filter.(neorpc.NotificationFilter) - notification := r.Payload[0].(*state.ContainedNotificationEvent) - hashOk := filt.Contract == nil || notification.ScriptHash.Equals(*filt.Contract) - nameOk := filt.Name == nil || notification.Name == *filt.Name - return hashOk && nameOk - case neorpc.ExecutionEventID: - filt := f.filter.(neorpc.ExecutionFilter) - applog := r.Payload[0].(*state.AppExecResult) - return applog.VMState.String() == filt.State - case neorpc.NotaryRequestEventID: - filt := f.filter.(neorpc.TxFilter) - req := r.Payload[0].(*result.NotaryRequestEvent) - senderOk := filt.Sender == nil || req.NotaryRequest.FallbackTransaction.Signers[1].Account == *filt.Sender - signerOK := true - if filt.Signer != nil { - signerOK = false - for _, signer := range req.NotaryRequest.MainTransaction.Signers { - if signer.Account.Equals(*filt.Signer) { - signerOK = true - break - } - } - } - return senderOk && signerOK - } - return false -}