From 6eaa76520f08441113768dec8918b9407a347f09 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Tue, 18 Apr 2023 19:36:57 +0300 Subject: [PATCH] rpcclient: close subscriber channels on wsReader exit The reader is about to exit and it will close legacy c.Notifications, but it will leave subscription channels at the same time. This is wrong since these channels will no longer receive any new events, game over. Signed-off-by: Roman Khimov --- pkg/rpcclient/wsclient.go | 13 +++++++++- pkg/rpcclient/wsclient_test.go | 43 ++++++++++++++++++---------------- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index ff1e21bd1..49443f408 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -46,7 +46,8 @@ import ( // subscriptions share the same receiver channel, then matching notification is // only sent once per channel. The receiver channel will be closed by the WSClient // immediately after MissedEvent is received from the server; no unsubscription -// is performed in this case, so it's the user responsibility to unsubscribe. +// is performed in this case, so it's the user responsibility to unsubscribe. It +// will also be closed on disconnection from server. type WSClient struct { Client // Notifications is a channel that is used to send events received from @@ -539,6 +540,16 @@ readloop: } c.respChannels = nil c.respLock.Unlock() + c.subscriptionsLock.Lock() + for rcvrCh, ids := range c.receivers { + rcvr := c.subscriptions[ids[0]] + _, ok := rcvr.(*naiveReceiver) + if !ok { // naiveReceiver uses c.Notifications that is about to be closed below. + c.subscriptions[ids[0]].Close() + } + delete(c.receivers, rcvrCh) + } + c.subscriptionsLock.Unlock() close(c.Notifications) c.Client.ctxCancel() } diff --git a/pkg/rpcclient/wsclient_test.go b/pkg/rpcclient/wsclient_test.go index 1d7f3f774..c3f30c894 100644 --- a/pkg/rpcclient/wsclient_test.go +++ b/pkg/rpcclient/wsclient_test.go @@ -30,10 +30,17 @@ import ( ) func TestWSClientClose(t *testing.T) { - srv := initTestServer(t, "") + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID + bCh := make(chan *block.Block) + _, err = wsc.ReceiveBlocks(nil, bCh) + require.NoError(t, err) wsc.Close() + // Subscriber channel must be closed by server. + _, ok := <-bCh + require.False(t, ok) } func TestWSClientSubscription(t *testing.T) { @@ -296,10 +303,6 @@ func TestWSExecutionVMStateCheck(t *testing.T) { } func TestWSFilteredSubscriptions(t *testing.T) { - bCh := make(chan *block.Block) - txCh := make(chan *transaction.Transaction) - aerCh := make(chan *state.AppExecResult) - ntfCh := make(chan *state.ContainedNotificationEvent) var cases = []struct { name string clientCode func(*testing.T, *WSClient) @@ -308,7 +311,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"blocks primary", func(t *testing.T, wsc *WSClient) { primary := 3 - _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, bCh) + _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Block)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -323,7 +326,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"blocks since", func(t *testing.T, wsc *WSClient) { var since uint32 = 3 - _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, bCh) + _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, make(chan *block.Block)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -338,7 +341,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"blocks till", func(t *testing.T, wsc *WSClient) { var till uint32 = 3 - _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, bCh) + _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, make(chan *block.Block)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -361,7 +364,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { Primary: &primary, Since: &since, Till: &till, - }, bCh) + }, make(chan *block.Block)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -376,7 +379,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"transactions sender", func(t *testing.T, wsc *WSClient) { sender := util.Uint160{1, 2, 3, 4, 5} - _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, txCh) + _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, make(chan *transaction.Transaction)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -390,7 +393,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"transactions signer", func(t *testing.T, wsc *WSClient) { signer := util.Uint160{0, 42} - _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, txCh) + _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, make(chan *transaction.Transaction)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -405,7 +408,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { func(t *testing.T, wsc *WSClient) { sender := util.Uint160{1, 2, 3, 4, 5} signer := util.Uint160{0, 42} - _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, txCh) + _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, make(chan *transaction.Transaction)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -419,7 +422,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"notifications contract hash", func(t *testing.T, wsc *WSClient) { contract := util.Uint160{1, 2, 3, 4, 5} - _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, ntfCh) + _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, make(chan *state.ContainedNotificationEvent)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -433,7 +436,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"notifications name", func(t *testing.T, wsc *WSClient) { name := "my_pretty_notification" - _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, ntfCh) + _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, make(chan *state.ContainedNotificationEvent)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -448,7 +451,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { func(t *testing.T, wsc *WSClient) { contract := util.Uint160{1, 2, 3, 4, 5} name := "my_pretty_notification" - _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, ntfCh) + _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, make(chan *state.ContainedNotificationEvent)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -461,8 +464,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, {"executions state", func(t *testing.T, wsc *WSClient) { - state := "FAULT" - _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state}, aerCh) + vmstate := "FAULT" + _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate}, make(chan *state.AppExecResult)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -476,7 +479,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"executions container", func(t *testing.T, wsc *WSClient) { container := util.Uint256{1, 2, 3} - _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, aerCh) + _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, make(chan *state.AppExecResult)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -489,9 +492,9 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, {"executions state and container", func(t *testing.T, wsc *WSClient) { - state := "FAULT" + vmstate := "FAULT" container := util.Uint256{1, 2, 3} - _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state, Container: &container}, aerCh) + _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate, Container: &container}, make(chan *state.AppExecResult)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) {