From 8991ee91cd6e2f172bc901416d1eb0bd8773047f Mon Sep 17 00:00:00 2001 From: AnnaShaleva Date: Fri, 18 Feb 2022 20:28:13 +0300 Subject: [PATCH] rpc: make RPC WSClient thread-safe Add ability to use unique request IDs for RPC requests. --- pkg/rpc/client/client.go | 15 +++++++- pkg/rpc/client/rpc_test.go | 23 ++++++++---- pkg/rpc/client/wsclient.go | 63 +++++++++++++++++++++++++++------ pkg/rpc/client/wsclient_test.go | 8 +++++ pkg/rpc/request/types.go | 4 +-- 5 files changed, 93 insertions(+), 20 deletions(-) diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index 4dfc40ef4..20baae7ba 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -17,6 +17,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response" "github.com/nspcc-dev/neo-go/pkg/util" + "go.uber.org/atomic" ) const ( @@ -40,6 +41,12 @@ type Client struct { // cache is mostly filled in during Init(), but can also be updated // during regular Client lifecycle. cache cache + + latestReqID *atomic.Uint64 + // getNextRequestID returns ID to be used for subsequent request creation. + // It is defined on Client so that our testing code can override this method + // for the sake of more predictable request IDs generation behaviour. + getNextRequestID func() uint64 } // Options defines options for the RPC client. @@ -110,12 +117,18 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { cache: cache{ nativeHashes: make(map[string]util.Uint160), }, + latestReqID: atomic.NewUint64(0), } + cl.getNextRequestID = (cl).getRequestID cl.opts = opts cl.requestF = cl.makeHTTPRequest return cl, nil } +func (c *Client) getRequestID() uint64 { + return c.latestReqID.Inc() +} + // Init sets magic of the network client connected to, stateRootInHeader option // and native NEO, GAS and Policy contracts scripthashes. This method should be // called before any header- or block-related requests in order to deserialize @@ -159,7 +172,7 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{ JSONRPC: request.JSONRPCVersion, Method: method, RawParams: p.Values, - ID: 1, + ID: c.getNextRequestID(), } raw, err := c.requestF(&r) diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index 8c099e810..1e58d2857 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -1705,6 +1705,7 @@ func TestRPCClients(t *testing.T) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { c, err := New(ctx, endpoint, opts) require.NoError(t, err) + c.getNextRequestID = getTestRequestID require.NoError(t, c.Init()) return c, nil }) @@ -1713,6 +1714,7 @@ func TestRPCClients(t *testing.T) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) return &wsc.Client, nil }) @@ -1732,6 +1734,7 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options if err != nil { t.Fatal(err) } + c.getNextRequestID = getTestRequestID actual, err := testCase.invoke(c) if testCase.fails { @@ -1755,14 +1758,14 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options endpoint := srv.URL opts := Options{} - c, err := newClient(context.TODO(), endpoint, opts) - if err != nil { - t.Fatal(err) - } - for _, testCase := range testBatch { t.Run(testCase.name, func(t *testing.T) { - _, err := testCase.invoke(c) + c, err := newClient(context.TODO(), endpoint, opts) + if err != nil { + t.Fatal(err) + } + c.getNextRequestID = getTestRequestID + _, err = testCase.invoke(c) assert.Error(t, err) }) } @@ -1878,6 +1881,7 @@ func TestCalculateValidUntilBlock(t *testing.T) { if err != nil { t.Fatal(err) } + c.getNextRequestID = getTestRequestID require.NoError(t, c.Init()) validUntilBlock, err := c.CalculateValidUntilBlock() @@ -1913,6 +1917,7 @@ func TestGetNetwork(t *testing.T) { if err != nil { t.Fatal(err) } + c.getNextRequestID = getTestRequestID // network was not initialised _, err = c.GetNetwork() require.True(t, errors.Is(err, errNetworkNotInitialized)) @@ -1924,6 +1929,7 @@ func TestGetNetwork(t *testing.T) { if err != nil { t.Fatal(err) } + c.getNextRequestID = getTestRequestID require.NoError(t, c.Init()) m, err := c.GetNetwork() require.NoError(t, err) @@ -1945,6 +1951,7 @@ func TestUninitedClient(t *testing.T) { c, err := New(context.TODO(), endpoint, opts) require.NoError(t, err) + c.getNextRequestID = getTestRequestID _, err = c.GetBlockByIndex(0) require.Error(t, err) @@ -1970,3 +1977,7 @@ func newTestNEF(script []byte) nef.File { ne.Checksum = ne.CalculateChecksum() return ne } + +func getTestRequestID() uint64 { + return 1 +} diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go index 6cb5e6049..c3f13a522 100644 --- a/pkg/rpc/client/wsclient.go +++ b/pkg/rpc/client/wsclient.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "strconv" "sync" "time" @@ -31,14 +32,16 @@ type WSClient struct { // be closed, so make sure to handle this. Notifications chan Notification - ws *websocket.Conn - done chan struct{} - responses chan *response.Raw - requests chan *request.Raw - shutdown chan struct{} + ws *websocket.Conn + done chan struct{} + requests chan *request.Raw + shutdown chan struct{} subscriptionsLock sync.RWMutex subscriptions map[string]bool + + respLock sync.RWMutex + respChannels map[uint64]chan *response.Raw } // Notification represents server-generated notification for client subscriptions. @@ -95,7 +98,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error ws: ws, shutdown: make(chan struct{}), done: make(chan struct{}), - responses: make(chan *response.Raw), + respChannels: make(map[uint64]chan *response.Raw), requests: make(chan *request.Raw), subscriptions: make(map[string]bool), } @@ -178,14 +181,27 @@ readloop: resp.JSONRPC = rr.JSONRPC resp.Error = rr.Error resp.Result = rr.Result - c.responses <- resp + id, err := strconv.Atoi(string(resp.ID)) + if err != nil { + break // Malformed response (invalid response ID). + } + ch := c.getResponseChannel(uint64(id)) + if ch == nil { + break // Unknown response (unexpected response ID). + } + ch <- resp } else { // Malformed response, neither valid request, nor valid response. break } } close(c.done) - close(c.responses) + c.respLock.Lock() + for _, ch := range c.respChannels { + close(ch) + } + c.respChannels = nil + c.respLock.Unlock() close(c.Notifications) } @@ -220,16 +236,41 @@ func (c *WSClient) wsWriter() { } } +func (c *WSClient) registerRespChannel(id uint64, ch chan *response.Raw) { + c.respLock.Lock() + defer c.respLock.Unlock() + c.respChannels[id] = ch +} + +func (c *WSClient) unregisterRespChannel(id uint64) { + c.respLock.Lock() + defer c.respLock.Unlock() + if ch, ok := c.respChannels[id]; ok { + delete(c.respChannels, id) + close(ch) + } +} + +func (c *WSClient) getResponseChannel(id uint64) chan *response.Raw { + c.respLock.RLock() + defer c.respLock.RUnlock() + return c.respChannels[id] +} + func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) { + ch := make(chan *response.Raw) + c.registerRespChannel(r.ID, ch) + select { case <-c.done: - return nil, errors.New("connection lost") + return nil, errors.New("connection lost before sending the request") case c.requests <- r: } select { case <-c.done: - return nil, errors.New("connection lost") - case resp := <-c.responses: + return nil, errors.New("connection lost while waiting for the response") + case resp := <-ch: + c.unregisterRespChannel(r.ID) return resp, nil } } diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 7a8eda105..675cee021 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -45,6 +45,7 @@ func TestWSClientSubscription(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) id, err := f(wsc) require.NoError(t, err) @@ -58,6 +59,7 @@ func TestWSClientSubscription(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) _, err = f(wsc) require.Error(t, err) @@ -107,6 +109,7 @@ func TestWSClientUnsubscription(t *testing.T) { srv := initTestServer(t, rc.response) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) rc.code(t, wsc) }) @@ -143,6 +146,7 @@ func TestWSClientEvents(t *testing.T) { wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually. wsc.cache.network = netmode.UnitTestNet for range events { @@ -167,6 +171,7 @@ func TestWSExecutionVMStateCheck(t *testing.T) { srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID require.NoError(t, wsc.Init()) filter := "NONE" _, err = wsc.SubscribeForTransactionExecutions(&filter) @@ -316,6 +321,7 @@ func TestWSFilteredSubscriptions(t *testing.T) { })) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + wsc.getNextRequestID = getTestRequestID wsc.cache.network = netmode.UnitTestNet c.clientCode(t, wsc) wsc.Close() @@ -329,6 +335,8 @@ func TestNewWS(t *testing.T) { t.Run("good", func(t *testing.T) { c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) require.NoError(t, err) + c.getNextRequestID = getTestRequestID + c.cache.network = netmode.UnitTestNet require.NoError(t, c.Init()) }) t.Run("bad URL", func(t *testing.T) { diff --git a/pkg/rpc/request/types.go b/pkg/rpc/request/types.go index 654771060..2281ef3ae 100644 --- a/pkg/rpc/request/types.go +++ b/pkg/rpc/request/types.go @@ -32,12 +32,12 @@ func NewRawParams(vals ...interface{}) RawParams { return p } -// Raw represents JSON-RPC request. +// Raw represents JSON-RPC request on the Client side. type Raw struct { JSONRPC string `json:"jsonrpc"` Method string `json:"method"` RawParams []interface{} `json:"params"` - ID int `json:"id"` + ID uint64 `json:"id"` } // Request contains standard JSON-RPC 2.0 request and batch of