diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go index bdac24816..6774dd421 100644 --- a/pkg/rpc/client/wsclient.go +++ b/pkg/rpc/client/wsclient.go @@ -7,8 +7,11 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response" + "github.com/nspcc-dev/neo-go/pkg/rpc/response/result" ) // WSClient is a websocket-enabled RPC client that can be used with appropriate @@ -17,12 +20,28 @@ import ( // that is only provided via websockets (like event subscription mechanism). type WSClient struct { Client + // Notifications is a channel that is used to send events received from + // server. Client's code is supposed to be reading from this channel if + // it wants to use subscription mechanism, failing to do so will cause + // WSClient to block even regular requests. This channel is not buffered. + // In case of protocol error or upon connection closure this channel will + // be closed, so make sure to handle this. + Notifications chan Notification + ws *websocket.Conn done chan struct{} - notifications chan *request.In responses chan *response.Raw requests chan *request.Raw shutdown chan struct{} + subscriptions map[string]bool +} + +// Notification represents server-generated notification for client subscriptions. +// Value can be one of block.Block, result.ApplicationLog, result.NotificationEvent +// or transaction.Transaction based on Type. +type Notification struct { + Type response.EventID + Value interface{} } // requestResponse is a combined type for request and response since we can get @@ -59,12 +78,15 @@ func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error return nil, err } wsc := &WSClient{ - Client: *cl, - ws: ws, - shutdown: make(chan struct{}), - done: make(chan struct{}), - responses: make(chan *response.Raw), - requests: make(chan *request.Raw), + Client: *cl, + Notifications: make(chan Notification), + + ws: ws, + shutdown: make(chan struct{}), + done: make(chan struct{}), + responses: make(chan *response.Raw), + requests: make(chan *request.Raw), + subscriptions: make(map[string]bool), } go wsc.wsReader() go wsc.wsWriter() @@ -86,6 +108,7 @@ func (c *WSClient) Close() { func (c *WSClient) wsReader() { c.ws.SetReadLimit(wsReadLimit) c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) +readloop: for { rr := new(requestResponse) c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) @@ -95,9 +118,37 @@ func (c *WSClient) wsReader() { break } if rr.RawID == nil && rr.Method != "" { - if c.notifications != nil { - c.notifications <- &rr.In + event, err := response.GetEventIDFromString(rr.Method) + if err != nil { + // Bad event received. + break } + var slice []json.RawMessage + err = json.Unmarshal(rr.RawParams, &slice) + if err != nil || len(slice) != 1 { + // Bad event received. + break + } + var val interface{} + switch event { + case response.BlockEventID: + val = new(block.Block) + case response.TransactionEventID: + val = new(transaction.Transaction) + case response.NotificationEventID: + val = new(result.NotificationEvent) + case response.ExecutionEventID: + val = new(result.ApplicationLog) + default: + // Bad event received. + break readloop + } + err = json.Unmarshal(slice[0], val) + if err != nil || len(slice) != 1 { + // Bad event received. + break + } + c.Notifications <- Notification{event, val} } else if rr.RawID != nil && (rr.Error != nil || rr.Result != nil) { resp := new(response.Raw) resp.ID = rr.RawID @@ -112,9 +163,7 @@ func (c *WSClient) wsReader() { } close(c.done) close(c.responses) - if c.notifications != nil { - close(c.notifications) - } + close(c.Notifications) } func (c *WSClient) wsWriter() { @@ -158,3 +207,73 @@ func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) { return resp, nil } } + +func (c *WSClient) performSubscription(params request.RawParams) (string, error) { + var resp string + + if err := c.performRequest("subscribe", params, &resp); err != nil { + return "", err + } + c.subscriptions[resp] = true + return resp, nil +} + +func (c *WSClient) performUnsubscription(id string) error { + var resp bool + + if !c.subscriptions[id] { + return errors.New("no subscription with this ID") + } + if err := c.performRequest("unsubscribe", request.NewRawParams(id), &resp); err != nil { + return err + } + if !resp { + return errors.New("unsubscribe method returned false result") + } + delete(c.subscriptions, id) + return nil +} + +// SubscribeForNewBlocks adds subscription for new block events to this instance +// of client. +func (c *WSClient) SubscribeForNewBlocks() (string, error) { + params := request.NewRawParams("block_added") + return c.performSubscription(params) +} + +// SubscribeForNewTransactions adds subscription for new transaction events to +// this instance of client. +func (c *WSClient) SubscribeForNewTransactions() (string, error) { + params := request.NewRawParams("transaction_added") + return c.performSubscription(params) +} + +// SubscribeForExecutionNotifications adds subscription for notifications +// generated during transaction execution to this instance of client. +func (c *WSClient) SubscribeForExecutionNotifications() (string, error) { + params := request.NewRawParams("notification_from_execution") + return c.performSubscription(params) +} + +// SubscribeForTransactionExecutions adds subscription for application execution +// results generated during transaction execution to this instance of client. +func (c *WSClient) SubscribeForTransactionExecutions() (string, error) { + params := request.NewRawParams("transaction_executed") + return c.performSubscription(params) +} + +// Unsubscribe removes subscription for given event stream. +func (c *WSClient) Unsubscribe(id string) error { + return c.performUnsubscription(id) +} + +// UnsubscribeAll removes all active subscriptions of current client. +func (c *WSClient) UnsubscribeAll() error { + for id := range c.subscriptions { + err := c.performUnsubscription(id) + if err != nil { + return err + } + } + return nil +} diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 2a996999a..f747c1710 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -2,8 +2,12 @@ package client import ( "context" + "net/http" + "net/http/httptest" "testing" + "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/require" ) @@ -14,3 +18,129 @@ func TestWSClientClose(t *testing.T) { require.NoError(t, err) wsc.Close() } + +func TestWSClientSubscription(t *testing.T) { + var cases = map[string]func(*WSClient) (string, error){ + "blocks": (*WSClient).SubscribeForNewBlocks, + "transactions": (*WSClient).SubscribeForNewTransactions, + "notifications": (*WSClient).SubscribeForExecutionNotifications, + "executions": (*WSClient).SubscribeForTransactionExecutions, + } + t.Run("good", func(t *testing.T) { + for name, f := range cases { + t.Run(name, func(t *testing.T) { + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "result": "55aaff00"}`) + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + id, err := f(wsc) + require.NoError(t, err) + require.Equal(t, "55aaff00", id) + }) + } + }) + t.Run("bad", func(t *testing.T) { + for name, f := range cases { + t.Run(name, func(t *testing.T) { + srv := initTestServer(t, `{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`) + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + _, err = f(wsc) + require.Error(t, err) + }) + } + }) +} + +func TestWSClientUnsubscription(t *testing.T) { + type responseCheck struct { + response string + code func(*testing.T, *WSClient) + } + var cases = map[string]responseCheck{ + "good": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { + // We can't really subscribe using this stub server, so set up wsc internals. + wsc.subscriptions["0"] = true + err := wsc.Unsubscribe("0") + require.NoError(t, err) + }}, + "all": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { + // We can't really subscribe using this stub server, so set up wsc internals. + wsc.subscriptions["0"] = true + err := wsc.UnsubscribeAll() + require.NoError(t, err) + require.Equal(t, 0, len(wsc.subscriptions)) + }}, + "not subscribed": {`{"jsonrpc": "2.0", "id": 1, "result": true}`, func(t *testing.T, wsc *WSClient) { + err := wsc.Unsubscribe("0") + require.Error(t, err) + }}, + "error returned": {`{"jsonrpc": "2.0", "id": 1, "error":{"code":-32602,"message":"Invalid Params"}}`, func(t *testing.T, wsc *WSClient) { + // We can't really subscribe using this stub server, so set up wsc internals. + wsc.subscriptions["0"] = true + err := wsc.Unsubscribe("0") + require.Error(t, err) + }}, + "false returned": {`{"jsonrpc": "2.0", "id": 1, "result": false}`, func(t *testing.T, wsc *WSClient) { + // We can't really subscribe using this stub server, so set up wsc internals. + wsc.subscriptions["0"] = true + err := wsc.Unsubscribe("0") + require.Error(t, err) + }}, + } + for name, rc := range cases { + t.Run(name, func(t *testing.T) { + srv := initTestServer(t, rc.response) + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + rc.code(t, wsc) + }) + } +} + +func TestWSClientEvents(t *testing.T) { + var ok bool + // Events from RPC server test chain. + var events = []string{ + `{"jsonrpc":"2.0","method":"transaction_executed","params":[{"txid":"0x93670859cc8a42f6ea994869c944879678d33d7501d388f5a446a8c7de147df7","executions":[{"trigger":"Application","contract":"0x0000000000000000000000000000000000000000","vmstate":"HALT","gas_consumed":"1.048","stack":[{"type":"Integer","value":"1"}],"notifications":[{"contract":"0xc2789e5ab9bab828743833965b1df0d5fbcc206f","state":{"type":"Array","value":[{"type":"ByteArray","value":"636f6e74726163742063616c6c"},{"type":"ByteArray","value":"507574"},{"type":"Array","value":[{"type":"ByteArray","value":"746573746b6579"},{"type":"ByteArray","value":"7465737476616c7565"}]}]}}]}]}]}`, + `{"jsonrpc":"2.0","method":"notification_from_execution","params":[{"contract":"0xc2789e5ab9bab828743833965b1df0d5fbcc206f","state":{"type":"Array","value":[{"type":"ByteArray","value":"636f6e74726163742063616c6c"},{"type":"ByteArray","value":"507574"},{"type":"Array","value":[{"type":"ByteArray","value":"746573746b6579"},{"type":"ByteArray","value":"7465737476616c7565"}]}]}}]}`, + `{"jsonrpc":"2.0","method":"transaction_added","params":[{"txid":"0x93670859cc8a42f6ea994869c944879678d33d7501d388f5a446a8c7de147df7","size":60,"type":"InvocationTransaction","version":1,"attributes":[],"vin":[],"vout":[],"scripts":[],"script":"097465737476616c756507746573746b657952c103507574676f20ccfbd5f01d5b9633387428b8bab95a9e78c2"}]}`, + `{"jsonrpc":"2.0","method":"block_added","params":[{"version":0,"previousblockhash":"0x33f3e0e24542b2ec3b6420e6881c31f6460a39a4e733d88f7557cbcc3b5ed560","merkleroot":"0x9d922c5cfd4c8cd1da7a6b2265061998dc438bd0dea7145192e2858155e6c57a","time":1586154525,"height":205,"nonce":1111,"next_consensus":"0xa21e4f7178607089e4fe9fab1300d1f5a3d348be","script":{"invocation":"4047a444a51218ac856f1cbc629f251c7c88187910534d6ba87847c86a9a73ed4951d203fd0a87f3e65657a7259269473896841f65c0a0c8efc79d270d917f4ff640435ee2f073c94a02f0276dfe4465037475e44e1c34c0decb87ec9c2f43edf688059fc4366a41c673d72ba772b4782c39e79f01cb981247353216d52d2df1651140527eb0dfd80a800fdd7ac8fbe68fc9366db2d71655d8ba235525a97a69a7181b1e069b82091be711c25e504a17c3c55eee6e76e6af13cb488fbe35d5c5d025c34041f39a02ebe9bb08be0e4aaa890f447dc9453209bbfb4705d8f2d869c2b55ee2d41dbec2ee476a059d77fb7c26400284328d05aece5f3168b48f1db1c6f7be0b","verification":"532102103a7f7dd016558597f7960d27c516a4394fd968b9e65155eb4b013e4040406e2102a7bc55fe8684e0119768d104ba30795bdcc86619e864add26156723ed185cd622102b3622bf4017bdfe317c58aed5f4c753f206b7db896046fa7d774bbc4bf7f8dc22103d90c07df63e690ce77912e10ab51acc944b66860237b608c4f8f8309e71ee69954ae"},"tx":[{"txid":"0xf9adfde059810f37b3d0686d67f6b29034e0c669537df7e59b40c14a0508b9ed","size":10,"type":"MinerTransaction","version":0,"attributes":[],"vin":[],"vout":[],"scripts":[]},{"txid":"0x93670859cc8a42f6ea994869c944879678d33d7501d388f5a446a8c7de147df7","size":60,"type":"InvocationTransaction","version":1,"attributes":[],"vin":[],"vout":[],"scripts":[],"script":"097465737476616c756507746573746b657952c103507574676f20ccfbd5f01d5b9633387428b8bab95a9e78c2"}]}]}`, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ws" && req.Method == "GET" { + var upgrader = websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + for _, event := range events { + ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) + err = ws.WriteMessage(1, []byte(event)) + if err != nil { + break + } + } + ws.Close() + return + } + })) + + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + for range events { + select { + case _, ok = <-wsc.Notifications: + case <-time.After(time.Second): + t.Fatal("timeout waiting for event") + } + require.Equal(t, true, ok) + } + select { + case _, ok = <-wsc.Notifications: + case <-time.After(time.Second): + t.Fatal("timeout waiting for event") + } + // Connection closed by server. + require.Equal(t, false, ok) +}