From 70aed34d77d916ca9d2551b871332fb2ba85a3f5 Mon Sep 17 00:00:00 2001 From: Roman Khimov Date: Thu, 29 Jun 2023 11:18:30 +0300 Subject: [PATCH] interop/contract: fix state rollbacks for nested contexts Our wrapping optimization relied on the caller context having a TRY block, but each context (including internal calls!) has an exception handling stack of its own, which means that for an invocation stack of entry A.someMethodFromEntry() # this one has a TRY A.internalMethodViaCALL() # this one doesn't B.someMethod() we get `HasTryBlock() == false` for `A.internalMethodViaCALL()` context, which leads to missing wrapper and missing rollbacks if B is to THROW. What this patch does instead is it checks for any context within contract boundaries. Fixes #3045. Signed-off-by: Roman Khimov --- pkg/core/interop/contract/call.go | 2 +- pkg/core/interop/contract/call_test.go | 3 +++ pkg/vm/context.go | 10 ---------- pkg/vm/vm.go | 22 ++++++++++++++++++++++ 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index 170e596a4..e1b8b72be 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -117,7 +117,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra ic.Invocations[cs.Hash]++ f = ic.VM.Context().GetCallFlags() & f - wrapped := ic.VM.Context().HasTryBlock() && // If the method is not wrapped into try-catch block, then changes should be discarded anyway if exception occurs. + wrapped := ic.VM.ContractHasTryBlock() && // If the method is not wrapped into try-catch block, then changes should be discarded anyway if exception occurs. f&(callflag.All^callflag.ReadOnly) != 0 // If the method is safe, then it's read-only and doesn't perform storage changes or emit notifications. baseNtfCount := len(ic.Notifications) baseDAO := ic.DAO diff --git a/pkg/core/interop/contract/call_test.go b/pkg/core/interop/contract/call_test.go index 50648a4b6..909e6a7c4 100644 --- a/pkg/core/interop/contract/call_test.go +++ b/pkg/core/interop/contract/call_test.go @@ -304,6 +304,9 @@ func TestSnapshotIsolation_Exceptions(t *testing.T) { for i := 0; i < nNtfB1; i++ { runtime.Notify("NotificationFromB before panic", i) } + internalCaller(keyA, valueA, nNtfA) + } + func internalCaller(keyA, valueA []byte, nNtfA int) { contract.Call(interop.Hash160{` + hashAStr + `}, "doAndPanic", contract.All, keyA, valueA, nNtfA) } func CheckStorageChanges() bool { diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 95d693296..1d44c29c5 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -315,16 +315,6 @@ func (v *VM) PushContextScriptHash(n int) error { return nil } -func (c *Context) HasTryBlock() bool { - for i := 0; i < c.tryStack.Len(); i++ { - eCtx := c.tryStack.Peek(i).Value().(*exceptionHandlingContext) - if eCtx.State == eTry { - return true - } - } - return false -} - // MarshalJSON implements the JSON marshalling interface. func (c *Context) MarshalJSON() ([]byte, error) { var aux = contextAux{ diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 950dfed7e..4ee109516 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1804,6 +1804,28 @@ func throwUnhandledException(item stackitem.Item) { panic(msg) } +// ContractHasTryBlock checks if the currently executing contract has a TRY +// block in one of its contexts. +func (v *VM) ContractHasTryBlock() bool { + var topctx *Context // Currently executing context. + for i := 0; i < len(v.istack); i++ { + ictx := v.istack[len(v.istack)-1-i] // It's a stack, going backwards like handleException(). + if topctx == nil { + topctx = ictx + } + if ictx.sc != topctx.sc { + return false // Different contract -> no one cares. + } + for j := 0; j < ictx.tryStack.Len(); j++ { + eCtx := ictx.tryStack.Peek(j).Value().(*exceptionHandlingContext) + if eCtx.State == eTry { + return true + } + } + } + return false +} + // CheckMultisigPar checks if the sigs contains sufficient valid signatures. func CheckMultisigPar(v *VM, curve elliptic.Curve, h []byte, pkeys [][]byte, sigs [][]byte) bool { if len(sigs) == 1 {