rpc/client: provide scripts in AddNetworkFee

To calculate network fee properly we must know type of every
signer (simple, multisig, contract). Providing scripts is the most
simple and flexible way to know this.
This commit is contained in:
Evgenii Stratonikov 2020-08-17 11:12:21 +03:00
parent 58af143f25
commit 8699a4c1a9
2 changed files with 89 additions and 18 deletions

View file

@ -9,7 +9,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/rpc/request"
@ -513,27 +512,25 @@ func (c *Client) CalculateValidUntilBlock() (uint32, error) {
}
// AddNetworkFee adds network fee for each witness script and optional extra
// network fee to transaction.
func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, acc *wallet.Account) error {
size := io.GetVarSize(tx)
if acc.Contract != nil {
netFee, sizeDelta := core.CalculateNetworkFee(acc.Contract.Script)
tx.NetworkFee += netFee
size += sizeDelta
// network fee to transaction. `accs` is an array signer's accounts.
func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs ...*wallet.Account) error {
if len(tx.Signers) != len(accs) {
return errors.New("number of signers must match number of scripts")
}
for _, cosigner := range tx.Signers {
script := acc.Contract.Script
if !cosigner.Account.Equals(hash.Hash160(acc.Contract.Script)) {
size := io.GetVarSize(tx)
for i, cosigner := range tx.Signers {
if accs[i].Contract.Script == nil {
contract, err := c.GetContractState(cosigner.Account)
if err != nil {
return err
if err == nil {
if contract == nil {
continue
}
netFee, sizeDelta := core.CalculateNetworkFee(contract.Script)
tx.NetworkFee += netFee
size += sizeDelta
}
if contract == nil {
continue
}
script = contract.Script
}
netFee, sizeDelta := core.CalculateNetworkFee(script)
netFee, sizeDelta := core.CalculateNetworkFee(accs[i].Contract.Script)
tx.NetworkFee += netFee
size += sizeDelta
}

View file

@ -5,9 +5,16 @@ import (
"testing"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/internal/testchain"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/rpc/client"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/wallet"
"github.com/stretchr/testify/require"
)
@ -57,3 +64,70 @@ func TestClient_NEP5(t *testing.T) {
require.EqualValues(t, 877, b)
})
}
func TestAddNetworkFee(t *testing.T) {
chain, rpcSrv, httpSrv := initServerWithInMemoryChain(t)
defer chain.Close()
defer rpcSrv.Shutdown()
c, err := client.New(context.Background(), httpSrv.URL, client.Options{Network: testchain.Network()})
require.NoError(t, err)
getAccounts := func(t *testing.T, n int) []*wallet.Account {
accs := make([]*wallet.Account, n)
var err error
for i := range accs {
accs[i], err = wallet.NewAccount()
require.NoError(t, err)
}
return accs
}
feePerByte := chain.FeePerByte()
t.Run("Invalid", func(t *testing.T) {
tx := transaction.New(testchain.Network(), []byte{byte(opcode.PUSH1)}, 0)
accs := getAccounts(t, 2)
tx.Signers = []transaction.Signer{{
Account: accs[0].PrivateKey().GetScriptHash(),
Scopes: transaction.CalledByEntry,
}}
require.Error(t, c.AddNetworkFee(tx, 10, accs[0], accs[1]))
})
t.Run("Simple", func(t *testing.T) {
tx := transaction.New(testchain.Network(), []byte{byte(opcode.PUSH1)}, 0)
accs := getAccounts(t, 1)
tx.Signers = []transaction.Signer{{
Account: accs[0].PrivateKey().GetScriptHash(),
Scopes: transaction.CalledByEntry,
}}
require.NoError(t, c.AddNetworkFee(tx, 10, accs[0]))
require.NoError(t, accs[0].SignTx(tx))
cFee, _ := core.CalculateNetworkFee(accs[0].Contract.Script)
require.Equal(t, int64(io.GetVarSize(tx))*feePerByte+cFee+10, tx.NetworkFee)
})
t.Run("Multi", func(t *testing.T) {
tx := transaction.New(testchain.Network(), []byte{byte(opcode.PUSH1)}, 0)
accs := getAccounts(t, 3)
pubs := keys.PublicKeys{accs[1].PrivateKey().PublicKey(), accs[2].PrivateKey().PublicKey()}
require.NoError(t, accs[1].ConvertMultisig(1, pubs))
require.NoError(t, accs[2].ConvertMultisig(1, pubs))
tx.Signers = []transaction.Signer{
{
Account: accs[0].PrivateKey().GetScriptHash(),
Scopes: transaction.CalledByEntry,
},
{
Account: hash.Hash160(accs[1].Contract.Script),
Scopes: transaction.Global,
},
}
require.NoError(t, c.AddNetworkFee(tx, 10, accs[0], accs[1]))
require.NoError(t, accs[0].SignTx(tx))
require.NoError(t, accs[1].SignTx(tx))
cFee, _ := core.CalculateNetworkFee(accs[0].Contract.Script)
cFeeM, _ := core.CalculateNetworkFee(accs[1].Contract.Script)
require.Equal(t, int64(io.GetVarSize(tx))*feePerByte+cFee+cFeeM+10, tx.NetworkFee)
})
}