diff --git a/cli/wallet/validator.go b/cli/wallet/validator.go index 2e9a4c8be..fc03899cb 100644 --- a/cli/wallet/validator.go +++ b/cli/wallet/validator.go @@ -100,7 +100,7 @@ func handleCandidate(ctx *cli.Context, method string) error { w := io.NewBufBinWriter() emit.AppCallWithOperationAndArgs(w.BinWriter, client.NeoContractHash, method, acc.PrivateKey().PublicKey().Bytes()) emit.Opcode(w.BinWriter, opcode.ASSERT) - tx, err := c.CreateTxFromScript(w.Bytes(), acc, int64(gas)) + tx, err := c.CreateTxFromScript(w.Bytes(), acc, -1, int64(gas)) if err != nil { return cli.NewExitError(err, 1) } else if err = acc.SignTx(tx); err != nil { @@ -155,7 +155,7 @@ func handleVote(ctx *cli.Context) error { emit.AppCallWithOperationAndArgs(w.BinWriter, client.NeoContractHash, "vote", addr.BytesBE(), pubArg) emit.Opcode(w.BinWriter, opcode.ASSERT) - tx, err := c.CreateTxFromScript(w.Bytes(), acc, int64(gas)) + tx, err := c.CreateTxFromScript(w.Bytes(), acc, -1, int64(gas)) if err != nil { return cli.NewExitError(err, 1) } diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index b7a9a8747..1eae743fa 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -9,11 +9,9 @@ import ( "net" "net/http" "net/url" - "sync" "time" "github.com/nspcc-dev/neo-go/pkg/config/netmode" - "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response" ) @@ -34,8 +32,6 @@ type Client struct { ctx context.Context opts Options requestF func(*request.Raw) (*response.Raw, error) - wifMu *sync.Mutex - wif *keys.WIF cache cache } @@ -96,7 +92,6 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { cl := &Client{ ctx: ctx, cli: httpClient, - wifMu: new(sync.Mutex), endpoint: url, } cl.opts = opts @@ -104,31 +99,6 @@ func New(ctx context.Context, endpoint string, opts Options) (*Client, error) { return cl, nil } -// WIF returns WIF structure associated with the client. -func (c *Client) WIF() keys.WIF { - c.wifMu.Lock() - defer c.wifMu.Unlock() - return keys.WIF{ - Version: c.wif.Version, - Compressed: c.wif.Compressed, - PrivateKey: c.wif.PrivateKey, - S: c.wif.S, - } -} - -// SetWIF decodes given WIF and adds some wallet -// data to client. Useful for RPC calls that require an open wallet. -func (c *Client) SetWIF(wif string) error { - c.wifMu.Lock() - defer c.wifMu.Unlock() - decodedWif, err := keys.WIFDecode(wif, 0x00) - if err != nil { - return fmt.Errorf("failed to decode WIF: %w", err) - } - c.wif = decodedWif - return nil -} - func (c *Client) performRequest(method string, p request.RawParams, v interface{}) error { var r = request.Raw{ JSONRPC: request.JSONRPCVersion, diff --git a/pkg/rpc/client/nep5.go b/pkg/rpc/client/nep5.go index 000376901..7a1f22cc5 100644 --- a/pkg/rpc/client/nep5.go +++ b/pkg/rpc/client/nep5.go @@ -134,39 +134,38 @@ func (c *Client) CreateNEP5MultiTransferTx(acc *wallet.Account, gas int64, recip recipients[i].Address, recipients[i].Amount) emit.Opcode(w.BinWriter, opcode.ASSERT) } - return c.CreateTxFromScript(w.Bytes(), acc, gas) + return c.CreateTxFromScript(w.Bytes(), acc, -1, gas) } // CreateTxFromScript creates transaction and properly sets cosigners and NetworkFee. -func (c *Client) CreateTxFromScript(script []byte, acc *wallet.Account, gas int64) (*transaction.Transaction, error) { +// If sysFee <= 0, it is determined via result of `invokescript` RPC. +func (c *Client) CreateTxFromScript(script []byte, acc *wallet.Account, sysFee, netFee int64, + cosigners ...transaction.Signer) (*transaction.Transaction, error) { from, err := address.StringToUint160(acc.Address) if err != nil { return nil, fmt.Errorf("bad account address: %v", err) } - result, err := c.InvokeScript(script, []transaction.Signer{ - { - Account: from, - Scopes: transaction.CalledByEntry, - }, - }) - if err != nil { - return nil, fmt.Errorf("can't add system fee to transaction: %w", err) - } - tx := transaction.New(c.opts.Network, script, result.GasConsumed) - tx.Signers = []transaction.Signer{ - { - Account: from, - Scopes: transaction.CalledByEntry, - }, - } - tx.ValidUntilBlock, err = c.CalculateValidUntilBlock() - if err != nil { - return nil, fmt.Errorf("can't calculate validUntilBlock: %w", err) + + signers := getSigners(from, cosigners) + if sysFee < 0 { + result, err := c.InvokeScript(script, signers) + if err != nil { + return nil, fmt.Errorf("can't add system fee to transaction: %w", err) + } + sysFee = result.GasConsumed } - err = c.AddNetworkFee(tx, gas, acc) + tx := transaction.New(c.opts.Network, script, sysFee) + tx.Signers = signers + + tx.ValidUntilBlock, err = c.CalculateValidUntilBlock() if err != nil { - return nil, fmt.Errorf("can't add network fee to transaction: %w", err) + return nil, fmt.Errorf("failed to add validUntilBlock to transaction: %w", err) + } + + err = c.AddNetworkFee(tx, netFee, acc) + if err != nil { + return nil, fmt.Errorf("failed to add network fee: %w", err) } return tx, nil diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index 897420b1f..f45363f6e 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -9,8 +9,6 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" - "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/rpc/request" "github.com/nspcc-dev/neo-go/pkg/rpc/response/result" @@ -411,26 +409,7 @@ func (c *Client) SignAndPushInvocationTx(script []byte, acc *wallet.Account, sys var txHash util.Uint256 var err error - tx := transaction.New(c.opts.Network, script, sysfee) - tx.SystemFee = sysfee - - addr, err := address.StringToUint160(acc.Address) - if err != nil { - return txHash, fmt.Errorf("failed to get address: %w", err) - } - tx.Signers = getSigners(addr, cosigners) - - validUntilBlock, err := c.CalculateValidUntilBlock() - if err != nil { - return txHash, fmt.Errorf("failed to add validUntilBlock to transaction: %w", err) - } - tx.ValidUntilBlock = validUntilBlock - - err = c.AddNetworkFee(tx, int64(netfee), acc) - if err != nil { - return txHash, fmt.Errorf("failed to add network fee: %w", err) - } - + tx, err := c.CreateTxFromScript(script, acc, sysfee, int64(netfee), cosigners...) if err = acc.SignTx(tx); err != nil { return txHash, fmt.Errorf("failed to sign tx: %w", err) } @@ -513,27 +492,25 @@ func (c *Client) CalculateValidUntilBlock() (uint32, error) { } // AddNetworkFee adds network fee for each witness script and optional extra -// network fee to transaction. -func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, acc *wallet.Account) error { - size := io.GetVarSize(tx) - if acc.Contract != nil { - netFee, sizeDelta := core.CalculateNetworkFee(acc.Contract.Script) - tx.NetworkFee += netFee - size += sizeDelta +// network fee to transaction. `accs` is an array signer's accounts. +func (c *Client) AddNetworkFee(tx *transaction.Transaction, extraFee int64, accs ...*wallet.Account) error { + if len(tx.Signers) != len(accs) { + return errors.New("number of signers must match number of scripts") } - for _, cosigner := range tx.Signers { - script := acc.Contract.Script - if !cosigner.Account.Equals(hash.Hash160(acc.Contract.Script)) { + size := io.GetVarSize(tx) + for i, cosigner := range tx.Signers { + if accs[i].Contract.Script == nil { contract, err := c.GetContractState(cosigner.Account) - if err != nil { - return err + if err == nil { + if contract == nil { + continue + } + netFee, sizeDelta := core.CalculateNetworkFee(contract.Script) + tx.NetworkFee += netFee + size += sizeDelta } - if contract == nil { - continue - } - script = contract.Script } - netFee, sizeDelta := core.CalculateNetworkFee(script) + netFee, sizeDelta := core.CalculateNetworkFee(accs[i].Contract.Script) tx.NetworkFee += netFee size += sizeDelta } diff --git a/pkg/rpc/server/client_test.go b/pkg/rpc/server/client_test.go index ca4b9750e..059a3c7b0 100644 --- a/pkg/rpc/server/client_test.go +++ b/pkg/rpc/server/client_test.go @@ -5,9 +5,16 @@ import ( "testing" "github.com/nspcc-dev/neo-go/pkg/config/netmode" + "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/internal/testchain" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/rpc/client" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/require" ) @@ -57,3 +64,137 @@ func TestClient_NEP5(t *testing.T) { require.EqualValues(t, 877, b) }) } + +func TestAddNetworkFee(t *testing.T) { + chain, rpcSrv, httpSrv := initServerWithInMemoryChain(t) + defer chain.Close() + defer rpcSrv.Shutdown() + + c, err := client.New(context.Background(), httpSrv.URL, client.Options{Network: testchain.Network()}) + require.NoError(t, err) + + getAccounts := func(t *testing.T, n int) []*wallet.Account { + accs := make([]*wallet.Account, n) + var err error + for i := range accs { + accs[i], err = wallet.NewAccount() + require.NoError(t, err) + } + return accs + } + + feePerByte := chain.FeePerByte() + + t.Run("Invalid", func(t *testing.T) { + tx := transaction.New(testchain.Network(), []byte{byte(opcode.PUSH1)}, 0) + accs := getAccounts(t, 2) + tx.Signers = []transaction.Signer{{ + Account: accs[0].PrivateKey().GetScriptHash(), + Scopes: transaction.CalledByEntry, + }} + require.Error(t, c.AddNetworkFee(tx, 10, accs[0], accs[1])) + }) + t.Run("Simple", func(t *testing.T) { + tx := transaction.New(testchain.Network(), []byte{byte(opcode.PUSH1)}, 0) + accs := getAccounts(t, 1) + tx.Signers = []transaction.Signer{{ + Account: accs[0].PrivateKey().GetScriptHash(), + Scopes: transaction.CalledByEntry, + }} + require.NoError(t, c.AddNetworkFee(tx, 10, accs[0])) + require.NoError(t, accs[0].SignTx(tx)) + cFee, _ := core.CalculateNetworkFee(accs[0].Contract.Script) + require.Equal(t, int64(io.GetVarSize(tx))*feePerByte+cFee+10, tx.NetworkFee) + }) + + t.Run("Multi", func(t *testing.T) { + tx := transaction.New(testchain.Network(), []byte{byte(opcode.PUSH1)}, 0) + accs := getAccounts(t, 4) + pubs := keys.PublicKeys{accs[1].PrivateKey().PublicKey(), accs[2].PrivateKey().PublicKey(), accs[3].PrivateKey().PublicKey()} + require.NoError(t, accs[1].ConvertMultisig(2, pubs)) + require.NoError(t, accs[2].ConvertMultisig(2, pubs)) + tx.Signers = []transaction.Signer{ + { + Account: accs[0].PrivateKey().GetScriptHash(), + Scopes: transaction.CalledByEntry, + }, + { + Account: hash.Hash160(accs[1].Contract.Script), + Scopes: transaction.Global, + }, + } + require.NoError(t, c.AddNetworkFee(tx, 10, accs[0], accs[1])) + require.NoError(t, accs[0].SignTx(tx)) + require.NoError(t, accs[1].SignTx(tx)) + require.NoError(t, accs[2].SignTx(tx)) + cFee, _ := core.CalculateNetworkFee(accs[0].Contract.Script) + cFeeM, _ := core.CalculateNetworkFee(accs[1].Contract.Script) + require.Equal(t, int64(io.GetVarSize(tx))*feePerByte+cFee+cFeeM+10, tx.NetworkFee) + }) +} + +func TestSignAndPushInvocationTx(t *testing.T) { + chain, rpcSrv, httpSrv := initServerWithInMemoryChain(t) + defer chain.Close() + defer rpcSrv.Shutdown() + + c, err := client.New(context.Background(), httpSrv.URL, client.Options{Network: testchain.Network()}) + require.NoError(t, err) + + priv := testchain.PrivateKey(0) + acc, err := wallet.NewAccountFromWIF(priv.WIF()) + require.NoError(t, err) + h, err := c.SignAndPushInvocationTx([]byte{byte(opcode.PUSH1)}, acc, 30, 0, []transaction.Signer{{ + Account: priv.GetScriptHash(), + Scopes: transaction.CalledByEntry, + }}) + require.NoError(t, err) + + mp := chain.GetMemPool() + tx, ok := mp.TryGetValue(h) + require.True(t, ok) + require.Equal(t, h, tx.Hash()) + require.EqualValues(t, 30, tx.SystemFee) +} + +func TestPing(t *testing.T) { + chain, rpcSrv, httpSrv := initServerWithInMemoryChain(t) + defer chain.Close() + + c, err := client.New(context.Background(), httpSrv.URL, client.Options{Network: testchain.Network()}) + require.NoError(t, err) + + require.NoError(t, c.Ping()) + require.NoError(t, rpcSrv.Shutdown()) + httpSrv.Close() + require.Error(t, c.Ping()) +} + +func TestCreateTxFromScript(t *testing.T) { + chain, rpcSrv, httpSrv := initServerWithInMemoryChain(t) + defer chain.Close() + defer rpcSrv.Shutdown() + + c, err := client.New(context.Background(), httpSrv.URL, client.Options{Network: testchain.Network()}) + require.NoError(t, err) + + priv := testchain.PrivateKey(0) + acc, err := wallet.NewAccountFromWIF(priv.WIF()) + require.NoError(t, err) + t.Run("NoSystemFee", func(t *testing.T) { + tx, err := c.CreateTxFromScript([]byte{byte(opcode.PUSH1)}, acc, -1, 10) + require.NoError(t, err) + require.True(t, tx.ValidUntilBlock > chain.BlockHeight()) + require.EqualValues(t, 30, tx.SystemFee) // PUSH1 + require.True(t, len(tx.Signers) == 1) + require.Equal(t, acc.PrivateKey().GetScriptHash(), tx.Signers[0].Account) + }) + t.Run("ProvideSystemFee", func(t *testing.T) { + tx, err := c.CreateTxFromScript([]byte{byte(opcode.PUSH1)}, acc, 123, 10) + require.NoError(t, err) + require.True(t, tx.ValidUntilBlock > chain.BlockHeight()) + require.EqualValues(t, 123, tx.SystemFee) + require.True(t, len(tx.Signers) == 1) + require.Equal(t, acc.PrivateKey().GetScriptHash(), tx.Signers[0].Account) + }) +} diff --git a/pkg/vm/context.go b/pkg/vm/context.go index aff90f0a5..f2cf01ab1 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -221,7 +221,7 @@ func (c *Context) Equals(s stackitem.Item) bool { func (c *Context) atBreakPoint() bool { for _, n := range c.breakPoints { - if n == c.ip { + if n == c.nextip { return true } } diff --git a/pkg/vm/debug_test.go b/pkg/vm/debug_test.go new file mode 100644 index 000000000..7e6a8877e --- /dev/null +++ b/pkg/vm/debug_test.go @@ -0,0 +1,42 @@ +package vm + +import ( + "math/big" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/stretchr/testify/require" +) + +func TestVM_Debug(t *testing.T) { + prog := makeProgram(opcode.CALL, 3, opcode.RET, + opcode.PUSH2, opcode.PUSH3, opcode.ADD, opcode.RET) + t.Run("BreakPoint", func(t *testing.T) { + v := load(prog) + v.AddBreakPoint(3) + v.AddBreakPoint(5) + require.NoError(t, v.Run()) + require.Equal(t, 3, v.Context().NextIP()) + require.NoError(t, v.Run()) + require.Equal(t, 5, v.Context().NextIP()) + require.NoError(t, v.Run()) + require.Equal(t, 1, v.estack.len) + require.Equal(t, big.NewInt(5), v.estack.Top().Value()) + }) + t.Run("StepInto", func(t *testing.T) { + v := load(prog) + require.NoError(t, v.StepInto()) + require.Equal(t, 3, v.Context().NextIP()) + require.NoError(t, v.StepOut()) + require.Equal(t, 2, v.Context().NextIP()) + require.Equal(t, 1, v.estack.len) + require.Equal(t, big.NewInt(5), v.estack.Top().Value()) + }) + t.Run("StepOver", func(t *testing.T) { + v := load(prog) + require.NoError(t, v.StepOver()) + require.Equal(t, 2, v.Context().NextIP()) + require.Equal(t, 1, v.estack.len) + require.Equal(t, big.NewInt(5), v.estack.Top().Value()) + }) +} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index ac19b20e4..af9cfff6e 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -89,7 +89,7 @@ func New() *VM { // NewWithTrigger returns a new VM for executions triggered by t. func NewWithTrigger(t trigger.Type) *VM { vm := &VM{ - state: HaltState, + state: NoneState, istack: NewStack("invocation"), refs: newRefCounter(), keys: make(map[string]*keys.PublicKey), @@ -358,11 +358,6 @@ func (v *VM) Run() error { // HaltState (the default) or BreakState are safe to continue. v.state = NoneState for { - // check for breakpoint before executing the next instruction - ctx := v.Context() - if ctx != nil && ctx.atBreakPoint() { - v.state = BreakState - } switch { case v.state.HasFlag(FaultState): // Should be caught and reported already by the v.Step(), @@ -379,6 +374,11 @@ func (v *VM) Run() error { v.state = FaultState return errors.New("unknown state") } + // check for breakpoint before executing the next instruction + ctx := v.Context() + if ctx != nil && ctx.atBreakPoint() { + v.state = BreakState + } } } @@ -430,14 +430,15 @@ func (v *VM) StepOut() error { var err error if v.state == BreakState { v.state = NoneState - } else { - v.state = BreakState } expSize := v.istack.len for v.state == NoneState && v.istack.len >= expSize { err = v.StepInto() } + if v.state == NoneState { + v.state = BreakState + } return err } @@ -451,8 +452,6 @@ func (v *VM) StepOver() error { if v.state == BreakState { v.state = NoneState - } else { - v.state = BreakState } expSize := v.istack.len diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index cd2fae269..fbf354f05 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -516,7 +516,7 @@ func checkEnumeratorStack(t *testing.T, vm *VM, arr []stackitem.Item) { } } -func testIterableCreate(t *testing.T, typ string) { +func testIterableCreate(t *testing.T, typ string, isByteArray bool) { isIter := typ == "Iterator" prog := getSyscallProg("System." + typ + ".Create") prog = append(prog, getEnumeratorProg(2, isIter)...) @@ -526,7 +526,12 @@ func testIterableCreate(t *testing.T, typ string) { stackitem.NewBigInteger(big.NewInt(42)), stackitem.NewByteArray([]byte{3, 2, 1}), } - vm.estack.Push(&Element{value: stackitem.NewArray(arr)}) + if isByteArray { + arr[1] = stackitem.Make(7) + vm.estack.PushVal([]byte{42, 7}) + } else { + vm.estack.Push(&Element{value: stackitem.NewArray(arr)}) + } runVM(t, vm) if isIter { @@ -543,11 +548,13 @@ func testIterableCreate(t *testing.T, typ string) { } func TestEnumeratorCreate(t *testing.T) { - testIterableCreate(t, "Enumerator") + t.Run("Array", func(t *testing.T) { testIterableCreate(t, "Enumerator", false) }) + t.Run("ByteArray", func(t *testing.T) { testIterableCreate(t, "Enumerator", true) }) } func TestIteratorCreate(t *testing.T) { - testIterableCreate(t, "Iterator") + t.Run("Array", func(t *testing.T) { testIterableCreate(t, "Iterator", false) }) + t.Run("ByteArray", func(t *testing.T) { testIterableCreate(t, "Iterator", true) }) } func testIterableConcat(t *testing.T, typ string) { diff --git a/pkg/wallet/account.go b/pkg/wallet/account.go index 6ec829f21..393bf22a7 100644 --- a/pkg/wallet/account.go +++ b/pkg/wallet/account.go @@ -104,9 +104,17 @@ func (a *Account) SignTx(t *transaction.Transaction) error { } sign := a.privateKey.Sign(data) + verif := a.getVerificationScript() + invoc := append([]byte{byte(opcode.PUSHDATA1), 64}, sign...) + for i := range t.Scripts { + if bytes.Equal(t.Scripts[i].VerificationScript, verif) { + t.Scripts[i].InvocationScript = append(t.Scripts[i].InvocationScript, invoc...) + return nil + } + } t.Scripts = append(t.Scripts, transaction.Witness{ - InvocationScript: append([]byte{byte(opcode.PUSHDATA1), 64}, sign...), - VerificationScript: a.getVerificationScript(), + InvocationScript: invoc, + VerificationScript: verif, }) return nil