From e63191d31f016189205010f9a01b4fb1e93d66bf Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Thu, 10 Dec 2020 16:52:16 +0300 Subject: [PATCH] core: hangle CallingScriptHash correctly When using native contracts, script hash of second-to-top context on invocation stack does not always correspond to a real calling contract. --- pkg/core/interop/runtime/engine.go | 7 ++++++- pkg/core/interop_system_test.go | 3 ++- pkg/core/native_neo_test.go | 27 ++++++++++++++++++++++++--- pkg/vm/vm.go | 1 + 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/pkg/core/interop/runtime/engine.go b/pkg/core/interop/runtime/engine.go index 7dac81171..0ddf03136 100644 --- a/pkg/core/interop/runtime/engine.go +++ b/pkg/core/interop/runtime/engine.go @@ -22,8 +22,13 @@ func GetExecutingScriptHash(ic *interop.Context) error { } // GetCallingScriptHash returns calling script hash. +// While Executing and Entry script hashes are always valid for non-native contracts, +// Calling hash is set explicitly when native contracts are used, because when switching from +// one native to another, no operations are performed on invocation stack. func GetCallingScriptHash(ic *interop.Context) error { - return ic.VM.PushContextScriptHash(1) + h := ic.VM.GetCallingScriptHash() + ic.VM.Estack().PushVal(h.BytesBE()) + return nil } // GetEntryScriptHash returns entry script hash. diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index a8f197214..c968bb3ef 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -440,7 +440,8 @@ func getTestContractState() (*state.Contract, *state.Contract) { emit.Syscall(w.BinWriter, interopnames.SystemStorageGet) emit.Opcodes(w.BinWriter, opcode.RET) onPaymentOff := w.Len() - emit.Int(w.BinWriter, 3) + emit.Syscall(w.BinWriter, interopnames.SystemRuntimeGetCallingScriptHash) + emit.Int(w.BinWriter, 4) emit.Opcodes(w.BinWriter, opcode.PACK) emit.String(w.BinWriter, "LastPayment") emit.Syscall(w.BinWriter, interopnames.SystemRuntimeNotify) diff --git a/pkg/core/native_neo_test.go b/pkg/core/native_neo_test.go index aecbf4bdd..0e2da1511 100644 --- a/pkg/core/native_neo_test.go +++ b/pkg/core/native_neo_test.go @@ -317,11 +317,32 @@ func TestNEO_TransferOnPayment(t *testing.T) { aer, err := bc.GetAppExecResults(tx.Hash(), trigger.Application) require.NoError(t, err) require.Equal(t, vm.HaltState, aer[0].VMState) - require.Len(t, aer[0].Events, 3) // transfer + auto GAS claim + onPayment + require.Len(t, aer[0].Events, 3) // transfer + GAS claim for sender + onPayment e := aer[0].Events[2] require.Equal(t, "LastPayment", e.Name) arr := e.Item.Value().([]stackitem.Item) - require.Equal(t, neoOwner.BytesBE(), arr[0].Value()) - require.Equal(t, big.NewInt(amount), arr[1].Value()) + require.Equal(t, bc.contracts.NEO.Hash.BytesBE(), arr[0].Value()) + require.Equal(t, neoOwner.BytesBE(), arr[1].Value()) + require.Equal(t, big.NewInt(amount), arr[2].Value()) + + tx = transferTokenFromMultisigAccount(t, bc, cs.Hash, bc.contracts.NEO.Hash, amount) + aer, err = bc.GetAppExecResults(tx.Hash(), trigger.Application) + require.NoError(t, err) + require.Equal(t, vm.HaltState, aer[0].VMState) + // Now we must also have GAS claim for contract and corresponding `onPayment`. + require.Len(t, aer[0].Events, 5) + + e = aer[0].Events[2] // onPayment for GAS claim + require.Equal(t, "LastPayment", e.Name) + arr = e.Item.Value().([]stackitem.Item) + require.Equal(t, stackitem.Null{}, arr[1]) + require.Equal(t, bc.contracts.GAS.Hash.BytesBE(), arr[0].Value()) + + e = aer[0].Events[4] // onPayment for NEO transfer + require.Equal(t, "LastPayment", e.Name) + arr = e.Item.Value().([]stackitem.Item) + require.Equal(t, bc.contracts.NEO.Hash.BytesBE(), arr[0].Value()) + require.Equal(t, neoOwner.BytesBE(), arr[1].Value()) + require.Equal(t, big.NewInt(amount), arr[2].Value()) } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index ae4b47eab..a586086b8 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -284,6 +284,7 @@ func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) { ctx.tryStack = NewStack("exception") ctx.callFlag = f ctx.static = newSlot(v.refs) + ctx.callingScriptHash = v.GetCurrentScriptHash() v.istack.PushVal(ctx) }