diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index 897420b1f..32bdbceb6 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -9,7 +9,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/rpc/request" @@ -513,27 +512,25 @@ func (c *Client) CalculateValidUntilBlock() (uint32, error) { } // AddNetworkFee adds network fee for each witness script and optional extra -// network fee to transaction. -func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, acc *wallet.Account) error { - size := io.GetVarSize(tx) - if acc.Contract != nil { - netFee, sizeDelta := core.CalculateNetworkFee(acc.Contract.Script) - tx.NetworkFee += netFee - size += sizeDelta +// network fee to transaction. `accs` is an array signer's accounts. +func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs ...*wallet.Account) error { + if len(tx.Signers) != len(accs) { + return errors.New("number of signers must match number of scripts") } - for _, cosigner := range tx.Signers { - script := acc.Contract.Script - if !cosigner.Account.Equals(hash.Hash160(acc.Contract.Script)) { + size := io.GetVarSize(tx) + for i, cosigner := range tx.Signers { + if accs[i].Contract.Script == nil { contract, err := c.GetContractState(cosigner.Account) - if err != nil { - return err + if err == nil { + if contract == nil { + continue + } + netFee, sizeDelta := core.CalculateNetworkFee(contract.Script) + tx.NetworkFee += netFee + size += sizeDelta } - if contract == nil { - continue - } - script = contract.Script } - netFee, sizeDelta := core.CalculateNetworkFee(script) + netFee, sizeDelta := core.CalculateNetworkFee(accs[i].Contract.Script) tx.NetworkFee += netFee size += sizeDelta } diff --git a/pkg/rpc/server/client_test.go b/pkg/rpc/server/client_test.go index ca4b9750e..9bcff6063 100644 --- a/pkg/rpc/server/client_test.go +++ b/pkg/rpc/server/client_test.go @@ -5,9 +5,16 @@ import ( "testing" "github.com/nspcc-dev/neo-go/pkg/config/netmode" + "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/internal/testchain" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/rpc/client" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/require" ) @@ -57,3 +64,70 @@ func TestClient_NEP5(t *testing.T) { require.EqualValues(t, 877, b) }) } + +func TestAddNetworkFee(t *testing.T) { + chain, rpcSrv, httpSrv := initServerWithInMemoryChain(t) + defer chain.Close() + defer rpcSrv.Shutdown() + + c, err := client.New(context.Background(), httpSrv.URL, client.Options{Network: testchain.Network()}) + require.NoError(t, err) + + getAccounts := func(t *testing.T, n int) []*wallet.Account { + accs := make([]*wallet.Account, n) + var err error + for i := range accs { + accs[i], err = wallet.NewAccount() + require.NoError(t, err) + } + return accs + } + + feePerByte := chain.FeePerByte() + + t.Run("Invalid", func(t *testing.T) { + tx := transaction.New(testchain.Network(), []byte{byte(opcode.PUSH1)}, 0) + accs := getAccounts(t, 2) + tx.Signers = []transaction.Signer{{ + Account: accs[0].PrivateKey().GetScriptHash(), + Scopes: transaction.CalledByEntry, + }} + require.Error(t, c.AddNetworkFee(tx, 10, accs[0], accs[1])) + }) + t.Run("Simple", func(t *testing.T) { + tx := transaction.New(testchain.Network(), []byte{byte(opcode.PUSH1)}, 0) + accs := getAccounts(t, 1) + tx.Signers = []transaction.Signer{{ + Account: accs[0].PrivateKey().GetScriptHash(), + Scopes: transaction.CalledByEntry, + }} + require.NoError(t, c.AddNetworkFee(tx, 10, accs[0])) + require.NoError(t, accs[0].SignTx(tx)) + cFee, _ := core.CalculateNetworkFee(accs[0].Contract.Script) + require.Equal(t, int64(io.GetVarSize(tx))*feePerByte+cFee+10, tx.NetworkFee) + }) + + t.Run("Multi", func(t *testing.T) { + tx := transaction.New(testchain.Network(), []byte{byte(opcode.PUSH1)}, 0) + accs := getAccounts(t, 3) + pubs := keys.PublicKeys{accs[1].PrivateKey().PublicKey(), accs[2].PrivateKey().PublicKey()} + require.NoError(t, accs[1].ConvertMultisig(1, pubs)) + require.NoError(t, accs[2].ConvertMultisig(1, pubs)) + tx.Signers = []transaction.Signer{ + { + Account: accs[0].PrivateKey().GetScriptHash(), + Scopes: transaction.CalledByEntry, + }, + { + Account: hash.Hash160(accs[1].Contract.Script), + Scopes: transaction.Global, + }, + } + require.NoError(t, c.AddNetworkFee(tx, 10, accs[0], accs[1])) + require.NoError(t, accs[0].SignTx(tx)) + require.NoError(t, accs[1].SignTx(tx)) + cFee, _ := core.CalculateNetworkFee(accs[0].Contract.Script) + cFeeM, _ := core.CalculateNetworkFee(accs[1].Contract.Script) + require.Equal(t, int64(io.GetVarSize(tx))*feePerByte+cFee+cFeeM+10, tx.NetworkFee) + }) +}