From 2896c0a83a66b896bf66f10c2f7204c871ea9650 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Tue, 22 Feb 2022 16:18:52 +0300 Subject: [PATCH] rpc: add test for concurrent WSClient access and request IDs generation --- pkg/rpc/client/wsclient_test.go | 97 +++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 675cee021..0945aff77 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -6,7 +6,10 @@ import ( "fmt" "net/http" "net/http/httptest" + "sort" + "strconv" "strings" + "sync" "testing" "time" @@ -15,6 +18,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) func TestWSClientClose(t *testing.T) { @@ -344,3 +348,96 @@ func TestNewWS(t *testing.T) { require.Error(t, err) }) } + +func TestWSConcurrentAccess(t *testing.T) { + var ids struct { + lock sync.RWMutex + m map[int]struct{} + } + ids.m = make(map[int]struct{}) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ws" && req.Method == "GET" { + var upgrader = websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + for { + err = ws.SetReadDeadline(time.Now().Add(2 * time.Second)) + require.NoError(t, err) + _, p, err := ws.ReadMessage() + if err != nil { + break + } + r := request.NewIn() + err = json.Unmarshal(p, r) + if err != nil { + t.Fatalf("Cannot decode request body: %s", req.Body) + } + i, err := strconv.Atoi(string(r.RawID)) + require.NoError(t, err) + ids.lock.Lock() + ids.m[i] = struct{}{} + ids.lock.Unlock() + var response string + // Different responses to catch possible unmarshalling errors connected with invalid IDs distribution. + switch r.Method { + case "getblockcount": + response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":123}`, r.RawID) + case "getversion": + response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":{"network":42,"tcpport":20332,"wsport":20342,"nonce":2153672787,"useragent":"/NEO-GO:0.73.1-pre-273-ge381358/"}}`, r.RawID) + case "getblockhash": + response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":"0x157ca5e5b8cf8f84c9660502a3270b346011612bded1514a6847f877c433a9bb"}`, r.RawID) + } + err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) + require.NoError(t, err) + err = ws.WriteMessage(1, []byte(response)) + if err != nil { + break + } + } + ws.Close() + return + } + })) + t.Cleanup(srv.Close) + + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + batchCount := 100 + completed := atomic.NewInt32(0) + for i := 0; i < batchCount; i++ { + go func() { + _, err := wsc.GetBlockCount() + require.NoError(t, err) + completed.Inc() + }() + go func() { + _, err := wsc.GetBlockHash(123) + require.NoError(t, err) + completed.Inc() + }() + + go func() { + _, err := wsc.GetVersion() + require.NoError(t, err) + completed.Inc() + }() + } + require.Eventually(t, func() bool { + return int(completed.Load()) == batchCount*3 + }, time.Second, 100*time.Millisecond) + + ids.lock.RLock() + require.True(t, len(ids.m) > batchCount) + idsList := make([]int, 0, len(ids.m)) + for i := range ids.m { + idsList = append(idsList, i) + } + ids.lock.RUnlock() + + sort.Ints(idsList) + require.Equal(t, 1, idsList[0]) + require.Less(t, idsList[len(idsList)-1], + batchCount*3+1) // batchCount*requestsPerBatch+1 + wsc.Close() +}