forked from TrueCloudLab/neoneo-go
core: call from native contracts synchronously
Follow neo-project/neo#2130.
This commit is contained in:
parent
189d0d801a
commit
e903e40085
9 changed files with 87 additions and 38 deletions
|
@ -60,12 +60,18 @@ func callExInternal(ic *interop.Context, h []byte, name string, args []stackitem
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty, nil)
|
return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallExInternal calls a contract with flags and can't be invoked directly by user.
|
// CallExInternal calls a contract with flags and can't be invoked directly by user.
|
||||||
func CallExInternal(ic *interop.Context, cs *state.Contract,
|
func CallExInternal(ic *interop.Context, cs *state.Contract,
|
||||||
name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState, callback func(ctx *vm.Context)) error {
|
name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState) error {
|
||||||
|
return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, checkReturn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// callExFromNative calls a contract with flags using provided calling hash.
|
||||||
|
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
|
||||||
|
name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState) error {
|
||||||
md := cs.Manifest.ABI.GetMethod(name)
|
md := cs.Manifest.ABI.GetMethod(name)
|
||||||
if md == nil {
|
if md == nil {
|
||||||
return fmt.Errorf("method '%s' not found", name)
|
return fmt.Errorf("method '%s' not found", name)
|
||||||
|
@ -76,7 +82,7 @@ func CallExInternal(ic *interop.Context, cs *state.Contract,
|
||||||
}
|
}
|
||||||
|
|
||||||
ic.VM.Invocations[cs.Hash]++
|
ic.VM.Invocations[cs.Hash]++
|
||||||
ic.VM.LoadScriptWithHash(cs.Script, cs.Hash, ic.VM.Context().GetCallFlags()&f)
|
ic.VM.LoadScriptWithCallingHash(caller, cs.Script, cs.Hash, ic.VM.Context().GetCallFlags()&f)
|
||||||
var isNative bool
|
var isNative bool
|
||||||
for i := range ic.Natives {
|
for i := range ic.Natives {
|
||||||
if ic.Natives[i].Metadata().Hash.Equals(cs.Hash) {
|
if ic.Natives[i].Metadata().Hash.Equals(cs.Hash) {
|
||||||
|
@ -95,7 +101,6 @@ func CallExInternal(ic *interop.Context, cs *state.Contract,
|
||||||
ic.VM.Jump(ic.VM.Context(), md.Offset)
|
ic.VM.Jump(ic.VM.Context(), md.Offset)
|
||||||
}
|
}
|
||||||
ic.VM.Context().CheckReturn = checkReturn
|
ic.VM.Context().CheckReturn = checkReturn
|
||||||
ic.VM.Context().Callback = callback
|
|
||||||
|
|
||||||
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
|
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
|
||||||
if md != nil {
|
if md != nil {
|
||||||
|
@ -104,3 +109,24 @@ func CallExInternal(ic *interop.Context, cs *state.Contract,
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrNativeCall is returned for failed calls from native.
|
||||||
|
var ErrNativeCall = errors.New("error during call from native")
|
||||||
|
|
||||||
|
// CallFromNative performs synchronous call from native contract.
|
||||||
|
func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, checkReturn vm.CheckReturnState) error {
|
||||||
|
startSize := ic.VM.Istack().Len()
|
||||||
|
if err := callExFromNative(ic, caller, cs, method, args, smartcontract.All, checkReturn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for !ic.VM.HasStopped() && ic.VM.Istack().Len() > startSize {
|
||||||
|
if err := ic.VM.Step(); err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", ErrNativeCall, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ic.VM.State() == vm.FaultState {
|
||||||
|
return ErrNativeCall
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -187,7 +187,7 @@ func callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) error {
|
||||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
|
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
|
||||||
if md != nil {
|
if md != nil {
|
||||||
return contract.CallExInternal(ic, cs, manifest.MethodDeploy,
|
return contract.CallExInternal(ic, cs, manifest.MethodDeploy,
|
||||||
[]stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty, nil)
|
[]stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -972,7 +972,7 @@ func TestContractCreateDeploy(t *testing.T) {
|
||||||
|
|
||||||
cs.Hash = state.CreateContractHash(sender, cs.Script)
|
cs.Hash = state.CreateContractHash(sender, cs.Script)
|
||||||
v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All)
|
v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All)
|
||||||
err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil)
|
err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, v.Run())
|
require.NoError(t, v.Run())
|
||||||
require.Equal(t, "create", v.Estack().Pop().String())
|
require.Equal(t, "create", v.Estack().Pop().String())
|
||||||
|
@ -993,7 +993,7 @@ func TestContractCreateDeploy(t *testing.T) {
|
||||||
require.NoError(t, v.Run())
|
require.NoError(t, v.Run())
|
||||||
|
|
||||||
v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All)
|
v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All)
|
||||||
err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil)
|
err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, v.Run())
|
require.NoError(t, v.Run())
|
||||||
require.Equal(t, "update", v.Estack().Pop().String())
|
require.Equal(t, "update", v.Estack().Pop().String())
|
||||||
|
|
|
@ -159,7 +159,7 @@ func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint
|
||||||
stackitem.NewBigInteger(amount),
|
stackitem.NewBigInteger(amount),
|
||||||
data,
|
data,
|
||||||
}
|
}
|
||||||
if err := contract.CallExInternal(ic, cs, manifest.MethodOnPayment, args, smartcontract.All, vm.EnsureIsEmpty, nil); err != nil {
|
if err := contract.CallFromNative(ic, c.Hash, cs, manifest.MethodOnPayment, args, vm.EnsureIsEmpty); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -247,15 +247,13 @@ func (n *Notary) withdraw(ic *interop.Context, args []stackitem.Item) stackitem.
|
||||||
panic(fmt.Errorf("failed to get GAS contract state: %w", err))
|
panic(fmt.Errorf("failed to get GAS contract state: %w", err))
|
||||||
}
|
}
|
||||||
transferArgs := []stackitem.Item{stackitem.NewByteArray(n.Hash.BytesBE()), stackitem.NewByteArray(to.BytesBE()), stackitem.NewBigInteger(deposit.Amount), stackitem.Null{}}
|
transferArgs := []stackitem.Item{stackitem.NewByteArray(n.Hash.BytesBE()), stackitem.NewByteArray(to.BytesBE()), stackitem.NewBigInteger(deposit.Amount), stackitem.Null{}}
|
||||||
err = contract.CallExInternal(ic, cs, "transfer", transferArgs, smartcontract.All, vm.EnsureIsEmpty, func(ctx *vm.Context) { // we need EnsureIsEmpty because there's a callback popping result from the stack
|
err = contract.CallFromNative(ic, n.Hash, cs, "transfer", transferArgs, vm.EnsureNotEmpty)
|
||||||
isTransferOk := ic.VM.Estack().Pop().Bool()
|
|
||||||
if !isTransferOk {
|
|
||||||
panic("failed to transfer GAS from Notary account")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to transfer GAS from Notary account: %w", err))
|
panic(fmt.Errorf("failed to transfer GAS from Notary account: %w", err))
|
||||||
}
|
}
|
||||||
|
if !ic.VM.Estack().Pop().Bool() {
|
||||||
|
panic("failed to transfer GAS from Notary account: `transfer` returned false")
|
||||||
|
}
|
||||||
if err := n.removeDepositFor(ic.DAO, from); err != nil {
|
if err := n.removeDepositFor(ic.DAO, from); err != nil {
|
||||||
panic(fmt.Errorf("failed to remove withdrawn deposit for %s from the storage: %w", from.StringBE(), err))
|
panic(fmt.Errorf("failed to remove withdrawn deposit for %s from the storage: %w", from.StringBE(), err))
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
|
@ -248,16 +249,17 @@ func (o *Oracle) FinishInternal(ic *interop.Context) error {
|
||||||
|
|
||||||
r := io.NewBinReaderFromBuf(req.UserData)
|
r := io.NewBinReaderFromBuf(req.UserData)
|
||||||
userData := stackitem.DecodeBinaryStackItem(r)
|
userData := stackitem.DecodeBinaryStackItem(r)
|
||||||
args := stackitem.NewArray([]stackitem.Item{
|
args := []stackitem.Item{
|
||||||
stackitem.Make(req.URL),
|
stackitem.Make(req.URL),
|
||||||
stackitem.Make(userData),
|
stackitem.Make(userData),
|
||||||
stackitem.Make(resp.Code),
|
stackitem.Make(resp.Code),
|
||||||
stackitem.Make(resp.Result),
|
stackitem.Make(resp.Result),
|
||||||
})
|
}
|
||||||
ic.VM.Estack().PushVal(args)
|
cs, err := ic.DAO.GetContractState(req.CallbackContract)
|
||||||
ic.VM.Estack().PushVal(req.CallbackMethod)
|
if err != nil {
|
||||||
ic.VM.Estack().PushVal(req.CallbackContract.BytesBE())
|
return err
|
||||||
return contract.Call(ic)
|
}
|
||||||
|
return contract.CallFromNative(ic, o.Hash, cs, req.CallbackMethod, args, vm.EnsureIsEmpty)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.Item {
|
func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.Item {
|
||||||
|
|
|
@ -91,6 +91,21 @@ func newTestNative() *testNative {
|
||||||
RequiredFlags: smartcontract.NoneFlag}
|
RequiredFlags: smartcontract.NoneFlag}
|
||||||
tn.meta.AddMethod(md, desc)
|
tn.meta.AddMethod(md, desc)
|
||||||
|
|
||||||
|
desc = &manifest.Method{
|
||||||
|
Name: "callOtherContractWithReturn",
|
||||||
|
Parameters: []manifest.Parameter{
|
||||||
|
manifest.NewParameter("contractHash", smartcontract.Hash160Type),
|
||||||
|
manifest.NewParameter("method", smartcontract.StringType),
|
||||||
|
manifest.NewParameter("arg", smartcontract.ArrayType),
|
||||||
|
},
|
||||||
|
ReturnType: smartcontract.IntegerType,
|
||||||
|
}
|
||||||
|
md = &interop.MethodAndPrice{
|
||||||
|
Func: tn.callOtherContractWithReturn,
|
||||||
|
Price: testSumPrice,
|
||||||
|
RequiredFlags: smartcontract.NoneFlag}
|
||||||
|
tn.meta.AddMethod(md, desc)
|
||||||
|
|
||||||
desc = &manifest.Method{Name: "onPersist", ReturnType: smartcontract.BoolType}
|
desc = &manifest.Method{Name: "onPersist", ReturnType: smartcontract.BoolType}
|
||||||
md = &interop.MethodAndPrice{Func: tn.OnPersist, RequiredFlags: smartcontract.AllowModifyStates}
|
md = &interop.MethodAndPrice{Func: tn.OnPersist, RequiredFlags: smartcontract.AllowModifyStates}
|
||||||
tn.meta.AddMethod(md, desc)
|
tn.meta.AddMethod(md, desc)
|
||||||
|
@ -122,16 +137,17 @@ func toUint160(item stackitem.Item) util.Uint160 {
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tn *testNative) call(ic *interop.Context, args []stackitem.Item, retState vm.CheckReturnState) {
|
func (tn *testNative) call(ic *interop.Context, args []stackitem.Item, checkReturn vm.CheckReturnState) {
|
||||||
cs, err := ic.DAO.GetContractState(toUint160(args[0]))
|
h := toUint160(args[0])
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
bs, err := args[1].TryBytes()
|
bs, err := args[1].TryBytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
err = contract.CallExInternal(ic, cs, string(bs), args[2].Value().([]stackitem.Item), smartcontract.All, retState, nil)
|
cs, err := ic.DAO.GetContractState(h)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = contract.CallFromNative(ic, tn.meta.Hash, cs, string(bs), args[2].Value().([]stackitem.Item), checkReturn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -142,6 +158,12 @@ func (tn *testNative) callOtherContractNoReturn(ic *interop.Context, args []stac
|
||||||
return stackitem.Null{}
|
return stackitem.Null{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tn *testNative) callOtherContractWithReturn(ic *interop.Context, args []stackitem.Item) stackitem.Item {
|
||||||
|
tn.call(ic, args, vm.EnsureNotEmpty)
|
||||||
|
bi := ic.VM.Estack().Pop().BigInt()
|
||||||
|
return stackitem.Make(bi.Add(bi, big.NewInt(1)))
|
||||||
|
}
|
||||||
|
|
||||||
func TestNativeContract_Invoke(t *testing.T) {
|
func TestNativeContract_Invoke(t *testing.T) {
|
||||||
chain := newTestChain(t)
|
chain := newTestChain(t)
|
||||||
defer chain.Close()
|
defer chain.Close()
|
||||||
|
@ -238,4 +260,10 @@ func TestNativeContract_InvokeOtherContract(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
checkResult(t, res, stackitem.Null{}) // simple call is done with EnsureNotEmpty
|
checkResult(t, res, stackitem.Null{}) // simple call is done with EnsureNotEmpty
|
||||||
})
|
})
|
||||||
|
t.Run("non-native, with return", func(t *testing.T) {
|
||||||
|
res, err := invokeContractMethod(chain, testSumPrice*4+10000, tn.Metadata().Hash,
|
||||||
|
"callOtherContractWithReturn", cs.Hash, "ret7", []interface{}{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
checkResult(t, res, stackitem.Make(8))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,14 +48,6 @@ type Context struct {
|
||||||
// Call flags this context was created with.
|
// Call flags this context was created with.
|
||||||
callFlag smartcontract.CallFlag
|
callFlag smartcontract.CallFlag
|
||||||
|
|
||||||
// InvocationState contains expected return type and actions to be performed on context unload.
|
|
||||||
InvocationState
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvocationState contains return convention and callback to be executed on context unload.
|
|
||||||
type InvocationState struct {
|
|
||||||
// Callback is executed on context unload.
|
|
||||||
Callback func(ctx *Context)
|
|
||||||
// CheckReturn specifies if amount of return values needs to be checked.
|
// CheckReturn specifies if amount of return values needs to be checked.
|
||||||
CheckReturn CheckReturnState
|
CheckReturn CheckReturnState
|
||||||
}
|
}
|
||||||
|
|
11
pkg/vm/vm.go
11
pkg/vm/vm.go
|
@ -295,11 +295,17 @@ func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) {
|
||||||
// each other.
|
// each other.
|
||||||
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) {
|
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) {
|
||||||
shash := v.GetCurrentScriptHash()
|
shash := v.GetCurrentScriptHash()
|
||||||
|
v.LoadScriptWithCallingHash(shash, b, hash, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadScriptWithCallingHash is similar to LoadScriptWithHash but sets calling hash explicitly.
|
||||||
|
// It should be used for calling from native contracts.
|
||||||
|
func (v *VM) LoadScriptWithCallingHash(caller util.Uint160, b []byte, hash util.Uint160, f smartcontract.CallFlag) {
|
||||||
v.LoadScriptWithFlags(b, f)
|
v.LoadScriptWithFlags(b, f)
|
||||||
ctx := v.Context()
|
ctx := v.Context()
|
||||||
ctx.isDeployed = true
|
ctx.isDeployed = true
|
||||||
ctx.scriptHash = hash
|
ctx.scriptHash = hash
|
||||||
ctx.callingScriptHash = shash
|
ctx.callingScriptHash = caller
|
||||||
}
|
}
|
||||||
|
|
||||||
// Context returns the current executed context. Nil if there is no context,
|
// Context returns the current executed context. Nil if there is no context,
|
||||||
|
@ -1418,9 +1424,6 @@ func (v *VM) unloadContext(ctx *Context) {
|
||||||
if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static {
|
if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static {
|
||||||
ctx.static.Clear()
|
ctx.static.Clear()
|
||||||
}
|
}
|
||||||
if ctx.Callback != nil {
|
|
||||||
ctx.Callback(ctx)
|
|
||||||
}
|
|
||||||
switch ctx.CheckReturn {
|
switch ctx.CheckReturn {
|
||||||
case NoCheck:
|
case NoCheck:
|
||||||
case EnsureIsEmpty:
|
case EnsureIsEmpty:
|
||||||
|
|
Loading…
Reference in a new issue