diff --git a/pkg/rpc/client/wsclient_test.go b/pkg/rpc/client/wsclient_test.go index 1eebe08bd..b548e74a3 100644 --- a/pkg/rpc/client/wsclient_test.go +++ b/pkg/rpc/client/wsclient_test.go @@ -181,8 +181,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.TxFilterT, param.Type) filt, ok := param.Value.(request.TxFilter) require.Equal(t, true, ok) @@ -196,8 +196,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.NotificationFilterT, param.Type) filt, ok := param.Value.(request.NotificationFilter) require.Equal(t, true, ok) @@ -211,8 +211,8 @@ func TestWSFilteredSubscriptions(t *testing.T) { require.NoError(t, err) }, func(t *testing.T, p *request.Params) { - param, ok := p.Value(1) - require.Equal(t, true, ok) + param := p.Value(1) + require.NotNil(t, param) require.Equal(t, request.ExecutionFilterT, param.Type) filt, ok := param.Value.(request.ExecutionFilter) require.Equal(t, true, ok) diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index 42159c336..5a961ffe9 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -61,12 +61,17 @@ const ( ExecutionFilterT ) +var errMissingParameter = errors.New("parameter is missing") + func (p Param) String() string { return fmt.Sprintf("%v", p.Value) } // GetString returns string value of the parameter. -func (p Param) GetString() (string, error) { +func (p *Param) GetString() (string, error) { + if p == nil { + return "", errMissingParameter + } str, ok := p.Value.(string) if !ok { return "", errors.New("not a string") @@ -75,7 +80,10 @@ func (p Param) GetString() (string, error) { } // GetInt returns int value of te parameter. -func (p Param) GetInt() (int, error) { +func (p *Param) GetInt() (int, error) { + if p == nil { + return 0, errMissingParameter + } i, ok := p.Value.(int) if ok { return i, nil @@ -86,7 +94,10 @@ func (p Param) GetInt() (int, error) { } // GetArray returns a slice of Params stored in the parameter. -func (p Param) GetArray() ([]Param, error) { +func (p *Param) GetArray() ([]Param, error) { + if p == nil { + return nil, errMissingParameter + } a, ok := p.Value.([]Param) if !ok { return nil, errors.New("not an array") @@ -95,7 +106,7 @@ func (p Param) GetArray() ([]Param, error) { } // GetUint256 returns Uint256 value of the parameter. -func (p Param) GetUint256() (util.Uint256, error) { +func (p *Param) GetUint256() (util.Uint256, error) { s, err := p.GetString() if err != nil { return util.Uint256{}, err @@ -105,7 +116,7 @@ func (p Param) GetUint256() (util.Uint256, error) { } // GetUint160FromHex returns Uint160 value of the parameter encoded in hex. -func (p Param) GetUint160FromHex() (util.Uint160, error) { +func (p *Param) GetUint160FromHex() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -119,7 +130,7 @@ func (p Param) GetUint160FromHex() (util.Uint160, error) { // GetUint160FromAddress returns Uint160 value of the parameter that was // supplied as an address. -func (p Param) GetUint160FromAddress() (util.Uint160, error) { +func (p *Param) GetUint160FromAddress() (util.Uint160, error) { s, err := p.GetString() if err != nil { return util.Uint160{}, err @@ -129,7 +140,10 @@ func (p Param) GetUint160FromAddress() (util.Uint160, error) { } // GetFuncParam returns current parameter as a function call parameter. -func (p Param) GetFuncParam() (FuncParam, error) { +func (p *Param) GetFuncParam() (FuncParam, error) { + if p == nil { + return FuncParam{}, errMissingParameter + } fp, ok := p.Value.(FuncParam) if !ok { return FuncParam{}, errors.New("not a function parameter") @@ -139,7 +153,7 @@ func (p Param) GetFuncParam() (FuncParam, error) { // GetBytesHex returns []byte value of the parameter if // it is a hex-encoded string. -func (p Param) GetBytesHex() ([]byte, error) { +func (p *Param) GetBytesHex() ([]byte, error) { s, err := p.GetString() if err != nil { return nil, err diff --git a/pkg/rpc/request/params.go b/pkg/rpc/request/params.go index 8b1945cb1..dd2ac35b9 100644 --- a/pkg/rpc/request/params.go +++ b/pkg/rpc/request/params.go @@ -7,20 +7,19 @@ type ( // Value returns the param struct for the given // index if it exists. -func (p Params) Value(index int) (*Param, bool) { +func (p Params) Value(index int) *Param { if len(p) > index { - return &p[index], true + return &p[index] } - return nil, false + return nil } // ValueWithType returns the param struct at the given index if it // exists and matches the given type. -func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) { - if val, ok := p.Value(index); ok && val.Type == valType { - return val, true +func (p Params) ValueWithType(index int, valType paramType) *Param { + if val := p.Value(index); val != nil && val.Type == valType { + return val } - - return nil, false + return nil } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index a91502e5e..f62a368fe 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -389,8 +389,8 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Error) { var hash util.Uint256 - param, ok := reqParams.Value(0) - if !ok { + param := reqParams.Value(0) + if param == nil { return nil, response.ErrInvalidParams } @@ -425,8 +425,8 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro } func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return nil, response.ErrInvalidParams } num, err := s.blockHeightFromParam(param) @@ -463,20 +463,15 @@ func (s *Server) getRawMempool(_ request.Params) (interface{}, *response.Error) } func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { + param := reqParams.Value(0) + if param == nil { return nil, response.ErrInvalidParams } return validateAddress(param.Value), nil } func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - - paramAssetID, err := param.GetUint256() + paramAssetID, err := reqParams.ValueWithType(0, request.StringT).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -490,12 +485,7 @@ func (s *Server) getAssetState(reqParams request.Params) (interface{}, *response // getApplicationLog returns the contract log based on the specified txid. func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - txHash, err := param.GetUint256() + txHash, err := reqParams.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -522,10 +512,7 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp } func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } + p := ps.ValueWithType(0, request.StringT) u, err := p.GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams @@ -574,11 +561,7 @@ func (s *Server) getClaimable(ps request.Params) (interface{}, *response.Error) } func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromHex() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -607,11 +590,7 @@ func (s *Server) getNEP5Balances(ps request.Params) (interface{}, *response.Erro } func (s *Server) getNEP5Transfers(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -709,24 +688,14 @@ func (s *Server) getDecimals(h util.Uint160, cache map[util.Uint160]int64) (int6 } func (s *Server) getStorage(ps request.Params) (interface{}, *response.Error) { - param, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - scriptHash, err := param.GetUint160FromHex() + scriptHash, err := ps.Value(0).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } scriptHash = scriptHash.Reverse() - param, ok = ps.Value(1) - if !ok { - return nil, response.ErrInvalidParams - } - - key, err := param.GetBytesHex() + key, err := ps.Value(1).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -743,9 +712,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp var resultsErr *response.Error var results interface{} - if param0, ok := reqParams.Value(0); !ok { - return nil, response.ErrInvalidParams - } else if txHash, err := param0.GetUint256(); err != nil { + if txHash, err := reqParams.Value(0).GetUint256(); err != nil { resultsErr = response.ErrInvalidParams } else if tx, height, err := s.chain.GetTransaction(txHash); err != nil { err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash) @@ -757,7 +724,10 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp resultsErr = response.NewInvalidParamsError(err.Error(), err) } - param1, _ := reqParams.Value(1) + param1 := reqParams.Value(1) + if param1 == nil { + param1 = &request.Param{} + } switch v := param1.Value.(type) { case int, float64, bool, string: @@ -777,12 +747,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp } func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - h, err := p.GetUint256() + h, err := ps.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } @@ -796,22 +761,12 @@ func (s *Server) getTransactionHeight(ps request.Params) (interface{}, *response } func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - - h, err := p.GetUint256() + h, err := ps.Value(0).GetUint256() if err != nil { return nil, response.ErrInvalidParams } - p, ok = ps.ValueWithType(1, request.NumberT) - if !ok { - return nil, response.ErrInvalidParams - } - - num, err := p.GetInt() + num, err := ps.ValueWithType(1, request.NumberT).GetInt() if err != nil || num < 0 { return nil, response.ErrInvalidParams } @@ -833,18 +788,15 @@ func (s *Server) getTxOut(ps request.Params) (interface{}, *response.Error) { func (s *Server) getContractState(reqParams request.Params) (interface{}, *response.Error) { var results interface{} - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } else if scriptHash, err := param.GetUint160FromHex(); err != nil { + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() + if err != nil { return nil, response.ErrInvalidParams + } + cs := s.chain.GetContractState(scriptHash) + if cs != nil { + results = result.NewContractState(cs) } else { - cs := s.chain.GetContractState(scriptHash) - if cs != nil { - results = result.NewContractState(cs) - } else { - return nil, response.NewRPCError("Unknown contract", "", nil) - } + return nil, response.NewRPCError("Unknown contract", "", nil) } return results, nil } @@ -862,33 +814,31 @@ func (s *Server) getAccountStateAux(reqParams request.Params, unspents bool) (in var resultsErr *response.Error var results interface{} - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } else if scriptHash, err := param.GetUint160FromAddress(); err != nil { + param := reqParams.ValueWithType(0, request.StringT) + scriptHash, err := param.GetUint160FromAddress() + if err != nil { return nil, response.ErrInvalidParams + } + as := s.chain.GetAccountState(scriptHash) + if as == nil { + as = state.NewAccount(scriptHash) + } + if unspents { + str, err := param.GetString() + if err != nil { + return nil, response.ErrInvalidParams + } + results = result.NewUnspents(as, s.chain, str) } else { - as := s.chain.GetAccountState(scriptHash) - if as == nil { - as = state.NewAccount(scriptHash) - } - if unspents { - str, err := param.GetString() - if err != nil { - return nil, response.ErrInvalidParams - } - results = result.NewUnspents(as, s.chain, str) - } else { - results = result.NewAccountState(as) - } + results = result.NewAccountState(as) } return results, resultsErr } // getBlockSysFee returns the system fees of the block, based on the specified index. func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.NumberT) - if !ok { + param := reqParams.ValueWithType(0, request.NumberT) + if param == nil { return 0, response.ErrInvalidParams } @@ -915,21 +865,13 @@ func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *respons func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *response.Error) { var verbose bool - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - hash, err := param.GetUint256() + hash, err := reqParams.ValueWithType(0, request.StringT).GetUint256() if err != nil { return nil, response.ErrInvalidParams } - param, ok = reqParams.ValueWithType(1, request.NumberT) - if ok { - v, err := param.GetInt() - if err != nil { - return nil, response.ErrInvalidParams - } + v, err := reqParams.ValueWithType(1, request.NumberT).GetInt() + if err == nil { verbose = v != 0 } @@ -952,11 +894,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons // getUnclaimed returns unclaimed GAS amount of the specified address. func (s *Server) getUnclaimed(ps request.Params) (interface{}, *response.Error) { - p, ok := ps.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - u, err := p.GetUint160FromAddress() + u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddress() if err != nil { return nil, response.ErrInvalidParams } @@ -997,19 +935,11 @@ func (s *Server) getValidators(_ request.Params) (interface{}, *response.Error) // invoke implements the `invoke` RPC call. func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) { - scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - scriptHash, err := scriptHashHex.GetUint160FromHex() + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } - sliceP, ok := reqParams.ValueWithType(1, request.ArrayT) - if !ok { - return nil, response.ErrInvalidParams - } - slice, err := sliceP.GetArray() + slice, err := reqParams.ValueWithType(1, request.ArrayT).GetArray() if err != nil { return nil, response.ErrInvalidParams } @@ -1022,11 +952,7 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) // invokescript implements the `invokescript` RPC call. func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *response.Error) { - scriptHashHex, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - scriptHash, err := scriptHashHex.GetUint160FromHex() + scriptHash, err := reqParams.ValueWithType(0, request.StringT).GetUint160FromHex() if err != nil { return nil, response.ErrInvalidParams } @@ -1069,11 +995,7 @@ func (s *Server) runScriptInVM(script []byte) *result.Invoke { // submitBlock broadcasts a raw block over the NEO network. func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) { - param, ok := reqParams.ValueWithType(0, request.StringT) - if !ok { - return nil, response.ErrInvalidParams - } - blockBytes, err := param.GetBytesHex() + blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesHex() if err != nil { return nil, response.ErrInvalidParams } @@ -1134,11 +1056,7 @@ func (s *Server) sendrawtransaction(reqParams request.Params) (interface{}, *res // subscribe handles subscription requests from websocket clients. func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - streamName, err := p.GetString() + streamName, err := reqParams.Value(0).GetString() if err != nil { return nil, response.ErrInvalidParams } @@ -1148,8 +1066,7 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface } // Optional filter. var filter interface{} - p, ok = reqParams.Value(1) - if ok { + if p := reqParams.Value(1); p != nil { // It doesn't accept filters. if event == response.BlockEventID { return nil, response.ErrInvalidParams @@ -1224,11 +1141,7 @@ func (s *Server) subscribeToChannel(event response.EventID) { // unsubscribe handles unsubscription requests from websocket clients. func (s *Server) unsubscribe(reqParams request.Params, sub *subscriber) (interface{}, *response.Error) { - p, ok := reqParams.Value(0) - if !ok { - return nil, response.ErrInvalidParams - } - id, err := p.GetInt() + id, err := reqParams.Value(0).GetInt() if err != nil || id < 0 { return nil, response.ErrInvalidParams }