diff --git a/pkg/core/helper_test.go b/pkg/core/helper_test.go index 0cef3e47a..58bcab296 100644 --- a/pkg/core/helper_test.go +++ b/pkg/core/helper_test.go @@ -540,7 +540,7 @@ func addSigners(sender util.Uint160, txs ...*transaction.Transaction) { for _, tx := range txs { tx.Signers = []transaction.Signer{{ Account: sender, - Scopes: transaction.CalledByEntry, + Scopes: transaction.Global, AllowedContracts: nil, AllowedGroups: nil, }} diff --git a/pkg/core/interop/context.go b/pkg/core/interop/context.go index 1c7a3ad6e..ce6ba1550 100644 --- a/pkg/core/interop/context.go +++ b/pkg/core/interop/context.go @@ -51,6 +51,7 @@ type Context struct { cancelFuncs []context.CancelFunc getContract func(dao.DAO, util.Uint160) (*state.Contract, error) baseExecFee int64 + signers []transaction.Signer } // NewContext returns new interop context. @@ -89,6 +90,22 @@ func (ic *Context) InitNonceData() { } } +// UseSigners allows overriding signers used in this context. +func (ic *Context) UseSigners(s []transaction.Signer) { + ic.signers = s +} + +// Signers returns signers witnessing current execution context. +func (ic *Context) Signers() []transaction.Signer { + if ic.signers != nil { + return ic.signers + } + if ic.Tx != nil { + return ic.Tx.Signers + } + return nil +} + // Function binds function name, id with the function itself and price, // it's supposed to be inited once for all interopContexts, so it doesn't use // vm.InteropFuncPrice directly. diff --git a/pkg/core/interop/runtime/witness.go b/pkg/core/interop/runtime/witness.go index cc951724d..ba5674993 100644 --- a/pkg/core/interop/runtime/witness.go +++ b/pkg/core/interop/runtime/witness.go @@ -22,11 +22,7 @@ func CheckHashedWitness(ic *interop.Context, hash util.Uint160) (bool, error) { if !callingSH.Equals(util.Uint160{}) && hash.Equals(callingSH) { return true, nil } - if tx, ok := ic.Container.(*transaction.Transaction); ok { - return checkScope(ic, tx, ic.VM, hash) - } - - return false, errors.New("script container is not a transaction") + return checkScope(ic, hash) } type scopeContext struct { @@ -61,21 +57,26 @@ func (sc scopeContext) CurrentScriptHasGroup(k *keys.PublicKey) (bool, error) { return sc.checkScriptGroups(sc.GetCurrentScriptHash(), k) } -func checkScope(ic *interop.Context, tx *transaction.Transaction, v *vm.VM, hash util.Uint160) (bool, error) { - for _, c := range tx.Signers { +func checkScope(ic *interop.Context, hash util.Uint160) (bool, error) { + signers := ic.Signers() + if len(signers) == 0 { + return false, errors.New("no valid signers") + } + for i := range signers { + c := &signers[i] if c.Account == hash { if c.Scopes == transaction.Global { return true, nil } if c.Scopes&transaction.CalledByEntry != 0 { - callingScriptHash := v.GetCallingScriptHash() - entryScriptHash := v.GetEntryScriptHash() + callingScriptHash := ic.VM.GetCallingScriptHash() + entryScriptHash := ic.VM.GetEntryScriptHash() if callingScriptHash.Equals(util.Uint160{}) || callingScriptHash == entryScriptHash { return true, nil } } if c.Scopes&transaction.CustomContracts != 0 { - currentScriptHash := v.GetCurrentScriptHash() + currentScriptHash := ic.VM.GetCurrentScriptHash() for _, allowedContract := range c.AllowedContracts { if allowedContract == currentScriptHash { return true, nil @@ -83,7 +84,7 @@ func checkScope(ic *interop.Context, tx *transaction.Transaction, v *vm.VM, hash } } if c.Scopes&transaction.CustomGroups != 0 { - groups, err := getContractGroups(v, ic, v.GetCurrentScriptHash()) + groups, err := getContractGroups(ic.VM, ic, ic.VM.GetCurrentScriptHash()) if err != nil { return false, err } @@ -95,7 +96,7 @@ func checkScope(ic *interop.Context, tx *transaction.Transaction, v *vm.VM, hash } } if c.Scopes&transaction.Rules != 0 { - ctx := scopeContext{v, ic} + ctx := scopeContext{ic.VM, ic} for _, r := range c.Rules { res, err := r.Condition.Match(ctx) if err != nil { diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index e2adda0a8..4ee1c63da 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -1183,7 +1183,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, }, } - ic.Container = tx + ic.Tx = tx callingScriptHash := scriptHash loadScriptWithHashAndFlags(ic, script, callingScriptHash, callflag.All) ic.VM.LoadScriptWithHash([]byte{0x1}, random.Uint160(), callflag.AllowCall) @@ -1205,7 +1205,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, }, } - ic.Container = tx + ic.Tx = tx callingScriptHash := scriptHash loadScriptWithHashAndFlags(ic, script, callingScriptHash, callflag.All) ic.VM.LoadScriptWithHash([]byte{0x1}, random.Uint160(), callflag.AllowCall) @@ -1242,7 +1242,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, } loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) - ic.Container = tx + ic.Tx = tx check(t, ic, hash.BytesBE(), false, true) }) t.Run("CalledByEntry", func(t *testing.T) { @@ -1256,7 +1256,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, } loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) - ic.Container = tx + ic.Tx = tx check(t, ic, hash.BytesBE(), false, true) }) t.Run("CustomContracts", func(t *testing.T) { @@ -1271,7 +1271,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, } loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) - ic.Container = tx + ic.Tx = tx check(t, ic, hash.BytesBE(), false, true) }) t.Run("CustomGroups", func(t *testing.T) { @@ -1287,7 +1287,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, } loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) - ic.Container = tx + ic.Tx = tx check(t, ic, hash.BytesBE(), false, false) }) t.Run("positive", func(t *testing.T) { @@ -1319,7 +1319,7 @@ func TestRuntimeCheckWitness(t *testing.T) { } require.NoError(t, bc.contracts.Management.PutContractState(ic.DAO, contractState)) loadScriptWithHashAndFlags(ic, contractScript, contractScriptHash, callflag.All) - ic.Container = tx + ic.Tx = tx check(t, ic, targetHash.BytesBE(), false, true) }) }) @@ -1339,7 +1339,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, } loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) - ic.Container = tx + ic.Tx = tx check(t, ic, hash.BytesBE(), false, false) }) t.Run("allow", func(t *testing.T) { @@ -1358,7 +1358,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, } loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) - ic.Container = tx + ic.Tx = tx check(t, ic, hash.BytesBE(), false, true) }) t.Run("deny", func(t *testing.T) { @@ -1377,7 +1377,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, } loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) - ic.Container = tx + ic.Tx = tx check(t, ic, hash.BytesBE(), false, false) }) }) @@ -1392,7 +1392,7 @@ func TestRuntimeCheckWitness(t *testing.T) { }, } loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) - ic.Container = tx + ic.Tx = tx check(t, ic, hash.BytesBE(), false, false) }) }) diff --git a/pkg/core/native/oracle.go b/pkg/core/native/oracle.go index 393c0d979..18c168743 100644 --- a/pkg/core/native/oracle.go +++ b/pkg/core/native/oracle.go @@ -275,6 +275,12 @@ func (o *Oracle) FinishInternal(ic *interop.Context) error { stackitem.Make(req.OriginalTxID.BytesBE()), }), }) + origTx, _, err := ic.DAO.GetTransaction(req.OriginalTxID) + if err != nil { + return ErrRequestNotFound + } + ic.UseSigners(origTx.Signers) + defer ic.UseSigners(nil) userData, err := stackitem.Deserialize(req.UserData) if err != nil { diff --git a/pkg/core/native_oracle_test.go b/pkg/core/native_oracle_test.go index 4ba681a31..28f883879 100644 --- a/pkg/core/native_oracle_test.go +++ b/pkg/core/native_oracle_test.go @@ -43,6 +43,11 @@ func getOracleContractState(h util.Uint160, stdHash util.Uint160) *state.Contrac // `handle` method aborts if len(userData) == 2 offset := w.Len() + emit.Bytes(w.BinWriter, neoOwner.BytesBE()) + emit.Syscall(w.BinWriter, interopnames.SystemRuntimeCheckWitness) + emit.Instruction(w.BinWriter, opcode.JMPIF, []byte{3}) + emit.Opcodes(w.BinWriter, opcode.ABORT) + emit.Opcodes(w.BinWriter, opcode.OVER) emit.Opcodes(w.BinWriter, opcode.SIZE) emit.Int(w.BinWriter, 2)