From 7b53a0c239405c4641070f9b32f591173da03e8b Mon Sep 17 00:00:00 2001 From: Dmitrii Stepanov Date: Tue, 28 Nov 2023 09:23:55 +0300 Subject: [PATCH] neotest: Add contract signer support Signed-off-by: Dmitrii Stepanov --- pkg/core/native/native_test/ledger_test.go | 2 +- pkg/neotest/basic.go | 30 ++++++-- pkg/neotest/signer.go | 79 ++++++++++++++++++++++ 3 files changed, 104 insertions(+), 7 deletions(-) diff --git a/pkg/core/native/native_test/ledger_test.go b/pkg/core/native/native_test/ledger_test.go index 3af6d346a..a7d878aef 100644 --- a/pkg/core/native/native_test/ledger_test.go +++ b/pkg/core/native/native_test/ledger_test.go @@ -225,7 +225,7 @@ func TestLedger_GetTransactionSignersInteropAPI(t *testing.T) { }, }, }} - neotest.AddNetworkFee(e.Chain, tx, c.Committee) + neotest.AddNetworkFee(t, e.Chain, tx, c.Committee) neotest.AddSystemFee(e.Chain, tx, -1) require.NoError(t, c.Committee.SignTx(e.Chain.GetConfig().Magic, tx)) c.AddNewBlock(t, tx) diff --git a/pkg/neotest/basic.go b/pkg/neotest/basic.go index 4e5a7eb4d..c64a6b4a2 100644 --- a/pkg/neotest/basic.go +++ b/pkg/neotest/basic.go @@ -106,7 +106,7 @@ func (e *Executor) SignTx(t testing.TB, tx *transaction.Transaction, sysFee int6 Scopes: transaction.Global, }) } - AddNetworkFee(e.Chain, tx, signers...) + AddNetworkFee(t, e.Chain, tx, signers...) AddSystemFee(e.Chain, tx, sysFee) for _, acc := range signers { @@ -280,7 +280,7 @@ func NewDeployTxBy(t testing.TB, bc *core.Blockchain, signer Signer, c *Contract Account: signer.ScriptHash(), Scopes: transaction.Global, }} - AddNetworkFee(bc, tx, signer) + AddNetworkFee(t, bc, tx, signer) require.NoError(t, signer.SignTx(netmode.UnitTestNet, tx)) return tx } @@ -297,13 +297,31 @@ func AddSystemFee(bc *core.Blockchain, tx *transaction.Transaction, sysFee int64 } // AddNetworkFee adds network fee to the transaction. -func AddNetworkFee(bc *core.Blockchain, tx *transaction.Transaction, signers ...Signer) { +func AddNetworkFee(t testing.TB, bc *core.Blockchain, tx *transaction.Transaction, signers ...Signer) { baseFee := bc.GetBaseExecFee() size := io.GetVarSize(tx) for _, sgr := range signers { - netFee, sizeDelta := fee.Calculate(baseFee, sgr.Script()) - tx.NetworkFee += netFee - size += sizeDelta + if csgr, ok := sgr.(ContractSigner); ok { + sc, err := csgr.InvocationScript(tx) + require.NoError(t, err) + + txCopy := *tx + ic, err := bc.GetTestVM(trigger.Verification, &txCopy, nil) + require.NoError(t, err) + + ic.UseSigners(tx.Signers) + ic.VM.GasLimit = bc.GetMaxVerificationGAS() + + require.NoError(t, bc.InitVerificationContext(ic, csgr.ScriptHash(), &transaction.Witness{InvocationScript: sc, VerificationScript: csgr.Script()})) + require.NoError(t, ic.VM.Run()) + + tx.NetworkFee += ic.VM.GasConsumed() + size += io.GetVarSize(sc) + io.GetVarSize(csgr.Script()) + } else { + netFee, sizeDelta := fee.Calculate(baseFee, sgr.Script()) + tx.NetworkFee += netFee + size += sizeDelta + } } tx.NetworkFee += int64(size)*bc.FeePerByte() + bc.CalculateAttributesFee(tx) } diff --git a/pkg/neotest/signer.go b/pkg/neotest/signer.go index 45369101b..8446e177f 100644 --- a/pkg/neotest/signer.go +++ b/pkg/neotest/signer.go @@ -2,6 +2,7 @@ package neotest import ( "bytes" + "errors" "fmt" "sort" "testing" @@ -10,8 +11,10 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/require" @@ -44,6 +47,13 @@ type MultiSigner interface { Single(n int) SingleSigner } +// ContractSigner is an interface for contract signer. +type ContractSigner interface { + Signer + // InvocationScript returns an invocation script to be used as invocation script for contract-based witness. + InvocationScript(tx *transaction.Transaction) ([]byte, error) +} + // signer represents a simple-signature signer. type signer wallet.Account @@ -179,3 +189,72 @@ func checkMultiSigner(t testing.TB, s Signer) { require.Equal(t, h, accs[i].Contract.ScriptHash(), "inconsistent multi-signer accounts") } } + +type contractSigner struct { + params func(tx *transaction.Transaction) []any + scriptHash util.Uint160 +} + +// NewContractSigner returns a contract signer for the provided contract hash. +// getInvParams must return params to be used as invocation script for contract-based witness. +func NewContractSigner(h util.Uint160, getInvParams func(tx *transaction.Transaction) []any) ContractSigner { + return &contractSigner{ + scriptHash: h, + params: getInvParams, + } +} + +// InvocationScript implements ContractSigner. +func (s *contractSigner) InvocationScript(tx *transaction.Transaction) ([]byte, error) { + params := s.params(tx) + script := io.NewBufBinWriter() + for i := range params { + emit.Any(script.BinWriter, params[i]) + } + if script.Err != nil { + return nil, script.Err + } + return script.Bytes(), nil +} + +// Script implements ContractSigner. +func (s *contractSigner) Script() []byte { + return []byte{} +} + +// ScriptHash implements ContractSigner. +func (s *contractSigner) ScriptHash() util.Uint160 { + return s.scriptHash +} + +// SignHashable implements ContractSigner. +func (s *contractSigner) SignHashable(uint32, hash.Hashable) []byte { + panic("not supported") +} + +// SignTx implements ContractSigner. +func (s *contractSigner) SignTx(magic netmode.Magic, tx *transaction.Transaction) error { + pos := -1 + for idx := range tx.Signers { + if tx.Signers[idx].Account.Equals(s.ScriptHash()) { + pos = idx + break + } + } + if pos < 0 { + return fmt.Errorf("signer %s not found", s.ScriptHash().String()) + } + if len(tx.Scripts) < pos { + return errors.New("transaction is not yet signed by the previous signer") + } + invoc, err := s.InvocationScript(tx) + if err != nil { + return err + } + if len(tx.Scripts) == pos { + tx.Scripts = append(tx.Scripts, transaction.Witness{}) + } + tx.Scripts[pos].InvocationScript = invoc + tx.Scripts[pos].VerificationScript = s.Script() + return nil +}