diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go index 6898a1055..14cd25f8e 100644 --- a/pkg/rpc/client/wsclient.go +++ b/pkg/rpc/client/wsclient.go @@ -82,6 +82,9 @@ const ( wsWriteLimit = wsPingPeriod / 2 ) +// errConnClosedByUser is a WSClient error used iff the user calls (*WSClient).Close method by himself. +var errConnClosedByUser = errors.New("connection closed by user") + // NewWS returns a new WSClient ready to use (with established websocket // connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`. // You should call Init method to initialize the network magic the client is @@ -121,6 +124,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error // unusable. func (c *WSClient) Close() { if c.closeCalled.CAS(false, true) { + c.setCloseErr(errConnClosedByUser) // Closing shutdown channel sends a signal to wsWriter to break out of the // loop. In doing so it does ws.Close() closing the network connection // which in turn makes wsReader receive an err from ws.ReadJSON() and also @@ -146,25 +150,25 @@ readloop: err := c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) if err != nil { connCloseErr = fmt.Errorf("failed to set response read deadline: %w", err) - break + break readloop } err = c.ws.ReadJSON(rr) if err != nil { // Timeout/connection loss/malformed response. connCloseErr = fmt.Errorf("failed to read JSON response (timeout/connection loss/malformed response): %w", err) - break + break readloop } if rr.RawID == nil && rr.Method != "" { event, err := response.GetEventIDFromString(rr.Method) if err != nil { // Bad event received. connCloseErr = fmt.Errorf("failed to perse event ID from string %s: %w", rr.Method, err) - break + break readloop } if event != response.MissedEventID && len(rr.RawParams) != 1 { // Bad event received. connCloseErr = fmt.Errorf("bad event received: %s / %d", event, len(rr.RawParams)) - break + break readloop } var val interface{} switch event { @@ -196,7 +200,7 @@ readloop: if err != nil { // Bad event received. connCloseErr = fmt.Errorf("failed to unmarshal event of type %s from JSON: %w", event, err) - break + break readloop } } c.Notifications <- Notification{event, val} @@ -209,18 +213,18 @@ readloop: id, err := strconv.Atoi(string(resp.ID)) if err != nil { connCloseErr = fmt.Errorf("failed to retrieve response ID from string %s: %w", string(resp.ID), err) - break // Malformed response (invalid response ID). + break readloop // Malformed response (invalid response ID). } ch := c.getResponseChannel(uint64(id)) if ch == nil { connCloseErr = fmt.Errorf("unknown response channel for response %d", id) - break // Unknown response (unexpected response ID). + break readloop // Unknown response (unexpected response ID). } ch <- resp } else { // Malformed response, neither valid request, nor valid response. connCloseErr = fmt.Errorf("malformed response") - break + break readloop } } if connCloseErr != nil { @@ -443,10 +447,14 @@ func (c *WSClient) setCloseErr(err error) { } } -// GetError returns the reason of WS connection closing. +// GetError returns the reason of WS connection closing. It returns nil in case if connection +// was closed by the use via Close() method calling. func (c *WSClient) GetError() error { c.closeErrLock.RLock() defer c.closeErrLock.RUnlock() + if c.closeErr != nil && errors.Is(c.closeErr, errConnClosedByUser) { + return nil + } return c.closeErr } diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index b26106a9c..73ea93ff6 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -472,20 +472,24 @@ func TestWS_RequestAfterClose(t *testing.T) { } func TestWSClient_ConnClosedError(t *testing.T) { - srv := initTestServer(t, "") - t.Run("standard closing", func(t *testing.T) { + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": 123}`) c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) - c.Close() - + // Check client is working. + _, err = c.GetBlockCount() + require.NoError(t, err) err = c.GetError() - require.Error(t, err) - require.True(t, strings.Contains(err.Error(), "use of closed network connection")) + require.NoError(t, err) + + c.Close() + err = c.GetError() + require.NoError(t, err) }) t.Run("malformed request", func(t *testing.T) { + srv := initTestServer(t, "") c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err)