From 556ab39a5a2745893f16407e0c391526ee64a05e Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 29 Apr 2020 22:51:43 +0300 Subject: [PATCH] rpc/client: add minimalistic websocket client --- pkg/rpc/client/client.go | 10 +- pkg/rpc/client/rpc.go | 2 +- pkg/rpc/client/rpc_test.go | 53 +++++++++-- pkg/rpc/client/wsclient.go | 160 ++++++++++++++++++++++++++++++++ pkg/rpc/client/wsclient_test.go | 16 ++++ 5 files changed, 230 insertions(+), 11 deletions(-) create mode 100644 pkg/rpc/client/wsclient.go create mode 100644 pkg/rpc/client/wsclient_test.go diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index f6f44729a..0f44dc123 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -32,9 +32,10 @@ type Client struct { cli *http.Client endpoint *url.URL ctx context.Context + opts Options + requestF func(*request.Raw) (*response.Raw, error) wifMu *sync.Mutex wif *keys.WIF - balancer request.BalanceGetter } // Options defines options for the RPC client. @@ -94,7 +95,8 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { if opts.Balancer == nil { opts.Balancer = cl } - cl.balancer = opts.Balancer + cl.opts = opts + cl.requestF = cl.makeHTTPRequest return cl, nil } @@ -152,12 +154,14 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{ ID: 1, } - raw, err := c.makeHTTPRequest(&r) + raw, err := c.requestF(&r) if raw != nil && raw.Error != nil { return raw.Error } else if err != nil { return err + } else if raw == nil || raw.Result == nil { + return errors.New("no result returned") } return json.Unmarshal(raw.Result, v) } diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index 400303f88..09257d09a 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -484,7 +484,7 @@ func (c *Client) TransferAsset(asset util.Uint256, address string, amount util.F Address: address, Value: amount, WIF: c.WIF(), - Balancer: c.balancer, + Balancer: c.opts.Balancer, } resp util.Uint256 ) diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index 6cef59d25..0191e82d4 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -3,11 +3,13 @@ package client import ( "context" "encoding/hex" - "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" + "time" + "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -1427,7 +1429,22 @@ var rpcClientErrorCases = map[string][]rpcClientErrorCase{ }, } -func TestRPCClient(t *testing.T) { +func TestRPCClients(t *testing.T) { + t.Run("Client", func(t *testing.T) { + testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { + return New(ctx, endpoint, opts) + }) + }) + t.Run("WSClient", func(t *testing.T) { + testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { + wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) + require.NoError(t, err) + return &wsc.Client, nil + }) + }) +} + +func testRPCClient(t *testing.T, newClient func(context.Context, string, Options) (*Client, error)) { for method, testBatch := range rpcClientTestCases { t.Run(method, func(t *testing.T) { for _, testCase := range testBatch { @@ -1437,7 +1454,7 @@ func TestRPCClient(t *testing.T) { endpoint := srv.URL opts := Options{} - c, err := New(context.TODO(), endpoint, opts) + c, err := newClient(context.TODO(), endpoint, opts) if err != nil { t.Fatal(err) } @@ -1461,7 +1478,7 @@ func TestRPCClient(t *testing.T) { endpoint := srv.URL opts := Options{} - c, err := New(context.TODO(), endpoint, opts) + c, err := newClient(context.TODO(), endpoint, opts) if err != nil { t.Fatal(err) } @@ -1475,8 +1492,31 @@ func TestRPCClient(t *testing.T) { } } +func httpURLtoWS(url string) string { + return "ws" + strings.TrimPrefix(url, "http") + "/ws" +} + func initTestServer(t *testing.T, resp string) *httptest.Server { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ws" && req.Method == "GET" { + var upgrader = websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + for { + ws.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, _, err = ws.ReadMessage() + if err != nil { + break + } + ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) + err = ws.WriteMessage(1, []byte(resp)) + if err != nil { + break + } + } + ws.Close() + return + } requestHandler(t, w, resp) })) @@ -1485,10 +1525,9 @@ func initTestServer(t *testing.T, resp string) *httptest.Server { func requestHandler(t *testing.T, w http.ResponseWriter, resp string) { w.Header().Set("Content-Type", "application/json; charset=utf-8") - encoder := json.NewEncoder(w) - err := encoder.Encode(json.RawMessage(resp)) + _, err := w.Write([]byte(resp)) if err != nil { - t.Fatalf("Error encountered while encoding response: %s", err.Error()) + t.Fatalf("Error writing response: %s", err.Error()) } } diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go new file mode 100644 index 000000000..bdac24816 --- /dev/null +++ b/pkg/rpc/client/wsclient.go @@ -0,0 +1,160 @@ +package client + +import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/rpc/request" + "github.com/nspcc-dev/neo-go/pkg/rpc/response" +) + +// WSClient is a websocket-enabled RPC client that can be used with appropriate +// servers. It's supposed to be faster than Client because it has persistent +// connection to the server and at the same time is exposes some functionality +// that is only provided via websockets (like event subscription mechanism). +type WSClient struct { + Client + ws *websocket.Conn + done chan struct{} + notifications chan *request.In + responses chan *response.Raw + requests chan *request.Raw + shutdown chan struct{} +} + +// requestResponse is a combined type for request and response since we can get +// any of them here. +type requestResponse struct { + request.In + Error *response.Error `json:"error,omitempty"` + Result json.RawMessage `json:"result,omitempty"` +} + +const ( + // Message limit for receiving side. + wsReadLimit = 10 * 1024 * 1024 + + // Disconnection timeout. + wsPongLimit = 60 * time.Second + + // Ping period for connection liveness check. + wsPingPeriod = wsPongLimit / 2 + + // Write deadline. + wsWriteLimit = wsPingPeriod / 2 +) + +// NewWS returns a new WSClient ready to use (with established websocket +// connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`. +func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) { + cl, err := New(ctx, endpoint, opts) + cl.cli = nil + + dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} + ws, _, err := dialer.Dial(endpoint, nil) + if err != nil { + return nil, err + } + wsc := &WSClient{ + Client: *cl, + ws: ws, + shutdown: make(chan struct{}), + done: make(chan struct{}), + responses: make(chan *response.Raw), + requests: make(chan *request.Raw), + } + go wsc.wsReader() + go wsc.wsWriter() + wsc.requestF = wsc.makeWsRequest + return wsc, nil +} + +// Close closes connection to the remote side rendering this client instance +// unusable. +func (c *WSClient) Close() { + // Closing shutdown channel send signal to wsWriter to break out of the + // loop. In doing so it does ws.Close() closing the network connection + // which in turn makes wsReader receieve err from ws,ReadJSON() and also + // break out of the loop closing c.done channel in its shutdown sequence. + close(c.shutdown) + <-c.done +} + +func (c *WSClient) wsReader() { + c.ws.SetReadLimit(wsReadLimit) + c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) + for { + rr := new(requestResponse) + c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) + err := c.ws.ReadJSON(rr) + if err != nil { + // Timeout/connection loss/malformed response. + break + } + if rr.RawID == nil && rr.Method != "" { + if c.notifications != nil { + c.notifications <- &rr.In + } + } else if rr.RawID != nil && (rr.Error != nil || rr.Result != nil) { + resp := new(response.Raw) + resp.ID = rr.RawID + resp.JSONRPC = rr.JSONRPC + resp.Error = rr.Error + resp.Result = rr.Result + c.responses <- resp + } else { + // Malformed response, neither valid request, nor valid response. + break + } + } + close(c.done) + close(c.responses) + if c.notifications != nil { + close(c.notifications) + } +} + +func (c *WSClient) wsWriter() { + pingTicker := time.NewTicker(wsPingPeriod) + defer c.ws.Close() + defer pingTicker.Stop() + for { + select { + case <-c.shutdown: + return + case <-c.done: + return + case req, ok := <-c.requests: + if !ok { + return + } + c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout)) + if err := c.ws.WriteJSON(req); err != nil { + return + } + case <-pingTicker.C: + c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)) + if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + return + } + } + } + +} + +func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) { + select { + case <-c.done: + return nil, errors.New("connection lost") + case c.requests <- r: + } + select { + case <-c.done: + return nil, errors.New("connection lost") + case resp := <-c.responses: + return resp, nil + } +} diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go new file mode 100644 index 000000000..2a996999a --- /dev/null +++ b/pkg/rpc/client/wsclient_test.go @@ -0,0 +1,16 @@ +package client + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWSClientClose(t *testing.T) { + srv := initTestServer(t, "") + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + wsc.Close() +}