Merge pull request #2234 from nspcc-dev/rpc/params-parsing

rpc: method-specific parameters parsing optimisation
This commit is contained in:
Roman Khimov 2021-11-10 20:45:44 +03:00 committed by GitHub
commit 6b8e615094
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 655 additions and 535 deletions

View file

@ -335,10 +335,7 @@ func (s *Server) handleIn(req *request.In, sub *subscriber) response.Abstract {
return s.packResponse(req, nil, response.NewInvalidParamsError("Problem parsing JSON", fmt.Errorf("invalid version, expected 2.0 got: '%s'", req.JSONRPC)))
}
reqParams, err := req.Params()
if err != nil {
return s.packResponse(req, nil, response.NewInvalidParamsError("Problem parsing request parameters", err))
}
reqParams := request.Params(req.RawParams)
s.log.Debug("processing rpc request",
zap.String("method", req.Method),
@ -349,11 +346,11 @@ func (s *Server) handleIn(req *request.In, sub *subscriber) response.Abstract {
resErr = response.NewMethodNotFoundError(fmt.Sprintf("Method '%s' not supported", req.Method), nil)
handler, ok := rpcHandlers[req.Method]
if ok {
res, resErr = handler(s, *reqParams)
res, resErr = handler(s, reqParams)
} else if sub != nil {
handler, ok := rpcWsHandlers[req.Method]
if ok {
res, resErr = handler(s, *reqParams, sub)
res, resErr = handler(s, reqParams, sub)
}
}
return s.packResponse(req, res, resErr)
@ -462,27 +459,20 @@ func (s *Server) getConnectionCount(_ request.Params) (interface{}, *response.Er
}
func (s *Server) blockHashFromParam(param *request.Param) (util.Uint256, *response.Error) {
var hash util.Uint256
var (
hash util.Uint256
err error
)
if param == nil {
return hash, response.ErrInvalidParams
}
switch param.Type {
case request.StringT:
var err error
hash, err = param.GetUint256()
if err != nil {
return hash, response.ErrInvalidParams
}
case request.NumberT:
num, err := s.blockHeightFromParam(param)
if err != nil {
return hash, response.ErrInvalidParams
if hash, err = param.GetUint256(); err != nil {
num, respErr := s.blockHeightFromParam(param)
if respErr != nil {
return hash, respErr
}
hash = s.chain.GetHeaderHash(num)
default:
return hash, response.ErrInvalidParams
}
return hash, nil
}
@ -499,7 +489,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro
return nil, response.NewInternalServerError(fmt.Sprintf("Problem locating block with hash: %s", hash), err)
}
if reqParams.Value(1).GetBoolean() {
if v, _ := reqParams.Value(1).GetBoolean(); v {
return result.NewBlock(block, s.chain), nil
}
writer := io.NewBufBinWriter()
@ -508,11 +498,7 @@ func (s *Server) getBlock(reqParams request.Params) (interface{}, *response.Erro
}
func (s *Server) getBlockHash(reqParams request.Params) (interface{}, *response.Error) {
param := reqParams.ValueWithType(0, request.NumberT)
if param == nil {
return nil, response.ErrInvalidParams
}
num, err := s.blockHeightFromParam(param)
num, err := s.blockHeightFromParam(reqParams.Value(0))
if err != nil {
return nil, response.ErrInvalidParams
}
@ -557,7 +543,7 @@ func (s *Server) getPeers(_ request.Params) (interface{}, *response.Error) {
}
func (s *Server) getRawMempool(reqParams request.Params) (interface{}, *response.Error) {
verbose := reqParams.Value(0).GetBoolean()
verbose, _ := reqParams.Value(0).GetBoolean()
mp := s.chain.GetMemPool()
hashList := make([]util.Uint256, 0)
for _, item := range mp.GetVerifiedTransactions() {
@ -574,11 +560,15 @@ func (s *Server) getRawMempool(reqParams request.Params) (interface{}, *response
}
func (s *Server) validateAddress(reqParams request.Params) (interface{}, *response.Error) {
param := reqParams.Value(0)
if param == nil {
param, err := reqParams.Value(0).GetString()
if err != nil {
return nil, response.ErrInvalidParams
}
return validateAddress(param.Value), nil
return result.ValidateAddress{
Address: reqParams.Value(0),
IsValid: validateAddress(param),
}, nil
}
// calculateNetworkFee calculates network fee for the transaction.
@ -644,11 +634,11 @@ func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *resp
trig := trigger.All
if len(reqParams) > 1 {
trigString := reqParams.ValueWithType(1, request.StringT)
if trigString == nil {
trigString, err := reqParams.Value(1).GetString()
if err != nil {
return nil, response.ErrInvalidParams
}
trig, err = trigger.FromString(trigString.String())
trig, err = trigger.FromString(trigString)
if err != nil {
return nil, response.ErrInvalidParams
}
@ -877,19 +867,13 @@ func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Err
if param == nil {
return 0, response.ErrInvalidParams
}
switch param.Type {
case request.StringT:
var err error
scriptHash, err := param.GetUint160FromHex()
if err != nil {
return 0, response.ErrInvalidParams
}
if scriptHash, err := param.GetUint160FromHex(); err == nil {
cs := s.chain.GetContractState(scriptHash)
if cs == nil {
return 0, response.ErrUnknown
}
result = cs.ID
case request.NumberT:
} else {
id, err := param.GetInt()
if err != nil {
return 0, response.ErrInvalidParams
@ -898,8 +882,6 @@ func (s *Server) contractIDFromParam(param *request.Param) (int32, *response.Err
return 0, response.WrapErrorWithData(response.ErrInvalidParams, err)
}
result = int32(id)
default:
return 0, response.ErrInvalidParams
}
return result, nil
}
@ -910,36 +892,29 @@ func (s *Server) contractScriptHashFromParam(param *request.Param) (util.Uint160
if param == nil {
return result, response.ErrInvalidParams
}
switch param.Type {
case request.StringT:
var err error
result, err = param.GetUint160FromAddressOrHex()
if err == nil {
return result, nil
}
name, err := param.GetString()
if err != nil {
return result, response.ErrInvalidParams
}
result, err = s.chain.GetNativeContractScriptHash(name)
if err != nil {
return result, response.NewRPCError("Unknown contract: querying by name is supported for native contracts only", "", nil)
}
case request.NumberT:
id, err := param.GetInt()
if err != nil {
return result, response.ErrInvalidParams
}
if err := checkInt32(id); err != nil {
return result, response.WrapErrorWithData(response.ErrInvalidParams, err)
}
result, err = s.chain.GetContractScriptHash(int32(id))
if err != nil {
return result, response.NewRPCError("Unknown contract", "", err)
}
default:
nameOrHashOrIndex, err := param.GetString()
if err != nil {
return result, response.ErrInvalidParams
}
result, err = param.GetUint160FromAddressOrHex()
if err == nil {
return result, nil
}
result, err = s.chain.GetNativeContractScriptHash(nameOrHashOrIndex)
if err == nil {
return result, nil
}
id, err := strconv.Atoi(nameOrHashOrIndex)
if err != nil {
return result, response.NewRPCError("Unknown contract", "", err)
}
if err := checkInt32(id); err != nil {
return result, response.WrapErrorWithData(response.ErrInvalidParams, err)
}
result, err = s.chain.GetContractScriptHash(int32(id))
if err != nil {
return result, response.NewRPCError("Unknown contract", "", err)
}
return result, nil
}
@ -1168,7 +1143,7 @@ func (s *Server) getStateRoot(ps request.Params) (interface{}, *response.Error)
}
var rt *state.MPTRoot
var h util.Uint256
height, err := p.GetInt()
height, err := p.GetIntStrict()
if err == nil {
if err := checkUint32(height); err != nil {
return nil, response.WrapErrorWithData(response.ErrInvalidParams, err)
@ -1219,7 +1194,7 @@ func (s *Server) getrawtransaction(reqParams request.Params) (interface{}, *resp
err = fmt.Errorf("invalid transaction %s: %w", txHash, err)
return nil, response.NewRPCError("Unknown transaction", err.Error(), err)
}
if reqParams.Value(1).GetBoolean() {
if v, _ := reqParams.Value(1).GetBoolean(); v {
if height == math.MaxUint32 {
return result.NewTransactionOutputRaw(tx, nil, nil, s.chain), nil
}
@ -1274,12 +1249,7 @@ func (s *Server) getNativeContracts(_ request.Params) (interface{}, *response.Er
// getBlockSysFee returns the system fees of the block, based on the specified index.
func (s *Server) getBlockSysFee(reqParams request.Params) (interface{}, *response.Error) {
param := reqParams.ValueWithType(0, request.NumberT)
if param == nil {
return 0, response.ErrInvalidParams
}
num, err := s.blockHeightFromParam(param)
num, err := s.blockHeightFromParam(reqParams.Value(0))
if err != nil {
return 0, response.NewRPCError("Invalid height", "", nil)
}
@ -1306,7 +1276,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons
return nil, respErr
}
verbose := reqParams.Value(1).GetBoolean()
verbose, _ := reqParams.Value(1).GetBoolean()
h, err := s.chain.GetHeader(hash)
if err != nil {
return nil, response.NewRPCError("unknown block", "", nil)
@ -1326,7 +1296,7 @@ func (s *Server) getBlockHeader(reqParams request.Params) (interface{}, *respons
// getUnclaimedGas returns unclaimed GAS amount of the specified address.
func (s *Server) getUnclaimedGas(ps request.Params) (interface{}, *response.Error) {
u, err := ps.ValueWithType(0, request.StringT).GetUint160FromAddressOrHex()
u, err := ps.Value(0).GetUint160FromAddressOrHex()
if err != nil {
return nil, response.ErrInvalidParams
}
@ -1398,7 +1368,11 @@ func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *respons
if len(tx.Signers) == 0 {
tx.Signers = []transaction.Signer{{Account: util.Uint160{}, Scopes: transaction.None}}
}
script, err := request.CreateFunctionInvocationScript(scriptHash, reqParams[1].String(), reqParams[2:checkWitnessHashesIndex])
method, err := reqParams[1].GetString()
if err != nil {
return nil, response.ErrInvalidParams
}
script, err := request.CreateFunctionInvocationScript(scriptHash, method, reqParams[2:checkWitnessHashesIndex])
if err != nil {
return nil, response.NewInternalServerError("can't create invocation script", err)
}
@ -1520,7 +1494,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
// submitBlock broadcasts a raw block over the NEO network.
func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.Error) {
blockBytes, err := reqParams.ValueWithType(0, request.StringT).GetBytesBase64()
blockBytes, err := reqParams.Value(0).GetBytesBase64()
if err != nil {
return nil, response.NewInvalidParamsError("missing parameter or not base64", err)
}
@ -1550,10 +1524,7 @@ func (s *Server) submitNotaryRequest(ps request.Params) (interface{}, *response.
return nil, response.NewInternalServerError("P2PNotaryRequest was received, but P2PSignatureExtensions are disabled", nil)
}
if len(ps) < 1 {
return nil, response.NewInvalidParamsError("not enough parameters", nil)
}
bytePayload, err := ps[0].GetBytesBase64()
bytePayload, err := ps.Value(0).GetBytesBase64()
if err != nil {
return nil, response.NewInvalidParamsError("not base64", err)
}
@ -1645,34 +1616,27 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface
// Optional filter.
var filter interface{}
if p := reqParams.Value(1); p != nil {
param, ok := p.Value.(json.RawMessage)
if !ok {
return nil, response.ErrInvalidParams
}
jd := json.NewDecoder(bytes.NewReader(param))
param := *p
jd := json.NewDecoder(bytes.NewReader(param.RawMessage))
jd.DisallowUnknownFields()
switch event {
case response.BlockEventID:
flt := new(request.BlockFilter)
err = jd.Decode(flt)
p.Type = request.BlockFilterT
p.Value = *flt
filter = *flt
case response.TransactionEventID, response.NotaryRequestEventID:
flt := new(request.TxFilter)
err = jd.Decode(flt)
p.Type = request.TxFilterT
p.Value = *flt
filter = *flt
case response.NotificationEventID:
flt := new(request.NotificationFilter)
err = jd.Decode(flt)
p.Type = request.NotificationFilterT
p.Value = *flt
filter = *flt
case response.ExecutionEventID:
flt := new(request.ExecutionFilter)
err = jd.Decode(flt)
if err == nil && (flt.State == "HALT" || flt.State == "FAULT") {
p.Type = request.ExecutionFilterT
p.Value = *flt
filter = *flt
} else if err == nil {
err = errors.New("invalid state")
}
@ -1680,7 +1644,6 @@ func (s *Server) subscribe(reqParams request.Params, sub *subscriber) (interface
if err != nil {
return nil, response.ErrInvalidParams
}
filter = p.Value
}
s.subsLock.Lock()
@ -1912,7 +1875,7 @@ drainloop:
func (s *Server) blockHeightFromParam(param *request.Param) (int, *response.Error) {
num, err := param.GetInt()
if err != nil {
return 0, nil
return 0, response.ErrInvalidParams
}
if num < 0 || num > int(s.chain.BlockHeight()) {
@ -1946,10 +1909,8 @@ func (s *Server) logRequestError(r *request.Request, jsonErr *response.Error) {
if r.In != nil {
logFields = append(logFields, zap.String("method", r.In.Method))
params, err := r.In.Params()
if err == nil {
logFields = append(logFields, zap.Any("params", params))
}
params := request.Params(r.In.RawParams)
logFields = append(logFields, zap.Any("params", params))
}
s.log.Error("Error encountered with rpc request", logFields...)
@ -1996,11 +1957,10 @@ func (s *Server) writeHTTPServerResponse(r *request.Request, w http.ResponseWrit
// validateAddress verifies that the address is a correct NEO address
// see https://docs.neo.org/en-us/node/cli/2.9.4/api/validateaddress.html
func validateAddress(addr interface{}) result.ValidateAddress {
resp := result.ValidateAddress{Address: addr}
func validateAddress(addr interface{}) bool {
if addr, ok := addr.(string); ok {
_, err := address.StringToUint160(addr)
resp.IsValid = (err == nil)
return err == nil
}
return resp
return false
}