Merge pull request #2980 from nspcc-dev/close-notification-chans-on-wsclient-disconnect

Close notification channels on wsclient disconnect
This commit is contained in:
Roman Khimov 2023-04-19 18:11:36 +03:00 committed by GitHub
commit 7b109586ca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 22 deletions

View file

@ -46,7 +46,8 @@ import (
// subscriptions share the same receiver channel, then matching notification is // subscriptions share the same receiver channel, then matching notification is
// only sent once per channel. The receiver channel will be closed by the WSClient // only sent once per channel. The receiver channel will be closed by the WSClient
// immediately after MissedEvent is received from the server; no unsubscription // immediately after MissedEvent is received from the server; no unsubscription
// is performed in this case, so it's the user responsibility to unsubscribe. // is performed in this case, so it's the user responsibility to unsubscribe. It
// will also be closed on disconnection from server.
type WSClient struct { type WSClient struct {
Client Client
// Notifications is a channel that is used to send events received from // Notifications is a channel that is used to send events received from
@ -539,6 +540,16 @@ readloop:
} }
c.respChannels = nil c.respChannels = nil
c.respLock.Unlock() c.respLock.Unlock()
c.subscriptionsLock.Lock()
for rcvrCh, ids := range c.receivers {
rcvr := c.subscriptions[ids[0]]
_, ok := rcvr.(*naiveReceiver)
if !ok { // naiveReceiver uses c.Notifications that is about to be closed below.
c.subscriptions[ids[0]].Close()
}
delete(c.receivers, rcvrCh)
}
c.subscriptionsLock.Unlock()
close(c.Notifications) close(c.Notifications)
c.Client.ctxCancel() c.Client.ctxCancel()
} }
@ -638,7 +649,10 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
select { select {
case <-c.done: case <-c.done:
return nil, errors.New("connection lost while waiting for the response") return nil, errors.New("connection lost while waiting for the response")
case resp := <-ch: case resp, ok := <-ch:
if !ok {
return nil, errors.New("connection lost while waiting for the response")
}
c.unregisterRespChannel(r.ID) c.unregisterRespChannel(r.ID)
return resp, nil return resp, nil
} }

View file

@ -30,10 +30,17 @@ import (
) )
func TestWSClientClose(t *testing.T) { func TestWSClientClose(t *testing.T) {
srv := initTestServer(t, "") srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
bCh := make(chan *block.Block)
_, err = wsc.ReceiveBlocks(nil, bCh)
require.NoError(t, err)
wsc.Close() wsc.Close()
// Subscriber channel must be closed by server.
_, ok := <-bCh
require.False(t, ok)
} }
func TestWSClientSubscription(t *testing.T) { func TestWSClientSubscription(t *testing.T) {
@ -296,10 +303,6 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
} }
func TestWSFilteredSubscriptions(t *testing.T) { func TestWSFilteredSubscriptions(t *testing.T) {
bCh := make(chan *block.Block)
txCh := make(chan *transaction.Transaction)
aerCh := make(chan *state.AppExecResult)
ntfCh := make(chan *state.ContainedNotificationEvent)
var cases = []struct { var cases = []struct {
name string name string
clientCode func(*testing.T, *WSClient) clientCode func(*testing.T, *WSClient)
@ -308,7 +311,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"blocks primary", {"blocks primary",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
primary := 3 primary := 3
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, bCh) _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Block))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -323,7 +326,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"blocks since", {"blocks since",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
var since uint32 = 3 var since uint32 = 3
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, bCh) _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, make(chan *block.Block))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -338,7 +341,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"blocks till", {"blocks till",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
var till uint32 = 3 var till uint32 = 3
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, bCh) _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, make(chan *block.Block))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -361,7 +364,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
Primary: &primary, Primary: &primary,
Since: &since, Since: &since,
Till: &till, Till: &till,
}, bCh) }, make(chan *block.Block))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -376,7 +379,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"transactions sender", {"transactions sender",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
sender := util.Uint160{1, 2, 3, 4, 5} sender := util.Uint160{1, 2, 3, 4, 5}
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, txCh) _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, make(chan *transaction.Transaction))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -390,7 +393,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"transactions signer", {"transactions signer",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
signer := util.Uint160{0, 42} signer := util.Uint160{0, 42}
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, txCh) _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, make(chan *transaction.Transaction))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -405,7 +408,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
sender := util.Uint160{1, 2, 3, 4, 5} sender := util.Uint160{1, 2, 3, 4, 5}
signer := util.Uint160{0, 42} signer := util.Uint160{0, 42}
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, txCh) _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, make(chan *transaction.Transaction))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -419,7 +422,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"notifications contract hash", {"notifications contract hash",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
contract := util.Uint160{1, 2, 3, 4, 5} contract := util.Uint160{1, 2, 3, 4, 5}
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, ntfCh) _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, make(chan *state.ContainedNotificationEvent))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -433,7 +436,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"notifications name", {"notifications name",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
name := "my_pretty_notification" name := "my_pretty_notification"
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, ntfCh) _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, make(chan *state.ContainedNotificationEvent))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -448,7 +451,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
contract := util.Uint160{1, 2, 3, 4, 5} contract := util.Uint160{1, 2, 3, 4, 5}
name := "my_pretty_notification" name := "my_pretty_notification"
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, ntfCh) _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, make(chan *state.ContainedNotificationEvent))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -461,8 +464,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
}, },
{"executions state", {"executions state",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
state := "FAULT" vmstate := "FAULT"
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state}, aerCh) _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate}, make(chan *state.AppExecResult))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -476,7 +479,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"executions container", {"executions container",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
container := util.Uint256{1, 2, 3} container := util.Uint256{1, 2, 3}
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, aerCh) _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, make(chan *state.AppExecResult))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {
@ -489,9 +492,9 @@ func TestWSFilteredSubscriptions(t *testing.T) {
}, },
{"executions state and container", {"executions state and container",
func(t *testing.T, wsc *WSClient) { func(t *testing.T, wsc *WSClient) {
state := "FAULT" vmstate := "FAULT"
container := util.Uint256{1, 2, 3} container := util.Uint256{1, 2, 3}
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state, Container: &container}, aerCh) _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate, Container: &container}, make(chan *state.AppExecResult))
require.NoError(t, err) require.NoError(t, err)
}, },
func(t *testing.T, p *params.Params) { func(t *testing.T, p *params.Params) {