vm, core: push Null return value only if no exception occurs

Close https://github.com/nspcc-dev/neo-go/issues/2509.
This commit is contained in:
Anna Shaleva 2022-05-23 11:35:01 +03:00
parent ce226f6b76
commit 08b68e9b82
5 changed files with 178 additions and 14 deletions

View file

@ -2247,7 +2247,7 @@ func (bc *Blockchain) InitVerificationContext(ic *interop.Context, hash util.Uin
} }
ic.Invocations[cs.Hash]++ ic.Invocations[cs.Hash]++
ic.VM.LoadNEFMethod(&cs.NEF, util.Uint160{}, hash, callflag.ReadOnly, ic.VM.LoadNEFMethod(&cs.NEF, util.Uint160{}, hash, callflag.ReadOnly,
true, verifyOffset, initOffset) true, verifyOffset, initOffset, nil)
} }
if len(witness.InvocationScript) != 0 { if len(witness.InvocationScript) != 0 {
err := vm.IsScriptCorrect(witness.InvocationScript, nil) err := vm.IsScriptCorrect(witness.InvocationScript, nil)

View file

@ -40,7 +40,7 @@ func LoadToken(ic *interop.Context) func(id int32) error {
if err != nil { if err != nil {
return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err) return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err)
} }
return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args) return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, nil)
} }
} }
@ -69,14 +69,17 @@ func Call(ic *interop.Context) error {
return fmt.Errorf("method not found: %s/%d", method, len(args)) return fmt.Errorf("method not found: %s/%d", method, len(args))
} }
hasReturn := md.ReturnType != smartcontract.VoidType hasReturn := md.ReturnType != smartcontract.VoidType
var cb vm.ContextUnloadCallback
if !hasReturn { if !hasReturn {
ic.VM.Estack().PushItem(stackitem.Null{}) cb = func(estack *vm.Stack) {
estack.PushItem(stackitem.Null{})
}
} }
return callInternal(ic, cs, method, fs, hasReturn, args) return callInternal(ic, cs, method, fs, hasReturn, args, cb)
} }
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
hasReturn bool, args []stackitem.Item) error { hasReturn bool, args []stackitem.Item, cb vm.ContextUnloadCallback) error {
md := cs.Manifest.ABI.GetMethod(name, len(args)) md := cs.Manifest.ABI.GetMethod(name, len(args))
if md.Safe { if md.Safe {
f &^= (callflag.WriteStates | callflag.AllowNotify) f &^= (callflag.WriteStates | callflag.AllowNotify)
@ -88,12 +91,12 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl
} }
} }
} }
return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn) return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, cb)
} }
// callExFromNative calls a contract with flags using the provided calling hash. // callExFromNative calls a contract with flags using the provided calling hash.
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error { name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, cb vm.ContextUnloadCallback) error {
for _, nc := range ic.Natives { for _, nc := range ic.Natives {
if nc.Metadata().Name == nativenames.Policy { if nc.Metadata().Name == nativenames.Policy {
var pch = nc.(policyChecker) var pch = nc.(policyChecker)
@ -120,7 +123,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra
} }
ic.Invocations[cs.Hash]++ ic.Invocations[cs.Hash]++
ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, ic.VM.Context().GetCallFlags()&f, ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, ic.VM.Context().GetCallFlags()&f,
hasReturn, methodOff, initOff) hasReturn, methodOff, initOff, cb)
for e, i := ic.VM.Estack(), len(args)-1; i >= 0; i-- { for e, i := ic.VM.Estack(), len(args)-1; i >= 0; i-- {
e.PushItem(args[i]) e.PushItem(args[i])
@ -134,7 +137,7 @@ var ErrNativeCall = errors.New("failed native call")
// CallFromNative performs synchronous call from native contract. // CallFromNative performs synchronous call from native contract.
func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error { func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error {
startSize := ic.VM.Istack().Len() startSize := ic.VM.Istack().Len()
if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn); err != nil { if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, nil); err != nil {
return err return err
} }

View file

@ -507,3 +507,150 @@ func TestSnapshotIsolation_CallToItself(t *testing.T) {
// unwrapped and persisted during the previous call. // unwrapped and persisted during the previous call.
ctrInvoker.Invoke(t, stackitem.Null{}, "check") ctrInvoker.Invoke(t, stackitem.Null{}, "check")
} }
// This test is written to check https://github.com/nspcc-dev/neo-go/issues/2509
// and https://github.com/neo-project/neo/pull/2745#discussion_r879167180.
func TestRET_after_FINALLY_PanicInsideVoidMethod(t *testing.T) {
bc, acc := chain.NewSingle(t)
e := neotest.NewExecutor(t, bc, acc, acc)
// Contract A throws catchable exception. It also has a non-void method.
srcA := `package contractA
func Panic() {
panic("panic from A")
}
func ReturnSomeValue() int {
return 5
}`
ctrA := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcA), &compiler.Options{
NoEventsCheck: true,
NoPermissionsCheck: true,
Name: "contractA",
})
e.DeployContract(t, ctrA, nil)
var hashAStr string
for i := 0; i < util.Uint160Size; i++ {
hashAStr += fmt.Sprintf("%#x", ctrA.Hash[i])
if i != util.Uint160Size-1 {
hashAStr += ", "
}
}
// Contract B calls A and catches the exception thrown by A.
srcB := `package contractB
import (
"github.com/nspcc-dev/neo-go/pkg/interop"
"github.com/nspcc-dev/neo-go/pkg/interop/contract"
)
func Catch() {
defer func() {
if r := recover(); r != nil {
// Call method with return value to check https://github.com/neo-project/neo/pull/2745#discussion_r879167180.
contract.Call(interop.Hash160{` + hashAStr + `}, "returnSomeValue", contract.All)
}
}()
contract.Call(interop.Hash160{` + hashAStr + `}, "panic", contract.All)
}`
ctrB := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcB), &compiler.Options{
Name: "contractB",
NoEventsCheck: true,
NoPermissionsCheck: true,
Permissions: []manifest.Permission{
{
Methods: manifest.WildStrings{Value: nil},
},
},
})
e.DeployContract(t, ctrB, nil)
ctrInvoker := e.NewInvoker(ctrB.Hash, e.Committee)
ctrInvoker.Invoke(t, stackitem.Null{}, "catch")
}
// This test is written to check https://github.com/neo-project/neo/pull/2745#discussion_r879125733.
func TestRET_after_FINALLY_CallNonVoidAfterVoidMethod(t *testing.T) {
bc, acc := chain.NewSingle(t)
e := neotest.NewExecutor(t, bc, acc, acc)
// Contract A has two methods. One of them has no return value, and the other has it.
srcA := `package contractA
import "github.com/nspcc-dev/neo-go/pkg/interop/runtime"
func NoRet() {
runtime.Notify("no ret")
}
func HasRet() int {
runtime.Notify("ret")
return 5
}`
ctrA := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcA), &compiler.Options{
NoEventsCheck: true,
NoPermissionsCheck: true,
Name: "contractA",
})
e.DeployContract(t, ctrA, nil)
var hashAStr string
for i := 0; i < util.Uint160Size; i++ {
hashAStr += fmt.Sprintf("%#x", ctrA.Hash[i])
if i != util.Uint160Size-1 {
hashAStr += ", "
}
}
// Contract B calls A in try-catch block.
srcB := `package contractB
import (
"github.com/nspcc-dev/neo-go/pkg/interop"
"github.com/nspcc-dev/neo-go/pkg/interop/contract"
"github.com/nspcc-dev/neo-go/pkg/interop/util"
)
func CallAInTryCatch() {
defer func() {
if r := recover(); r != nil {
util.Abort() // should never happen
}
}()
contract.Call(interop.Hash160{` + hashAStr + `}, "noRet", contract.All)
contract.Call(interop.Hash160{` + hashAStr + `}, "hasRet", contract.All)
}`
ctrB := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcB), &compiler.Options{
Name: "contractB",
NoEventsCheck: true,
NoPermissionsCheck: true,
Permissions: []manifest.Permission{
{
Methods: manifest.WildStrings{Value: nil},
},
},
})
e.DeployContract(t, ctrB, nil)
ctrInvoker := e.NewInvoker(ctrB.Hash, e.Committee)
h := ctrInvoker.Invoke(t, stackitem.Null{}, "callAInTryCatch")
aer := e.GetTxExecResult(t, h)
require.Equal(t, 1, len(aer.Stack))
}
// This test is created to check https://github.com/neo-project/neo/pull/2755#discussion_r880087983.
func TestCALLL_from_VoidContext(t *testing.T) {
bc, acc := chain.NewSingle(t)
e := neotest.NewExecutor(t, bc, acc, acc)
// Contract A has void method `CallHasRet` which calls non-void method `HasRet`.
srcA := `package contractA
func CallHasRet() { // Creates a context with non-nil onUnload.
HasRet()
}
func HasRet() int { // CALL_L clones parent context, check that onUnload is not cloned.
return 5
}`
ctrA := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcA), &compiler.Options{
NoEventsCheck: true,
NoPermissionsCheck: true,
Name: "contractA",
})
e.DeployContract(t, ctrA, nil)
ctrInvoker := e.NewInvoker(ctrA.Hash, e.Committee)
ctrInvoker.Invoke(t, stackitem.Null{}, "callHasRet")
}

