From 08b68e9b8200e195038ba9827c04614fac41c46c Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Mon, 23 May 2022 11:35:01 +0300 Subject: [PATCH] vm, core: push Null return value only if no exception occurs Close https://github.com/nspcc-dev/neo-go/issues/2509. --- pkg/core/blockchain.go | 2 +- pkg/core/interop/contract/call.go | 19 +-- pkg/core/interop_system_neotest_test.go | 147 ++++++++++++++++++++++++ pkg/vm/context.go | 6 + pkg/vm/vm.go | 18 ++- 5 files changed, 178 insertions(+), 14 deletions(-) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index bdd411d55..a1616389b 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -2247,7 +2247,7 @@ func (bc *Blockchain) InitVerificationContext(ic *interop.Context, hash util.Uin } ic.Invocations[cs.Hash]++ ic.VM.LoadNEFMethod(&cs.NEF, util.Uint160{}, hash, callflag.ReadOnly, - true, verifyOffset, initOffset) + true, verifyOffset, initOffset, nil) } if len(witness.InvocationScript) != 0 { err := vm.IsScriptCorrect(witness.InvocationScript, nil) diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 5627863ef..c2f8feb9b 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -40,7 +40,7 @@ func LoadToken(ic *interop.Context) func(id int32) error { if err != nil { return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err) } - return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args) + return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, nil) } } @@ -69,14 +69,17 @@ func Call(ic *interop.Context) error { return fmt.Errorf("method not found: %s/%d", method, len(args)) } hasReturn := md.ReturnType != smartcontract.VoidType + var cb vm.ContextUnloadCallback if !hasReturn { - ic.VM.Estack().PushItem(stackitem.Null{}) + cb = func(estack *vm.Stack) { + estack.PushItem(stackitem.Null{}) + } } - return callInternal(ic, cs, method, fs, hasReturn, args) + return callInternal(ic, cs, method, fs, hasReturn, args, cb) } func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, - hasReturn bool, args []stackitem.Item) error { + hasReturn bool, args []stackitem.Item, cb vm.ContextUnloadCallback) error { md := cs.Manifest.ABI.GetMethod(name, len(args)) if md.Safe { f &^= (callflag.WriteStates | callflag.AllowNotify) @@ -88,12 +91,12 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl } } } - return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn) + return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, cb) } // callExFromNative calls a contract with flags using the provided calling hash. func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, - name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error { + name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, cb vm.ContextUnloadCallback) error { for _, nc := range ic.Natives { if nc.Metadata().Name == nativenames.Policy { var pch = nc.(policyChecker) @@ -120,7 +123,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra } ic.Invocations[cs.Hash]++ ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, ic.VM.Context().GetCallFlags()&f, - hasReturn, methodOff, initOff) + hasReturn, methodOff, initOff, cb) for e, i := ic.VM.Estack(), len(args)-1; i >= 0; i-- { e.PushItem(args[i]) @@ -134,7 +137,7 @@ var ErrNativeCall = errors.New("failed native call") // CallFromNative performs synchronous call from native contract. func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error { startSize := ic.VM.Istack().Len() - if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn); err != nil { + if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, nil); err != nil { return err } diff --git a/pkg/core/interop_system_neotest_test.go b/pkg/core/interop_system_neotest_test.go index ab3820a95..1537c225c 100644 --- a/pkg/core/interop_system_neotest_test.go +++ b/pkg/core/interop_system_neotest_test.go @@ -507,3 +507,150 @@ func TestSnapshotIsolation_CallToItself(t *testing.T) { // unwrapped and persisted during the previous call. ctrInvoker.Invoke(t, stackitem.Null{}, "check") } + +// This test is written to check https://github.com/nspcc-dev/neo-go/issues/2509 +// and https://github.com/neo-project/neo/pull/2745#discussion_r879167180. +func TestRET_after_FINALLY_PanicInsideVoidMethod(t *testing.T) { + bc, acc := chain.NewSingle(t) + e := neotest.NewExecutor(t, bc, acc, acc) + + // Contract A throws catchable exception. It also has a non-void method. + srcA := `package contractA + func Panic() { + panic("panic from A") + } + func ReturnSomeValue() int { + return 5 + }` + ctrA := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcA), &compiler.Options{ + NoEventsCheck: true, + NoPermissionsCheck: true, + Name: "contractA", + }) + e.DeployContract(t, ctrA, nil) + + var hashAStr string + for i := 0; i < util.Uint160Size; i++ { + hashAStr += fmt.Sprintf("%#x", ctrA.Hash[i]) + if i != util.Uint160Size-1 { + hashAStr += ", " + } + } + // Contract B calls A and catches the exception thrown by A. + srcB := `package contractB + import ( + "github.com/nspcc-dev/neo-go/pkg/interop" + "github.com/nspcc-dev/neo-go/pkg/interop/contract" + ) + func Catch() { + defer func() { + if r := recover(); r != nil { + // Call method with return value to check https://github.com/neo-project/neo/pull/2745#discussion_r879167180. + contract.Call(interop.Hash160{` + hashAStr + `}, "returnSomeValue", contract.All) + } + }() + contract.Call(interop.Hash160{` + hashAStr + `}, "panic", contract.All) + }` + ctrB := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcB), &compiler.Options{ + Name: "contractB", + NoEventsCheck: true, + NoPermissionsCheck: true, + Permissions: []manifest.Permission{ + { + Methods: manifest.WildStrings{Value: nil}, + }, + }, + }) + e.DeployContract(t, ctrB, nil) + + ctrInvoker := e.NewInvoker(ctrB.Hash, e.Committee) + ctrInvoker.Invoke(t, stackitem.Null{}, "catch") +} + +// This test is written to check https://github.com/neo-project/neo/pull/2745#discussion_r879125733. +func TestRET_after_FINALLY_CallNonVoidAfterVoidMethod(t *testing.T) { + bc, acc := chain.NewSingle(t) + e := neotest.NewExecutor(t, bc, acc, acc) + + // Contract A has two methods. One of them has no return value, and the other has it. + srcA := `package contractA + import "github.com/nspcc-dev/neo-go/pkg/interop/runtime" + func NoRet() { + runtime.Notify("no ret") + } + func HasRet() int { + runtime.Notify("ret") + return 5 + }` + ctrA := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcA), &compiler.Options{ + NoEventsCheck: true, + NoPermissionsCheck: true, + Name: "contractA", + }) + e.DeployContract(t, ctrA, nil) + + var hashAStr string + for i := 0; i < util.Uint160Size; i++ { + hashAStr += fmt.Sprintf("%#x", ctrA.Hash[i]) + if i != util.Uint160Size-1 { + hashAStr += ", " + } + } + // Contract B calls A in try-catch block. + srcB := `package contractB + import ( + "github.com/nspcc-dev/neo-go/pkg/interop" + "github.com/nspcc-dev/neo-go/pkg/interop/contract" + "github.com/nspcc-dev/neo-go/pkg/interop/util" + ) + func CallAInTryCatch() { + defer func() { + if r := recover(); r != nil { + util.Abort() // should never happen + } + }() + contract.Call(interop.Hash160{` + hashAStr + `}, "noRet", contract.All) + contract.Call(interop.Hash160{` + hashAStr + `}, "hasRet", contract.All) + }` + ctrB := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcB), &compiler.Options{ + Name: "contractB", + NoEventsCheck: true, + NoPermissionsCheck: true, + Permissions: []manifest.Permission{ + { + Methods: manifest.WildStrings{Value: nil}, + }, + }, + }) + e.DeployContract(t, ctrB, nil) + + ctrInvoker := e.NewInvoker(ctrB.Hash, e.Committee) + h := ctrInvoker.Invoke(t, stackitem.Null{}, "callAInTryCatch") + aer := e.GetTxExecResult(t, h) + + require.Equal(t, 1, len(aer.Stack)) +} + +// This test is created to check https://github.com/neo-project/neo/pull/2755#discussion_r880087983. +func TestCALLL_from_VoidContext(t *testing.T) { + bc, acc := chain.NewSingle(t) + e := neotest.NewExecutor(t, bc, acc, acc) + + // Contract A has void method `CallHasRet` which calls non-void method `HasRet`. + srcA := `package contractA + func CallHasRet() { // Creates a context with non-nil onUnload. + HasRet() + } + func HasRet() int { // CALL_L clones parent context, check that onUnload is not cloned. + return 5 + }` + ctrA := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcA), &compiler.Options{ + NoEventsCheck: true, + NoPermissionsCheck: true, + Name: "contractA", + }) + e.DeployContract(t, ctrA, nil) + + ctrInvoker := e.NewInvoker(ctrA.Hash, e.Committee) + ctrInvoker.Invoke(t, stackitem.Null{}, "callHasRet") +} diff --git a/pkg/vm/context.go b/pkg/vm/context.go index ebf392c4e..752a0d1a8 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -63,8 +63,14 @@ type Context struct { // isWrapped tells whether the context's DAO was wrapped into another layer of // MemCachedStore on creation and whether it should be unwrapped on context unloading. isWrapped bool + // onUnload is a callback that should be called after current context unloading + // if no exception occurs. + onUnload ContextUnloadCallback } +// ContextUnloadCallback is a callback method used on context unloading from istack. +type ContextUnloadCallback func(parentEstack *Stack) + var errNoInstParam = errors.New("failed to read instruction parameter") // NewContext returns a new Context object. diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 36b289bd0..11f72ca77 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -311,7 +311,7 @@ func (v *VM) LoadScript(b []byte) { // LoadScriptWithFlags loads script and sets call flag to f. func (v *VM) LoadScriptWithFlags(b []byte, f callflag.CallFlag) { - v.loadScriptWithCallingHash(b, nil, v.GetCurrentScriptHash(), util.Uint160{}, f, -1, 0) + v.loadScriptWithCallingHash(b, nil, v.GetCurrentScriptHash(), util.Uint160{}, f, -1, 0, nil) } // LoadScriptWithHash is similar to the LoadScriptWithFlags method, but it also loads @@ -321,19 +321,19 @@ func (v *VM) LoadScriptWithFlags(b []byte, f callflag.CallFlag) { // accordingly). It's up to the user of this function to make sure the script and hash match // each other. func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f callflag.CallFlag) { - v.loadScriptWithCallingHash(b, nil, v.GetCurrentScriptHash(), hash, f, 1, 0) + v.loadScriptWithCallingHash(b, nil, v.GetCurrentScriptHash(), hash, f, 1, 0, nil) } // LoadNEFMethod allows to create a context to execute a method from the NEF // file with the specified caller and executing hash, call flags, return value, // method and _initialize offsets. func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160, f callflag.CallFlag, - hasReturn bool, methodOff int, initOff int) { + hasReturn bool, methodOff int, initOff int, onContextUnload ContextUnloadCallback) { var rvcount int if hasReturn { rvcount = 1 } - v.loadScriptWithCallingHash(exe.Script, exe, caller, hash, f, rvcount, methodOff) + v.loadScriptWithCallingHash(exe.Script, exe, caller, hash, f, rvcount, methodOff, onContextUnload) if initOff >= 0 { v.Call(initOff) } @@ -342,7 +342,7 @@ func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160 // loadScriptWithCallingHash is similar to LoadScriptWithHash but sets calling hash explicitly. // It should be used for calling from native contracts. func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160, - hash util.Uint160, f callflag.CallFlag, rvcount int, offset int) { + hash util.Uint160, f callflag.CallFlag, rvcount int, offset int, onContextUnload ContextUnloadCallback) { var sl slot v.checkInvocationStackSize() @@ -384,6 +384,7 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint } } ctx.persistNotificationsCountOnUnloading = true + ctx.onUnload = onContextUnload v.istack.PushItem(ctx) } @@ -1651,6 +1652,12 @@ func (v *VM) unloadContext(ctx *Context) { if currCtx != nil && ctx.persistNotificationsCountOnUnloading && !(ctx.isWrapped && v.uncaughtException != nil) { *currCtx.notificationsCount += *ctx.notificationsCount } + if currCtx != nil && ctx.onUnload != nil { + if v.uncaughtException == nil { + ctx.onUnload(currCtx.Estack()) // Use the estack of current context. + } + ctx.onUnload = nil + } } // getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks. @@ -1713,6 +1720,7 @@ func (v *VM) call(ctx *Context, offset int) { newCtx.notificationsCount = ctx.notificationsCount newCtx.isWrapped = false newCtx.persistNotificationsCountOnUnloading = false + newCtx.onUnload = nil v.istack.PushItem(newCtx) newCtx.Jump(offset) }