core, vm: move all isolation-related logic out of VM
Keep it inside the interop context.
This commit is contained in:
parent
f79f62dab4
commit
a39b7cc3fd
6 changed files with 109 additions and 126 deletions
|
@ -89,11 +89,6 @@ func (dao *Simple) GetWrapped() *Simple {
|
|||
return d
|
||||
}
|
||||
|
||||
// GetUnwrapped returns the underlying DAO. It does not perform changes persist.
|
||||
func (dao *Simple) GetUnwrapped() *Simple {
|
||||
return dao.nativeCachePS
|
||||
}
|
||||
|
||||
// GetPrivate returns a new DAO instance with another layer of private
|
||||
// MemCachedStore around the current DAO Store.
|
||||
func (dao *Simple) GetPrivate() *Simple {
|
||||
|
|
|
@ -318,33 +318,6 @@ func (ic *Context) SpawnVM() *vm.VM {
|
|||
v := vm.NewWithTrigger(ic.Trigger)
|
||||
v.GasLimit = -1
|
||||
v.SyscallHandler = ic.SyscallHandler
|
||||
wrapper := func() {
|
||||
if ic.DAO == nil {
|
||||
return
|
||||
}
|
||||
ic.DAO = ic.DAO.GetPrivate()
|
||||
}
|
||||
unwrapper := func(commit bool, ntfToRemove int) error {
|
||||
if !commit {
|
||||
have := len(ic.Notifications)
|
||||
if have < ntfToRemove {
|
||||
panic(fmt.Errorf("inconsistent notifications count: should remove %d, have %d", ntfToRemove, len(ic.Notifications)))
|
||||
}
|
||||
ic.Notifications = ic.Notifications[:have-ntfToRemove]
|
||||
}
|
||||
if ic.DAO == nil {
|
||||
return nil
|
||||
}
|
||||
if commit {
|
||||
_, err := ic.DAO.Persist()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to persist changes %w", err)
|
||||
}
|
||||
}
|
||||
ic.DAO = ic.DAO.GetUnwrapped()
|
||||
return nil
|
||||
}
|
||||
v.SetIsolationCallbacks(wrapper, unwrapper)
|
||||
ic.VM = v
|
||||
return v
|
||||
}
|
||||
|
@ -415,5 +388,4 @@ func (ic *Context) AddNotification(hash util.Uint160, name string, item *stackit
|
|||
Name: name,
|
||||
Item: item,
|
||||
})
|
||||
ic.VM.EmitNotification()
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func LoadToken(ic *interop.Context) func(id int32) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err)
|
||||
}
|
||||
return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, nil)
|
||||
return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, false)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,17 +69,11 @@ func Call(ic *interop.Context) error {
|
|||
return fmt.Errorf("method not found: %s/%d", method, len(args))
|
||||
}
|
||||
hasReturn := md.ReturnType != smartcontract.VoidType
|
||||
var cb vm.ContextUnloadCallback
|
||||
if !hasReturn {
|
||||
cb = func(estack *vm.Stack) {
|
||||
estack.PushItem(stackitem.Null{})
|
||||
}
|
||||
}
|
||||
return callInternal(ic, cs, method, fs, hasReturn, args, cb)
|
||||
return callInternal(ic, cs, method, fs, hasReturn, args, !hasReturn)
|
||||
}
|
||||
|
||||
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
|
||||
hasReturn bool, args []stackitem.Item, cb vm.ContextUnloadCallback) error {
|
||||
hasReturn bool, args []stackitem.Item, pushNullOnUnloading bool) error {
|
||||
md := cs.Manifest.ABI.GetMethod(name, len(args))
|
||||
if md.Safe {
|
||||
f &^= (callflag.WriteStates | callflag.AllowNotify)
|
||||
|
@ -91,12 +85,12 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl
|
|||
}
|
||||
}
|
||||
}
|
||||
return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, cb)
|
||||
return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, pushNullOnUnloading)
|
||||
}
|
||||
|
||||
// callExFromNative calls a contract with flags using the provided calling hash.
|
||||
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
|
||||
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, cb vm.ContextUnloadCallback) error {
|
||||
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, pushNullOnUnloading bool) error {
|
||||
for _, nc := range ic.Natives {
|
||||
if nc.Metadata().Name == nativenames.Policy {
|
||||
var pch = nc.(policyChecker)
|
||||
|
@ -122,8 +116,34 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra
|
|||
initOff = md.Offset
|
||||
}
|
||||
ic.Invocations[cs.Hash]++
|
||||
ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, ic.VM.Context().GetCallFlags()&f,
|
||||
hasReturn, methodOff, initOff, cb)
|
||||
f = ic.VM.Context().GetCallFlags() & f
|
||||
|
||||
wrapped := f&(callflag.All^callflag.ReadOnly) != 0 || // If the method is safe, then it's read-only and doesn't perform storage changes or emit notifications.
|
||||
ic.VM.Context().HasTryBlock() // If the method is not wrapped into try-catch block, then changes should be discarded anyway if exception occurs.
|
||||
baseNtfCount := len(ic.Notifications)
|
||||
baseDAO := ic.DAO
|
||||
if wrapped {
|
||||
ic.DAO = ic.DAO.GetPrivate()
|
||||
}
|
||||
onUnload := func(commit bool) error {
|
||||
if wrapped {
|
||||
if commit {
|
||||
_, err := ic.DAO.Persist()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to persist changes %w", err)
|
||||
}
|
||||
} else {
|
||||
ic.Notifications = ic.Notifications[:baseNtfCount] // Rollback all notification changes made by current context.
|
||||
}
|
||||
ic.DAO = baseDAO
|
||||
}
|
||||
if pushNullOnUnloading && commit {
|
||||
ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, f,
|
||||
hasReturn, methodOff, initOff, onUnload)
|
||||
|
||||
for e, i := ic.VM.Estack(), len(args)-1; i >= 0; i-- {
|
||||
e.PushItem(args[i])
|
||||
|
@ -137,7 +157,7 @@ var ErrNativeCall = errors.New("failed native call")
|
|||
// CallFromNative performs synchronous call from native contract.
|
||||
func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error {
|
||||
startSize := ic.VM.Istack().Len()
|
||||
if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, nil); err != nil {
|
||||
if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -456,6 +456,59 @@ func TestSnapshotIsolation_Exceptions(t *testing.T) {
|
|||
require.Equal(t, nNtfBBeforePanic+nNtfBAfterPanic, len(aer.Events))
|
||||
}
|
||||
|
||||
// This test is written to test nested calls with try-catch block and proper notifications handling.
|
||||
func TestSnapshotIsolation_NestedContextException(t *testing.T) {
|
||||
bc, acc := chain.NewSingle(t)
|
||||
e := neotest.NewExecutor(t, bc, acc, acc)
|
||||
|
||||
srcA := `package contractA
|
||||
import (
|
||||
"github.com/nspcc-dev/neo-go/pkg/interop/contract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/interop/runtime"
|
||||
)
|
||||
func CallA() {
|
||||
runtime.Notify("Calling A")
|
||||
contract.Call(runtime.GetExecutingScriptHash(), "a", contract.All)
|
||||
runtime.Notify("Finish")
|
||||
}
|
||||
func A() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
runtime.Notify("Caught")
|
||||
}
|
||||
}()
|
||||
runtime.Notify("A")
|
||||
contract.Call(runtime.GetExecutingScriptHash(), "b", contract.All)
|
||||
runtime.Notify("Unreachable A")
|
||||
}
|
||||
func B() int {
|
||||
runtime.Notify("B")
|
||||
contract.Call(runtime.GetExecutingScriptHash(), "c", contract.All)
|
||||
runtime.Notify("Unreachable B")
|
||||
return 5
|
||||
}
|
||||
func C() {
|
||||
runtime.Notify("C")
|
||||
panic("exception from C")
|
||||
}`
|
||||
ctrA := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcA), &compiler.Options{
|
||||
NoEventsCheck: true,
|
||||
NoPermissionsCheck: true,
|
||||
Name: "contractA",
|
||||
Permissions: []manifest.Permission{{Methods: manifest.WildStrings{Value: nil}}},
|
||||
})
|
||||
e.DeployContract(t, ctrA, nil)
|
||||
|
||||
ctrInvoker := e.NewInvoker(ctrA.Hash, e.Committee)
|
||||
h := ctrInvoker.Invoke(t, stackitem.Null{}, "callA")
|
||||
aer := e.GetTxExecResult(t, h)
|
||||
require.Equal(t, 4, len(aer.Events))
|
||||
require.Equal(t, "Calling A", aer.Events[0].Name)
|
||||
require.Equal(t, "A", aer.Events[1].Name)
|
||||
require.Equal(t, "Caught", aer.Events[2].Name)
|
||||
require.Equal(t, "Finish", aer.Events[3].Name)
|
||||
}
|
||||
|
||||
// This test is written to avoid https://github.com/neo-project/neo/issues/2746.
|
||||
func TestSnapshotIsolation_CallToItself(t *testing.T) {
|
||||
bc, acc := chain.NewSingle(t)
|
||||
|
|
|
@ -54,41 +54,28 @@ type Context struct {
|
|||
NEF *nef.File
|
||||
// invTree is an invocation tree (or branch of it) for this context.
|
||||
invTree *InvocationTree
|
||||
// notificationsCount stores number of notifications emitted during current context
|
||||
// handling.
|
||||
notificationsCount *int
|
||||
// persistNotificationsCountOnUnloading denotes whether notificationsCount should be
|
||||
// persisted to the parent context on current context unloading.
|
||||
persistNotificationsCountOnUnloading bool
|
||||
// isWrapped tells whether the context's DAO was wrapped into another layer of
|
||||
// MemCachedStore on creation and whether it should be unwrapped on context unloading.
|
||||
isWrapped bool
|
||||
// onUnload is a callback that should be called after current context unloading
|
||||
// if no exception occurs.
|
||||
onUnload ContextUnloadCallback
|
||||
}
|
||||
|
||||
// ContextUnloadCallback is a callback method used on context unloading from istack.
|
||||
type ContextUnloadCallback func(parentEstack *Stack)
|
||||
type ContextUnloadCallback func(commit bool) error
|
||||
|
||||
var errNoInstParam = errors.New("failed to read instruction parameter")
|
||||
|
||||
// NewContext returns a new Context object.
|
||||
func NewContext(b []byte) *Context {
|
||||
return NewContextWithParams(b, -1, 0, nil)
|
||||
return NewContextWithParams(b, -1, 0)
|
||||
}
|
||||
|
||||
// NewContextWithParams creates new Context objects using script, parameter count,
|
||||
// return value count and initial position in script.
|
||||
func NewContextWithParams(b []byte, rvcount int, pos int, notificationsCount *int) *Context {
|
||||
if notificationsCount == nil {
|
||||
notificationsCount = new(int)
|
||||
}
|
||||
func NewContextWithParams(b []byte, rvcount int, pos int) *Context {
|
||||
return &Context{
|
||||
prog: b,
|
||||
retCount: rvcount,
|
||||
nextip: pos,
|
||||
notificationsCount: notificationsCount,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -335,3 +322,13 @@ func (v *VM) PushContextScriptHash(n int) error {
|
|||
v.Estack().PushItem(stackitem.NewByteArray(h.BytesBE()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Context) HasTryBlock() bool {
|
||||
for i := 0; i < c.tryStack.Len(); i++ {
|
||||
eCtx := c.tryStack.Peek(i).Value().(*exceptionHandlingContext)
|
||||
if eCtx.State == eTry {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
66
pkg/vm/vm.go
66
pkg/vm/vm.go
|
@ -67,11 +67,6 @@ type VM struct {
|
|||
// callback to get interop price
|
||||
getPrice func(opcode.Opcode, []byte) int64
|
||||
|
||||
// wraps DAO with private MemCachedStore
|
||||
wrapDao func()
|
||||
// either commits or discards changes made in the current context; performs DAO unwrapping.
|
||||
unwrapDAO func(commit bool, notificationsCount int) error
|
||||
|
||||
istack Stack // invocation stack.
|
||||
estack *Stack // execution stack.
|
||||
|
||||
|
@ -121,24 +116,6 @@ func NewWithTrigger(t trigger.Type) *VM {
|
|||
return vm
|
||||
}
|
||||
|
||||
func (v *VM) EmitNotification() {
|
||||
currCtx := v.Context()
|
||||
if currCtx == nil {
|
||||
return
|
||||
}
|
||||
*currCtx.notificationsCount++
|
||||
}
|
||||
|
||||
// SetIsolationCallbacks registers given callbacks to perform DAO and interop context
|
||||
// isolation between contract calls.
|
||||
// wrapper performs DAO cloning;
|
||||
// committer persists changes made in the upper snapshot to the underlying DAO;
|
||||
// reverter rolls back the whole set of changes made in the current snapshot.
|
||||
func (v *VM) SetIsolationCallbacks(wrapper func(), unwrapper func(commit bool, notificationsCount int) error) {
|
||||
v.wrapDao = wrapper
|
||||
v.unwrapDAO = unwrapper
|
||||
}
|
||||
|
||||
// SetPriceGetter registers the given PriceGetterFunc in v.
|
||||
// f accepts vm's Context, current instruction and instruction parameter.
|
||||
func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) {
|
||||
|
@ -343,7 +320,7 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint
|
|||
var sl slot
|
||||
|
||||
v.checkInvocationStackSize()
|
||||
ctx := NewContextWithParams(b, rvcount, offset, nil)
|
||||
ctx := NewContextWithParams(b, rvcount, offset)
|
||||
if rvcount != -1 || v.estack.Len() != 0 {
|
||||
v.estack = newStack("evaluation", &v.refs)
|
||||
}
|
||||
|
@ -354,9 +331,9 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint
|
|||
ctx.scriptHash = hash
|
||||
ctx.callingScriptHash = caller
|
||||
ctx.NEF = exe
|
||||
parent := v.Context()
|
||||
if v.invTree != nil {
|
||||
curTree := v.invTree
|
||||
parent := v.Context()
|
||||
if parent != nil {
|
||||
curTree = parent.invTree
|
||||
}
|
||||
|
@ -364,23 +341,6 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint
|
|||
curTree.Calls = append(curTree.Calls, newTree)
|
||||
ctx.invTree = newTree
|
||||
}
|
||||
if v.wrapDao != nil {
|
||||
needWrap := f&(callflag.All^callflag.ReadOnly) != 0 // If the method is safe, then it's read-only and doesn't perform storage changes or emit notifications.
|
||||
if !needWrap && parent != nil { // If the method is not wrapped into try-catch block, then changes should be discarded anyway if exception occurs.
|
||||
for i := 0; i < parent.tryStack.Len(); i++ {
|
||||
eCtx := parent.tryStack.Peek(i).Value().(*exceptionHandlingContext)
|
||||
if eCtx.State == eTry {
|
||||
needWrap = true // TODO: is it correct to wrap it only once and break after the first occurrence?
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if needWrap {
|
||||
v.wrapDao()
|
||||
ctx.isWrapped = true
|
||||
}
|
||||
}
|
||||
ctx.persistNotificationsCountOnUnloading = true
|
||||
ctx.onUnload = onContextUnload
|
||||
v.istack.PushItem(ctx)
|
||||
}
|
||||
|
@ -1632,21 +1592,12 @@ func (v *VM) unloadContext(ctx *Context) {
|
|||
if ctx.static != nil && (currCtx == nil || ctx.static != currCtx.static) {
|
||||
ctx.static.ClearRefs(&v.refs)
|
||||
}
|
||||
if ctx.isWrapped && v.unwrapDAO != nil { // In case of CALL, CALLA, CALLL we don't need to commit/discard changes, unwrap DAO and change notificationsCount.
|
||||
err := v.unwrapDAO(v.uncaughtException == nil, *ctx.notificationsCount)
|
||||
if ctx.onUnload != nil {
|
||||
err := ctx.onUnload(v.uncaughtException == nil)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to unwrap DAO: %w", err))
|
||||
panic(fmt.Errorf("context unload callback failed: %w", err))
|
||||
}
|
||||
}
|
||||
if currCtx != nil && ctx.persistNotificationsCountOnUnloading && !(ctx.isWrapped && v.uncaughtException != nil) {
|
||||
*currCtx.notificationsCount += *ctx.notificationsCount
|
||||
}
|
||||
if currCtx != nil && ctx.onUnload != nil {
|
||||
if v.uncaughtException == nil {
|
||||
ctx.onUnload(currCtx.Estack()) // Use the estack of current context.
|
||||
}
|
||||
ctx.onUnload = nil
|
||||
}
|
||||
}
|
||||
|
||||
// getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks.
|
||||
|
@ -1703,12 +1654,7 @@ func (v *VM) call(ctx *Context, offset int) {
|
|||
newCtx.tryStack.elems = nil
|
||||
initStack(&newCtx.tryStack, "exception", nil)
|
||||
newCtx.NEF = ctx.NEF
|
||||
// Use exactly the same counter and don't use v.wrapDao() for this context.
|
||||
// Unloading of such unwrapped context will be properly handled inside
|
||||
// unloadContext without unnecessary DAO unwrapping and notificationsCount changes.
|
||||
newCtx.notificationsCount = ctx.notificationsCount
|
||||
newCtx.isWrapped = false
|
||||
newCtx.persistNotificationsCountOnUnloading = false
|
||||
// Do not clone unloading callback, new context does not require any actions to perform on unloading.
|
||||
newCtx.onUnload = nil
|
||||
v.istack.PushItem(newCtx)
|
||||
newCtx.Jump(offset)
|
||||
|
|
Loading…
Reference in a new issue