diff --git a/go.mod b/go.mod index 9f503d101..b3e0facdf 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/dgraph-io/badger/v2 v2.0.3 github.com/go-redis/redis v6.10.2+incompatible github.com/go-yaml/yaml v2.1.0+incompatible + github.com/gorilla/websocket v1.4.2 github.com/mr-tron/base58 v1.1.2 github.com/nspcc-dev/dbft v0.0.0-20200427132226-05feeca847dd github.com/nspcc-dev/rfc6979 v0.2.0 diff --git a/go.sum b/go.sum index 9747c253e..3b88ae8b1 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,8 @@ github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index 3a5c2820f..d5ad26e7b 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -31,33 +31,34 @@ const ( // Client represents the middleman for executing JSON RPC calls // to remote NEO RPC nodes. type Client struct { - // The underlying http client. It's never a good practice to use - // the http.DefaultClient, therefore we will role our own. - cliMu *sync.Mutex - cli *http.Client - endpoint *url.URL - ctx context.Context - version string - wifMu *sync.Mutex - wif *keys.WIF - balancerMu *sync.Mutex - balancer request.BalanceGetter - cache cache + cli *http.Client + endpoint *url.URL + ctx context.Context + opts Options + requestF func(*request.Raw) (*response.Raw, error) + wifMu *sync.Mutex + wif *keys.WIF + cache cache } // Options defines options for the RPC client. -// All Values are optional. If any duration is not specified -// a default of 3 seconds will be used. +// All values are optional. If any duration is not specified +// a default of 4 seconds will be used. type Options struct { - Cert string - Key string - CACert string - DialTimeout time.Duration - Client *http.Client - // Version is the version of the client that will be send - // along with the request body. If no version is specified - // the default version (currently 2.0) will be used. - Version string + // Balancer is an implementation of request.BalanceGetter interface, + // if not set then the default Client's implementation will be used, but + // it relies on server support for `getunspents` RPC call which is + // standard for neo-go, but only implemented as a plugin for C# node. So + // you can override it here to use NeoScanServer for example. + Balancer request.BalanceGetter + + // Cert is a client-side certificate, it doesn't work at the moment along + // with the other two options below. + Cert string + Key string + CACert string + DialTimeout time.Duration + RequestTimeout time.Duration } // cache stores cache values for the RPC client methods @@ -79,37 +80,39 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { return nil, err } - if opts.Version == "" { - opts.Version = defaultClientVersion + if opts.DialTimeout <= 0 { + opts.DialTimeout = defaultDialTimeout } - if opts.Client == nil { - opts.Client = &http.Client{ - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: opts.DialTimeout, - }).DialContext, - }, - } + if opts.RequestTimeout <= 0 { + opts.RequestTimeout = defaultRequestTimeout + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: opts.DialTimeout, + }).DialContext, + }, + Timeout: opts.RequestTimeout, } // TODO(@antdm): Enable SSL. if opts.Cert != "" && opts.Key != "" { } - if opts.Client.Timeout == 0 { - opts.Client.Timeout = defaultRequestTimeout + cl := &Client{ + ctx: ctx, + cli: httpClient, + wifMu: new(sync.Mutex), + endpoint: url, } - - return &Client{ - ctx: ctx, - cli: opts.Client, - cliMu: new(sync.Mutex), - balancerMu: new(sync.Mutex), - wifMu: new(sync.Mutex), - endpoint: url, - version: opts.Version, - }, nil + if opts.Balancer == nil { + opts.Balancer = cl + } + cl.opts = opts + cl.requestF = cl.makeHTTPRequest + return cl, nil } // WIF returns WIF structure associated with the client. @@ -137,43 +140,10 @@ func (c *Client) SetWIF(wif string) error { return nil } -// Balancer is a getter for balance field. -func (c *Client) Balancer() request.BalanceGetter { - c.balancerMu.Lock() - defer c.balancerMu.Unlock() - return c.balancer -} - -// SetBalancer is a setter for balance field. -func (c *Client) SetBalancer(b request.BalanceGetter) { - c.balancerMu.Lock() - defer c.balancerMu.Unlock() - - if b != nil { - c.balancer = b - } -} - -// Client is a getter for client field. -func (c *Client) Client() *http.Client { - c.cliMu.Lock() - defer c.cliMu.Unlock() - return c.cli -} - -// SetClient is a setter for client field. -func (c *Client) SetClient(cli *http.Client) { - c.cliMu.Lock() - defer c.cliMu.Unlock() - - if cli != nil { - c.cli = cli - } -} - -// CalculateInputs creates input transactions for the specified amount of given -// asset belonging to specified address. This implementation uses GetUnspents -// JSON-RPC call internally, so make sure your RPC server supports that. +// CalculateInputs implements request.BalanceGetter interface and returns inputs +// array for the specified amount of given asset belonging to specified address. +// This implementation uses GetUnspents JSON-RPC call internally, so make sure +// your RPC server supports that. func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.Fixed8) ([]transaction.Input, util.Fixed8, error) { var utxos state.UnspentBalances @@ -192,47 +162,59 @@ func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.F } func (c *Client) performRequest(method string, p request.RawParams, v interface{}) error { + var r = request.Raw{ + JSONRPC: request.JSONRPCVersion, + Method: method, + RawParams: p.Values, + ID: 1, + } + + raw, err := c.requestF(&r) + + if raw != nil && raw.Error != nil { + return raw.Error + } else if err != nil { + return err + } else if raw == nil || raw.Result == nil { + return errors.New("no result returned") + } + return json.Unmarshal(raw.Result, v) +} + +func (c *Client) makeHTTPRequest(r *request.Raw) (*response.Raw, error) { var ( - r = request.Raw{ - JSONRPC: c.version, - Method: method, - RawParams: p.Values, - ID: 1, - } buf = new(bytes.Buffer) - raw = &response.Raw{} + raw = new(response.Raw) ) if err := json.NewEncoder(buf).Encode(r); err != nil { - return err + return nil, err } req, err := http.NewRequest("POST", c.endpoint.String(), buf) if err != nil { - return err + return nil, err } - resp, err := c.Client().Do(req) + resp, err := c.cli.Do(req) if err != nil { - return err + return nil, err } defer resp.Body.Close() // The node might send us proper JSON anyway, so look there first and if // it parses, then it has more relevant data than HTTP error code. err = json.NewDecoder(resp.Body).Decode(raw) - if err == nil { - if raw.Error != nil { - err = raw.Error + if err != nil { + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf("HTTP %d/%s", resp.StatusCode, http.StatusText(resp.StatusCode)) } else { - err = json.Unmarshal(raw.Result, v) + err = errors.Wrap(err, "JSON decoding") } - } else if resp.StatusCode != http.StatusOK { - err = fmt.Errorf("HTTP %d/%s", resp.StatusCode, http.StatusText(resp.StatusCode)) - } else { - err = errors.Wrap(err, "JSON decoding") } - - return err + if err != nil { + return nil, err + } + return raw, nil } // Ping attempts to create a connection to the endpoint. diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index 49e830112..3e53db23d 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -482,7 +482,7 @@ func (c *Client) TransferAsset(asset util.Uint256, address string, amount util.F Address: address, Value: amount, WIF: c.WIF(), - Balancer: c.Balancer(), + Balancer: c.opts.Balancer, } resp util.Uint256 ) diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index 6a4b1aba8..8ceda6640 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -3,11 +3,13 @@ package client import ( "context" "encoding/hex" - "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" + "time" + "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -1510,7 +1512,22 @@ var rpcClientErrorCases = map[string][]rpcClientErrorCase{ }, } -func TestRPCClient(t *testing.T) { +func TestRPCClients(t *testing.T) { + t.Run("Client", func(t *testing.T) { + testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { + return New(ctx, endpoint, opts) + }) + }) + t.Run("WSClient", func(t *testing.T) { + testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { + wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) + require.NoError(t, err) + return &wsc.Client, nil + }) + }) +} + +func testRPCClient(t *testing.T, newClient func(context.Context, string, Options) (*Client, error)) { for method, testBatch := range rpcClientTestCases { t.Run(method, func(t *testing.T) { for _, testCase := range testBatch { @@ -1520,7 +1537,7 @@ func TestRPCClient(t *testing.T) { endpoint := srv.URL opts := Options{} - c, err := New(context.TODO(), endpoint, opts) + c, err := newClient(context.TODO(), endpoint, opts) if err != nil { t.Fatal(err) } @@ -1544,7 +1561,7 @@ func TestRPCClient(t *testing.T) { endpoint := srv.URL opts := Options{} - c, err := New(context.TODO(), endpoint, opts) + c, err := newClient(context.TODO(), endpoint, opts) if err != nil { t.Fatal(err) } @@ -1558,8 +1575,31 @@ func TestRPCClient(t *testing.T) { } } +func httpURLtoWS(url string) string { + return "ws" + strings.TrimPrefix(url, "http") + "/ws" +} + func initTestServer(t *testing.T, resp string) *httptest.Server { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ws" && req.Method == "GET" { + var upgrader = websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + for { + ws.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, _, err = ws.ReadMessage() + if err != nil { + break + } + ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) + err = ws.WriteMessage(1, []byte(resp)) + if err != nil { + break + } + } + ws.Close() + return + } requestHandler(t, w, resp) })) @@ -1568,11 +1608,10 @@ func initTestServer(t *testing.T, resp string) *httptest.Server { func requestHandler(t *testing.T, w http.ResponseWriter, resp string) { w.Header().Set("Content-Type", "application/json; charset=utf-8") - encoder := json.NewEncoder(w) - err := encoder.Encode(json.RawMessage(resp)) + _, err := w.Write([]byte(resp)) if err != nil { - t.Fatalf("Error encountered while encoding response: %s", err.Error()) + t.Fatalf("Error writing response: %s", err.Error()) } } diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go new file mode 100644 index 000000000..bdac24816 --- /dev/null +++ b/pkg/rpc/client/wsclient.go @@ -0,0 +1,160 @@ +package client + +import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/rpc/request" + "github.com/nspcc-dev/neo-go/pkg/rpc/response" +) + +// WSClient is a websocket-enabled RPC client that can be used with appropriate +// servers. It's supposed to be faster than Client because it has persistent +// connection to the server and at the same time is exposes some functionality +// that is only provided via websockets (like event subscription mechanism). +type WSClient struct { + Client + ws *websocket.Conn + done chan struct{} + notifications chan *request.In + responses chan *response.Raw + requests chan *request.Raw + shutdown chan struct{} +} + +// requestResponse is a combined type for request and response since we can get +// any of them here. +type requestResponse struct { + request.In + Error *response.Error `json:"error,omitempty"` + Result json.RawMessage `json:"result,omitempty"` +} + +const ( + // Message limit for receiving side. + wsReadLimit = 10 * 1024 * 1024 + + // Disconnection timeout. + wsPongLimit = 60 * time.Second + + // Ping period for connection liveness check. + wsPingPeriod = wsPongLimit / 2 + + // Write deadline. + wsWriteLimit = wsPingPeriod / 2 +) + +// NewWS returns a new WSClient ready to use (with established websocket +// connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`. +func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) { + cl, err := New(ctx, endpoint, opts) + cl.cli = nil + + dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} + ws, _, err := dialer.Dial(endpoint, nil) + if err != nil { + return nil, err + } + wsc := &WSClient{ + Client: *cl, + ws: ws, + shutdown: make(chan struct{}), + done: make(chan struct{}), + responses: make(chan *response.Raw), + requests: make(chan *request.Raw), + } + go wsc.wsReader() + go wsc.wsWriter() + wsc.requestF = wsc.makeWsRequest + return wsc, nil +} + +// Close closes connection to the remote side rendering this client instance +// unusable. +func (c *WSClient) Close() { + // Closing shutdown channel send signal to wsWriter to break out of the + // loop. In doing so it does ws.Close() closing the network connection + // which in turn makes wsReader receieve err from ws,ReadJSON() and also + // break out of the loop closing c.done channel in its shutdown sequence. + close(c.shutdown) + <-c.done +} + +func (c *WSClient) wsReader() { + c.ws.SetReadLimit(wsReadLimit) + c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) + for { + rr := new(requestResponse) + c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) + err := c.ws.ReadJSON(rr) + if err != nil { + // Timeout/connection loss/malformed response. + break + } + if rr.RawID == nil && rr.Method != "" { + if c.notifications != nil { + c.notifications <- &rr.In + } + } else if rr.RawID != nil && (rr.Error != nil || rr.Result != nil) { + resp := new(response.Raw) + resp.ID = rr.RawID + resp.JSONRPC = rr.JSONRPC + resp.Error = rr.Error + resp.Result = rr.Result + c.responses <- resp + } else { + // Malformed response, neither valid request, nor valid response. + break + } + } + close(c.done) + close(c.responses) + if c.notifications != nil { + close(c.notifications) + } +} + +func (c *WSClient) wsWriter() { + pingTicker := time.NewTicker(wsPingPeriod) + defer c.ws.Close() + defer pingTicker.Stop() + for { + select { + case <-c.shutdown: + return + case <-c.done: + return + case req, ok := <-c.requests: + if !ok { + return + } + c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout)) + if err := c.ws.WriteJSON(req); err != nil { + return + } + case <-pingTicker.C: + c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)) + if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + return + } + } + } + +} + +func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) { + select { + case <-c.done: + return nil, errors.New("connection lost") + case c.requests <- r: + } + select { + case <-c.done: + return nil, errors.New("connection lost") + case resp := <-c.responses: + return resp, nil + } +} diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go new file mode 100644 index 000000000..2a996999a --- /dev/null +++ b/pkg/rpc/client/wsclient_test.go @@ -0,0 +1,16 @@ +package client + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWSClientClose(t *testing.T) { + srv := initTestServer(t, "") + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + wsc.Close() +} diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index b4164ef2f..817cba0ba 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -9,12 +9,12 @@ import ( "net" "net/http" "strconv" + "time" - "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" - "github.com/nspcc-dev/neo-go/pkg/rpc" - + "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -22,6 +22,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/network" + "github.com/nspcc-dev/neo-go/pkg/rpc" "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response" "github.com/nspcc-dev/neo-go/pkg/rpc/response/result" @@ -44,7 +45,21 @@ type ( } ) -var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){ +const ( + // Message limit for receiving side. + wsReadLimit = 4096 + + // Disconnection timeout. + wsPongLimit = 60 * time.Second + + // Ping period for connection liveness check. + wsPingPeriod = wsPongLimit / 2 + + // Write deadline. + wsWriteLimit = wsPingPeriod / 2 +) + +var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){ "getaccountstate": (*Server).getAccountState, "getapplicationlog": (*Server).getApplicationLog, "getassetstate": (*Server).getAssetState, @@ -77,10 +92,14 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){ "validateaddress": (*Server).validateAddress, } -var invalidBlockHeightError = func(index int, height int) error { - return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height) +var invalidBlockHeightError = func(index int, height int) *response.Error { + return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil) } +// upgrader is a no-op websocket.Upgrader that reuses HTTP server buffers and +// doesn't set any Error function. +var upgrader = websocket.Upgrader{} + // New creates a new Server struct. func New(chain blockchainer.Blockchainer, conf rpc.Config, coreServer *network.Server, log *zap.Logger) Server { httpServer := &http.Server{ @@ -111,11 +130,11 @@ func (s *Server) Start(errChan chan error) { s.log.Info("RPC server is not enabled") return } - s.Handler = http.HandlerFunc(s.requestHandler) + s.Handler = http.HandlerFunc(s.handleHTTPRequest) s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr)) if cfg := s.config.TLSConfig; cfg.Enabled { - s.https.Handler = http.HandlerFunc(s.requestHandler) + s.https.Handler = http.HandlerFunc(s.handleHTTPRequest) s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr)) go func() { err := s.https.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile) @@ -149,11 +168,23 @@ func (s *Server) Shutdown() error { return err } -func (s *Server) requestHandler(w http.ResponseWriter, httpRequest *http.Request) { +func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { + if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" { + ws, err := upgrader.Upgrade(w, httpRequest, nil) + if err != nil { + s.log.Info("websocket connection upgrade failed", zap.Error(err)) + return + } + resChan := make(chan response.Raw) + go s.handleWsWrites(ws, resChan) + s.handleWsReads(ws, resChan) + return + } + req := request.NewIn() if httpRequest.Method != "POST" { - s.WriteErrorResponse( + s.writeHTTPErrorResponse( req, w, response.NewInvalidParamsError( @@ -165,59 +196,90 @@ func (s *Server) requestHandler(w http.ResponseWriter, httpRequest *http.Request err := req.DecodeData(httpRequest.Body) if err != nil { - s.WriteErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err)) + s.writeHTTPErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err)) return } - reqParams, err := req.Params() - if err != nil { - s.WriteErrorResponse(req, w, response.NewInvalidParamsError("Problem parsing request parameters", err)) - return - } - - s.methodHandler(w, req, *reqParams) + resp := s.handleRequest(req) + s.writeHTTPServerResponse(req, w, resp) } -func (s *Server) methodHandler(w http.ResponseWriter, req *request.In, reqParams request.Params) { +func (s *Server) handleRequest(req *request.In) response.Raw { + reqParams, err := req.Params() + if err != nil { + return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err)) + } + s.log.Debug("processing rpc request", zap.String("method", req.Method), zap.String("params", fmt.Sprintf("%v", reqParams))) - var ( - results interface{} - resultsErr error - ) - incCounter(req.Method) handler, ok := rpcHandlers[req.Method] - if ok { - results, resultsErr = handler(s, reqParams) - } else { - resultsErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil) + if !ok { + return s.packResponseToRaw(req, nil, response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)) } - - if resultsErr != nil { - s.WriteErrorResponse(req, w, resultsErr) - return - } - - s.WriteResponse(req, w, results) + res, resErr := handler(s, *reqParams) + return s.packResponseToRaw(req, res, resErr) } -func (s *Server) getBestBlockHash(_ request.Params) (interface{}, error) { +func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) { + pingTicker := time.NewTicker(wsPingPeriod) + defer ws.Close() + defer pingTicker.Stop() + for { + select { + case res, ok := <-resChan: + if !ok { + return + } + ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)) + if err := ws.WriteJSON(res); err != nil { + return + } + case <-pingTicker.C: + ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)) + if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + return + } + } + } +} + +func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw) { + ws.SetReadLimit(wsReadLimit) + ws.SetReadDeadline(time.Now().Add(wsPongLimit)) + ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) + for { + req := new(request.In) + err := ws.ReadJSON(req) + if err != nil { + break + } + res := s.handleRequest(req) + if res.Error != nil { + s.logRequestError(req, res.Error) + } + resChan <- res + } + close(resChan) + ws.Close() +} + +func (s *Server) getBestBlockHash(_ request.Params) (interface{}, *response.Error) { return "0x" + s.chain.CurrentBlockHash().StringLE(), nil } -func (s *Server) getBlockCount(_ request.Params) (interface{}, error) { +func (s *Server) getBlockCount(_ request.Params) (interface{}, *response.Error) { return s.chain.BlockHeight() + 1, nil } -func (s *Server) getConnectionCount(_ request.Params) (interface{}, error) { +func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Error) { return s.coreServer.PeerCount(), nil } -func (s *Server) getBlock(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) { var hash util.Uint256 param, ok := reqParams.Value(0) @@ -255,7 +317,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, error) { return hex.EncodeToString(writer.Bytes()), nil } -func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.NumberT) if !ok { return nil, response.ErrInvalidParams @@ -268,7 +330,7 @@ func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) { return s.chain.GetHeaderHash(num), nil } -func (s *Server) getVersion(_ request.Params) (interface{}, error) { +func (s *Server) getVersion(_ request.Params) (interface{}, *response.Error) { return result.Version{ Port: s.coreServer.Port, Nonce: s.coreServer.ID(), @@ -276,7 +338,7 @@ func (s *Server) getVersion(_ request.Params) (interface{}, error) { }, nil } -func (s *Server) getPeers(_ request.Params) (interface{}, error) { +func (s *Server) getPeers(_ request.Params) (interface{}, *response.Error) { peers := result.NewGetPeers() peers.AddUnconnected(s.coreServer.UnconnectedPeers()) peers.AddConnected(s.coreServer.ConnectedPeers()) @@ -284,7 +346,7 @@ func (s *Server) getPeers(_ request.Params) (interface{}, error) { return peers, nil } -func (s *Server) getRawMempool(_ request.Params) (interface{}, error) { +func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) { mp := s.chain.GetMemPool() hashList := make([]util.Uint256, 0) for _, item := range mp.GetVerifiedTransactions() { @@ -293,7 +355,7 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, error) { return hashList, nil } -func (s *Server) validateAddress(reqParams request.Params) (interface{}, error) { +func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -301,7 +363,7 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, error) return validateAddress(param.Value), nil } -func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) { +func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -320,7 +382,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) { } // getApplicationLog returns the contract log based on the specified txid. -func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error) { +func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -352,7 +414,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error return result.NewApplicationLog(appExecResult, scriptHash), nil } -func (s *Server) getClaimable(ps request.Params) (interface{}, error) { +func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -369,7 +431,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) { return nil }) if err != nil { - return nil, err + return nil, response.NewInternalServerError("Unclaimed processing failure", err) } } @@ -404,7 +466,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) { }, nil } -func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) { +func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -437,7 +499,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) { return bs, nil } -func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, error) { +func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -501,7 +563,7 @@ func amountToString(amount int64, decimals int64) string { return fmt.Sprintf(fs, q, r) } -func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, error) { +func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, *response.Error) { if d, ok := cache[h]; ok { return d, nil } @@ -516,11 +578,11 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 }, }) if err != nil { - return 0, err + return 0, response.NewInternalServerError("Can't create script", err) } res := s.runScriptInVM(script) if res == nil || res.State != "HALT" || len(res.Stack) == 0 { - return 0, errors.New("execution error") + return 0, response.NewInternalServerError("execution error", errors.New("no result")) } var d int64 @@ -530,16 +592,16 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 case smartcontract.ByteArrayType: d = emit.BytesToInt(item.Value.([]byte)).Int64() default: - return 0, errors.New("invalid result") + return 0, response.NewInternalServerError("invalid result", errors.New("not an integer")) } if d < 0 { - return 0, errors.New("negative decimals") + return 0, response.NewInternalServerError("incorrect result", errors.New("negative result")) } cache[h] = d return d, nil } -func (s *Server) getStorage(ps request.Params) (interface{}, error) { +func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { param, ok := ps.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -570,8 +632,8 @@ func (s *Server) getStorage(ps request.Params) (interface{}, error) { return hex.EncodeToString(item.Value), nil } -func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error) { - var resultsErr error +func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *response.Error) { + var resultsErr *response.Error var results interface{} if param0, ok := reqParams.Value(0); !ok { @@ -607,7 +669,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error return results, resultsErr } -func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) { +func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) { p, ok := ps.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -626,7 +688,7 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) { return height, nil } -func (s *Server) getTxOut(ps request.Params) (interface{}, error) { +func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) { p, ok := ps.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -661,7 +723,7 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, error) { } // getContractState returns contract state (contract information, according to the contract script hash). -func (s *Server) getContractState(reqParams request.Params) (interface{}, error) { +func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) { var results interface{} param, ok := reqParams.ValueWithType(0, request.StringT) @@ -680,17 +742,17 @@ func (s *Server) getContractState(reqParams request.Params) (interface{}, error) return results, nil } -func (s *Server) getAccountState(ps request.Params) (interface{}, error) { +func (s *Server) getAccountState(ps request.Params) (interface{}, *response.Error) { return s.getAccountStateAux(ps, false) } -func (s *Server) getUnspents(ps request.Params) (interface{}, error) { +func (s *Server) getUnspents(ps request.Params) (interface{}, *response.Error) { return s.getAccountStateAux(ps, true) } // getAccountState returns account state either in short or full (unspents included) form. -func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, error) { - var resultsErr error +func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, *response.Error) { + var resultsErr *response.Error var results interface{} param, ok := reqParams.ValueWithType(0, request.StringT) @@ -717,7 +779,7 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in } // getBlockSysFee returns the system fees of the block, based on the specified index. -func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.NumberT) if !ok { return 0, response.ErrInvalidParams @@ -729,9 +791,9 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) { } headerHash := s.chain.GetHeaderHash(num) - block, err := s.chain.GetBlock(headerHash) - if err != nil { - return 0, response.NewRPCError(err.Error(), "", nil) + block, errBlock := s.chain.GetBlock(headerHash) + if errBlock != nil { + return 0, response.NewRPCError(errBlock.Error(), "", nil) } var blockSysFee util.Fixed8 @@ -743,7 +805,7 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) { } // getBlockHeader returns the corresponding block header information according to the specified script hash. -func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) { var verbose bool param, ok := reqParams.ValueWithType(0, request.StringT) @@ -776,13 +838,13 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) { buf := io.NewBufBinWriter() h.EncodeBinary(buf.BinWriter) if buf.Err != nil { - return nil, err + return nil, response.NewInternalServerError("encoding error", buf.Err) } return hex.EncodeToString(buf.Bytes()), nil } // getUnclaimed returns unclaimed GAS amount of the specified address. -func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) { +func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -796,21 +858,24 @@ func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) { if acc == nil { return nil, response.NewInternalServerError("unknown account", nil) } - - return result.NewUnclaimed(acc, s.chain) + res, errRes := result.NewUnclaimed(acc, s.chain) + if errRes != nil { + return nil, response.NewInternalServerError("can't create unclaimed response", errRes) + } + return res, nil } // getValidators returns the current NEO consensus nodes information and voting status. -func (s *Server) getValidators(_ request.Params) (interface{}, error) { +func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) { var validators keys.PublicKeys validators, err := s.chain.GetValidators() if err != nil { - return nil, err + return nil, response.NewRPCError("can't get validators", "", err) } enrollments, err := s.chain.GetEnrollments() if err != nil { - return nil, err + return nil, response.NewRPCError("can't get enrollments", "", err) } var res []result.Validator for _, v := range enrollments { @@ -824,14 +889,14 @@ func (s *Server) getValidators(_ request.Params) (interface{}, error) { } // invoke implements the `invoke` RPC call. -func (s *Server) invoke(reqParams request.Params) (interface{}, error) { +func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) { scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams } scriptHash, err := scriptHashHex.GetUint160FromHex() if err != nil { - return nil, err + return nil, response.ErrInvalidParams } sliceP, ok := reqParams.ValueWithType(1, request.ArrayT) if !ok { @@ -839,34 +904,34 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, error) { } slice, err := sliceP.GetArray() if err != nil { - return nil, err + return nil, response.ErrInvalidParams } script, err := request.CreateInvocationScript(scriptHash, slice) if err != nil { - return nil, err + return nil, response.NewInternalServerError("can't create invocation script", err) } return s.runScriptInVM(script), nil } // invokescript implements the `invokescript` RPC call. -func (s *Server) invokeFunction(reqParams request.Params) (interface{}, error) { +func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) { scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams } scriptHash, err := scriptHashHex.GetUint160FromHex() if err != nil { - return nil, err + return nil, response.ErrInvalidParams } script, err := request.CreateFunctionInvocationScript(scriptHash, reqParams[1:]) if err != nil { - return nil, err + return nil, response.NewInternalServerError("can't create invocation script", err) } return s.runScriptInVM(script), nil } // invokescript implements the `invokescript` RPC call. -func (s *Server) invokescript(reqParams request.Params) (interface{}, error) { +func (s *Server) invokescript(reqParams request.Params) (interface{}, *response.Error) { if len(reqParams) < 1 { return nil, response.ErrInvalidParams } @@ -896,7 +961,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke { } // submitBlock broadcasts a raw block over the NEO network. -func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) { +func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -923,8 +988,8 @@ func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) { return true, nil } -func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, error) { - var resultsErr error +func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *response.Error) { + var resultsErr *response.Error var results interface{} if len(reqParams) < 1 { @@ -958,7 +1023,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, erro return results, resultsErr } -func (s *Server) blockHeightFromParam(param *request.Param) (int, error) { +func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) { num, err := param.GetInt() if err != nil { return 0, nil @@ -970,23 +1035,33 @@ func (s *Server) blockHeightFromParam(param *request.Param) (int, error) { return num, nil } -// WriteErrorResponse writes an error response to the ResponseWriter. -func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, err error) { - jsonErr, ok := err.(*response.Error) - if !ok { - jsonErr = response.NewInternalServerError("Internal server error", err) - } - +func (s *Server) packResponseToRaw(r *request.In, result interface{}, respErr *response.Error) response.Raw { resp := response.Raw{ HeaderAndError: response.HeaderAndError{ Header: response.Header{ JSONRPC: r.JSONRPC, ID: r.RawID, }, - Error: jsonErr, }, } + if respErr != nil { + resp.Error = respErr + } else { + resJSON, err := json.Marshal(result) + if err != nil { + s.log.Error("failed to marshal result", + zap.Error(err), + zap.String("method", r.Method)) + resp.Error = response.NewInternalServerError("failed to encode result", err) + } else { + resp.Result = resJSON + } + } + return resp +} +// logRequestError is a request error logger. +func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) { logFields := []zap.Field{ zap.Error(jsonErr.Cause), zap.String("method", r.Method), @@ -998,35 +1073,20 @@ func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, err er } s.log.Error("Error encountered with rpc request", logFields...) - - w.WriteHeader(jsonErr.HTTPCode) - s.writeServerResponse(r, w, resp) } -// WriteResponse encodes the response and writes it to the ResponseWriter. -func (s *Server) WriteResponse(r *request.In, w http.ResponseWriter, result interface{}) { - resJSON, err := json.Marshal(result) - if err != nil { - s.log.Error("Error encountered while encoding response", - zap.String("err", err.Error()), - zap.String("method", r.Method)) - return - } - - resp := response.Raw{ - HeaderAndError: response.HeaderAndError{ - Header: response.Header{ - JSONRPC: r.JSONRPC, - ID: r.RawID, - }, - }, - Result: resJSON, - } - - s.writeServerResponse(r, w, resp) +// writeHTTPErrorResponse writes an error response to the ResponseWriter. +func (s *Server) writeHTTPErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) { + resp := s.packResponseToRaw(r, nil, jsonErr) + s.writeHTTPServerResponse(r, w, resp) } -func (s *Server) writeServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) { +func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) { + // Errors can happen in many places and we can only catch ALL of them here. + if resp.Error != nil { + s.logRequestError(r, resp.Error) + w.WriteHeader(resp.Error.HTTPCode) + } w.Header().Set("Content-Type", "application/json; charset=utf-8") if s.config.EnableCORSWorkaround { w.Header().Set("Access-Control-Allow-Origin", "*") diff --git a/pkg/rpc/server/server_helper_test.go b/pkg/rpc/server/server_helper_test.go index 1263c0dcd..c6ee3167e 100644 --- a/pkg/rpc/server/server_helper_test.go +++ b/pkg/rpc/server/server_helper_test.go @@ -2,6 +2,7 @@ package server import ( "net/http" + "net/http/httptest" "os" "testing" @@ -17,7 +18,7 @@ import ( "go.uber.org/zap/zaptest" ) -func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, http.HandlerFunc) { +func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *httptest.Server) { var nBlocks uint32 net := config.ModeUnitTestNet @@ -55,9 +56,11 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, http.HandlerFu server, err := network.NewServer(serverConfig, chain, logger) require.NoError(t, err) rpcServer := New(chain, cfg.ApplicationConfiguration.RPC, server, logger) - handler := http.HandlerFunc(rpcServer.requestHandler) - return chain, handler + handler := http.HandlerFunc(rpcServer.handleHTTPRequest) + srv := httptest.NewServer(handler) + + return chain, srv } type FeerStub struct{} diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index 6a5ef7135..35a1b98e5 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" @@ -31,7 +32,7 @@ import ( type executor struct { chain *core.Blockchain - handler http.HandlerFunc + httpSrv *httptest.Server } const ( @@ -813,18 +814,31 @@ var rpcTestCases = map[string][]rpcTestCase{ } func TestRPC(t *testing.T) { - chain, handler := initServerWithInMemoryChain(t) + t.Run("http", func(t *testing.T) { + testRPCProtocol(t, doRPCCallOverHTTP) + }) + + t.Run("websocket", func(t *testing.T) { + testRPCProtocol(t, doRPCCallOverWS) + }) +} + +// testRPCProtocol runs a full set of tests using given callback to make actual +// calls. Some tests change the chain state, thus we reinitialize the chain from +// scratch here. +func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []byte) { + chain, httpSrv := initServerWithInMemoryChain(t) defer chain.Close() - e := &executor{chain: chain, handler: handler} + e := &executor{chain: chain, httpSrv: httpSrv} for method, cases := range rpcTestCases { t.Run(method, func(t *testing.T) { rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}` for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), handler, t) + body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t) result := checkErrGetResult(t, body, tc.fail) if tc.fail { return @@ -848,7 +862,7 @@ func TestRPC(t *testing.T) { rpc := `{"jsonrpc": "2.0", "id": 1, "method": "submitblock", "params": ["%s"]}` t.Run("empty", func(t *testing.T) { s := newBlock(t, chain, 1) - body := doRPCCall(fmt.Sprintf(rpc, s), handler, t) + body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, s)), httpSrv.URL, t) checkErrGetResult(t, body, true) }) @@ -867,13 +881,13 @@ func TestRPC(t *testing.T) { t.Run("invalid height", func(t *testing.T) { b := newBlock(t, chain, 2, newTx()) - body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), handler, t) + body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), httpSrv.URL, t) checkErrGetResult(t, body, true) }) t.Run("positive", func(t *testing.T) { b := newBlock(t, chain, 1, newTx()) - body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), handler, t) + body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), httpSrv.URL, t) data := checkErrGetResult(t, body, false) var res bool require.NoError(t, json.Unmarshal(data, &res)) @@ -885,7 +899,7 @@ func TestRPC(t *testing.T) { block, _ := chain.GetBlock(chain.GetHeaderHash(0)) TXHash := block.Transactions[0].Hash() rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s"]}"`, TXHash.StringLE()) - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) result := checkErrGetResult(t, body, false) var res string err := json.Unmarshal(result, &res) @@ -897,7 +911,7 @@ func TestRPC(t *testing.T) { block, _ := chain.GetBlock(chain.GetHeaderHash(0)) TXHash := block.Transactions[0].Hash() rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 0]}"`, TXHash.StringLE()) - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) result := checkErrGetResult(t, body, false) var res string err := json.Unmarshal(result, &res) @@ -909,7 +923,7 @@ func TestRPC(t *testing.T) { block, _ := chain.GetBlock(chain.GetHeaderHash(0)) TXHash := block.Transactions[0].Hash() rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 1]}"`, TXHash.StringLE()) - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) txOut := checkErrGetResult(t, body, false) actual := result.TransactionOutputRaw{} err := json.Unmarshal(txOut, &actual) @@ -936,7 +950,7 @@ func TestRPC(t *testing.T) { hdr := e.getHeader(testHeaderHash) runCase := func(t *testing.T, rpc string, expected, actual interface{}) { - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) data := checkErrGetResult(t, body, false) require.NoError(t, json.Unmarshal(data, actual)) require.Equal(t, expected, actual) @@ -984,7 +998,7 @@ func TestRPC(t *testing.T) { tx := block.Transactions[2] rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "gettxout", "params": [%s, %d]}"`, `"`+tx.Hash().StringLE()+`"`, 0) - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) res := checkErrGetResult(t, body, false) var txOut result.TransactionOutput @@ -1010,7 +1024,7 @@ func TestRPC(t *testing.T) { } rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getrawmempool", "params": []}` - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) res := checkErrGetResult(t, body, false) var actual []util.Uint256 @@ -1085,12 +1099,23 @@ func checkErrGetResult(t *testing.T, body []byte, expectingFail bool) json.RawMe return resp.Result } -func doRPCCall(rpcCall string, handler http.HandlerFunc, t *testing.T) []byte { - req := httptest.NewRequest("POST", "http://0.0.0.0:20333/", strings.NewReader(rpcCall)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - handler(w, req) - resp := w.Result() +func doRPCCallOverWS(rpcCall string, url string, t *testing.T) []byte { + dialer := websocket.Dialer{HandshakeTimeout: time.Second} + url = "ws" + strings.TrimPrefix(url, "http") + c, _, err := dialer.Dial(url+"/ws", nil) + require.NoError(t, err) + c.SetWriteDeadline(time.Now().Add(time.Second)) + require.NoError(t, c.WriteMessage(1, []byte(rpcCall))) + c.SetReadDeadline(time.Now().Add(time.Second)) + _, body, err := c.ReadMessage() + require.NoError(t, err) + return bytes.TrimSpace(body) +} + +func doRPCCallOverHTTP(rpcCall string, url string, t *testing.T) []byte { + cl := http.Client{Timeout: time.Second} + resp, err := cl.Post(url, "application/json", strings.NewReader(rpcCall)) + require.NoErrorf(t, err, "could not make a POST request") body, err := ioutil.ReadAll(resp.Body) assert.NoErrorf(t, err, "could not read response from the request: %s", rpcCall) return bytes.TrimSpace(body)