rpcclient: introduce WSOptions for WSClient

Make a separate structure for WSClient configuration.

Signed-off-by: Anna Shaleva <shaleva.ann@nspcc.ru>
This commit is contained in:
Anna Shaleva 2023-04-25 11:26:36 +03:00
parent dd8218f87a
commit 4a49bf5de4
4 changed files with 28 additions and 19 deletions

View file

@ -1803,7 +1803,7 @@ func TestRPCClients(t *testing.T) {
}) })
t.Run("WSClient", func(t *testing.T) { t.Run("WSClient", func(t *testing.T) {
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) wsc, err := NewWS(ctx, httpURLtoWS(endpoint), WSOptions{opts})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())

View file

@ -66,6 +66,7 @@ type WSClient struct {
Notifications chan Notification Notifications chan Notification
ws *websocket.Conn ws *websocket.Conn
wsOpts WSOptions
done chan struct{} done chan struct{}
requests chan *neorpc.Request requests chan *neorpc.Request
shutdown chan struct{} shutdown chan struct{}
@ -86,6 +87,13 @@ type WSClient struct {
respChannels map[uint64]chan *neorpc.Response respChannels map[uint64]chan *neorpc.Response
} }
// WSOptions defines options for the web-socket RPC client. It contains a
// set of options for the underlying standard RPC client as far as
// WSClient-specific options. See Options documentation for more details.
type WSOptions struct {
Options
}
// notificationReceiver is an interface aimed to provide WS subscriber functionality // notificationReceiver is an interface aimed to provide WS subscriber functionality
// for different types of subscriptions. // for different types of subscriptions.
type notificationReceiver interface { type notificationReceiver interface {
@ -382,7 +390,7 @@ var errConnClosedByUser = errors.New("connection closed by user")
// connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`. // connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`.
// You should call Init method to initialize the network magic the client is // You should call Init method to initialize the network magic the client is
// operating on. // operating on.
func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) { func NewWS(ctx context.Context, endpoint string, opts WSOptions) (*WSClient, error) {
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
ws, resp, err := dialer.DialContext(ctx, endpoint, nil) ws, resp, err := dialer.DialContext(ctx, endpoint, nil)
if resp != nil && resp.Body != nil { // Can be non-nil even with error returned. if resp != nil && resp.Body != nil { // Can be non-nil even with error returned.
@ -405,6 +413,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error
Notifications: make(chan Notification), Notifications: make(chan Notification),
ws: ws, ws: ws,
wsOpts: opts,
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
done: make(chan struct{}), done: make(chan struct{}),
closeCalled: *atomic.NewBool(false), closeCalled: *atomic.NewBool(false),
@ -414,7 +423,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error
receivers: make(map[any][]string), receivers: make(map[any][]string),
} }
err = initClient(ctx, &wsc.Client, endpoint, opts) err = initClient(ctx, &wsc.Client, endpoint, opts.Options)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -31,7 +31,7 @@ import (
func TestWSClientClose(t *testing.T) { func TestWSClientClose(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID wsc.getNextRequestID = getTestRequestID
bCh := make(chan *block.Block) bCh := make(chan *block.Block)
@ -70,7 +70,7 @@ func TestWSClientSubscription(t *testing.T) {
for name, f := range cases { for name, f := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
@ -84,7 +84,7 @@ func TestWSClientSubscription(t *testing.T) {
for name, f := range cases { for name, f := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
@ -134,7 +134,7 @@ func TestWSClientUnsubscription(t *testing.T) {
for name, rc := range cases { for name, rc := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
srv := initTestServer(t, rc.response) srv := initTestServer(t, rc.response)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
@ -170,7 +170,7 @@ func TestWSClientEvents(t *testing.T) {
return return
} }
})) }))
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID wsc.getNextRequestID = getTestRequestID
wsc.cacheLock.Lock() wsc.cacheLock.Lock()
@ -292,7 +292,7 @@ func TestWSClientEvents(t *testing.T) {
func TestWSExecutionVMStateCheck(t *testing.T) { func TestWSExecutionVMStateCheck(t *testing.T) {
// Will answer successfully if request slips through. // Will answer successfully if request slips through.
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
@ -527,7 +527,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
ws.Close() ws.Close()
} }
})) }))
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID wsc.getNextRequestID = getTestRequestID
wsc.cache.network = netmode.UnitTestNet wsc.cache.network = netmode.UnitTestNet
@ -541,14 +541,14 @@ func TestNewWS(t *testing.T) {
srv := initTestServer(t, "") srv := initTestServer(t, "")
t.Run("good", func(t *testing.T) { t.Run("good", func(t *testing.T) {
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
c.getNextRequestID = getTestRequestID c.getNextRequestID = getTestRequestID
c.cache.network = netmode.UnitTestNet c.cache.network = netmode.UnitTestNet
require.NoError(t, c.Init()) require.NoError(t, c.Init())
}) })
t.Run("bad URL", func(t *testing.T) { t.Run("bad URL", func(t *testing.T) {
_, err := NewWS(context.TODO(), strings.TrimPrefix(srv.URL, "http://"), Options{}) _, err := NewWS(context.TODO(), strings.TrimPrefix(srv.URL, "http://"), WSOptions{})
require.Error(t, err) require.Error(t, err)
}) })
} }
@ -605,7 +605,7 @@ func TestWSConcurrentAccess(t *testing.T) {
})) }))
t.Cleanup(srv.Close) t.Cleanup(srv.Close)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
batchCount := 100 batchCount := 100
completed := atomic.NewInt32(0) completed := atomic.NewInt32(0)
@ -649,7 +649,7 @@ func TestWSConcurrentAccess(t *testing.T) {
func TestWSDoubleClose(t *testing.T) { func TestWSDoubleClose(t *testing.T) {
srv := initTestServer(t, "") srv := initTestServer(t, "")
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
require.NotPanics(t, func() { require.NotPanics(t, func() {
@ -661,7 +661,7 @@ func TestWSDoubleClose(t *testing.T) {
func TestWS_RequestAfterClose(t *testing.T) { func TestWS_RequestAfterClose(t *testing.T) {
srv := initTestServer(t, "") srv := initTestServer(t, "")
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
c.Close() c.Close()
@ -676,7 +676,7 @@ func TestWS_RequestAfterClose(t *testing.T) {
func TestWSClient_ConnClosedError(t *testing.T) { func TestWSClient_ConnClosedError(t *testing.T) {
t.Run("standard closing", func(t *testing.T) { t.Run("standard closing", func(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": 123}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": 123}`)
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
// Check client is working. // Check client is working.
@ -692,7 +692,7 @@ func TestWSClient_ConnClosedError(t *testing.T) {
t.Run("malformed request", func(t *testing.T) { t.Run("malformed request", func(t *testing.T) {
srv := initTestServer(t, "") srv := initTestServer(t, "")
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), WSOptions{})
require.NoError(t, err) require.NoError(t, err)
defaultMaxBlockSize := 262144 defaultMaxBlockSize := 262144

View file

@ -2074,7 +2074,7 @@ func mkSubsClient(t *testing.T, rpcSrv *Server, httpSrv *httptest.Server, local
icl, err = rpcclient.NewInternal(context.Background(), rpcSrv.RegisterLocal) icl, err = rpcclient.NewInternal(context.Background(), rpcSrv.RegisterLocal)
} else { } else {
url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws"
c, err = rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) c, err = rpcclient.NewWS(context.Background(), url, rpcclient.WSOptions{})
} }
require.NoError(t, err) require.NoError(t, err)
if local { if local {
@ -2240,7 +2240,7 @@ func TestWSClientHandshakeError(t *testing.T) {
defer rpcSrv.Shutdown() defer rpcSrv.Shutdown()
url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" url := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws"
_, err := rpcclient.NewWS(context.Background(), url, rpcclient.Options{}) _, err := rpcclient.NewWS(context.Background(), url, rpcclient.WSOptions{})
require.ErrorContains(t, err, "websocket users limit reached") require.ErrorContains(t, err, "websocket users limit reached")
} }