From 57de98e1a31389380f1c97bcb22f8fd28328ce5e Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Tue, 28 Apr 2020 16:56:33 +0300 Subject: [PATCH 01/12] rpc/server: refactor handler methods a little request.In is a natural request representation, one can always get request.Params from it. --- pkg/rpc/server/server.go | 16 ++++++++-------- pkg/rpc/server/server_helper_test.go | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index b4164ef2f..b43a626f2 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -111,11 +111,11 @@ func (s *Server) Start(errChan chan error) { s.log.Info("RPC server is not enabled") return } - s.Handler = http.HandlerFunc(s.requestHandler) + s.Handler = http.HandlerFunc(s.handleHTTPRequest) s.log.Info("starting rpc-server", zap.String("endpoint", s.Addr)) if cfg := s.config.TLSConfig; cfg.Enabled { - s.https.Handler = http.HandlerFunc(s.requestHandler) + s.https.Handler = http.HandlerFunc(s.handleHTTPRequest) s.log.Info("starting rpc-server (https)", zap.String("endpoint", s.https.Addr)) go func() { err := s.https.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile) @@ -149,7 +149,7 @@ func (s *Server) Shutdown() error { return err } -func (s *Server) requestHandler(w http.ResponseWriter, httpRequest *http.Request) { +func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { req := request.NewIn() if httpRequest.Method != "POST" { @@ -169,16 +169,16 @@ func (s *Server) requestHandler(w http.ResponseWriter, httpRequest *http.Request return } + s.handleRequest(w, req) +} + +func (s *Server) handleRequest(w http.ResponseWriter, req *request.In) { reqParams, err := req.Params() if err != nil { s.WriteErrorResponse(req, w, response.NewInvalidParamsError("Problem parsing request parameters", err)) return } - s.methodHandler(w, req, *reqParams) -} - -func (s *Server) methodHandler(w http.ResponseWriter, req *request.In, reqParams request.Params) { s.log.Debug("processing rpc request", zap.String("method", req.Method), zap.String("params", fmt.Sprintf("%v", reqParams))) @@ -192,7 +192,7 @@ func (s *Server) methodHandler(w http.ResponseWriter, req *request.In, reqParams handler, ok := rpcHandlers[req.Method] if ok { - results, resultsErr = handler(s, reqParams) + results, resultsErr = handler(s, *reqParams) } else { resultsErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil) } diff --git a/pkg/rpc/server/server_helper_test.go b/pkg/rpc/server/server_helper_test.go index 1263c0dcd..9b99a6505 100644 --- a/pkg/rpc/server/server_helper_test.go +++ b/pkg/rpc/server/server_helper_test.go @@ -55,7 +55,7 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, http.HandlerFu server, err := network.NewServer(serverConfig, chain, logger) require.NoError(t, err) rpcServer := New(chain, cfg.ApplicationConfiguration.RPC, server, logger) - handler := http.HandlerFunc(rpcServer.requestHandler) + handler := http.HandlerFunc(rpcServer.handleHTTPRequest) return chain, handler } From 236f3dabdd629d3fc02fa077fe3fb668b1a75aad Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Tue, 28 Apr 2020 22:35:19 +0300 Subject: [PATCH 02/12] rpc: change handlers to always return response.Error for errors As it's expected by WriteErrorResponse() actually. --- pkg/rpc/server/server.go | 126 +++++++++++++++++++-------------------- 1 file changed, 62 insertions(+), 64 deletions(-) diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index b43a626f2..2bf8e7056 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -44,7 +44,7 @@ type ( } ) -var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){ +var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){ "getaccountstate": (*Server).getAccountState, "getapplicationlog": (*Server).getApplicationLog, "getassetstate": (*Server).getAssetState, @@ -77,8 +77,8 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){ "validateaddress": (*Server).validateAddress, } -var invalidBlockHeightError = func(index int, height int) error { - return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height) +var invalidBlockHeightError = func(index int, height int) *response.Error { + return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil) } // New creates a new Server struct. @@ -185,7 +185,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, req *request.In) { var ( results interface{} - resultsErr error + resultsErr *response.Error ) incCounter(req.Method) @@ -205,19 +205,19 @@ func (s *Server) handleRequest(w http.ResponseWriter, req *request.In) { s.WriteResponse(req, w, results) } -func (s *Server) getBestBlockHash(_ request.Params) (interface{}, error) { +func (s *Server) getBestBlockHash(_ request.Params) (interface{}, *response.Error) { return "0x" + s.chain.CurrentBlockHash().StringLE(), nil } -func (s *Server) getBlockCount(_ request.Params) (interface{}, error) { +func (s *Server) getBlockCount(_ request.Params) (interface{}, *response.Error) { return s.chain.BlockHeight() + 1, nil } -func (s *Server) getConnectionCount(_ request.Params) (interface{}, error) { +func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Error) { return s.coreServer.PeerCount(), nil } -func (s *Server) getBlock(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) { var hash util.Uint256 param, ok := reqParams.Value(0) @@ -255,7 +255,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, error) { return hex.EncodeToString(writer.Bytes()), nil } -func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.NumberT) if !ok { return nil, response.ErrInvalidParams @@ -268,7 +268,7 @@ func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) { return s.chain.GetHeaderHash(num), nil } -func (s *Server) getVersion(_ request.Params) (interface{}, error) { +func (s *Server) getVersion(_ request.Params) (interface{}, *response.Error) { return result.Version{ Port: s.coreServer.Port, Nonce: s.coreServer.ID(), @@ -276,7 +276,7 @@ func (s *Server) getVersion(_ request.Params) (interface{}, error) { }, nil } -func (s *Server) getPeers(_ request.Params) (interface{}, error) { +func (s *Server) getPeers(_ request.Params) (interface{}, *response.Error) { peers := result.NewGetPeers() peers.AddUnconnected(s.coreServer.UnconnectedPeers()) peers.AddConnected(s.coreServer.ConnectedPeers()) @@ -284,7 +284,7 @@ func (s *Server) getPeers(_ request.Params) (interface{}, error) { return peers, nil } -func (s *Server) getRawMempool(_ request.Params) (interface{}, error) { +func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) { mp := s.chain.GetMemPool() hashList := make([]util.Uint256, 0) for _, item := range mp.GetVerifiedTransactions() { @@ -293,7 +293,7 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, error) { return hashList, nil } -func (s *Server) validateAddress(reqParams request.Params) (interface{}, error) { +func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -301,7 +301,7 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, error) return validateAddress(param.Value), nil } -func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) { +func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -320,7 +320,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) { } // getApplicationLog returns the contract log based on the specified txid. -func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error) { +func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -352,7 +352,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error return result.NewApplicationLog(appExecResult, scriptHash), nil } -func (s *Server) getClaimable(ps request.Params) (interface{}, error) { +func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -369,7 +369,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) { return nil }) if err != nil { - return nil, err + return nil, response.NewInternalServerError("Unclaimed processing failure", err) } } @@ -404,7 +404,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) { }, nil } -func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) { +func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -437,7 +437,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) { return bs, nil } -func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, error) { +func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -501,7 +501,7 @@ func amountToString(amount int64, decimals int64) string { return fmt.Sprintf(fs, q, r) } -func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, error) { +func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, *response.Error) { if d, ok := cache[h]; ok { return d, nil } @@ -516,11 +516,11 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 }, }) if err != nil { - return 0, err + return 0, response.NewInternalServerError("Can't create script", err) } res := s.runScriptInVM(script) if res == nil || res.State != "HALT" || len(res.Stack) == 0 { - return 0, errors.New("execution error") + return 0, response.NewInternalServerError("execution error", errors.New("no result")) } var d int64 @@ -530,16 +530,16 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 case smartcontract.ByteArrayType: d = emit.BytesToInt(item.Value.([]byte)).Int64() default: - return 0, errors.New("invalid result") + return 0, response.NewInternalServerError("invalid result", errors.New("not an integer")) } if d < 0 { - return 0, errors.New("negative decimals") + return 0, response.NewInternalServerError("incorrect result", errors.New("negative result")) } cache[h] = d return d, nil } -func (s *Server) getStorage(ps request.Params) (interface{}, error) { +func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { param, ok := ps.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -570,8 +570,8 @@ func (s *Server) getStorage(ps request.Params) (interface{}, error) { return hex.EncodeToString(item.Value), nil } -func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error) { - var resultsErr error +func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *response.Error) { + var resultsErr *response.Error var results interface{} if param0, ok := reqParams.Value(0); !ok { @@ -607,7 +607,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error return results, resultsErr } -func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) { +func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) { p, ok := ps.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -626,7 +626,7 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) { return height, nil } -func (s *Server) getTxOut(ps request.Params) (interface{}, error) { +func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) { p, ok := ps.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -661,7 +661,7 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, error) { } // getContractState returns contract state (contract information, according to the contract script hash). -func (s *Server) getContractState(reqParams request.Params) (interface{}, error) { +func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) { var results interface{} param, ok := reqParams.ValueWithType(0, request.StringT) @@ -680,17 +680,17 @@ func (s *Server) getContractState(reqParams request.Params) (interface{}, error) return results, nil } -func (s *Server) getAccountState(ps request.Params) (interface{}, error) { +func (s *Server) getAccountState(ps request.Params) (interface{}, *response.Error) { return s.getAccountStateAux(ps, false) } -func (s *Server) getUnspents(ps request.Params) (interface{}, error) { +func (s *Server) getUnspents(ps request.Params) (interface{}, *response.Error) { return s.getAccountStateAux(ps, true) } // getAccountState returns account state either in short or full (unspents included) form. -func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, error) { - var resultsErr error +func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, *response.Error) { + var resultsErr *response.Error var results interface{} param, ok := reqParams.ValueWithType(0, request.StringT) @@ -717,7 +717,7 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in } // getBlockSysFee returns the system fees of the block, based on the specified index. -func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.NumberT) if !ok { return 0, response.ErrInvalidParams @@ -729,9 +729,9 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) { } headerHash := s.chain.GetHeaderHash(num) - block, err := s.chain.GetBlock(headerHash) - if err != nil { - return 0, response.NewRPCError(err.Error(), "", nil) + block, errBlock := s.chain.GetBlock(headerHash) + if errBlock != nil { + return 0, response.NewRPCError(errBlock.Error(), "", nil) } var blockSysFee util.Fixed8 @@ -743,7 +743,7 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) { } // getBlockHeader returns the corresponding block header information according to the specified script hash. -func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) { var verbose bool param, ok := reqParams.ValueWithType(0, request.StringT) @@ -776,13 +776,13 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) { buf := io.NewBufBinWriter() h.EncodeBinary(buf.BinWriter) if buf.Err != nil { - return nil, err + return nil, response.NewInternalServerError("encoding error", buf.Err) } return hex.EncodeToString(buf.Bytes()), nil } // getUnclaimed returns unclaimed GAS amount of the specified address. -func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) { +func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -796,21 +796,24 @@ func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) { if acc == nil { return nil, response.NewInternalServerError("unknown account", nil) } - - return result.NewUnclaimed(acc, s.chain) + res, errRes := result.NewUnclaimed(acc, s.chain) + if errRes != nil { + return nil, response.NewInternalServerError("can't create unclaimed response", errRes) + } + return res, nil } // getValidators returns the current NEO consensus nodes information and voting status. -func (s *Server) getValidators(_ request.Params) (interface{}, error) { +func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) { var validators keys.PublicKeys validators, err := s.chain.GetValidators() if err != nil { - return nil, err + return nil, response.NewRPCError("can't get validators", "", err) } enrollments, err := s.chain.GetEnrollments() if err != nil { - return nil, err + return nil, response.NewRPCError("can't get enrollments", "", err) } var res []result.Validator for _, v := range enrollments { @@ -824,14 +827,14 @@ func (s *Server) getValidators(_ request.Params) (interface{}, error) { } // invoke implements the `invoke` RPC call. -func (s *Server) invoke(reqParams request.Params) (interface{}, error) { +func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) { scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams } scriptHash, err := scriptHashHex.GetUint160FromHex() if err != nil { - return nil, err + return nil, response.ErrInvalidParams } sliceP, ok := reqParams.ValueWithType(1, request.ArrayT) if !ok { @@ -839,34 +842,34 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, error) { } slice, err := sliceP.GetArray() if err != nil { - return nil, err + return nil, response.ErrInvalidParams } script, err := request.CreateInvocationScript(scriptHash, slice) if err != nil { - return nil, err + return nil, response.NewInternalServerError("can't create invocation script", err) } return s.runScriptInVM(script), nil } // invokescript implements the `invokescript` RPC call. -func (s *Server) invokeFunction(reqParams request.Params) (interface{}, error) { +func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) { scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams } scriptHash, err := scriptHashHex.GetUint160FromHex() if err != nil { - return nil, err + return nil, response.ErrInvalidParams } script, err := request.CreateFunctionInvocationScript(scriptHash, reqParams[1:]) if err != nil { - return nil, err + return nil, response.NewInternalServerError("can't create invocation script", err) } return s.runScriptInVM(script), nil } // invokescript implements the `invokescript` RPC call. -func (s *Server) invokescript(reqParams request.Params) (interface{}, error) { +func (s *Server) invokescript(reqParams request.Params) (interface{}, *response.Error) { if len(reqParams) < 1 { return nil, response.ErrInvalidParams } @@ -896,7 +899,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke { } // submitBlock broadcasts a raw block over the NEO network. -func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) { +func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -923,8 +926,8 @@ func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) { return true, nil } -func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, error) { - var resultsErr error +func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *response.Error) { + var resultsErr *response.Error var results interface{} if len(reqParams) < 1 { @@ -958,7 +961,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, erro return results, resultsErr } -func (s *Server) blockHeightFromParam(param *request.Param) (int, error) { +func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) { num, err := param.GetInt() if err != nil { return 0, nil @@ -971,12 +974,7 @@ func (s *Server) blockHeightFromParam(param *request.Param) (int, error) { } // WriteErrorResponse writes an error response to the ResponseWriter. -func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, err error) { - jsonErr, ok := err.(*response.Error) - if !ok { - jsonErr = response.NewInternalServerError("Internal server error", err) - } - +func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) { resp := response.Raw{ HeaderAndError: response.HeaderAndError{ Header: response.Header{ From 1b523be4b6591510ce6bb66524ec6adab394aedc Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Tue, 28 Apr 2020 22:56:19 +0300 Subject: [PATCH 03/12] rpc: shuffle handleHttpRequest/handleRequest responsibilities Make handleRequest reusable in other contexts like websockets. --- pkg/rpc/server/server.go | 88 +++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 50 deletions(-) diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 2bf8e7056..d6b64390c 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -153,7 +153,7 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ req := request.NewIn() if httpRequest.Method != "POST" { - s.WriteErrorResponse( + s.writeHTTPErrorResponse( req, w, response.NewInvalidParamsError( @@ -165,44 +165,32 @@ func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Requ err := req.DecodeData(httpRequest.Body) if err != nil { - s.WriteErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err)) + s.writeHTTPErrorResponse(req, w, response.NewParseError("Problem parsing JSON-RPC request body", err)) return } - s.handleRequest(w, req) + resp := s.handleRequest(req) + s.writeHTTPServerResponse(req, w, resp) } -func (s *Server) handleRequest(w http.ResponseWriter, req *request.In) { +func (s *Server) handleRequest(req *request.In) response.Raw { reqParams, err := req.Params() if err != nil { - s.WriteErrorResponse(req, w, response.NewInvalidParamsError("Problem parsing request parameters", err)) - return + return s.packResponseToRaw(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err)) } s.log.Debug("processing rpc request", zap.String("method", req.Method), zap.String("params", fmt.Sprintf("%v", reqParams))) - var ( - results interface{} - resultsErr *response.Error - ) - incCounter(req.Method) handler, ok := rpcHandlers[req.Method] - if ok { - results, resultsErr = handler(s, *reqParams) - } else { - resultsErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil) + if !ok { + return s.packResponseToRaw(req, nil, response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)) } - - if resultsErr != nil { - s.WriteErrorResponse(req, w, resultsErr) - return - } - - s.WriteResponse(req, w, results) + res, resErr := handler(s, *reqParams) + return s.packResponseToRaw(req, res, resErr) } func (s *Server) getBestBlockHash(_ request.Params) (interface{}, *response.Error) { @@ -973,18 +961,33 @@ func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Erro return num, nil } -// WriteErrorResponse writes an error response to the ResponseWriter. -func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) { +func (s *Server) packResponseToRaw(r *request.In, result interface{}, respErr *response.Error) response.Raw { resp := response.Raw{ HeaderAndError: response.HeaderAndError{ Header: response.Header{ JSONRPC: r.JSONRPC, ID: r.RawID, }, - Error: jsonErr, }, } + if respErr != nil { + resp.Error = respErr + } else { + resJSON, err := json.Marshal(result) + if err != nil { + s.log.Error("failed to marshal result", + zap.Error(err), + zap.String("method", r.Method)) + resp.Error = response.NewInternalServerError("failed to encode result", err) + } else { + resp.Result = resJSON + } + } + return resp +} +// logRequestError is a request error logger. +func (s *Server) logRequestError(r *request.In, jsonErr *response.Error) { logFields := []zap.Field{ zap.Error(jsonErr.Cause), zap.String("method", r.Method), @@ -996,35 +999,20 @@ func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, jsonEr } s.log.Error("Error encountered with rpc request", logFields...) - - w.WriteHeader(jsonErr.HTTPCode) - s.writeServerResponse(r, w, resp) } -// WriteResponse encodes the response and writes it to the ResponseWriter. -func (s *Server) WriteResponse(r *request.In, w http.ResponseWriter, result interface{}) { - resJSON, err := json.Marshal(result) - if err != nil { - s.log.Error("Error encountered while encoding response", - zap.String("err", err.Error()), - zap.String("method", r.Method)) - return - } - - resp := response.Raw{ - HeaderAndError: response.HeaderAndError{ - Header: response.Header{ - JSONRPC: r.JSONRPC, - ID: r.RawID, - }, - }, - Result: resJSON, - } - - s.writeServerResponse(r, w, resp) +// writeHTTPErrorResponse writes an error response to the ResponseWriter. +func (s *Server) writeHTTPErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) { + resp := s.packResponseToRaw(r, nil, jsonErr) + s.writeHTTPServerResponse(r, w, resp) } -func (s *Server) writeServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) { +func (s *Server) writeHTTPServerResponse(r *request.In, w http.ResponseWriter, resp response.Raw) { + // Errors can happen in many places and we can only catch ALL of them here. + if resp.Error != nil { + s.logRequestError(r, resp.Error) + w.WriteHeader(resp.Error.HTTPCode) + } w.Header().Set("Content-Type", "application/json; charset=utf-8") if s.config.EnableCORSWorkaround { w.Header().Set("Access-Control-Allow-Origin", "*") From d275652b375dfa78d5d9ebbde0263d5814685bda Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 29 Apr 2020 15:14:56 +0300 Subject: [PATCH 04/12] rpc/server: use httptest.Server for testing Which allows to reuse it for websockets. --- pkg/rpc/server/server_helper_test.go | 9 ++++--- pkg/rpc/server/server_test.go | 36 +++++++++++++--------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/pkg/rpc/server/server_helper_test.go b/pkg/rpc/server/server_helper_test.go index 9b99a6505..c6ee3167e 100644 --- a/pkg/rpc/server/server_helper_test.go +++ b/pkg/rpc/server/server_helper_test.go @@ -2,6 +2,7 @@ package server import ( "net/http" + "net/http/httptest" "os" "testing" @@ -17,7 +18,7 @@ import ( "go.uber.org/zap/zaptest" ) -func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, http.HandlerFunc) { +func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, *httptest.Server) { var nBlocks uint32 net := config.ModeUnitTestNet @@ -55,9 +56,11 @@ func initServerWithInMemoryChain(t *testing.T) (*core.Blockchain, http.HandlerFu server, err := network.NewServer(serverConfig, chain, logger) require.NoError(t, err) rpcServer := New(chain, cfg.ApplicationConfiguration.RPC, server, logger) - handler := http.HandlerFunc(rpcServer.handleHTTPRequest) - return chain, handler + handler := http.HandlerFunc(rpcServer.handleHTTPRequest) + srv := httptest.NewServer(handler) + + return chain, srv } type FeerStub struct{} diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index ab5e059de..56f9de38f 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -32,7 +32,7 @@ import ( type executor struct { chain *core.Blockchain - handler http.HandlerFunc + httpSrv *httptest.Server } const ( @@ -814,18 +814,18 @@ var rpcTestCases = map[string][]rpcTestCase{ } func TestRPC(t *testing.T) { - chain, handler := initServerWithInMemoryChain(t) + chain, httpSrv := initServerWithInMemoryChain(t) defer chain.Close() - e := &executor{chain: chain, handler: handler} + e := &executor{chain: chain, httpSrv: httpSrv} for method, cases := range rpcTestCases { t.Run(method, func(t *testing.T) { rpc := `{"jsonrpc": "2.0", "id": 1, "method": "%s", "params": %s}` for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), handler, t) + body := doRPCCall(fmt.Sprintf(rpc, method, tc.params), httpSrv.URL, t) result := checkErrGetResult(t, body, tc.fail) if tc.fail { return @@ -849,7 +849,7 @@ func TestRPC(t *testing.T) { rpc := `{"jsonrpc": "2.0", "id": 1, "method": "submitblock", "params": ["%s"]}` t.Run("empty", func(t *testing.T) { s := newBlock(t, chain, 1) - body := doRPCCall(fmt.Sprintf(rpc, s), handler, t) + body := doRPCCall(fmt.Sprintf(rpc, s), httpSrv.URL, t) checkErrGetResult(t, body, true) }) @@ -867,13 +867,13 @@ func TestRPC(t *testing.T) { t.Run("invalid height", func(t *testing.T) { b := newBlock(t, chain, 2, newTx()) - body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), handler, t) + body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), httpSrv.URL, t) checkErrGetResult(t, body, true) }) t.Run("positive", func(t *testing.T) { b := newBlock(t, chain, 1, newTx()) - body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), handler, t) + body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, b)), httpSrv.URL, t) data := checkErrGetResult(t, body, false) var res bool require.NoError(t, json.Unmarshal(data, &res)) @@ -885,7 +885,7 @@ func TestRPC(t *testing.T) { block, _ := chain.GetBlock(chain.GetHeaderHash(0)) TXHash := block.Transactions[1].Hash() rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s"]}"`, TXHash.StringLE()) - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) result := checkErrGetResult(t, body, false) var res string err := json.Unmarshal(result, &res) @@ -897,7 +897,7 @@ func TestRPC(t *testing.T) { block, _ := chain.GetBlock(chain.GetHeaderHash(0)) TXHash := block.Transactions[1].Hash() rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 0]}"`, TXHash.StringLE()) - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) result := checkErrGetResult(t, body, false) var res string err := json.Unmarshal(result, &res) @@ -909,7 +909,7 @@ func TestRPC(t *testing.T) { block, _ := chain.GetBlock(chain.GetHeaderHash(0)) TXHash := block.Transactions[1].Hash() rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "getrawtransaction", "params": ["%s", 1]}"`, TXHash.StringLE()) - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) txOut := checkErrGetResult(t, body, false) actual := result.TransactionOutputRaw{} err := json.Unmarshal(txOut, &actual) @@ -936,7 +936,7 @@ func TestRPC(t *testing.T) { hdr := e.getHeader(testHeaderHash) runCase := func(t *testing.T, rpc string, expected, actual interface{}) { - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) data := checkErrGetResult(t, body, false) require.NoError(t, json.Unmarshal(data, actual)) require.Equal(t, expected, actual) @@ -984,7 +984,7 @@ func TestRPC(t *testing.T) { tx := block.Transactions[3] rpc := fmt.Sprintf(`{"jsonrpc": "2.0", "id": 1, "method": "gettxout", "params": [%s, %d]}"`, `"`+tx.Hash().StringLE()+`"`, 0) - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) res := checkErrGetResult(t, body, false) var txOut result.TransactionOutput @@ -1010,7 +1010,7 @@ func TestRPC(t *testing.T) { } rpc := `{"jsonrpc": "2.0", "id": 1, "method": "getrawmempool", "params": []}` - body := doRPCCall(rpc, handler, t) + body := doRPCCall(rpc, httpSrv.URL, t) res := checkErrGetResult(t, body, false) var actual []util.Uint256 @@ -1082,12 +1082,10 @@ func checkErrGetResult(t *testing.T, body []byte, expectingFail bool) json.RawMe return resp.Result } -func doRPCCall(rpcCall string, handler http.HandlerFunc, t *testing.T) []byte { - req := httptest.NewRequest("POST", "http://0.0.0.0:20333/", strings.NewReader(rpcCall)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - handler(w, req) - resp := w.Result() +func doRPCCall(rpcCall string, url string, t *testing.T) []byte { + cl := http.Client{Timeout: time.Second} + resp, err := cl.Post(url, "application/json", strings.NewReader(rpcCall)) + require.NoErrorf(t, err, "could not make a POST request") body, err := ioutil.ReadAll(resp.Body) assert.NoErrorf(t, err, "could not read response from the request: %s", rpcCall) return bytes.TrimSpace(body) From 8cec6694ae9e90f7b279fa5b5d66d5bf07fb9627 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Mon, 4 May 2020 16:53:36 +0300 Subject: [PATCH 05/12] rpc/server: fix test block encoding The end result of the previous code wasn't even a valid JSON. --- pkg/rpc/server/server_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index 56f9de38f..4edb9b6da 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -849,7 +849,7 @@ func TestRPC(t *testing.T) { rpc := `{"jsonrpc": "2.0", "id": 1, "method": "submitblock", "params": ["%s"]}` t.Run("empty", func(t *testing.T) { s := newBlock(t, chain, 1) - body := doRPCCall(fmt.Sprintf(rpc, s), httpSrv.URL, t) + body := doRPCCall(fmt.Sprintf(rpc, encodeBlock(t, s)), httpSrv.URL, t) checkErrGetResult(t, body, true) }) From ec62edac68501b82f866139112b6bf9cd2898df3 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 29 Apr 2020 15:25:58 +0300 Subject: [PATCH 06/12] rpc/server: add websockets support via '/ws' URL --- go.mod | 1 + go.sum | 2 + pkg/rpc/server/server.go | 80 +++++++++++++++++++++++++++++++++-- pkg/rpc/server/server_test.go | 29 ++++++++++++- 4 files changed, 108 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 0d4eb0ec8..570f75872 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/dgraph-io/badger/v2 v2.0.3 github.com/go-redis/redis v6.10.2+incompatible github.com/go-yaml/yaml v2.1.0+incompatible + github.com/gorilla/websocket v1.4.2 github.com/mr-tron/base58 v1.1.2 github.com/nspcc-dev/dbft v0.0.0-20200303183127-36d3da79c682 github.com/nspcc-dev/rfc6979 v0.2.0 diff --git a/go.sum b/go.sum index abbecef44..9d9886a13 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,8 @@ github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index d6b64390c..817cba0ba 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -9,12 +9,12 @@ import ( "net" "net/http" "strconv" + "time" - "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" - "github.com/nspcc-dev/neo-go/pkg/rpc" - + "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -22,6 +22,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/network" + "github.com/nspcc-dev/neo-go/pkg/rpc" "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response" "github.com/nspcc-dev/neo-go/pkg/rpc/response/result" @@ -44,6 +45,20 @@ type ( } ) +const ( + // Message limit for receiving side. + wsReadLimit = 4096 + + // Disconnection timeout. + wsPongLimit = 60 * time.Second + + // Ping period for connection liveness check. + wsPingPeriod = wsPongLimit / 2 + + // Write deadline. + wsWriteLimit = wsPingPeriod / 2 +) + var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){ "getaccountstate": (*Server).getAccountState, "getapplicationlog": (*Server).getApplicationLog, @@ -81,6 +96,10 @@ var invalidBlockHeightError = func(index int, height int) *response.Error { return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil) } +// upgrader is a no-op websocket.Upgrader that reuses HTTP server buffers and +// doesn't set any Error function. +var upgrader = websocket.Upgrader{} + // New creates a new Server struct. func New(chain blockchainer.Blockchainer, conf rpc.Config, coreServer *network.Server, log *zap.Logger) Server { httpServer := &http.Server{ @@ -150,6 +169,18 @@ func (s *Server) Shutdown() error { } func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) { + if httpRequest.URL.Path == "/ws" && httpRequest.Method == "GET" { + ws, err := upgrader.Upgrade(w, httpRequest, nil) + if err != nil { + s.log.Info("websocket connection upgrade failed", zap.Error(err)) + return + } + resChan := make(chan response.Raw) + go s.handleWsWrites(ws, resChan) + s.handleWsReads(ws, resChan) + return + } + req := request.NewIn() if httpRequest.Method != "POST" { @@ -193,6 +224,49 @@ func (s *Server) handleRequest(req *request.In) response.Raw { return s.packResponseToRaw(req, res, resErr) } +func (s *Server) handleWsWrites(ws *websocket.Conn, resChan <-chan response.Raw) { + pingTicker := time.NewTicker(wsPingPeriod) + defer ws.Close() + defer pingTicker.Stop() + for { + select { + case res, ok := <-resChan: + if !ok { + return + } + ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)) + if err := ws.WriteJSON(res); err != nil { + return + } + case <-pingTicker.C: + ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)) + if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + return + } + } + } +} + +func (s *Server) handleWsReads(ws *websocket.Conn, resChan chan<- response.Raw) { + ws.SetReadLimit(wsReadLimit) + ws.SetReadDeadline(time.Now().Add(wsPongLimit)) + ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) + for { + req := new(request.In) + err := ws.ReadJSON(req) + if err != nil { + break + } + res := s.handleRequest(req) + if res.Error != nil { + s.logRequestError(req, res.Error) + } + resChan <- res + } + close(resChan) + ws.Close() +} + func (s *Server) getBestBlockHash(_ request.Params) (interface{}, *response.Error) { return "0x" + s.chain.CurrentBlockHash().StringLE(), nil } diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index 4edb9b6da..b01a204b2 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" @@ -814,6 +815,19 @@ var rpcTestCases = map[string][]rpcTestCase{ } func TestRPC(t *testing.T) { + t.Run("http", func(t *testing.T) { + testRPCProtocol(t, doRPCCallOverHTTP) + }) + + t.Run("websocket", func(t *testing.T) { + testRPCProtocol(t, doRPCCallOverWS) + }) +} + +// testRPCProtocol runs a full set of tests using given callback to make actual +// calls. Some tests change the chain state, thus we reinitialize the chain from +// scratch here. +func testRPCProtocol(t *testing.T, doRPCCall func(string, string, *testing.T) []byte) { chain, httpSrv := initServerWithInMemoryChain(t) defer chain.Close() @@ -1082,7 +1096,20 @@ func checkErrGetResult(t *testing.T, body []byte, expectingFail bool) json.RawMe return resp.Result } -func doRPCCall(rpcCall string, url string, t *testing.T) []byte { +func doRPCCallOverWS(rpcCall string, url string, t *testing.T) []byte { + dialer := websocket.Dialer{HandshakeTimeout: time.Second} + url = "ws" + strings.TrimPrefix(url, "http") + c, _, err := dialer.Dial(url+"/ws", nil) + require.NoError(t, err) + c.SetWriteDeadline(time.Now().Add(time.Second)) + require.NoError(t, c.WriteMessage(1, []byte(rpcCall))) + c.SetReadDeadline(time.Now().Add(time.Second)) + _, body, err := c.ReadMessage() + require.NoError(t, err) + return bytes.TrimSpace(body) +} + +func doRPCCallOverHTTP(rpcCall string, url string, t *testing.T) []byte { cl := http.Client{Timeout: time.Second} resp, err := cl.Post(url, "application/json", strings.NewReader(rpcCall)) require.NoErrorf(t, err, "could not make a POST request") From 315aabde564fd8429b3f262cff83080535e9084b Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 29 Apr 2020 17:19:31 +0300 Subject: [PATCH 07/12] client: make http.Client internal to the Client Exposing it the outside users is strange, so incapsulate it completely. Fix DialTimeout setting along the way, handle negative timeouts as invalid. --- pkg/rpc/client/client.go | 62 ++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 40 deletions(-) diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index 3a5c2820f..cb0e623db 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -31,9 +31,6 @@ const ( // Client represents the middleman for executing JSON RPC calls // to remote NEO RPC nodes. type Client struct { - // The underlying http client. It's never a good practice to use - // the http.DefaultClient, therefore we will role our own. - cliMu *sync.Mutex cli *http.Client endpoint *url.URL ctx context.Context @@ -49,11 +46,11 @@ type Client struct { // All Values are optional. If any duration is not specified // a default of 3 seconds will be used. type Options struct { - Cert string - Key string - CACert string - DialTimeout time.Duration - Client *http.Client + Cert string + Key string + CACert string + DialTimeout time.Duration + RequestTimeout time.Duration // Version is the version of the client that will be send // along with the request body. If no version is specified // the default version (currently 2.0) will be used. @@ -79,32 +76,34 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { return nil, err } + if opts.DialTimeout <= 0 { + opts.DialTimeout = defaultDialTimeout + } + + if opts.RequestTimeout <= 0 { + opts.RequestTimeout = defaultRequestTimeout + } + if opts.Version == "" { opts.Version = defaultClientVersion } - if opts.Client == nil { - opts.Client = &http.Client{ - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: opts.DialTimeout, - }).DialContext, - }, - } + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: opts.DialTimeout, + }).DialContext, + }, + Timeout: opts.RequestTimeout, } // TODO(@antdm): Enable SSL. if opts.Cert != "" && opts.Key != "" { } - if opts.Client.Timeout == 0 { - opts.Client.Timeout = defaultRequestTimeout - } - return &Client{ ctx: ctx, - cli: opts.Client, - cliMu: new(sync.Mutex), + cli: httpClient, balancerMu: new(sync.Mutex), wifMu: new(sync.Mutex), endpoint: url, @@ -154,23 +153,6 @@ func (c *Client) SetBalancer(b request.BalanceGetter) { } } -// Client is a getter for client field. -func (c *Client) Client() *http.Client { - c.cliMu.Lock() - defer c.cliMu.Unlock() - return c.cli -} - -// SetClient is a setter for client field. -func (c *Client) SetClient(cli *http.Client) { - c.cliMu.Lock() - defer c.cliMu.Unlock() - - if cli != nil { - c.cli = cli - } -} - // CalculateInputs creates input transactions for the specified amount of given // asset belonging to specified address. This implementation uses GetUnspents // JSON-RPC call internally, so make sure your RPC server supports that. @@ -211,7 +193,7 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{ if err != nil { return err } - resp, err := c.Client().Do(req) + resp, err := c.cli.Do(req) if err != nil { return err } From 20d477cbd8bd606b9dc6b76332e8f8e18c19496d Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 29 Apr 2020 18:04:05 +0300 Subject: [PATCH 08/12] client: remove Balancer getter/setter, make it an Option Keep it internal to the client instance, it makes no sense exposing it to the outside user. --- pkg/rpc/client/client.go | 67 ++++++++++++++++++---------------------- pkg/rpc/client/rpc.go | 2 +- 2 files changed, 31 insertions(+), 38 deletions(-) diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index cb0e623db..ecc3a2d7c 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -31,21 +31,26 @@ const ( // Client represents the middleman for executing JSON RPC calls // to remote NEO RPC nodes. type Client struct { - cli *http.Client - endpoint *url.URL - ctx context.Context - version string - wifMu *sync.Mutex - wif *keys.WIF - balancerMu *sync.Mutex - balancer request.BalanceGetter - cache cache + cli *http.Client + endpoint *url.URL + ctx context.Context + version string + wifMu *sync.Mutex + wif *keys.WIF + balancer request.BalanceGetter + cache cache } // Options defines options for the RPC client. // All Values are optional. If any duration is not specified // a default of 3 seconds will be used. type Options struct { + // Balancer is an implementation of request.BalanceGetter interface, + // if not set then the default Client's implementation will be used, but + // it relies on server support for `getunspents` RPC call which is + // standard for neo-go, but only implemented as a plugin for C# node. So + // you can override it here to use NeoScanServer for example. + Balancer request.BalanceGetter Cert string Key string CACert string @@ -101,14 +106,18 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { if opts.Cert != "" && opts.Key != "" { } - return &Client{ - ctx: ctx, - cli: httpClient, - balancerMu: new(sync.Mutex), - wifMu: new(sync.Mutex), - endpoint: url, - version: opts.Version, - }, nil + cl := &Client{ + ctx: ctx, + cli: httpClient, + wifMu: new(sync.Mutex), + endpoint: url, + version: opts.Version, + } + if opts.Balancer == nil { + opts.Balancer = cl + } + cl.balancer = opts.Balancer + return cl, nil } // WIF returns WIF structure associated with the client. @@ -136,26 +145,10 @@ func (c *Client) SetWIF(wif string) error { return nil } -// Balancer is a getter for balance field. -func (c *Client) Balancer() request.BalanceGetter { - c.balancerMu.Lock() - defer c.balancerMu.Unlock() - return c.balancer -} - -// SetBalancer is a setter for balance field. -func (c *Client) SetBalancer(b request.BalanceGetter) { - c.balancerMu.Lock() - defer c.balancerMu.Unlock() - - if b != nil { - c.balancer = b - } -} - -// CalculateInputs creates input transactions for the specified amount of given -// asset belonging to specified address. This implementation uses GetUnspents -// JSON-RPC call internally, so make sure your RPC server supports that. +// CalculateInputs implements request.BalanceGetter interface and returns inputs +// array for the specified amount of given asset belonging to specified address. +// This implementation uses GetUnspents JSON-RPC call internally, so make sure +// your RPC server supports that. func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.Fixed8) ([]transaction.Input, util.Fixed8, error) { var utxos state.UnspentBalances diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index 49e830112..dc199074b 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -482,7 +482,7 @@ func (c *Client) TransferAsset(asset util.Uint256, address string, amount util.F Address: address, Value: amount, WIF: c.WIF(), - Balancer: c.Balancer(), + Balancer: c.balancer, } resp util.Uint256 ) From 19397ec4a8e039d80b6359bbd151229691fbf410 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 29 Apr 2020 18:15:26 +0300 Subject: [PATCH 09/12] rpc/client: fix some comments --- pkg/rpc/client/client.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index ecc3a2d7c..953c25477 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -42,15 +42,18 @@ type Client struct { } // Options defines options for the RPC client. -// All Values are optional. If any duration is not specified -// a default of 3 seconds will be used. +// All values are optional. If any duration is not specified +// a default of 4 seconds will be used. type Options struct { // Balancer is an implementation of request.BalanceGetter interface, // if not set then the default Client's implementation will be used, but // it relies on server support for `getunspents` RPC call which is // standard for neo-go, but only implemented as a plugin for C# node. So // you can override it here to use NeoScanServer for example. - Balancer request.BalanceGetter + Balancer request.BalanceGetter + + // Cert is a client-side certificate, it doesn't work at the moment along + // with the other two options below. Cert string Key string CACert string From 6d202ad4c582faa1fd3dea2cd797072a003fa662 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 29 Apr 2020 18:17:18 +0300 Subject: [PATCH 10/12] rpc/client: drop Version from Options It makes no sense at all, it's a JSON-RPC version. --- pkg/rpc/client/client.go | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index 953c25477..3c6184bb7 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -34,7 +34,6 @@ type Client struct { cli *http.Client endpoint *url.URL ctx context.Context - version string wifMu *sync.Mutex wif *keys.WIF balancer request.BalanceGetter @@ -59,10 +58,6 @@ type Options struct { CACert string DialTimeout time.Duration RequestTimeout time.Duration - // Version is the version of the client that will be send - // along with the request body. If no version is specified - // the default version (currently 2.0) will be used. - Version string } // cache stores cache values for the RPC client methods @@ -92,10 +87,6 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { opts.RequestTimeout = defaultRequestTimeout } - if opts.Version == "" { - opts.Version = defaultClientVersion - } - httpClient := &http.Client{ Transport: &http.Transport{ DialContext: (&net.Dialer{ @@ -114,7 +105,6 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { cli: httpClient, wifMu: new(sync.Mutex), endpoint: url, - version: opts.Version, } if opts.Balancer == nil { opts.Balancer = cl @@ -172,7 +162,7 @@ func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.F func (c *Client) performRequest(method string, p request.RawParams, v interface{}) error { var ( r = request.Raw{ - JSONRPC: c.version, + JSONRPC: request.JSONRPCVersion, Method: method, RawParams: p.Values, ID: 1, From a458a1774892e9fb52594741f6d460911798c9de Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 29 Apr 2020 18:39:24 +0300 Subject: [PATCH 11/12] rpc/client: separate out http-related functionality --- pkg/rpc/client/client.go | 50 ++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index 3c6184bb7..a3fbbd00e 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -160,47 +160,57 @@ func (c *Client) CalculateInputs(address string, asset util.Uint256, cost util.F } func (c *Client) performRequest(method string, p request.RawParams, v interface{}) error { + var r = request.Raw{ + JSONRPC: request.JSONRPCVersion, + Method: method, + RawParams: p.Values, + ID: 1, + } + + raw, err := c.makeHTTPRequest(&r) + + if raw != nil && raw.Error != nil { + return raw.Error + } else if err != nil { + return err + } + return json.Unmarshal(raw.Result, v) +} + +func (c *Client) makeHTTPRequest(r *request.Raw) (*response.Raw, error) { var ( - r = request.Raw{ - JSONRPC: request.JSONRPCVersion, - Method: method, - RawParams: p.Values, - ID: 1, - } buf = new(bytes.Buffer) - raw = &response.Raw{} + raw = new(response.Raw) ) if err := json.NewEncoder(buf).Encode(r); err != nil { - return err + return nil, err } req, err := http.NewRequest("POST", c.endpoint.String(), buf) if err != nil { - return err + return nil, err } resp, err := c.cli.Do(req) if err != nil { - return err + return nil, err } defer resp.Body.Close() // The node might send us proper JSON anyway, so look there first and if // it parses, then it has more relevant data than HTTP error code. err = json.NewDecoder(resp.Body).Decode(raw) - if err == nil { - if raw.Error != nil { - err = raw.Error + if err != nil { + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf("HTTP %d/%s", resp.StatusCode, http.StatusText(resp.StatusCode)) } else { - err = json.Unmarshal(raw.Result, v) + err = errors.Wrap(err, "JSON decoding") } - } else if resp.StatusCode != http.StatusOK { - err = fmt.Errorf("HTTP %d/%s", resp.StatusCode, http.StatusText(resp.StatusCode)) - } else { - err = errors.Wrap(err, "JSON decoding") } - - return err + if err != nil { + return nil, err + } + return raw, nil } // Ping attempts to create a connection to the endpoint. From 3de48d7d9063d714591d87750a092382adcdb6c6 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Wed, 29 Apr 2020 22:51:43 +0300 Subject: [PATCH 12/12] rpc/client: add minimalistic websocket client --- pkg/rpc/client/client.go | 10 +- pkg/rpc/client/rpc.go | 2 +- pkg/rpc/client/rpc_test.go | 53 +++++++++-- pkg/rpc/client/wsclient.go | 160 ++++++++++++++++++++++++++++++++ pkg/rpc/client/wsclient_test.go | 16 ++++ 5 files changed, 230 insertions(+), 11 deletions(-) create mode 100644 pkg/rpc/client/wsclient.go create mode 100644 pkg/rpc/client/wsclient_test.go diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index a3fbbd00e..d5ad26e7b 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -34,9 +34,10 @@ type Client struct { cli *http.Client endpoint *url.URL ctx context.Context + opts Options + requestF func(*request.Raw) (*response.Raw, error) wifMu *sync.Mutex wif *keys.WIF - balancer request.BalanceGetter cache cache } @@ -109,7 +110,8 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { if opts.Balancer == nil { opts.Balancer = cl } - cl.balancer = opts.Balancer + cl.opts = opts + cl.requestF = cl.makeHTTPRequest return cl, nil } @@ -167,12 +169,14 @@ func (c *Client) performRequest(method string, p request.RawParams, v interface{ ID: 1, } - raw, err := c.makeHTTPRequest(&r) + raw, err := c.requestF(&r) if raw != nil && raw.Error != nil { return raw.Error } else if err != nil { return err + } else if raw == nil || raw.Result == nil { + return errors.New("no result returned") } return json.Unmarshal(raw.Result, v) } diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index dc199074b..3e53db23d 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -482,7 +482,7 @@ func (c *Client) TransferAsset(asset util.Uint256, address string, amount util.F Address: address, Value: amount, WIF: c.WIF(), - Balancer: c.balancer, + Balancer: c.opts.Balancer, } resp util.Uint256 ) diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index f630868b2..3756d50c1 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -3,11 +3,13 @@ package client import ( "context" "encoding/hex" - "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" + "time" + "github.com/gorilla/websocket" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -1433,7 +1435,22 @@ var rpcClientErrorCases = map[string][]rpcClientErrorCase{ }, } -func TestRPCClient(t *testing.T) { +func TestRPCClients(t *testing.T) { + t.Run("Client", func(t *testing.T) { + testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { + return New(ctx, endpoint, opts) + }) + }) + t.Run("WSClient", func(t *testing.T) { + testRPCClient(t, func(ctx context.Context, endpoint string, opts Options) (*Client, error) { + wsc, err := NewWS(ctx, httpURLtoWS(endpoint), opts) + require.NoError(t, err) + return &wsc.Client, nil + }) + }) +} + +func testRPCClient(t *testing.T, newClient func(context.Context, string, Options) (*Client, error)) { for method, testBatch := range rpcClientTestCases { t.Run(method, func(t *testing.T) { for _, testCase := range testBatch { @@ -1443,7 +1460,7 @@ func TestRPCClient(t *testing.T) { endpoint := srv.URL opts := Options{} - c, err := New(context.TODO(), endpoint, opts) + c, err := newClient(context.TODO(), endpoint, opts) if err != nil { t.Fatal(err) } @@ -1467,7 +1484,7 @@ func TestRPCClient(t *testing.T) { endpoint := srv.URL opts := Options{} - c, err := New(context.TODO(), endpoint, opts) + c, err := newClient(context.TODO(), endpoint, opts) if err != nil { t.Fatal(err) } @@ -1481,8 +1498,31 @@ func TestRPCClient(t *testing.T) { } } +func httpURLtoWS(url string) string { + return "ws" + strings.TrimPrefix(url, "http") + "/ws" +} + func initTestServer(t *testing.T, resp string) *httptest.Server { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ws" && req.Method == "GET" { + var upgrader = websocket.Upgrader{} + ws, err := upgrader.Upgrade(w, req, nil) + require.NoError(t, err) + for { + ws.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, _, err = ws.ReadMessage() + if err != nil { + break + } + ws.SetWriteDeadline(time.Now().Add(2 * time.Second)) + err = ws.WriteMessage(1, []byte(resp)) + if err != nil { + break + } + } + ws.Close() + return + } requestHandler(t, w, resp) })) @@ -1491,11 +1531,10 @@ func initTestServer(t *testing.T, resp string) *httptest.Server { func requestHandler(t *testing.T, w http.ResponseWriter, resp string) { w.Header().Set("Content-Type", "application/json; charset=utf-8") - encoder := json.NewEncoder(w) - err := encoder.Encode(json.RawMessage(resp)) + _, err := w.Write([]byte(resp)) if err != nil { - t.Fatalf("Error encountered while encoding response: %s", err.Error()) + t.Fatalf("Error writing response: %s", err.Error()) } } diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go new file mode 100644 index 000000000..bdac24816 --- /dev/null +++ b/pkg/rpc/client/wsclient.go @@ -0,0 +1,160 @@ +package client + +import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/gorilla/websocket" + "github.com/nspcc-dev/neo-go/pkg/rpc/request" + "github.com/nspcc-dev/neo-go/pkg/rpc/response" +) + +// WSClient is a websocket-enabled RPC client that can be used with appropriate +// servers. It's supposed to be faster than Client because it has persistent +// connection to the server and at the same time is exposes some functionality +// that is only provided via websockets (like event subscription mechanism). +type WSClient struct { + Client + ws *websocket.Conn + done chan struct{} + notifications chan *request.In + responses chan *response.Raw + requests chan *request.Raw + shutdown chan struct{} +} + +// requestResponse is a combined type for request and response since we can get +// any of them here. +type requestResponse struct { + request.In + Error *response.Error `json:"error,omitempty"` + Result json.RawMessage `json:"result,omitempty"` +} + +const ( + // Message limit for receiving side. + wsReadLimit = 10 * 1024 * 1024 + + // Disconnection timeout. + wsPongLimit = 60 * time.Second + + // Ping period for connection liveness check. + wsPingPeriod = wsPongLimit / 2 + + // Write deadline. + wsWriteLimit = wsPingPeriod / 2 +) + +// NewWS returns a new WSClient ready to use (with established websocket +// connection). You need to use websocket URL for it like `ws://1.2.3.4/ws`. +func NewWS(ctx context.Context, endpoint string, opts Options) (*WSClient, error) { + cl, err := New(ctx, endpoint, opts) + cl.cli = nil + + dialer := websocket.Dialer{HandshakeTimeout: opts.DialTimeout} + ws, _, err := dialer.Dial(endpoint, nil) + if err != nil { + return nil, err + } + wsc := &WSClient{ + Client: *cl, + ws: ws, + shutdown: make(chan struct{}), + done: make(chan struct{}), + responses: make(chan *response.Raw), + requests: make(chan *request.Raw), + } + go wsc.wsReader() + go wsc.wsWriter() + wsc.requestF = wsc.makeWsRequest + return wsc, nil +} + +// Close closes connection to the remote side rendering this client instance +// unusable. +func (c *WSClient) Close() { + // Closing shutdown channel send signal to wsWriter to break out of the + // loop. In doing so it does ws.Close() closing the network connection + // which in turn makes wsReader receieve err from ws,ReadJSON() and also + // break out of the loop closing c.done channel in its shutdown sequence. + close(c.shutdown) + <-c.done +} + +func (c *WSClient) wsReader() { + c.ws.SetReadLimit(wsReadLimit) + c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)); return nil }) + for { + rr := new(requestResponse) + c.ws.SetReadDeadline(time.Now().Add(wsPongLimit)) + err := c.ws.ReadJSON(rr) + if err != nil { + // Timeout/connection loss/malformed response. + break + } + if rr.RawID == nil && rr.Method != "" { + if c.notifications != nil { + c.notifications <- &rr.In + } + } else if rr.RawID != nil && (rr.Error != nil || rr.Result != nil) { + resp := new(response.Raw) + resp.ID = rr.RawID + resp.JSONRPC = rr.JSONRPC + resp.Error = rr.Error + resp.Result = rr.Result + c.responses <- resp + } else { + // Malformed response, neither valid request, nor valid response. + break + } + } + close(c.done) + close(c.responses) + if c.notifications != nil { + close(c.notifications) + } +} + +func (c *WSClient) wsWriter() { + pingTicker := time.NewTicker(wsPingPeriod) + defer c.ws.Close() + defer pingTicker.Stop() + for { + select { + case <-c.shutdown: + return + case <-c.done: + return + case req, ok := <-c.requests: + if !ok { + return + } + c.ws.SetWriteDeadline(time.Now().Add(c.opts.RequestTimeout)) + if err := c.ws.WriteJSON(req); err != nil { + return + } + case <-pingTicker.C: + c.ws.SetWriteDeadline(time.Now().Add(wsWriteLimit)) + if err := c.ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + return + } + } + } + +} + +func (c *WSClient) makeWsRequest(r *request.Raw) (*response.Raw, error) { + select { + case <-c.done: + return nil, errors.New("connection lost") + case c.requests <- r: + } + select { + case <-c.done: + return nil, errors.New("connection lost") + case resp := <-c.responses: + return resp, nil + } +} diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go new file mode 100644 index 000000000..2a996999a --- /dev/null +++ b/pkg/rpc/client/wsclient_test.go @@ -0,0 +1,16 @@ +package client + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWSClientClose(t *testing.T) { + srv := initTestServer(t, "") + defer srv.Close() + wsc, err := NewWS(context.TODO(), httpURLtoWS(srv.URL), Options{}) + require.NoError(t, err) + wsc.Close() +}