rpc: add ability to get WSClient connection closure error

Close #2421.
This commit is contained in:
Anna Shaleva 2022-05-23 13:47:52 +03:00
parent 73ef36e03e
commit 19646e0967
2 changed files with 92 additions and 6 deletions

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -41,6 +42,9 @@ type WSClient struct {
shutdown chan struct{} shutdown chan struct{}
closeCalled atomic.Bool closeCalled atomic.Bool
closeErrLock sync.RWMutex
closeErr error
subscriptionsLock sync.RWMutex subscriptionsLock sync.RWMutex
subscriptions map[string]bool subscriptions map[string]bool
@ -128,27 +132,38 @@ func (c *WSClient) Close() {
func (c *WSClient) wsReader() { func (c *WSClient) wsReader() {
c.ws.SetReadLimit(wsReadLimit) c.ws.SetReadLimit(wsReadLimit)
c.ws.SetPongHandler(func(string) error { return c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) }) c.ws.SetPongHandler(func(string) error {
err := c.ws.SetReadDeadline(time.Now().Add(wsPongLimit))
if err != nil {
c.setCloseErr(fmt.Errorf("failed to set pong read deadline: %w", err))
}
return err
})
var connCloseErr error
readloop: readloop:
for { for {
rr := new(requestResponse) rr := new(requestResponse)
err := c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) err := c.ws.SetReadDeadline(time.Now().Add(wsPongLimit))
if err != nil { if err != nil {
connCloseErr = fmt.Errorf("failed to set response read deadline: %w", err)
break break
} }
err = c.ws.ReadJSON(rr) err = c.ws.ReadJSON(rr)
if err != nil { if err != nil {
// Timeout/connection loss/malformed response. // Timeout/connection loss/malformed response.
connCloseErr = fmt.Errorf("failed to read JSON response (timeout/connection loss/malformed response): %w", err)
break break
} }
if rr.RawID == nil && rr.Method != "" { if rr.RawID == nil && rr.Method != "" {
event, err := response.GetEventIDFromString(rr.Method) event, err := response.GetEventIDFromString(rr.Method)
if err != nil { if err != nil {
// Bad event received. // Bad event received.
connCloseErr = fmt.Errorf("failed to perse event ID from string %s: %w", rr.Method, err)
break break
} }
if event != response.MissedEventID && len(rr.RawParams) != 1 { if event != response.MissedEventID && len(rr.RawParams) != 1 {
// Bad event received. // Bad event received.
connCloseErr = fmt.Errorf("bad event received: %s / %d", event, len(rr.RawParams))
break break
} }
var val interface{} var val interface{}
@ -157,7 +172,8 @@ readloop:
sr, err := c.StateRootInHeader() sr, err := c.StateRootInHeader()
if err != nil { if err != nil {
// Client is not initialized. // Client is not initialized.
break connCloseErr = fmt.Errorf("failed to fetch StateRootInHeader: %w", err)
break readloop
} }
val = block.New(sr) val = block.New(sr)
case response.TransactionEventID: case response.TransactionEventID:
@ -172,12 +188,14 @@ readloop:
// No value. // No value.
default: default:
// Bad event received. // Bad event received.
connCloseErr = fmt.Errorf("unknown event received: %d", event)
break readloop break readloop
} }
if event != response.MissedEventID { if event != response.MissedEventID {
err = json.Unmarshal(rr.RawParams[0].RawMessage, val) err = json.Unmarshal(rr.RawParams[0].RawMessage, val)
if err != nil { if err != nil {
// Bad event received. // Bad event received.
connCloseErr = fmt.Errorf("failed to unmarshal event of type %s from JSON: %w", event, err)
break break
} }
} }
@ -190,18 +208,24 @@ readloop:
resp.Result = rr.Result resp.Result = rr.Result
id, err := strconv.Atoi(string(resp.ID)) id, err := strconv.Atoi(string(resp.ID))
if err != nil { if err != nil {
connCloseErr = fmt.Errorf("failed to retrieve response ID from string %s: %w", string(resp.ID), err)
break // Malformed response (invalid response ID). break // Malformed response (invalid response ID).
} }
ch := c.getResponseChannel(uint64(id)) ch := c.getResponseChannel(uint64(id))
if ch == nil { if ch == nil {
connCloseErr = fmt.Errorf("unknown response channel for response %d", id)
break // Unknown response (unexpected response ID). break // Unknown response (unexpected response ID).
} }
ch <- resp ch <- resp
} else { } else {
// Malformed response, neither valid request, nor valid response. // Malformed response, neither valid request, nor valid response.
connCloseErr = fmt.Errorf("malformed response")
break break
} }
} }
if connCloseErr != nil {
c.setCloseErr(connCloseErr)
}
close(c.done) close(c.done)
c.respLock.Lock() c.respLock.Lock()
for _, ch := range c.respChannels { for _, ch := range c.respChannels {
@ -216,6 +240,8 @@ func (c *WSClient) wsWriter() {
pingTicker := time.NewTicker(wsPingPeriod) pingTicker := time.NewTicker(wsPingPeriod)
defer c.ws.Close() defer c.ws.Close()
defer pingTicker.Stop() defer pingTicker.Stop()
var connCloseErr error
writeloop:
for { for {
select { select {
case <-c.shutdown: case <-c.shutdown:
@ -227,20 +253,27 @@ func (c *WSClient) wsWriter() {
return return
} }
if err := c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout)); err != nil { if err := c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout)); err != nil {
return connCloseErr = fmt.Errorf("failed to set request write deadline: %w", err)
break writeloop
} }
if err := c.ws.WriteJSON(req); err != nil { if err := c.ws.WriteJSON(req); err != nil {
return connCloseErr = fmt.Errorf("failed to write JSON request: %w", err)
break writeloop
} }
case <-pingTicker.C: case <-pingTicker.C:
if err := c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)); err != nil { if err := c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)); err != nil {
return connCloseErr = fmt.Errorf("failed to set ping write deadline: %w", err)
break writeloop
} }
if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
return connCloseErr = fmt.Errorf("failed to write ping message: %w", err)
break writeloop
} }
} }
} }
if connCloseErr != nil {
c.setCloseErr(connCloseErr)
}
} }
func (c *WSClient) unregisterRespChannel(id uint64) { func (c *WSClient) unregisterRespChannel(id uint64) {
@ -399,3 +432,21 @@ func (c *WSClient) UnsubscribeAll() error {
} }
return nil return nil
} }
// setCloseErr is a thread-safe method setting closeErr in case if it's not yet set.
func (c *WSClient) setCloseErr(err error) {
c.closeErrLock.Lock()
defer c.closeErrLock.Unlock()
if c.closeErr == nil {
c.closeErr = err
}
}
// GetError returns the reason of WS connection closing.
func (c *WSClient) GetError() error {
c.closeErrLock.RLock()
defer c.closeErrLock.RUnlock()
return c.closeErr
}

View file

@ -15,6 +15,8 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -468,3 +470,36 @@ func TestWS_RequestAfterClose(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "connection lost before registering response channel")) require.True(t, strings.Contains(err.Error(), "connection lost before registering response channel"))
} }
func TestWSClient_ConnClosedError(t *testing.T) {
srv := initTestServer(t, "")
t.Run("standard closing", func(t *testing.T) {
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
c.Close()
err = c.GetError()
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "use of closed network connection"))
})
t.Run("malformed request", func(t *testing.T) {
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
defaultMaxBlockSize := 262144
_, err = c.SubmitP2PNotaryRequest(&payload.P2PNotaryRequest{
MainTransaction: &transaction.Transaction{
Script: make([]byte, defaultMaxBlockSize*3),
},
FallbackTransaction: &transaction.Transaction{},
})
require.Error(t, err)
err = c.GetError()
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "failed to read JSON response (timeout/connection loss/malformed response)"), err.Error())
})
}