diff --git a/pkg/rpcclient/rpc_test.go b/pkg/rpcclient/rpc_test.go index 3cf3b9f9b..c60e77ec5 100644 --- a/pkg/rpcclient/rpc_test.go +++ b/pkg/rpcclient/rpc_test.go @@ -1803,7 +1803,7 @@ func TestRPCClients(t *testing.T) { }) t.Run("WSClient", func(t *testing.T) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { - wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) + wsc, err := NewWS(ctx, httpURLtoWS(endpoint), WSOptions{Options: opts}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 49443f408..b77e03f4b 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -38,7 +38,11 @@ import ( // will make WSClient wait for the channel reader to get the event and while // it waits every other messages (subscription-related or request replies) // will be blocked. This also means that subscription channel must be properly -// drained after unsubscription. +// drained after unsubscription. If CloseNotificationChannelIfFull option is on +// then the receiver channel will be closed immediately in case if a subsequent +// notification can't be sent to it, which means WSClient's operations are +// unblocking in this mode. No unsubscription is performed in this case, so it's +// still the user responsibility to unsubscribe. // // Any received subscription items (blocks/transactions/nofitications) are passed // via pointers for efficiency, but the actual structures MUST NOT be changed, as @@ -47,7 +51,9 @@ import ( // only sent once per channel. The receiver channel will be closed by the WSClient // immediately after MissedEvent is received from the server; no unsubscription // is performed in this case, so it's the user responsibility to unsubscribe. It -// will also be closed on disconnection from server. +// will also be closed on disconnection from server or on situation when it's +// impossible to send a subsequent notification to the subscriber's channel and +// CloseNotificationChannelIfFull option is on. type WSClient struct { Client // Notifications is a channel that is used to send events received from @@ -66,6 +72,7 @@ type WSClient struct { Notifications chan Notification ws *websocket.Conn + wsOpts WSOptions done chan struct{} requests chan *neorpc.Request shutdown chan struct{} @@ -86,6 +93,21 @@ type WSClient struct { respChannels map[uint64]chan *neorpc.Response } +// WSOptions defines options for the web-socket RPC client. It contains a +// set of options for the underlying standard RPC client as far as +// WSClient-specific options. See Options documentation for more details. +type WSOptions struct { + Options + // CloseNotificationChannelIfFull allows WSClient to close a subscriber's + // receive channel in case if the channel isn't read properly and no more + // events can be pushed to it. This option, if set, allows to avoid WSClient + // blocking on a subsequent notification dispatch. However, if enabled, the + // corresponding subscription is kept even after receiver's channel closing, + // thus it's still the caller's duty to call Unsubscribe() for this + // subscription. + CloseNotificationChannelIfFull bool +} + // notificationReceiver is an interface aimed to provide WS subscriber functionality // for different types of subscriptions. type notificationReceiver interface { @@ -94,8 +116,11 @@ type notificationReceiver interface { // Receiver returns notification receiver channel. Receiver() any // TrySend checks whether notification passes receiver filter and sends it - // to the underlying channel if so. - TrySend(ntf Notification) bool + // to the underlying channel if so. It is performed under subscriptions lock + // taken. nonBlocking denotes whether the receiving operation shouldn't block + // the client's operation. It returns whether notification matches the filter + // and whether the receiver channel is overflown. + TrySend(ntf Notification, nonBlocking bool) (bool, bool) // Close closes underlying receiver channel. Close() } @@ -125,12 +150,21 @@ func (r *blockReceiver) Receiver() any { } // TrySend implements notificationReceiver interface. -func (r *blockReceiver) TrySend(ntf Notification) bool { +func (r *blockReceiver) TrySend(ntf Notification, nonBlocking bool) (bool, bool) { if rpcevent.Matches(r, ntf) { - r.ch <- ntf.Value.(*block.Block) - return true + if nonBlocking { + select { + case r.ch <- ntf.Value.(*block.Block): + default: + return true, true + } + } else { + r.ch <- ntf.Value.(*block.Block) + } + + return true, false } - return false + return false, false } // Close implements notificationReceiver interface. @@ -163,12 +197,21 @@ func (r *txReceiver) Receiver() any { } // TrySend implements notificationReceiver interface. -func (r *txReceiver) TrySend(ntf Notification) bool { +func (r *txReceiver) TrySend(ntf Notification, nonBlocking bool) (bool, bool) { if rpcevent.Matches(r, ntf) { - r.ch <- ntf.Value.(*transaction.Transaction) - return true + if nonBlocking { + select { + case r.ch <- ntf.Value.(*transaction.Transaction): + default: + return true, true + } + } else { + r.ch <- ntf.Value.(*transaction.Transaction) + } + + return true, false } - return false + return false, false } // Close implements notificationReceiver interface. @@ -201,12 +244,21 @@ func (r *executionNotificationReceiver) Receiver() any { } // TrySend implements notificationReceiver interface. -func (r *executionNotificationReceiver) TrySend(ntf Notification) bool { +func (r *executionNotificationReceiver) TrySend(ntf Notification, nonBlocking bool) (bool, bool) { if rpcevent.Matches(r, ntf) { - r.ch <- ntf.Value.(*state.ContainedNotificationEvent) - return true + if nonBlocking { + select { + case r.ch <- ntf.Value.(*state.ContainedNotificationEvent): + default: + return true, true + } + } else { + r.ch <- ntf.Value.(*state.ContainedNotificationEvent) + } + + return true, false } - return false + return false, false } // Close implements notificationReceiver interface. @@ -239,12 +291,21 @@ func (r *executionReceiver) Receiver() any { } // TrySend implements notificationReceiver interface. -func (r *executionReceiver) TrySend(ntf Notification) bool { +func (r *executionReceiver) TrySend(ntf Notification, nonBlocking bool) (bool, bool) { if rpcevent.Matches(r, ntf) { - r.ch <- ntf.Value.(*state.AppExecResult) - return true + if nonBlocking { + select { + case r.ch <- ntf.Value.(*state.AppExecResult): + default: + return true, true + } + } else { + r.ch <- ntf.Value.(*state.AppExecResult) + } + + return true, false } - return false + return false, false } // Close implements notificationReceiver interface. @@ -277,12 +338,21 @@ func (r *notaryRequestReceiver) Receiver() any { } // TrySend implements notificationReceiver interface. -func (r *notaryRequestReceiver) TrySend(ntf Notification) bool { +func (r *notaryRequestReceiver) TrySend(ntf Notification, nonBlocking bool) (bool, bool) { if rpcevent.Matches(r, ntf) { - r.ch <- ntf.Value.(*result.NotaryRequestEvent) - return true + if nonBlocking { + select { + case r.ch <- ntf.Value.(*result.NotaryRequestEvent): + default: + return true, true + } + } else { + r.ch <- ntf.Value.(*result.NotaryRequestEvent) + } + + return true, false } - return false + return false, false } // Close implements notificationReceiver interface. @@ -316,12 +386,21 @@ func (r *naiveReceiver) Receiver() any { } // TrySend implements notificationReceiver interface. -func (r *naiveReceiver) TrySend(ntf Notification) bool { +func (r *naiveReceiver) TrySend(ntf Notification, nonBlocking bool) (bool, bool) { if rpcevent.Matches(r, ntf) { - r.ch <- ntf - return true + if nonBlocking { + select { + case r.ch <- ntf: + default: + return true, true + } + } else { + r.ch <- ntf + } + + return true, false } - return false + return false, false } // Close implements notificationReceiver interface. @@ -382,7 +461,7 @@ var errConnClosedByUser = errors.New("connection closed by user") // connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`. // You should call Init method to initialize the network magic the client is // operating on. -func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) { +func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) { dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} ws, resp, err := dialer.DialContext(ctx, endpoint, nil) if resp != nil && resp.Body != nil { // Can be non-nil even with error returned. @@ -405,6 +484,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error Notifications: make(chan Notification), ws: ws, + wsOpts: opts, shutdown: make(chan struct{}), done: make(chan struct{}), closeCalled: *atomic.NewBool(false), @@ -414,7 +494,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error receivers: make(map[any][]string), } - err = initClient(ctx, &wsc.Client, endpoint, opts) + err = initClient(ctx, &wsc.Client, endpoint, opts.Options) if err != nil { return nil, err } @@ -542,16 +622,30 @@ readloop: c.respLock.Unlock() c.subscriptionsLock.Lock() for rcvrCh, ids := range c.receivers { - rcvr := c.subscriptions[ids[0]] + c.dropSubCh(rcvrCh, ids[0], true) + } + c.subscriptionsLock.Unlock() + c.Client.ctxCancel() +} + +// dropSubCh closes corresponding subscriber's channel and removes it from the +// receivers map. If the channel belongs to a naive subscriber then it will be +// closed manually without call to Close(). The channel is still being kept in +// the subscribers map as technically the server-side subscription still exists +// and the user is responsible for unsubscription. This method must be called +// under subscriptionsLock taken. It's the caller's duty to ensure dropSubCh +// will be called once per channel, otherwise panic will occur. +func (c *WSClient) dropSubCh(rcvrCh any, id string, ignoreCloseNotificationChannelIfFull bool) { + if ignoreCloseNotificationChannelIfFull || c.wsOpts.CloseNotificationChannelIfFull { + rcvr := c.subscriptions[id] _, ok := rcvr.(*naiveReceiver) - if !ok { // naiveReceiver uses c.Notifications that is about to be closed below. - c.subscriptions[ids[0]].Close() + if ok { // naiveReceiver uses c.Notifications that should be handled separately. + close(c.Notifications) + } else { + c.subscriptions[id].Close() } delete(c.receivers, rcvrCh) } - c.subscriptionsLock.Unlock() - close(c.Notifications) - c.Client.ctxCancel() } func (c *WSClient) wsWriter() { @@ -604,15 +698,20 @@ func (c *WSClient) notifySubscribers(ntf Notification) { c.subscriptionsLock.Unlock() return } - c.subscriptionsLock.RLock() - for _, ids := range c.receivers { + c.subscriptionsLock.Lock() + for rcvrCh, ids := range c.receivers { for _, id := range ids { - if c.subscriptions[id].TrySend(ntf) { + ok, dropCh := c.subscriptions[id].TrySend(ntf, c.wsOpts.CloseNotificationChannelIfFull) + if dropCh { + c.dropSubCh(rcvrCh, id, false) + break // strictly single drop per channel + } + if ok { break // strictly one notification per channel } } } - c.subscriptionsLock.RUnlock() + c.subscriptionsLock.Unlock() } func (c *WSClient) unregisterRespChannel(id uint64) { diff --git a/pkg/rpcclient/wsclient_test.go b/pkg/rpcclient/wsclient_test.go index c3f30c894..ccd272e3b 100644 --- a/pkg/rpcclient/wsclient_test.go +++ b/pkg/rpcclient/wsclient_test.go @@ -31,7 +31,7 @@ import ( func TestWSClientClose(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID bCh := make(chan *block.Block) @@ -70,7 +70,7 @@ func TestWSClientSubscription(t *testing.T) { for name, f := range cases { t.Run(name, func(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) @@ -84,7 +84,7 @@ func TestWSClientSubscription(t *testing.T) { for name, f := range cases { t.Run(name, func(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) @@ -134,7 +134,7 @@ func TestWSClientUnsubscription(t *testing.T) { for name, rc := range cases { t.Run(name, func(t *testing.T) { srv := initTestServer(t, rc.response) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) @@ -170,7 +170,7 @@ func TestWSClientEvents(t *testing.T) { return } })) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID wsc.cacheLock.Lock() @@ -289,10 +289,92 @@ func TestWSClientEvents(t *testing.T) { require.False(t, ok) } +func TestWSClientNonBlockingEvents(t *testing.T) { + // Use buffered channel as a receiver to check it will be closed by WSClient + // after overflow if CloseNotificationChannelIfFull option is enabled. + const chCap = 3 + bCh := make(chan *block.Block, chCap) + + // Events from RPC server testchain. Require events len to be larger than chCap to reach + // subscriber's chanel overflow. + var events = []string{ + fmt.Sprintf(`{"jsonrpc":"2.0","method":"block_added","params":[%s]}`, b1Verbose), + fmt.Sprintf(`{"jsonrpc":"2.0","method":"block_added","params":[%s]}`, b1Verbose), + fmt.Sprintf(`{"jsonrpc":"2.0","method":"block_added","params":[%s]}`, b1Verbose), + fmt.Sprintf(`{"jsonrpc":"2.0","method":"block_added","params":[%s]}`, b1Verbose), + fmt.Sprintf(`{"jsonrpc":"2.0","method":"block_added","params":[%s]}`, b1Verbose), + } + require.True(t, chCap < len(events)) + + var blocksSent atomic.Bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ws" && req.Method == "GET" { + var upgrader = websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + for _, event := range events { + err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) + require.NoError(t, err) + err = ws.WriteMessage(1, []byte(event)) + if err != nil { + break + } + } + blocksSent.Store(true) + ws.Close() + return + } + })) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{CloseNotificationChannelIfFull: true}) + require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID + wsc.cacheLock.Lock() + wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually. + wsc.cache.network = netmode.UnitTestNet + wsc.cacheLock.Unlock() + + // Our server mock is restricted, so perform subscriptions manually. + wsc.subscriptionsLock.Lock() + wsc.subscriptions["0"] = &blockReceiver{ch: bCh} + wsc.subscriptions["1"] = &blockReceiver{ch: bCh} + wsc.receivers[chan<- *block.Block(bCh)] = []string{"0", "1"} + wsc.subscriptionsLock.Unlock() + + // Check that events are sent to WSClient. + require.Eventually(t, func() bool { + return blocksSent.Load() + }, time.Second, 100*time.Millisecond) + + // Check that block receiver channel was removed from the receivers list due to overflow. + require.Eventually(t, func() bool { + wsc.subscriptionsLock.RLock() + defer wsc.subscriptionsLock.RUnlock() + return len(wsc.receivers) == 0 + }, 2*time.Second, 200*time.Millisecond) + + // Check that subscriptions are still there and waiting for the call to Unsubscribe() + // to be excluded from the subscriptions map. + wsc.subscriptionsLock.RLock() + require.True(t, len(wsc.subscriptions) == 2) + wsc.subscriptionsLock.RUnlock() + + // Check that receiver was closed after overflow. + for i := 0; i < chCap; i++ { + _, ok := <-bCh + require.True(t, ok) + } + select { + case _, ok := <-bCh: + require.False(t, ok) + default: + t.Fatal("channel wasn't closed by WSClient") + } +} + func TestWSExecutionVMStateCheck(t *testing.T) { // Will answer successfully if request slips through. srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) @@ -527,7 +609,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { ws.Close() } })) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID wsc.cache.network = netmode.UnitTestNet @@ -541,14 +623,14 @@ func TestNewWS(t *testing.T) { srv := initTestServer(t, "") t.Run("good", func(t *testing.T) { - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) c.getNextRequestID = getTestRequestID c.cache.network = netmode.UnitTestNet require.NoError(t, c.Init()) }) t.Run("bad URL", func(t *testing.T) { - _, err := NewWS(context.TODO(), strings.TrimPrefix(srv.URL, "http://"), Options{}) + _, err := NewWS(context.TODO(), strings.TrimPrefix(srv.URL, "http://"), WSOptions{}) require.Error(t, err) }) } @@ -605,7 +687,7 @@ func TestWSConcurrentAccess(t *testing.T) { })) t.Cleanup(srv.Close) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) batchCount := 100 completed := atomic.NewInt32(0) @@ -649,7 +731,7 @@ func TestWSConcurrentAccess(t *testing.T) { func TestWSDoubleClose(t *testing.T) { srv := initTestServer(t, "") - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) require.NotPanics(t, func() { @@ -661,7 +743,7 @@ func TestWSDoubleClose(t *testing.T) { func TestWS_RequestAfterClose(t *testing.T) { srv := initTestServer(t, "") - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) c.Close() @@ -676,7 +758,7 @@ func TestWS_RequestAfterClose(t *testing.T) { func TestWSClient_ConnClosedError(t *testing.T) { t.Run("standard closing", func(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": 123}`) - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) // Check client is working. @@ -692,7 +774,7 @@ func TestWSClient_ConnClosedError(t *testing.T) { t.Run("malformed request", func(t *testing.T) { srv := initTestServer(t, "") - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) defaultMaxBlockSize := 262144 diff --git a/pkg/services/rpcsrv/client_test.go b/pkg/services/rpcsrv/client_test.go index af5a80cd1..73f36154f 100644 --- a/pkg/services/rpcsrv/client_test.go +++ b/pkg/services/rpcsrv/client_test.go @@ -2073,7 +2073,7 @@ func mkSubsClient(t *testing.T, rpcSrv *Server, httpSrv *httptest.Server, local icl, err = rpcclient.NewInternal(context.Background(), rpcSrv.RegisterLocal) } else { url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" - c, err = rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) + c, err = rpcclient.NewWS(context.Background(), url, rpcclient.WSOptions{}) } require.NoError(t, err) if local { @@ -2239,7 +2239,7 @@ func TestWSClientHandshakeError(t *testing.T) { defer rpcSrv.Shutdown() url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" - _, err := rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) + _, err := rpcclient.NewWS(context.Background(), url, rpcclient.WSOptions{}) require.ErrorContains(t, err, "websocket users limit reached") }