diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index fe2925771..6362c443e 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -60,12 +60,18 @@ func callExInternal(ic *interop.Context, h []byte, name string, args []stackitem } } } - return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty, nil) + return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty) } // CallExInternal calls a contract with flags and can't be invoked directly by user. func CallExInternal(ic *interop.Context, cs *state.Contract, - name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState, callback func(ctx *vm.Context)) error { + name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState) error { + return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, checkReturn) +} + +// callExFromNative calls a contract with flags using provided calling hash. +func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, + name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState) error { md := cs.Manifest.ABI.GetMethod(name) if md == nil { return fmt.Errorf("method '%s' not found", name) @@ -76,7 +82,7 @@ func CallExInternal(ic *interop.Context, cs *state.Contract, } ic.VM.Invocations[cs.Hash]++ - ic.VM.LoadScriptWithHash(cs.Script, cs.Hash, ic.VM.Context().GetCallFlags()&f) + ic.VM.LoadScriptWithCallingHash(caller, cs.Script, cs.Hash, ic.VM.Context().GetCallFlags()&f) var isNative bool for i := range ic.Natives { if ic.Natives[i].Metadata().Hash.Equals(cs.Hash) { @@ -95,7 +101,6 @@ func CallExInternal(ic *interop.Context, cs *state.Contract, ic.VM.Jump(ic.VM.Context(), md.Offset) } ic.VM.Context().CheckReturn = checkReturn - ic.VM.Context().Callback = callback md = cs.Manifest.ABI.GetMethod(manifest.MethodInit) if md != nil { @@ -104,3 +109,24 @@ func CallExInternal(ic *interop.Context, cs *state.Contract, return nil } + +// ErrNativeCall is returned for failed calls from native. +var ErrNativeCall = errors.New("error during call from native") + +// CallFromNative performs synchronous call from native contract. +func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, checkReturn vm.CheckReturnState) error { + startSize := ic.VM.Istack().Len() + if err := callExFromNative(ic, caller, cs, method, args, smartcontract.All, checkReturn); err != nil { + return err + } + + for !ic.VM.HasStopped() && ic.VM.Istack().Len() > startSize { + if err := ic.VM.Step(); err != nil { + return fmt.Errorf("%w: %v", ErrNativeCall, err) + } + } + if ic.VM.State() == vm.FaultState { + return ErrNativeCall + } + return nil +} diff --git a/pkg/core/interop/runtime/engine.go b/pkg/core/interop/runtime/engine.go index 7dac81171..0ddf03136 100644 --- a/pkg/core/interop/runtime/engine.go +++ b/pkg/core/interop/runtime/engine.go @@ -22,8 +22,13 @@ func GetExecutingScriptHash(ic *interop.Context) error { } // GetCallingScriptHash returns calling script hash. +// While Executing and Entry script hashes are always valid for non-native contracts, +// Calling hash is set explicitly when native contracts are used, because when switching from +// one native to another, no operations are performed on invocation stack. func GetCallingScriptHash(ic *interop.Context) error { - return ic.VM.PushContextScriptHash(1) + h := ic.VM.GetCallingScriptHash() + ic.VM.Estack().PushVal(h.BytesBE()) + return nil } // GetEntryScriptHash returns entry script hash. diff --git a/pkg/core/interop_neo.go b/pkg/core/interop_neo.go index a65902903..2971d57b3 100644 --- a/pkg/core/interop_neo.go +++ b/pkg/core/interop_neo.go @@ -187,7 +187,7 @@ func callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) error { md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy) if md != nil { return contract.CallExInternal(ic, cs, manifest.MethodDeploy, - []stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty, nil) + []stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty) } return nil } diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index 618ab24a5..c968bb3ef 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -440,7 +440,8 @@ func getTestContractState() (*state.Contract, *state.Contract) { emit.Syscall(w.BinWriter, interopnames.SystemStorageGet) emit.Opcodes(w.BinWriter, opcode.RET) onPaymentOff := w.Len() - emit.Int(w.BinWriter, 3) + emit.Syscall(w.BinWriter, interopnames.SystemRuntimeGetCallingScriptHash) + emit.Int(w.BinWriter, 4) emit.Opcodes(w.BinWriter, opcode.PACK) emit.String(w.BinWriter, "LastPayment") emit.Syscall(w.BinWriter, interopnames.SystemRuntimeNotify) @@ -972,7 +973,7 @@ func TestContractCreateDeploy(t *testing.T) { cs.Hash = state.CreateContractHash(sender, cs.Script) v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All) - err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil) + err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty) require.NoError(t, err) require.NoError(t, v.Run()) require.Equal(t, "create", v.Estack().Pop().String()) @@ -993,7 +994,7 @@ func TestContractCreateDeploy(t *testing.T) { require.NoError(t, v.Run()) v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All) - err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil) + err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty) require.NoError(t, err) require.NoError(t, v.Run()) require.Equal(t, "update", v.Estack().Pop().String()) diff --git a/pkg/core/native/native_nep17.go b/pkg/core/native/native_nep17.go index 8fd519c01..edd97018f 100644 --- a/pkg/core/native/native_nep17.go +++ b/pkg/core/native/native_nep17.go @@ -159,7 +159,7 @@ func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint stackitem.NewBigInteger(amount), data, } - if err := contract.CallExInternal(ic, cs, manifest.MethodOnPayment, args, smartcontract.All, vm.EnsureIsEmpty, nil); err != nil { + if err := contract.CallFromNative(ic, c.Hash, cs, manifest.MethodOnPayment, args, vm.EnsureIsEmpty); err != nil { panic(err) } } diff --git a/pkg/core/native/notary.go b/pkg/core/native/notary.go index 5d84d03e9..d8e1d66f8 100644 --- a/pkg/core/native/notary.go +++ b/pkg/core/native/notary.go @@ -247,15 +247,13 @@ func (n *Notary) withdraw(ic *interop.Context, args []stackitem.Item) stackitem. panic(fmt.Errorf("failed to get GAS contract state: %w", err)) } transferArgs := []stackitem.Item{stackitem.NewByteArray(n.Hash.BytesBE()), stackitem.NewByteArray(to.BytesBE()), stackitem.NewBigInteger(deposit.Amount), stackitem.Null{}} - err = contract.CallExInternal(ic, cs, "transfer", transferArgs, smartcontract.All, vm.EnsureIsEmpty, func(ctx *vm.Context) { // we need EnsureIsEmpty because there's a callback popping result from the stack - isTransferOk := ic.VM.Estack().Pop().Bool() - if !isTransferOk { - panic("failed to transfer GAS from Notary account") - } - }) + err = contract.CallFromNative(ic, n.Hash, cs, "transfer", transferArgs, vm.EnsureNotEmpty) if err != nil { panic(fmt.Errorf("failed to transfer GAS from Notary account: %w", err)) } + if !ic.VM.Estack().Pop().Bool() { + panic("failed to transfer GAS from Notary account: `transfer` returned false") + } if err := n.removeDepositFor(ic.DAO, from); err != nil { panic(fmt.Errorf("failed to remove withdrawn deposit for %s from the storage: %w", from.StringBE(), err)) } diff --git a/pkg/core/native/oracle.go b/pkg/core/native/oracle.go index 35d405d62..7c3a21b90 100644 --- a/pkg/core/native/oracle.go +++ b/pkg/core/native/oracle.go @@ -20,6 +20,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -248,16 +249,17 @@ func (o *Oracle) FinishInternal(ic *interop.Context) error { r := io.NewBinReaderFromBuf(req.UserData) userData := stackitem.DecodeBinaryStackItem(r) - args := stackitem.NewArray([]stackitem.Item{ + args := []stackitem.Item{ stackitem.Make(req.URL), stackitem.Make(userData), stackitem.Make(resp.Code), stackitem.Make(resp.Result), - }) - ic.VM.Estack().PushVal(args) - ic.VM.Estack().PushVal(req.CallbackMethod) - ic.VM.Estack().PushVal(req.CallbackContract.BytesBE()) - return contract.Call(ic) + } + cs, err := ic.DAO.GetContractState(req.CallbackContract) + if err != nil { + return err + } + return contract.CallFromNative(ic, o.Hash, cs, req.CallbackMethod, args, vm.EnsureIsEmpty) } func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.Item { diff --git a/pkg/core/native_contract_test.go b/pkg/core/native_contract_test.go index 19f1236ce..5c09011cd 100644 --- a/pkg/core/native_contract_test.go +++ b/pkg/core/native_contract_test.go @@ -91,6 +91,21 @@ func newTestNative() *testNative { RequiredFlags: smartcontract.NoneFlag} tn.meta.AddMethod(md, desc) + desc = &manifest.Method{ + Name: "callOtherContractWithReturn", + Parameters: []manifest.Parameter{ + manifest.NewParameter("contractHash", smartcontract.Hash160Type), + manifest.NewParameter("method", smartcontract.StringType), + manifest.NewParameter("arg", smartcontract.ArrayType), + }, + ReturnType: smartcontract.IntegerType, + } + md = &interop.MethodAndPrice{ + Func: tn.callOtherContractWithReturn, + Price: testSumPrice, + RequiredFlags: smartcontract.NoneFlag} + tn.meta.AddMethod(md, desc) + desc = &manifest.Method{Name: "onPersist", ReturnType: smartcontract.BoolType} md = &interop.MethodAndPrice{Func: tn.OnPersist, RequiredFlags: smartcontract.AllowModifyStates} tn.meta.AddMethod(md, desc) @@ -122,16 +137,17 @@ func toUint160(item stackitem.Item) util.Uint160 { return u } -func (tn *testNative) call(ic *interop.Context, args []stackitem.Item, retState vm.CheckReturnState) { - cs, err := ic.DAO.GetContractState(toUint160(args[0])) - if err != nil { - panic(err) - } +func (tn *testNative) call(ic *interop.Context, args []stackitem.Item, checkReturn vm.CheckReturnState) { + h := toUint160(args[0]) bs, err := args[1].TryBytes() if err != nil { panic(err) } - err = contract.CallExInternal(ic, cs, string(bs), args[2].Value().([]stackitem.Item), smartcontract.All, retState, nil) + cs, err := ic.DAO.GetContractState(h) + if err != nil { + panic(err) + } + err = contract.CallFromNative(ic, tn.meta.Hash, cs, string(bs), args[2].Value().([]stackitem.Item), checkReturn) if err != nil { panic(err) } @@ -142,6 +158,12 @@ func (tn *testNative) callOtherContractNoReturn(ic *interop.Context, args []stac return stackitem.Null{} } +func (tn *testNative) callOtherContractWithReturn(ic *interop.Context, args []stackitem.Item) stackitem.Item { + tn.call(ic, args, vm.EnsureNotEmpty) + bi := ic.VM.Estack().Pop().BigInt() + return stackitem.Make(bi.Add(bi, big.NewInt(1))) +} + func TestNativeContract_Invoke(t *testing.T) { chain := newTestChain(t) defer chain.Close() @@ -238,4 +260,10 @@ func TestNativeContract_InvokeOtherContract(t *testing.T) { require.NoError(t, err) checkResult(t, res, stackitem.Null{}) // simple call is done with EnsureNotEmpty }) + t.Run("non-native, with return", func(t *testing.T) { + res, err := invokeContractMethod(chain, testSumPrice*4+10000, tn.Metadata().Hash, + "callOtherContractWithReturn", cs.Hash, "ret7", []interface{}{}) + require.NoError(t, err) + checkResult(t, res, stackitem.Make(8)) + }) } diff --git a/pkg/core/native_neo_test.go b/pkg/core/native_neo_test.go index aecbf4bdd..0e2da1511 100644 --- a/pkg/core/native_neo_test.go +++ b/pkg/core/native_neo_test.go @@ -317,11 +317,32 @@ func TestNEO_TransferOnPayment(t *testing.T) { aer, err := bc.GetAppExecResults(tx.Hash(), trigger.Application) require.NoError(t, err) require.Equal(t, vm.HaltState, aer[0].VMState) - require.Len(t, aer[0].Events, 3) // transfer + auto GAS claim + onPayment + require.Len(t, aer[0].Events, 3) // transfer + GAS claim for sender + onPayment e := aer[0].Events[2] require.Equal(t, "LastPayment", e.Name) arr := e.Item.Value().([]stackitem.Item) - require.Equal(t, neoOwner.BytesBE(), arr[0].Value()) - require.Equal(t, big.NewInt(amount), arr[1].Value()) + require.Equal(t, bc.contracts.NEO.Hash.BytesBE(), arr[0].Value()) + require.Equal(t, neoOwner.BytesBE(), arr[1].Value()) + require.Equal(t, big.NewInt(amount), arr[2].Value()) + + tx = transferTokenFromMultisigAccount(t, bc, cs.Hash, bc.contracts.NEO.Hash, amount) + aer, err = bc.GetAppExecResults(tx.Hash(), trigger.Application) + require.NoError(t, err) + require.Equal(t, vm.HaltState, aer[0].VMState) + // Now we must also have GAS claim for contract and corresponding `onPayment`. + require.Len(t, aer[0].Events, 5) + + e = aer[0].Events[2] // onPayment for GAS claim + require.Equal(t, "LastPayment", e.Name) + arr = e.Item.Value().([]stackitem.Item) + require.Equal(t, stackitem.Null{}, arr[1]) + require.Equal(t, bc.contracts.GAS.Hash.BytesBE(), arr[0].Value()) + + e = aer[0].Events[4] // onPayment for NEO transfer + require.Equal(t, "LastPayment", e.Name) + arr = e.Item.Value().([]stackitem.Item) + require.Equal(t, bc.contracts.NEO.Hash.BytesBE(), arr[0].Value()) + require.Equal(t, neoOwner.BytesBE(), arr[1].Value()) + require.Equal(t, big.NewInt(amount), arr[2].Value()) } diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 6cb7300a3..7d70b08e9 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -48,14 +48,6 @@ type Context struct { // Call flags this context was created with. callFlag smartcontract.CallFlag - // InvocationState contains expected return type and actions to be performed on context unload. - InvocationState -} - -// InvocationState contains return convention and callback to be executed on context unload. -type InvocationState struct { - // Callback is executed on context unload. - Callback func(ctx *Context) // CheckReturn specifies if amount of return values needs to be checked. CheckReturn CheckReturnState } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index c96ad4676..a586086b8 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -284,6 +284,7 @@ func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) { ctx.tryStack = NewStack("exception") ctx.callFlag = f ctx.static = newSlot(v.refs) + ctx.callingScriptHash = v.GetCurrentScriptHash() v.istack.PushVal(ctx) } @@ -295,11 +296,17 @@ func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) { // each other. func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) { shash := v.GetCurrentScriptHash() + v.LoadScriptWithCallingHash(shash, b, hash, f) +} + +// LoadScriptWithCallingHash is similar to LoadScriptWithHash but sets calling hash explicitly. +// It should be used for calling from native contracts. +func (v *VM) LoadScriptWithCallingHash(caller util.Uint160, b []byte, hash util.Uint160, f smartcontract.CallFlag) { v.LoadScriptWithFlags(b, f) ctx := v.Context() ctx.isDeployed = true ctx.scriptHash = hash - ctx.callingScriptHash = shash + ctx.callingScriptHash = caller } // Context returns the current executed context. Nil if there is no context, @@ -1418,9 +1425,6 @@ func (v *VM) unloadContext(ctx *Context) { if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { ctx.static.Clear() } - if ctx.Callback != nil { - ctx.Callback(ctx) - } switch ctx.CheckReturn { case NoCheck: case EnsureIsEmpty: