Merge pull request #2980 from nspcc-dev/close-notification-chans-on-wsclient-disconnect

Close notification channels on wsclient disconnect
This commit is contained in:
Roman Khimov 2023-04-19 18:11:36 +03:00 committed by GitHub
commit 7b109586ca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 22 deletions

View file

@ -46,7 +46,8 @@ import (
// subscriptions share the same receiver channel, then matching notification is
// only sent once per channel. The receiver channel will be closed by the WSClient
// immediately after MissedEvent is received from the server; no unsubscription
// is performed in this case, so it's the user responsibility to unsubscribe.
// is performed in this case, so it's the user responsibility to unsubscribe. It
// will also be closed on disconnection from server.
type WSClient struct {
Client
// Notifications is a channel that is used to send events received from
@ -539,6 +540,16 @@ readloop:
}
c.respChannels = nil
c.respLock.Unlock()
c.subscriptionsLock.Lock()
for rcvrCh, ids := range c.receivers {
rcvr := c.subscriptions[ids[0]]
_, ok := rcvr.(*naiveReceiver)
if !ok { // naiveReceiver uses c.Notifications that is about to be closed below.
c.subscriptions[ids[0]].Close()
}
delete(c.receivers, rcvrCh)
}
c.subscriptionsLock.Unlock()
close(c.Notifications)
c.Client.ctxCancel()
}
@ -638,7 +649,10 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
select {
case <-c.done:
return nil, errors.New("connection lost while waiting for the response")
case resp := <-ch:
case resp, ok := <-ch:
if !ok {
return nil, errors.New("connection lost while waiting for the response")
}
c.unregisterRespChannel(r.ID)
return resp, nil
}

View file

@ -30,10 +30,17 @@ import (
)
func TestWSClientClose(t *testing.T) {
srv := initTestServer(t, "")
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
bCh := make(chan *block.Block)
_, err = wsc.ReceiveBlocks(nil, bCh)
require.NoError(t, err)
wsc.Close()
// Subscriber channel must be closed by server.
_, ok := <-bCh
require.False(t, ok)
}
func TestWSClientSubscription(t *testing.T) {
@ -296,10 +303,6 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
}
func TestWSFilteredSubscriptions(t *testing.T) {
bCh := make(chan *block.Block)
txCh := make(chan *transaction.Transaction)
aerCh := make(chan *state.AppExecResult)
ntfCh := make(chan *state.ContainedNotificationEvent)
var cases = []struct {
name string
clientCode func(*testing.T, *WSClient)
@ -308,7 +311,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"blocks primary",
func(t *testing.T, wsc *WSClient) {
primary := 3
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, bCh)
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Block))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -323,7 +326,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"blocks since",
func(t *testing.T, wsc *WSClient) {
var since uint32 = 3
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, bCh)
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, make(chan *block.Block))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -338,7 +341,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"blocks till",
func(t *testing.T, wsc *WSClient) {
var till uint32 = 3
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, bCh)
_, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, make(chan *block.Block))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -361,7 +364,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
Primary: &primary,
Since: &since,
Till: &till,
}, bCh)
}, make(chan *block.Block))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -376,7 +379,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"transactions sender",
func(t *testing.T, wsc *WSClient) {
sender := util.Uint160{1, 2, 3, 4, 5}
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, txCh)
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, make(chan *transaction.Transaction))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -390,7 +393,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"transactions signer",
func(t *testing.T, wsc *WSClient) {
signer := util.Uint160{0, 42}
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, txCh)
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, make(chan *transaction.Transaction))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -405,7 +408,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
func(t *testing.T, wsc *WSClient) {
sender := util.Uint160{1, 2, 3, 4, 5}
signer := util.Uint160{0, 42}
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, txCh)
_, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, make(chan *transaction.Transaction))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -419,7 +422,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"notifications contract hash",
func(t *testing.T, wsc *WSClient) {
contract := util.Uint160{1, 2, 3, 4, 5}
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, ntfCh)
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, make(chan *state.ContainedNotificationEvent))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -433,7 +436,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"notifications name",
func(t *testing.T, wsc *WSClient) {
name := "my_pretty_notification"
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, ntfCh)
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, make(chan *state.ContainedNotificationEvent))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -448,7 +451,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
func(t *testing.T, wsc *WSClient) {
contract := util.Uint160{1, 2, 3, 4, 5}
name := "my_pretty_notification"
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, ntfCh)
_, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, make(chan *state.ContainedNotificationEvent))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -461,8 +464,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
},
{"executions state",
func(t *testing.T, wsc *WSClient) {
state := "FAULT"
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state}, aerCh)
vmstate := "FAULT"
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate}, make(chan *state.AppExecResult))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -476,7 +479,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
{"executions container",
func(t *testing.T, wsc *WSClient) {
container := util.Uint256{1, 2, 3}
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, aerCh)
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, make(chan *state.AppExecResult))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {
@ -489,9 +492,9 @@ func TestWSFilteredSubscriptions(t *testing.T) {
},
{"executions state and container",
func(t *testing.T, wsc *WSClient) {
state := "FAULT"
vmstate := "FAULT"
container := util.Uint256{1, 2, 3}
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state, Container: &container}, aerCh)
_, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate, Container: &container}, make(chan *state.AppExecResult))
require.NoError(t, err)
},
func(t *testing.T, p *params.Params) {