From dbe7be5b0091643b9ec4be8bcfc58224c591e1d1 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Tue, 28 Apr 2020 22:35:19 +0300 Subject: [PATCH] rpc: change handlers to always return response.Error for errors As it's expected by WriteErrorResponse() actually. --- pkg/rpc/server/server.go | 126 +++++++++++++++++++-------------------- 1 file changed, 62 insertions(+), 64 deletions(-) diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index e1b6210bf..7d9de8267 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -43,7 +43,7 @@ type ( } ) -var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){ +var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){ "getaccountstate": (*Server).getAccountState, "getapplicationlog": (*Server).getApplicationLog, "getassetstate": (*Server).getAssetState, @@ -76,8 +76,8 @@ var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, error){ "validateaddress": (*Server).validateAddress, } -var invalidBlockHeightError = func(index int, height int) error { - return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height) +var invalidBlockHeightError = func(index int, height int) *response.Error { + return response.NewRPCError(fmt.Sprintf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height), "", nil) } // New creates a new Server struct. @@ -184,7 +184,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, req *request.In) { var ( results interface{} - resultsErr error + resultsErr *response.Error ) incCounter(req.Method) @@ -204,19 +204,19 @@ func (s *Server) handleRequest(w http.ResponseWriter, req *request.In) { s.WriteResponse(req, w, results) } -func (s *Server) getBestBlockHash(_ request.Params) (interface{}, error) { +func (s *Server) getBestBlockHash(_ request.Params) (interface{}, *response.Error) { return "0x" + s.chain.CurrentBlockHash().StringLE(), nil } -func (s *Server) getBlockCount(_ request.Params) (interface{}, error) { +func (s *Server) getBlockCount(_ request.Params) (interface{}, *response.Error) { return s.chain.BlockHeight() + 1, nil } -func (s *Server) getConnectionCount(_ request.Params) (interface{}, error) { +func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Error) { return s.coreServer.PeerCount(), nil } -func (s *Server) getBlock(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) { var hash util.Uint256 param, ok := reqParams.Value(0) @@ -254,7 +254,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, error) { return hex.EncodeToString(writer.Bytes()), nil } -func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.NumberT) if !ok { return nil, response.ErrInvalidParams @@ -267,7 +267,7 @@ func (s *Server) getBlockHash(reqParams request.Params) (interface{}, error) { return s.chain.GetHeaderHash(num), nil } -func (s *Server) getVersion(_ request.Params) (interface{}, error) { +func (s *Server) getVersion(_ request.Params) (interface{}, *response.Error) { return result.Version{ Port: s.coreServer.Port, Nonce: s.coreServer.ID(), @@ -275,7 +275,7 @@ func (s *Server) getVersion(_ request.Params) (interface{}, error) { }, nil } -func (s *Server) getPeers(_ request.Params) (interface{}, error) { +func (s *Server) getPeers(_ request.Params) (interface{}, *response.Error) { peers := result.NewGetPeers() peers.AddUnconnected(s.coreServer.UnconnectedPeers()) peers.AddConnected(s.coreServer.ConnectedPeers()) @@ -283,7 +283,7 @@ func (s *Server) getPeers(_ request.Params) (interface{}, error) { return peers, nil } -func (s *Server) getRawMempool(_ request.Params) (interface{}, error) { +func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) { mp := s.chain.GetMemPool() hashList := make([]util.Uint256, 0) for _, item := range mp.GetVerifiedTransactions() { @@ -292,7 +292,7 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, error) { return hashList, nil } -func (s *Server) validateAddress(reqParams request.Params) (interface{}, error) { +func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -300,7 +300,7 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, error) return validateAddress(param.Value), nil } -func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) { +func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -319,7 +319,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, error) { } // getApplicationLog returns the contract log based on the specified txid. -func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error) { +func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -351,7 +351,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, error return result.NewApplicationLog(appExecResult, scriptHash), nil } -func (s *Server) getClaimable(ps request.Params) (interface{}, error) { +func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -368,7 +368,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) { return nil }) if err != nil { - return nil, err + return nil, response.NewInternalServerError("Unclaimed processing failure", err) } } @@ -403,7 +403,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, error) { }, nil } -func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) { +func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -436,7 +436,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, error) { return bs, nil } -func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, error) { +func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -500,7 +500,7 @@ func amountToString(amount int64, decimals int64) string { return fmt.Sprintf(fs, q, r) } -func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, error) { +func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int64, *response.Error) { if d, ok := cache[h]; ok { return d, nil } @@ -515,11 +515,11 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 }, }) if err != nil { - return 0, err + return 0, response.NewInternalServerError("Can't create script", err) } res := s.runScriptInVM(script) if res == nil || res.State != "HALT" || len(res.Stack) == 0 { - return 0, errors.New("execution error") + return 0, response.NewInternalServerError("execution error", errors.New("no result")) } var d int64 @@ -529,16 +529,16 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 case smartcontract.ByteArrayType: d = emit.BytesToInt(item.Value.([]byte)).Int64() default: - return 0, errors.New("invalid result") + return 0, response.NewInternalServerError("invalid result", errors.New("not an integer")) } if d < 0 { - return 0, errors.New("negative decimals") + return 0, response.NewInternalServerError("incorrect result", errors.New("negative result")) } cache[h] = d return d, nil } -func (s *Server) getStorage(ps request.Params) (interface{}, error) { +func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { param, ok := ps.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -569,8 +569,8 @@ func (s *Server) getStorage(ps request.Params) (interface{}, error) { return hex.EncodeToString(item.Value), nil } -func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error) { - var resultsErr error +func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *response.Error) { + var resultsErr *response.Error var results interface{} if param0, ok := reqParams.Value(0); !ok { @@ -606,7 +606,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, error return results, resultsErr } -func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) { +func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) { p, ok := ps.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -625,7 +625,7 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, error) { return height, nil } -func (s *Server) getTxOut(ps request.Params) (interface{}, error) { +func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) { p, ok := ps.Value(0) if !ok { return nil, response.ErrInvalidParams @@ -660,7 +660,7 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, error) { } // getContractState returns contract state (contract information, according to the contract script hash). -func (s *Server) getContractState(reqParams request.Params) (interface{}, error) { +func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) { var results interface{} param, ok := reqParams.ValueWithType(0, request.StringT) @@ -679,17 +679,17 @@ func (s *Server) getContractState(reqParams request.Params) (interface{}, error) return results, nil } -func (s *Server) getAccountState(ps request.Params) (interface{}, error) { +func (s *Server) getAccountState(ps request.Params) (interface{}, *response.Error) { return s.getAccountStateAux(ps, false) } -func (s *Server) getUnspents(ps request.Params) (interface{}, error) { +func (s *Server) getUnspents(ps request.Params) (interface{}, *response.Error) { return s.getAccountStateAux(ps, true) } // getAccountState returns account state either in short or full (unspents included) form. -func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, error) { - var resultsErr error +func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (interface{}, *response.Error) { + var resultsErr *response.Error var results interface{} param, ok := reqParams.ValueWithType(0, request.StringT) @@ -716,7 +716,7 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in } // getBlockSysFee returns the system fees of the block, based on the specified index. -func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.NumberT) if !ok { return 0, response.ErrInvalidParams @@ -728,9 +728,9 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) { } headerHash := s.chain.GetHeaderHash(num) - block, err := s.chain.GetBlock(headerHash) - if err != nil { - return 0, response.NewRPCError(err.Error(), "", nil) + block, errBlock := s.chain.GetBlock(headerHash) + if errBlock != nil { + return 0, response.NewRPCError(errBlock.Error(), "", nil) } var blockSysFee util.Fixed8 @@ -742,7 +742,7 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, error) { } // getBlockHeader returns the corresponding block header information according to the specified script hash. -func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) { +func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) { var verbose bool param, ok := reqParams.ValueWithType(0, request.StringT) @@ -775,13 +775,13 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, error) { buf := io.NewBufBinWriter() h.EncodeBinary(buf.BinWriter) if buf.Err != nil { - return nil, err + return nil, response.NewInternalServerError("encoding error", buf.Err) } return hex.EncodeToString(buf.Bytes()), nil } // getUnclaimed returns unclaimed GAS amount of the specified address. -func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) { +func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) { p, ok := ps.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -795,21 +795,24 @@ func (s *Server) getUnclaimed(ps request.Params) (interface{}, error) { if acc == nil { return nil, response.NewInternalServerError("unknown account", nil) } - - return result.NewUnclaimed(acc, s.chain) + res, errRes := result.NewUnclaimed(acc, s.chain) + if errRes != nil { + return nil, response.NewInternalServerError("can't create unclaimed response", errRes) + } + return res, nil } // getValidators returns the current NEO consensus nodes information and voting status. -func (s *Server) getValidators(_ request.Params) (interface{}, error) { +func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) { var validators keys.PublicKeys validators, err := s.chain.GetValidators() if err != nil { - return nil, err + return nil, response.NewRPCError("can't get validators", "", err) } enrollments, err := s.chain.GetEnrollments() if err != nil { - return nil, err + return nil, response.NewRPCError("can't get enrollments", "", err) } var res []result.Validator for _, v := range enrollments { @@ -823,14 +826,14 @@ func (s *Server) getValidators(_ request.Params) (interface{}, error) { } // invoke implements the `invoke` RPC call. -func (s *Server) invoke(reqParams request.Params) (interface{}, error) { +func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) { scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams } scriptHash, err := scriptHashHex.GetUint160FromHex() if err != nil { - return nil, err + return nil, response.ErrInvalidParams } sliceP, ok := reqParams.ValueWithType(1, request.ArrayT) if !ok { @@ -838,34 +841,34 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, error) { } slice, err := sliceP.GetArray() if err != nil { - return nil, err + return nil, response.ErrInvalidParams } script, err := request.CreateInvocationScript(scriptHash, slice) if err != nil { - return nil, err + return nil, response.NewInternalServerError("can't create invocation script", err) } return s.runScriptInVM(script), nil } // invokescript implements the `invokescript` RPC call. -func (s *Server) invokeFunction(reqParams request.Params) (interface{}, error) { +func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) { scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams } scriptHash, err := scriptHashHex.GetUint160FromHex() if err != nil { - return nil, err + return nil, response.ErrInvalidParams } script, err := request.CreateFunctionInvocationScript(scriptHash, reqParams[1:]) if err != nil { - return nil, err + return nil, response.NewInternalServerError("can't create invocation script", err) } return s.runScriptInVM(script), nil } // invokescript implements the `invokescript` RPC call. -func (s *Server) invokescript(reqParams request.Params) (interface{}, error) { +func (s *Server) invokescript(reqParams request.Params) (interface{}, *response.Error) { if len(reqParams) < 1 { return nil, response.ErrInvalidParams } @@ -895,7 +898,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke { } // submitBlock broadcasts a raw block over the NEO network. -func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) { +func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) { param, ok := reqParams.ValueWithType(0, request.StringT) if !ok { return nil, response.ErrInvalidParams @@ -922,8 +925,8 @@ func (s *Server) submitBlock(reqParams request.Params) (interface{}, error) { return true, nil } -func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, error) { - var resultsErr error +func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *response.Error) { + var resultsErr *response.Error var results interface{} if len(reqParams) < 1 { @@ -959,7 +962,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, erro return results, resultsErr } -func (s *Server) blockHeightFromParam(param *request.Param) (int, error) { +func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) { num, err := param.GetInt() if err != nil { return 0, nil @@ -972,12 +975,7 @@ func (s *Server) blockHeightFromParam(param *request.Param) (int, error) { } // WriteErrorResponse writes an error response to the ResponseWriter. -func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, err error) { - jsonErr, ok := err.(*response.Error) - if !ok { - jsonErr = response.NewInternalServerError("Internal server error", err) - } - +func (s *Server) WriteErrorResponse(r *request.In, w http.ResponseWriter, jsonErr *response.Error) { resp := response.Raw{ HeaderAndError: response.HeaderAndError{ Header: response.Header{