View file

@ -63,8 +63,14 @@ type Context struct {
// isWrapped tells whether the context's DAO was wrapped into another layer of // isWrapped tells whether the context's DAO was wrapped into another layer of
// MemCachedStore on creation and whether it should be unwrapped on context unloading. // MemCachedStore on creation and whether it should be unwrapped on context unloading.
isWrapped bool isWrapped bool
// onUnload is a callback that should be called after current context unloading
// if no exception occurs.
onUnload ContextUnloadCallback
} }
// ContextUnloadCallback is a callback method used on context unloading from istack.
type ContextUnloadCallback func(parentEstack *Stack)
var errNoInstParam = errors.New("failed to read instruction parameter") var errNoInstParam = errors.New("failed to read instruction parameter")
// NewContext returns a new Context object. // NewContext returns a new Context object.

View file

@ -311,7 +311,7 @@ func (v *VM) LoadScript(b []byte) {
// LoadScriptWithFlags loads script and sets call flag to f. // LoadScriptWithFlags loads script and sets call flag to f.
func (v *VM) LoadScriptWithFlags(b []byte, f callflag.CallFlag) { func (v *VM) LoadScriptWithFlags(b []byte, f callflag.CallFlag) {
v.loadScriptWithCallingHash(b, nil, v.GetCurrentScriptHash(), util.Uint160{}, f, -1, 0) v.loadScriptWithCallingHash(b, nil, v.GetCurrentScriptHash(), util.Uint160{}, f, -1, 0, nil)
} }
// LoadScriptWithHash is similar to the LoadScriptWithFlags method, but it also loads // LoadScriptWithHash is similar to the LoadScriptWithFlags method, but it also loads
@ -321,19 +321,19 @@ func (v *VM) LoadScriptWithFlags(b []byte, f callflag.CallFlag) {
// accordingly). It's up to the user of this function to make sure the script and hash match // accordingly). It's up to the user of this function to make sure the script and hash match
// each other. // each other.
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f callflag.CallFlag) { func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f callflag.CallFlag) {
v.loadScriptWithCallingHash(b, nil, v.GetCurrentScriptHash(), hash, f, 1, 0) v.loadScriptWithCallingHash(b, nil, v.GetCurrentScriptHash(), hash, f, 1, 0, nil)
} }
// LoadNEFMethod allows to create a context to execute a method from the NEF // LoadNEFMethod allows to create a context to execute a method from the NEF
// file with the specified caller and executing hash, call flags, return value, // file with the specified caller and executing hash, call flags, return value,
// method and _initialize offsets. // method and _initialize offsets.
func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160, f callflag.CallFlag, func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160, f callflag.CallFlag,
hasReturn bool, methodOff int, initOff int) { hasReturn bool, methodOff int, initOff int, onContextUnload ContextUnloadCallback) {
var rvcount int var rvcount int
if hasReturn { if hasReturn {
rvcount = 1 rvcount = 1
} }
v.loadScriptWithCallingHash(exe.Script, exe, caller, hash, f, rvcount, methodOff) v.loadScriptWithCallingHash(exe.Script, exe, caller, hash, f, rvcount, methodOff, onContextUnload)
if initOff >= 0 { if initOff >= 0 {
v.Call(initOff) v.Call(initOff)
} }
@ -342,7 +342,7 @@ func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160
// loadScriptWithCallingHash is similar to LoadScriptWithHash but sets calling hash explicitly. // loadScriptWithCallingHash is similar to LoadScriptWithHash but sets calling hash explicitly.
// It should be used for calling from native contracts. // It should be used for calling from native contracts.
func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160, func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160,
hash util.Uint160, f callflag.CallFlag, rvcount int, offset int) { hash util.Uint160, f callflag.CallFlag, rvcount int, offset int, onContextUnload ContextUnloadCallback) {
var sl slot var sl slot
v.checkInvocationStackSize() v.checkInvocationStackSize()
@ -384,6 +384,7 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint
} }
} }
ctx.persistNotificationsCountOnUnloading = true ctx.persistNotificationsCountOnUnloading = true
ctx.onUnload = onContextUnload
v.istack.PushItem(ctx) v.istack.PushItem(ctx)
} }
@ -1651,6 +1652,12 @@ func (v *VM) unloadContext(ctx *Context) {
if currCtx != nil && ctx.persistNotificationsCountOnUnloading && !(ctx.isWrapped && v.uncaughtException != nil) { if currCtx != nil && ctx.persistNotificationsCountOnUnloading && !(ctx.isWrapped && v.uncaughtException != nil) {
*currCtx.notificationsCount += *ctx.notificationsCount *currCtx.notificationsCount += *ctx.notificationsCount
} }
if currCtx != nil && ctx.onUnload != nil {
if v.uncaughtException == nil {
ctx.onUnload(currCtx.Estack()) // Use the estack of current context.
}
ctx.onUnload = nil
}
} }
// getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks. // getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks.
@ -1713,6 +1720,7 @@ func (v *VM) call(ctx *Context, offset int) {
newCtx.notificationsCount = ctx.notificationsCount newCtx.notificationsCount = ctx.notificationsCount
newCtx.isWrapped = false newCtx.isWrapped = false
newCtx.persistNotificationsCountOnUnloading = false newCtx.persistNotificationsCountOnUnloading = false
newCtx.onUnload = nil
v.istack.PushItem(newCtx) v.istack.PushItem(newCtx)
newCtx.Jump(offset) newCtx.Jump(offset)
} }