diff --git a/pkg/core/interop/runtime/witness.go b/pkg/core/interop/runtime/witness.go index d6a01e652..09c4335d4 100644 --- a/pkg/core/interop/runtime/witness.go +++ b/pkg/core/interop/runtime/witness.go @@ -41,6 +41,10 @@ func getContractGroups(v *vm.VM, ic *interop.Context, h util.Uint160) (manifest. return manifest.Groups(cs.Manifest.Groups), nil } +func (sc scopeContext) IsCalledByEntry() bool { + return sc.VM.Context().IsCalledByEntry() +} + func (sc scopeContext) checkScriptGroups(h util.Uint160, k *keys.PublicKey) (bool, error) { groups, err := getContractGroups(sc.VM, sc.ic, h) if err != nil { @@ -69,9 +73,7 @@ func checkScope(ic *interop.Context, hash util.Uint160) (bool, error) { return true, nil } if c.Scopes&transaction.CalledByEntry != 0 { - callingScriptHash := ic.VM.GetCallingScriptHash() - entryScriptHash := ic.VM.GetEntryScriptHash() - if callingScriptHash.Equals(util.Uint160{}) || callingScriptHash == entryScriptHash { + if ic.VM.Context().IsCalledByEntry() { return true, nil } } diff --git a/pkg/core/transaction/witness_condition.go b/pkg/core/transaction/witness_condition.go index 71fcc1d33..65a73b09e 100644 --- a/pkg/core/transaction/witness_condition.go +++ b/pkg/core/transaction/witness_condition.go @@ -63,9 +63,9 @@ type WitnessCondition interface { type MatchContext interface { GetCallingScriptHash() util.Uint160 GetCurrentScriptHash() util.Uint160 - GetEntryScriptHash() util.Uint160 CallingScriptHasGroup(*keys.PublicKey) (bool, error) CurrentScriptHasGroup(*keys.PublicKey) (bool, error) + IsCalledByEntry() bool } type ( @@ -394,8 +394,7 @@ func (c ConditionCalledByEntry) Type() WitnessConditionType { // Match implements the WitnessCondition interface checking whether this condition // matches given context. func (c ConditionCalledByEntry) Match(ctx MatchContext) (bool, error) { - entry := ctx.GetEntryScriptHash() - return entry.Equals(ctx.GetCallingScriptHash()) || entry.Equals(ctx.GetCurrentScriptHash()), nil + return ctx.IsCalledByEntry(), nil } // EncodeBinary implements the WitnessCondition interface allowing to serialize condition. diff --git a/pkg/core/transaction/witness_condition_test.go b/pkg/core/transaction/witness_condition_test.go index e8ac6d4fc..2a4f56244 100644 --- a/pkg/core/transaction/witness_condition_test.go +++ b/pkg/core/transaction/witness_condition_test.go @@ -170,6 +170,9 @@ func (t *TestMC) GetCurrentScriptHash() util.Uint160 { func (t *TestMC) GetEntryScriptHash() util.Uint160 { return t.entry } +func (t *TestMC) IsCalledByEntry() bool { + return t.entry.Equals(t.calling) || t.calling.Equals(util.Uint160{}) +} func (t *TestMC) CallingScriptHasGroup(k *keys.PublicKey) (bool, error) { res, err := t.CurrentScriptHasGroup(k) return !res, err // To differentiate from current we invert the logic value. diff --git a/pkg/core/transaction/witness_scope.go b/pkg/core/transaction/witness_scope.go index 5daf15eaa..aa67eb33a 100644 --- a/pkg/core/transaction/witness_scope.go +++ b/pkg/core/transaction/witness_scope.go @@ -13,7 +13,7 @@ type WitnessScope byte const ( // None specifies that no contract was witnessed. Only sign the transaction. None WitnessScope = 0 - // CalledByEntry means that this condition must hold: EntryScriptHash == CallingScriptHash. + // CalledByEntry witness is only valid in entry script and ones directly called by it. // No params is needed, as the witness/permission/signature given on first invocation will // automatically expire if entering deeper internal invokes. This can be default safe // choice for native NEO/GAS (previously used on Neo 2 as "attach" mode). diff --git a/pkg/vm/context.go b/pkg/vm/context.go index aaf74b01e..31c4280f7 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -37,6 +37,9 @@ type scriptContext struct { // Caller's contract script hash. callingScriptHash util.Uint160 + // Caller's scriptContext, if not entry. + callingContext *scriptContext + // Call flags this context was created with. callFlag callflag.CallFlag @@ -326,6 +329,12 @@ func (v *VM) getContextScriptHash(n int) util.Uint160 { return ctx.ScriptHash() } +// IsCalledByEntry checks parent script contexts and return true if the current one +// is an entry script (the first loaded into the VM) or one called by it. +func (c *Context) IsCalledByEntry() bool { + return c.sc.callingContext == nil || c.sc.callingContext.callingContext == nil +} + // PushContextScriptHash pushes the script hash of the // invocation stack element number n to the evaluation stack. func (v *VM) PushContextScriptHash(n int) error { diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 47c68a772..378a898b2 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -340,15 +340,18 @@ func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint if rvcount != -1 || v.estack.Len() != 0 { v.estack = subStack(v.estack) } + parent := v.Context() ctx.sc.estack = v.estack initStack(&ctx.tryStack, "exception", nil) ctx.sc.callFlag = f ctx.sc.scriptHash = hash ctx.sc.callingScriptHash = caller + if parent != nil { + ctx.sc.callingContext = parent.sc + } ctx.sc.NEF = exe if v.invTree != nil { curTree := v.invTree - parent := v.Context() if parent != nil { curTree = parent.sc.invTree }