Merge pull request #925 from nspcc-dev/rpc-over-websocket
RPC over websocket
This commit is contained in:
commit
b04c8623c5
10 changed files with 543 additions and 255 deletions
|
@ -9,12 +9,12 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||
"github.com/nspcc-dev/neo-go/pkg/rpc"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
|
||||
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network"
|
||||
"github.com/nspcc-dev/neo-go/pkg/rpc"
|
||||
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
|
||||
"github.com/nspcc-dev/neo-go/pkg/rpc/response"
|
||||
"github.com/nspcc-dev/neo-go/pkg/rpc/response/result"
|
||||
|
@ -44,7 +45,21 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){
|
||||
const (
|
||||
// Message limit for receiving side.
|
||||
wsReadLimit = 4096
|
||||
|
||||
// Disconnection timeout.
|
||||
wsPongLimit = 60 * time.Second
|
||||
|
||||
// Ping period for connection liveness check.
|
||||
wsPingPeriod = wsPongLimit / 2
|
||||
|
||||
// Write deadline.
|
||||
wsWriteLimit = wsPingPeriod / 2
|
||||
)
|
||||
|
||||
var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){
|
||||
"getaccountstate": (*Server).getAccountState,
|
||||
"getapplicationlog": (*Server).getApplicationLog,
|
||||
"getassetstate": (*Server).getAssetState,
|
||||
|
@ -77,10 +92,14 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){
|
|||
"validateaddress": (*Server).validateAddress,
|
||||
}
|
||||
|
||||
var invalidBlockHeightError = func(index int, height int) error {
|
||||
return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height)
|
||||
var invalidBlockHeightError = func(index int, height int) *response.Error {
|
||||
return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil)
|
||||
}
|
||||
|
||||
// upgrader is a no-op websocket.Upgrader that reuses HTTP server buffers and
|
||||
// doesn't set any Error function.
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
// New creates a new Server struct.
|
||||
func New(chain blockchainer.Blockchainer, conf rpc.Config, coreServer *network.Server, log *zap.Logger) Server {
|
||||
httpServer := &http.Server{
|
||||
|
@ -111,11 +130,11 @@ func (s *Server) Start(errChan chan error) {
|
|||
s.log.Info("RPC server is not enabled")
|
||||
return
|
||||
}
|
||||
s.Handler = http.HandlerFunc(s.requestHandler)
|
||||
s.Handler = http.HandlerFunc(s.handleHTTPRequest)
|
||||
s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr))
|
||||
|
||||
if cfg := s.config.TLSConfig; cfg.Enabled {
|
||||
s.https.Handler = http.HandlerFunc(s.requestHandler)
|
||||
s.https.Handler = http.HandlerFunc(s.handleHTTPRequest)
|
||||
s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr))
|
||||
go func() {
|
||||
err := s.https.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile)
|
||||
|
@ -149,11 +168,23 @@ func (s *Server) Shutdown() error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *Server) requestHandler(w http.ResponseWriter, httpRequest *http.Request) {
|
||||
func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
|
||||
if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" {
|
||||
ws, err := upgrader.Upgrade(w, httpRequest, nil)
|
||||
if err != nil {
|
||||
s.log.Info("websocket connection upgrade failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
resChan := make(chan response.Raw)
|
||||
go s.handleWsWrites(ws, resChan)
|
||||
s.handleWsReads(ws, resChan)
|
||||
return
|
||||
}
|
||||
|
||||
req := request.NewIn()
|
||||
|
||||
if httpRequest.Method != "POST" {
|
||||
s.WriteErrorResponse(
|
||||
s.writeHTTPErrorResponse(
|
||||
req,
|
||||
w,
|
||||
response.NewInvalidParamsError(
|
||||
|
@ -165,59 +196,90 @@ func (s *Server) requestHandler(w http.ResponseWriter, httpRequest *http.Request
|
|||
|
||||
err := req.DecodeData(httpRequest.Body)
|
||||
if err != nil {
|
||||
s.WriteErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err))
|
||||
s.writeHTTPErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err))
|
||||
return
|
||||
}
|
||||
|
||||
reqParams, err := req.Params()
|
||||
if err != nil {
|
||||
s.WriteErrorResponse(req, w, response.NewInvalidParamsError("Problem parsing request parameters", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.methodHandler(w, req, *reqParams)
|
||||
resp := s.handleRequest(req)
|
||||
s.writeHTTPServerResponse(req, w, resp)
|
||||
}
|
||||
|
||||
func (s *Server) methodHandler(w http.ResponseWriter, req *request.In, reqParams request.Params) {
|
||||
func (s *Server) handleRequest(req *request.In) response.Raw {
|
||||
reqParams, err := req.Params()
|
||||
if err != nil {
|
||||
return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err))
|
||||
}
|
||||
|
||||
s.log.Debug("processing rpc request",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("params", fmt.Sprintf("%v", reqParams)))
|
||||
|
||||
var (
|
||||
results interface{}
|
||||
resultsErr error
|
||||
)
|
||||
|
||||
incCounter(req.Method)
|
||||
|
||||
handler, ok := rpcHandlers[req.Method]
|
||||
if ok {
|
||||
results, resultsErr = handler(s, reqParams)
|
||||
} else {
|
||||
resultsErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)
|
||||
if !ok {
|
||||
return s.packResponseToRaw(req, nil, response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil))
|
||||
}
|
||||
|
||||
if resultsErr != nil {
|
||||
s.WriteErrorResponse(req, w, resultsErr)
|
||||
return
|
||||
}
|
||||
|
||||
s.WriteResponse(req, w, results)
|
||||
res, resErr := handler(s, *reqParams)
|
||||
return s.packResponseToRaw(req, res, resErr)
|
||||
}
|
||||
|
||||
func (s *Server) getBestBlockHash(_ request.Params) (interface{}, error) {
|
||||
func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) {
|
||||
pingTicker := time.NewTicker(wsPingPeriod)
|
||||
defer ws.Close()
|
||||
defer pingTicker.Stop()
|
||||
for {
|
||||
select {
|
||||
case res, ok := <-resChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
|
||||
if err := ws.WriteJSON(res); err != nil {
|
||||
return
|
||||
}
|
||||
case <-pingTicker.C:
|
||||
ws.SetWriteDeadline(time.Now().Add(wsWriteLimit))
|
||||
if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw) {
|
||||
ws.SetReadLimit(wsReadLimit)
|
||||
ws.SetReadDeadline(time.Now().Add(wsPongLimit))
|
||||
ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil })
|
||||
for {
|
||||
req := new(request.In)
|
||||
err := ws.ReadJSON(req)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
res := s.handleRequest(req)
|
||||
if res.Error != nil {
|
||||
s.logRequestError(req, res.Error)
|
||||
}
|
||||
resChan <- res
|
||||
}
|
||||
close(resChan)
|
||||
ws.Close()
|
||||
}
|
||||
|
||||
func (s *Server) getBestBlockHash(_ request.Params) (interface{}, *response.Error) {
|
||||
return "0x" + s.chain.CurrentBlockHash().StringLE(), nil
|
||||
}
|
||||
|
||||
func (s *Server) getBlockCount(_ request.Params) (interface{}, error) {
|
||||
func (s *Server) getBlockCount(_ request.Params) (interface{}, *response.Error) {
|
||||
return s.chain.BlockHeight() + 1, nil
|
||||
}
|
||||
|
||||
func (s *Server) getConnectionCount(_ request.Params) (interface{}, error) {
|
||||
func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Error) {
|
||||
return s.coreServer.PeerCount(), nil
|
||||
}
|
||||
|
||||
func (s *Server) getBlock(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var hash util.Uint256
|
||||
|
||||
param, ok := reqParams.Value(0)
|
||||
|
@ -255,7 +317,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, error) {
|
|||
return hex.EncodeToString(writer.Bytes()), nil
|
||||
}
|
||||
|
||||
func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -268,7 +330,7 @@ func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) {
|
|||
return s.chain.GetHeaderHash(num), nil
|
||||
}
|
||||
|
||||
func (s *Server) getVersion(_ request.Params) (interface{}, error) {
|
||||
func (s *Server) getVersion(_ request.Params) (interface{}, *response.Error) {
|
||||
return result.Version{
|
||||
Port: s.coreServer.Port,
|
||||
Nonce: s.coreServer.ID(),
|
||||
|
@ -276,7 +338,7 @@ func (s *Server) getVersion(_ request.Params) (interface{}, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) getPeers(_ request.Params) (interface{}, error) {
|
||||
func (s *Server) getPeers(_ request.Params) (interface{}, *response.Error) {
|
||||
peers := result.NewGetPeers()
|
||||
peers.AddUnconnected(s.coreServer.UnconnectedPeers())
|
||||
peers.AddConnected(s.coreServer.ConnectedPeers())
|
||||
|
@ -284,7 +346,7 @@ func (s *Server) getPeers(_ request.Params) (interface{}, error) {
|
|||
return peers, nil
|
||||
}
|
||||
|
||||
func (s *Server) getRawMempool(_ request.Params) (interface{}, error) {
|
||||
func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) {
|
||||
mp := s.chain.GetMemPool()
|
||||
hashList := make([]util.Uint256, 0)
|
||||
for _, item := range mp.GetVerifiedTransactions() {
|
||||
|
@ -293,7 +355,7 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, error) {
|
|||
return hashList, nil
|
||||
}
|
||||
|
||||
func (s *Server) validateAddress(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -301,7 +363,7 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, error)
|
|||
return validateAddress(param.Value), nil
|
||||
}
|
||||
|
||||
func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -320,7 +382,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) {
|
|||
}
|
||||
|
||||
// getApplicationLog returns the contract log based on the specified txid.
|
||||
func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -352,7 +414,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error
|
|||
return result.NewApplicationLog(appExecResult, scriptHash), nil
|
||||
}
|
||||
|
||||
func (s *Server) getClaimable(ps request.Params) (interface{}, error) {
|
||||
func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -369,7 +431,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) {
|
|||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, response.NewInternalServerError("Unclaimed processing failure", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -404,7 +466,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) {
|
||||
func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -437,7 +499,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) {
|
|||
return bs, nil
|
||||
}
|
||||
|
||||
func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, error) {
|
||||
func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -501,7 +563,7 @@ func amountToString(amount int64, decimals int64) string {
|
|||
return fmt.Sprintf(fs, q, r)
|
||||
}
|
||||
|
||||
func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, error) {
|
||||
func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, *response.Error) {
|
||||
if d, ok := cache[h]; ok {
|
||||
return d, nil
|
||||
}
|
||||
|
@ -516,11 +578,11 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6
|
|||
},
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, response.NewInternalServerError("Can't create script", err)
|
||||
}
|
||||
res := s.runScriptInVM(script)
|
||||
if res == nil || res.State != "HALT" || len(res.Stack) == 0 {
|
||||
return 0, errors.New("execution error")
|
||||
return 0, response.NewInternalServerError("execution error", errors.New("no result"))
|
||||
}
|
||||
|
||||
var d int64
|
||||
|
@ -530,16 +592,16 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6
|
|||
case smartcontract.ByteArrayType:
|
||||
d = emit.BytesToInt(item.Value.([]byte)).Int64()
|
||||
default:
|
||||
return 0, errors.New("invalid result")
|
||||
return 0, response.NewInternalServerError("invalid result", errors.New("not an integer"))
|
||||
}
|
||||
if d < 0 {
|
||||
return 0, errors.New("negative decimals")
|
||||
return 0, response.NewInternalServerError("incorrect result", errors.New("negative result"))
|
||||
}
|
||||
cache[h] = d
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (s *Server) getStorage(ps request.Params) (interface{}, error) {
|
||||
func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
|
||||
param, ok := ps.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -570,8 +632,8 @@ func (s *Server) getStorage(ps request.Params) (interface{}, error) {
|
|||
return hex.EncodeToString(item.Value), nil
|
||||
}
|
||||
|
||||
func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error) {
|
||||
var resultsErr error
|
||||
func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var resultsErr *response.Error
|
||||
var results interface{}
|
||||
|
||||
if param0, ok := reqParams.Value(0); !ok {
|
||||
|
@ -607,7 +669,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error
|
|||
return results, resultsErr
|
||||
}
|
||||
|
||||
func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) {
|
||||
func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -626,7 +688,7 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) {
|
|||
return height, nil
|
||||
}
|
||||
|
||||
func (s *Server) getTxOut(ps request.Params) (interface{}, error) {
|
||||
func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -661,7 +723,7 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, error) {
|
|||
}
|
||||
|
||||
// getContractState returns contract state (contract information, according to the contract script hash).
|
||||
func (s *Server) getContractState(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var results interface{}
|
||||
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
|
@ -680,17 +742,17 @@ func (s *Server) getContractState(reqParams request.Params) (interface{}, error)
|
|||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Server) getAccountState(ps request.Params) (interface{}, error) {
|
||||
func (s *Server) getAccountState(ps request.Params) (interface{}, *response.Error) {
|
||||
return s.getAccountStateAux(ps, false)
|
||||
}
|
||||
|
||||
func (s *Server) getUnspents(ps request.Params) (interface{}, error) {
|
||||
func (s *Server) getUnspents(ps request.Params) (interface{}, *response.Error) {
|
||||
return s.getAccountStateAux(ps, true)
|
||||
}
|
||||
|
||||
// getAccountState returns account state either in short or full (unspents included) form.
|
||||
func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, error) {
|
||||
var resultsErr error
|
||||
func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, *response.Error) {
|
||||
var resultsErr *response.Error
|
||||
var results interface{}
|
||||
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
|
@ -717,7 +779,7 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in
|
|||
}
|
||||
|
||||
// getBlockSysFee returns the system fees of the block, based on the specified index.
|
||||
func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
||||
if !ok {
|
||||
return 0, response.ErrInvalidParams
|
||||
|
@ -729,9 +791,9 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) {
|
|||
}
|
||||
|
||||
headerHash := s.chain.GetHeaderHash(num)
|
||||
block, err := s.chain.GetBlock(headerHash)
|
||||
if err != nil {
|
||||
return 0, response.NewRPCError(err.Error(), "", nil)
|
||||
block, errBlock := s.chain.GetBlock(headerHash)
|
||||
if errBlock != nil {
|
||||
return 0, response.NewRPCError(errBlock.Error(), "", nil)
|
||||
}
|
||||
|
||||
var blockSysFee util.Fixed8
|
||||
|
@ -743,7 +805,7 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) {
|
|||
}
|
||||
|
||||
// getBlockHeader returns the corresponding block header information according to the specified script hash.
|
||||
func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var verbose bool
|
||||
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
|
@ -776,13 +838,13 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) {
|
|||
buf := io.NewBufBinWriter()
|
||||
h.EncodeBinary(buf.BinWriter)
|
||||
if buf.Err != nil {
|
||||
return nil, err
|
||||
return nil, response.NewInternalServerError("encoding error", buf.Err)
|
||||
}
|
||||
return hex.EncodeToString(buf.Bytes()), nil
|
||||
}
|
||||
|
||||
// getUnclaimed returns unclaimed GAS amount of the specified address.
|
||||
func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) {
|
||||
func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -796,21 +858,24 @@ func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) {
|
|||
if acc == nil {
|
||||
return nil, response.NewInternalServerError("unknown account", nil)
|
||||
}
|
||||
|
||||
return result.NewUnclaimed(acc, s.chain)
|
||||
res, errRes := result.NewUnclaimed(acc, s.chain)
|
||||
if errRes != nil {
|
||||
return nil, response.NewInternalServerError("can't create unclaimed response", errRes)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// getValidators returns the current NEO consensus nodes information and voting status.
|
||||
func (s *Server) getValidators(_ request.Params) (interface{}, error) {
|
||||
func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) {
|
||||
var validators keys.PublicKeys
|
||||
|
||||
validators, err := s.chain.GetValidators()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, response.NewRPCError("can't get validators", "", err)
|
||||
}
|
||||
enrollments, err := s.chain.GetEnrollments()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, response.NewRPCError("can't get enrollments", "", err)
|
||||
}
|
||||
var res []result.Validator
|
||||
for _, v := range enrollments {
|
||||
|
@ -824,14 +889,14 @@ func (s *Server) getValidators(_ request.Params) (interface{}, error) {
|
|||
}
|
||||
|
||||
// invoke implements the `invoke` RPC call.
|
||||
func (s *Server) invoke(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) {
|
||||
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
scriptHash, err := scriptHashHex.GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
sliceP, ok := reqParams.ValueWithType(1, request.ArrayT)
|
||||
if !ok {
|
||||
|
@ -839,34 +904,34 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, error) {
|
|||
}
|
||||
slice, err := sliceP.GetArray()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
script, err := request.CreateInvocationScript(scriptHash, slice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, response.NewInternalServerError("can't create invocation script", err)
|
||||
}
|
||||
return s.runScriptInVM(script), nil
|
||||
}
|
||||
|
||||
// invokescript implements the `invokescript` RPC call.
|
||||
func (s *Server) invokeFunction(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) {
|
||||
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
scriptHash, err := scriptHashHex.GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
script, err := request.CreateFunctionInvocationScript(scriptHash, reqParams[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, response.NewInternalServerError("can't create invocation script", err)
|
||||
}
|
||||
return s.runScriptInVM(script), nil
|
||||
}
|
||||
|
||||
// invokescript implements the `invokescript` RPC call.
|
||||
func (s *Server) invokescript(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) invokescript(reqParams request.Params) (interface{}, *response.Error) {
|
||||
if len(reqParams) < 1 {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -896,7 +961,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke {
|
|||
}
|
||||
|
||||
// submitBlock broadcasts a raw block over the NEO network.
|
||||
func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) {
|
||||
func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
|
@ -923,8 +988,8 @@ func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, error) {
|
||||
var resultsErr error
|
||||
func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var resultsErr *response.Error
|
||||
var results interface{}
|
||||
|
||||
if len(reqParams) < 1 {
|
||||
|
@ -958,7 +1023,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, erro
|
|||
return results, resultsErr
|
||||
}
|
||||
|
||||
func (s *Server) blockHeightFromParam(param *request.Param) (int, error) {
|
||||
func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) {
|
||||
num, err := param.GetInt()
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
|
@ -970,23 +1035,33 @@ func (s *Server) blockHeightFromParam(param *request.Param) (int, error) {
|
|||
return num, nil
|
||||
}
|
||||
|
||||
// WriteErrorResponse writes an error response to the ResponseWriter.
|
||||
func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, err error) {
|
||||
jsonErr, ok := err.(*response.Error)
|
||||
if !ok {
|
||||
jsonErr = response.NewInternalServerError("Internal server error", err)
|
||||
}
|
||||
|
||||
func (s *Server) packResponseToRaw(r *request.In, result interface{}, respErr *response.Error) response.Raw {
|
||||
resp := response.Raw{
|
||||
HeaderAndError: response.HeaderAndError{
|
||||
Header: response.Header{
|
||||
JSONRPC: r.JSONRPC,
|
||||
ID: r.RawID,
|
||||
},
|
||||
Error: jsonErr,
|
||||
},
|
||||
}
|
||||
if respErr != nil {
|
||||
resp.Error = respErr
|
||||
} else {
|
||||
resJSON, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
s.log.Error("failed to marshal result",
|
||||
zap.Error(err),
|
||||
zap.String("method", r.Method))
|
||||
resp.Error = response.NewInternalServerError("failed to encode result", err)
|
||||
} else {
|
||||
resp.Result = resJSON
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// logRequestError is a request error logger.
|
||||
func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) {
|
||||
logFields := []zap.Field{
|
||||
zap.Error(jsonErr.Cause),
|
||||
zap.String("method", r.Method),
|
||||
|
@ -998,35 +1073,20 @@ func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, err er
|
|||
}
|
||||
|
||||
s.log.Error("Error encountered with rpc request", logFields...)
|
||||
|
||||
w.WriteHeader(jsonErr.HTTPCode)
|
||||
s.writeServerResponse(r, w, resp)
|
||||
}
|
||||
|
||||
// WriteResponse encodes the response and writes it to the ResponseWriter.
|
||||
func (s *Server) WriteResponse(r *request.In, w http.ResponseWriter, result interface{}) {
|
||||
resJSON, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
s.log.Error("Error encountered while encoding response",
|
||||
zap.String("err", err.Error()),
|
||||
zap.String("method", r.Method))
|
||||
return
|
||||
}
|
||||
|
||||
resp := response.Raw{
|
||||
HeaderAndError: response.HeaderAndError{
|
||||
Header: response.Header{
|
||||
JSONRPC: r.JSONRPC,
|
||||
ID: r.RawID,
|
||||
},
|
||||
},
|
||||
Result: resJSON,
|
||||
}
|
||||
|
||||
s.writeServerResponse(r, w, resp)
|
||||
// writeHTTPErrorResponse writes an error response to the ResponseWriter.
|
||||
func (s *Server) writeHTTPErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) {
|
||||
resp := s.packResponseToRaw(r, nil, jsonErr)
|
||||
s.writeHTTPServerResponse(r, w, resp)
|
||||
}
|
||||
|
||||
func (s *Server) writeServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) {
|
||||
func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) {
|
||||
// Errors can happen in many places and we can only catch ALL of them here.
|
||||
if resp.Error != nil {
|
||||
s.logRequestError(r, resp.Error)
|
||||
w.WriteHeader(resp.Error.HTTPCode)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
if s.config.EnableCORSWorkaround {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
|
@ -17,7 +18,7 @@ import (
|
|||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, http.HandlerFunc) {
|
||||
func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *httptest.Server) {
|
||||
var nBlocks uint32
|
||||
|
||||
net := config.ModeUnitTestNet
|
||||
|
@ -55,9 +56,11 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, http.HandlerFu
|
|||
server, err := network.NewServer(serverConfig, chain, logger)
|
||||
require.NoError(t, err)
|
||||
rpcServer := New(chain, cfg.ApplicationConfiguration.RPC, server, logger)
|
||||
handler := http.HandlerFunc(rpcServer.requestHandler)
|
||||
|
||||
return chain, handler
|
||||
handler := http.HandlerFunc(rpcServer.handleHTTPRequest)
|
||||
srv := httptest.NewServer(handler)
|
||||
|
||||
return chain, srv
|
||||
}
|
||||
|
||||
type FeerStub struct{}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
|
||||
|
@ -31,7 +32,7 @@ import (
|
|||
|
||||
type executor struct {
|
||||
chain *core.Blockchain
|
||||
handler http.HandlerFunc
|
||||
httpSrv *httptest.Server
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -813,18 +814,31 @@ var rpcTestCases = map[string][]rpcTestCase{
|
|||
}
|
||||
|
||||
func TestRPC(t *testing.T) {
|
||||
chain, handler := initServerWithInMemoryChain(t)
|
||||
t.Run("http", func(t *testing.T) {
|
||||
testRPCProtocol(t, doRPCCallOverHTTP)
|
||||
})
|
||||
|
||||
t.Run("websocket", func(t *testing.T) {
|
||||
testRPCProtocol(t, doRPCCallOverWS)
|
||||
})
|
||||
}
|
||||
|
||||
// testRPCProtocol runs a full set of tests using given callback to make actual
|
||||
// calls. Some tests change the chain state, thus we reinitialize the chain from
|
||||
// scratch here.
|
||||
func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []byte) {
|
||||
chain, httpSrv := initServerWithInMemoryChain(t)
|
||||
|
||||
defer chain.Close()
|
||||
|
||||
e := &executor{chain: chain, handler: handler}
|
||||
e := &executor{chain: chain, httpSrv: httpSrv}
|
||||
for method, cases := range rpcTestCases {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}`
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), handler, t)
|
||||
body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t)
|
||||
result := checkErrGetResult(t, body, tc.fail)
|
||||
if tc.fail {
|
||||
return
|
||||
|
@ -848,7 +862,7 @@ func TestRPC(t *testing.T) {
|
|||
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "submitblock", "params": ["%s"]}`
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
s := newBlock(t, chain, 1)
|
||||
body := doRPCCall(fmt.Sprintf(rpc, s), handler, t)
|
||||
body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, s)), httpSrv.URL, t)
|
||||
checkErrGetResult(t, body, true)
|
||||
})
|
||||
|
||||
|
@ -867,13 +881,13 @@ func TestRPC(t *testing.T) {
|
|||
|
||||
t.Run("invalid height", func(t *testing.T) {
|
||||
b := newBlock(t, chain, 2, newTx())
|
||||
body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), handler, t)
|
||||
body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), httpSrv.URL, t)
|
||||
checkErrGetResult(t, body, true)
|
||||
})
|
||||
|
||||
t.Run("positive", func(t *testing.T) {
|
||||
b := newBlock(t, chain, 1, newTx())
|
||||
body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), handler, t)
|
||||
body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), httpSrv.URL, t)
|
||||
data := checkErrGetResult(t, body, false)
|
||||
var res bool
|
||||
require.NoError(t, json.Unmarshal(data, &res))
|
||||
|
@ -885,7 +899,7 @@ func TestRPC(t *testing.T) {
|
|||
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
||||
TXHash := block.Transactions[0].Hash()
|
||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s"]}"`, TXHash.StringLE())
|
||||
body := doRPCCall(rpc, handler, t)
|
||||
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||
result := checkErrGetResult(t, body, false)
|
||||
var res string
|
||||
err := json.Unmarshal(result, &res)
|
||||
|
@ -897,7 +911,7 @@ func TestRPC(t *testing.T) {
|
|||
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
||||
TXHash := block.Transactions[0].Hash()
|
||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 0]}"`, TXHash.StringLE())
|
||||
body := doRPCCall(rpc, handler, t)
|
||||
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||
result := checkErrGetResult(t, body, false)
|
||||
var res string
|
||||
err := json.Unmarshal(result, &res)
|
||||
|
@ -909,7 +923,7 @@ func TestRPC(t *testing.T) {
|
|||
block, _ := chain.GetBlock(chain.GetHeaderHash(0))
|
||||
TXHash := block.Transactions[0].Hash()
|
||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 1]}"`, TXHash.StringLE())
|
||||
body := doRPCCall(rpc, handler, t)
|
||||
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||
txOut := checkErrGetResult(t, body, false)
|
||||
actual := result.TransactionOutputRaw{}
|
||||
err := json.Unmarshal(txOut, &actual)
|
||||
|
@ -936,7 +950,7 @@ func TestRPC(t *testing.T) {
|
|||
hdr := e.getHeader(testHeaderHash)
|
||||
|
||||
runCase := func(t *testing.T, rpc string, expected, actual interface{}) {
|
||||
body := doRPCCall(rpc, handler, t)
|
||||
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||
data := checkErrGetResult(t, body, false)
|
||||
require.NoError(t, json.Unmarshal(data, actual))
|
||||
require.Equal(t, expected, actual)
|
||||
|
@ -984,7 +998,7 @@ func TestRPC(t *testing.T) {
|
|||
tx := block.Transactions[2]
|
||||
rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "gettxout", "params": [%s, %d]}"`,
|
||||
`"`+tx.Hash().StringLE()+`"`, 0)
|
||||
body := doRPCCall(rpc, handler, t)
|
||||
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||
res := checkErrGetResult(t, body, false)
|
||||
|
||||
var txOut result.TransactionOutput
|
||||
|
@ -1010,7 +1024,7 @@ func TestRPC(t *testing.T) {
|
|||
}
|
||||
|
||||
rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getrawmempool", "params": []}`
|
||||
body := doRPCCall(rpc, handler, t)
|
||||
body := doRPCCall(rpc, httpSrv.URL, t)
|
||||
res := checkErrGetResult(t, body, false)
|
||||
|
||||
var actual []util.Uint256
|
||||
|
@ -1085,12 +1099,23 @@ func checkErrGetResult(t *testing.T, body []byte, expectingFail bool) json.RawMe
|
|||
return resp.Result
|
||||
}
|
||||
|
||||
func doRPCCall(rpcCall string, handler http.HandlerFunc, t *testing.T) []byte {
|
||||
req := httptest.NewRequest("POST", "http://0.0.0.0:20333/", strings.NewReader(rpcCall))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
handler(w, req)
|
||||
resp := w.Result()
|
||||
func doRPCCallOverWS(rpcCall string, url string, t *testing.T) []byte {
|
||||
dialer := websocket.Dialer{HandshakeTimeout: time.Second}
|
||||
url = "ws" + strings.TrimPrefix(url, "http")
|
||||
c, _, err := dialer.Dial(url+"/ws", nil)
|
||||
require.NoError(t, err)
|
||||
c.SetWriteDeadline(time.Now().Add(time.Second))
|
||||
require.NoError(t, c.WriteMessage(1, []byte(rpcCall)))
|
||||
c.SetReadDeadline(time.Now().Add(time.Second))
|
||||
_, body, err := c.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
return bytes.TrimSpace(body)
|
||||
}
|
||||
|
||||
func doRPCCallOverHTTP(rpcCall string, url string, t *testing.T) []byte {
|
||||
cl := http.Client{Timeout: time.Second}
|
||||
resp, err := cl.Post(url, "application/json", strings.NewReader(rpcCall))
|
||||
require.NoErrorf(t, err, "could not make a POST request")
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoErrorf(t, err, "could not read response from the request: %s", rpcCall)
|
||||
return bytes.TrimSpace(body)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue