rpc/client: add minimalistic websocket client
This commit is contained in:
parent
6333060897
commit
556ab39a5a
5 changed files with 230 additions and 11 deletions
|
@ -32,9 +32,10 @@ type Client struct {
|
||||||
cli *http.Client
|
cli *http.Client
|
||||||
endpoint *url.URL
|
endpoint *url.URL
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
opts Options
|
||||||
|
requestF func(*request.Raw) (*response.Raw, error)
|
||||||
wifMu *sync.Mutex
|
wifMu *sync.Mutex
|
||||||
wif *keys.WIF
|
wif *keys.WIF
|
||||||
balancer request.BalanceGetter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options defines options for the RPC client.
|
// Options defines options for the RPC client.
|
||||||
|
@ -94,7 +95,8 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
if opts.Balancer == nil {
|
if opts.Balancer == nil {
|
||||||
opts.Balancer = cl
|
opts.Balancer = cl
|
||||||
}
|
}
|
||||||
cl.balancer = opts.Balancer
|
cl.opts = opts
|
||||||
|
cl.requestF = cl.makeHTTPRequest
|
||||||
return cl, nil
|
return cl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,12 +154,14 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
raw, err := c.makeHTTPRequest(&r)
|
raw, err := c.requestF(&r)
|
||||||
|
|
||||||
if raw != nil && raw.Error != nil {
|
if raw != nil && raw.Error != nil {
|
||||||
return raw.Error
|
return raw.Error
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
|
} else if raw == nil || raw.Result == nil {
|
||||||
|
return errors.New("no result returned")
|
||||||
}
|
}
|
||||||
return json.Unmarshal(raw.Result, v)
|
return json.Unmarshal(raw.Result, v)
|
||||||
}
|
}
|
||||||
|
|
|
@ -484,7 +484,7 @@ func (c *Client) TransferAsset(asset util.Uint256, address string, amount util.F
|
||||||
Address: address,
|
Address: address,
|
||||||
Value: amount,
|
Value: amount,
|
||||||
WIF: c.WIF(),
|
WIF: c.WIF(),
|
||||||
Balancer: c.balancer,
|
Balancer: c.opts.Balancer,
|
||||||
}
|
}
|
||||||
resp util.Uint256
|
resp util.Uint256
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,11 +3,13 @@ package client
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
@ -1427,7 +1429,22 @@ var rpcClientErrorCases = map[string][]rpcClientErrorCase{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRPCClient(t *testing.T) {
|
func TestRPCClients(t *testing.T) {
|
||||||
|
t.Run("Client", func(t *testing.T) {
|
||||||
|
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
|
return New(ctx, endpoint, opts)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("WSClient", func(t *testing.T) {
|
||||||
|
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
|
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return &wsc.Client, nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testRPCClient(t *testing.T, newClient func(context.Context, string, Options) (*Client, error)) {
|
||||||
for method, testBatch := range rpcClientTestCases {
|
for method, testBatch := range rpcClientTestCases {
|
||||||
t.Run(method, func(t *testing.T) {
|
t.Run(method, func(t *testing.T) {
|
||||||
for _, testCase := range testBatch {
|
for _, testCase := range testBatch {
|
||||||
|
@ -1437,7 +1454,7 @@ func TestRPCClient(t *testing.T) {
|
||||||
|
|
||||||
endpoint := srv.URL
|
endpoint := srv.URL
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
c, err := New(context.TODO(), endpoint, opts)
|
c, err := newClient(context.TODO(), endpoint, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -1461,7 +1478,7 @@ func TestRPCClient(t *testing.T) {
|
||||||
|
|
||||||
endpoint := srv.URL
|
endpoint := srv.URL
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
c, err := New(context.TODO(), endpoint, opts)
|
c, err := newClient(context.TODO(), endpoint, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -1475,8 +1492,31 @@ func TestRPCClient(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func httpURLtoWS(url string) string {
|
||||||
|
return "ws" + strings.TrimPrefix(url, "http") + "/ws"
|
||||||
|
}
|
||||||
|
|
||||||
func initTestServer(t *testing.T, resp string) *httptest.Server {
|
func initTestServer(t *testing.T, resp string) *httptest.Server {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.URL.Path == "/ws" && req.Method == "GET" {
|
||||||
|
var upgrader = websocket.Upgrader{}
|
||||||
|
ws, err := upgrader.Upgrade(w, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for {
|
||||||
|
ws.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
_, _, err = ws.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
ws.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
err = ws.WriteMessage(1, []byte(resp))
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
requestHandler(t, w, resp)
|
requestHandler(t, w, resp)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -1485,10 +1525,9 @@ func initTestServer(t *testing.T, resp string) *httptest.Server {
|
||||||
|
|
||||||
func requestHandler(t *testing.T, w http.ResponseWriter, resp string) {
|
func requestHandler(t *testing.T, w http.ResponseWriter, resp string) {
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
encoder := json.NewEncoder(w)
|
_, err := w.Write([]byte(resp))
|
||||||
err := encoder.Encode(json.RawMessage(resp))
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error encountered while encoding response: %s", err.Error())
|
t.Fatalf("Error writing response: %s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
160
pkg/rpc/client/wsclient.go
Normal file
160
pkg/rpc/client/wsclient.go
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WSClient is a websocket-enabled RPC client that can be used with appropriate
|
||||||
|
// servers. It's supposed to be faster than Client because it has persistent
|
||||||
|
// connection to the server and at the same time is exposes some functionality
|
||||||
|
// that is only provided via websockets (like event subscription mechanism).
|
||||||
|
type WSClient struct {
|
||||||
|
Client
|
||||||
|
ws *websocket.Conn
|
||||||
|
done chan struct{}
|
||||||
|
notifications chan *request.In
|
||||||
|
responses chan *response.Raw
|
||||||
|
requests chan *request.Raw
|
||||||
|
shutdown chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestResponse is a combined type for request and response since we can get
|
||||||
|
// any of them here.
|
||||||
|
type requestResponse struct {
|
||||||
|
request.In
|
||||||
|
Error *response.Error `json:"error,omitempty"`
|
||||||
|
Result json.RawMessage `json:"result,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Message limit for receiving side.
|
||||||
|
wsReadLimit = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
// Disconnection timeout.
|
||||||
|
wsPongLimit = 60 * time.Second
|
||||||
|
|
||||||
|
// Ping period for connection liveness check.
|
||||||
|
wsPingPeriod = wsPongLimit / 2
|
||||||
|
|
||||||
|
// Write deadline.
|
||||||
|
wsWriteLimit = wsPingPeriod / 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWS returns a new WSClient ready to use (with established websocket
|
||||||
|
// connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`.
|
||||||
|
func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) {
|
||||||
|
cl, err := New(ctx, endpoint, opts)
|
||||||
|
cl.cli = nil
|
||||||
|
|
||||||
|
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
|
||||||
|
ws, _, err := dialer.Dial(endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
wsc := &WSClient{
|
||||||
|
Client: *cl,
|
||||||
|
ws: ws,
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
responses: make(chan *response.Raw),
|
||||||
|
requests: make(chan *request.Raw),
|
||||||
|
}
|
||||||
|
go wsc.wsReader()
|
||||||
|
go wsc.wsWriter()
|
||||||
|
wsc.requestF = wsc.makeWsRequest
|
||||||
|
return wsc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes connection to the remote side rendering this client instance
|
||||||
|
// unusable.
|
||||||
|
func (c *WSClient) Close() {
|
||||||
|
// Closing shutdown channel send signal to wsWriter to break out of the
|
||||||
|
// loop. In doing so it does ws.Close() closing the network connection
|
||||||
|
// which in turn makes wsReader receieve err from ws,ReadJSON() and also
|
||||||
|
// break out of the loop closing c.done channel in its shutdown sequence.
|
||||||
|
close(c.shutdown)
|
||||||
|
<-c.done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) wsReader() {
|
||||||
|
c.ws.SetReadLimit(wsReadLimit)
|
||||||
|
c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
|
||||||
|
for {
|
||||||
|
rr := new(requestResponse)
|
||||||
|
c.ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
||||||
|
err := c.ws.ReadJSON(rr)
|
||||||
|
if err != nil {
|
||||||
|
// Timeout/connection loss/malformed response.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rr.RawID == nil && rr.Method != "" {
|
||||||
|
if c.notifications != nil {
|
||||||
|
c.notifications <- &rr.In
|
||||||
|
}
|
||||||
|
} else if rr.RawID != nil && (rr.Error != nil || rr.Result != nil) {
|
||||||
|
resp := new(response.Raw)
|
||||||
|
resp.ID = rr.RawID
|
||||||
|
resp.JSONRPC = rr.JSONRPC
|
||||||
|
resp.Error = rr.Error
|
||||||
|
resp.Result = rr.Result
|
||||||
|
c.responses <- resp
|
||||||
|
} else {
|
||||||
|
// Malformed response, neither valid request, nor valid response.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(c.done)
|
||||||
|
close(c.responses)
|
||||||
|
if c.notifications != nil {
|
||||||
|
close(c.notifications)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) wsWriter() {
|
||||||
|
pingTicker := time.NewTicker(wsPingPeriod)
|
||||||
|
defer c.ws.Close()
|
||||||
|
defer pingTicker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.shutdown:
|
||||||
|
return
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
case req, ok := <-c.requests:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout))
|
||||||
|
if err := c.ws.WriteJSON(req); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-pingTicker.C:
|
||||||
|
c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
|
||||||
|
if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return nil, errors.New("connection lost")
|
||||||
|
case c.requests <- r:
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return nil, errors.New("connection lost")
|
||||||
|
case resp := <-c.responses:
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
16
pkg/rpc/client/wsclient_test.go
Normal file
16
pkg/rpc/client/wsclient_test.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWSClientClose(t *testing.T) {
|
||||||
|
srv := initTestServer(t, "")
|
||||||
|
defer srv.Close()
|
||||||
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
wsc.Close()
|
||||||
|
}
|
Loading…
Reference in a new issue