rpc/server: simplify errors handling during parameter parsing

Forward-ported from 2.x with some updates.
This commit is contained in:
Evgenii Stratonikov 2020-06-04 14:58:47 +03:00 committed by Roman Khimov
parent 6ea0d87934
commit 35f952e44f
4 changed files with 82 additions and 127 deletions

View file

@ -386,6 +386,10 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er
func (s *Server) blockHashFromParam(param *request.Param) (util.Uint256, *response.Error) {
var hash util.Uint256
if param == nil {
return hash, response.ErrInvalidParams
}
switch param.Type {
case request.StringT:
var err error
@ -406,11 +410,7 @@ func (s *Server) blockHashFromParam(param *request.Param) (util.Uint256, *respon
}
func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
param := reqParams.Value(0)
hash, respErr := s.blockHashFromParam(param)
if respErr != nil {
return nil, respErr
@ -430,8 +430,8 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro
}
func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.ValueWithType(0, request.NumberT)
if !ok {
param := reqParams.ValueWithType(0, request.NumberT)
if param == nil {
return nil, response.ErrInvalidParams
}
num, err := s.blockHeightFromParam(param)
@ -472,8 +472,8 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error)
}
func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.Value(0)
if !ok {
param := reqParams.Value(0)
if param == nil {
return nil, response.ErrInvalidParams
}
return validateAddress(param.Value), nil
@ -481,12 +481,7 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, *respon
// getApplicationLog returns the contract log based on the specified txid.
func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
txHash, err := param.GetUint256()
txHash, err := reqParams.Value(0).GetUint256()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -500,11 +495,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp
}
func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
u, err := p.GetUint160FromHex()
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromHex()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -533,11 +524,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Erro
}
func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
u, err := p.GetUint160FromAddress()
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -636,6 +623,9 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6
func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Error) {
var result int32
if param == nil {
return 0, response.ErrInvalidParams
}
switch param.Type {
case request.StringT:
var err error
@ -661,11 +651,7 @@ func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Err
}
func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
param, ok := ps.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
id, rErr := s.contractIDFromParam(param)
id, rErr := s.contractIDFromParam(ps.Value(0))
if rErr == response.ErrUnknown {
return nil, nil
}
@ -673,12 +659,7 @@ func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
return nil, rErr
}
param, ok = ps.Value(1)
if !ok {
return nil, response.ErrInvalidParams
}
key, err := param.GetBytesHex()
key, err := ps.Value(1).GetBytesHex()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -695,9 +676,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
var resultsErr *response.Error
var results interface{}
if param0, ok := reqParams.Value(0); !ok {
return nil, response.ErrInvalidParams
} else if txHash, err := param0.GetUint256(); err != nil {
if txHash, err := reqParams.Value(0).GetUint256(); err != nil {
resultsErr = response.ErrInvalidParams
} else if tx, height, err := s.chain.GetTransaction(txHash); err != nil {
err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash)
@ -709,7 +688,10 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
resultsErr = response.NewInvalidParamsError(err.Error(), err)
}
param1, _ := reqParams.Value(1)
param1 := reqParams.Value(1)
if param1 == nil {
param1 = &request.Param{}
}
switch v := param1.Value.(type) {
case int, float64, bool, string:
@ -729,12 +711,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
}
func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
h, err := p.GetUint256()
h, err := ps.Value(0).GetUint256()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -749,27 +726,21 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response
// getContractState returns contract state (contract information, according to the contract script hash).
func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) {
var results interface{}
param, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
if err != nil {
return nil, response.ErrInvalidParams
} else if scriptHash, err := param.GetUint160FromHex(); err != nil {
return nil, response.ErrInvalidParams
} else {
cs := s.chain.GetContractState(scriptHash)
if cs == nil {
return nil, response.NewRPCError("Unknown contract", "", nil)
}
results = cs
}
return results, nil
cs := s.chain.GetContractState(scriptHash)
if cs == nil {
return nil, response.NewRPCError("Unknown contract", "", nil)
}
return cs, nil
}
// getBlockSysFee returns the system fees of the block, based on the specified index.
func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.ValueWithType(0, request.NumberT)
if !ok {
param := reqParams.ValueWithType(0, request.NumberT)
if param == nil {
return 0, response.ErrInvalidParams
}
@ -796,22 +767,14 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons
func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) {
var verbose bool
param, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
param := reqParams.Value(0)
hash, respErr := s.blockHashFromParam(param)
if respErr != nil {
return nil, respErr
}
param, ok = reqParams.ValueWithType(1, request.NumberT)
if ok {
v, err := param.GetInt()
if err != nil {
return nil, response.ErrInvalidParams
}
v, err := reqParams.ValueWithType(1, request.NumberT).GetInt()
if err == nil {
verbose = v != 0
}
@ -834,11 +797,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons
// getUnclaimedGas returns unclaimed GAS amount of the specified address.
func (s *Server) getUnclaimedGas(ps request.Params) (interface{}, *response.Error) {
p, ok := ps.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
u, err := p.GetUint160FromAddress()
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -876,11 +835,7 @@ func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error)
// invokeFunction implements the `invokeFunction` RPC call.
func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) {
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
scriptHash, err := scriptHashHex.GetUint160FromHex()
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -941,11 +896,7 @@ func (s *Server) runScriptInVM(script []byte, tx *transaction.Transaction) *resu
// submitBlock broadcasts a raw block over the NEO network.
func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) {
param, ok := reqParams.ValueWithType(0, request.StringT)
if !ok {
return nil, response.ErrInvalidParams
}
blockBytes, err := param.GetBytesHex()
blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesHex()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -1004,11 +955,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res
// subscribe handles subscription requests from websocket clients.
func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
p, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
streamName, err := p.GetString()
streamName, err := reqParams.Value(0).GetString()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -1018,8 +965,7 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface
}
// Optional filter.
var filter interface{}
p, ok = reqParams.Value(1)
if ok {
if p := reqParams.Value(1); p != nil {
switch event {
case response.BlockEventID:
if p.Type != request.BlockFilterT {
@ -1093,11 +1039,7 @@ func (s *Server) subscribeToChannel(event response.EventID) {
// unsubscribe handles unsubscription requests from websocket clients.
func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
p, ok := reqParams.Value(0)
if !ok {
return nil, response.ErrInvalidParams
}
id, err := p.GetInt()
id, err := reqParams.Value(0).GetInt()
if err != nil || id < 0 {
return nil, response.ErrInvalidParams
}