rpc: add test for concurrent WSClient access and request IDs generation
This commit is contained in:
parent
0d8723527c
commit
2896c0a83a
1 changed files with 97 additions and 0 deletions
|
@ -6,7 +6,10 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -15,6 +18,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func TestWSClientClose(t *testing.T) {
|
||||
|
@ -344,3 +348,96 @@ func TestNewWS(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWSConcurrentAccess(t *testing.T) {
|
||||
var ids struct {
|
||||
lock sync.RWMutex
|
||||
m map[int]struct{}
|
||||
}
|
||||
ids.m = make(map[int]struct{})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path == "/ws" && req.Method == "GET" {
|
||||
var upgrader = websocket.Upgrader{}
|
||||
ws, err := upgrader.Upgrade(w, req, nil)
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
err = ws.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
require.NoError(t, err)
|
||||
_, p, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
r := request.NewIn()
|
||||
err = json.Unmarshal(p, r)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot decode request body: %s", req.Body)
|
||||
}
|
||||
i, err := strconv.Atoi(string(r.RawID))
|
||||
require.NoError(t, err)
|
||||
ids.lock.Lock()
|
||||
ids.m[i] = struct{}{}
|
||||
ids.lock.Unlock()
|
||||
var response string
|
||||
// Different responses to catch possible unmarshalling errors connected with invalid IDs distribution.
|
||||
switch r.Method {
|
||||
case "getblockcount":
|
||||
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":123}`, r.RawID)
|
||||
case "getversion":
|
||||
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":{"network":42,"tcpport":20332,"wsport":20342,"nonce":2153672787,"useragent":"/NEO-GO:0.73.1-pre-273-ge381358/"}}`, r.RawID)
|
||||
case "getblockhash":
|
||||
response = fmt.Sprintf(`{"id":%s,"jsonrpc":"2.0","result":"0x157ca5e5b8cf8f84c9660502a3270b346011612bded1514a6847f877c433a9bb"}`, r.RawID)
|
||||
}
|
||||
err = ws.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||
require.NoError(t, err)
|
||||
err = ws.WriteMessage(1, []byte(response))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
ws.Close()
|
||||
return
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||
require.NoError(t, err)
|
||||
batchCount := 100
|
||||
completed := atomic.NewInt32(0)
|
||||
for i := 0; i < batchCount; i++ {
|
||||
go func() {
|
||||
_, err := wsc.GetBlockCount()
|
||||
require.NoError(t, err)
|
||||
completed.Inc()
|
||||
}()
|
||||
go func() {
|
||||
_, err := wsc.GetBlockHash(123)
|
||||
require.NoError(t, err)
|
||||
completed.Inc()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := wsc.GetVersion()
|
||||
require.NoError(t, err)
|
||||
completed.Inc()
|
||||
}()
|
||||
}
|
||||
require.Eventually(t, func() bool {
|
||||
return int(completed.Load()) == batchCount*3
|
||||
}, time.Second, 100*time.Millisecond)
|
||||
|
||||
ids.lock.RLock()
|
||||
require.True(t, len(ids.m) > batchCount)
|
||||
idsList := make([]int, 0, len(ids.m))
|
||||
for i := range ids.m {
|
||||
idsList = append(idsList, i)
|
||||
}
|
||||
ids.lock.RUnlock()
|
||||
|
||||
sort.Ints(idsList)
|
||||
require.Equal(t, 1, idsList[0])
|
||||
require.Less(t, idsList[len(idsList)-1],
|
||||
batchCount*3+1) // batchCount*requestsPerBatch+1
|
||||
wsc.Close()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue