diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go index b67689852..95fd7f256 100644 --- a/pkg/rpc/client/wsclient.go +++ b/pkg/rpc/client/wsclient.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "strconv" "sync" "time" @@ -41,6 +42,9 @@ type WSClient struct { shutdown chan struct{} closeCalled atomic.Bool + closeErrLock sync.RWMutex + closeErr error + subscriptionsLock sync.RWMutex subscriptions map[string]bool @@ -128,27 +132,38 @@ func (c *WSClient) Close() { func (c *WSClient) wsReader() { c.ws.SetReadLimit(wsReadLimit) - c.ws.SetPongHandler(func(string) error { return c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) }) + c.ws.SetPongHandler(func(string) error { + err := c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) + if err != nil { + c.setCloseErr(fmt.Errorf("failed to set pong read deadline: %w", err)) + } + return err + }) + var connCloseErr error readloop: for { rr := new(requestResponse) err := c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) if err != nil { + connCloseErr = fmt.Errorf("failed to set response read deadline: %w", err) break } err = c.ws.ReadJSON(rr) if err != nil { // Timeout/connection loss/malformed response. + connCloseErr = fmt.Errorf("failed to read JSON response (timeout/connection loss/malformed response): %w", err) break } if rr.RawID == nil && rr.Method != "" { event, err := response.GetEventIDFromString(rr.Method) if err != nil { // Bad event received. + connCloseErr = fmt.Errorf("failed to perse event ID from string %s: %w", rr.Method, err) break } if event != response.MissedEventID && len(rr.RawParams) != 1 { // Bad event received. + connCloseErr = fmt.Errorf("bad event received: %s / %d", event, len(rr.RawParams)) break } var val interface{} @@ -157,7 +172,8 @@ readloop: sr, err := c.StateRootInHeader() if err != nil { // Client is not initialized. - break + connCloseErr = fmt.Errorf("failed to fetch StateRootInHeader: %w", err) + break readloop } val = block.New(sr) case response.TransactionEventID: @@ -172,12 +188,14 @@ readloop: // No value. default: // Bad event received. + connCloseErr = fmt.Errorf("unknown event received: %d", event) break readloop } if event != response.MissedEventID { err = json.Unmarshal(rr.RawParams[0].RawMessage, val) if err != nil { // Bad event received. + connCloseErr = fmt.Errorf("failed to unmarshal event of type %s from JSON: %w", event, err) break } } @@ -190,18 +208,24 @@ readloop: resp.Result = rr.Result id, err := strconv.Atoi(string(resp.ID)) if err != nil { + connCloseErr = fmt.Errorf("failed to retrieve response ID from string %s: %w", string(resp.ID), err) break // Malformed response (invalid response ID). } ch := c.getResponseChannel(uint64(id)) if ch == nil { + connCloseErr = fmt.Errorf("unknown response channel for response %d", id) break // Unknown response (unexpected response ID). } ch <- resp } else { // Malformed response, neither valid request, nor valid response. + connCloseErr = fmt.Errorf("malformed response") break } } + if connCloseErr != nil { + c.setCloseErr(connCloseErr) + } close(c.done) c.respLock.Lock() for _, ch := range c.respChannels { @@ -216,6 +240,8 @@ func (c *WSClient) wsWriter() { pingTicker := time.NewTicker(wsPingPeriod) defer c.ws.Close() defer pingTicker.Stop() + var connCloseErr error +writeloop: for { select { case <-c.shutdown: @@ -227,20 +253,27 @@ func (c *WSClient) wsWriter() { return } if err := c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout)); err != nil { - return + connCloseErr = fmt.Errorf("failed to set request write deadline: %w", err) + break writeloop } if err := c.ws.WriteJSON(req); err != nil { - return + connCloseErr = fmt.Errorf("failed to write JSON request: %w", err) + break writeloop } case <-pingTicker.C: if err := c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)); err != nil { - return + connCloseErr = fmt.Errorf("failed to set ping write deadline: %w", err) + break writeloop } if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { - return + connCloseErr = fmt.Errorf("failed to write ping message: %w", err) + break writeloop } } } + if connCloseErr != nil { + c.setCloseErr(connCloseErr) + } } func (c *WSClient) unregisterRespChannel(id uint64) { @@ -399,3 +432,21 @@ func (c *WSClient) UnsubscribeAll() error { } return nil } + +// setCloseErr is a thread-safe method setting closeErr in case if it's not yet set. +func (c *WSClient) setCloseErr(err error) { + c.closeErrLock.Lock() + defer c.closeErrLock.Unlock() + + if c.closeErr == nil { + c.closeErr = err + } +} + +// GetError returns the reason of WS connection closing. +func (c *WSClient) GetError() error { + c.closeErrLock.RLock() + defer c.closeErrLock.RUnlock() + + return c.closeErr +} diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 6e6ca2e1d..b26106a9c 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -15,6 +15,8 @@ import ( "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/config/netmode" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" @@ -468,3 +470,36 @@ func TestWS_RequestAfterClose(t *testing.T) { require.Error(t, err) require.True(t, strings.Contains(err.Error(), "connection lost before registering response channel")) } + +func TestWSClient_ConnClosedError(t *testing.T) { + srv := initTestServer(t, "") + + t.Run("standard closing", func(t *testing.T) { + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + + c.Close() + + err = c.GetError() + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "use of closed network connection")) + }) + + t.Run("malformed request", func(t *testing.T) { + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + + defaultMaxBlockSize := 262144 + _, err = c.SubmitP2PNotaryRequest(&payload.P2PNotaryRequest{ + MainTransaction: &transaction.Transaction{ + Script: make([]byte, defaultMaxBlockSize*3), + }, + FallbackTransaction: &transaction.Transaction{}, + }) + require.Error(t, err) + + err = c.GetError() + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "failed to read JSON response (timeout/connection loss/malformed response)"), err.Error()) + }) +}