rpc/server: simplify errors handling during parameter parsing
Forward-ported from 2.x with some updates.
This commit is contained in:
parent
6ea0d87934
commit
35f952e44f
4 changed files with 82 additions and 127 deletions
|
@ -183,8 +183,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
},
|
||||
func(t *testing.T, p *request.Params) {
|
||||
param, ok := p.Value(1)
|
||||
require.Equal(t, true, ok)
|
||||
param := p.Value(1)
|
||||
require.NotNil(t, param)
|
||||
require.Equal(t, request.BlockFilterT, param.Type)
|
||||
filt, ok := param.Value.(request.BlockFilter)
|
||||
require.Equal(t, true, ok)
|
||||
|
@ -198,8 +198,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
},
|
||||
func(t *testing.T, p *request.Params) {
|
||||
param, ok := p.Value(1)
|
||||
require.Equal(t, true, ok)
|
||||
param := p.Value(1)
|
||||
require.NotNil(t, param)
|
||||
require.Equal(t, request.TxFilterT, param.Type)
|
||||
filt, ok := param.Value.(request.TxFilter)
|
||||
require.Equal(t, true, ok)
|
||||
|
@ -214,8 +214,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
},
|
||||
func(t *testing.T, p *request.Params) {
|
||||
param, ok := p.Value(1)
|
||||
require.Equal(t, true, ok)
|
||||
param := p.Value(1)
|
||||
require.NotNil(t, param)
|
||||
require.Equal(t, request.TxFilterT, param.Type)
|
||||
filt, ok := param.Value.(request.TxFilter)
|
||||
require.Equal(t, true, ok)
|
||||
|
@ -231,8 +231,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
},
|
||||
func(t *testing.T, p *request.Params) {
|
||||
param, ok := p.Value(1)
|
||||
require.Equal(t, true, ok)
|
||||
param := p.Value(1)
|
||||
require.NotNil(t, param)
|
||||
require.Equal(t, request.TxFilterT, param.Type)
|
||||
filt, ok := param.Value.(request.TxFilter)
|
||||
require.Equal(t, true, ok)
|
||||
|
@ -247,8 +247,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
},
|
||||
func(t *testing.T, p *request.Params) {
|
||||
param, ok := p.Value(1)
|
||||
require.Equal(t, true, ok)
|
||||
param := p.Value(1)
|
||||
require.NotNil(t, param)
|
||||
require.Equal(t, request.NotificationFilterT, param.Type)
|
||||
filt, ok := param.Value.(request.NotificationFilter)
|
||||
require.Equal(t, true, ok)
|
||||
|
@ -262,8 +262,8 @@ func TestWSFilteredSubscriptions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
},
|
||||
func(t *testing.T, p *request.Params) {
|
||||
param, ok := p.Value(1)
|
||||
require.Equal(t, true, ok)
|
||||
param := p.Value(1)
|
||||
require.NotNil(t, param)
|
||||
require.Equal(t, request.ExecutionFilterT, param.Type)
|
||||
filt, ok := param.Value.(request.ExecutionFilter)
|
||||
require.Equal(t, true, ok)
|
||||
|
|
|
@ -69,12 +69,17 @@ const (
|
|||
Cosigner
|
||||
)
|
||||
|
||||
var errMissingParameter = errors.New("parameter is missing")
|
||||
|
||||
func (p Param) String() string {
|
||||
return fmt.Sprintf("%v", p.Value)
|
||||
}
|
||||
|
||||
// GetString returns string value of the parameter.
|
||||
func (p Param) GetString() (string, error) {
|
||||
func (p *Param) GetString() (string, error) {
|
||||
if p == nil {
|
||||
return "", errMissingParameter
|
||||
}
|
||||
str, ok := p.Value.(string)
|
||||
if !ok {
|
||||
return "", errors.New("not a string")
|
||||
|
@ -83,7 +88,10 @@ func (p Param) GetString() (string, error) {
|
|||
}
|
||||
|
||||
// GetInt returns int value of te parameter.
|
||||
func (p Param) GetInt() (int, error) {
|
||||
func (p *Param) GetInt() (int, error) {
|
||||
if p == nil {
|
||||
return 0, errMissingParameter
|
||||
}
|
||||
i, ok := p.Value.(int)
|
||||
if ok {
|
||||
return i, nil
|
||||
|
@ -94,7 +102,10 @@ func (p Param) GetInt() (int, error) {
|
|||
}
|
||||
|
||||
// GetArray returns a slice of Params stored in the parameter.
|
||||
func (p Param) GetArray() ([]Param, error) {
|
||||
func (p *Param) GetArray() ([]Param, error) {
|
||||
if p == nil {
|
||||
return nil, errMissingParameter
|
||||
}
|
||||
a, ok := p.Value.([]Param)
|
||||
if !ok {
|
||||
return nil, errors.New("not an array")
|
||||
|
@ -103,7 +114,7 @@ func (p Param) GetArray() ([]Param, error) {
|
|||
}
|
||||
|
||||
// GetUint256 returns Uint256 value of the parameter.
|
||||
func (p Param) GetUint256() (util.Uint256, error) {
|
||||
func (p *Param) GetUint256() (util.Uint256, error) {
|
||||
s, err := p.GetString()
|
||||
if err != nil {
|
||||
return util.Uint256{}, err
|
||||
|
@ -113,7 +124,7 @@ func (p Param) GetUint256() (util.Uint256, error) {
|
|||
}
|
||||
|
||||
// GetUint160FromHex returns Uint160 value of the parameter encoded in hex.
|
||||
func (p Param) GetUint160FromHex() (util.Uint160, error) {
|
||||
func (p *Param) GetUint160FromHex() (util.Uint160, error) {
|
||||
s, err := p.GetString()
|
||||
if err != nil {
|
||||
return util.Uint160{}, err
|
||||
|
@ -127,7 +138,7 @@ func (p Param) GetUint160FromHex() (util.Uint160, error) {
|
|||
|
||||
// GetUint160FromAddress returns Uint160 value of the parameter that was
|
||||
// supplied as an address.
|
||||
func (p Param) GetUint160FromAddress() (util.Uint160, error) {
|
||||
func (p *Param) GetUint160FromAddress() (util.Uint160, error) {
|
||||
s, err := p.GetString()
|
||||
if err != nil {
|
||||
return util.Uint160{}, err
|
||||
|
@ -137,7 +148,10 @@ func (p Param) GetUint160FromAddress() (util.Uint160, error) {
|
|||
}
|
||||
|
||||
// GetFuncParam returns current parameter as a function call parameter.
|
||||
func (p Param) GetFuncParam() (FuncParam, error) {
|
||||
func (p *Param) GetFuncParam() (FuncParam, error) {
|
||||
if p == nil {
|
||||
return FuncParam{}, errMissingParameter
|
||||
}
|
||||
fp, ok := p.Value.(FuncParam)
|
||||
if !ok {
|
||||
return FuncParam{}, errors.New("not a function parameter")
|
||||
|
@ -147,7 +161,7 @@ func (p Param) GetFuncParam() (FuncParam, error) {
|
|||
|
||||
// GetBytesHex returns []byte value of the parameter if
|
||||
// it is a hex-encoded string.
|
||||
func (p Param) GetBytesHex() ([]byte, error) {
|
||||
func (p *Param) GetBytesHex() ([]byte, error) {
|
||||
s, err := p.GetString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -7,20 +7,19 @@ type (
|
|||
|
||||
// Value returns the param struct for the given
|
||||
// index if it exists.
|
||||
func (p Params) Value(index int) (*Param, bool) {
|
||||
func (p Params) Value(index int) *Param {
|
||||
if len(p) > index {
|
||||
return &p[index], true
|
||||
return &p[index]
|
||||
}
|
||||
|
||||
return nil, false
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValueWithType returns the param struct at the given index if it
|
||||
// exists and matches the given type.
|
||||
func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) {
|
||||
if val, ok := p.Value(index); ok && val.Type == valType {
|
||||
return val, true
|
||||
func (p Params) ValueWithType(index int, valType paramType) *Param {
|
||||
if val := p.Value(index); val != nil && val.Type == valType {
|
||||
return val
|
||||
}
|
||||
|
||||
return nil, false
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -386,6 +386,10 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er
|
|||
func (s *Server) blockHashFromParam(param *request.Param) (util.Uint256, *response.Error) {
|
||||
var hash util.Uint256
|
||||
|
||||
if param == nil {
|
||||
return hash, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
switch param.Type {
|
||||
case request.StringT:
|
||||
var err error
|
||||
|
@ -406,11 +410,7 @@ func (s *Server) blockHashFromParam(param *request.Param) (util.Uint256, *respon
|
|||
}
|
||||
|
||||
func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
param := reqParams.Value(0)
|
||||
hash, respErr := s.blockHashFromParam(param)
|
||||
if respErr != nil {
|
||||
return nil, respErr
|
||||
|
@ -430,8 +430,8 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro
|
|||
}
|
||||
|
||||
func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
||||
if !ok {
|
||||
param := reqParams.ValueWithType(0, request.NumberT)
|
||||
if param == nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
num, err := s.blockHeightFromParam(param)
|
||||
|
@ -472,8 +472,8 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error)
|
|||
}
|
||||
|
||||
func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
param := reqParams.Value(0)
|
||||
if param == nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
return validateAddress(param.Value), nil
|
||||
|
@ -481,12 +481,7 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, *respon
|
|||
|
||||
// getApplicationLog returns the contract log based on the specified txid.
|
||||
func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
txHash, err := param.GetUint256()
|
||||
txHash, err := reqParams.Value(0).GetUint256()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -500,11 +495,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp
|
|||
}
|
||||
|
||||
func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
u, err := p.GetUint160FromHex()
|
||||
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -533,11 +524,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Erro
|
|||
}
|
||||
|
||||
func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
u, err := p.GetUint160FromAddress()
|
||||
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -636,6 +623,9 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6
|
|||
|
||||
func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Error) {
|
||||
var result int32
|
||||
if param == nil {
|
||||
return 0, response.ErrInvalidParams
|
||||
}
|
||||
switch param.Type {
|
||||
case request.StringT:
|
||||
var err error
|
||||
|
@ -661,11 +651,7 @@ func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Err
|
|||
}
|
||||
|
||||
func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
|
||||
param, ok := ps.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
id, rErr := s.contractIDFromParam(param)
|
||||
id, rErr := s.contractIDFromParam(ps.Value(0))
|
||||
if rErr == response.ErrUnknown {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -673,12 +659,7 @@ func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) {
|
|||
return nil, rErr
|
||||
}
|
||||
|
||||
param, ok = ps.Value(1)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
key, err := param.GetBytesHex()
|
||||
key, err := ps.Value(1).GetBytesHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -695,9 +676,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
|
|||
var resultsErr *response.Error
|
||||
var results interface{}
|
||||
|
||||
if param0, ok := reqParams.Value(0); !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
} else if txHash, err := param0.GetUint256(); err != nil {
|
||||
if txHash, err := reqParams.Value(0).GetUint256(); err != nil {
|
||||
resultsErr = response.ErrInvalidParams
|
||||
} else if tx, height, err := s.chain.GetTransaction(txHash); err != nil {
|
||||
err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash)
|
||||
|
@ -709,7 +688,10 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
|
|||
resultsErr = response.NewInvalidParamsError(err.Error(), err)
|
||||
}
|
||||
|
||||
param1, _ := reqParams.Value(1)
|
||||
param1 := reqParams.Value(1)
|
||||
if param1 == nil {
|
||||
param1 = &request.Param{}
|
||||
}
|
||||
switch v := param1.Value.(type) {
|
||||
|
||||
case int, float64, bool, string:
|
||||
|
@ -729,12 +711,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
|
|||
}
|
||||
|
||||
func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
h, err := p.GetUint256()
|
||||
h, err := ps.Value(0).GetUint256()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -749,27 +726,21 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response
|
|||
|
||||
// getContractState returns contract state (contract information, according to the contract script hash).
|
||||
func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var results interface{}
|
||||
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
} else if scriptHash, err := param.GetUint160FromHex(); err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
} else {
|
||||
}
|
||||
cs := s.chain.GetContractState(scriptHash)
|
||||
if cs == nil {
|
||||
return nil, response.NewRPCError("Unknown contract", "", nil)
|
||||
}
|
||||
results = cs
|
||||
}
|
||||
return results, nil
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
// getBlockSysFee returns the system fees of the block, based on the specified index.
|
||||
func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.NumberT)
|
||||
if !ok {
|
||||
param := reqParams.ValueWithType(0, request.NumberT)
|
||||
if param == nil {
|
||||
return 0, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
|
@ -796,22 +767,14 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons
|
|||
func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) {
|
||||
var verbose bool
|
||||
|
||||
param, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
||||
param := reqParams.Value(0)
|
||||
hash, respErr := s.blockHashFromParam(param)
|
||||
if respErr != nil {
|
||||
return nil, respErr
|
||||
}
|
||||
|
||||
param, ok = reqParams.ValueWithType(1, request.NumberT)
|
||||
if ok {
|
||||
v, err := param.GetInt()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
v, err := reqParams.ValueWithType(1, request.NumberT).GetInt()
|
||||
if err == nil {
|
||||
verbose = v != 0
|
||||
}
|
||||
|
||||
|
@ -834,11 +797,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons
|
|||
|
||||
// getUnclaimedGas returns unclaimed GAS amount of the specified address.
|
||||
func (s *Server) getUnclaimedGas(ps request.Params) (interface{}, *response.Error) {
|
||||
p, ok := ps.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
u, err := p.GetUint160FromAddress()
|
||||
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -876,11 +835,7 @@ func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error)
|
|||
|
||||
// invokeFunction implements the `invokeFunction` RPC call.
|
||||
func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) {
|
||||
scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
scriptHash, err := scriptHashHex.GetUint160FromHex()
|
||||
scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -941,11 +896,7 @@ func (s *Server) runScriptInVM(script []byte, tx *transaction.Transaction) *resu
|
|||
|
||||
// submitBlock broadcasts a raw block over the NEO network.
|
||||
func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) {
|
||||
param, ok := reqParams.ValueWithType(0, request.StringT)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
blockBytes, err := param.GetBytesHex()
|
||||
blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesHex()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -1004,11 +955,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res
|
|||
|
||||
// subscribe handles subscription requests from websocket clients.
|
||||
func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
|
||||
p, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
streamName, err := p.GetString()
|
||||
streamName, err := reqParams.Value(0).GetString()
|
||||
if err != nil {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
@ -1018,8 +965,7 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface
|
|||
}
|
||||
// Optional filter.
|
||||
var filter interface{}
|
||||
p, ok = reqParams.Value(1)
|
||||
if ok {
|
||||
if p := reqParams.Value(1); p != nil {
|
||||
switch event {
|
||||
case response.BlockEventID:
|
||||
if p.Type != request.BlockFilterT {
|
||||
|
@ -1093,11 +1039,7 @@ func (s *Server) subscribeToChannel(event response.EventID) {
|
|||
|
||||
// unsubscribe handles unsubscription requests from websocket clients.
|
||||
func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) {
|
||||
p, ok := reqParams.Value(0)
|
||||
if !ok {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
id, err := p.GetInt()
|
||||
id, err := reqParams.Value(0).GetInt()
|
||||
if err != nil || id < 0 {
|
||||
return nil, response.ErrInvalidParams
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue