diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 4d757049e..d40878e6e 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -52,12 +52,12 @@ func callExInternal(ic *interop.Context, h []byte, name string, args []stackitem return errors.New("disallowed method call") } } - return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty) + return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty, nil) } // CallExInternal calls a contract with flags and can't be invoked directly by user. func CallExInternal(ic *interop.Context, cs *state.Contract, - name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState) error { + name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState, callback func(ctx *vm.Context)) error { md := cs.Manifest.ABI.GetMethod(name) if md == nil { return fmt.Errorf("method '%s' not found", name) @@ -88,6 +88,7 @@ func CallExInternal(ic *interop.Context, cs *state.Contract, ic.VM.Jump(ic.VM.Context(), md.Offset) } ic.VM.Context().CheckReturn = checkReturn + ic.VM.Context().Callback = callback md = cs.Manifest.ABI.GetMethod(manifest.MethodInit) if md != nil { diff --git a/pkg/core/interop_neo.go b/pkg/core/interop_neo.go index 35f9e9f0e..4e143756a 100644 --- a/pkg/core/interop_neo.go +++ b/pkg/core/interop_neo.go @@ -195,7 +195,7 @@ func callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) error { md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy) if md != nil { return contract.CallExInternal(ic, cs, manifest.MethodDeploy, - []stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty) + []stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty, nil) } return nil } diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index 5f0be0574..42802165b 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -911,7 +911,7 @@ func TestContractCreateDeploy(t *testing.T) { require.NoError(t, ic.VM.Run()) v.LoadScriptWithFlags(currCs.Script, smartcontract.All) - err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty) + err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil) require.NoError(t, err) require.NoError(t, v.Run()) require.Equal(t, "create", v.Estack().Pop().String()) @@ -932,7 +932,7 @@ func TestContractCreateDeploy(t *testing.T) { require.NoError(t, v.Run()) v.LoadScriptWithFlags(currCs.Script, smartcontract.All) - err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty) + err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil) require.NoError(t, err) require.NoError(t, v.Run()) require.Equal(t, "update", v.Estack().Pop().String()) diff --git a/pkg/core/native/native_nep17.go b/pkg/core/native/native_nep17.go index c5853503d..aa18d7ed7 100644 --- a/pkg/core/native/native_nep17.go +++ b/pkg/core/native/native_nep17.go @@ -161,7 +161,7 @@ func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint stackitem.NewBigInteger(amount), data, } - if err := contract.CallExInternal(ic, cs, manifest.MethodOnPayment, args, smartcontract.All, vm.EnsureIsEmpty); err != nil { + if err := contract.CallExInternal(ic, cs, manifest.MethodOnPayment, args, smartcontract.All, vm.EnsureIsEmpty, nil); err != nil { panic(err) } } diff --git a/pkg/core/native_contract_test.go b/pkg/core/native_contract_test.go index ac0528f91..67ced4b50 100644 --- a/pkg/core/native_contract_test.go +++ b/pkg/core/native_contract_test.go @@ -133,7 +133,7 @@ func (tn *testNative) call(ic *interop.Context, args []stackitem.Item, retState if err != nil { panic(err) } - err = contract.CallExInternal(ic, cs, string(bs), args[2].Value().([]stackitem.Item), smartcontract.All, retState) + err = contract.CallExInternal(ic, cs, string(bs), args[2].Value().([]stackitem.Item), smartcontract.All, retState, nil) if err != nil { panic(err) } diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 4017f7516..08d38b87a 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -45,6 +45,14 @@ type Context struct { // Call flags this context was created with. callFlag smartcontract.CallFlag + // InvocationState contains expected return type and actions to be performed on context unload. + InvocationState +} + +// InvocationState contains return convention and callback to be executed on context unload. +type InvocationState struct { + // Callback is executed on context unload. + Callback func(ctx *Context) // CheckReturn specifies if amount of return values needs to be checked. CheckReturn CheckReturnState } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 31a3197a0..03190378e 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1410,6 +1410,9 @@ func (v *VM) unloadContext(ctx *Context) { if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { ctx.static.Clear() } + if ctx.Callback != nil { + ctx.Callback(ctx) + } switch ctx.CheckReturn { case NoCheck: case EnsureIsEmpty: