rpc: make parameter type an enum

Signed-off-by: Evgenii Stratonikov <evgeniy@nspcc.ru>
This commit is contained in:
Evgenii Stratonikov 2019-11-21 16:42:51 +03:00
parent d2bdae99e4
commit 7331127556
3 changed files with 105 additions and 117 deletions

View file

@ -1,7 +1,11 @@
package rpc package rpc
import ( import (
"encoding/json"
"fmt" "fmt"
"github.com/CityOfZion/neo-go/pkg/util"
"github.com/pkg/errors"
) )
type ( type (
@ -9,13 +13,56 @@ type (
// the server or to send to a server using // the server or to send to a server using
// the client. // the client.
Param struct { Param struct {
StringVal string Type paramType
IntVal int Value interface{}
Type string
RawValue interface{}
} }
paramType int
)
const (
defaultT paramType = iota
stringT
numberT
) )
func (p Param) String() string { func (p Param) String() string {
return fmt.Sprintf("%v", p.RawValue) return fmt.Sprintf("%v", p.Value)
}
// GetString returns string value of the parameter.
func (p Param) GetString() string { return p.Value.(string) }
// GetInt returns int value of te parameter.
func (p Param) GetInt() int { return p.Value.(int) }
// GetUint256 returns Uint256 value of the parameter.
func (p Param) GetUint256() (util.Uint256, error) {
s, ok := p.Value.(string)
if !ok {
return util.Uint256{}, errors.New("must be a string")
}
return util.Uint256DecodeReverseString(s)
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (p *Param) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err == nil {
p.Type = stringT
p.Value = s
return nil
}
var num float64
if err := json.Unmarshal(data, &num); err == nil {
p.Type = numberT
p.Value = int(num)
return nil
}
return errors.New("unknown type")
} }

View file

@ -1,49 +1,13 @@
package rpc package rpc
import (
"encoding/json"
)
type ( type (
// Params represents the JSON-RPC params. // Params represents the JSON-RPC params.
Params []Param Params []Param
) )
// UnmarshalJSON implements the Unmarshaller // Value returns the param struct for the given
// interface.
func (p *Params) UnmarshalJSON(data []byte) error {
var params []interface{}
err := json.Unmarshal(data, &params)
if err != nil {
return err
}
for i := 0; i < len(params); i++ {
param := Param{
RawValue: params[i],
}
switch val := params[i].(type) {
case string:
param.StringVal = val
param.Type = "string"
case float64:
newVal, _ := params[i].(float64)
param.IntVal = int(newVal)
param.Type = "number"
}
*p = append(*p, param)
}
return nil
}
// ValueAt returns the param struct for the given
// index if it exists. // index if it exists.
func (p Params) ValueAt(index int) (*Param, bool) { func (p Params) Value(index int) (*Param, bool) {
if len(p) > index { if len(p) > index {
return &p[index], true return &p[index], true
} }
@ -51,33 +15,12 @@ func (p Params) ValueAt(index int) (*Param, bool) {
return nil, false return nil, false
} }
// ValueAtAndType returns the param struct at the given index if it // ValueWithType returns the param struct at the given index if it
// exists and matches the given type. // exists and matches the given type.
func (p Params) ValueAtAndType(index int, valueType string) (*Param, bool) { func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) {
if len(p) > index && valueType == p[index].Type { if val, ok := p.Value(index); ok && val.Type == valType {
return &p[index], true return val, true
} }
return nil, false return nil, false
} }
// Value returns the param struct for the given
// index if it exists.
func (p Params) Value(index int) (*Param, error) {
if len(p) <= index {
return nil, errInvalidParams
}
return &p[index], nil
}
// ValueWithType returns the param struct at the given index if it
// exists and matches the given type.
func (p Params) ValueWithType(index int, valType string) (*Param, error) {
val, err := p.Value(index)
if err != nil {
return nil, err
} else if val.Type != valType {
return nil, errInvalidParams
}
return &p[index], nil
}

View file

@ -30,11 +30,9 @@ type (
} }
) )
var ( var invalidBlockHeightError = func(index int, height int) error {
invalidBlockHeightError = func(index int, height int) error { return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height)
return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height) }
}
)
// NewServer creates a new Server struct. // NewServer creates a new Server struct.
func NewServer(chain core.Blockchainer, conf config.RPCConfig, coreServer *network.Server) Server { func NewServer(chain core.Blockchainer, conf config.RPCConfig, coreServer *network.Server) Server {
@ -123,27 +121,28 @@ Methods:
getbestblockCalled.Inc() getbestblockCalled.Inc()
var hash util.Uint256 var hash util.Uint256
param, err := reqParams.Value(0) param, ok := reqParams.Value(0)
if err != nil { if !ok {
resultsErr = err resultsErr = errInvalidParams
break Methods break Methods
} }
switch param.Type { switch param.Type {
case "string": case stringT:
hash, err = util.Uint256DecodeReverseString(param.StringVal) var err error
hash, err = param.GetUint256()
if err != nil { if err != nil {
resultsErr = errInvalidParams resultsErr = errInvalidParams
break Methods break Methods
} }
case "number": case numberT:
if !s.validBlockHeight(param) { if !s.validBlockHeight(param) {
resultsErr = errInvalidParams resultsErr = errInvalidParams
break Methods break Methods
} }
hash = s.chain.GetHeaderHash(param.IntVal) hash = s.chain.GetHeaderHash(param.GetInt())
case "default": case defaultT:
resultsErr = errInvalidParams resultsErr = errInvalidParams
break Methods break Methods
} }
@ -161,16 +160,16 @@ Methods:
case "getblockhash": case "getblockhash":
getblockHashCalled.Inc() getblockHashCalled.Inc()
param, err := reqParams.ValueWithType(0, "number") param, ok := reqParams.ValueWithType(0, numberT)
if err != nil { if !ok {
resultsErr = err resultsErr = errInvalidParams
break Methods break Methods
} else if !s.validBlockHeight(param) { } else if !s.validBlockHeight(param) {
resultsErr = invalidBlockHeightError(0, param.IntVal) resultsErr = invalidBlockHeightError(0, param.GetInt())
break Methods break Methods
} }
results = s.chain.GetHeaderHash(param.IntVal) results = s.chain.GetHeaderHash(param.GetInt())
case "getconnectioncount": case "getconnectioncount":
getconnectioncountCalled.Inc() getconnectioncountCalled.Inc()
@ -203,22 +202,22 @@ Methods:
case "validateaddress": case "validateaddress":
validateaddressCalled.Inc() validateaddressCalled.Inc()
param, err := reqParams.Value(0) param, ok := reqParams.Value(0)
if err != nil { if !ok {
resultsErr = err resultsErr = errInvalidParams
break Methods break Methods
} }
results = wrappers.ValidateAddress(param.RawValue) results = wrappers.ValidateAddress(param.Value)
case "getassetstate": case "getassetstate":
getassetstateCalled.Inc() getassetstateCalled.Inc()
param, err := reqParams.ValueWithType(0, "string") param, ok := reqParams.ValueWithType(0, stringT)
if err != nil { if !ok {
resultsErr = err resultsErr = errInvalidParams
break Methods break Methods
} }
paramAssetID, err := util.Uint256DecodeReverseString(param.StringVal) paramAssetID, err := param.GetUint256()
if err != nil { if err != nil {
resultsErr = errInvalidParams resultsErr = errInvalidParams
break break
@ -266,14 +265,13 @@ func (s *Server) getrawtransaction(reqParams Params) (interface{}, error) {
var resultsErr error var resultsErr error
var results interface{} var results interface{}
param0, err := reqParams.ValueWithType(0, "string") if param0, ok := reqParams.Value(0); !ok {
if err != nil { return nil, errInvalidParams
resultsErr = err } else if txHash, err := param0.GetUint256(); err != nil {
} else if txHash, err := util.Uint256DecodeReverseString(param0.StringVal); err != nil {
resultsErr = errInvalidParams resultsErr = errInvalidParams
} else if tx, height, err := s.chain.GetTransaction(txHash); err != nil { } else if tx, height, err := s.chain.GetTransaction(txHash); err != nil {
err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash) err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash)
resultsErr = NewInvalidParamsError(err.Error(), err) return nil, NewInvalidParamsError(err.Error(), err)
} else if len(reqParams) >= 2 { } else if len(reqParams) >= 2 {
_header := s.chain.GetHeaderHash(int(height)) _header := s.chain.GetHeaderHash(int(height))
header, err := s.chain.GetHeader(_header) header, err := s.chain.GetHeader(_header)
@ -281,8 +279,8 @@ func (s *Server) getrawtransaction(reqParams Params) (interface{}, error) {
resultsErr = NewInvalidParamsError(err.Error(), err) resultsErr = NewInvalidParamsError(err.Error(), err)
} }
param1, _ := reqParams.ValueAt(1) param1, _ := reqParams.Value(1)
switch v := param1.RawValue.(type) { switch v := param1.Value.(type) {
case int, float64, bool, string: case int, float64, bool, string:
if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" { if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" {
@ -305,14 +303,14 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{},
var resultsErr error var resultsErr error
var results interface{} var results interface{}
param, err := reqParams.ValueWithType(0, "string") param, ok := reqParams.ValueWithType(0, stringT)
if err != nil { if !ok {
resultsErr = err return nil, errInvalidParams
} else if scriptHash, err := crypto.Uint160DecodeAddress(param.StringVal); err != nil { } else if scriptHash, err := crypto.Uint160DecodeAddress(param.GetString()); err != nil {
resultsErr = errInvalidParams return nil, errInvalidParams
} else if as := s.chain.GetAccountState(scriptHash); as != nil { } else if as := s.chain.GetAccountState(scriptHash); as != nil {
if unspents { if unspents {
results = wrappers.NewUnspents(as, s.chain, param.StringVal) results = wrappers.NewUnspents(as, s.chain, param.GetString())
} else { } else {
results = wrappers.NewAccountState(as) results = wrappers.NewAccountState(as)
} }
@ -324,11 +322,11 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{},
// invokescript implements the `invokescript` RPC call. // invokescript implements the `invokescript` RPC call.
func (s *Server) invokescript(reqParams Params) (interface{}, error) { func (s *Server) invokescript(reqParams Params) (interface{}, error) {
hexScript, err := reqParams.ValueWithType(0, "string") hexScript, ok := reqParams.ValueWithType(0, stringT)
if err != nil { if !ok {
return nil, err return nil, errInvalidParams
} }
script, err := hex.DecodeString(hexScript.StringVal) script, err := hex.DecodeString(hexScript.GetString())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -338,7 +336,7 @@ func (s *Server) invokescript(reqParams Params) (interface{}, error) {
result := &wrappers.InvokeResult{ result := &wrappers.InvokeResult{
State: vm.State(), State: vm.State(),
GasConsumed: "0.1", GasConsumed: "0.1",
Script: hexScript.StringVal, Script: hexScript.GetString(),
Stack: vm.Estack(), Stack: vm.Estack(),
} }
return result, nil return result, nil
@ -348,10 +346,10 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) {
var resultsErr error var resultsErr error
var results interface{} var results interface{}
param, err := reqParams.ValueWithType(0, "string") param, ok := reqParams.ValueWithType(0, stringT)
if err != nil { if !ok {
resultsErr = err resultsErr = errInvalidParams
} else if byteTx, err := hex.DecodeString(param.StringVal); err != nil { } else if byteTx, err := hex.DecodeString(param.GetString()); err != nil {
resultsErr = errInvalidParams resultsErr = errInvalidParams
} else { } else {
r := io.NewBinReaderFromBuf(byteTx) r := io.NewBinReaderFromBuf(byteTx)
@ -387,5 +385,5 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) {
} }
func (s Server) validBlockHeight(param *Param) bool { func (s Server) validBlockHeight(param *Param) bool {
return param.IntVal >= 0 && param.IntVal <= int(s.chain.BlockHeight()) return param.GetInt() >= 0 && param.GetInt() <= int(s.chain.BlockHeight())
} }