Merge pull request #3000 from nspcc-dev/conloss-error-for-wsclient

rpcclient: provide some exported error for disconnected WSClient
This commit is contained in:
Roman Khimov 2023-05-03 16:42:12 +03:00 committed by GitHub
commit 8e6025fbc8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 5 deletions

View file

@ -454,6 +454,11 @@ const (
// ErrNilNotificationReceiver is returned when notification receiver channel is nil. // ErrNilNotificationReceiver is returned when notification receiver channel is nil.
var ErrNilNotificationReceiver = errors.New("nil notification receiver") var ErrNilNotificationReceiver = errors.New("nil notification receiver")
// ErrWSConnLost is a WSClient-specific error that will be returned for any
// requests after disconnection (including intentional ones via
// (*WSClient).Close).
var ErrWSConnLost = errors.New("connection lost")
// errConnClosedByUser is a WSClient error used iff the user calls (*WSClient).Close method by himself. // errConnClosedByUser is a WSClient error used iff the user calls (*WSClient).Close method by himself.
var errConnClosedByUser = errors.New("connection closed by user") var errConnClosedByUser = errors.New("connection closed by user")
@ -735,22 +740,22 @@ func (c *WSClient) makeWsRequest(r *neorpc.Request) (*neorpc.Response, error) {
select { select {
case <-c.done: case <-c.done:
c.respLock.Unlock() c.respLock.Unlock()
return nil, errors.New("connection lost before registering response channel") return nil, fmt.Errorf("%w: before registering response channel", ErrWSConnLost)
default: default:
c.respChannels[r.ID] = ch c.respChannels[r.ID] = ch
c.respLock.Unlock() c.respLock.Unlock()
} }
select { select {
case <-c.done: case <-c.done:
return nil, errors.New("connection lost before sending the request") return nil, fmt.Errorf("%w: before sending the request", ErrWSConnLost)
case c.requests <- r: case c.requests <- r:
} }
select { select {
case <-c.done: case <-c.done:
return nil, errors.New("connection lost while waiting for the response") return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost)
case resp, ok := <-ch: case resp, ok := <-ch:
if !ok { if !ok {
return nil, errors.New("connection lost while waiting for the response") return nil, fmt.Errorf("%w: while waiting for the response", ErrWSConnLost)
} }
c.unregisterRespChannel(r.ID) c.unregisterRespChannel(r.ID)
return resp, nil return resp, nil

View file

@ -3,6 +3,7 @@ package rpcclient
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -752,7 +753,7 @@ func TestWS_RequestAfterClose(t *testing.T) {
_, err = c.GetBlockCount() _, err = c.GetBlockCount()
}) })
require.Error(t, err) require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "connection lost before registering response channel")) require.True(t, errors.Is(err, ErrWSConnLost))
} }
func TestWSClient_ConnClosedError(t *testing.T) { func TestWSClient_ConnClosedError(t *testing.T) {