From 7331127556e774ccfea79c4b9b0f2bb2d5af271f Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 21 Nov 2019 16:42:51 +0300 Subject: [PATCH] rpc: make parameter type an enum Signed-off-by: Evgenii Stratonikov --- pkg/rpc/param.go | 57 +++++++++++++++++++++++++--- pkg/rpc/params.go | 69 +++------------------------------- pkg/rpc/server.go | 96 +++++++++++++++++++++++------------------------ 3 files changed, 105 insertions(+), 117 deletions(-) diff --git a/pkg/rpc/param.go b/pkg/rpc/param.go index d86b2ef2f..52360830d 100644 --- a/pkg/rpc/param.go +++ b/pkg/rpc/param.go @@ -1,7 +1,11 @@ package rpc import ( + "encoding/json" "fmt" + + "github.com/CityOfZion/neo-go/pkg/util" + "github.com/pkg/errors" ) type ( @@ -9,13 +13,56 @@ type ( // the server or to send to a server using // the client. Param struct { - StringVal string - IntVal int - Type string - RawValue interface{} + Type paramType + Value interface{} } + + paramType int +) + +const ( + defaultT paramType = iota + stringT + numberT ) func (p Param) String() string { - return fmt.Sprintf("%v", p.RawValue) + return fmt.Sprintf("%v", p.Value) +} + +// GetString returns string value of the parameter. +func (p Param) GetString() string { return p.Value.(string) } + +// GetInt returns int value of te parameter. +func (p Param) GetInt() int { return p.Value.(int) } + +// GetUint256 returns Uint256 value of the parameter. +func (p Param) GetUint256() (util.Uint256, error) { + s, ok := p.Value.(string) + if !ok { + return util.Uint256{}, errors.New("must be a string") + } + + return util.Uint256DecodeReverseString(s) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (p *Param) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err == nil { + p.Type = stringT + p.Value = s + + return nil + } + + var num float64 + if err := json.Unmarshal(data, &num); err == nil { + p.Type = numberT + p.Value = int(num) + + return nil + } + + return errors.New("unknown type") } diff --git a/pkg/rpc/params.go b/pkg/rpc/params.go index bef402f0a..1affcf267 100644 --- a/pkg/rpc/params.go +++ b/pkg/rpc/params.go @@ -1,49 +1,13 @@ package rpc -import ( - "encoding/json" -) - type ( // Params represents the JSON-RPC params. Params []Param ) -// UnmarshalJSON implements the Unmarshaller -// interface. -func (p *Params) UnmarshalJSON(data []byte) error { - var params []interface{} - - err := json.Unmarshal(data, ¶ms) - if err != nil { - return err - } - - for i := 0; i < len(params); i++ { - param := Param{ - RawValue: params[i], - } - - switch val := params[i].(type) { - case string: - param.StringVal = val - param.Type = "string" - - case float64: - newVal, _ := params[i].(float64) - param.IntVal = int(newVal) - param.Type = "number" - } - - *p = append(*p, param) - } - - return nil -} - -// ValueAt returns the param struct for the given +// Value returns the param struct for the given // index if it exists. -func (p Params) ValueAt(index int) (*Param, bool) { +func (p Params) Value(index int) (*Param, bool) { if len(p) > index { return &p[index], true } @@ -51,33 +15,12 @@ func (p Params) ValueAt(index int) (*Param, bool) { return nil, false } -// ValueAtAndType returns the param struct at the given index if it +// ValueWithType returns the param struct at the given index if it // exists and matches the given type. -func (p Params) ValueAtAndType(index int, valueType string) (*Param, bool) { - if len(p) > index && valueType == p[index].Type { - return &p[index], true +func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) { + if val, ok := p.Value(index); ok && val.Type == valType { + return val, true } return nil, false } - -// Value returns the param struct for the given -// index if it exists. -func (p Params) Value(index int) (*Param, error) { - if len(p) <= index { - return nil, errInvalidParams - } - return &p[index], nil -} - -// ValueWithType returns the param struct at the given index if it -// exists and matches the given type. -func (p Params) ValueWithType(index int, valType string) (*Param, error) { - val, err := p.Value(index) - if err != nil { - return nil, err - } else if val.Type != valType { - return nil, errInvalidParams - } - return &p[index], nil -} diff --git a/pkg/rpc/server.go b/pkg/rpc/server.go index 8b521ed36..c9e730321 100644 --- a/pkg/rpc/server.go +++ b/pkg/rpc/server.go @@ -30,11 +30,9 @@ type ( } ) -var ( - invalidBlockHeightError = func(index int, height int) error { - return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height) - } -) +var invalidBlockHeightError = func(index int, height int) error { + return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height) +} // NewServer creates a new Server struct. func NewServer(chain core.Blockchainer, conf config.RPCConfig, coreServer *network.Server) Server { @@ -123,27 +121,28 @@ Methods: getbestblockCalled.Inc() var hash util.Uint256 - param, err := reqParams.Value(0) - if err != nil { - resultsErr = err + param, ok := reqParams.Value(0) + if !ok { + resultsErr = errInvalidParams break Methods } switch param.Type { - case "string": - hash, err = util.Uint256DecodeReverseString(param.StringVal) + case stringT: + var err error + hash, err = param.GetUint256() if err != nil { resultsErr = errInvalidParams break Methods } - case "number": + case numberT: if !s.validBlockHeight(param) { resultsErr = errInvalidParams break Methods } - hash = s.chain.GetHeaderHash(param.IntVal) - case "default": + hash = s.chain.GetHeaderHash(param.GetInt()) + case defaultT: resultsErr = errInvalidParams break Methods } @@ -161,16 +160,16 @@ Methods: case "getblockhash": getblockHashCalled.Inc() - param, err := reqParams.ValueWithType(0, "number") - if err != nil { - resultsErr = err + param, ok := reqParams.ValueWithType(0, numberT) + if !ok { + resultsErr = errInvalidParams break Methods } else if !s.validBlockHeight(param) { - resultsErr = invalidBlockHeightError(0, param.IntVal) + resultsErr = invalidBlockHeightError(0, param.GetInt()) break Methods } - results = s.chain.GetHeaderHash(param.IntVal) + results = s.chain.GetHeaderHash(param.GetInt()) case "getconnectioncount": getconnectioncountCalled.Inc() @@ -203,22 +202,22 @@ Methods: case "validateaddress": validateaddressCalled.Inc() - param, err := reqParams.Value(0) - if err != nil { - resultsErr = err + param, ok := reqParams.Value(0) + if !ok { + resultsErr = errInvalidParams break Methods } - results = wrappers.ValidateAddress(param.RawValue) + results = wrappers.ValidateAddress(param.Value) case "getassetstate": getassetstateCalled.Inc() - param, err := reqParams.ValueWithType(0, "string") - if err != nil { - resultsErr = err + param, ok := reqParams.ValueWithType(0, stringT) + if !ok { + resultsErr = errInvalidParams break Methods } - paramAssetID, err := util.Uint256DecodeReverseString(param.StringVal) + paramAssetID, err := param.GetUint256() if err != nil { resultsErr = errInvalidParams break @@ -266,14 +265,13 @@ func (s *Server) getrawtransaction(reqParams Params) (interface{}, error) { var resultsErr error var results interface{} - param0, err := reqParams.ValueWithType(0, "string") - if err != nil { - resultsErr = err - } else if txHash, err := util.Uint256DecodeReverseString(param0.StringVal); err != nil { + if param0, ok := reqParams.Value(0); !ok { + return nil, errInvalidParams + } else if txHash, err := param0.GetUint256(); err != nil { resultsErr = errInvalidParams } else if tx, height, err := s.chain.GetTransaction(txHash); err != nil { err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash) - resultsErr = NewInvalidParamsError(err.Error(), err) + return nil, NewInvalidParamsError(err.Error(), err) } else if len(reqParams) >= 2 { _header := s.chain.GetHeaderHash(int(height)) header, err := s.chain.GetHeader(_header) @@ -281,8 +279,8 @@ func (s *Server) getrawtransaction(reqParams Params) (interface{}, error) { resultsErr = NewInvalidParamsError(err.Error(), err) } - param1, _ := reqParams.ValueAt(1) - switch v := param1.RawValue.(type) { + param1, _ := reqParams.Value(1) + switch v := param1.Value.(type) { case int, float64, bool, string: if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" { @@ -305,14 +303,14 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{}, var resultsErr error var results interface{} - param, err := reqParams.ValueWithType(0, "string") - if err != nil { - resultsErr = err - } else if scriptHash, err := crypto.Uint160DecodeAddress(param.StringVal); err != nil { - resultsErr = errInvalidParams + param, ok := reqParams.ValueWithType(0, stringT) + if !ok { + return nil, errInvalidParams + } else if scriptHash, err := crypto.Uint160DecodeAddress(param.GetString()); err != nil { + return nil, errInvalidParams } else if as := s.chain.GetAccountState(scriptHash); as != nil { if unspents { - results = wrappers.NewUnspents(as, s.chain, param.StringVal) + results = wrappers.NewUnspents(as, s.chain, param.GetString()) } else { results = wrappers.NewAccountState(as) } @@ -324,11 +322,11 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{}, // invokescript implements the `invokescript` RPC call. func (s *Server) invokescript(reqParams Params) (interface{}, error) { - hexScript, err := reqParams.ValueWithType(0, "string") - if err != nil { - return nil, err + hexScript, ok := reqParams.ValueWithType(0, stringT) + if !ok { + return nil, errInvalidParams } - script, err := hex.DecodeString(hexScript.StringVal) + script, err := hex.DecodeString(hexScript.GetString()) if err != nil { return nil, err } @@ -338,7 +336,7 @@ func (s *Server) invokescript(reqParams Params) (interface{}, error) { result := &wrappers.InvokeResult{ State: vm.State(), GasConsumed: "0.1", - Script: hexScript.StringVal, + Script: hexScript.GetString(), Stack: vm.Estack(), } return result, nil @@ -348,10 +346,10 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) { var resultsErr error var results interface{} - param, err := reqParams.ValueWithType(0, "string") - if err != nil { - resultsErr = err - } else if byteTx, err := hex.DecodeString(param.StringVal); err != nil { + param, ok := reqParams.ValueWithType(0, stringT) + if !ok { + resultsErr = errInvalidParams + } else if byteTx, err := hex.DecodeString(param.GetString()); err != nil { resultsErr = errInvalidParams } else { r := io.NewBinReaderFromBuf(byteTx) @@ -387,5 +385,5 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) { } func (s Server) validBlockHeight(param *Param) bool { - return param.IntVal >= 0 && param.IntVal <= int(s.chain.BlockHeight()) + return param.GetInt() >= 0 && param.GetInt() <= int(s.chain.BlockHeight()) }