mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-26 19:42:23 +00:00
rpc: make RPC WSClient thread-safe
Add ability to use unique request IDs for RPC requests.
This commit is contained in:
parent
5b2e88b916
commit
8991ee91cd
5 changed files with 93 additions and 20 deletions
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -40,6 +41,12 @@ type Client struct {
|
||||||
// cache is mostly filled in during Init(), but can also be updated
|
// cache is mostly filled in during Init(), but can also be updated
|
||||||
// during regular Client lifecycle.
|
// during regular Client lifecycle.
|
||||||
cache cache
|
cache cache
|
||||||
|
|
||||||
|
latestReqID *atomic.Uint64
|
||||||
|
// getNextRequestID returns ID to be used for subsequent request creation.
|
||||||
|
// It is defined on Client so that our testing code can override this method
|
||||||
|
// for the sake of more predictable request IDs generation behaviour.
|
||||||
|
getNextRequestID func() uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options defines options for the RPC client.
|
// Options defines options for the RPC client.
|
||||||
|
@ -110,12 +117,18 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
cache: cache{
|
cache: cache{
|
||||||
nativeHashes: make(map[string]util.Uint160),
|
nativeHashes: make(map[string]util.Uint160),
|
||||||
},
|
},
|
||||||
|
latestReqID: atomic.NewUint64(0),
|
||||||
}
|
}
|
||||||
|
cl.getNextRequestID = (cl).getRequestID
|
||||||
cl.opts = opts
|
cl.opts = opts
|
||||||
cl.requestF = cl.makeHTTPRequest
|
cl.requestF = cl.makeHTTPRequest
|
||||||
return cl, nil
|
return cl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) getRequestID() uint64 {
|
||||||
|
return c.latestReqID.Inc()
|
||||||
|
}
|
||||||
|
|
||||||
// Init sets magic of the network client connected to, stateRootInHeader option
|
// Init sets magic of the network client connected to, stateRootInHeader option
|
||||||
// and native NEO, GAS and Policy contracts scripthashes. This method should be
|
// and native NEO, GAS and Policy contracts scripthashes. This method should be
|
||||||
// called before any header- or block-related requests in order to deserialize
|
// called before any header- or block-related requests in order to deserialize
|
||||||
|
@ -159,7 +172,7 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{
|
||||||
JSONRPC: request.JSONRPCVersion,
|
JSONRPC: request.JSONRPCVersion,
|
||||||
Method: method,
|
Method: method,
|
||||||
RawParams: p.Values,
|
RawParams: p.Values,
|
||||||
ID: 1,
|
ID: c.getNextRequestID(),
|
||||||
}
|
}
|
||||||
|
|
||||||
raw, err := c.requestF(&r)
|
raw, err := c.requestF(&r)
|
||||||
|
|
|
@ -1705,6 +1705,7 @@ func TestRPCClients(t *testing.T) {
|
||||||
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
c, err := New(ctx, endpoint, opts)
|
c, err := New(ctx, endpoint, opts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, c.Init())
|
require.NoError(t, c.Init())
|
||||||
return c, nil
|
return c, nil
|
||||||
})
|
})
|
||||||
|
@ -1713,6 +1714,7 @@ func TestRPCClients(t *testing.T) {
|
||||||
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
|
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
return &wsc.Client, nil
|
return &wsc.Client, nil
|
||||||
})
|
})
|
||||||
|
@ -1732,6 +1734,7 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
|
|
||||||
actual, err := testCase.invoke(c)
|
actual, err := testCase.invoke(c)
|
||||||
if testCase.fails {
|
if testCase.fails {
|
||||||
|
@ -1755,14 +1758,14 @@ func testRPCClient(t *testing.T, newClient func(context.Context, string, Options
|
||||||
|
|
||||||
endpoint := srv.URL
|
endpoint := srv.URL
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
c, err := newClient(context.TODO(), endpoint, opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range testBatch {
|
for _, testCase := range testBatch {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
_, err := testCase.invoke(c)
|
c, err := newClient(context.TODO(), endpoint, opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
|
_, err = testCase.invoke(c)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1878,6 +1881,7 @@ func TestCalculateValidUntilBlock(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, c.Init())
|
require.NoError(t, c.Init())
|
||||||
|
|
||||||
validUntilBlock, err := c.CalculateValidUntilBlock()
|
validUntilBlock, err := c.CalculateValidUntilBlock()
|
||||||
|
@ -1913,6 +1917,7 @@ func TestGetNetwork(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
// network was not initialised
|
// network was not initialised
|
||||||
_, err = c.GetNetwork()
|
_, err = c.GetNetwork()
|
||||||
require.True(t, errors.Is(err, errNetworkNotInitialized))
|
require.True(t, errors.Is(err, errNetworkNotInitialized))
|
||||||
|
@ -1924,6 +1929,7 @@ func TestGetNetwork(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, c.Init())
|
require.NoError(t, c.Init())
|
||||||
m, err := c.GetNetwork()
|
m, err := c.GetNetwork()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -1945,6 +1951,7 @@ func TestUninitedClient(t *testing.T) {
|
||||||
|
|
||||||
c, err := New(context.TODO(), endpoint, opts)
|
c, err := New(context.TODO(), endpoint, opts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
|
|
||||||
_, err = c.GetBlockByIndex(0)
|
_, err = c.GetBlockByIndex(0)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
@ -1970,3 +1977,7 @@ func newTestNEF(script []byte) nef.File {
|
||||||
ne.Checksum = ne.CalculateChecksum()
|
ne.Checksum = ne.CalculateChecksum()
|
||||||
return ne
|
return ne
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTestRequestID() uint64 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -31,14 +32,16 @@ type WSClient struct {
|
||||||
// be closed, so make sure to handle this.
|
// be closed, so make sure to handle this.
|
||||||
Notifications chan Notification
|
Notifications chan Notification
|
||||||
|
|
||||||
ws *websocket.Conn
|
ws *websocket.Conn
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
responses chan *response.Raw
|
requests chan *request.Raw
|
||||||
requests chan *request.Raw
|
shutdown chan struct{}
|
||||||
shutdown chan struct{}
|
|
||||||
|
|
||||||
subscriptionsLock sync.RWMutex
|
subscriptionsLock sync.RWMutex
|
||||||
subscriptions map[string]bool
|
subscriptions map[string]bool
|
||||||
|
|
||||||
|
respLock sync.RWMutex
|
||||||
|
respChannels map[uint64]chan *response.Raw
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notification represents server-generated notification for client subscriptions.
|
// Notification represents server-generated notification for client subscriptions.
|
||||||
|
@ -95,7 +98,7 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error
|
||||||
ws: ws,
|
ws: ws,
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
responses: make(chan *response.Raw),
|
respChannels: make(map[uint64]chan *response.Raw),
|
||||||
requests: make(chan *request.Raw),
|
requests: make(chan *request.Raw),
|
||||||
subscriptions: make(map[string]bool),
|
subscriptions: make(map[string]bool),
|
||||||
}
|
}
|
||||||
|
@ -178,14 +181,27 @@ readloop:
|
||||||
resp.JSONRPC = rr.JSONRPC
|
resp.JSONRPC = rr.JSONRPC
|
||||||
resp.Error = rr.Error
|
resp.Error = rr.Error
|
||||||
resp.Result = rr.Result
|
resp.Result = rr.Result
|
||||||
c.responses <- resp
|
id, err := strconv.Atoi(string(resp.ID))
|
||||||
|
if err != nil {
|
||||||
|
break // Malformed response (invalid response ID).
|
||||||
|
}
|
||||||
|
ch := c.getResponseChannel(uint64(id))
|
||||||
|
if ch == nil {
|
||||||
|
break // Unknown response (unexpected response ID).
|
||||||
|
}
|
||||||
|
ch <- resp
|
||||||
} else {
|
} else {
|
||||||
// Malformed response, neither valid request, nor valid response.
|
// Malformed response, neither valid request, nor valid response.
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
close(c.done)
|
close(c.done)
|
||||||
close(c.responses)
|
c.respLock.Lock()
|
||||||
|
for _, ch := range c.respChannels {
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
c.respChannels = nil
|
||||||
|
c.respLock.Unlock()
|
||||||
close(c.Notifications)
|
close(c.Notifications)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,16 +236,41 @@ func (c *WSClient) wsWriter() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) registerRespChannel(id uint64, ch chan *response.Raw) {
|
||||||
|
c.respLock.Lock()
|
||||||
|
defer c.respLock.Unlock()
|
||||||
|
c.respChannels[id] = ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) unregisterRespChannel(id uint64) {
|
||||||
|
c.respLock.Lock()
|
||||||
|
defer c.respLock.Unlock()
|
||||||
|
if ch, ok := c.respChannels[id]; ok {
|
||||||
|
delete(c.respChannels, id)
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) getResponseChannel(id uint64) chan *response.Raw {
|
||||||
|
c.respLock.RLock()
|
||||||
|
defer c.respLock.RUnlock()
|
||||||
|
return c.respChannels[id]
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
|
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
|
||||||
|
ch := make(chan *response.Raw)
|
||||||
|
c.registerRespChannel(r.ID, ch)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
return nil, errors.New("connection lost")
|
return nil, errors.New("connection lost before sending the request")
|
||||||
case c.requests <- r:
|
case c.requests <- r:
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
return nil, errors.New("connection lost")
|
return nil, errors.New("connection lost while waiting for the response")
|
||||||
case resp := <-c.responses:
|
case resp := <-ch:
|
||||||
|
c.unregisterRespChannel(r.ID)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,7 @@ func TestWSClientSubscription(t *testing.T) {
|
||||||
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
|
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
id, err := f(wsc)
|
id, err := f(wsc)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -58,6 +59,7 @@ func TestWSClientSubscription(t *testing.T) {
|
||||||
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`)
|
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`)
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
_, err = f(wsc)
|
_, err = f(wsc)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
@ -107,6 +109,7 @@ func TestWSClientUnsubscription(t *testing.T) {
|
||||||
srv := initTestServer(t, rc.response)
|
srv := initTestServer(t, rc.response)
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
rc.code(t, wsc)
|
rc.code(t, wsc)
|
||||||
})
|
})
|
||||||
|
@ -143,6 +146,7 @@ func TestWSClientEvents(t *testing.T) {
|
||||||
|
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually.
|
wsc.cache.initDone = true // Our server mock is restricted, so perform initialisation manually.
|
||||||
wsc.cache.network = netmode.UnitTestNet
|
wsc.cache.network = netmode.UnitTestNet
|
||||||
for range events {
|
for range events {
|
||||||
|
@ -167,6 +171,7 @@ func TestWSExecutionVMStateCheck(t *testing.T) {
|
||||||
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
|
srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`)
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
require.NoError(t, wsc.Init())
|
require.NoError(t, wsc.Init())
|
||||||
filter := "NONE"
|
filter := "NONE"
|
||||||
_, err = wsc.SubscribeForTransactionExecutions(&filter)
|
_, err = wsc.SubscribeForTransactionExecutions(&filter)
|
||||||
|
@ -316,6 +321,7 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
wsc.getNextRequestID = getTestRequestID
|
||||||
wsc.cache.network = netmode.UnitTestNet
|
wsc.cache.network = netmode.UnitTestNet
|
||||||
c.clientCode(t, wsc)
|
c.clientCode(t, wsc)
|
||||||
wsc.Close()
|
wsc.Close()
|
||||||
|
@ -329,6 +335,8 @@ func TestNewWS(t *testing.T) {
|
||||||
t.Run("good", func(t *testing.T) {
|
t.Run("good", func(t *testing.T) {
|
||||||
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
c, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
c.getNextRequestID = getTestRequestID
|
||||||
|
c.cache.network = netmode.UnitTestNet
|
||||||
require.NoError(t, c.Init())
|
require.NoError(t, c.Init())
|
||||||
})
|
})
|
||||||
t.Run("bad URL", func(t *testing.T) {
|
t.Run("bad URL", func(t *testing.T) {
|
||||||
|
|
|
@ -32,12 +32,12 @@ func NewRawParams(vals ...interface{}) RawParams {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Raw represents JSON-RPC request.
|
// Raw represents JSON-RPC request on the Client side.
|
||||||
type Raw struct {
|
type Raw struct {
|
||||||
JSONRPC string `json:"jsonrpc"`
|
JSONRPC string `json:"jsonrpc"`
|
||||||
Method string `json:"method"`
|
Method string `json:"method"`
|
||||||
RawParams []interface{} `json:"params"`
|
RawParams []interface{} `json:"params"`
|
||||||
ID int `json:"id"`
|
ID uint64 `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request contains standard JSON-RPC 2.0 request and batch of
|
// Request contains standard JSON-RPC 2.0 request and batch of
|
||||||
|
|
Loading…
Reference in a new issue