Merge pull request #2261 from nspcc-dev/fix-oracle-witnesses

native/interop: use oracle request signers for oracle response witness
This commit is contained in:
Roman Khimov 2021-11-16 12:02:08 +03:00 committed by GitHub
commit 965c3b2c13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 53 additions and 24 deletions

View file

@ -540,7 +540,7 @@ func addSigners(sender util.Uint160, txs ...*transaction.Transaction) {
for _, tx := range txs { for _, tx := range txs {
tx.Signers = []transaction.Signer{{ tx.Signers = []transaction.Signer{{
Account: sender, Account: sender,
Scopes: transaction.CalledByEntry, Scopes: transaction.Global,
AllowedContracts: nil, AllowedContracts: nil,
AllowedGroups: nil, AllowedGroups: nil,
}} }}

View file

@ -51,6 +51,7 @@ type Context struct {
cancelFuncs []context.CancelFunc cancelFuncs []context.CancelFunc
getContract func(dao.DAO, util.Uint160) (*state.Contract, error) getContract func(dao.DAO, util.Uint160) (*state.Contract, error)
baseExecFee int64 baseExecFee int64
signers []transaction.Signer
} }
// NewContext returns new interop context. // NewContext returns new interop context.
@ -89,6 +90,22 @@ func (ic *Context) InitNonceData() {
} }
} }
// UseSigners allows overriding signers used in this context.
func (ic *Context) UseSigners(s []transaction.Signer) {
ic.signers = s
}
// Signers returns signers witnessing current execution context.
func (ic *Context) Signers() []transaction.Signer {
if ic.signers != nil {
return ic.signers
}
if ic.Tx != nil {
return ic.Tx.Signers
}
return nil
}
// Function binds function name, id with the function itself and price, // Function binds function name, id with the function itself and price,
// it's supposed to be inited once for all interopContexts, so it doesn't use // it's supposed to be inited once for all interopContexts, so it doesn't use
// vm.InteropFuncPrice directly. // vm.InteropFuncPrice directly.

View file

