Merge pull request #3392 from nspcc-dev/adjust-deadlines

*: adjust WS connection RW deadlines
This commit is contained in:
Roman Khimov 2024-04-03 14:50:14 +03:00 committed by GitHub
commit be1b97d04e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 53 additions and 31 deletions

View file

@ -1938,7 +1938,7 @@ func initTestServer(t *testing.T, resp string) *httptest.Server {
ws, err := upgrader.Upgrade(w, req, nil) ws, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err) require.NoError(t, err)
for { for {
err = ws.SetReadDeadline(time.Now().Add(2 * time.Second)) err = ws.SetReadDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
_, p, err := ws.ReadMessage() _, p, err := ws.ReadMessage()
if err != nil { if err != nil {
@ -1950,7 +1950,7 @@ func initTestServer(t *testing.T, resp string) *httptest.Server {
t.Fatalf("Cannot decode request body: %s", req.Body) t.Fatalf("Cannot decode request body: %s", req.Body)
} }
response := wrapInitResponse(r, resp) response := wrapInitResponse(r, resp)
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) err = ws.SetWriteDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
err = ws.WriteMessage(1, []byte(response)) err = ws.WriteMessage(1, []byte(response))
if err != nil { if err != nil {

View file

@ -163,7 +163,7 @@ func TestWSClientEvents(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
<-startSending <-startSending
for _, event := range events { for _, event := range events {
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) err = ws.SetWriteDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
err = ws.WriteMessage(1, []byte(event)) err = ws.WriteMessage(1, []byte(event))
if err != nil { if err != nil {
@ -308,7 +308,7 @@ func TestWSClientNonBlockingEvents(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
<-startSending <-startSending
for _, event := range events { for _, event := range events {
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) err = ws.SetWriteDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
err = ws.WriteMessage(1, []byte(event)) err = ws.WriteMessage(1, []byte(event))
if err != nil { if err != nil {
@ -738,14 +738,14 @@ func TestWSFilteredSubscriptions(t *testing.T) {
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
ws, err := upgrader.Upgrade(w, req, nil) ws, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err) require.NoError(t, err)
err = ws.SetReadDeadline(time.Now().Add(2 * time.Second)) err = ws.SetReadDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
req := params.In{} req := params.In{}
err = ws.ReadJSON(&req) err = ws.ReadJSON(&req)
require.NoError(t, err) require.NoError(t, err)
params := params.Params(req.RawParams) params := params.Params(req.RawParams)
c.serverCode(t, &params) c.serverCode(t, &params)
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) err = ws.SetWriteDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
err = ws.WriteMessage(1, []byte(`{"jsonrpc": "2.0", "id": 1, "result": "0"}`)) err = ws.WriteMessage(1, []byte(`{"jsonrpc": "2.0", "id": 1, "result": "0"}`))
require.NoError(t, err) require.NoError(t, err)
@ -793,7 +793,7 @@ func TestWSConcurrentAccess(t *testing.T) {
ws, err := upgrader.Upgrade(w, req, nil) ws, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err) require.NoError(t, err)
for { for {
err = ws.SetReadDeadline(time.Now().Add(2 * time.Second)) err = ws.SetReadDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
_, p, err := ws.ReadMessage() _, p, err := ws.ReadMessage()
if err != nil { if err != nil {
@ -819,7 +819,7 @@ func TestWSConcurrentAccess(t *testing.T) {
case "getblockhash": case "getblockhash":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":"0x157ca5e5b8cf8f84c9660502a3270b346011612bded1514a6847f877c433a9bb"}`, r.RawID) response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":"0x157ca5e5b8cf8f84c9660502a3270b346011612bded1514a6847f877c433a9bb"}`, r.RawID)
} }
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) err = ws.SetWriteDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
err = ws.WriteMessage(1, []byte(response)) err = ws.WriteMessage(1, []byte(response))
if err != nil { if err != nil {

View file

@ -3477,10 +3477,10 @@ func doRPCCallOverWS(rpcCall string, url string, t *testing.T) []byte {
c, r, err := dialer.Dial(url+"/ws", nil) c, r, err := dialer.Dial(url+"/ws", nil)
require.NoError(t, err) require.NoError(t, err)
defer r.Body.Close() defer r.Body.Close()
err = c.SetWriteDeadline(time.Now().Add(time.Second)) err = c.SetWriteDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, c.WriteMessage(1, []byte(rpcCall))) require.NoError(t, c.WriteMessage(1, []byte(rpcCall)))
err = c.SetReadDeadline(time.Now().Add(time.Second)) err = c.SetReadDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
_, body, err := c.ReadMessage() _, body, err := c.ReadMessage()
require.NoError(t, err) require.NoError(t, err)

View file

@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -20,21 +19,36 @@ import (
const testOverflow = false const testOverflow = false
func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, isFinished *atomic.Bool, readerToExitCh chan struct{}) { func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, readerStopCh chan struct{}, readerToExitCh chan struct{}) {
for !isFinished.Load() { readLoop:
err := ws.SetReadDeadline(time.Now().Add(time.Second)) for {
if isFinished.Load() { select {
require.Error(t, err) case <-readerStopCh:
break break readLoop
default:
err := ws.SetReadDeadline(time.Now().Add(5 * time.Second))
select {
case <-readerStopCh:
break readLoop
default:
require.NoError(t, err)
}
_, body, err := ws.ReadMessage()
select {
case <-readerStopCh:
break readLoop
default:
require.NoError(t, err)
}
select {
case msgCh <- body:
case <-time.After(10 * time.Second):
t.Log("exiting wsReader loop: unable to send response to receiver")
break readLoop
}
} }
require.NoError(t, err)
_, body, err := ws.ReadMessage()
if isFinished.Load() {
require.Error(t, err)
break
}
require.NoError(t, err)
msgCh <- body
} }
close(readerToExitCh) close(readerToExitCh)
} }
@ -42,7 +56,7 @@ func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, isFinished
func callWSGetRaw(t *testing.T, ws *websocket.Conn, msg string, respCh <-chan []byte) *neorpc.Response { func callWSGetRaw(t *testing.T, ws *websocket.Conn, msg string, respCh <-chan []byte) *neorpc.Response {
var resp = new(neorpc.Response) var resp = new(neorpc.Response)
require.NoError(t, ws.SetWriteDeadline(time.Now().Add(time.Second))) require.NoError(t, ws.SetWriteDeadline(time.Now().Add(5*time.Second)))
require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(msg))) require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(msg)))
body := <-respCh body := <-respCh
@ -69,14 +83,22 @@ func initCleanServerAndWSClient(t *testing.T, startNetworkServer ...bool) (*core
// Use buffered channel to read server's messages and then read expected // Use buffered channel to read server's messages and then read expected
// responses from it. // responses from it.
respMsgs := make(chan []byte, 16) respMsgs := make(chan []byte, 16)
finishedFlag := &atomic.Bool{} readerStopCh := make(chan struct{})
readerToExitCh := make(chan struct{}) readerToExitCh := make(chan struct{})
go wsReader(t, ws, respMsgs, finishedFlag, readerToExitCh) go wsReader(t, ws, respMsgs, readerStopCh, readerToExitCh)
if len(startNetworkServer) != 0 && startNetworkServer[0] { if len(startNetworkServer) != 0 && startNetworkServer[0] {
rpcSrv.coreServer.Start() rpcSrv.coreServer.Start()
} }
t.Cleanup(func() { t.Cleanup(func() {
finishedFlag.Store(true) drainLoop:
for {
select {
case <-respMsgs:
default:
break drainLoop
}
}
close(readerStopCh)
<-readerToExitCh <-readerToExitCh
ws.Close() ws.Close()
if len(startNetworkServer) != 0 && startNetworkServer[0] { if len(startNetworkServer) != 0 && startNetworkServer[0] {
@ -556,11 +578,11 @@ func TestBadSubUnsub(t *testing.T) {
} }
func doSomeWSRequest(t *testing.T, ws *websocket.Conn) { func doSomeWSRequest(t *testing.T, ws *websocket.Conn) {
require.NoError(t, ws.SetWriteDeadline(time.Now().Add(time.Second))) require.NoError(t, ws.SetWriteDeadline(time.Now().Add(5*time.Second)))
// It could be just about anything including invalid request, // It could be just about anything including invalid request,
// we only care about server handling being active. // we only care about server handling being active.
require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(`{"jsonrpc": "2.0", "method": "getversion", "params": [], "id": 1}`))) require.NoError(t, ws.WriteMessage(websocket.TextMessage, []byte(`{"jsonrpc": "2.0", "method": "getversion", "params": [], "id": 1}`)))
err := ws.SetReadDeadline(time.Now().Add(time.Second)) err := ws.SetReadDeadline(time.Now().Add(5 * time.Second))
require.NoError(t, err) require.NoError(t, err)
_, _, err = ws.ReadMessage() _, _, err = ws.ReadMessage()
require.NoError(t, err) require.NoError(t, err)