core, vm: move all isolation-related logic out of VM

Keep it inside the interop context.
This commit is contained in:
Anna Shaleva 2022-05-25 10:00:02 +03:00
parent f79f62dab4
commit a39b7cc3fd
6 changed files with 109 additions and 126 deletions

View file

@ -89,11 +89,6 @@ func (dao *Simple) GetWrapped() *Simple {
return d return d
} }
// GetUnwrapped returns the underlying DAO. It does not perform changes persist.
func (dao *Simple) GetUnwrapped() *Simple {
return dao.nativeCachePS
}
// GetPrivate returns a new DAO instance with another layer of private // GetPrivate returns a new DAO instance with another layer of private
// MemCachedStore around the current DAO Store. // MemCachedStore around the current DAO Store.
func (dao *Simple) GetPrivate() *Simple { func (dao *Simple) GetPrivate() *Simple {

View file

@ -318,33 +318,6 @@ func (ic *Context) SpawnVM() *vm.VM {
v := vm.NewWithTrigger(ic.Trigger) v := vm.NewWithTrigger(ic.Trigger)
v.GasLimit = -1 v.GasLimit = -1
v.SyscallHandler = ic.SyscallHandler v.SyscallHandler = ic.SyscallHandler
wrapper := func() {
if ic.DAO == nil {
return
}
ic.DAO = ic.DAO.GetPrivate()
}
unwrapper := func(commit bool, ntfToRemove int) error {
if !commit {
have := len(ic.Notifications)
if have < ntfToRemove {
panic(fmt.Errorf("inconsistent notifications count: should remove %d, have %d", ntfToRemove, len(ic.Notifications)))
}
ic.Notifications = ic.Notifications[:have-ntfToRemove]
}
if ic.DAO == nil {
return nil
}
if commit {
_, err := ic.DAO.Persist()
if err != nil {
return fmt.Errorf("failed to persist changes %w", err)
}
}
ic.DAO = ic.DAO.GetUnwrapped()
return nil
}
v.SetIsolationCallbacks(wrapper, unwrapper)
ic.VM = v ic.VM = v
return v return v
} }
@ -415,5 +388,4 @@ func (ic *Context) AddNotification(hash util.Uint160, name string, item *stackit
Name: name, Name: name,
Item: item, Item: item,
}) })
ic.VM.EmitNotification()
} }

View file

