rpcclient: close subscriber channels on wsReader exit

The reader is about to exit and it will close legacy c.Notifications, but it
will leave subscription channels at the same time. This is wrong since these
channels will no longer receive any new events, game over.

Signed-off-by: Roman Khimov <roman@nspcc.ru>
This commit is contained in:
Roman Khimov 2023-04-18 19:36:57 +03:00 committed by Anna Shaleva
parent 08b273266b
commit 288dee8871
2 changed files with 35 additions and 21 deletions

View file

@ -46,7 +46,8 @@ import (
// subscriptions share the same receiver channel, then matching notification is // subscriptions share the same receiver channel, then matching notification is
// only sent once per channel. The receiver channel will be closed by the WSClient // only sent once per channel. The receiver channel will be closed by the WSClient
// immediately after MissedEvent is received from the server; no unsubscription // immediately after MissedEvent is received from the server; no unsubscription
// is performed in this case, so it's the user responsibility to unsubscribe. // is performed in this case, so it's the user responsibility to unsubscribe. It
// will also be closed on disconnection from server.
type WSClient struct { type WSClient struct {
Client Client
// Notifications is a channel that is used to send events received from // Notifications is a channel that is used to send events received from
@ -539,6 +540,16 @@ readloop:
} }
c.respChannels = nil c.respChannels = nil
c.respLock.Unlock() c.respLock.Unlock()
c.subscriptionsLock.Lock()
for rcvrCh, ids := range c.receivers {
rcvr := c.subscriptions[ids[0]]
_, ok := rcvr.(*naiveReceiver)
if !ok { // naiveReceiver uses c.Notifications that is about to be closed below.
c.subscriptions[ids[0]].Close()
}
delete(c.receivers, rcvrCh)
}
c.subscriptionsLock.Unlock()
close(c.Notifications) close(c.Notifications)
c.Client.ctxCancel() c.Client.ctxCancel()
} }

View file

@ -30,10 +30,17 @@ import (
) )
func TestWSClientClose(t *testing.T) { func TestWSClientClose(t *testing.T) {
srv := initTestServer(t, "") srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
bCh := make(chan *block.Block)
_, err = wsc.ReceiveBlocks(nil, bCh)
require.NoError(t, err)
wsc.Close() wsc.Close()
// Subscriber channel must be closed by server.
_, ok := <-bCh
require.False(t, ok)
} }
func TestWSClientSubscription(t *testing.T) { func TestWSClientSubscription(t *testing.T) {
@ -296,10 +303,6 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
} }
func TestWSFilteredSubscriptions(t *testing.T) { func TestWSFilteredSubscriptions(t *testing.T) {
bCh := make(chan *block.Block)
txCh := make(chan *transaction.Transaction)
aerCh := make(chan *state.AppExecResult)
ntfCh := make(chan *state.ContainedNotificationEvent)
var cases = []struct { var cases = []struct {
name string name string
clientCode func(*testing.T, *WSClient) clientCode func(*testing.T, *WSClient)
@ -308,7 +311,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"blocks primary", {"blocks primary",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
primary := 3 primary := 3
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, bCh) _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Block))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -323,7 +326,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"blocks since", {"blocks since",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
var since uint32 = 3 var since uint32 = 3
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, bCh) _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, make(chan *block.Block))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -338,7 +341,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"blocks till", {"blocks till",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
var till uint32 = 3 var till uint32 = 3
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, bCh) _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, make(chan *block.Block))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -361,7 +364,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
Primary: &primary, Primary: &primary,
Since: &since, Since: &since,
Till: &till, Till: &till,
}, bCh) }, make(chan *block.Block))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -376,7 +379,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"transactions sender", {"transactions sender",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
sender := util.Uint160{1, 2, 3, 4, 5} sender := util.Uint160{1, 2, 3, 4, 5}
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, txCh) _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, make(chan *transaction.Transaction))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -390,7 +393,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"transactions signer", {"transactions signer",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
signer := util.Uint160{0, 42} signer := util.Uint160{0, 42}
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, txCh) _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, make(chan *transaction.Transaction))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -405,7 +408,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
sender := util.Uint160{1, 2, 3, 4, 5} sender := util.Uint160{1, 2, 3, 4, 5}
signer := util.Uint160{0, 42} signer := util.Uint160{0, 42}
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, txCh) _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, make(chan *transaction.Transaction))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -419,7 +422,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"notifications contract hash", {"notifications contract hash",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
contract := util.Uint160{1, 2, 3, 4, 5} contract := util.Uint160{1, 2, 3, 4, 5}
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, ntfCh) _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, make(chan *state.ContainedNotificationEvent))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -433,7 +436,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"notifications name", {"notifications name",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
name := "my_pretty_notification" name := "my_pretty_notification"
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, ntfCh) _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, make(chan *state.ContainedNotificationEvent))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -448,7 +451,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
contract := util.Uint160{1, 2, 3, 4, 5} contract := util.Uint160{1, 2, 3, 4, 5}
name := "my_pretty_notification" name := "my_pretty_notification"
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, ntfCh) _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, make(chan *state.ContainedNotificationEvent))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -461,8 +464,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
}, },
{"executions state", {"executions state",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
state := "FAULT" vmstate := "FAULT"
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state}, aerCh) _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate}, make(chan *state.AppExecResult))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -476,7 +479,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"executions container", {"executions container",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
container := util.Uint256{1, 2, 3} container := util.Uint256{1, 2, 3}
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, aerCh) _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, make(chan *state.AppExecResult))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -489,9 +492,9 @@ func TestWSFilteredSubscriptions(t *testing.T) {
}, },
{"executions state and container", {"executions state and container",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
state := "FAULT" vmstate := "FAULT"
container := util.Uint256{1, 2, 3} container := util.Uint256{1, 2, 3}
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state, Container: &container}, aerCh) _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate, Container: &container}, make(chan *state.AppExecResult))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {