diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 3ec07198f..e4f11a0f0 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -454,6 +454,11 @@ const ( // ErrNilNotificationReceiver is returned when notification receiver channel is nil. var ErrNilNotificationReceiver = errors.New("nil notification receiver") +// ErrWSConnLost is a WSClient-specific error that will be returned for any +// requests after disconnection (including intentional ones via +// (*WSClient).Close). +var ErrWSConnLost = errors.New("connection lost") + // errConnClosedByUser is a WSClient error used iff the user calls (*WSClient).Close method by himself. var errConnClosedByUser = errors.New("connection closed by user") @@ -735,22 +740,22 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) { select { case <-c.done: c.respLock.Unlock() - return nil, errors.New("connection lost before registering response channel") + return nil, fmt.Errorf("%w: before registering response channel", ErrWSConnLost) default: c.respChannels[r.ID] = ch c.respLock.Unlock() } select { case <-c.done: - return nil, errors.New("connection lost before sending the request") + return nil, fmt.Errorf("%w: before sending the request", ErrWSConnLost) case c.requests <- r: } select { case <-c.done: - return nil, errors.New("connection lost while waiting for the response") + return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost) case resp, ok := <-ch: if !ok { - return nil, errors.New("connection lost while waiting for the response") + return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost) } c.unregisterRespChannel(r.ID) return resp, nil diff --git a/pkg/rpcclient/wsclient_test.go b/pkg/rpcclient/wsclient_test.go index ccd272e3b..f414d4c80 100644 --- a/pkg/rpcclient/wsclient_test.go +++ b/pkg/rpcclient/wsclient_test.go @@ -3,6 +3,7 @@ package rpcclient import ( "context" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -752,7 +753,7 @@ func TestWS_RequestAfterClose(t *testing.T) { _, err = c.GetBlockCount() }) require.Error(t, err) - require.True(t, strings.Contains(err.Error(), "connection lost before registering response channel")) + require.True(t, errors.Is(err, ErrWSConnLost)) } func TestWSClient_ConnClosedError(t *testing.T) {