From 35f952e44f82de0878a8b17ed9da7c271a2c20ef Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 4 Jun 2020 14:58:47 +0300 Subject: [PATCH] rpc/server: simplify errors handling during parameter parsing Forward-ported from 2.x with some updates. --- pkg/rpc/client/wsclient_test.go | 24 +++--- pkg/rpc/request/param.go | 30 +++++-- pkg/rpc/request/params.go | 15 ++-- pkg/rpc/server/server.go | 140 ++++++++++---------------------- 4 files changed, 82 insertions(+), 127 deletions(-) diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 50fed7ccf..09e331d8e 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -183,8 +183,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.BlockFilterT, param.Type) filt, ok := param.Value.(request.BlockFilter) require.Equal(t, true, ok) @@ -198,8 +198,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.TxFilterT, param.Type) filt, ok := param.Value.(request.TxFilter) require.Equal(t, true, ok) @@ -214,8 +214,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.TxFilterT, param.Type) filt, ok := param.Value.(request.TxFilter) require.Equal(t, true, ok) @@ -231,8 +231,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.TxFilterT, param.Type) filt, ok := param.Value.(request.TxFilter) require.Equal(t, true, ok) @@ -247,8 +247,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.NotificationFilterT, param.Type) filt, ok := param.Value.(request.NotificationFilter) require.Equal(t, true, ok) @@ -262,8 +262,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.ExecutionFilterT, param.Type) filt, ok := param.Value.(request.ExecutionFilter) require.Equal(t, true, ok) diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index 87af906fc..7df182043 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -69,12 +69,17 @@ const ( Cosigner ) +var errMissingParameter = errors.New("parameter is missing") + func (p Param) String() string { return fmt.Sprintf("%v", p.Value) } // GetString returns string value of the parameter. -func (p Param) GetString() (string, error) { +func (p *Param) GetString() (string, error) { + if p == nil { + return "", errMissingParameter + } str, ok := p.Value.(string) if !ok { return "", errors.New("not a string") @@ -83,7 +88,10 @@ func (p Param) GetString() (string, error) { } // GetInt returns int value of te parameter. -func (p Param) GetInt() (int, error) { +func (p *Param) GetInt() (int, error) { + if p == nil { + return 0, errMissingParameter + } i, ok := p.Value.(int) if ok { return i, nil @@ -94,7 +102,10 @@ func (p Param) GetInt() (int, error) { } // GetArray returns a slice of Params stored in the parameter. -func (p Param) GetArray() ([]Param, error) { +func (p *Param) GetArray() ([]Param, error) { + if p == nil { + return nil, errMissingParameter + } a, ok := p.Value.([]Param) if !ok { return nil, errors.New("not an array") @@ -103,7 +114,7 @@ func (p Param) GetArray() ([]Param, error) { } // GetUint256 returns Uint256 value of the parameter. -func (p Param) GetUint256() (util.Uint256, error) { +func (p *Param) GetUint256() (util.Uint256, error) { s, err := p.GetString() if err != nil { return util.Uint256{}, err @@ -113,7 +124,7 @@ func (p Param) GetUint256() (util.Uint256, error) { } // GetUint160FromHex returns Uint160 value of the parameter encoded in hex. -func (p Param) GetUint160FromHex() (util.Uint160, error) { +func (p *Param) GetUint160FromHex() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -127,7 +138,7 @@ func (p Param) GetUint160FromHex() (util.Uint160, error) { // GetUint160FromAddress returns Uint160 value of the parameter that was // supplied as an address. -func (p Param) GetUint160FromAddress() (util.Uint160, error) { +func (p *Param) GetUint160FromAddress() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -137,7 +148,10 @@ func (p Param) GetUint160FromAddress() (util.Uint160, error) { } // GetFuncParam returns current parameter as a function call parameter. -func (p Param) GetFuncParam() (FuncParam, error) { +func (p *Param) GetFuncParam() (FuncParam, error) { + if p == nil { + return FuncParam{}, errMissingParameter + } fp, ok := p.Value.(FuncParam) if !ok { return FuncParam{}, errors.New("not a function parameter") @@ -147,7 +161,7 @@ func (p Param) GetFuncParam() (FuncParam, error) { // GetBytesHex returns []byte value of the parameter if // it is a hex-encoded string. -func (p Param) GetBytesHex() ([]byte, error) { +func (p *Param) GetBytesHex() ([]byte, error) { s, err := p.GetString() if err != nil { return nil, err diff --git a/pkg/rpc/request/params.go b/pkg/rpc/request/params.go index 8b1945cb1..dd2ac35b9 100644 --- a/pkg/rpc/request/params.go +++ b/pkg/rpc/request/params.go @@ -7,20 +7,19 @@ type ( // Value returns the param struct for the given // index if it exists. -func (p Params) Value(index int) (*Param, bool) { +func (p Params) Value(index int) *Param { if len(p) > index { - return &p[index], true + return &p[index] } - return nil, false + return nil } // ValueWithType returns the param struct at the given index if it // exists and matches the given type. -func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) { - if val, ok := p.Value(index); ok && val.Type == valType { - return val, true +func (p Params) ValueWithType(index int, valType paramType) *Param { + if val := p.Value(index); val != nil && val.Type == valType { + return val } - - return nil, false + return nil } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index d9299574d..5fce6c7be 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -386,6 +386,10 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er func (s *Server) blockHashFromParam(param *request.Param) (util.Uint256, *response.Error) { var hash util.Uint256 + if param == nil { + return hash, response.ErrInvalidParams + } + switch param.Type { case request.StringT: var err error @@ -406,11 +410,7 @@ func (s *Server) blockHashFromParam(param *request.Param) (util.Uint256, *respon } func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - + param := reqParams.Value(0) hash, respErr := s.blockHashFromParam(param) if respErr != nil { return nil, respErr @@ -430,8 +430,8 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro } func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return nil, response.ErrInvalidParams } num, err := s.blockHeightFromParam(param) @@ -472,8 +472,8 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) } func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { + param := reqParams.Value(0) + if param == nil { return nil, response.ErrInvalidParams } return validateAddress(param.Value), nil @@ -481,12 +481,7 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, *respon // getApplicationLog returns the contract log based on the specified txid. func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - txHash, err := param.GetUint256() + txHash, err := reqParams.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -500,11 +495,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp } func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromHex() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -533,11 +524,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Erro } func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -636,6 +623,9 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Error) { var result int32 + if param == nil { + return 0, response.ErrInvalidParams + } switch param.Type { case request.StringT: var err error @@ -661,11 +651,7 @@ func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Err } func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { - param, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - id, rErr := s.contractIDFromParam(param) + id, rErr := s.contractIDFromParam(ps.Value(0)) if rErr == response.ErrUnknown { return nil, nil } @@ -673,12 +659,7 @@ func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { return nil, rErr } - param, ok = ps.Value(1) - if !ok { - return nil, response.ErrInvalidParams - } - - key, err := param.GetBytesHex() + key, err := ps.Value(1).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -695,9 +676,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp var resultsErr *response.Error var results interface{} - if param0, ok := reqParams.Value(0); !ok { - return nil, response.ErrInvalidParams - } else if txHash, err := param0.GetUint256(); err != nil { + if txHash, err := reqParams.Value(0).GetUint256(); err != nil { resultsErr = response.ErrInvalidParams } else if tx, height, err := s.chain.GetTransaction(txHash); err != nil { err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash) @@ -709,7 +688,10 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp resultsErr = response.NewInvalidParamsError(err.Error(), err) } - param1, _ := reqParams.Value(1) + param1 := reqParams.Value(1) + if param1 == nil { + param1 = &request.Param{} + } switch v := param1.Value.(type) { case int, float64, bool, string: @@ -729,12 +711,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp } func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - h, err := p.GetUint256() + h, err := ps.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -749,27 +726,21 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response // getContractState returns contract state (contract information, according to the contract script hash). func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) { - var results interface{} - - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() + if err != nil { return nil, response.ErrInvalidParams - } else if scriptHash, err := param.GetUint160FromHex(); err != nil { - return nil, response.ErrInvalidParams - } else { - cs := s.chain.GetContractState(scriptHash) - if cs == nil { - return nil, response.NewRPCError("Unknown contract", "", nil) - } - results = cs } - return results, nil + cs := s.chain.GetContractState(scriptHash) + if cs == nil { + return nil, response.NewRPCError("Unknown contract", "", nil) + } + return cs, nil } // getBlockSysFee returns the system fees of the block, based on the specified index. func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return 0, response.ErrInvalidParams } @@ -796,22 +767,14 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) { var verbose bool - param, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - + param := reqParams.Value(0) hash, respErr := s.blockHashFromParam(param) if respErr != nil { return nil, respErr } - param, ok = reqParams.ValueWithType(1, request.NumberT) - if ok { - v, err := param.GetInt() - if err != nil { - return nil, response.ErrInvalidParams - } + v, err := reqParams.ValueWithType(1, request.NumberT).GetInt() + if err == nil { verbose = v != 0 } @@ -834,11 +797,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons // getUnclaimedGas returns unclaimed GAS amount of the specified address. func (s *Server) getUnclaimedGas(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -876,11 +835,7 @@ func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) // invokeFunction implements the `invokeFunction` RPC call. func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) { - scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - scriptHash, err := scriptHashHex.GetUint160FromHex() + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -941,11 +896,7 @@ func (s *Server) runScriptInVM(script []byte, tx *transaction.Transaction) *resu // submitBlock broadcasts a raw block over the NEO network. func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - blockBytes, err := param.GetBytesHex() + blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -1004,11 +955,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res // subscribe handles subscription requests from websocket clients. func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - streamName, err := p.GetString() + streamName, err := reqParams.Value(0).GetString() if err != nil { return nil, response.ErrInvalidParams } @@ -1018,8 +965,7 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface } // Optional filter. var filter interface{} - p, ok = reqParams.Value(1) - if ok { + if p := reqParams.Value(1); p != nil { switch event { case response.BlockEventID: if p.Type != request.BlockFilterT { @@ -1093,11 +1039,7 @@ func (s *Server) subscribeToChannel(event response.EventID) { // unsubscribe handles unsubscription requests from websocket clients. func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - id, err := p.GetInt() + id, err := reqParams.Value(0).GetInt() if err != nil || id < 0 { return nil, response.ErrInvalidParams }