diff --git a/pkg/services/rpcsrv/subscription_test.go b/pkg/services/rpcsrv/subscription_test.go index 48251a3f7..b37780688 100644 --- a/pkg/services/rpcsrv/subscription_test.go +++ b/pkg/services/rpcsrv/subscription_test.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "strings" - "sync/atomic" "testing" "time" @@ -20,21 +19,36 @@ import ( const testOverflow = false -func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, isFinished *atomic.Bool, readerToExitCh chan struct{}) { - for !isFinished.Load() { - err := ws.SetReadDeadline(time.Now().Add(5 * time.Second)) - if isFinished.Load() { - require.Error(t, err) - break +func wsReader(t *testing.T, ws *websocket.Conn, msgCh chan<- []byte, readerStopCh chan struct{}, readerToExitCh chan struct{}) { +readLoop: + for { + select { + case <-readerStopCh: + break readLoop + default: + err := ws.SetReadDeadline(time.Now().Add(5 * time.Second)) + select { + case <-readerStopCh: + break readLoop + default: + require.NoError(t, err) + } + + _, body, err := ws.ReadMessage() + select { + case <-readerStopCh: + break readLoop + default: + require.NoError(t, err) + } + + select { + case msgCh <- body: + case <-time.After(10 * time.Second): + t.Log("exiting wsReader loop: unable to send response to receiver") + break readLoop + } } - require.NoError(t, err) - _, body, err := ws.ReadMessage() - if isFinished.Load() { - require.Error(t, err) - break - } - require.NoError(t, err) - msgCh <- body } close(readerToExitCh) } @@ -69,14 +83,22 @@ func initCleanServerAndWSClient(t *testing.T, startNetworkServer ...bool) (*core // Use buffered channel to read server's messages and then read expected // responses from it. respMsgs := make(chan []byte, 16) - finishedFlag := &atomic.Bool{} + readerStopCh := make(chan struct{}) readerToExitCh := make(chan struct{}) - go wsReader(t, ws, respMsgs, finishedFlag, readerToExitCh) + go wsReader(t, ws, respMsgs, readerStopCh, readerToExitCh) if len(startNetworkServer) != 0 && startNetworkServer[0] { rpcSrv.coreServer.Start() } t.Cleanup(func() { - finishedFlag.Store(true) + drainLoop: + for { + select { + case <-respMsgs: + default: + break drainLoop + } + } + close(readerStopCh) <-readerToExitCh ws.Close() if len(startNetworkServer) != 0 && startNetworkServer[0] {