diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index 5fdac1576..daf2e60c2 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -89,11 +89,6 @@ func (dao *Simple) GetWrapped() *Simple { return d } -// GetUnwrapped returns the underlying DAO. It does not perform changes persist. -func (dao *Simple) GetUnwrapped() *Simple { - return dao.nativeCachePS -} - // GetPrivate returns a new DAO instance with another layer of private // MemCachedStore around the current DAO Store. func (dao *Simple) GetPrivate() *Simple { diff --git a/pkg/core/interop/context.go b/pkg/core/interop/context.go index 3ff7f063c..9444fd3d4 100644 --- a/pkg/core/interop/context.go +++ b/pkg/core/interop/context.go @@ -318,33 +318,6 @@ func (ic *Context) SpawnVM() *vm.VM { v := vm.NewWithTrigger(ic.Trigger) v.GasLimit = -1 v.SyscallHandler = ic.SyscallHandler - wrapper := func() { - if ic.DAO == nil { - return - } - ic.DAO = ic.DAO.GetPrivate() - } - unwrapper := func(commit bool, ntfToRemove int) error { - if !commit { - have := len(ic.Notifications) - if have < ntfToRemove { - panic(fmt.Errorf("inconsistent notifications count: should remove %d, have %d", ntfToRemove, len(ic.Notifications))) - } - ic.Notifications = ic.Notifications[:have-ntfToRemove] - } - if ic.DAO == nil { - return nil - } - if commit { - _, err := ic.DAO.Persist() - if err != nil { - return fmt.Errorf("failed to persist changes %w", err) - } - } - ic.DAO = ic.DAO.GetUnwrapped() - return nil - } - v.SetIsolationCallbacks(wrapper, unwrapper) ic.VM = v return v } @@ -415,5 +388,4 @@ func (ic *Context) AddNotification(hash util.Uint160, name string, item *stackit Name: name, Item: item, }) - ic.VM.EmitNotification() } diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index c2f8feb9b..64e3c6731 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -40,7 +40,7 @@ func LoadToken(ic *interop.Context) func(id int32) error { if err != nil { return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err) } - return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, nil) + return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, false) } } @@ -69,17 +69,11 @@ func Call(ic *interop.Context) error { return fmt.Errorf("method not found: %s/%d", method, len(args)) } hasReturn := md.ReturnType != smartcontract.VoidType - var cb vm.ContextUnloadCallback - if !hasReturn { - cb = func(estack *vm.Stack) { - estack.PushItem(stackitem.Null{}) - } - } - return callInternal(ic, cs, method, fs, hasReturn, args, cb) + return callInternal(ic, cs, method, fs, hasReturn, args, !hasReturn) } func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, - hasReturn bool, args []stackitem.Item, cb vm.ContextUnloadCallback) error { + hasReturn bool, args []stackitem.Item, pushNullOnUnloading bool) error { md := cs.Manifest.ABI.GetMethod(name, len(args)) if md.Safe { f &^= (callflag.WriteStates | callflag.AllowNotify) @@ -91,12 +85,12 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl } } } - return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, cb) + return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, pushNullOnUnloading) } // callExFromNative calls a contract with flags using the provided calling hash. func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, - name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, cb vm.ContextUnloadCallback) error { + name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, pushNullOnUnloading bool) error { for _, nc := range ic.Natives { if nc.Metadata().Name == nativenames.Policy { var pch = nc.(policyChecker) @@ -122,8 +116,34 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra initOff = md.Offset } ic.Invocations[cs.Hash]++ - ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, ic.VM.Context().GetCallFlags()&f, - hasReturn, methodOff, initOff, cb) + f = ic.VM.Context().GetCallFlags() & f + + wrapped := f&(callflag.All^callflag.ReadOnly) != 0 || // If the method is safe, then it's read-only and doesn't perform storage changes or emit notifications. + ic.VM.Context().HasTryBlock() // If the method is not wrapped into try-catch block, then changes should be discarded anyway if exception occurs. + baseNtfCount := len(ic.Notifications) + baseDAO := ic.DAO + if wrapped { + ic.DAO = ic.DAO.GetPrivate() + } + onUnload := func(commit bool) error { + if wrapped { + if commit { + _, err := ic.DAO.Persist() + if err != nil { + return fmt.Errorf("failed to persist changes %w", err) + } + } else { + ic.Notifications = ic.Notifications[:baseNtfCount] // Rollback all notification changes made by current context. + } + ic.DAO = baseDAO + } + if pushNullOnUnloading && commit { + ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack. + } + return nil + } + ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, f, + hasReturn, methodOff, initOff, onUnload) for e, i := ic.VM.Estack(), len(args)-1; i >= 0; i-- { e.PushItem(args[i]) @@ -137,7 +157,7 @@ var ErrNativeCall = errors.New("failed native call") // CallFromNative performs synchronous call from native contract. func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error { startSize := ic.VM.Istack().Len() - if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, nil); err != nil { + if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, false); err != nil { return err } diff --git a/pkg/core/interop_system_neotest_test.go b/pkg/core/interop_system_neotest_test.go index 1537c225c..39103604c 100644 --- a/pkg/core/interop_system_neotest_test.go +++ b/pkg/core/interop_system_neotest_test.go @@ -456,6 +456,59 @@ func TestSnapshotIsolation_Exceptions(t *testing.T) { require.Equal(t, nNtfBBeforePanic+nNtfBAfterPanic, len(aer.Events)) } +// This test is written to test nested calls with try-catch block and proper notifications handling. +func TestSnapshotIsolation_NestedContextException(t *testing.T) { + bc, acc := chain.NewSingle(t) + e := neotest.NewExecutor(t, bc, acc, acc) + + srcA := `package contractA + import ( + "github.com/nspcc-dev/neo-go/pkg/interop/contract" + "github.com/nspcc-dev/neo-go/pkg/interop/runtime" + ) + func CallA() { + runtime.Notify("Calling A") + contract.Call(runtime.GetExecutingScriptHash(), "a", contract.All) + runtime.Notify("Finish") + } + func A() { + defer func() { + if r := recover(); r != nil { + runtime.Notify("Caught") + } + }() + runtime.Notify("A") + contract.Call(runtime.GetExecutingScriptHash(), "b", contract.All) + runtime.Notify("Unreachable A") + } + func B() int { + runtime.Notify("B") + contract.Call(runtime.GetExecutingScriptHash(), "c", contract.All) + runtime.Notify("Unreachable B") + return 5 + } + func C() { + runtime.Notify("C") + panic("exception from C") + }` + ctrA := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcA), &compiler.Options{ + NoEventsCheck: true, + NoPermissionsCheck: true, + Name: "contractA", + Permissions: []manifest.Permission{{Methods: manifest.WildStrings{Value: nil}}}, + }) + e.DeployContract(t, ctrA, nil) + + ctrInvoker := e.NewInvoker(ctrA.Hash, e.Committee) + h := ctrInvoker.Invoke(t, stackitem.Null{}, "callA") + aer := e.GetTxExecResult(t, h) + require.Equal(t, 4, len(aer.Events)) + require.Equal(t, "Calling A", aer.Events[0].Name) + require.Equal(t, "A", aer.Events[1].Name) + require.Equal(t, "Caught", aer.Events[2].Name) + require.Equal(t, "Finish", aer.Events[3].Name) +} + // This test is written to avoid https://github.com/neo-project/neo/issues/2746. func TestSnapshotIsolation_CallToItself(t *testing.T) { bc, acc := chain.NewSingle(t) diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 752a0d1a8..c468e115e 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -54,41 +54,28 @@ type Context struct { NEF *nef.File // invTree is an invocation tree (or branch of it) for this context. invTree *InvocationTree - // notificationsCount stores number of notifications emitted during current context - // handling. - notificationsCount *int - // persistNotificationsCountOnUnloading denotes whether notificationsCount should be - // persisted to the parent context on current context unloading. - persistNotificationsCountOnUnloading bool - // isWrapped tells whether the context's DAO was wrapped into another layer of - // MemCachedStore on creation and whether it should be unwrapped on context unloading. - isWrapped bool // onUnload is a callback that should be called after current context unloading // if no exception occurs. onUnload ContextUnloadCallback } // ContextUnloadCallback is a callback method used on context unloading from istack. -type ContextUnloadCallback func(parentEstack *Stack) +type ContextUnloadCallback func(commit bool) error var errNoInstParam = errors.New("failed to read instruction parameter") // NewContext returns a new Context object. func NewContext(b []byte) *Context { - return NewContextWithParams(b, -1, 0, nil) + return NewContextWithParams(b, -1, 0) } // NewContextWithParams creates new Context objects using script, parameter count, // return value count and initial position in script. -func NewContextWithParams(b []byte, rvcount int, pos int, notificationsCount *int) *Context { - if notificationsCount == nil { - notificationsCount = new(int) - } +func NewContextWithParams(b []byte, rvcount int, pos int) *Context { return &Context{ - prog: b, - retCount: rvcount, - nextip: pos, - notificationsCount: notificationsCount, + prog: b, + retCount: rvcount, + nextip: pos, } } @@ -335,3 +322,13 @@ func (v *VM) PushContextScriptHash(n int) error { v.Estack().PushItem(stackitem.NewByteArray(h.BytesBE())) return nil } + +func (c *Context) HasTryBlock() bool { + for i := 0; i < c.tryStack.Len(); i++ { + eCtx := c.tryStack.Peek(i).Value().(*exceptionHandlingContext) + if eCtx.State == eTry { + return true + } + } + return false +} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 4733d777d..f73a6e7f6 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -67,11 +67,6 @@ type VM struct { // callback to get interop price getPrice func(opcode.Opcode, []byte) int64 - // wraps DAO with private MemCachedStore - wrapDao func() - // either commits or discards changes made in the current context; performs DAO unwrapping. - unwrapDAO func(commit bool, notificationsCount int) error - istack Stack // invocation stack. estack *Stack // execution stack. @@ -121,24 +116,6 @@ func NewWithTrigger(t trigger.Type) *VM { return vm } -func (v *VM) EmitNotification() { - currCtx := v.Context() - if currCtx == nil { - return - } - *currCtx.notificationsCount++ -} - -// SetIsolationCallbacks registers given callbacks to perform DAO and interop context -// isolation between contract calls. -// wrapper performs DAO cloning; -// committer persists changes made in the upper snapshot to the underlying DAO; -// reverter rolls back the whole set of changes made in the current snapshot. -func (v *VM) SetIsolationCallbacks(wrapper func(), unwrapper func(commit bool, notificationsCount int) error) { - v.wrapDao = wrapper - v.unwrapDAO = unwrapper -} - // SetPriceGetter registers the given PriceGetterFunc in v. // f accepts vm's Context, current instruction and instruction parameter. func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) { @@ -343,7 +320,7 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint var sl slot v.checkInvocationStackSize() - ctx := NewContextWithParams(b, rvcount, offset, nil) + ctx := NewContextWithParams(b, rvcount, offset) if rvcount != -1 || v.estack.Len() != 0 { v.estack = newStack("evaluation", &v.refs) } @@ -354,9 +331,9 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint ctx.scriptHash = hash ctx.callingScriptHash = caller ctx.NEF = exe - parent := v.Context() if v.invTree != nil { curTree := v.invTree + parent := v.Context() if parent != nil { curTree = parent.invTree } @@ -364,23 +341,6 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint curTree.Calls = append(curTree.Calls, newTree) ctx.invTree = newTree } - if v.wrapDao != nil { - needWrap := f&(callflag.All^callflag.ReadOnly) != 0 // If the method is safe, then it's read-only and doesn't perform storage changes or emit notifications. - if !needWrap && parent != nil { // If the method is not wrapped into try-catch block, then changes should be discarded anyway if exception occurs. - for i := 0; i < parent.tryStack.Len(); i++ { - eCtx := parent.tryStack.Peek(i).Value().(*exceptionHandlingContext) - if eCtx.State == eTry { - needWrap = true // TODO: is it correct to wrap it only once and break after the first occurrence? - break - } - } - } - if needWrap { - v.wrapDao() - ctx.isWrapped = true - } - } - ctx.persistNotificationsCountOnUnloading = true ctx.onUnload = onContextUnload v.istack.PushItem(ctx) } @@ -1632,21 +1592,12 @@ func (v *VM) unloadContext(ctx *Context) { if ctx.static != nil && (currCtx == nil || ctx.static != currCtx.static) { ctx.static.ClearRefs(&v.refs) } - if ctx.isWrapped && v.unwrapDAO != nil { // In case of CALL, CALLA, CALLL we don't need to commit/discard changes, unwrap DAO and change notificationsCount. - err := v.unwrapDAO(v.uncaughtException == nil, *ctx.notificationsCount) + if ctx.onUnload != nil { + err := ctx.onUnload(v.uncaughtException == nil) if err != nil { - panic(fmt.Errorf("failed to unwrap DAO: %w", err)) + panic(fmt.Errorf("context unload callback failed: %w", err)) } } - if currCtx != nil && ctx.persistNotificationsCountOnUnloading && !(ctx.isWrapped && v.uncaughtException != nil) { - *currCtx.notificationsCount += *ctx.notificationsCount - } - if currCtx != nil && ctx.onUnload != nil { - if v.uncaughtException == nil { - ctx.onUnload(currCtx.Estack()) // Use the estack of current context. - } - ctx.onUnload = nil - } } // getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks. @@ -1703,12 +1654,7 @@ func (v *VM) call(ctx *Context, offset int) { newCtx.tryStack.elems = nil initStack(&newCtx.tryStack, "exception", nil) newCtx.NEF = ctx.NEF - // Use exactly the same counter and don't use v.wrapDao() for this context. - // Unloading of such unwrapped context will be properly handled inside - // unloadContext without unnecessary DAO unwrapping and notificationsCount changes. - newCtx.notificationsCount = ctx.notificationsCount - newCtx.isWrapped = false - newCtx.persistNotificationsCountOnUnloading = false + // Do not clone unloading callback, new context does not require any actions to perform on unloading. newCtx.onUnload = nil v.istack.PushItem(newCtx) newCtx.Jump(offset)