rpc: make parameter type an enum

Signed-off-by: Evgenii Stratonikov <evgeniy@nspcc.ru>
This commit is contained in:
Evgenii Stratonikov 2019-11-21 16:42:51 +03:00
parent d2bdae99e4
commit 7331127556
3 changed files with 105 additions and 117 deletions

View file

@ -1,7 +1,11 @@
package rpc
import (
"encoding/json"
"fmt"
"github.com/CityOfZion/neo-go/pkg/util"
"github.com/pkg/errors"
)
type (
@ -9,13 +13,56 @@ type (
// the server or to send to a server using
// the client.
Param struct {
StringVal string
IntVal int
Type string
RawValue interface{}
Type paramType
Value interface{}
}
paramType int
)
const (
defaultT paramType = iota
stringT
numberT
)
func (p Param) String() string {
return fmt.Sprintf("%v", p.RawValue)
return fmt.Sprintf("%v", p.Value)
}
// GetString returns string value of the parameter.
func (p Param) GetString() string { return p.Value.(string) }
// GetInt returns int value of te parameter.
func (p Param) GetInt() int { return p.Value.(int) }
// GetUint256 returns Uint256 value of the parameter.
func (p Param) GetUint256() (util.Uint256, error) {
s, ok := p.Value.(string)
if !ok {
return util.Uint256{}, errors.New("must be a string")
}
return util.Uint256DecodeReverseString(s)
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (p *Param) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err == nil {
p.Type = stringT
p.Value = s
return nil
}
var num float64
if err := json.Unmarshal(data, &num); err == nil {
p.Type = numberT
p.Value = int(num)
return nil
}
return errors.New("unknown type")
}

View file

@ -1,49 +1,13 @@
package rpc
import (
"encoding/json"
)
type (
// Params represents the JSON-RPC params.
Params []Param
)
// UnmarshalJSON implements the Unmarshaller
// interface.
func (p *Params) UnmarshalJSON(data []byte) error {
var params []interface{}
err := json.Unmarshal(data, &params)
if err != nil {
return err
}
for i := 0; i < len(params); i++ {
param := Param{
RawValue: params[i],
}
switch val := params[i].(type) {
case string:
param.StringVal = val
param.Type = "string"
case float64:
newVal, _ := params[i].(float64)
param.IntVal = int(newVal)
param.Type = "number"
}
*p = append(*p, param)
}
return nil
}
// ValueAt returns the param struct for the given
// Value returns the param struct for the given
// index if it exists.
func (p Params) ValueAt(index int) (*Param, bool) {
func (p Params) Value(index int) (*Param, bool) {
if len(p) > index {
return &p[index], true
}
@ -51,33 +15,12 @@ func (p Params) ValueAt(index int) (*Param, bool) {
return nil, false
}
// ValueAtAndType returns the param struct at the given index if it
// ValueWithType returns the param struct at the given index if it
// exists and matches the given type.
func (p Params) ValueAtAndType(index int, valueType string) (*Param, bool) {
if len(p) > index && valueType == p[index].Type {
return &p[index], true
func (p Params) ValueWithType(index int, valType paramType) (*Param, bool) {
if val, ok := p.Value(index); ok && val.Type == valType {
return val, true
}
return nil, false
}
// Value returns the param struct for the given
// index if it exists.
func (p Params) Value(index int) (*Param, error) {
if len(p) <= index {
return nil, errInvalidParams
}
return &p[index], nil
}
// ValueWithType returns the param struct at the given index if it
// exists and matches the given type.
func (p Params) ValueWithType(index int, valType string) (*Param, error) {
val, err := p.Value(index)
if err != nil {
return nil, err
} else if val.Type != valType {
return nil, errInvalidParams
}
return &p[index], nil
}

View file

@ -30,11 +30,9 @@ type (
}
)
var (
invalidBlockHeightError = func(index int, height int) error {
var invalidBlockHeightError = func(index int, height int) error {
return errors.Errorf("Param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height)
}
)
}
// NewServer creates a new Server struct.
func NewServer(chain core.Blockchainer, conf config.RPCConfig, coreServer *network.Server) Server {
@ -123,27 +121,28 @@ Methods:
getbestblockCalled.Inc()
var hash util.Uint256
param, err := reqParams.Value(0)
if err != nil {
resultsErr = err
param, ok := reqParams.Value(0)
if !ok {
resultsErr = errInvalidParams
break Methods
}
switch param.Type {
case "string":
hash, err = util.Uint256DecodeReverseString(param.StringVal)
case stringT:
var err error
hash, err = param.GetUint256()
if err != nil {
resultsErr = errInvalidParams
break Methods
}
case "number":
case numberT:
if !s.validBlockHeight(param) {
resultsErr = errInvalidParams
break Methods
}
hash = s.chain.GetHeaderHash(param.IntVal)
case "default":
hash = s.chain.GetHeaderHash(param.GetInt())
case defaultT:
resultsErr = errInvalidParams
break Methods
}
@ -161,16 +160,16 @@ Methods:
case "getblockhash":
getblockHashCalled.Inc()
param, err := reqParams.ValueWithType(0, "number")
if err != nil {
resultsErr = err
param, ok := reqParams.ValueWithType(0, numberT)
if !ok {
resultsErr = errInvalidParams
break Methods
} else if !s.validBlockHeight(param) {
resultsErr = invalidBlockHeightError(0, param.IntVal)
resultsErr = invalidBlockHeightError(0, param.GetInt())
break Methods
}
results = s.chain.GetHeaderHash(param.IntVal)
results = s.chain.GetHeaderHash(param.GetInt())
case "getconnectioncount":
getconnectioncountCalled.Inc()
@ -203,22 +202,22 @@ Methods:
case "validateaddress":
validateaddressCalled.Inc()
param, err := reqParams.Value(0)
if err != nil {
resultsErr = err
param, ok := reqParams.Value(0)
if !ok {
resultsErr = errInvalidParams
break Methods
}
results = wrappers.ValidateAddress(param.RawValue)
results = wrappers.ValidateAddress(param.Value)
case "getassetstate":
getassetstateCalled.Inc()
param, err := reqParams.ValueWithType(0, "string")
if err != nil {
resultsErr = err
param, ok := reqParams.ValueWithType(0, stringT)
if !ok {
resultsErr = errInvalidParams
break Methods
}
paramAssetID, err := util.Uint256DecodeReverseString(param.StringVal)
paramAssetID, err := param.GetUint256()
if err != nil {
resultsErr = errInvalidParams
break
@ -266,14 +265,13 @@ func (s *Server) getrawtransaction(reqParams Params) (interface{}, error) {
var resultsErr error
var results interface{}
param0, err := reqParams.ValueWithType(0, "string")
if err != nil {
resultsErr = err
} else if txHash, err := util.Uint256DecodeReverseString(param0.StringVal); err != nil {
if param0, ok := reqParams.Value(0); !ok {
return nil, errInvalidParams
} else if txHash, err := param0.GetUint256(); err != nil {
resultsErr = errInvalidParams
} else if tx, height, err := s.chain.GetTransaction(txHash); err != nil {
err = errors.Wrapf(err, "Invalid transaction hash: %s", txHash)
resultsErr = NewInvalidParamsError(err.Error(), err)
return nil, NewInvalidParamsError(err.Error(), err)
} else if len(reqParams) >= 2 {
_header := s.chain.GetHeaderHash(int(height))
header, err := s.chain.GetHeader(_header)
@ -281,8 +279,8 @@ func (s *Server) getrawtransaction(reqParams Params) (interface{}, error) {
resultsErr = NewInvalidParamsError(err.Error(), err)
}
param1, _ := reqParams.ValueAt(1)
switch v := param1.RawValue.(type) {
param1, _ := reqParams.Value(1)
switch v := param1.Value.(type) {
case int, float64, bool, string:
if v == 0 || v == "0" || v == 0.0 || v == false || v == "false" {
@ -305,14 +303,14 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{},
var resultsErr error
var results interface{}
param, err := reqParams.ValueWithType(0, "string")
if err != nil {
resultsErr = err
} else if scriptHash, err := crypto.Uint160DecodeAddress(param.StringVal); err != nil {
resultsErr = errInvalidParams
param, ok := reqParams.ValueWithType(0, stringT)
if !ok {
return nil, errInvalidParams
} else if scriptHash, err := crypto.Uint160DecodeAddress(param.GetString()); err != nil {
return nil, errInvalidParams
} else if as := s.chain.GetAccountState(scriptHash); as != nil {
if unspents {
results = wrappers.NewUnspents(as, s.chain, param.StringVal)
results = wrappers.NewUnspents(as, s.chain, param.GetString())
} else {
results = wrappers.NewAccountState(as)
}
@ -324,11 +322,11 @@ func (s *Server) getAccountState(reqParams Params, unspents bool) (interface{},
// invokescript implements the `invokescript` RPC call.
func (s *Server) invokescript(reqParams Params) (interface{}, error) {
hexScript, err := reqParams.ValueWithType(0, "string")
if err != nil {
return nil, err
hexScript, ok := reqParams.ValueWithType(0, stringT)
if !ok {
return nil, errInvalidParams
}
script, err := hex.DecodeString(hexScript.StringVal)
script, err := hex.DecodeString(hexScript.GetString())
if err != nil {
return nil, err
}
@ -338,7 +336,7 @@ func (s *Server) invokescript(reqParams Params) (interface{}, error) {
result := &wrappers.InvokeResult{
State: vm.State(),
GasConsumed: "0.1",
Script: hexScript.StringVal,
Script: hexScript.GetString(),
Stack: vm.Estack(),
}
return result, nil
@ -348,10 +346,10 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) {
var resultsErr error
var results interface{}
param, err := reqParams.ValueWithType(0, "string")
if err != nil {
resultsErr = err
} else if byteTx, err := hex.DecodeString(param.StringVal); err != nil {
param, ok := reqParams.ValueWithType(0, stringT)
if !ok {
resultsErr = errInvalidParams
} else if byteTx, err := hex.DecodeString(param.GetString()); err != nil {
resultsErr = errInvalidParams
} else {
r := io.NewBinReaderFromBuf(byteTx)
@ -387,5 +385,5 @@ func (s *Server) sendrawtransaction(reqParams Params) (interface{}, error) {
}
func (s Server) validBlockHeight(param *Param) bool {
return param.IntVal >= 0 && param.IntVal <= int(s.chain.BlockHeight())
return param.GetInt() >= 0 && param.GetInt() <= int(s.chain.BlockHeight())
}