mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2025-01-07 09:50:36 +00:00
Merge pull request #921 from nspcc-dev/rpc-websocket-2.x
RPC over websocket (2.x)
This commit is contained in:
commit
e097e86bfa
10 changed files with 538 additions and 249 deletions
1
go.mod
1
go.mod
|
@ -6,6 +6,7 @@ require (
|
||||||
github.com/dgraph-io/badger/v2 v2.0.3
|
github.com/dgraph-io/badger/v2 v2.0.3
|
||||||
github.com/go-redis/redis v6.10.2+incompatible
|
github.com/go-redis/redis v6.10.2+incompatible
|
||||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||||
|
github.com/gorilla/websocket v1.4.2
|
||||||
github.com/mr-tron/base58 v1.1.2
|
github.com/mr-tron/base58 v1.1.2
|
||||||
github.com/nspcc-dev/dbft v0.0.0-20200303183127-36d3da79c682
|
github.com/nspcc-dev/dbft v0.0.0-20200303183127-36d3da79c682
|
||||||
github.com/nspcc-dev/rfc6979 v0.2.0
|
github.com/nspcc-dev/rfc6979 v0.2.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -85,6 +85,8 @@ github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp
|
||||||
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
|
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
|
||||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||||
|
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||||
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
||||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||||
|
|
|
@ -29,32 +29,33 @@ const (
|
||||||
// Client represents the middleman for executing JSON RPC calls
|
// Client represents the middleman for executing JSON RPC calls
|
||||||
// to remote NEO RPC nodes.
|
// to remote NEO RPC nodes.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
// The underlying http client. It's never a good practice to use
|
|
||||||
// the http.DefaultClient, therefore we will role our own.
|
|
||||||
cliMu *sync.Mutex
|
|
||||||
cli *http.Client
|
cli *http.Client
|
||||||
endpoint *url.URL
|
endpoint *url.URL
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
version string
|
opts Options
|
||||||
|
requestF func(*request.Raw) (*response.Raw, error)
|
||||||
wifMu *sync.Mutex
|
wifMu *sync.Mutex
|
||||||
wif *keys.WIF
|
wif *keys.WIF
|
||||||
balancerMu *sync.Mutex
|
|
||||||
balancer request.BalanceGetter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options defines options for the RPC client.
|
// Options defines options for the RPC client.
|
||||||
// All Values are optional. If any duration is not specified
|
// All values are optional. If any duration is not specified
|
||||||
// a default of 3 seconds will be used.
|
// a default of 4 seconds will be used.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
|
// Balancer is an implementation of request.BalanceGetter interface,
|
||||||
|
// if not set then the default Client's implementation will be used, but
|
||||||
|
// it relies on server support for `getunspents` RPC call which is
|
||||||
|
// standard for neo-go, but only implemented as a plugin for C# node. So
|
||||||
|
// you can override it here to use NeoScanServer for example.
|
||||||
|
Balancer request.BalanceGetter
|
||||||
|
|
||||||
|
// Cert is a client-side certificate, it doesn't work at the moment along
|
||||||
|
// with the other two options below.
|
||||||
Cert string
|
Cert string
|
||||||
Key string
|
Key string
|
||||||
CACert string
|
CACert string
|
||||||
DialTimeout time.Duration
|
DialTimeout time.Duration
|
||||||
Client *http.Client
|
RequestTimeout time.Duration
|
||||||
// Version is the version of the client that will be send
|
|
||||||
// along with the request body. If no version is specified
|
|
||||||
// the default version (currently 2.0) will be used.
|
|
||||||
Version string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new Client ready to use.
|
// New returns a new Client ready to use.
|
||||||
|
@ -64,37 +65,39 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Version == "" {
|
if opts.DialTimeout <= 0 {
|
||||||
opts.Version = defaultClientVersion
|
opts.DialTimeout = defaultDialTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Client == nil {
|
if opts.RequestTimeout <= 0 {
|
||||||
opts.Client = &http.Client{
|
opts.RequestTimeout = defaultRequestTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
DialContext: (&net.Dialer{
|
DialContext: (&net.Dialer{
|
||||||
Timeout: opts.DialTimeout,
|
Timeout: opts.DialTimeout,
|
||||||
}).DialContext,
|
}).DialContext,
|
||||||
},
|
},
|
||||||
}
|
Timeout: opts.RequestTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(@antdm): Enable SSL.
|
// TODO(@antdm): Enable SSL.
|
||||||
if opts.Cert != "" && opts.Key != "" {
|
if opts.Cert != "" && opts.Key != "" {
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Client.Timeout == 0 {
|
cl := &Client{
|
||||||
opts.Client.Timeout = defaultRequestTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cli: opts.Client,
|
cli: httpClient,
|
||||||
cliMu: new(sync.Mutex),
|
|
||||||
balancerMu: new(sync.Mutex),
|
|
||||||
wifMu: new(sync.Mutex),
|
wifMu: new(sync.Mutex),
|
||||||
endpoint: url,
|
endpoint: url,
|
||||||
version: opts.Version,
|
}
|
||||||
}, nil
|
if opts.Balancer == nil {
|
||||||
|
opts.Balancer = cl
|
||||||
|
}
|
||||||
|
cl.opts = opts
|
||||||
|
cl.requestF = cl.makeHTTPRequest
|
||||||
|
return cl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WIF returns WIF structure associated with the client.
|
// WIF returns WIF structure associated with the client.
|
||||||
|
@ -122,43 +125,10 @@ func (c *Client) SetWIF(wif string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Balancer is a getter for balance field.
|
// CalculateInputs implements request.BalanceGetter interface and returns inputs
|
||||||
func (c *Client) Balancer() request.BalanceGetter {
|
// array for the specified amount of given asset belonging to specified address.
|
||||||
c.balancerMu.Lock()
|
// This implementation uses GetUnspents JSON-RPC call internally, so make sure
|
||||||
defer c.balancerMu.Unlock()
|
// your RPC server supports that.
|
||||||
return c.balancer
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBalancer is a setter for balance field.
|
|
||||||
func (c *Client) SetBalancer(b request.BalanceGetter) {
|
|
||||||
c.balancerMu.Lock()
|
|
||||||
defer c.balancerMu.Unlock()
|
|
||||||
|
|
||||||
if b != nil {
|
|
||||||
c.balancer = b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client is a getter for client field.
|
|
||||||
func (c *Client) Client() *http.Client {
|
|
||||||
c.cliMu.Lock()
|
|
||||||
defer c.cliMu.Unlock()
|
|
||||||
return c.cli
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClient is a setter for client field.
|
|
||||||
func (c *Client) SetClient(cli *http.Client) {
|
|
||||||
c.cliMu.Lock()
|
|
||||||
defer c.cliMu.Unlock()
|
|
||||||
|
|
||||||
if cli != nil {
|
|
||||||
c.cli = cli
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CalculateInputs creates input transactions for the specified amount of given
|
|
||||||
// asset belonging to specified address. This implementation uses GetUnspents
|
|
||||||
// JSON-RPC call internally, so make sure your RPC server suppors that.
|
|
||||||
func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.Fixed8) ([]transaction.Input, util.Fixed8, error) {
|
func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.Fixed8) ([]transaction.Input, util.Fixed8, error) {
|
||||||
var utxos state.UnspentBalances
|
var utxos state.UnspentBalances
|
||||||
|
|
||||||
|
@ -177,47 +147,59 @@ func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.F
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) performRequest(method string, p request.RawParams, v interface{}) error {
|
func (c *Client) performRequest(method string, p request.RawParams, v interface{}) error {
|
||||||
var (
|
var r = request.Raw{
|
||||||
r = request.Raw{
|
JSONRPC: request.JSONRPCVersion,
|
||||||
JSONRPC: c.version,
|
|
||||||
Method: method,
|
Method: method,
|
||||||
RawParams: p.Values,
|
RawParams: p.Values,
|
||||||
ID: 1,
|
ID: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
raw, err := c.requestF(&r)
|
||||||
|
|
||||||
|
if raw != nil && raw.Error != nil {
|
||||||
|
return raw.Error
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
} else if raw == nil || raw.Result == nil {
|
||||||
|
return errors.New("no result returned")
|
||||||
|
}
|
||||||
|
return json.Unmarshal(raw.Result, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) makeHTTPRequest(r *request.Raw) (*response.Raw, error) {
|
||||||
|
var (
|
||||||
buf = new(bytes.Buffer)
|
buf = new(bytes.Buffer)
|
||||||
raw = &response.Raw{}
|
raw = new(response.Raw)
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := json.NewEncoder(buf).Encode(r); err != nil {
|
if err := json.NewEncoder(buf).Encode(r); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", c.endpoint.String(), buf)
|
req, err := http.NewRequest("POST", c.endpoint.String(), buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp, err := c.Client().Do(req)
|
resp, err := c.cli.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// The node might send us proper JSON anyway, so look there first and if
|
// The node might send us proper JSON anyway, so look there first and if
|
||||||
// it parses, then it has more relevant data than HTTP error code.
|
// it parses, then it has more relevant data than HTTP error code.
|
||||||
err = json.NewDecoder(resp.Body).Decode(raw)
|
err = json.NewDecoder(resp.Body).Decode(raw)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
if raw.Error != nil {
|
if resp.StatusCode != http.StatusOK {
|
||||||
err = raw.Error
|
|
||||||
} else {
|
|
||||||
err = json.Unmarshal(raw.Result, v)
|
|
||||||
}
|
|
||||||
} else if resp.StatusCode != http.StatusOK {
|
|
||||||
err = fmt.Errorf("HTTP %d/%s", resp.StatusCode, http.StatusText(resp.StatusCode))
|
err = fmt.Errorf("HTTP %d/%s", resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||||
} else {
|
} else {
|
||||||
err = errors.Wrap(err, "JSON decoding")
|
err = errors.Wrap(err, "JSON decoding")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return raw, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ping attempts to create a connection to the endpoint.
|
// Ping attempts to create a connection to the endpoint.
|
||||||
|
|
|
@ -484,7 +484,7 @@ func (c *Client) TransferAsset(asset util.Uint256, address string, amount util.F
|
||||||
Address: address,
|
Address: address,
|
||||||
Value: amount,
|
Value: amount,
|
||||||
WIF: c.WIF(),
|
WIF: c.WIF(),
|
||||||
Balancer: c.Balancer(),
|
Balancer: c.opts.Balancer,
|
||||||
}
|
}
|
||||||
resp util.Uint256
|
resp util.Uint256
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,11 +3,13 @@ package client
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
@ -1427,7 +1429,22 @@ var rpcClientErrorCases = map[string][]rpcClientErrorCase{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRPCClient(t *testing.T) {
|
func TestRPCClients(t *testing.T) {
|
||||||
|
t.Run("Client", func(t *testing.T) {
|
||||||
|
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
|
return New(ctx, endpoint, opts)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("WSClient", func(t *testing.T) {
|
||||||
|
testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) {
|
||||||
|
wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return &wsc.Client, nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testRPCClient(t *testing.T, newClient func(context.Context, string, Options) (*Client, error)) {
|
||||||
for method, testBatch := range rpcClientTestCases {
|
for method, testBatch := range rpcClientTestCases {
|
||||||
t.Run(method, func(t *testing.T) {
|
t.Run(method, func(t *testing.T) {
|
||||||
for _, testCase := range testBatch {
|
for _, testCase := range testBatch {
|
||||||
|
@ -1437,7 +1454,7 @@ func TestRPCClient(t *testing.T) {
|
||||||
|
|
||||||
endpoint := srv.URL
|
endpoint := srv.URL
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
c, err := New(context.TODO(), endpoint, opts)
|
c, err := newClient(context.TODO(), endpoint, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -1461,7 +1478,7 @@ func TestRPCClient(t *testing.T) {
|
||||||
|
|
||||||
endpoint := srv.URL
|
endpoint := srv.URL
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
c, err := New(context.TODO(), endpoint, opts)
|
c, err := newClient(context.TODO(), endpoint, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -1475,8 +1492,31 @@ func TestRPCClient(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func httpURLtoWS(url string) string {
|
||||||
|
return "ws" + strings.TrimPrefix(url, "http") + "/ws"
|
||||||
|
}
|
||||||
|
|
||||||
func initTestServer(t *testing.T, resp string) *httptest.Server {
|
func initTestServer(t *testing.T, resp string) *httptest.Server {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.URL.Path == "/ws" && req.Method == "GET" {
|
||||||
|
var upgrader = websocket.Upgrader{}
|
||||||
|
ws, err := upgrader.Upgrade(w, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for {
|
||||||
|
ws.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
_, _, err = ws.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
ws.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
err = ws.WriteMessage(1, []byte(resp))
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
requestHandler(t, w, resp)
|
requestHandler(t, w, resp)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -1485,10 +1525,9 @@ func initTestServer(t *testing.T, resp string) *httptest.Server {
|
||||||
|
|
||||||
func requestHandler(t *testing.T, w http.ResponseWriter, resp string) {
|
func requestHandler(t *testing.T, w http.ResponseWriter, resp string) {
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
encoder := json.NewEncoder(w)
|
_, err := w.Write([]byte(resp))
|
||||||
err := encoder.Encode(json.RawMessage(resp))
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error encountered while encoding response: %s", err.Error())
|
t.Fatalf("Error writing response: %s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
160
pkg/rpc/client/wsclient.go
Normal file
160
pkg/rpc/client/wsclient.go
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WSClient is a websocket-enabled RPC client that can be used with appropriate
|
||||||
|
// servers. It's supposed to be faster than Client because it has persistent
|
||||||
|
// connection to the server and at the same time is exposes some functionality
|
||||||
|
// that is only provided via websockets (like event subscription mechanism).
|
||||||
|
type WSClient struct {
|
||||||
|
Client
|
||||||
|
ws *websocket.Conn
|
||||||
|
done chan struct{}
|
||||||
|
notifications chan *request.In
|
||||||
|
responses chan *response.Raw
|
||||||
|
requests chan *request.Raw
|
||||||
|
shutdown chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestResponse is a combined type for request and response since we can get
|
||||||
|
// any of them here.
|
||||||
|
type requestResponse struct {
|
||||||
|
request.In
|
||||||
|
Error *response.Error `json:"error,omitempty"`
|
||||||
|
Result json.RawMessage `json:"result,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Message limit for receiving side.
|
||||||
|
wsReadLimit = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
// Disconnection timeout.
|
||||||
|
wsPongLimit = 60 * time.Second
|
||||||
|
|
||||||
|
// Ping period for connection liveness check.
|
||||||
|
wsPingPeriod = wsPongLimit / 2
|
||||||
|
|
||||||
|
// Write deadline.
|
||||||
|
wsWriteLimit = wsPingPeriod / 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWS returns a new WSClient ready to use (with established websocket
|
||||||
|
// connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`.
|
||||||
|
func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) {
|
||||||
|
cl, err := New(ctx, endpoint, opts)
|
||||||
|
cl.cli = nil
|
||||||
|
|
||||||
|
dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout}
|
||||||
|
ws, _, err := dialer.Dial(endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
wsc := &WSClient{
|
||||||
|
Client: *cl,
|
||||||
|
ws: ws,
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
responses: make(chan *response.Raw),
|
||||||
|
requests: make(chan *request.Raw),
|
||||||
|
}
|
||||||
|
go wsc.wsReader()
|
||||||
|
go wsc.wsWriter()
|
||||||
|
wsc.requestF = wsc.makeWsRequest
|
||||||
|
return wsc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes connection to the remote side rendering this client instance
|
||||||
|
// unusable.
|
||||||
|
func (c *WSClient) Close() {
|
||||||
|
// Closing shutdown channel send signal to wsWriter to break out of the
|
||||||
|
// loop. In doing so it does ws.Close() closing the network connection
|
||||||
|
// which in turn makes wsReader receieve err from ws,ReadJSON() and also
|
||||||
|
// break out of the loop closing c.done channel in its shutdown sequence.
|
||||||
|
close(c.shutdown)
|
||||||
|
<-c.done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) wsReader() {
|
||||||
|
c.ws.SetReadLimit(wsReadLimit)
|
||||||
|
c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
|
||||||
|
for {
|
||||||
|
rr := new(requestResponse)
|
||||||
|
c.ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
||||||
|
err := c.ws.ReadJSON(rr)
|
||||||
|
if err != nil {
|
||||||
|
// Timeout/connection loss/malformed response.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rr.RawID == nil && rr.Method != "" {
|
||||||
|
if c.notifications != nil {
|
||||||
|
c.notifications <- &rr.In
|
||||||
|
}
|
||||||
|
} else if rr.RawID != nil && (rr.Error != nil || rr.Result != nil) {
|
||||||
|
resp := new(response.Raw)
|
||||||
|
resp.ID = rr.RawID
|
||||||
|
resp.JSONRPC = rr.JSONRPC
|
||||||
|
resp.Error = rr.Error
|
||||||
|
resp.Result = rr.Result
|
||||||
|
c.responses <- resp
|
||||||
|
} else {
|
||||||
|
// Malformed response, neither valid request, nor valid response.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(c.done)
|
||||||
|
close(c.responses)
|
||||||
|
if c.notifications != nil {
|
||||||
|
close(c.notifications)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) wsWriter() {
|
||||||
|
pingTicker := time.NewTicker(wsPingPeriod)
|
||||||
|
defer c.ws.Close()
|
||||||
|
defer pingTicker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.shutdown:
|
||||||
|
return
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
case req, ok := <-c.requests:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout))
|
||||||
|
if err := c.ws.WriteJSON(req); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-pingTicker.C:
|
||||||
|
c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
|
||||||
|
if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return nil, errors.New("connection lost")
|
||||||
|
case c.requests <- r:
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return nil, errors.New("connection lost")
|
||||||
|
case resp := <-c.responses:
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
16
pkg/rpc/client/wsclient_test.go
Normal file
16
pkg/rpc/client/wsclient_test.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWSClientClose(t *testing.T) {
|
||||||
|
srv := initTestServer(t, "")
|
||||||
|
defer srv.Close()
|
||||||
|
wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
wsc.Close()
|
||||||
|
}
|
|
@ -9,9 +9,9 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc"
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||||
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
|
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network"
|
"github.com/nspcc-dev/neo-go/pkg/network"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/rpc"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/rpc/response/result"
|
"github.com/nspcc-dev/neo-go/pkg/rpc/response/result"
|
||||||
|
@ -43,7 +44,21 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){
|
const (
|
||||||
|
// Message limit for receiving side.
|
||||||
|
wsReadLimit = 4096
|
||||||
|
|
||||||
|
// Disconnection timeout.
|
||||||
|
wsPongLimit = 60 * time.Second
|
||||||
|
|
||||||
|
// Ping period for connection liveness check.
|
||||||
|
wsPingPeriod = wsPongLimit / 2
|
||||||
|
|
||||||
|
// Write deadline.
|
||||||
|
wsWriteLimit = wsPingPeriod / 2
|
||||||
|
)
|
||||||
|
|
||||||
|
var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){
|
||||||
"getaccountstate": (*Server).getAccountState,
|
"getaccountstate": (*Server).getAccountState,
|
||||||
"getapplicationlog": (*Server).getApplicationLog,
|
"getapplicationlog": (*Server).getApplicationLog,
|
||||||
"getassetstate": (*Server).getAssetState,
|
"getassetstate": (*Server).getAssetState,
|
||||||
|
@ -76,10 +91,14 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){
|
||||||
"validateaddress": (*Server).validateAddress,
|
"validateaddress": (*Server).validateAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
var invalidBlockHeightError = func(index int, height int) error {
|
var invalidBlockHeightError = func(index int, height int) *response.Error {
|
||||||
return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height)
|
return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// upgrader is a no-op websocket.Upgrader that reuses HTTP server buffers and
|
||||||
|
// doesn't set any Error function.
|
||||||
|
var upgrader = websocket.Upgrader{}
|
||||||
|
|
||||||
// New creates a new Server struct.
|
// New creates a new Server struct.
|
||||||
func New(chain core.Blockchainer, conf rpc.Config, coreServer *network.Server, log *zap.Logger) Server {
|
func New(chain core.Blockchainer, conf rpc.Config, coreServer *network.Server, log *zap.Logger) Server {
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
|
@ -110,11 +129,11 @@ func (s *Server) Start(errChan chan error) {
|
||||||
s.log.Info("RPC server is not enabled")
|
s.log.Info("RPC server is not enabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.Handler = http.HandlerFunc(s.requestHandler)
|
s.Handler = http.HandlerFunc(s.handleHTTPRequest)
|
||||||
s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr))
|
s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr))
|
||||||
|
|
||||||
if cfg := s.config.TLSConfig; cfg.Enabled {
|
if cfg := s.config.TLSConfig; cfg.Enabled {
|
||||||
s.https.Handler = http.HandlerFunc(s.requestHandler)
|
s.https.Handler = http.HandlerFunc(s.handleHTTPRequest)
|
||||||
s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr))
|
s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr))
|
||||||
go func() {
|
go func() {
|
||||||
err := s.https.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile)
|
err := s.https.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile)
|
||||||
|
@ -148,11 +167,23 @@ func (s *Server) Shutdown() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) requestHandler(w http.ResponseWriter, httpRequest *http.Request) {
|
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
|
||||||
|
if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" {
|
||||||
|
ws, err := upgrader.Upgrade(w, httpRequest, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Info("websocket connection upgrade failed", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resChan := make(chan response.Raw)
|
||||||
|
go s.handleWsWrites(ws, resChan)
|
||||||
|
s.handleWsReads(ws, resChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
req := request.NewIn()
|
req := request.NewIn()
|
||||||
|
|
||||||
if httpRequest.Method != "POST" {
|
if httpRequest.Method != "POST" {
|
||||||
s.WriteErrorResponse(
|
s.writeHTTPErrorResponse(
|
||||||
req,
|
req,
|
||||||
w,
|
w,
|
||||||
response.NewInvalidParamsError(
|
response.NewInvalidParamsError(
|
||||||
|
@ -164,59 +195,90 @@ func (s *Server) requestHandler(w http.ResponseWriter, httpRequest *http.Request
|
||||||
|
|
||||||
err := req.DecodeData(httpRequest.Body)
|
err := req.DecodeData(httpRequest.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.WriteErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err))
|
s.writeHTTPErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp := s.handleRequest(req)
|
||||||
|
s.writeHTTPServerResponse(req, w, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleRequest(req *request.In) response.Raw {
|
||||||
reqParams, err := req.Params()
|
reqParams, err := req.Params()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.WriteErrorResponse(req, w, response.NewInvalidParamsError("Problem parsing request parameters", err))
|
return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.methodHandler(w, req, *reqParams)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) methodHandler(w http.ResponseWriter, req *request.In, reqParams request.Params) {
|
|
||||||
s.log.Debug("processing rpc request",
|
s.log.Debug("processing rpc request",
|
||||||
zap.String("method", req.Method),
|
zap.String("method", req.Method),
|
||||||
zap.String("params", fmt.Sprintf("%v", reqParams)))
|
zap.String("params", fmt.Sprintf("%v", reqParams)))
|
||||||
|
|
||||||
var (
|
|
||||||
results interface{}
|
|
||||||
resultsErr error
|
|
||||||
)
|
|
||||||
|
|
||||||
incCounter(req.Method)
|
incCounter(req.Method)
|
||||||
|
|
||||||
handler, ok := rpcHandlers[req.Method]
|
handler, ok := rpcHandlers[req.Method]
|
||||||
if ok {
|
if !ok {
|
||||||
results, resultsErr = handler(s, reqParams)
|
return s.packResponseToRaw(req, nil, response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil))
|
||||||
} else {
|
}
|
||||||
resultsErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)
|
res, resErr := handler(s, *reqParams)
|
||||||
|
return s.packResponseToRaw(req, res, resErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resultsErr != nil {
|
func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) {
|
||||||
s.WriteErrorResponse(req, w, resultsErr)
|
pingTicker := time.NewTicker(wsPingPeriod)
|
||||||
|
defer ws.Close()
|
||||||
|
defer pingTicker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case res, ok := <-resChan:
|
||||||
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
|
||||||
s.WriteResponse(req, w, results)
|
if err := ws.WriteJSON(res); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-pingTicker.C:
|
||||||
|
ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
|
||||||
|
if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getBestBlockHash(_ request.Params) (interface{}, error) {
|
func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw) {
|
||||||
|
ws.SetReadLimit(wsReadLimit)
|
||||||
|
ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
||||||
|
ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
|
||||||
|
for {
|
||||||
|
req := new(request.In)
|
||||||
|
err := ws.ReadJSON(req)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
res := s.handleRequest(req)
|
||||||
|
if res.Error != nil {
|
||||||
|
s.logRequestError(req, res.Error)
|
||||||
|
}
|
||||||
|
resChan <- res
|
||||||
|
}
|
||||||
|
close(resChan)
|
||||||
|
ws.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getBestBlockHash(_ request.Params) (interface{}, *response.Error) {
|
||||||
return "0x" + s.chain.CurrentBlockHash().StringLE(), nil
|
return "0x" + s.chain.CurrentBlockHash().StringLE(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getBlockCount(_ request.Params) (interface{}, error) {
|
func (s *Server) getBlockCount(_ request.Params) (interface{}, *response.Error) {
|
||||||
return s.chain.BlockHeight() + 1, nil
|
return s.chain.BlockHeight() + 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getConnectionCount(_ request.Params) (interface{}, error) {
|
func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Error) {
|
||||||
return s.coreServer.PeerCount(), nil
|
return s.coreServer.PeerCount(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getBlock(reqParams request.Params) (interface{}, error) {
|
func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
var hash util.Uint256
|
var hash util.Uint256
|
||||||
|
|
||||||
param, ok := reqParams.Value(0)
|
param, ok := reqParams.Value(0)
|
||||||
|
@ -254,7 +316,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, error) {
|
||||||
return hex.EncodeToString(writer.Bytes()), nil
|
return hex.EncodeToString(writer.Bytes()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) {
|
func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -267,7 +329,7 @@ func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) {
|
||||||
return s.chain.GetHeaderHash(num), nil
|
return s.chain.GetHeaderHash(num), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getVersion(_ request.Params) (interface{}, error) {
|
func (s *Server) getVersion(_ request.Params) (interface{}, *response.Error) {
|
||||||
return result.Version{
|
return result.Version{
|
||||||
Port: s.coreServer.Port,
|
Port: s.coreServer.Port,
|
||||||
Nonce: s.coreServer.ID(),
|
Nonce: s.coreServer.ID(),
|
||||||
|
@ -275,7 +337,7 @@ func (s *Server) getVersion(_ request.Params) (interface{}, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getPeers(_ request.Params) (interface{}, error) {
|
func (s *Server) getPeers(_ request.Params) (interface{}, *response.Error) {
|
||||||
peers := result.NewGetPeers()
|
peers := result.NewGetPeers()
|
||||||
peers.AddUnconnected(s.coreServer.UnconnectedPeers())
|
peers.AddUnconnected(s.coreServer.UnconnectedPeers())
|
||||||
peers.AddConnected(s.coreServer.ConnectedPeers())
|
peers.AddConnected(s.coreServer.ConnectedPeers())
|
||||||
|
@ -283,7 +345,7 @@ func (s *Server) getPeers(_ request.Params) (interface{}, error) {
|
||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getRawMempool(_ request.Params) (interface{}, error) {
|
func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) {
|
||||||
mp := s.chain.GetMemPool()
|
mp := s.chain.GetMemPool()
|
||||||
hashList := make([]util.Uint256, 0)
|
hashList := make([]util.Uint256, 0)
|
||||||
for _, item := range mp.GetVerifiedTransactions() {
|
for _, item := range mp.GetVerifiedTransactions() {
|
||||||
|
@ -292,7 +354,7 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, error) {
|
||||||
return hashList, nil
|
return hashList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) validateAddress(reqParams request.Params) (interface{}, error) {
|
func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
param, ok := reqParams.Value(0)
|
param, ok := reqParams.Value(0)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -300,7 +362,7 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, error)
|
||||||
return validateAddress(param.Value), nil
|
return validateAddress(param.Value), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) {
|
func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -319,7 +381,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// getApplicationLog returns the contract log based on the specified txid.
|
// getApplicationLog returns the contract log based on the specified txid.
|
||||||
func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error) {
|
func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
param, ok := reqParams.Value(0)
|
param, ok := reqParams.Value(0)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -351,7 +413,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error
|
||||||
return result.NewApplicationLog(appExecResult, scriptHash), nil
|
return result.NewApplicationLog(appExecResult, scriptHash), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getClaimable(ps request.Params) (interface{}, error) {
|
func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) {
|
||||||
p, ok := ps.ValueWithType(0, request.StringT)
|
p, ok := ps.ValueWithType(0, request.StringT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -368,7 +430,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, response.NewInternalServerError("Unclaimed processing failure", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -403,7 +465,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) {
|
func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) {
|
||||||
p, ok := ps.ValueWithType(0, request.StringT)
|
p, ok := ps.ValueWithType(0, request.StringT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -436,7 +498,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) {
|
||||||
return bs, nil
|
return bs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, error) {
|
func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) {
|
||||||
p, ok := ps.ValueWithType(0, request.StringT)
|
p, ok := ps.ValueWithType(0, request.StringT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -500,7 +562,7 @@ func amountToString(amount int64, decimals int64) string {
|
||||||
return fmt.Sprintf(fs, q, r)
|
return fmt.Sprintf(fs, q, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, error) {
|
func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, *response.Error) {
|
||||||
if d, ok := cache[h]; ok {
|
if d, ok := cache[h]; ok {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
@ -515,11 +577,11 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, response.NewInternalServerError("Can't create script", err)
|
||||||
}
|
}
|
||||||
res := s.runScriptInVM(script)
|
res := s.runScriptInVM(script)
|
||||||
if res == nil || res.State != "HALT" || len(res.Stack) == 0 {
|
if res == nil || res.State != "HALT" || len(res.Stack) == 0 {
|
||||||
return 0, errors.New("execution error")
|
return 0, response.NewInternalServerError("execution error", errors.New("no result"))
|
||||||
}
|
}
|
||||||
|
|
||||||
var d int64
|
var d int64
|
||||||
|
@ -529,16 +591,16 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6
|
||||||
case smartcontract.ByteArrayType:
|
case smartcontract.ByteArrayType:
|
||||||
d = emit.BytesToInt(item.Value.([]byte)).Int64()
|
d = emit.BytesToInt(item.Value.([]byte)).Int64()
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("invalid result")
|
return 0, response.NewInternalServerError("invalid result", errors.New("not an integer"))
|
||||||
}
|
}
|
||||||
if d < 0 {
|
if d < 0 {
|
||||||
return 0, errors.New("negative decimals")
|
return 0, response.NewInternalServerError("incorrect result", errors.New("negative result"))
|
||||||
}
|
}
|
||||||
cache[h] = d
|
cache[h] = d
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getStorage(ps request.Params) (interface{}, error) {
|
func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
|
||||||
param, ok := ps.Value(0)
|
param, ok := ps.Value(0)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -569,8 +631,8 @@ func (s *Server) getStorage(ps request.Params) (interface{}, error) {
|
||||||
return hex.EncodeToString(item.Value), nil
|
return hex.EncodeToString(item.Value), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error) {
|
func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
var resultsErr error
|
var resultsErr *response.Error
|
||||||
var results interface{}
|
var results interface{}
|
||||||
|
|
||||||
if param0, ok := reqParams.Value(0); !ok {
|
if param0, ok := reqParams.Value(0); !ok {
|
||||||
|
@ -606,7 +668,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error
|
||||||
return results, resultsErr
|
return results, resultsErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) {
|
func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) {
|
||||||
p, ok := ps.Value(0)
|
p, ok := ps.Value(0)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -625,7 +687,7 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) {
|
||||||
return height, nil
|
return height, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getTxOut(ps request.Params) (interface{}, error) {
|
func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) {
|
||||||
p, ok := ps.Value(0)
|
p, ok := ps.Value(0)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -660,7 +722,7 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// getContractState returns contract state (contract information, according to the contract script hash).
|
// getContractState returns contract state (contract information, according to the contract script hash).
|
||||||
func (s *Server) getContractState(reqParams request.Params) (interface{}, error) {
|
func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
var results interface{}
|
var results interface{}
|
||||||
|
|
||||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||||
|
@ -679,17 +741,17 @@ func (s *Server) getContractState(reqParams request.Params) (interface{}, error)
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getAccountState(ps request.Params) (interface{}, error) {
|
func (s *Server) getAccountState(ps request.Params) (interface{}, *response.Error) {
|
||||||
return s.getAccountStateAux(ps, false)
|
return s.getAccountStateAux(ps, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getUnspents(ps request.Params) (interface{}, error) {
|
func (s *Server) getUnspents(ps request.Params) (interface{}, *response.Error) {
|
||||||
return s.getAccountStateAux(ps, true)
|
return s.getAccountStateAux(ps, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAccountState returns account state either in short or full (unspents included) form.
|
// getAccountState returns account state either in short or full (unspents included) form.
|
||||||
func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, error) {
|
func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, *response.Error) {
|
||||||
var resultsErr error
|
var resultsErr *response.Error
|
||||||
var results interface{}
|
var results interface{}
|
||||||
|
|
||||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||||
|
@ -716,7 +778,7 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in
|
||||||
}
|
}
|
||||||
|
|
||||||
// getBlockSysFee returns the system fees of the block, based on the specified index.
|
// getBlockSysFee returns the system fees of the block, based on the specified index.
|
||||||
func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) {
|
func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, response.ErrInvalidParams
|
return 0, response.ErrInvalidParams
|
||||||
|
@ -728,9 +790,9 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
headerHash := s.chain.GetHeaderHash(num)
|
headerHash := s.chain.GetHeaderHash(num)
|
||||||
block, err := s.chain.GetBlock(headerHash)
|
block, errBlock := s.chain.GetBlock(headerHash)
|
||||||
if err != nil {
|
if errBlock != nil {
|
||||||
return 0, response.NewRPCError(err.Error(), "", nil)
|
return 0, response.NewRPCError(errBlock.Error(), "", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var blockSysFee util.Fixed8
|
var blockSysFee util.Fixed8
|
||||||
|
@ -742,7 +804,7 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// getBlockHeader returns the corresponding block header information according to the specified script hash.
|
// getBlockHeader returns the corresponding block header information according to the specified script hash.
|
||||||
func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) {
|
func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
var verbose bool
|
var verbose bool
|
||||||
|
|
||||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||||
|
@ -775,13 +837,13 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) {
|
||||||
buf := io.NewBufBinWriter()
|
buf := io.NewBufBinWriter()
|
||||||
h.EncodeBinary(buf.BinWriter)
|
h.EncodeBinary(buf.BinWriter)
|
||||||
if buf.Err != nil {
|
if buf.Err != nil {
|
||||||
return nil, err
|
return nil, response.NewInternalServerError("encoding error", buf.Err)
|
||||||
}
|
}
|
||||||
return hex.EncodeToString(buf.Bytes()), nil
|
return hex.EncodeToString(buf.Bytes()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUnclaimed returns unclaimed GAS amount of the specified address.
|
// getUnclaimed returns unclaimed GAS amount of the specified address.
|
||||||
func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) {
|
func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) {
|
||||||
p, ok := ps.ValueWithType(0, request.StringT)
|
p, ok := ps.ValueWithType(0, request.StringT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -795,21 +857,24 @@ func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) {
|
||||||
if acc == nil {
|
if acc == nil {
|
||||||
return nil, response.NewInternalServerError("unknown account", nil)
|
return nil, response.NewInternalServerError("unknown account", nil)
|
||||||
}
|
}
|
||||||
|
res, errRes := result.NewUnclaimed(acc, s.chain)
|
||||||
return result.NewUnclaimed(acc, s.chain)
|
if errRes != nil {
|
||||||
|
return nil, response.NewInternalServerError("can't create unclaimed response", errRes)
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getValidators returns the current NEO consensus nodes information and voting status.
|
// getValidators returns the current NEO consensus nodes information and voting status.
|
||||||
func (s *Server) getValidators(_ request.Params) (interface{}, error) {
|
func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) {
|
||||||
var validators keys.PublicKeys
|
var validators keys.PublicKeys
|
||||||
|
|
||||||
validators, err := s.chain.GetValidators()
|
validators, err := s.chain.GetValidators()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, response.NewRPCError("can't get validators", "", err)
|
||||||
}
|
}
|
||||||
enrollments, err := s.chain.GetEnrollments()
|
enrollments, err := s.chain.GetEnrollments()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, response.NewRPCError("can't get enrollments", "", err)
|
||||||
}
|
}
|
||||||
var res []result.Validator
|
var res []result.Validator
|
||||||
for _, v := range enrollments {
|
for _, v := range enrollments {
|
||||||
|
@ -823,14 +888,14 @@ func (s *Server) getValidators(_ request.Params) (interface{}, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// invoke implements the `invoke` RPC call.
|
// invoke implements the `invoke` RPC call.
|
||||||
func (s *Server) invoke(reqParams request.Params) (interface{}, error) {
|
func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
|
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
}
|
}
|
||||||
scriptHash, err := scriptHashHex.GetUint160FromHex()
|
scriptHash, err := scriptHashHex.GetUint160FromHex()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, response.ErrInvalidParams
|
||||||
}
|
}
|
||||||
sliceP, ok := reqParams.ValueWithType(1, request.ArrayT)
|
sliceP, ok := reqParams.ValueWithType(1, request.ArrayT)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -838,34 +903,34 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, error) {
|
||||||
}
|
}
|
||||||
slice, err := sliceP.GetArray()
|
slice, err := sliceP.GetArray()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, response.ErrInvalidParams
|
||||||
}
|
}
|
||||||
script, err := request.CreateInvocationScript(scriptHash, slice)
|
script, err := request.CreateInvocationScript(scriptHash, slice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, response.NewInternalServerError("can't create invocation script", err)
|
||||||
}
|
}
|
||||||
return s.runScriptInVM(script), nil
|
return s.runScriptInVM(script), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// invokescript implements the `invokescript` RPC call.
|
// invokescript implements the `invokescript` RPC call.
|
||||||
func (s *Server) invokeFunction(reqParams request.Params) (interface{}, error) {
|
func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
|
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
}
|
}
|
||||||
scriptHash, err := scriptHashHex.GetUint160FromHex()
|
scriptHash, err := scriptHashHex.GetUint160FromHex()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, response.ErrInvalidParams
|
||||||
}
|
}
|
||||||
script, err := request.CreateFunctionInvocationScript(scriptHash, reqParams[1:])
|
script, err := request.CreateFunctionInvocationScript(scriptHash, reqParams[1:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, response.NewInternalServerError("can't create invocation script", err)
|
||||||
}
|
}
|
||||||
return s.runScriptInVM(script), nil
|
return s.runScriptInVM(script), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// invokescript implements the `invokescript` RPC call.
|
// invokescript implements the `invokescript` RPC call.
|
||||||
func (s *Server) invokescript(reqParams request.Params) (interface{}, error) {
|
func (s *Server) invokescript(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
if len(reqParams) < 1 {
|
if len(reqParams) < 1 {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
}
|
}
|
||||||
|
@ -895,7 +960,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke {
|
||||||
}
|
}
|
||||||
|
|
||||||
// submitBlock broadcasts a raw block over the NEO network.
|
// submitBlock broadcasts a raw block over the NEO network.
|
||||||
func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) {
|
func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, response.ErrInvalidParams
|
return nil, response.ErrInvalidParams
|
||||||
|
@ -922,8 +987,8 @@ func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, error) {
|
func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *response.Error) {
|
||||||
var resultsErr error
|
var resultsErr *response.Error
|
||||||
var results interface{}
|
var results interface{}
|
||||||
|
|
||||||
if len(reqParams) < 1 {
|
if len(reqParams) < 1 {
|
||||||
|
@ -959,7 +1024,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, erro
|
||||||
return results, resultsErr
|
return results, resultsErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) blockHeightFromParam(param *request.Param) (int, error) {
|
func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) {
|
||||||
num, err := param.GetInt()
|
num, err := param.GetInt()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
@ -971,23 +1036,33 @@ func (s *Server) blockHeightFromParam(param *request.Param) (int, error) {
|
||||||
return num, nil
|
return num, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteErrorResponse writes an error response to the ResponseWriter.
|
func (s *Server) packResponseToRaw(r *request.In, result interface{}, respErr *response.Error) response.Raw {
|
||||||
func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, err error) {
|
|
||||||
jsonErr, ok := err.(*response.Error)
|
|
||||||
if !ok {
|
|
||||||
jsonErr = response.NewInternalServerError("Internal server error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := response.Raw{
|
resp := response.Raw{
|
||||||
HeaderAndError: response.HeaderAndError{
|
HeaderAndError: response.HeaderAndError{
|
||||||
Header: response.Header{
|
Header: response.Header{
|
||||||
JSONRPC: r.JSONRPC,
|
JSONRPC: r.JSONRPC,
|
||||||
ID: r.RawID,
|
ID: r.RawID,
|
||||||
},
|
},
|
||||||
Error: jsonErr,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
if respErr != nil {
|
||||||
|
resp.Error = respErr
|
||||||
|
} else {
|
||||||
|
resJSON, err := json.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error("failed to marshal result",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("method", r.Method))
|
||||||
|
resp.Error = response.NewInternalServerError("failed to encode result", err)
|
||||||
|
} else {
|
||||||
|
resp.Result = resJSON
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// logRequestError is a request error logger.
|
||||||
|
func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) {
|
||||||
logFields := []zap.Field{
|
logFields := []zap.Field{
|
||||||
zap.Error(jsonErr.Cause),
|
zap.Error(jsonErr.Cause),
|
||||||
zap.String("method", r.Method),
|
zap.String("method", r.Method),
|
||||||
|
@ -999,35 +1074,20 @@ func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, err er
|
||||||
}
|
}
|
||||||
|
|
||||||
s.log.Error("Error encountered with rpc request", logFields...)
|
s.log.Error("Error encountered with rpc request", logFields...)
|
||||||
|
|
||||||
w.WriteHeader(jsonErr.HTTPCode)
|
|
||||||
s.writeServerResponse(r, w, resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteResponse encodes the response and writes it to the ResponseWriter.
|
// writeHTTPErrorResponse writes an error response to the ResponseWriter.
|
||||||
func (s *Server) WriteResponse(r *request.In, w http.ResponseWriter, result interface{}) {
|
func (s *Server) writeHTTPErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) {
|
||||||
resJSON, err := json.Marshal(result)
|
resp := s.packResponseToRaw(r, nil, jsonErr)
|
||||||
if err != nil {
|
s.writeHTTPServerResponse(r, w, resp)
|
||||||
s.log.Error("Error encountered while encoding response",
|
|
||||||
zap.String("err", err.Error()),
|
|
||||||
zap.String("method", r.Method))
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := response.Raw{
|
func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) {
|
||||||
HeaderAndError: response.HeaderAndError{
|
// Errors can happen in many places and we can only catch ALL of them here.
|
||||||
Header: response.Header{
|
if resp.Error != nil {
|
||||||
JSONRPC: r.JSONRPC,
|
s.logRequestError(r, resp.Error)
|
||||||
ID: r.RawID,
|
w.WriteHeader(resp.Error.HTTPCode)
|
||||||
},
|
|
||||||
},
|
|
||||||
Result: resJSON,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.writeServerResponse(r, w, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) writeServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) {
|
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
if s.config.EnableCORSWorkaround {
|
if s.config.EnableCORSWorkaround {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
|
|
@ -2,6 +2,7 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -17,7 +18,7 @@ import (
|
||||||
"go.uber.org/zap/zaptest"
|
"go.uber.org/zap/zaptest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, http.HandlerFunc) {
|
func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *httptest.Server) {
|
||||||
var nBlocks uint32
|
var nBlocks uint32
|
||||||
|
|
||||||
net := config.ModeUnitTestNet
|
net := config.ModeUnitTestNet
|
||||||
|
@ -55,9 +56,11 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, http.HandlerFu
|
||||||
server, err := network.NewServer(serverConfig, chain, logger)
|
server, err := network.NewServer(serverConfig, chain, logger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
rpcServer := New(chain, cfg.ApplicationConfiguration.RPC, server, logger)
|
rpcServer := New(chain, cfg.ApplicationConfiguration.RPC, server, logger)
|
||||||
handler := http.HandlerFunc(rpcServer.requestHandler)
|
|
||||||
|
|
||||||
return chain, handler
|
handler := http.HandlerFunc(rpcServer.handleHTTPRequest)
|
||||||
|
srv := httptest.NewServer(handler)
|
||||||
|
|
||||||
|
return chain, srv
|
||||||
}
|
}
|
||||||
|
|
||||||
type FeerStub struct{}
|
type FeerStub struct{}
|
||||||
|
|
|
@ -12,7 +12,9 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||||
|
@ -28,7 +30,7 @@ import (
|
||||||
|
|
||||||
type executor struct {
|
type executor struct {
|
||||||
chain *core.Blockchain
|
chain *core.Blockchain
|
||||||
handler http.HandlerFunc
|
httpSrv *httptest.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -881,18 +883,31 @@ var rpcTestCases = map[string][]rpcTestCase{
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRPC(t *testing.T) {
|
func TestRPC(t *testing.T) {
|
||||||
chain, handler := initServerWithInMemoryChain(t)
|
t.Run("http", func(t *testing.T) {
|
||||||
|
testRPCProtocol(t, doRPCCallOverHTTP)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("websocket", func(t *testing.T) {
|
||||||
|
testRPCProtocol(t, doRPCCallOverWS)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// testRPCProtocol runs a full set of tests using given callback to make actual
|
||||||
|
// calls. Some tests change the chain state, thus we reinitialize the chain from
|
||||||
|
// scratch here.
|
||||||
|
func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []byte) {
|
||||||
|
chain, httpSrv := initServerWithInMemoryChain(t)
|
||||||
|
|
||||||
defer chain.Close()
|
defer chain.Close()
|
||||||
|
|
||||||
e := &executor{chain: chain, handler: handler}
|
e := &executor{chain: chain, httpSrv: httpSrv}
|
||||||
for method, cases := range rpcTestCases {
|
for method, cases := range rpcTestCases {
|
||||||
t.Run(method, func(t *testing.T) {
|
t.Run(method, func(t *testing.T) {
|
||||||
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}`
|
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}`
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), handler, t)
|
body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t)
|
||||||
result := checkErrGetResult(t, body, tc.fail)
|
result := checkErrGetResult(t, body, tc.fail)
|
||||||
if tc.fail {
|
if tc.fail {
|
||||||
return
|
return
|
||||||
|
@ -916,7 +931,7 @@ func TestRPC(t *testing.T) {
|
||||||
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
||||||
TXHash := block.Transactions[1].Hash()
|
TXHash := block.Transactions[1].Hash()
|
||||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s"]}"`, TXHash.StringLE())
|
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s"]}"`, TXHash.StringLE())
|
||||||
body := doRPCCall(rpc, handler, t)
|
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||||
result := checkErrGetResult(t, body, false)
|
result := checkErrGetResult(t, body, false)
|
||||||
var res string
|
var res string
|
||||||
err := json.Unmarshal(result, &res)
|
err := json.Unmarshal(result, &res)
|
||||||
|
@ -928,7 +943,7 @@ func TestRPC(t *testing.T) {
|
||||||
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
||||||
TXHash := block.Transactions[1].Hash()
|
TXHash := block.Transactions[1].Hash()
|
||||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 0]}"`, TXHash.StringLE())
|
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 0]}"`, TXHash.StringLE())
|
||||||
body := doRPCCall(rpc, handler, t)
|
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||||
result := checkErrGetResult(t, body, false)
|
result := checkErrGetResult(t, body, false)
|
||||||
var res string
|
var res string
|
||||||
err := json.Unmarshal(result, &res)
|
err := json.Unmarshal(result, &res)
|
||||||
|
@ -940,7 +955,7 @@ func TestRPC(t *testing.T) {
|
||||||
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
||||||
TXHash := block.Transactions[1].Hash()
|
TXHash := block.Transactions[1].Hash()
|
||||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 1]}"`, TXHash.StringLE())
|
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 1]}"`, TXHash.StringLE())
|
||||||
body := doRPCCall(rpc, handler, t)
|
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||||
txOut := checkErrGetResult(t, body, false)
|
txOut := checkErrGetResult(t, body, false)
|
||||||
actual := result.TransactionOutputRaw{}
|
actual := result.TransactionOutputRaw{}
|
||||||
err := json.Unmarshal(txOut, &actual)
|
err := json.Unmarshal(txOut, &actual)
|
||||||
|
@ -966,7 +981,7 @@ func TestRPC(t *testing.T) {
|
||||||
tx := block.Transactions[3]
|
tx := block.Transactions[3]
|
||||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "gettxout", "params": [%s, %d]}"`,
|
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "gettxout", "params": [%s, %d]}"`,
|
||||||
`"`+tx.Hash().StringLE()+`"`, 0)
|
`"`+tx.Hash().StringLE()+`"`, 0)
|
||||||
body := doRPCCall(rpc, handler, t)
|
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||||
res := checkErrGetResult(t, body, false)
|
res := checkErrGetResult(t, body, false)
|
||||||
|
|
||||||
var txOut result.TransactionOutput
|
var txOut result.TransactionOutput
|
||||||
|
@ -997,7 +1012,7 @@ func TestRPC(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getrawmempool", "params": []}`
|
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getrawmempool", "params": []}`
|
||||||
body := doRPCCall(rpc, handler, t)
|
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||||
res := checkErrGetResult(t, body, false)
|
res := checkErrGetResult(t, body, false)
|
||||||
|
|
||||||
var actual []util.Uint256
|
var actual []util.Uint256
|
||||||
|
@ -1027,12 +1042,23 @@ func checkErrGetResult(t *testing.T, body []byte, expectingFail bool) json.RawMe
|
||||||
return resp.Result
|
return resp.Result
|
||||||
}
|
}
|
||||||
|
|
||||||
func doRPCCall(rpcCall string, handler http.HandlerFunc, t *testing.T) []byte {
|
func doRPCCallOverWS(rpcCall string, url string, t *testing.T) []byte {
|
||||||
req := httptest.NewRequest("POST", "http://0.0.0.0:20333/", strings.NewReader(rpcCall))
|
dialer := websocket.Dialer{HandshakeTimeout: time.Second}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
url = "ws" + strings.TrimPrefix(url, "http")
|
||||||
w := httptest.NewRecorder()
|
c, _, err := dialer.Dial(url+"/ws", nil)
|
||||||
handler(w, req)
|
require.NoError(t, err)
|
||||||
resp := w.Result()
|
c.SetWriteDeadline(time.Now().Add(time.Second))
|
||||||
|
require.NoError(t, c.WriteMessage(1, []byte(rpcCall)))
|
||||||
|
c.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
_, body, err := c.ReadMessage()
|
||||||
|
require.NoError(t, err)
|
||||||
|
return bytes.TrimSpace(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doRPCCallOverHTTP(rpcCall string, url string, t *testing.T) []byte {
|
||||||
|
cl := http.Client{Timeout: time.Second}
|
||||||
|
resp, err := cl.Post(url, "application/json", strings.NewReader(rpcCall))
|
||||||
|
require.NoErrorf(t, err, "could not make a POST request")
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
assert.NoErrorf(t, err, "could not read response from the request: %s", rpcCall)
|
assert.NoErrorf(t, err, "could not read response from the request: %s", rpcCall)
|
||||||
return bytes.TrimSpace(body)
|
return bytes.TrimSpace(body)
|
||||||
|
|
Loading…
Reference in a new issue