diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 279bbb5e3..aa5989c6d 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -67,11 +68,11 @@ func Call(ic *interop.Context) error { return fmt.Errorf("method not found: %s/%d", method, len(args)) } hasReturn := md.ReturnType != smartcontract.VoidType - return callInternal(ic, cs, method, fs, hasReturn, args, !hasReturn) + return callInternal(ic, cs, method, fs, hasReturn, args, true) } func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, - hasReturn bool, args []stackitem.Item, pushNullOnUnloading bool) error { + hasReturn bool, args []stackitem.Item, isDynamic bool) error { md := cs.Manifest.ABI.GetMethod(name, len(args)) if md.Safe { f &^= (callflag.WriteStates | callflag.AllowNotify) @@ -83,12 +84,12 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl } } } - return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, pushNullOnUnloading, false) + return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, isDynamic, false) } // callExFromNative calls a contract with flags using the provided calling hash. func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, - name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, pushNullOnUnloading bool, callFromNative bool) error { + name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, isDynamic bool, callFromNative bool) error { for _, nc := range ic.Natives { if nc.Metadata().Name == nativenames.Policy { var pch = nc.(policyChecker) @@ -123,7 +124,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra if wrapped { ic.DAO = ic.DAO.GetPrivate() } - onUnload := func(commit bool) error { + onUnload := func(ctx *vm.Context, commit bool) error { if wrapped { if commit { _, err := ic.DAO.Persist() @@ -135,8 +136,13 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra } ic.DAO = baseDAO } - if pushNullOnUnloading && commit { - ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack. + if isDynamic && commit { + eLen := ctx.Estack().Len() + if eLen == 0 && ctx.NumOfReturnVals() == 0 { // No return value and none expected. + ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack. + } else if eLen > 1 { // 1 or -1 (all) retrun values expected, but only one can be returned. + return errors.New("multiple return values in a cross-contract call") + } // All other rvcount/stack length mismatches are checked by the VM. } if callFromNative && !commit { return fmt.Errorf("unhandled exception") diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 31c4280f7..6c086249f 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -73,7 +73,7 @@ type Context struct { } // ContextUnloadCallback is a callback method used on context unloading from istack. -type ContextUnloadCallback func(commit bool) error +type ContextUnloadCallback func(ctx *Context, commit bool) error var errNoInstParam = errors.New("failed to read instruction parameter") @@ -239,6 +239,11 @@ func (c *Context) GetNEF() *nef.File { return c.sc.NEF } +// NumOfReturnVals returns the number of return values expected from this context. +func (c *Context) NumOfReturnVals() int { + return c.retCount +} + // Value implements the stackitem.Item interface. func (c *Context) Value() interface{} { return c diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 378a898b2..44e3f50d5 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1634,7 +1634,7 @@ func (v *VM) unloadContext(ctx *Context) { ctx.sc.static.ClearRefs(&v.refs) } if ctx.sc.onUnload != nil { - err := ctx.sc.onUnload(v.uncaughtException == nil) + err := ctx.sc.onUnload(ctx, v.uncaughtException == nil) if err != nil { errMessage := fmt.Sprintf("context unload callback failed: %s", err) if v.uncaughtException != nil {