diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index c3f6172a1..49443f408 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -46,7 +46,8 @@ import ( // subscriptions share the same receiver channel, then matching notification is // only sent once per channel. The receiver channel will be closed by the WSClient // immediately after MissedEvent is received from the server; no unsubscription -// is performed in this case, so it's the user responsibility to unsubscribe. +// is performed in this case, so it's the user responsibility to unsubscribe. It +// will also be closed on disconnection from server. type WSClient struct { Client // Notifications is a channel that is used to send events received from @@ -539,6 +540,16 @@ readloop: } c.respChannels = nil c.respLock.Unlock() + c.subscriptionsLock.Lock() + for rcvrCh, ids := range c.receivers { + rcvr := c.subscriptions[ids[0]] + _, ok := rcvr.(*naiveReceiver) + if !ok { // naiveReceiver uses c.Notifications that is about to be closed below. + c.subscriptions[ids[0]].Close() + } + delete(c.receivers, rcvrCh) + } + c.subscriptionsLock.Unlock() close(c.Notifications) c.Client.ctxCancel() } @@ -638,7 +649,10 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) { select { case <-c.done: return nil, errors.New("connection lost while waiting for the response") - case resp := <-ch: + case resp, ok := <-ch: + if !ok { + return nil, errors.New("connection lost while waiting for the response") + } c.unregisterRespChannel(r.ID) return resp, nil } diff --git a/pkg/rpcclient/wsclient_test.go b/pkg/rpcclient/wsclient_test.go index 1d7f3f774..c3f30c894 100644 --- a/pkg/rpcclient/wsclient_test.go +++ b/pkg/rpcclient/wsclient_test.go @@ -30,10 +30,17 @@ import ( ) func TestWSClientClose(t *testing.T) { - srv := initTestServer(t, "") + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID + bCh := make(chan *block.Block) + _, err = wsc.ReceiveBlocks(nil, bCh) + require.NoError(t, err) wsc.Close() + // Subscriber channel must be closed by server. + _, ok := <-bCh + require.False(t, ok) } func TestWSClientSubscription(t *testing.T) { @@ -296,10 +303,6 @@ func TestWSExecutionVMStateCheck(t *testing.T) { } func TestWSFilteredSubscriptions(t *testing.T) { - bCh := make(chan *block.Block) - txCh := make(chan *transaction.Transaction) - aerCh := make(chan *state.AppExecResult) - ntfCh := make(chan *state.ContainedNotificationEvent) var cases = []struct { name string clientCode func(*testing.T, *WSClient) @@ -308,7 +311,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"blocks primary", func(t *testing.T, wsc *WSClient) { primary := 3 - _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, bCh) + _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Primary: &primary}, make(chan *block.Block)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -323,7 +326,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"blocks since", func(t *testing.T, wsc *WSClient) { var since uint32 = 3 - _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, bCh) + _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Since: &since}, make(chan *block.Block)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -338,7 +341,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"blocks till", func(t *testing.T, wsc *WSClient) { var till uint32 = 3 - _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, bCh) + _, err := wsc.ReceiveBlocks(&neorpc.BlockFilter{Till: &till}, make(chan *block.Block)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -361,7 +364,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { Primary: &primary, Since: &since, Till: &till, - }, bCh) + }, make(chan *block.Block)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -376,7 +379,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"transactions sender", func(t *testing.T, wsc *WSClient) { sender := util.Uint160{1, 2, 3, 4, 5} - _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, txCh) + _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender}, make(chan *transaction.Transaction)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -390,7 +393,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"transactions signer", func(t *testing.T, wsc *WSClient) { signer := util.Uint160{0, 42} - _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, txCh) + _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Signer: &signer}, make(chan *transaction.Transaction)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -405,7 +408,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { func(t *testing.T, wsc *WSClient) { sender := util.Uint160{1, 2, 3, 4, 5} signer := util.Uint160{0, 42} - _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, txCh) + _, err := wsc.ReceiveTransactions(&neorpc.TxFilter{Sender: &sender, Signer: &signer}, make(chan *transaction.Transaction)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -419,7 +422,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"notifications contract hash", func(t *testing.T, wsc *WSClient) { contract := util.Uint160{1, 2, 3, 4, 5} - _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, ntfCh) + _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract}, make(chan *state.ContainedNotificationEvent)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -433,7 +436,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"notifications name", func(t *testing.T, wsc *WSClient) { name := "my_pretty_notification" - _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, ntfCh) + _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Name: &name}, make(chan *state.ContainedNotificationEvent)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -448,7 +451,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { func(t *testing.T, wsc *WSClient) { contract := util.Uint160{1, 2, 3, 4, 5} name := "my_pretty_notification" - _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, ntfCh) + _, err := wsc.ReceiveExecutionNotifications(&neorpc.NotificationFilter{Contract: &contract, Name: &name}, make(chan *state.ContainedNotificationEvent)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -461,8 +464,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, {"executions state", func(t *testing.T, wsc *WSClient) { - state := "FAULT" - _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state}, aerCh) + vmstate := "FAULT" + _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate}, make(chan *state.AppExecResult)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -476,7 +479,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { {"executions container", func(t *testing.T, wsc *WSClient) { container := util.Uint256{1, 2, 3} - _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, aerCh) + _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{Container: &container}, make(chan *state.AppExecResult)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) { @@ -489,9 +492,9 @@ func TestWSFilteredSubscriptions(t *testing.T) { }, {"executions state and container", func(t *testing.T, wsc *WSClient) { - state := "FAULT" + vmstate := "FAULT" container := util.Uint256{1, 2, 3} - _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &state, Container: &container}, aerCh) + _, err := wsc.ReceiveExecutions(&neorpc.ExecutionFilter{State: &vmstate, Container: &container}, make(chan *state.AppExecResult)) require.NoError(t, err) }, func(t *testing.T, p *params.Params) {