From e8d2277fe56201a00d19e4cfae807f1d6032d0bb Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Thu, 4 Aug 2022 18:17:32 +0300 Subject: [PATCH] contract/vm: only push NULL after call in dynamic contexts And determine the need for Null dynamically. For some reason the only dynamic context is Contract.Call. CALLT is not dynamic and neither is a call from native contract, go figure... --- pkg/core/interop/contract/call.go | 20 +++++++++++++------- pkg/vm/context.go | 7 ++++++- pkg/vm/vm.go | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 279bbb5e3..aa5989c6d 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -67,11 +68,11 @@ func Call(ic *interop.Context) error { return fmt.Errorf("method not found: %s/%d", method, len(args)) } hasReturn := md.ReturnType != smartcontract.VoidType - return callInternal(ic, cs, method, fs, hasReturn, args, !hasReturn) + return callInternal(ic, cs, method, fs, hasReturn, args, true) } func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, - hasReturn bool, args []stackitem.Item, pushNullOnUnloading bool) error { + hasReturn bool, args []stackitem.Item, isDynamic bool) error { md := cs.Manifest.ABI.GetMethod(name, len(args)) if md.Safe { f &^= (callflag.WriteStates | callflag.AllowNotify) @@ -83,12 +84,12 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl } } } - return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, pushNullOnUnloading, false) + return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, isDynamic, false) } // callExFromNative calls a contract with flags using the provided calling hash. func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, - name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, pushNullOnUnloading bool, callFromNative bool) error { + name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, isDynamic bool, callFromNative bool) error { for _, nc := range ic.Natives { if nc.Metadata().Name == nativenames.Policy { var pch = nc.(policyChecker) @@ -123,7 +124,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra if wrapped { ic.DAO = ic.DAO.GetPrivate() } - onUnload := func(commit bool) error { + onUnload := func(ctx *vm.Context, commit bool) error { if wrapped { if commit { _, err := ic.DAO.Persist() @@ -135,8 +136,13 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra } ic.DAO = baseDAO } - if pushNullOnUnloading && commit { - ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack. + if isDynamic && commit { + eLen := ctx.Estack().Len() + if eLen == 0 && ctx.NumOfReturnVals() == 0 { // No return value and none expected. + ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack. + } else if eLen > 1 { // 1 or -1 (all) retrun values expected, but only one can be returned. + return errors.New("multiple return values in a cross-contract call") + } // All other rvcount/stack length mismatches are checked by the VM. } if callFromNative && !commit { return fmt.Errorf("unhandled exception") diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 31c4280f7..6c086249f 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -73,7 +73,7 @@ type Context struct { } // ContextUnloadCallback is a callback method used on context unloading from istack. -type ContextUnloadCallback func(commit bool) error +type ContextUnloadCallback func(ctx *Context, commit bool) error var errNoInstParam = errors.New("failed to read instruction parameter") @@ -239,6 +239,11 @@ func (c *Context) GetNEF() *nef.File { return c.sc.NEF } +// NumOfReturnVals returns the number of return values expected from this context. +func (c *Context) NumOfReturnVals() int { + return c.retCount +} + // Value implements the stackitem.Item interface. func (c *Context) Value() interface{} { return c diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 378a898b2..44e3f50d5 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1634,7 +1634,7 @@ func (v *VM) unloadContext(ctx *Context) { ctx.sc.static.ClearRefs(&v.refs) } if ctx.sc.onUnload != nil { - err := ctx.sc.onUnload(v.uncaughtException == nil) + err := ctx.sc.onUnload(ctx, v.uncaughtException == nil) if err != nil { errMessage := fmt.Sprintf("context unload callback failed: %s", err) if v.uncaughtException != nil {