rpc: support multiple WSClient notification receivers

This commit is contained in:
Anna Shaleva 2022-10-17 13:31:24 +03:00
parent 4ce6bc6a66
commit 6d38e75149
7 changed files with 599 additions and 96 deletions

View file

@ -0,0 +1,83 @@
package rpcevent
import (
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/neorpc"
"github.com/nspcc-dev/neo-go/pkg/neorpc/result"
)
type (
// Comparator is an interface required from notification event filter to be able to
// filter notifications.
Comparator interface {
EventID() neorpc.EventID
Filter() interface{}
}
// Container is an interface required from notification event to be able to
// pass filter.
Container interface {
EventID() neorpc.EventID
EventPayload() interface{}
}
)
// Matches filters our given Container against Comparator filter.
func Matches(f Comparator, r Container) bool {
expectedEvent := f.EventID()
filter := f.Filter()
if r.EventID() != expectedEvent {
return false
}
if filter == nil {
return true
}
switch f.EventID() {
case neorpc.BlockEventID:
filt := filter.(neorpc.BlockFilter)
b := r.EventPayload().(*block.Block)
return int(b.PrimaryIndex) == filt.Primary
case neorpc.TransactionEventID:
filt := filter.(neorpc.TxFilter)
tx := r.EventPayload().(*transaction.Transaction)
senderOK := filt.Sender == nil || tx.Sender().Equals(*filt.Sender)
signerOK := true
if filt.Signer != nil {
signerOK = false
for i := range tx.Signers {
if tx.Signers[i].Account.Equals(*filt.Signer) {
signerOK = true
break
}
}
}
return senderOK && signerOK
case neorpc.NotificationEventID:
filt := filter.(neorpc.NotificationFilter)
notification := r.EventPayload().(*state.ContainedNotificationEvent)
hashOk := filt.Contract == nil || notification.ScriptHash.Equals(*filt.Contract)
nameOk := filt.Name == nil || notification.Name == *filt.Name
return hashOk && nameOk
case neorpc.ExecutionEventID:
filt := filter.(neorpc.ExecutionFilter)
applog := r.EventPayload().(*state.AppExecResult)
return applog.VMState.String() == filt.State
case neorpc.NotaryRequestEventID:
filt := filter.(neorpc.TxFilter)
req := r.EventPayload().(*result.NotaryRequestEvent)
senderOk := filt.Sender == nil || req.NotaryRequest.FallbackTransaction.Signers[1].Account == *filt.Sender
signerOK := true
if filt.Signer != nil {
signerOK = false
for _, signer := range req.NotaryRequest.MainTransaction.Signers {
if signer.Account.Equals(*filt.Signer) {
signerOK = true
break
}
}
}
return senderOk && signerOK
}
return false
}

View file

@ -0,0 +1,252 @@
package rpcevent
import (
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/neorpc"
"github.com/nspcc-dev/neo-go/pkg/neorpc/result"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/vmstate"
"github.com/stretchr/testify/require"
)
type (
testComparator struct {
id neorpc.EventID
filter interface{}
}
testContainer struct {
id neorpc.EventID
pld interface{}
}
)
func (c testComparator) EventID() neorpc.EventID {
return c.id
}
func (c testComparator) Filter() interface{} {
return c.filter
}
func (c testContainer) EventID() neorpc.EventID {
return c.id
}
func (c testContainer) EventPayload() interface{} {
return c.pld
}
func TestMatches(t *testing.T) {
primary := byte(1)
sender := util.Uint160{1, 2, 3}
signer := util.Uint160{4, 5, 6}
contract := util.Uint160{7, 8, 9}
badUint160 := util.Uint160{9, 9, 9}
name := "ntf name"
badName := "bad name"
bContainer := testContainer{
id: neorpc.BlockEventID,
pld: &block.Block{
Header: block.Header{PrimaryIndex: primary},
},
}
st := vmstate.Halt
badState := "FAULT"
txContainer := testContainer{
id: neorpc.TransactionEventID,
pld: &transaction.Transaction{Signers: []transaction.Signer{{Account: sender}, {Account: signer}}},
}
ntfContainer := testContainer{
id: neorpc.NotificationEventID,
pld: &state.ContainedNotificationEvent{NotificationEvent: state.NotificationEvent{ScriptHash: contract, Name: name}},
}
exContainer := testContainer{
id: neorpc.ExecutionEventID,
pld: &state.AppExecResult{Execution: state.Execution{VMState: st}},
}
ntrContainer := testContainer{
id: neorpc.NotaryRequestEventID,
pld: &result.NotaryRequestEvent{
NotaryRequest: &payload.P2PNotaryRequest{
MainTransaction: &transaction.Transaction{Signers: []transaction.Signer{{Account: signer}}},
FallbackTransaction: &transaction.Transaction{Signers: []transaction.Signer{{Account: util.Uint160{}}, {Account: sender}}},
},
},
}
missedContainer := testContainer{
id: neorpc.MissedEventID,
}
var testCases = []struct {
name string
comparator testComparator
container testContainer
expected bool
}{
{
name: "ID mismatch",
comparator: testComparator{id: neorpc.TransactionEventID},
container: bContainer,
expected: false,
},
{
name: "missed event",
comparator: testComparator{id: neorpc.BlockEventID},
container: missedContainer,
expected: false,
},
{
name: "block, no filter",
comparator: testComparator{id: neorpc.BlockEventID},
container: bContainer,
expected: true,
},
{
name: "block, primary mismatch",
comparator: testComparator{
id: neorpc.BlockEventID,
filter: neorpc.BlockFilter{Primary: int(primary + 1)},
},
container: bContainer,
expected: false,
},
{
name: "block, filter match",
comparator: testComparator{
id: neorpc.BlockEventID,
filter: neorpc.BlockFilter{Primary: int(primary)},
},
container: bContainer,
expected: true,
},
{
name: "transaction, no filter",
comparator: testComparator{id: neorpc.TransactionEventID},
container: txContainer,
expected: true,
},
{
name: "transaction, sender mismatch",
comparator: testComparator{
id: neorpc.TransactionEventID,
filter: neorpc.TxFilter{Sender: &badUint160},
},
container: txContainer,
expected: false,
},
{
name: "transaction, signer mismatch",
comparator: testComparator{
id: neorpc.TransactionEventID,
filter: neorpc.TxFilter{Signer: &badUint160},
},
container: txContainer,
expected: false,
},
{
name: "transaction, filter match",
comparator: testComparator{
id: neorpc.TransactionEventID,
filter: neorpc.TxFilter{Sender: &sender, Signer: &signer},
},
container: txContainer,
expected: true,
},
{
name: "notification, no filter",
comparator: testComparator{id: neorpc.NotificationEventID},
container: ntfContainer,
expected: true,
},
{
name: "notification, contract mismatch",
comparator: testComparator{
id: neorpc.NotificationEventID,
filter: neorpc.NotificationFilter{Contract: &badUint160},
},
container: ntfContainer,
expected: false,
},
{
name: "notification, name mismatch",
comparator: testComparator{
id: neorpc.NotificationEventID,
filter: neorpc.NotificationFilter{Name: &badName},
},
container: ntfContainer,
expected: false,
},
{
name: "notification, filter match",
comparator: testComparator{
id: neorpc.NotificationEventID,
filter: neorpc.NotificationFilter{Name: &name, Contract: &contract},
},
container: ntfContainer,
expected: true,
},
{
name: "execution, no filter",
comparator: testComparator{id: neorpc.ExecutionEventID},
container: exContainer,
expected: true,
},
{
name: "execution, state mismatch",
comparator: testComparator{
id: neorpc.ExecutionEventID,
filter: neorpc.ExecutionFilter{State: badState},
},
container: exContainer,
expected: false,
},
{
name: "execution, filter mismatch",
comparator: testComparator{
id: neorpc.ExecutionEventID,
filter: neorpc.ExecutionFilter{State: st.String()},
},
container: exContainer,
expected: true,
},
{
name: "notary request, no filter",
comparator: testComparator{id: neorpc.NotaryRequestEventID},
container: ntrContainer,
expected: true,
},
{
name: "notary request, sender mismatch",
comparator: testComparator{
id: neorpc.NotaryRequestEventID,
filter: neorpc.TxFilter{Sender: &badUint160},
},
container: ntrContainer,
expected: false,
},
{
name: "notary request, signer mismatch",
comparator: testComparator{
id: neorpc.NotaryRequestEventID,
filter: neorpc.TxFilter{Signer: &badUint160},
},
container: ntrContainer,
expected: false,
},
{
name: "notary request, filter match",
comparator: testComparator{
id: neorpc.NotaryRequestEventID,
filter: neorpc.TxFilter{Sender: &sender, Signer: &signer},
},
container: ntrContainer,
expected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expected, Matches(tc.comparator, tc.container))
})
}
}

View file

