rpc: make RPC WSClient thread-safe

Add ability to use unique request IDs for RPC requests.
This commit is contained in:
AnnaShaleva 2022-02-18 20:28:13 +03:00
parent 5b2e88b916
commit 8991ee91cd
5 changed files with 93 additions and 20 deletions

View file

@ -17,6 +17,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/rpc/response" "github.com/nspcc-dev/neo-go/pkg/rpc/response"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"go.uber.org/atomic"
) )
const ( const (
@ -40,6 +41,12 @@ type Client struct {
// cache is mostly filled in during Init(), but can also be updated // cache is mostly filled in during Init(), but can also be updated
// during regular Client lifecycle. // during regular Client lifecycle.
cache cache cache cache
latestReqID *atomic.Uint64
// getNextRequestID returns ID to be used for subsequent request creation.
// It is defined on Client so that our testing code can override this method
// for the sake of more predictable request IDs generation behaviour.
getNextRequestID func() uint64
} }
// Options defines options for the RPC client. // Options defines options for the RPC client.
@ -110,12 +117,18 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
cache: cache{ cache: cache{
nativeHashes: make(map[string]util.Uint160), nativeHashes: make(map[string]util.Uint160),
}, },
latestReqID: atomic.NewUint64(0),
} }
cl.getNextRequestID = (cl).getRequestID
cl.opts = opts cl.opts = opts
cl.requestF = cl.makeHTTPRequest cl.requestF = cl.makeHTTPRequest
return cl, nil return cl, nil
} }
func (c *Client) getRequestID() uint64 {
return c.latestReqID.Inc()
}
// Init sets magic of the network client connected to, stateRootInHeader option // Init sets magic of the network client connected to, stateRootInHeader option
// and native NEO, GAS and Policy contracts scripthashes. This method should be // and native NEO, GAS and Policy contracts scripthashes. This method should be
// called before any header- or block-related requests in order to deserialize // called before any header- or block-related requests in order to deserialize
@ -159,7 +172,7 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{
JSONRPC: request.JSONRPCVersion, JSONRPC: request.JSONRPCVersion,
Method: method, Method: method,
RawParams: p.Values, RawParams: p.Values,
ID: 1, ID: c.getNextRequestID(),
} }
raw, err := c.requestF(&r) raw, err := c.requestF(&r)

View file

@ -1705,6 +1705,7 @@ func TestRPCClients(t *testing.T) {
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
c, err := New(ctx, endpoint, opts) c, err := New(ctx, endpoint, opts)
require.NoError(t, err) require.NoError(t, err)
c.getNextRequestID = getTestRequestID
require.NoError(t, c.Init()) require.NoError(t, c.Init())
return c, nil return c, nil
}) })
@ -1713,6 +1714,7 @@ func TestRPCClients(t *testing.T) {
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
return &wsc.Client, nil return &wsc.Client, nil
}) })
@ -1732,6 +1734,7 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.getNextRequestID = getTestRequestID
actual, err := testCase.invoke(c) actual, err := testCase.invoke(c)
if testCase.fails { if testCase.fails {
@ -1755,14 +1758,14 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
endpoint := srv.URL endpoint := srv.URL
opts := Options{} opts := Options{}
c, err := newClient(context.TODO(), endpoint, opts)
if err != nil {
t.Fatal(err)
}
for _, testCase := range testBatch { for _, testCase := range testBatch {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
_, err := testCase.invoke(c) c, err := newClient(context.TODO(), endpoint, opts)
if err != nil {
t.Fatal(err)
}
c.getNextRequestID = getTestRequestID
_, err = testCase.invoke(c)
assert.Error(t, err) assert.Error(t, err)
}) })
} }
@ -1878,6 +1881,7 @@ func TestCalculateValidUntilBlock(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.getNextRequestID = getTestRequestID
require.NoError(t, c.Init()) require.NoError(t, c.Init())
validUntilBlock, err := c.CalculateValidUntilBlock() validUntilBlock, err := c.CalculateValidUntilBlock()
@ -1913,6 +1917,7 @@ func TestGetNetwork(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.getNextRequestID = getTestRequestID
// network was not initialised // network was not initialised
_, err = c.GetNetwork() _, err = c.GetNetwork()
require.True(t, errors.Is(err, errNetworkNotInitialized)) require.True(t, errors.Is(err, errNetworkNotInitialized))
@ -1924,6 +1929,7 @@ func TestGetNetwork(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.getNextRequestID = getTestRequestID
require.NoError(t, c.Init()) require.NoError(t, c.Init())
m, err := c.GetNetwork() m, err := c.GetNetwork()
require.NoError(t, err) require.NoError(t, err)
@ -1945,6 +1951,7 @@ func TestUninitedClient(t *testing.T) {
c, err := New(context.TODO(), endpoint, opts) c, err := New(context.TODO(), endpoint, opts)
require.NoError(t, err) require.NoError(t, err)
c.getNextRequestID = getTestRequestID
_, err = c.GetBlockByIndex(0) _, err = c.GetBlockByIndex(0)
require.Error(t, err) require.Error(t, err)
@ -1970,3 +1977,7 @@ func newTestNEF(script []byte) nef.File {
ne.Checksum = ne.CalculateChecksum() ne.Checksum = ne.CalculateChecksum()
return ne return ne
} }
func getTestRequestID() uint64 {
return 1
}

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"strconv"
"sync" "sync"
"time" "time"
@ -31,14 +32,16 @@ type WSClient struct {
// be closed, so make sure to handle this. // be closed, so make sure to handle this.
Notifications chan Notification Notifications chan Notification
ws *websocket.Conn ws *websocket.Conn
done chan struct{} done chan struct{}
responses chan *response.Raw requests chan *request.Raw
requests chan *request.Raw shutdown chan struct{}
shutdown chan struct{}
subscriptionsLock sync.RWMutex subscriptionsLock sync.RWMutex
subscriptions map[string]bool subscriptions map[string]bool
respLock sync.RWMutex
respChannels map[uint64]chan *response.Raw
} }
// Notification represents server-generated notification for client subscriptions. // Notification represents server-generated notification for client subscriptions.
@ -95,7 +98,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error
ws: ws, ws: ws,
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
done: make(chan struct{}), done: make(chan struct{}),
responses: make(chan *response.Raw), respChannels: make(map[uint64]chan *response.Raw),
requests: make(chan *request.Raw), requests: make(chan *request.Raw),
subscriptions: make(map[string]bool), subscriptions: make(map[string]bool),
} }
@ -178,14 +181,27 @@ readloop:
resp.JSONRPC = rr.JSONRPC resp.JSONRPC = rr.JSONRPC
resp.Error = rr.Error resp.Error = rr.Error
resp.Result = rr.Result resp.Result = rr.Result
c.responses <- resp id, err := strconv.Atoi(string(resp.ID))
if err != nil {
break // Malformed response (invalid response ID).
}
ch := c.getResponseChannel(uint64(id))
if ch == nil {
break // Unknown response (unexpected response ID).
}
ch <- resp
} else { } else {
// Malformed response, neither valid request, nor valid response. // Malformed response, neither valid request, nor valid response.
break break
} }
} }
close(c.done) close(c.done)
close(c.responses) c.respLock.Lock()
for _, ch := range c.respChannels {
close(ch)
}
c.respChannels = nil
c.respLock.Unlock()
close(c.Notifications) close(c.Notifications)
} }
@ -220,16 +236,41 @@ func (c *WSClient) wsWriter() {
} }
} }
func (c *WSClient) registerRespChannel(id uint64, ch chan *response.Raw) {
c.respLock.Lock()
defer c.respLock.Unlock()
c.respChannels[id] = ch
}
func (c *WSClient) unregisterRespChannel(id uint64) {
c.respLock.Lock()
defer c.respLock.Unlock()
if ch, ok := c.respChannels[id]; ok {
delete(c.respChannels, id)
close(ch)
}
}
func (c *WSClient) getResponseChannel(id uint64) chan *response.Raw {
c.respLock.RLock()
defer c.respLock.RUnlock()
return c.respChannels[id]
}
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) { func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
ch := make(chan *response.Raw)
c.registerRespChannel(r.ID, ch)
select { select {
case <-c.done: case <-c.done:
return nil, errors.New("connection lost") return nil, errors.New("connection lost before sending the request")
case c.requests <- r: case c.requests <- r:
} }
select { select {
case <-c.done: case <-c.done:
return nil, errors.New("connection lost") return nil, errors.New("connection lost while waiting for the response")
case resp := <-c.responses: case resp := <-ch:
c.unregisterRespChannel(r.ID)
return resp, nil return resp, nil
} }
} }

