rpc: add test for concurrent WSClient access and request IDs generation

This commit is contained in:
Anna Shaleva 2022-02-22 16:18:52 +03:00 committed by AnnaShaleva
parent 0d8723527c
commit 2896c0a83a

View file

@ -6,7 +6,10 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sort"
"strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -15,6 +18,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/atomic"
) )
func TestWSClientClose(t *testing.T) { func TestWSClientClose(t *testing.T) {
@ -344,3 +348,96 @@ func TestNewWS(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}) })
} }
func TestWSConcurrentAccess(t *testing.T) {
var ids struct {
lock sync.RWMutex
m map[int]struct{}
}
ids.m = make(map[int]struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/ws" && req.Method == "GET" {
var upgrader = websocket.Upgrader{}
ws, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
for {
err = ws.SetReadDeadline(time.Now().Add(2 * time.Second))
require.NoError(t, err)
_, p, err := ws.ReadMessage()
if err != nil {
break
}
r := request.NewIn()
err = json.Unmarshal(p, r)
if err != nil {
t.Fatalf("Cannot decode request body: %s", req.Body)
}
i, err := strconv.Atoi(string(r.RawID))
require.NoError(t, err)
ids.lock.Lock()
ids.m[i] = struct{}{}
ids.lock.Unlock()
var response string
// Different responses to catch possible unmarshalling errors connected with invalid IDs distribution.
switch r.Method {
case "getblockcount":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":123}`, r.RawID)
case "getversion":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":{"network":42,"tcpport":20332,"wsport":20342,"nonce":2153672787,"useragent":"/NEO-GO:0.73.1-pre-273-ge381358/"}}`, r.RawID)
case "getblockhash":
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":"0x157ca5e5b8cf8f84c9660502a3270b346011612bded1514a6847f877c433a9bb"}`, r.RawID)
}
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second))
require.NoError(t, err)
err = ws.WriteMessage(1, []byte(response))
if err != nil {
break
}
}
ws.Close()
return
}
}))
t.Cleanup(srv.Close)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
batchCount := 100
completed := atomic.NewInt32(0)
for i := 0; i < batchCount; i++ {
go func() {
_, err := wsc.GetBlockCount()
require.NoError(t, err)
completed.Inc()
}()
go func() {
_, err := wsc.GetBlockHash(123)
require.NoError(t, err)
completed.Inc()
}()
go func() {
_, err := wsc.GetVersion()
require.NoError(t, err)
completed.Inc()
}()
}
require.Eventually(t, func() bool {
return int(completed.Load()) == batchCount*3
}, time.Second, 100*time.Millisecond)
ids.lock.RLock()
require.True(t, len(ids.m) > batchCount)
idsList := make([]int, 0, len(ids.m))
for i := range ids.m {
idsList = append(idsList, i)
}
ids.lock.RUnlock()
sort.Ints(idsList)
require.Equal(t, 1, idsList[0])
require.Less(t, idsList[len(idsList)-1],
batchCount*3+1) // batchCount*requestsPerBatch+1
wsc.Close()
}