From 4a49bf5de4f94f64c21363024064989f6f29f136 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Tue, 25 Apr 2023 11:26:36 +0300 Subject: [PATCH] rpcclient: introduce WSOptions for WSClient Make a separate structure for WSClient configuration. Signed-off-by: Anna Shaleva --- pkg/rpcclient/rpc_test.go | 2 +- pkg/rpcclient/wsclient.go | 13 +++++++++++-- pkg/rpcclient/wsclient_test.go | 28 ++++++++++++++-------------- pkg/services/rpcsrv/client_test.go | 4 ++-- 4 files changed, 28 insertions(+), 19 deletions(-) diff --git a/pkg/rpcclient/rpc_test.go b/pkg/rpcclient/rpc_test.go index 3cf3b9f9b..1e032c7fc 100644 --- a/pkg/rpcclient/rpc_test.go +++ b/pkg/rpcclient/rpc_test.go @@ -1803,7 +1803,7 @@ func TestRPCClients(t *testing.T) { }) t.Run("WSClient", func(t *testing.T) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { - wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) + wsc, err := NewWS(ctx, httpURLtoWS(endpoint), WSOptions{opts}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) diff --git a/pkg/rpcclient/wsclient.go b/pkg/rpcclient/wsclient.go index 49443f408..f6bb034e1 100644 --- a/pkg/rpcclient/wsclient.go +++ b/pkg/rpcclient/wsclient.go @@ -66,6 +66,7 @@ type WSClient struct { Notifications chan Notification ws *websocket.Conn + wsOpts WSOptions done chan struct{} requests chan *neorpc.Request shutdown chan struct{} @@ -86,6 +87,13 @@ type WSClient struct { respChannels map[uint64]chan *neorpc.Response } +// WSOptions defines options for the web-socket RPC client. It contains a +// set of options for the underlying standard RPC client as far as +// WSClient-specific options. See Options documentation for more details. +type WSOptions struct { + Options +} + // notificationReceiver is an interface aimed to provide WS subscriber functionality // for different types of subscriptions. type notificationReceiver interface { @@ -382,7 +390,7 @@ var errConnClosedByUser = errors.New("connection closed by user") // connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`. // You should call Init method to initialize the network magic the client is // operating on. -func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) { +func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) { dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} ws, resp, err := dialer.DialContext(ctx, endpoint, nil) if resp != nil && resp.Body != nil { // Can be non-nil even with error returned. @@ -405,6 +413,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error Notifications: make(chan Notification), ws: ws, + wsOpts: opts, shutdown: make(chan struct{}), done: make(chan struct{}), closeCalled: *atomic.NewBool(false), @@ -414,7 +423,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error receivers: make(map[any][]string), } - err = initClient(ctx, &wsc.Client, endpoint, opts) + err = initClient(ctx, &wsc.Client, endpoint, opts.Options) if err != nil { return nil, err } diff --git a/pkg/rpcclient/wsclient_test.go b/pkg/rpcclient/wsclient_test.go index c3f30c894..b72d2ec6b 100644 --- a/pkg/rpcclient/wsclient_test.go +++ b/pkg/rpcclient/wsclient_test.go @@ -31,7 +31,7 @@ import ( func TestWSClientClose(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID bCh := make(chan *block.Block) @@ -70,7 +70,7 @@ func TestWSClientSubscription(t *testing.T) { for name, f := range cases { t.Run(name, func(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) @@ -84,7 +84,7 @@ func TestWSClientSubscription(t *testing.T) { for name, f := range cases { t.Run(name, func(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) @@ -134,7 +134,7 @@ func TestWSClientUnsubscription(t *testing.T) { for name, rc := range cases { t.Run(name, func(t *testing.T) { srv := initTestServer(t, rc.response) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) @@ -170,7 +170,7 @@ func TestWSClientEvents(t *testing.T) { return } })) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID wsc.cacheLock.Lock() @@ -292,7 +292,7 @@ func TestWSClientEvents(t *testing.T) { func TestWSExecutionVMStateCheck(t *testing.T) { // Will answer successfully if request slips through. srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) @@ -527,7 +527,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { ws.Close() } })) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) wsc.getNextRequestID = getTestRequestID wsc.cache.network = netmode.UnitTestNet @@ -541,14 +541,14 @@ func TestNewWS(t *testing.T) { srv := initTestServer(t, "") t.Run("good", func(t *testing.T) { - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) c.getNextRequestID = getTestRequestID c.cache.network = netmode.UnitTestNet require.NoError(t, c.Init()) }) t.Run("bad URL", func(t *testing.T) { - _, err := NewWS(context.TODO(), strings.TrimPrefix(srv.URL, "http://"), Options{}) + _, err := NewWS(context.TODO(), strings.TrimPrefix(srv.URL, "http://"), WSOptions{}) require.Error(t, err) }) } @@ -605,7 +605,7 @@ func TestWSConcurrentAccess(t *testing.T) { })) t.Cleanup(srv.Close) - wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) batchCount := 100 completed := atomic.NewInt32(0) @@ -649,7 +649,7 @@ func TestWSConcurrentAccess(t *testing.T) { func TestWSDoubleClose(t *testing.T) { srv := initTestServer(t, "") - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) require.NotPanics(t, func() { @@ -661,7 +661,7 @@ func TestWSDoubleClose(t *testing.T) { func TestWS_RequestAfterClose(t *testing.T) { srv := initTestServer(t, "") - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) c.Close() @@ -676,7 +676,7 @@ func TestWS_RequestAfterClose(t *testing.T) { func TestWSClient_ConnClosedError(t *testing.T) { t.Run("standard closing", func(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": 123}`) - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) // Check client is working. @@ -692,7 +692,7 @@ func TestWSClient_ConnClosedError(t *testing.T) { t.Run("malformed request", func(t *testing.T) { srv := initTestServer(t, "") - c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{}) require.NoError(t, err) defaultMaxBlockSize := 262144 diff --git a/pkg/services/rpcsrv/client_test.go b/pkg/services/rpcsrv/client_test.go index 55b756c42..cb65ed676 100644 --- a/pkg/services/rpcsrv/client_test.go +++ b/pkg/services/rpcsrv/client_test.go @@ -2074,7 +2074,7 @@ func mkSubsClient(t *testing.T, rpcSrv *Server, httpSrv *httptest.Server, local icl, err = rpcclient.NewInternal(context.Background(), rpcSrv.RegisterLocal) } else { url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" - c, err = rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) + c, err = rpcclient.NewWS(context.Background(), url, rpcclient.WSOptions{}) } require.NoError(t, err) if local { @@ -2240,7 +2240,7 @@ func TestWSClientHandshakeError(t *testing.T) { defer rpcSrv.Shutdown() url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" - _, err := rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) + _, err := rpcclient.NewWS(context.Background(), url, rpcclient.WSOptions{}) require.ErrorContains(t, err, "websocket users limit reached") }