forked from TrueCloudLab/neoneo-go
parent
73ef36e03e
commit
19646e0967
2 changed files with 92 additions and 6 deletions
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -41,6 +42,9 @@ type WSClient struct {
|
||||||
shutdown chan struct{}
|
shutdown chan struct{}
|
||||||
closeCalled atomic.Bool
|
closeCalled atomic.Bool
|
||||||
|
|
||||||
|
closeErrLock sync.RWMutex
|
||||||
|
closeErr error
|
||||||
|
|
||||||
subscriptionsLock sync.RWMutex
|
subscriptionsLock sync.RWMutex
|
||||||
subscriptions map[string]bool
|
subscriptions map[string]bool
|
||||||
|
|
||||||
|
@ -128,27 +132,38 @@ func (c *WSClient) Close() {
|
||||||
|
|
||||||
func (c *WSClient) wsReader() {
|
func (c *WSClient) wsReader() {
|
||||||
c.ws.SetReadLimit(wsReadLimit)
|
c.ws.SetReadLimit(wsReadLimit)
|
||||||
c.ws.SetPongHandler(func(string) error { return c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) })
|
c.ws.SetPongHandler(func(string) error {
|
||||||
|
err := c.ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
||||||
|
if err != nil {
|
||||||
|
c.setCloseErr(fmt.Errorf("failed to set pong read deadline: %w", err))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
var connCloseErr error
|
||||||
readloop:
|
readloop:
|
||||||
for {
|
for {
|
||||||
rr := new(requestResponse)
|
rr := new(requestResponse)
|
||||||
err := c.ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
err := c.ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
connCloseErr = fmt.Errorf("failed to set response read deadline: %w", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
err = c.ws.ReadJSON(rr)
|
err = c.ws.ReadJSON(rr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Timeout/connection loss/malformed response.
|
// Timeout/connection loss/malformed response.
|
||||||
|
connCloseErr = fmt.Errorf("failed to read JSON response (timeout/connection loss/malformed response): %w", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if rr.RawID == nil && rr.Method != "" {
|
if rr.RawID == nil && rr.Method != "" {
|
||||||
event, err := response.GetEventIDFromString(rr.Method)
|
event, err := response.GetEventIDFromString(rr.Method)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Bad event received.
|
// Bad event received.
|
||||||
|
connCloseErr = fmt.Errorf("failed to perse event ID from string %s: %w", rr.Method, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if event != response.MissedEventID && len(rr.RawParams) != 1 {
|
if event != response.MissedEventID && len(rr.RawParams) != 1 {
|
||||||
// Bad event received.
|
// Bad event received.
|
||||||
|
connCloseErr = fmt.Errorf("bad event received: %s / %d", event, len(rr.RawParams))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
var val interface{}
|
var val interface{}
|
||||||
|
@ -157,7 +172,8 @@ readloop:
|
||||||
sr, err := c.StateRootInHeader()
|
sr, err := c.StateRootInHeader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Client is not initialized.
|
// Client is not initialized.
|
||||||
break
|
connCloseErr = fmt.Errorf("failed to fetch StateRootInHeader: %w", err)
|
||||||
|
break readloop
|
||||||
}
|
}
|
||||||
val = block.New(sr)
|
val = block.New(sr)
|
||||||
case response.TransactionEventID:
|
case response.TransactionEventID:
|
||||||
|
@ -172,12 +188,14 @@ readloop:
|
||||||
// No value.
|
// No value.
|
||||||
default:
|
default:
|
||||||
// Bad event received.
|
// Bad event received.
|
||||||
|
connCloseErr = fmt.Errorf("unknown event received: %d", event)
|
||||||
break readloop
|
break readloop
|
||||||
}
|
}
|
||||||
if event != response.MissedEventID {
|
if event != response.MissedEventID {
|
||||||
err = json.Unmarshal(rr.RawParams[0].RawMessage, val)
|
err = json.Unmarshal(rr.RawParams[0].RawMessage, val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Bad event received.
|
// Bad event received.
|
||||||
|
connCloseErr = fmt.Errorf("failed to unmarshal event of type %s from JSON: %w", event, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -190,18 +208,24 @@ readloop:
|
||||||
resp.Result = rr.Result
|
resp.Result = rr.Result
|
||||||
id, err := strconv.Atoi(string(resp.ID))
|
id, err := strconv.Atoi(string(resp.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
connCloseErr = fmt.Errorf("failed to retrieve response ID from string %s: %w", string(resp.ID), err)
|
||||||
break // Malformed response (invalid response ID).
|
break // Malformed response (invalid response ID).
|
||||||
}
|
}
|
||||||
ch := c.getResponseChannel(uint64(id))
|
ch := c.getResponseChannel(uint64(id))
|
||||||
if ch == nil {
|
if ch == nil {
|
||||||
|
connCloseErr = fmt.Errorf("unknown response channel for response %d", id)
|
||||||
break // Unknown response (unexpected response ID).
|
break // Unknown response (unexpected response ID).
|
||||||
}
|
}
|
||||||
ch <- resp
|
ch <- resp
|
||||||
} else {
|
} else {
|
||||||
// Malformed response, neither valid request, nor valid response.
|
// Malformed response, neither valid request, nor valid response.
|
||||||
|
connCloseErr = fmt.Errorf("malformed response")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if connCloseErr != nil {
|
||||||
|
c.setCloseErr(connCloseErr)
|
||||||
|
}
|
||||||
close(c.done)
|
close(c.done)
|
||||||
c.respLock.Lock()
|
c.respLock.Lock()
|
||||||
for _, ch := range c.respChannels {
|
for _, ch := range c.respChannels {
|
||||||
|
@ -216,6 +240,8 @@ func (c *WSClient) wsWriter() {
|
||||||
pingTicker := time.NewTicker(wsPingPeriod)
|
pingTicker := time.NewTicker(wsPingPeriod)
|
||||||
defer c.ws.Close()
|
defer c.ws.Close()
|
||||||
defer pingTicker.Stop()
|
defer pingTicker.Stop()
|
||||||
|
var connCloseErr error
|
||||||
|
writeloop:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.shutdown:
|
case <-c.shutdown:
|
||||||
|
@ -227,20 +253,27 @@ func (c *WSClient) wsWriter() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout)); err != nil {
|
if err := c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout)); err != nil {
|
||||||
return
|
connCloseErr = fmt.Errorf("failed to set request write deadline: %w", err)
|
||||||
|
break writeloop
|
||||||
}
|
}
|
||||||
if err := c.ws.WriteJSON(req); err != nil {
|
if err := c.ws.WriteJSON(req); err != nil {
|
||||||
return
|
connCloseErr = fmt.Errorf("failed to write JSON request: %w", err)
|
||||||
|
break writeloop
|
||||||
}
|
}
|
||||||
case <-pingTicker.C:
|
case <-pingTicker.C:
|
||||||
if err := c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)); err != nil {
|
if err := c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)); err != nil {
|
||||||
return
|
connCloseErr = fmt.Errorf("failed to set ping write deadline: %w", err)
|
||||||
|
break writeloop
|
||||||
}
|
}
|
||||||
if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||||
return
|
connCloseErr = fmt.Errorf("failed to write ping message: %w", err)
|
||||||
|
break writeloop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if connCloseErr != nil {
|
||||||
|
c.setCloseErr(connCloseErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WSClient) unregisterRespChannel(id uint64) {
|
func (c *WSClient) unregisterRespChannel(id uint64) {
|
||||||
|
@ -399,3 +432,21 @@ func (c *WSClient) UnsubscribeAll() error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setCloseErr is a thread-safe method setting closeErr in case if it's not yet set.
|
||||||
|
func (c *WSClient) setCloseErr(err error) {
|
||||||
|
c.closeErrLock.Lock()
|
||||||
|
defer c.closeErrLock.Unlock()
|
||||||
|
|
||||||
|
if c.closeErr == nil {
|
||||||
|
c.closeErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetError returns the reason of WS connection closing.
|
||||||
|
func (c *WSClient) GetError() error {
|
||||||
|
c.closeErrLock.RLock()
|
||||||
|
defer c.closeErrLock.RUnlock()
|
||||||
|
|
||||||
|
return c.closeErr
|
||||||
|
}
|
||||||
|
|
|
@ -15,6 +15,8 @@ import (
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
|
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -468,3 +470,36 @@ func TestWS_RequestAfterClose(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.True(t, strings.Contains(err.Error(), "connection lost before registering response channel"))
|
require.True(t, strings.Contains(err.Error(), "connection lost before registering response channel"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWSClient_ConnClosedError(t *testing.T) {
|
||||||
|
srv := initTestServer(t, "")
|
||||||
|
|
||||||
|
t.Run("standard closing", func(t *testing.T) {
|
||||||
|
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c.Close()
|
||||||
|
|
||||||
|
err = c.GetError()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, strings.Contains(err.Error(), "use of closed network connection"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("malformed request", func(t *testing.T) {
|
||||||
|
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defaultMaxBlockSize := 262144
|
||||||
|
_, err = c.SubmitP2PNotaryRequest(&payload.P2PNotaryRequest{
|
||||||
|
MainTransaction: &transaction.Transaction{
|
||||||
|
Script: make([]byte, defaultMaxBlockSize*3),
|
||||||
|
},
|
||||||
|
FallbackTransaction: &transaction.Transaction{},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
err = c.GetError()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, strings.Contains(err.Error(), "failed to read JSON response (timeout/connection loss/malformed response)"), err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue