From 252e03bc34fa37c056d77ee8d4c77156c74cf09e Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Wed, 24 Mar 2021 20:32:48 +0300 Subject: [PATCH] rpc: add CalculateNetworkFee RPC method --- pkg/core/blockchainer/policer.go | 1 + pkg/rpc/client/rpc.go | 14 ++++++ pkg/rpc/server/client_test.go | 34 ++++++++++++-- pkg/rpc/server/server.go | 79 ++++++++++++++++++++++++++++++++ 4 files changed, 125 insertions(+), 3 deletions(-) diff --git a/pkg/core/blockchainer/policer.go b/pkg/core/blockchainer/policer.go index 19eef7545..29c3a3ca5 100644 --- a/pkg/core/blockchainer/policer.go +++ b/pkg/core/blockchainer/policer.go @@ -5,4 +5,5 @@ type Policer interface { GetBaseExecFee() int64 GetMaxVerificationGAS() int64 GetStoragePrice() int64 + FeePerByte() int64 } diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index 9465badb3..884482c32 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -28,6 +28,20 @@ import ( var errNetworkNotInitialized = errors.New("RPC client network is not initialized") +// CalculateNetworkFee calculates network fee for transaction. The transaction may +// have empty witnesses for contract signers and may have only verification scripts +// filled for standard sig/multisig signers. +func (c *Client) CalculateNetworkFee(tx *transaction.Transaction) (int64, error) { + var ( + params = request.NewRawParams(tx.Bytes()) + resp int64 + ) + if err := c.performRequest("calculatenetworkfee", params, &resp); err != nil { + return 0, err + } + return resp, nil +} + // GetApplicationLog returns the contract log based on the specified txid. func (c *Client) GetApplicationLog(hash util.Uint256, trig *trigger.Type) (*result.ApplicationLog, error) { var ( diff --git a/pkg/rpc/server/client_test.go b/pkg/rpc/server/client_test.go index 4fb084ac4..efa18d793 100644 --- a/pkg/rpc/server/client_test.go +++ b/pkg/rpc/server/client_test.go @@ -66,7 +66,7 @@ func TestClient_NEP17(t *testing.T) { }) } -func TestAddNetworkFee(t *testing.T) { +func TestAddNetworkFeeCalculateNetworkFee(t *testing.T) { chain, rpcSrv, httpSrv := initServerWithInMemoryChain(t) defer chain.Close() defer rpcSrv.Shutdown() @@ -110,6 +110,13 @@ func TestAddNetworkFee(t *testing.T) { tx.Nonce = nonce nonce++ + tx.Scripts = []transaction.Witness{ + {VerificationScript: acc0.GetVerificationScript()}, + } + actualCalculatedNetFee, err := c.CalculateNetworkFee(tx) + require.NoError(t, err) + + tx.Scripts = nil require.NoError(t, c.AddNetworkFee(tx, extraFee, acc0)) actual := tx.NetworkFee @@ -118,7 +125,8 @@ func TestAddNetworkFee(t *testing.T) { expected := int64(io.GetVarSize(tx))*feePerByte + cFee + extraFee require.Equal(t, expected, actual) - err := chain.VerifyTx(tx) + require.Equal(t, expected, actualCalculatedNetFee+extraFee) + err = chain.VerifyTx(tx) if extraFee < 0 { require.Error(t, err) } else { @@ -165,6 +173,15 @@ func TestAddNetworkFee(t *testing.T) { tx.Nonce = nonce nonce++ + tx.Scripts = []transaction.Witness{ + {VerificationScript: acc0.GetVerificationScript()}, + {VerificationScript: acc1.GetVerificationScript()}, + } + actualCalculatedNetFee, err := c.CalculateNetworkFee(tx) + require.NoError(t, err) + + tx.Scripts = nil + require.NoError(t, c.AddNetworkFee(tx, extraFee, acc0, acc1)) actual := tx.NetworkFee @@ -178,7 +195,8 @@ func TestAddNetworkFee(t *testing.T) { expected := int64(io.GetVarSize(tx))*feePerByte + cFee + cFeeM + extraFee require.Equal(t, expected, actual) - err := chain.VerifyTx(tx) + require.Equal(t, expected, actualCalculatedNetFee+extraFee) + err = chain.VerifyTx(tx) if extraFee < 0 { require.Error(t, err) } else { @@ -228,9 +246,19 @@ func TestAddNetworkFee(t *testing.T) { Scopes: transaction.Global, }, } + // we need to fill standard verification scripts to use CalculateNetworkFee. + tx.Scripts = []transaction.Witness{ + {VerificationScript: acc0.GetVerificationScript()}, + {}, + } + actual, err := c.CalculateNetworkFee(tx) + require.NoError(t, err) + tx.Scripts = nil + require.NoError(t, c.AddNetworkFee(tx, extraFee, acc0, acc1)) require.NoError(t, acc0.SignTx(testchain.Network(), tx)) tx.Scripts = append(tx.Scripts, transaction.Witness{}) + require.Equal(t, tx.NetworkFee, actual+extraFee) err = chain.VerifyTx(tx) if extraFee < 0 { require.Error(t, err) diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 3479fc9b4..8f10088ca 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -20,6 +20,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" + "github.com/nspcc-dev/neo-go/pkg/core/fee" "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -93,6 +94,7 @@ const ( ) var rpcHandlers = map[string]func(*Server, request.Params) (interface{}, *response.Error){ + "calculatenetworkfee": (*Server).calculateNetworkFee, "getapplicationlog": (*Server).getApplicationLog, "getbestblockhash": (*Server).getBestBlockHash, "getblock": (*Server).getBlock, @@ -549,6 +551,83 @@ func (s *Server) validateAddress(reqParams request.Params) (interface{}, *respon return validateAddress(param.Value), nil } +// calculateNetworkFee calculates network fee for the transaction. +func (s *Server) calculateNetworkFee(reqParams request.Params) (interface{}, *response.Error) { + if len(reqParams) < 1 { + return 0, response.ErrInvalidParams + } + byteTx, err := reqParams[0].GetBytesBase64() + if err != nil { + return 0, response.WrapErrorWithData(response.ErrInvalidParams, err) + } + tx, err := transaction.NewTransactionFromBytes(byteTx) + if err != nil { + return 0, response.WrapErrorWithData(response.ErrInvalidParams, err) + } + hashablePart, err := tx.EncodeHashableFields() + if err != nil { + return 0, response.WrapErrorWithData(response.ErrInvalidParams, fmt.Errorf("failed to compute tx size: %w", err)) + } + size := len(hashablePart) + io.GetVarSize(len(tx.Signers)) + var ( + ef int64 + netFee int64 + ) + for i, signer := range tx.Signers { + var verificationScript []byte + for _, w := range tx.Scripts { + if w.VerificationScript != nil && hash.Hash160(w.VerificationScript).Equals(signer.Account) { + // then it's a standard sig/multisig witness + verificationScript = w.VerificationScript + break + } + } + if verificationScript == nil { // then it still might be a contract-based verification + verificationErr := fmt.Sprintf("contract verification for signer #%d failed", i) + res, respErr := s.runScriptInVM(trigger.Verification, []byte{}, signer.Account, tx) + if respErr != nil && errors.Is(respErr.Cause, core.ErrUnknownVerificationContract) { + // it's neither a contract-based verification script nor a standard witness attached to + // the tx, so the user did not provide enough data to calculate fee for that witness => + // it's a user error + return 0, response.NewRPCError(verificationErr, respErr.Cause.Error(), respErr.Cause) + } + if respErr != nil { + return 0, respErr + } + if res.State != "HALT" { + cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException) + return 0, response.NewRPCError(verificationErr, cause.Error(), cause) + } + if l := len(res.Stack); l != 1 { + cause := fmt.Errorf("result stack length should be equal to 1, got %d", l) + return 0, response.NewRPCError(verificationErr, cause.Error(), cause) + } + isOK, err := res.Stack[0].TryBool() + if err != nil { + cause := fmt.Errorf("resulting stackitem cannot be converted to Boolean: %w", err) + return 0, response.NewRPCError(verificationErr, cause.Error(), cause) + } + if !isOK { + cause := errors.New("`verify` method returned `false` on stack") + return 0, response.NewRPCError(verificationErr, cause.Error(), cause) + } + netFee += res.GasConsumed + size += io.GetVarSize([]byte{}) * 2 // both scripts are empty + continue + } + + if ef == 0 { + ef = s.chain.GetPolicer().GetBaseExecFee() + } + fee, sizeDelta := fee.Calculate(ef, verificationScript) + netFee += fee + size += sizeDelta + } + fee := s.chain.GetPolicer().FeePerByte() + netFee += int64(size) * fee + return netFee, nil +} + // getApplicationLog returns the contract log based on the specified txid or blockid. func (s *Server) getApplicationLog(reqParams request.Params) (interface{}, *response.Error) { hash, err := reqParams.Value(0).GetUint256()