forked from TrueCloudLab/neoneo-go
rpc/client: add minimalistic websocket client
This commit is contained in:
parent
a458a17748
commit
3de48d7d90
5 changed files with 230 additions and 11 deletions
|
@ -34,9 +34,10 @@ type Client struct {
|
|||
cli *http.Client
|
||||
endpoint *url.URL
|
||||
ctx context.Context
|
||||
opts Options
|
||||
requestF func(*request.Raw) (*response.Raw, error)
|
||||
wifMu *sync.Mutex
|
||||
wif *keys.WIF
|
||||
balancer request.BalanceGetter
|
||||
cache cache
|
||||
}
|
||||
|
||||
|
@ -109,7 +110,8 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
|||
if opts.Balancer == nil {
|
||||
opts.Balancer = cl
|
||||
}
|
||||
cl.balancer = opts.Balancer
|
||||
cl.opts = opts
|
||||
cl.requestF = cl.makeHTTPRequest
|
||||
return cl, nil
|
||||
}
|
||||
|
||||
|
@ -167,12 +169,14 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{
|
|||
ID: 1,
|
||||
}
|
||||
|
||||
raw, err := c.makeHTTPRequest(&r)
|
||||
raw, err := c.requestF(&r)
|
||||
|
||||
if raw != nil && raw.Error != nil {
|
||||
return raw.Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else if raw == nil || raw.Result == nil {
|
||||
return errors.New("no result returned")
|
||||
}
|
||||
return json.Unmarshal(raw.Result, v)
|
||||
}
|
||||
|
|
|
@ -482,7 +482,7 @@ func (c *Client) TransferAsset(asset util.Uint256, address string, amount util.F
|
|||
Address: address,
|
||||
Value: amount,
|
||||
WIF: c.WIF(),
|
||||
Balancer: c.balancer,
|
||||
Balancer: c.opts.Balancer,
|
||||
}
|
||||
resp util.Uint256
|
||||
)
|
||||
|
|
|
@ -3,11 +3,13 @@ package client
|
|||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
|
@ -1433,7 +1435,22 @@ var rpcClientErrorCases = map[string][]rpcClientErrorCase{
|
|||
},
|
||||
}
|
||||
|
||||
func TestRPCClient(t *testing.T) {
|
||||
func TestRPCClients(t *testing.T) {
|
||||
t.Run("Client", func(t *testing.T) {
|
||||
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||
return New(ctx, endpoint, opts)
|
||||
})
|
||||
})
|
||||
t.Run("WSClient", func(t *testing.T) {
|
||||
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
|
||||
require.NoError(t, err)
|
||||
return &wsc.Client, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testRPCClient(t *testing.T, newClient func(context.Context, string, Options) (*Client, error)) {
|
||||
for method, testBatch := range rpcClientTestCases {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
for _, testCase := range testBatch {
|
||||
|
@ -1443,7 +1460,7 @@ func TestRPCClient(t *testing.T) {
|
|||
|
||||
endpoint := srv.URL
|
||||
opts := Options{}
|
||||
c, err := New(context.TODO(), endpoint, opts)
|
||||
c, err := newClient(context.TODO(), endpoint, opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1467,7 +1484,7 @@ func TestRPCClient(t *testing.T) {
|
|||
|
||||
endpoint := srv.URL
|
||||
opts := Options{}
|
||||
c, err := New(context.TODO(), endpoint, opts)
|
||||
c, err := newClient(context.TODO(), endpoint, opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1481,8 +1498,31 @@ func TestRPCClient(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func httpURLtoWS(url string) string {
|
||||
return "ws" + strings.TrimPrefix(url, "http") + "/ws"
|
||||
}
|
||||
|
||||
func initTestServer(t *testing.T, resp string) *httptest.Server {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path == "/ws" && req.Method == "GET" {
|
||||
var upgrader = websocket.Upgrader{}
|
||||
ws, err := upgrader.Upgrade(w, req, nil)
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
ws.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
_, _, err = ws.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
ws.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||
err = ws.WriteMessage(1, []byte(resp))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
ws.Close()
|
||||
return
|
||||
}
|
||||
requestHandler(t, w, resp)
|
||||
}))
|
||||
|
||||
|
@ -1491,11 +1531,10 @@ func initTestServer(t *testing.T, resp string) *httptest.Server {
|
|||
|
||||
func requestHandler(t *testing.T, w http.ResponseWriter, resp string) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
encoder := json.NewEncoder(w)
|
||||
err := encoder.Encode(json.RawMessage(resp))
|
||||
_, err := w.Write([]byte(resp))
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error encountered while encoding response: %s", err.Error())
|
||||
t.Fatalf("Error writing response: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
160
pkg/rpc/client/wsclient.go
Normal file
160
pkg/rpc/client/wsclient.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
||||
)
|
||||
|
||||
// WSClient is a websocket-enabled RPC client that can be used with appropriate
|
||||
// servers. It's supposed to be faster than Client because it has persistent
|
||||
// connection to the server and at the same time is exposes some functionality
|
||||
// that is only provided via websockets (like event subscription mechanism).
|
||||
type WSClient struct {
|
||||
Client
|
||||
ws *websocket.Conn
|
||||
done chan struct{}
|
||||
notifications chan *request.In
|
||||
responses chan *response.Raw
|
||||
requests chan *request.Raw
|
||||
shutdown chan struct{}
|
||||
}
|
||||
|
||||
// requestResponse is a combined type for request and response since we can get
|
||||
// any of them here.
|
||||
type requestResponse struct {
|
||||
request.In
|
||||
Error *response.Error `json:"error,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// Message limit for receiving side.
|
||||
wsReadLimit = 10 * 1024 * 1024
|
||||
|
||||
// Disconnection timeout.
|
||||
wsPongLimit = 60 * time.Second
|
||||
|
||||
// Ping period for connection liveness check.
|
||||
wsPingPeriod = wsPongLimit / 2
|
||||
|
||||
// Write deadline.
|
||||
wsWriteLimit = wsPingPeriod / 2
|
||||
)
|
||||
|
||||
// NewWS returns a new WSClient ready to use (with established websocket
|
||||
// connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`.
|
||||
func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) {
|
||||
cl, err := New(ctx, endpoint, opts)
|
||||
cl.cli = nil
|
||||
|
||||
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
|
||||
ws, _, err := dialer.Dial(endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wsc := &WSClient{
|
||||
Client: *cl,
|
||||
ws: ws,
|
||||
shutdown: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
responses: make(chan *response.Raw),
|
||||
requests: make(chan *request.Raw),
|
||||
}
|
||||
go wsc.wsReader()
|
||||
go wsc.wsWriter()
|
||||
wsc.requestF = wsc.makeWsRequest
|
||||
return wsc, nil
|
||||
}
|
||||
|
||||
// Close closes connection to the remote side rendering this client instance
|
||||
// unusable.
|
||||
func (c *WSClient) Close() {
|
||||
// Closing shutdown channel send signal to wsWriter to break out of the
|
||||
// loop. In doing so it does ws.Close() closing the network connection
|
||||
// which in turn makes wsReader receieve err from ws,ReadJSON() and also
|
||||
// break out of the loop closing c.done channel in its shutdown sequence.
|
||||
close(c.shutdown)
|
||||
<-c.done
|
||||
}
|
||||
|
||||
func (c *WSClient) wsReader() {
|
||||
c.ws.SetReadLimit(wsReadLimit)
|
||||
c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
|
||||
for {
|
||||
rr := new(requestResponse)
|
||||
c.ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
||||
err := c.ws.ReadJSON(rr)
|
||||
if err != nil {
|
||||
// Timeout/connection loss/malformed response.
|
||||
break
|
||||
}
|
||||
if rr.RawID == nil && rr.Method != "" {
|
||||
if c.notifications != nil {
|
||||
c.notifications <- &rr.In
|
||||
}
|
||||
} else if rr.RawID != nil && (rr.Error != nil || rr.Result != nil) {
|
||||
resp := new(response.Raw)
|
||||
resp.ID = rr.RawID
|
||||
resp.JSONRPC = rr.JSONRPC
|
||||
resp.Error = rr.Error
|
||||
resp.Result = rr.Result
|
||||
c.responses <- resp
|
||||
} else {
|
||||
// Malformed response, neither valid request, nor valid response.
|
||||
break
|
||||
}
|
||||
}
|
||||
close(c.done)
|
||||
close(c.responses)
|
||||
if c.notifications != nil {
|
||||
close(c.notifications)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) wsWriter() {
|
||||
pingTicker := time.NewTicker(wsPingPeriod)
|
||||
defer c.ws.Close()
|
||||
defer pingTicker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.shutdown:
|
||||
return
|
||||
case <-c.done:
|
||||
return
|
||||
case req, ok := <-c.requests:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout))
|
||||
if err := c.ws.WriteJSON(req); err != nil {
|
||||
return
|
||||
}
|
||||
case <-pingTicker.C:
|
||||
c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
|
||||
if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
|
||||
select {
|
||||
case <-c.done:
|
||||
return nil, errors.New("connection lost")
|
||||
case c.requests <- r:
|
||||
}
|
||||
select {
|
||||
case <-c.done:
|
||||
return nil, errors.New("connection lost")
|
||||
case resp := <-c.responses:
|
||||
return resp, nil
|
||||
}
|
||||
}
|
16
pkg/rpc/client/wsclient_test.go
Normal file
16
pkg/rpc/client/wsclient_test.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWSClientClose(t *testing.T) {
|
||||
srv := initTestServer(t, "")
|
||||
defer srv.Close()
|
||||
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||
require.NoError(t, err)
|
||||
wsc.Close()
|
||||
}
|
Loading…
Reference in a new issue