@ -40,7 +40,7 @@ func LoadToken(ic *interop.Context) func(id int32) error {
if err != nil { if err != nil {
return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err) return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err)
} }
return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, nil) return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, false)
} }
} }
@ -69,17 +69,11 @@ func Call(ic *interop.Context) error {
return fmt.Errorf("method not found: %s/%d", method, len(args)) return fmt.Errorf("method not found: %s/%d", method, len(args))
} }
hasReturn := md.ReturnType != smartcontract.VoidType hasReturn := md.ReturnType != smartcontract.VoidType
var cb vm.ContextUnloadCallback return callInternal(ic, cs, method, fs, hasReturn, args, !hasReturn)
if !hasReturn {
cb = func(estack *vm.Stack) {
estack.PushItem(stackitem.Null{})
}
}
return callInternal(ic, cs, method, fs, hasReturn, args, cb)
} }
func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag, func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
hasReturn bool, args []stackitem.Item, cb vm.ContextUnloadCallback) error { hasReturn bool, args []stackitem.Item, pushNullOnUnloading bool) error {
md := cs.Manifest.ABI.GetMethod(name, len(args)) md := cs.Manifest.ABI.GetMethod(name, len(args))
if md.Safe { if md.Safe {
f &^= (callflag.WriteStates | callflag.AllowNotify) f &^= (callflag.WriteStates | callflag.AllowNotify)
@ -91,12 +85,12 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl
} }
} }
} }
return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, cb) return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, hasReturn, pushNullOnUnloading)
} }
// callExFromNative calls a contract with flags using the provided calling hash. // callExFromNative calls a contract with flags using the provided calling hash.
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, cb vm.ContextUnloadCallback) error { name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool, pushNullOnUnloading bool) error {
for _, nc := range ic.Natives { for _, nc := range ic.Natives {
if nc.Metadata().Name == nativenames.Policy { if nc.Metadata().Name == nativenames.Policy {
var pch = nc.(policyChecker) var pch = nc.(policyChecker)
@ -122,8 +116,34 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra
initOff = md.Offset initOff = md.Offset
} }
ic.Invocations[cs.Hash]++ ic.Invocations[cs.Hash]++
ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, ic.VM.Context().GetCallFlags()&f, f = ic.VM.Context().GetCallFlags() & f
hasReturn, methodOff, initOff, cb)
wrapped := f&(callflag.All^callflag.ReadOnly) != 0 || // If the method is safe, then it's read-only and doesn't perform storage changes or emit notifications.
ic.VM.Context().HasTryBlock() // If the method is not wrapped into try-catch block, then changes should be discarded anyway if exception occurs.
baseNtfCount := len(ic.Notifications)
baseDAO := ic.DAO
if wrapped {
ic.DAO = ic.DAO.GetPrivate()
}
onUnload := func(commit bool) error {
if wrapped {
if commit {
_, err := ic.DAO.Persist()
if err != nil {
return fmt.Errorf("failed to persist changes %w", err)
}
} else {
ic.Notifications = ic.Notifications[:baseNtfCount] // Rollback all notification changes made by current context.
}
ic.DAO = baseDAO
}
if pushNullOnUnloading && commit {
ic.VM.Context().Estack().PushItem(stackitem.Null{}) // Must use current context stack.
}
return nil
}
ic.VM.LoadNEFMethod(&cs.NEF, caller, cs.Hash, f,
hasReturn, methodOff, initOff, onUnload)
for e, i := ic.VM.Estack(), len(args)-1; i >= 0; i-- { for e, i := ic.VM.Estack(), len(args)-1; i >= 0; i-- {
e.PushItem(args[i]) e.PushItem(args[i])
@ -137,7 +157,7 @@ var ErrNativeCall = errors.New("failed native call")
// CallFromNative performs synchronous call from native contract. // CallFromNative performs synchronous call from native contract.
func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error { func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, hasReturn bool) error {
startSize := ic.VM.Istack().Len() startSize := ic.VM.Istack().Len()
if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, nil); err != nil { if err := callExFromNative(ic, caller, cs, method, args, callflag.All, hasReturn, false); err != nil {
return err return err
} }

View file

@ -456,6 +456,59 @@ func TestSnapshotIsolation_Exceptions(t *testing.T) {
require.Equal(t, nNtfBBeforePanic+nNtfBAfterPanic, len(aer.Events)) require.Equal(t, nNtfBBeforePanic+nNtfBAfterPanic, len(aer.Events))
} }
// This test is written to test nested calls with try-catch block and proper notifications handling.
func TestSnapshotIsolation_NestedContextException(t *testing.T) {
bc, acc := chain.NewSingle(t)
e := neotest.NewExecutor(t, bc, acc, acc)
srcA := `package contractA
import (
"github.com/nspcc-dev/neo-go/pkg/interop/contract"
"github.com/nspcc-dev/neo-go/pkg/interop/runtime"
)
func CallA() {
runtime.Notify("Calling A")
contract.Call(runtime.GetExecutingScriptHash(), "a", contract.All)
runtime.Notify("Finish")
}
func A() {
defer func() {
if r := recover(); r != nil {
runtime.Notify("Caught")
}
}()
runtime.Notify("A")
contract.Call(runtime.GetExecutingScriptHash(), "b", contract.All)
runtime.Notify("Unreachable A")
}
func B() int {
runtime.Notify("B")
contract.Call(runtime.GetExecutingScriptHash(), "c", contract.All)
runtime.Notify("Unreachable B")
return 5
}
func C() {
runtime.Notify("C")
panic("exception from C")
}`
ctrA := neotest.CompileSource(t, acc.ScriptHash(), strings.NewReader(srcA), &compiler.Options{
NoEventsCheck: true,
NoPermissionsCheck: true,
Name: "contractA",
Permissions: []manifest.Permission{{Methods: manifest.WildStrings{Value: nil}}},
})
e.DeployContract(t, ctrA, nil)
ctrInvoker := e.NewInvoker(ctrA.Hash, e.Committee)
h := ctrInvoker.Invoke(t, stackitem.Null{}, "callA")
aer := e.GetTxExecResult(t, h)
require.Equal(t, 4, len(aer.Events))
require.Equal(t, "Calling A", aer.Events[0].Name)
require.Equal(t, "A", aer.Events[1].Name)
require.Equal(t, "Caught", aer.Events[2].Name)
require.Equal(t, "Finish", aer.Events[3].Name)
}
// This test is written to avoid https://github.com/neo-project/neo/issues/2746. // This test is written to avoid https://github.com/neo-project/neo/issues/2746.
func TestSnapshotIsolation_CallToItself(t *testing.T) { func TestSnapshotIsolation_CallToItself(t *testing.T) {
bc, acc := chain.NewSingle(t) bc, acc := chain.NewSingle(t)

View file

@ -54,41 +54,28 @@ type Context struct {
NEF *nef.File NEF *nef.File
// invTree is an invocation tree (or branch of it) for this context. // invTree is an invocation tree (or branch of it) for this context.
invTree *InvocationTree invTree *InvocationTree
// notificationsCount stores number of notifications emitted during current context
// handling.
notificationsCount *int
// persistNotificationsCountOnUnloading denotes whether notificationsCount should be
// persisted to the parent context on current context unloading.
persistNotificationsCountOnUnloading bool
// isWrapped tells whether the context's DAO was wrapped into another layer of
// MemCachedStore on creation and whether it should be unwrapped on context unloading.
isWrapped bool
// onUnload is a callback that should be called after current context unloading // onUnload is a callback that should be called after current context unloading
// if no exception occurs. // if no exception occurs.
onUnload ContextUnloadCallback onUnload ContextUnloadCallback
} }
// ContextUnloadCallback is a callback method used on context unloading from istack. // ContextUnloadCallback is a callback method used on context unloading from istack.
type ContextUnloadCallback func(parentEstack *Stack) type ContextUnloadCallback func(commit bool) error
var errNoInstParam = errors.New("failed to read instruction parameter") var errNoInstParam = errors.New("failed to read instruction parameter")
// NewContext returns a new Context object. // NewContext returns a new Context object.
func NewContext(b []byte) *Context { func NewContext(b []byte) *Context {
return NewContextWithParams(b, -1, 0, nil) return NewContextWithParams(b, -1, 0)
} }
// NewContextWithParams creates new Context objects using script, parameter count, // NewContextWithParams creates new Context objects using script, parameter count,
// return value count and initial position in script. // return value count and initial position in script.
func NewContextWithParams(b []byte, rvcount int, pos int, notificationsCount *int) *Context { func NewContextWithParams(b []byte, rvcount int, pos int) *Context {
if notificationsCount == nil {
notificationsCount = new(int)
}
return &Context{ return &Context{
prog: b, prog: b,
retCount: rvcount, retCount: rvcount,
nextip: pos, nextip: pos,
notificationsCount: notificationsCount,
} }
} }
@ -335,3 +322,13 @@ func (v *VM) PushContextScriptHash(n int) error {
v.Estack().PushItem(stackitem.NewByteArray(h.BytesBE())) v.Estack().PushItem(stackitem.NewByteArray(h.BytesBE()))
return nil return nil
} }
func (c *Context) HasTryBlock() bool {
for i := 0; i < c.tryStack.Len(); i++ {
eCtx := c.tryStack.Peek(i).Value().(*exceptionHandlingContext)
if eCtx.State == eTry {
return true
}
}
return false
}

View file

@ -67,11 +67,6 @@ type VM struct {
// callback to get interop price // callback to get interop price
getPrice func(opcode.Opcode, []byte) int64 getPrice func(opcode.Opcode, []byte) int64
// wraps DAO with private MemCachedStore
wrapDao func()
// either commits or discards changes made in the current context; performs DAO unwrapping.
unwrapDAO func(commit bool, notificationsCount int) error
istack Stack // invocation stack. istack Stack // invocation stack.
estack *Stack // execution stack. estack *Stack // execution stack.
@ -121,24 +116,6 @@ func NewWithTrigger(t trigger.Type) *VM {
return vm return vm
} }
func (v *VM) EmitNotification() {
currCtx := v.Context()
if currCtx == nil {
return
}
*currCtx.notificationsCount++
}
// SetIsolationCallbacks registers given callbacks to perform DAO and interop context
// isolation between contract calls.
// wrapper performs DAO cloning;
// committer persists changes made in the upper snapshot to the underlying DAO;
// reverter rolls back the whole set of changes made in the current snapshot.
func (v *VM) SetIsolationCallbacks(wrapper func(), unwrapper func(commit bool, notificationsCount int) error) {
v.wrapDao = wrapper
v.unwrapDAO = unwrapper
}
// SetPriceGetter registers the given PriceGetterFunc in v. // SetPriceGetter registers the given PriceGetterFunc in v.
// f accepts vm's Context, current instruction and instruction parameter. // f accepts vm's Context, current instruction and instruction parameter.
func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) { func (v *VM) SetPriceGetter(f func(opcode.Opcode, []byte) int64) {
@ -343,7 +320,7 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint
var sl slot var sl slot
v.checkInvocationStackSize() v.checkInvocationStackSize()
ctx := NewContextWithParams(b, rvcount, offset, nil) ctx := NewContextWithParams(b, rvcount, offset)
if rvcount != -1 || v.estack.Len() != 0 { if rvcount != -1 || v.estack.Len() != 0 {
v.estack = newStack("evaluation", &v.refs) v.estack = newStack("evaluation", &v.refs)
} }
@ -354,9 +331,9 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint
ctx.scriptHash = hash ctx.scriptHash = hash
ctx.callingScriptHash = caller ctx.callingScriptHash = caller
ctx.NEF = exe ctx.NEF = exe
parent := v.Context()
if v.invTree != nil { if v.invTree != nil {
curTree := v.invTree curTree := v.invTree
parent := v.Context()
if parent != nil { if parent != nil {
curTree = parent.invTree curTree = parent.invTree
} }
@ -364,23 +341,6 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint
curTree.Calls = append(curTree.Calls, newTree) curTree.Calls = append(curTree.Calls, newTree)
ctx.invTree = newTree ctx.invTree = newTree
} }
if v.wrapDao != nil {
needWrap := f&(callflag.All^callflag.ReadOnly) != 0 // If the method is safe, then it's read-only and doesn't perform storage changes or emit notifications.
if !needWrap && parent != nil { // If the method is not wrapped into try-catch block, then changes should be discarded anyway if exception occurs.
for i := 0; i < parent.tryStack.Len(); i++ {
eCtx := parent.tryStack.Peek(i).Value().(*exceptionHandlingContext)
if eCtx.State == eTry {
needWrap = true // TODO: is it correct to wrap it only once and break after the first occurrence?
break
}
}
}
if needWrap {
v.wrapDao()
ctx.isWrapped = true
}
}
ctx.persistNotificationsCountOnUnloading = true
ctx.onUnload = onContextUnload ctx.onUnload = onContextUnload
v.istack.PushItem(ctx) v.istack.PushItem(ctx)
} }
@ -1632,21 +1592,12 @@ func (v *VM) unloadContext(ctx *Context) {
if ctx.static != nil && (currCtx == nil || ctx.static != currCtx.static) { if ctx.static != nil && (currCtx == nil || ctx.static != currCtx.static) {
ctx.static.ClearRefs(&v.refs) ctx.static.ClearRefs(&v.refs)
} }
if ctx.isWrapped && v.unwrapDAO != nil { // In case of CALL, CALLA, CALLL we don't need to commit/discard changes, unwrap DAO and change notificationsCount. if ctx.onUnload != nil {
err := v.unwrapDAO(v.uncaughtException == nil, *ctx.notificationsCount) err := ctx.onUnload(v.uncaughtException == nil)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to unwrap DAO: %w", err)) panic(fmt.Errorf("context unload callback failed: %w", err))
} }
} }
if currCtx != nil && ctx.persistNotificationsCountOnUnloading && !(ctx.isWrapped && v.uncaughtException != nil) {
*currCtx.notificationsCount += *ctx.notificationsCount
}
if currCtx != nil && ctx.onUnload != nil {
if v.uncaughtException == nil {
ctx.onUnload(currCtx.Estack()) // Use the estack of current context.
}
ctx.onUnload = nil
}
} }
// getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks. // getTryParams splits TRY(L) instruction parameter into offsets for catch and finally blocks.
@ -1703,12 +1654,7 @@ func (v *VM) call(ctx *Context, offset int) {
newCtx.tryStack.elems = nil newCtx.tryStack.elems = nil
initStack(&newCtx.tryStack, "exception", nil) initStack(&newCtx.tryStack, "exception", nil)
newCtx.NEF = ctx.NEF newCtx.NEF = ctx.NEF
// Use exactly the same counter and don't use v.wrapDao() for this context. // Do not clone unloading callback, new context does not require any actions to perform on unloading.
// Unloading of such unwrapped context will be properly handled inside
// unloadContext without unnecessary DAO unwrapping and notificationsCount changes.
newCtx.notificationsCount = ctx.notificationsCount
newCtx.isWrapped = false
newCtx.persistNotificationsCountOnUnloading = false
newCtx.onUnload = nil newCtx.onUnload = nil
v.istack.PushItem(newCtx) v.istack.PushItem(newCtx)
newCtx.Jump(offset) newCtx.Jump(offset)