mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-30 09:33:36 +00:00
core: add callback to VM context
This commit is contained in:
parent
c9acc43023
commit
0f68528095
7 changed files with 19 additions and 7 deletions
|
@ -52,12 +52,12 @@ func callExInternal(ic *interop.Context, h []byte, name string, args []stackitem
|
||||||
return errors.New("disallowed method call")
|
return errors.New("disallowed method call")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty)
|
return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallExInternal calls a contract with flags and can't be invoked directly by user.
|
// CallExInternal calls a contract with flags and can't be invoked directly by user.
|
||||||
func CallExInternal(ic *interop.Context, cs *state.Contract,
|
func CallExInternal(ic *interop.Context, cs *state.Contract,
|
||||||
name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState) error {
|
name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState, callback func(ctx *vm.Context)) error {
|
||||||
md := cs.Manifest.ABI.GetMethod(name)
|
md := cs.Manifest.ABI.GetMethod(name)
|
||||||
if md == nil {
|
if md == nil {
|
||||||
return fmt.Errorf("method '%s' not found", name)
|
return fmt.Errorf("method '%s' not found", name)
|
||||||
|
@ -88,6 +88,7 @@ func CallExInternal(ic *interop.Context, cs *state.Contract,
|
||||||
ic.VM.Jump(ic.VM.Context(), md.Offset)
|
ic.VM.Jump(ic.VM.Context(), md.Offset)
|
||||||
}
|
}
|
||||||
ic.VM.Context().CheckReturn = checkReturn
|
ic.VM.Context().CheckReturn = checkReturn
|
||||||
|
ic.VM.Context().Callback = callback
|
||||||
|
|
||||||
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
|
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
|
||||||
if md != nil {
|
if md != nil {
|
||||||
|
|
|
@ -195,7 +195,7 @@ func callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) error {
|
||||||
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
|
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
|
||||||
if md != nil {
|
if md != nil {
|
||||||
return contract.CallExInternal(ic, cs, manifest.MethodDeploy,
|
return contract.CallExInternal(ic, cs, manifest.MethodDeploy,
|
||||||
[]stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty)
|
[]stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty, nil)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -911,7 +911,7 @@ func TestContractCreateDeploy(t *testing.T) {
|
||||||
require.NoError(t, ic.VM.Run())
|
require.NoError(t, ic.VM.Run())
|
||||||
|
|
||||||
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
|
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
|
||||||
err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty)
|
err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, v.Run())
|
require.NoError(t, v.Run())
|
||||||
require.Equal(t, "create", v.Estack().Pop().String())
|
require.Equal(t, "create", v.Estack().Pop().String())
|
||||||
|
@ -932,7 +932,7 @@ func TestContractCreateDeploy(t *testing.T) {
|
||||||
require.NoError(t, v.Run())
|
require.NoError(t, v.Run())
|
||||||
|
|
||||||
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
|
v.LoadScriptWithFlags(currCs.Script, smartcontract.All)
|
||||||
err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty)
|
err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, v.Run())
|
require.NoError(t, v.Run())
|
||||||
require.Equal(t, "update", v.Estack().Pop().String())
|
require.Equal(t, "update", v.Estack().Pop().String())
|
||||||
|
|
|
@ -161,7 +161,7 @@ func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint
|
||||||
stackitem.NewBigInteger(amount),
|
stackitem.NewBigInteger(amount),
|
||||||
data,
|
data,
|
||||||
}
|
}
|
||||||
if err := contract.CallExInternal(ic, cs, manifest.MethodOnPayment, args, smartcontract.All, vm.EnsureIsEmpty); err != nil {
|
if err := contract.CallExInternal(ic, cs, manifest.MethodOnPayment, args, smartcontract.All, vm.EnsureIsEmpty, nil); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,7 +133,7 @@ func (tn *testNative) call(ic *interop.Context, args []stackitem.Item, retState
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
err = contract.CallExInternal(ic, cs, string(bs), args[2].Value().([]stackitem.Item), smartcontract.All, retState)
|
err = contract.CallExInternal(ic, cs, string(bs), args[2].Value().([]stackitem.Item), smartcontract.All, retState, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,14 @@ type Context struct {
|
||||||
// Call flags this context was created with.
|
// Call flags this context was created with.
|
||||||
callFlag smartcontract.CallFlag
|
callFlag smartcontract.CallFlag
|
||||||
|
|
||||||
|
// InvocationState contains expected return type and actions to be performed on context unload.
|
||||||
|
InvocationState
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvocationState contains return convention and callback to be executed on context unload.
|
||||||
|
type InvocationState struct {
|
||||||
|
// Callback is executed on context unload.
|
||||||
|
Callback func(ctx *Context)
|
||||||
// CheckReturn specifies if amount of return values needs to be checked.
|
// CheckReturn specifies if amount of return values needs to be checked.
|
||||||
CheckReturn CheckReturnState
|
CheckReturn CheckReturnState
|
||||||
}
|
}
|
||||||
|
|
|
@ -1410,6 +1410,9 @@ func (v *VM) unloadContext(ctx *Context) {
|
||||||
if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static {
|
if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static {
|
||||||
ctx.static.Clear()
|
ctx.static.Clear()
|
||||||
}
|
}
|
||||||
|
if ctx.Callback != nil {
|
||||||
|
ctx.Callback(ctx)
|
||||||
|
}
|
||||||
switch ctx.CheckReturn {
|
switch ctx.CheckReturn {
|
||||||
case NoCheck:
|
case NoCheck:
|
||||||
case EnsureIsEmpty:
|
case EnsureIsEmpty:
|
||||||
|
|
Loading…
Reference in a new issue