Merge pull request #925 from nspcc-dev/rpc-over-websocket

RPC over websocket
This commit is contained in:
Roman Khimov 2020-05-06 12:52:58 +03:00 committed by GitHub
commit b04c8623c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 543 additions and 255 deletions

View file

@ -31,33 +31,34 @@ const (
// Client represents the middleman for executing JSON RPC calls
// to remote NEO RPC nodes.
type Client struct {
// The underlying http client. It's never a good practice to use
// the http.DefaultClient, therefore we will role our own.
cliMu *sync.Mutex
cli *http.Client
endpoint *url.URL
ctx context.Context
version string
wifMu *sync.Mutex
wif *keys.WIF
balancerMu *sync.Mutex
balancer request.BalanceGetter
cache cache
cli *http.Client
endpoint *url.URL
ctx context.Context
opts Options
requestF func(*request.Raw) (*response.Raw, error)
wifMu *sync.Mutex
wif *keys.WIF
cache cache
}
// Options defines options for the RPC client.
// All Values are optional. If any duration is not specified
// a default of 3 seconds will be used.
// All values are optional. If any duration is not specified
// a default of 4 seconds will be used.
type Options struct {
Cert string
Key string
CACert string
DialTimeout time.Duration
Client *http.Client
// Version is the version of the client that will be send
// along with the request body. If no version is specified
// the default version (currently 2.0) will be used.
Version string
// Balancer is an implementation of request.BalanceGetter interface,
// if not set then the default Client's implementation will be used, but
// it relies on server support for `getunspents` RPC call which is
// standard for neo-go, but only implemented as a plugin for C# node. So
// you can override it here to use NeoScanServer for example.
Balancer request.BalanceGetter
// Cert is a client-side certificate, it doesn't work at the moment along
// with the other two options below.
Cert string
Key string
CACert string
DialTimeout time.Duration
RequestTimeout time.Duration
}
// cache stores cache values for the RPC client methods
@ -79,37 +80,39 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
return nil, err
}
if opts.Version == "" {
opts.Version = defaultClientVersion
if opts.DialTimeout <= 0 {
opts.DialTimeout = defaultDialTimeout
}
if opts.Client == nil {
opts.Client = &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: opts.DialTimeout,
}).DialContext,
},
}
if opts.RequestTimeout <= 0 {
opts.RequestTimeout = defaultRequestTimeout
}
httpClient := &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: opts.DialTimeout,
}).DialContext,
},
Timeout: opts.RequestTimeout,
}
// TODO(@antdm): Enable SSL.
if opts.Cert != "" && opts.Key != "" {
}
if opts.Client.Timeout == 0 {
opts.Client.Timeout = defaultRequestTimeout
cl := &Client{
ctx: ctx,
cli: httpClient,
wifMu: new(sync.Mutex),
endpoint: url,
}
return &Client{
ctx: ctx,
cli: opts.Client,
cliMu: new(sync.Mutex),
balancerMu: new(sync.Mutex),
wifMu: new(sync.Mutex),
endpoint: url,
version: opts.Version,
}, nil
if opts.Balancer == nil {
opts.Balancer = cl
}
cl.opts = opts
cl.requestF = cl.makeHTTPRequest
return cl, nil
}
// WIF returns WIF structure associated with the client.
@ -137,43 +140,10 @@ func (c *Client) SetWIF(wif string) error {
return nil
}
// Balancer is a getter for balance field.
func (c *Client) Balancer() request.BalanceGetter {
c.balancerMu.Lock()
defer c.balancerMu.Unlock()
return c.balancer
}
// SetBalancer is a setter for balance field.
func (c *Client) SetBalancer(b request.BalanceGetter) {
c.balancerMu.Lock()
defer c.balancerMu.Unlock()
if b != nil {
c.balancer = b
}
}
// Client is a getter for client field.
func (c *Client) Client() *http.Client {
c.cliMu.Lock()
defer c.cliMu.Unlock()
return c.cli
}
// SetClient is a setter for client field.
func (c *Client) SetClient(cli *http.Client) {
c.cliMu.Lock()
defer c.cliMu.Unlock()
if cli != nil {
c.cli = cli
}
}
// CalculateInputs creates input transactions for the specified amount of given
// asset belonging to specified address. This implementation uses GetUnspents
// JSON-RPC call internally, so make sure your RPC server supports that.
// CalculateInputs implements request.BalanceGetter interface and returns inputs
// array for the specified amount of given asset belonging to specified address.
// This implementation uses GetUnspents JSON-RPC call internally, so make sure
// your RPC server supports that.
func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.Fixed8) ([]transaction.Input, util.Fixed8, error) {
var utxos state.UnspentBalances
@ -192,47 +162,59 @@ func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.F
}
func (c *Client) performRequest(method string, p request.RawParams, v interface{}) error {
var r = request.Raw{
JSONRPC: request.JSONRPCVersion,
Method: method,
RawParams: p.Values,
ID: 1,
}
raw, err := c.requestF(&r)
if raw != nil && raw.Error != nil {
return raw.Error
} else if err != nil {
return err
} else if raw == nil || raw.Result == nil {
return errors.New("no result returned")
}
return json.Unmarshal(raw.Result, v)
}
func (c *Client) makeHTTPRequest(r *request.Raw) (*response.Raw, error) {
var (
r = request.Raw{
JSONRPC: c.version,
Method: method,
RawParams: p.Values,
ID: 1,
}
buf = new(bytes.Buffer)
raw = &response.Raw{}
raw = new(response.Raw)
)
if err := json.NewEncoder(buf).Encode(r); err != nil {
return err
return nil, err
}
req, err := http.NewRequest("POST", c.endpoint.String(), buf)
if err != nil {
return err
return nil, err
}
resp, err := c.Client().Do(req)
resp, err := c.cli.Do(req)
if err != nil {
return err
return nil, err
}
defer resp.Body.Close()
// The node might send us proper JSON anyway, so look there first and if
// it parses, then it has more relevant data than HTTP error code.
err = json.NewDecoder(resp.Body).Decode(raw)
if err == nil {
if raw.Error != nil {
err = raw.Error
if err != nil {
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("HTTP %d/%s", resp.StatusCode, http.StatusText(resp.StatusCode))
} else {
err = json.Unmarshal(raw.Result, v)
err = errors.Wrap(err, "JSON decoding")
}
} else if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("HTTP %d/%s", resp.StatusCode, http.StatusText(resp.StatusCode))
} else {
err = errors.Wrap(err, "JSON decoding")
}
return err
if err != nil {
return nil, err
}
return raw, nil
}
// Ping attempts to create a connection to the endpoint.

View file

@ -482,7 +482,7 @@ func (c *Client) TransferAsset(asset util.Uint256, address string, amount util.F
Address: address,
Value: amount,
WIF: c.WIF(),
Balancer: c.Balancer(),
Balancer: c.opts.Balancer,
}
resp util.Uint256
)

View file

@ -3,11 +3,13 @@ package client
import (
"context"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
@ -1510,7 +1512,22 @@ var rpcClientErrorCases = map[string][]rpcClientErrorCase{
},
}
func TestRPCClient(t *testing.T) {
func TestRPCClients(t *testing.T) {
t.Run("Client", func(t *testing.T) {
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
return New(ctx, endpoint, opts)
})
})
t.Run("WSClient", func(t *testing.T) {
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
require.NoError(t, err)
return &wsc.Client, nil
})
})
}
func testRPCClient(t *testing.T, newClient func(context.Context, string, Options) (*Client, error)) {
for method, testBatch := range rpcClientTestCases {
t.Run(method, func(t *testing.T) {
for _, testCase := range testBatch {
@ -1520,7 +1537,7 @@ func TestRPCClient(t *testing.T) {
endpoint := srv.URL
opts := Options{}
c, err := New(context.TODO(), endpoint, opts)
c, err := newClient(context.TODO(), endpoint, opts)
if err != nil {
t.Fatal(err)
}
@ -1544,7 +1561,7 @@ func TestRPCClient(t *testing.T) {
endpoint := srv.URL
opts := Options{}
c, err := New(context.TODO(), endpoint, opts)
c, err := newClient(context.TODO(), endpoint, opts)
if err != nil {
t.Fatal(err)
}
@ -1558,8 +1575,31 @@ func TestRPCClient(t *testing.T) {
}
}
func httpURLtoWS(url string) string {
return "ws" + strings.TrimPrefix(url, "http") + "/ws"
}
func initTestServer(t *testing.T, resp string) *httptest.Server {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/ws" && req.Method == "GET" {
var upgrader = websocket.Upgrader{}
ws, err := upgrader.Upgrade(w, req, nil)
require.NoError(t, err)
for {
ws.SetReadDeadline(time.Now().Add(2 * time.Second))
_, _, err = ws.ReadMessage()
if err != nil {
break
}
ws.SetWriteDeadline(time.Now().Add(2 * time.Second))
err = ws.WriteMessage(1, []byte(resp))
if err != nil {
break
}
}
ws.Close()
return
}
requestHandler(t, w, resp)
}))
@ -1568,11 +1608,10 @@ func initTestServer(t *testing.T, resp string) *httptest.Server {
func requestHandler(t *testing.T, w http.ResponseWriter, resp string) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
encoder := json.NewEncoder(w)
err := encoder.Encode(json.RawMessage(resp))
_, err := w.Write([]byte(resp))
if err != nil {
t.Fatalf("Error encountered while encoding response: %s", err.Error())
t.Fatalf("Error writing response: %s", err.Error())
}
}

160
pkg/rpc/client/wsclient.go Normal file
View file

@ -0,0 +1,160 @@
package client
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/gorilla/websocket"
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
)
// WSClient is a websocket-enabled RPC client that can be used with appropriate
// servers. It's supposed to be faster than Client because it has persistent
// connection to the server and at the same time is exposes some functionality
// that is only provided via websockets (like event subscription mechanism).
type WSClient struct {
Client
ws *websocket.Conn
done chan struct{}
notifications chan *request.In
responses chan *response.Raw
requests chan *request.Raw
shutdown chan struct{}
}
// requestResponse is a combined type for request and response since we can get
// any of them here.
type requestResponse struct {
request.In
Error *response.Error `json:"error,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
}
const (
// Message limit for receiving side.
wsReadLimit = 10 * 1024 * 1024
// Disconnection timeout.
wsPongLimit = 60 * time.Second
// Ping period for connection liveness check.
wsPingPeriod = wsPongLimit / 2
// Write deadline.
wsWriteLimit = wsPingPeriod / 2
)
// NewWS returns a new WSClient ready to use (with established websocket
// connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`.
func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) {
cl, err := New(ctx, endpoint, opts)
cl.cli = nil
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
ws, _, err := dialer.Dial(endpoint, nil)
if err != nil {
return nil, err
}
wsc := &WSClient{
Client: *cl,
ws: ws,
shutdown: make(chan struct{}),
done: make(chan struct{}),
responses: make(chan *response.Raw),
requests: make(chan *request.Raw),
}
go wsc.wsReader()
go wsc.wsWriter()
wsc.requestF = wsc.makeWsRequest
return wsc, nil
}
// Close closes connection to the remote side rendering this client instance
// unusable.
func (c *WSClient) Close() {
// Closing shutdown channel send signal to wsWriter to break out of the
// loop. In doing so it does ws.Close() closing the network connection
// which in turn makes wsReader receieve err from ws,ReadJSON() and also
// break out of the loop closing c.done channel in its shutdown sequence.
close(c.shutdown)
<-c.done
}
func (c *WSClient) wsReader() {
c.ws.SetReadLimit(wsReadLimit)
c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
for {
rr := new(requestResponse)
c.ws.SetReadDeadline(time.Now().Add(wsPongLimit))
err := c.ws.ReadJSON(rr)
if err != nil {
// Timeout/connection loss/malformed response.
break
}
if rr.RawID == nil && rr.Method != "" {
if c.notifications != nil {
c.notifications <- &rr.In
}
} else if rr.RawID != nil && (rr.Error != nil || rr.Result != nil) {
resp := new(response.Raw)
resp.ID = rr.RawID
resp.JSONRPC = rr.JSONRPC
resp.Error = rr.Error
resp.Result = rr.Result
c.responses <- resp
} else {
// Malformed response, neither valid request, nor valid response.
break
}
}
close(c.done)
close(c.responses)
if c.notifications != nil {
close(c.notifications)
}
}
func (c *WSClient) wsWriter() {
pingTicker := time.NewTicker(wsPingPeriod)
defer c.ws.Close()
defer pingTicker.Stop()
for {
select {
case <-c.shutdown:
return
case <-c.done:
return
case req, ok := <-c.requests:
if !ok {
return
}
c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout))
if err := c.ws.WriteJSON(req); err != nil {
return
}
case <-pingTicker.C:
c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
return
}
}
}
}
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
select {
case <-c.done:
return nil, errors.New("connection lost")
case c.requests <- r:
}
select {
case <-c.done:
return nil, errors.New("connection lost")
case resp := <-c.responses:
return resp, nil
}
}

View file

@ -0,0 +1,16 @@
package client
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestWSClientClose(t *testing.T) {
srv := initTestServer(t, "")
defer srv.Close()
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
require.NoError(t, err)
wsc.Close()
}