@ -155,3 +155,14 @@ func (s *SignerWithWitness) UnmarshalJSON(data []byte) error {
} }
return nil return nil
} }
// EventID implements EventContainer interface and returns notification ID.
func (n *Notification) EventID() EventID {
return n.Event
}
// EventPayload implements EventContainer interface and returns notification
// object.
func (n *Notification) EventPayload() interface{} {
return n.Payload[0]
}

View file

@ -15,6 +15,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/neorpc" "github.com/nspcc-dev/neo-go/pkg/neorpc"
"github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/neorpc/result"
"github.com/nspcc-dev/neo-go/pkg/neorpc/rpcevent"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"go.uber.org/atomic" "go.uber.org/atomic"
) )
@ -45,12 +46,30 @@ type WSClient struct {
closeErr error closeErr error
subscriptionsLock sync.RWMutex subscriptionsLock sync.RWMutex
subscriptions map[string]bool subscriptions map[string]notificationReceiver
respLock sync.RWMutex respLock sync.RWMutex
respChannels map[uint64]chan *neorpc.Response respChannels map[uint64]chan *neorpc.Response
} }
// notificationReceiver is a server events receiver. It stores desired notifications ID
// and filter and a channel used to receive matching notifications.
type notificationReceiver struct {
typ neorpc.EventID
filter interface{}
ch chan<- Notification
}
// EventID implements neorpc.Comparator interface and returns notification ID.
func (r notificationReceiver) EventID() neorpc.EventID {
return r.typ
}
// Filter implements neorpc.Comparator interface and returns notification filter.
func (r notificationReceiver) Filter() interface{} {
return r.filter
}
// Notification represents a server-generated notification for client subscriptions. // Notification represents a server-generated notification for client subscriptions.
// Value can be one of *block.Block, *state.AppExecResult, *state.ContainedNotificationEvent // Value can be one of *block.Block, *state.AppExecResult, *state.ContainedNotificationEvent
// *transaction.Transaction or *subscriptions.NotaryRequestEvent based on Type. // *transaction.Transaction or *subscriptions.NotaryRequestEvent based on Type.
@ -59,6 +78,17 @@ type Notification struct {
Value interface{} Value interface{}
} }
// EventID implements Container interface and returns notification ID.
func (n Notification) EventID() neorpc.EventID {
return n.Type
}
// EventPayload implements Container interface and returns notification
// object.
func (n Notification) EventPayload() interface{} {
return n.Value
}
// requestResponse is a combined type for request and response since we can get // requestResponse is a combined type for request and response since we can get
// any of them here. // any of them here.
type requestResponse struct { type requestResponse struct {
@ -107,7 +137,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error
closeCalled: *atomic.NewBool(false), closeCalled: *atomic.NewBool(false),
respChannels: make(map[uint64]chan *neorpc.Response), respChannels: make(map[uint64]chan *neorpc.Response),
requests: make(chan *neorpc.Request), requests: make(chan *neorpc.Request),
subscriptions: make(map[string]bool), subscriptions: make(map[string]notificationReceiver),
} }
err = initClient(ctx, &wsc.Client, endpoint, opts) err = initClient(ctx, &wsc.Client, endpoint, opts)
@ -205,7 +235,16 @@ readloop:
break readloop break readloop
} }
} }
c.Notifications <- Notification{Type: event, Value: val} ok := make(map[chan<- Notification]bool)
c.subscriptionsLock.RLock()
for _, rcvr := range c.subscriptions {
ntf := Notification{Type: event, Value: val}
if (rpcevent.Matches(rcvr, ntf) || event == neorpc.MissedEventID /*missed event must be delivered to each receiver*/) && !ok[rcvr.ch] {
ok[rcvr.ch] = true // strictly one notification per channel
rcvr.ch <- ntf // this will block other receivers
}
}
c.subscriptionsLock.RUnlock()
} else if rr.ID != nil && (rr.Error != nil || rr.Result != nil) { } else if rr.ID != nil && (rr.Error != nil || rr.Result != nil) {
id, err := strconv.ParseUint(string(rr.ID), 10, 64) id, err := strconv.ParseUint(string(rr.ID), 10, 64)
if err != nil { if err != nil {
@ -317,7 +356,7 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
} }
} }
func (c *WSClient) performSubscription(params []interface{}) (string, error) { func (c *WSClient) performSubscription(params []interface{}, rcvr notificationReceiver) (string, error) {
var resp string var resp string
if err := c.performRequest("subscribe", params, &resp); err != nil { if err := c.performRequest("subscribe", params, &resp); err != nil {
@ -327,7 +366,7 @@ func (c *WSClient) performSubscription(params []interface{}) (string, error) {
c.subscriptionsLock.Lock() c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock() defer c.subscriptionsLock.Unlock()
c.subscriptions[resp] = true c.subscriptions[resp] = rcvr
return resp, nil return resp, nil
} }
@ -337,7 +376,7 @@ func (c *WSClient) performUnsubscription(id string) error {
c.subscriptionsLock.Lock() c.subscriptionsLock.Lock()
defer c.subscriptionsLock.Unlock() defer c.subscriptionsLock.Unlock()
if !c.subscriptions[id] { if _, ok := c.subscriptions[id]; !ok {
return errors.New("no subscription with this ID") return errors.New("no subscription with this ID")
} }
if err := c.performRequest("unsubscribe", []interface{}{id}, &resp); err != nil { if err := c.performRequest("unsubscribe", []interface{}{id}, &resp); err != nil {
@ -354,22 +393,52 @@ func (c *WSClient) performUnsubscription(id string) error {
// of the client. It can be filtered by primary consensus node index, nil value doesn't // of the client. It can be filtered by primary consensus node index, nil value doesn't
// add any filters. // add any filters.
func (c *WSClient) SubscribeForNewBlocks(primary *int) (string, error) { func (c *WSClient) SubscribeForNewBlocks(primary *int) (string, error) {
return c.SubscribeForNewBlocksWithChan(primary, c.Notifications)
}
// SubscribeForNewBlocksWithChan registers provided channel as a receiver for the
// specified new blocks notifications. The receiver channel must be properly read and
// drained after usage in order not to block other notification receivers.
// See SubscribeForNewBlocks for parameter details.
func (c *WSClient) SubscribeForNewBlocksWithChan(primary *int, rcvrCh chan<- Notification) (string, error) {
params := []interface{}{"block_added"} params := []interface{}{"block_added"}
var flt *neorpc.BlockFilter
if primary != nil { if primary != nil {
params = append(params, neorpc.BlockFilter{Primary: *primary}) flt = &neorpc.BlockFilter{Primary: *primary}
params = append(params, *flt)
} }
return c.performSubscription(params) rcvr := notificationReceiver{
typ: neorpc.BlockEventID,
filter: flt,
ch: rcvrCh,
}
return c.performSubscription(params, rcvr)
} }
// SubscribeForNewTransactions adds subscription for new transaction events to // SubscribeForNewTransactions adds subscription for new transaction events to
// this instance of the client. It can be filtered by the sender and/or the signer, nil // this instance of the client. It can be filtered by the sender and/or the signer, nil
// value is treated as missing filter. // value is treated as missing filter.
func (c *WSClient) SubscribeForNewTransactions(sender *util.Uint160, signer *util.Uint160) (string, error) { func (c *WSClient) SubscribeForNewTransactions(sender *util.Uint160, signer *util.Uint160) (string, error) {
return c.SubscribeForNewTransactionsWithChan(sender, signer, c.Notifications)
}
// SubscribeForNewTransactionsWithChan registers provided channel as a receiver
// for the specified new transactions notifications. The receiver channel must be
// properly read and drained after usage in order not to block other notification
// receivers. See SubscribeForNewTransactions for parameter details.
func (c *WSClient) SubscribeForNewTransactionsWithChan(sender *util.Uint160, signer *util.Uint160, rcvrCh chan<- Notification) (string, error) {
params := []interface{}{"transaction_added"} params := []interface{}{"transaction_added"}
var flt *neorpc.TxFilter
if sender != nil || signer != nil { if sender != nil || signer != nil {
params = append(params, neorpc.TxFilter{Sender: sender, Signer: signer}) flt = &neorpc.TxFilter{Sender: sender, Signer: signer}
params = append(params, *flt)
} }
return c.performSubscription(params) rcvr := notificationReceiver{
typ: neorpc.TransactionEventID,
filter: flt,
ch: rcvrCh,
}
return c.performSubscription(params, rcvr)
} }
// SubscribeForExecutionNotifications adds subscription for notifications // SubscribeForExecutionNotifications adds subscription for notifications
@ -377,11 +446,26 @@ func (c *WSClient) SubscribeForNewTransactions(sender *util.Uint160, signer *uti
// filtered by the contract's hash (that emits notifications), nil value puts no such // filtered by the contract's hash (that emits notifications), nil value puts no such
// restrictions. // restrictions.
func (c *WSClient) SubscribeForExecutionNotifications(contract *util.Uint160, name *string) (string, error) { func (c *WSClient) SubscribeForExecutionNotifications(contract *util.Uint160, name *string) (string, error) {
return c.SubscribeForExecutionNotificationsWithChan(contract, name, c.Notifications)
}
// SubscribeForExecutionNotificationsWithChan registers provided channel as a
// receiver for the specified execution events. The receiver channel must be
// properly read and drained after usage in order not to block other notification
// receivers. See SubscribeForExecutionNotifications for parameter details.
func (c *WSClient) SubscribeForExecutionNotificationsWithChan(contract *util.Uint160, name *string, rcvrCh chan<- Notification) (string, error) {
params := []interface{}{"notification_from_execution"} params := []interface{}{"notification_from_execution"}
var flt *neorpc.NotificationFilter
if contract != nil || name != nil { if contract != nil || name != nil {
params = append(params, neorpc.NotificationFilter{Contract: contract, Name: name}) flt = &neorpc.NotificationFilter{Contract: contract, Name: name}
params = append(params, *flt)
} }
return c.performSubscription(params) rcvr := notificationReceiver{
typ: neorpc.NotificationEventID,
filter: flt,
ch: rcvrCh,
}
return c.performSubscription(params, rcvr)
} }
// SubscribeForTransactionExecutions adds subscription for application execution // SubscribeForTransactionExecutions adds subscription for application execution
@ -389,14 +473,29 @@ func (c *WSClient) SubscribeForExecutionNotifications(contract *util.Uint160, na
// be filtered by state (HALT/FAULT) to check for successful or failing // be filtered by state (HALT/FAULT) to check for successful or failing
// transactions, nil value means no filtering. // transactions, nil value means no filtering.
func (c *WSClient) SubscribeForTransactionExecutions(state *string) (string, error) { func (c *WSClient) SubscribeForTransactionExecutions(state *string) (string, error) {
return c.SubscribeForTransactionExecutionsWithChan(state, c.Notifications)
}
// SubscribeForTransactionExecutionsWithChan registers provided channel as a
// receiver for the specified execution notifications. The receiver channel must be
// properly read and drained after usage in order not to block other notification
// receivers. See SubscribeForTransactionExecutions for parameter details.
func (c *WSClient) SubscribeForTransactionExecutionsWithChan(state *string, rcvrCh chan<- Notification) (string, error) {
params := []interface{}{"transaction_executed"} params := []interface{}{"transaction_executed"}
var flt *neorpc.ExecutionFilter
if state != nil { if state != nil {
if *state != "HALT" && *state != "FAULT" { if *state != "HALT" && *state != "FAULT" {
return "", errors.New("bad state parameter") return "", errors.New("bad state parameter")
} }
params = append(params, neorpc.ExecutionFilter{State: *state}) flt = &neorpc.ExecutionFilter{State: *state}
params = append(params, *flt)
} }
return c.performSubscription(params) rcvr := notificationReceiver{
typ: neorpc.ExecutionEventID,
filter: flt,
ch: rcvrCh,
}
return c.performSubscription(params, rcvr)
} }
// SubscribeForNotaryRequests adds subscription for notary request payloads // SubscribeForNotaryRequests adds subscription for notary request payloads
@ -404,11 +503,26 @@ func (c *WSClient) SubscribeForTransactionExecutions(state *string) (string, err
// request sender's hash, or main tx signer's hash, nil value puts no such // request sender's hash, or main tx signer's hash, nil value puts no such
// restrictions. // restrictions.
func (c *WSClient) SubscribeForNotaryRequests(sender *util.Uint160, mainSigner *util.Uint160) (string, error) { func (c *WSClient) SubscribeForNotaryRequests(sender *util.Uint160, mainSigner *util.Uint160) (string, error) {
return c.SubscribeForNotaryRequestsWithChan(sender, mainSigner, c.Notifications)
}
// SubscribeForNotaryRequestsWithChan registers provided channel as a receiver
// for the specified notary requests notifications. The receiver channel must be
// properly read and drained after usage in order not to block other notification
// receivers. See SubscribeForNotaryRequests for parameter details.
func (c *WSClient) SubscribeForNotaryRequestsWithChan(sender *util.Uint160, mainSigner *util.Uint160, rcvrCh chan<- Notification) (string, error) {
params := []interface{}{"notary_request_event"} params := []interface{}{"notary_request_event"}
var flt *neorpc.TxFilter
if sender != nil { if sender != nil {
params = append(params, neorpc.TxFilter{Sender: sender, Signer: mainSigner}) flt = &neorpc.TxFilter{Sender: sender, Signer: mainSigner}
params = append(params, *flt)
} }
return c.performSubscription(params) rcvr := notificationReceiver{
typ: neorpc.NotaryRequestEventID,
filter: flt,
ch: rcvrCh,
}
return c.performSubscription(params, rcvr)
} }
// Unsubscribe removes subscription for the given event stream. // Unsubscribe removes subscription for the given event stream.

View file

@ -32,19 +32,32 @@ func TestWSClientClose(t *testing.T) {
} }
func TestWSClientSubscription(t *testing.T) { func TestWSClientSubscription(t *testing.T) {
ch := make(chan Notification)
var cases = map[string]func(*WSClient) (string, error){ var cases = map[string]func(*WSClient) (string, error){
"blocks": func(wsc *WSClient) (string, error) { "blocks": func(wsc *WSClient) (string, error) {
return wsc.SubscribeForNewBlocks(nil) return wsc.SubscribeForNewBlocks(nil)
}, },
"blocks_with_custom_ch": func(wsc *WSClient) (string, error) {
return wsc.SubscribeForNewBlocksWithChan(nil, ch)
},
"transactions": func(wsc *WSClient) (string, error) { "transactions": func(wsc *WSClient) (string, error) {
return wsc.SubscribeForNewTransactions(nil, nil) return wsc.SubscribeForNewTransactions(nil, nil)
}, },
"transactions_with_custom_ch": func(wsc *WSClient) (string, error) {
return wsc.SubscribeForNewTransactionsWithChan(nil, nil, ch)
},
"notifications": func(wsc *WSClient) (string, error) { "notifications": func(wsc *WSClient) (string, error) {
return wsc.SubscribeForExecutionNotifications(nil, nil) return wsc.SubscribeForExecutionNotifications(nil, nil)
}, },
"notifications_with_custom_ch": func(wsc *WSClient) (string, error) {
return wsc.SubscribeForExecutionNotificationsWithChan(nil, nil, ch)
},
"executions": func(wsc *WSClient) (string, error) { "executions": func(wsc *WSClient) (string, error) {
return wsc.SubscribeForTransactionExecutions(nil) return wsc.SubscribeForTransactionExecutions(nil)
}, },
"executions_with_custom_ch": func(wsc *WSClient) (string, error) {
return wsc.SubscribeForTransactionExecutionsWithChan(nil, ch)
},
} }
t.Run("good", func(t *testing.T) { t.Run("good", func(t *testing.T) {
for name, f := range cases { for name, f := range cases {
@ -83,13 +96,13 @@ func TestWSClientUnsubscription(t *testing.T) {
var cases = map[string]responseCheck{ var cases = map[string]responseCheck{
"good": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { "good": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) {
// We can't really subscribe using this stub server, so set up wsc internals. // We can't really subscribe using this stub server, so set up wsc internals.
wsc.subscriptions["0"] = true wsc.subscriptions["0"] = notificationReceiver{}
err := wsc.Unsubscribe("0") err := wsc.Unsubscribe("0")
require.NoError(t, err) require.NoError(t, err)
}}, }},
"all": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { "all": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) {
// We can't really subscribe using this stub server, so set up wsc internals. // We can't really subscribe using this stub server, so set up wsc internals.
wsc.subscriptions["0"] = true wsc.subscriptions["0"] = notificationReceiver{}
err := wsc.UnsubscribeAll() err := wsc.UnsubscribeAll()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(wsc.subscriptions)) require.Equal(t, 0, len(wsc.subscriptions))
@ -100,13 +113,13 @@ func TestWSClientUnsubscription(t *testing.T) {
}}, }},
"error returned": {`{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`, func(t *testing.T, wsc *WSClient) { "error returned": {`{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`, func(t *testing.T, wsc *WSClient) {
// We can't really subscribe using this stub server, so set up wsc internals. // We can't really subscribe using this stub server, so set up wsc internals.
wsc.subscriptions["0"] = true wsc.subscriptions["0"] = notificationReceiver{}
err := wsc.Unsubscribe("0") err := wsc.Unsubscribe("0")
require.Error(t, err) require.Error(t, err)
}}, }},
"false returned": {`{"jsonrpc": "2.0", "id": 1, "result": false}`, func(t *testing.T, wsc *WSClient) { "false returned": {`{"jsonrpc": "2.0", "id": 1, "result": false}`, func(t *testing.T, wsc *WSClient) {
// We can't really subscribe using this stub server, so set up wsc internals. // We can't really subscribe using this stub server, so set up wsc internals.
wsc.subscriptions["0"] = true wsc.subscriptions["0"] = notificationReceiver{}
err := wsc.Unsubscribe("0") err := wsc.Unsubscribe("0")
require.Error(t, err) require.Error(t, err)
}}, }},
@ -151,10 +164,18 @@ func TestWSClientEvents(t *testing.T) {
} }
})) }))
t.Run("default ntf channel", func(t *testing.T) {
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID wsc.getNextRequestID = getTestRequestID
wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually. wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually.
// Our server mock is restricted, so perform subscriptions manually with default notifications channel.
wsc.subscriptionsLock.Lock()
wsc.subscriptions["0"] = notificationReceiver{typ: neorpc.BlockEventID, ch: wsc.Notifications}
wsc.subscriptions["1"] = notificationReceiver{typ: neorpc.ExecutionEventID, ch: wsc.Notifications}
wsc.subscriptions["2"] = notificationReceiver{typ: neorpc.NotificationEventID, ch: wsc.Notifications}
// MissedEvent must be delivered without subscription.
wsc.subscriptionsLock.Unlock()
wsc.cache.network = netmode.UnitTestNet wsc.cache.network = netmode.UnitTestNet
for range events { for range events {
select { select {
@ -171,6 +192,76 @@ func TestWSClientEvents(t *testing.T) {
} }
// Connection closed by server. // Connection closed by server.
require.False(t, ok) require.False(t, ok)
})
t.Run("multiple ntf channels", func(t *testing.T) {
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
wsc.cacheLock.Lock()
wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually.
wsc.cache.network = netmode.UnitTestNet
wsc.cacheLock.Unlock()
// Our server mock is restricted, so perform subscriptions manually with default notifications channel.
ch1 := make(chan Notification)
ch2 := make(chan Notification)
ch3 := make(chan Notification)
wsc.subscriptionsLock.Lock()
wsc.subscriptions["0"] = notificationReceiver{typ: neorpc.BlockEventID, ch: wsc.Notifications}
wsc.subscriptions["1"] = notificationReceiver{typ: neorpc.ExecutionEventID, ch: wsc.Notifications}
wsc.subscriptions["2"] = notificationReceiver{typ: neorpc.NotificationEventID, ch: wsc.Notifications}
wsc.subscriptions["3"] = notificationReceiver{typ: neorpc.BlockEventID, ch: ch1}
wsc.subscriptions["4"] = notificationReceiver{typ: neorpc.NotificationEventID, ch: ch2}
wsc.subscriptions["5"] = notificationReceiver{typ: neorpc.NotificationEventID, ch: ch2} // check duplicating subscriptions
wsc.subscriptions["6"] = notificationReceiver{typ: neorpc.ExecutionEventID, filter: neorpc.ExecutionFilter{State: "HALT"}, ch: ch2}
wsc.subscriptions["7"] = notificationReceiver{typ: neorpc.ExecutionEventID, filter: neorpc.ExecutionFilter{State: "FAULT"}, ch: ch3}
// MissedEvent must be delivered without subscription.
wsc.subscriptionsLock.Unlock()
var (
defaultChCnt int
ch1Cnt int
ch2Cnt int
ch3Cnt int
expectedDefaultCnCount = len(events)
expectedCh1Cnt = 1 + 1 // Block event + Missed event
expectedCh2Cnt = 1 + 2 + 1 // Notification event + 2 Execution events + Missed event
expectedCh3Cnt = 1 // Missed event
ntf Notification
)
for i := 0; i < expectedDefaultCnCount+expectedCh1Cnt+expectedCh2Cnt+expectedCh3Cnt; i++ {
select {
case ntf, ok = <-wsc.Notifications:
defaultChCnt++
case ntf, ok = <-ch1:
require.True(t, ntf.Type == neorpc.BlockEventID || ntf.Type == neorpc.MissedEventID, ntf.Type)
ch1Cnt++
case ntf, ok = <-ch2:
require.True(t, ntf.Type == neorpc.NotificationEventID || ntf.Type == neorpc.MissedEventID || ntf.Type == neorpc.ExecutionEventID)
ch2Cnt++
case ntf, ok = <-ch3:
require.True(t, ntf.Type == neorpc.MissedEventID)
ch3Cnt++
case <-time.After(time.Second):
t.Fatal("timeout waiting for event")
}
require.True(t, ok)
}
select {
case _, ok = <-wsc.Notifications:
case _, ok = <-ch1:
case _, ok = <-ch2:
case _, ok = <-ch3:
case <-time.After(time.Second):
t.Fatal("timeout waiting for event")
}
// Connection closed by server.
require.False(t, ok)
require.Equal(t, expectedDefaultCnCount, defaultChCnt)
require.Equal(t, expectedCh1Cnt, ch1Cnt)
require.Equal(t, expectedCh2Cnt, ch2Cnt)
require.Equal(t, expectedCh3Cnt, ch3Cnt)
})
} }
func TestWSExecutionVMStateCheck(t *testing.T) { func TestWSExecutionVMStateCheck(t *testing.T) {

View file

@ -40,6 +40,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/neorpc" "github.com/nspcc-dev/neo-go/pkg/neorpc"
"github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/neorpc/result"
"github.com/nspcc-dev/neo-go/pkg/neorpc/rpcevent"
"github.com/nspcc-dev/neo-go/pkg/network" "github.com/nspcc-dev/neo-go/pkg/network"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/services/oracle/broadcaster" "github.com/nspcc-dev/neo-go/pkg/services/oracle/broadcaster"
@ -2593,7 +2594,7 @@ chloop:
continue continue
} }
for i := range sub.feeds { for i := range sub.feeds {
if sub.feeds[i].Matches(&resp) { if rpcevent.Matches(sub.feeds[i], &resp) {
if msg == nil { if msg == nil {
b, err = json.Marshal(resp) b, err = json.Marshal(resp)
if err != nil { if err != nil {

View file

@ -2,11 +2,7 @@ package rpcsrv
import ( import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/neorpc" "github.com/nspcc-dev/neo-go/pkg/neorpc"
"github.com/nspcc-dev/neo-go/pkg/neorpc/result"
"go.uber.org/atomic" "go.uber.org/atomic"
) )
@ -22,12 +18,23 @@ type (
// that's not for long. // that's not for long.
feeds [maxFeeds]feed feeds [maxFeeds]feed
} }
// feed stores subscriber's desired event ID with filter.
feed struct { feed struct {
event neorpc.EventID event neorpc.EventID
filter interface{} filter interface{}
} }
) )
// EventID implements neorpc.EventComparator interface and returns notification ID.
func (f feed) EventID() neorpc.EventID {
return f.event
}
// Filter implements neorpc.EventComparator interface and returns notification filter.
func (f feed) Filter() interface{} {
return f.filter
}
const ( const (
// Maximum number of subscriptions per one client. // Maximum number of subscriptions per one client.
maxFeeds = 16 maxFeeds = 16
@ -42,59 +49,3 @@ const (
// a lot in terms of memory used. // a lot in terms of memory used.
notificationBufSize = 1024 notificationBufSize = 1024
) )
func (f *feed) Matches(r *neorpc.Notification) bool {
if r.Event != f.event {
return false
}
if f.filter == nil {
return true
}
switch f.event {
case neorpc.BlockEventID:
filt := f.filter.(neorpc.BlockFilter)
b := r.Payload[0].(*block.Block)
return int(b.PrimaryIndex) == filt.Primary
case neorpc.TransactionEventID:
filt := f.filter.(neorpc.TxFilter)
tx := r.Payload[0].(*transaction.Transaction)
senderOK := filt.Sender == nil || tx.Sender().Equals(*filt.Sender)
signerOK := true
if filt.Signer != nil {
signerOK = false
for i := range tx.Signers {
if tx.Signers[i].Account.Equals(*filt.Signer) {
signerOK = true
break
}
}
}
return senderOK && signerOK
case neorpc.NotificationEventID:
filt := f.filter.(neorpc.NotificationFilter)
notification := r.Payload[0].(*state.ContainedNotificationEvent)
hashOk := filt.Contract == nil || notification.ScriptHash.Equals(*filt.Contract)
nameOk := filt.Name == nil || notification.Name == *filt.Name
return hashOk && nameOk
case neorpc.ExecutionEventID:
filt := f.filter.(neorpc.ExecutionFilter)
applog := r.Payload[0].(*state.AppExecResult)
return applog.VMState.String() == filt.State
case neorpc.NotaryRequestEventID:
filt := f.filter.(neorpc.TxFilter)
req := r.Payload[0].(*result.NotaryRequestEvent)
senderOk := filt.Sender == nil || req.NotaryRequest.FallbackTransaction.Signers[1].Account == *filt.Sender
signerOK := true
if filt.Signer != nil {
signerOK = false
for _, signer := range req.NotaryRequest.MainTransaction.Signers {
if signer.Account.Equals(*filt.Signer) {
signerOK = true
break
}
}
}
return senderOk && signerOK
}
return false
}