View file

@ -45,6 +45,7 @@ func TestWSClientSubscription(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
id, err := f(wsc) id, err := f(wsc)
require.NoError(t, err) require.NoError(t, err)
@ -58,6 +59,7 @@ func TestWSClientSubscription(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
_, err = f(wsc) _, err = f(wsc)
require.Error(t, err) require.Error(t, err)
@ -107,6 +109,7 @@ func TestWSClientUnsubscription(t *testing.T) {
srv := initTestServer(t, rc.response) srv := initTestServer(t, rc.response)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
rc.code(t, wsc) rc.code(t, wsc)
}) })
@ -143,6 +146,7 @@ func TestWSClientEvents(t *testing.T) {
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually. wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually.
wsc.cache.network = netmode.UnitTestNet wsc.cache.network = netmode.UnitTestNet
for range events { for range events {
@ -167,6 +171,7 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
require.NoError(t, wsc.Init()) require.NoError(t, wsc.Init())
filter := "NONE" filter := "NONE"
_, err = wsc.SubscribeForTransactionExecutions(&filter) _, err = wsc.SubscribeForTransactionExecutions(&filter)
@ -316,6 +321,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
})) }))
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
wsc.getNextRequestID = getTestRequestID
wsc.cache.network = netmode.UnitTestNet wsc.cache.network = netmode.UnitTestNet
c.clientCode(t, wsc) c.clientCode(t, wsc)
wsc.Close() wsc.Close()
@ -329,6 +335,8 @@ func TestNewWS(t *testing.T) {
t.Run("good", func(t *testing.T) { t.Run("good", func(t *testing.T) {
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err) require.NoError(t, err)
c.getNextRequestID = getTestRequestID
c.cache.network = netmode.UnitTestNet
require.NoError(t, c.Init()) require.NoError(t, c.Init())
}) })
t.Run("bad URL", func(t *testing.T) { t.Run("bad URL", func(t *testing.T) {

View file

@ -32,12 +32,12 @@ func NewRawParams(vals ...interface{}) RawParams {
return p return p
} }
// Raw represents JSON-RPC request. // Raw represents JSON-RPC request on the Client side.
type Raw struct { type Raw struct {
JSONRPC string `json:"jsonrpc"` JSONRPC string `json:"jsonrpc"`
Method string `json:"method"` Method string `json:"method"`
RawParams []interface{} `json:"params"` RawParams []interface{} `json:"params"`
ID int `json:"id"` ID uint64 `json:"id"`
} }
// Request contains standard JSON-RPC 2.0 request and batch of // Request contains standard JSON-RPC 2.0 request and batch of