@ -22,11 +22,7 @@ func CheckHashedWitness(ic *interop.Context, hash util.Uint160) (bool, error) {
if !callingSH.Equals(util.Uint160{}) && hash.Equals(callingSH) { if !callingSH.Equals(util.Uint160{}) && hash.Equals(callingSH) {
return true, nil return true, nil
} }
if tx, ok := ic.Container.(*transaction.Transaction); ok { return checkScope(ic, hash)
return checkScope(ic, tx, ic.VM, hash)
}
return false, errors.New("script container is not a transaction")
} }
type scopeContext struct { type scopeContext struct {
@ -61,21 +57,26 @@ func (sc scopeContext) CurrentScriptHasGroup(k *keys.PublicKey) (bool, error) {
return sc.checkScriptGroups(sc.GetCurrentScriptHash(), k) return sc.checkScriptGroups(sc.GetCurrentScriptHash(), k)
} }
func checkScope(ic *interop.Context, tx *transaction.Transaction, v *vm.VM, hash util.Uint160) (bool, error) { func checkScope(ic *interop.Context, hash util.Uint160) (bool, error) {
for _, c := range tx.Signers { signers := ic.Signers()
if len(signers) == 0 {
return false, errors.New("no valid signers")
}
for i := range signers {
c := &signers[i]
if c.Account == hash { if c.Account == hash {
if c.Scopes == transaction.Global { if c.Scopes == transaction.Global {
return true, nil return true, nil
} }
if c.Scopes&transaction.CalledByEntry != 0 { if c.Scopes&transaction.CalledByEntry != 0 {
callingScriptHash := v.GetCallingScriptHash() callingScriptHash := ic.VM.GetCallingScriptHash()
entryScriptHash := v.GetEntryScriptHash() entryScriptHash := ic.VM.GetEntryScriptHash()
if callingScriptHash.Equals(util.Uint160{}) || callingScriptHash == entryScriptHash { if callingScriptHash.Equals(util.Uint160{}) || callingScriptHash == entryScriptHash {
return true, nil return true, nil
} }
} }
if c.Scopes&transaction.CustomContracts != 0 { if c.Scopes&transaction.CustomContracts != 0 {
currentScriptHash := v.GetCurrentScriptHash() currentScriptHash := ic.VM.GetCurrentScriptHash()
for _, allowedContract := range c.AllowedContracts { for _, allowedContract := range c.AllowedContracts {
if allowedContract == currentScriptHash { if allowedContract == currentScriptHash {
return true, nil return true, nil
@ -83,7 +84,7 @@ func checkScope(ic *interop.Context, tx *transaction.Transaction, v *vm.VM, hash
} }
} }
if c.Scopes&transaction.CustomGroups != 0 { if c.Scopes&transaction.CustomGroups != 0 {
groups, err := getContractGroups(v, ic, v.GetCurrentScriptHash()) groups, err := getContractGroups(ic.VM, ic, ic.VM.GetCurrentScriptHash())
if err != nil { if err != nil {
return false, err return false, err
} }
@ -95,7 +96,7 @@ func checkScope(ic *interop.Context, tx *transaction.Transaction, v *vm.VM, hash
} }
} }
if c.Scopes&transaction.Rules != 0 { if c.Scopes&transaction.Rules != 0 {
ctx := scopeContext{v, ic} ctx := scopeContext{ic.VM, ic}
for _, r := range c.Rules { for _, r := range c.Rules {
res, err := r.Condition.Match(ctx) res, err := r.Condition.Match(ctx)
if err != nil { if err != nil {

View file

@ -1183,7 +1183,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
}, },
} }
ic.Container = tx ic.Tx = tx
callingScriptHash := scriptHash callingScriptHash := scriptHash
loadScriptWithHashAndFlags(ic, script, callingScriptHash, callflag.All) loadScriptWithHashAndFlags(ic, script, callingScriptHash, callflag.All)
ic.VM.LoadScriptWithHash([]byte{0x1}, random.Uint160(), callflag.AllowCall) ic.VM.LoadScriptWithHash([]byte{0x1}, random.Uint160(), callflag.AllowCall)
@ -1205,7 +1205,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
}, },
} }
ic.Container = tx ic.Tx = tx
callingScriptHash := scriptHash callingScriptHash := scriptHash
loadScriptWithHashAndFlags(ic, script, callingScriptHash, callflag.All) loadScriptWithHashAndFlags(ic, script, callingScriptHash, callflag.All)
ic.VM.LoadScriptWithHash([]byte{0x1}, random.Uint160(), callflag.AllowCall) ic.VM.LoadScriptWithHash([]byte{0x1}, random.Uint160(), callflag.AllowCall)
@ -1242,7 +1242,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
} }
loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates)
ic.Container = tx ic.Tx = tx
check(t, ic, hash.BytesBE(), false, true) check(t, ic, hash.BytesBE(), false, true)
}) })
t.Run("CalledByEntry", func(t *testing.T) { t.Run("CalledByEntry", func(t *testing.T) {
@ -1256,7 +1256,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
} }
loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates)
ic.Container = tx ic.Tx = tx
check(t, ic, hash.BytesBE(), false, true) check(t, ic, hash.BytesBE(), false, true)
}) })
t.Run("CustomContracts", func(t *testing.T) { t.Run("CustomContracts", func(t *testing.T) {
@ -1271,7 +1271,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
} }
loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates)
ic.Container = tx ic.Tx = tx
check(t, ic, hash.BytesBE(), false, true) check(t, ic, hash.BytesBE(), false, true)
}) })
t.Run("CustomGroups", func(t *testing.T) { t.Run("CustomGroups", func(t *testing.T) {
@ -1287,7 +1287,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
} }
loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates)
ic.Container = tx ic.Tx = tx
check(t, ic, hash.BytesBE(), false, false) check(t, ic, hash.BytesBE(), false, false)
}) })
t.Run("positive", func(t *testing.T) { t.Run("positive", func(t *testing.T) {
@ -1319,7 +1319,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
} }
require.NoError(t, bc.contracts.Management.PutContractState(ic.DAO, contractState)) require.NoError(t, bc.contracts.Management.PutContractState(ic.DAO, contractState))
loadScriptWithHashAndFlags(ic, contractScript, contractScriptHash, callflag.All) loadScriptWithHashAndFlags(ic, contractScript, contractScriptHash, callflag.All)
ic.Container = tx ic.Tx = tx
check(t, ic, targetHash.BytesBE(), false, true) check(t, ic, targetHash.BytesBE(), false, true)
}) })
}) })
@ -1339,7 +1339,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
} }
loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates)
ic.Container = tx ic.Tx = tx
check(t, ic, hash.BytesBE(), false, false) check(t, ic, hash.BytesBE(), false, false)
}) })
t.Run("allow", func(t *testing.T) { t.Run("allow", func(t *testing.T) {
@ -1358,7 +1358,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
} }
loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates)
ic.Container = tx ic.Tx = tx
check(t, ic, hash.BytesBE(), false, true) check(t, ic, hash.BytesBE(), false, true)
}) })
t.Run("deny", func(t *testing.T) { t.Run("deny", func(t *testing.T) {
@ -1377,7 +1377,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
} }
loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates)
ic.Container = tx ic.Tx = tx
check(t, ic, hash.BytesBE(), false, false) check(t, ic, hash.BytesBE(), false, false)
}) })
}) })
@ -1392,7 +1392,7 @@ func TestRuntimeCheckWitness(t *testing.T) {
}, },
} }
loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates) loadScriptWithHashAndFlags(ic, script, scriptHash, callflag.ReadStates)
ic.Container = tx ic.Tx = tx
check(t, ic, hash.BytesBE(), false, false) check(t, ic, hash.BytesBE(), false, false)
}) })
}) })

View file

@ -275,6 +275,12 @@ func (o *Oracle) FinishInternal(ic *interop.Context) error {
stackitem.Make(req.OriginalTxID.BytesBE()), stackitem.Make(req.OriginalTxID.BytesBE()),
}), }),
}) })
origTx, _, err := ic.DAO.GetTransaction(req.OriginalTxID)
if err != nil {
return ErrRequestNotFound
}
ic.UseSigners(origTx.Signers)
defer ic.UseSigners(nil)
userData, err := stackitem.Deserialize(req.UserData) userData, err := stackitem.Deserialize(req.UserData)
if err != nil { if err != nil {

View file

@ -43,6 +43,11 @@ func getOracleContractState(h util.Uint160, stdHash util.Uint160) *state.Contrac
// `handle` method aborts if len(userData) == 2 // `handle` method aborts if len(userData) == 2
offset := w.Len() offset := w.Len()
emit.Bytes(w.BinWriter, neoOwner.BytesBE())
emit.Syscall(w.BinWriter, interopnames.SystemRuntimeCheckWitness)
emit.Instruction(w.BinWriter, opcode.JMPIF, []byte{3})
emit.Opcodes(w.BinWriter, opcode.ABORT)
emit.Opcodes(w.BinWriter, opcode.OVER) emit.Opcodes(w.BinWriter, opcode.OVER)
emit.Opcodes(w.BinWriter, opcode.SIZE) emit.Opcodes(w.BinWriter, opcode.SIZE)
emit.Int(w.BinWriter, 2) emit.Int(w.BinWriter, 2)