Merge pull request #1600 from nspcc-dev/nativesync

core: call from native contracts synchronously
This commit is contained in:
Roman Khimov 2020-12-10 17:52:10 +03:00 committed by GitHub
commit f0dba26d43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 120 additions and 43 deletions

View file

@ -60,12 +60,18 @@ func callExInternal(ic *interop.Context, h []byte, name string, args []stackitem
} }
} }
} }
return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty, nil) return CallExInternal(ic, cs, name, args, f, vm.EnsureNotEmpty)
} }
// CallExInternal calls a contract with flags and can't be invoked directly by user. // CallExInternal calls a contract with flags and can't be invoked directly by user.
func CallExInternal(ic *interop.Context, cs *state.Contract, func CallExInternal(ic *interop.Context, cs *state.Contract,
name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState, callback func(ctx *vm.Context)) error { name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState) error {
return callExFromNative(ic, ic.VM.GetCurrentScriptHash(), cs, name, args, f, checkReturn)
}
// callExFromNative calls a contract with flags using provided calling hash.
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
name string, args []stackitem.Item, f smartcontract.CallFlag, checkReturn vm.CheckReturnState) error {
md := cs.Manifest.ABI.GetMethod(name) md := cs.Manifest.ABI.GetMethod(name)
if md == nil { if md == nil {
return fmt.Errorf("method '%s' not found", name) return fmt.Errorf("method '%s' not found", name)
@ -76,7 +82,7 @@ func CallExInternal(ic *interop.Context, cs *state.Contract,
} }
ic.VM.Invocations[cs.Hash]++ ic.VM.Invocations[cs.Hash]++
ic.VM.LoadScriptWithHash(cs.Script, cs.Hash, ic.VM.Context().GetCallFlags()&f) ic.VM.LoadScriptWithCallingHash(caller, cs.Script, cs.Hash, ic.VM.Context().GetCallFlags()&f)
var isNative bool var isNative bool
for i := range ic.Natives { for i := range ic.Natives {
if ic.Natives[i].Metadata().Hash.Equals(cs.Hash) { if ic.Natives[i].Metadata().Hash.Equals(cs.Hash) {
@ -95,7 +101,6 @@ func CallExInternal(ic *interop.Context, cs *state.Contract,
ic.VM.Jump(ic.VM.Context(), md.Offset) ic.VM.Jump(ic.VM.Context(), md.Offset)
} }
ic.VM.Context().CheckReturn = checkReturn ic.VM.Context().CheckReturn = checkReturn
ic.VM.Context().Callback = callback
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit) md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
if md != nil { if md != nil {
@ -104,3 +109,24 @@ func CallExInternal(ic *interop.Context, cs *state.Contract,
return nil return nil
} }
// ErrNativeCall is returned for failed calls from native.
var ErrNativeCall = errors.New("error during call from native")
// CallFromNative performs synchronous call from native contract.
func CallFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract, method string, args []stackitem.Item, checkReturn vm.CheckReturnState) error {
startSize := ic.VM.Istack().Len()
if err := callExFromNative(ic, caller, cs, method, args, smartcontract.All, checkReturn); err != nil {
return err
}
for !ic.VM.HasStopped() && ic.VM.Istack().Len() > startSize {
if err := ic.VM.Step(); err != nil {
return fmt.Errorf("%w: %v", ErrNativeCall, err)
}
}
if ic.VM.State() == vm.FaultState {
return ErrNativeCall
}
return nil
}

View file

@ -22,8 +22,13 @@ func GetExecutingScriptHash(ic *interop.Context) error {
} }
// GetCallingScriptHash returns calling script hash. // GetCallingScriptHash returns calling script hash.
// While Executing and Entry script hashes are always valid for non-native contracts,
// Calling hash is set explicitly when native contracts are used, because when switching from
// one native to another, no operations are performed on invocation stack.
func GetCallingScriptHash(ic *interop.Context) error { func GetCallingScriptHash(ic *interop.Context) error {
return ic.VM.PushContextScriptHash(1) h := ic.VM.GetCallingScriptHash()
ic.VM.Estack().PushVal(h.BytesBE())
return nil
} }
// GetEntryScriptHash returns entry script hash. // GetEntryScriptHash returns entry script hash.

View file

@ -187,7 +187,7 @@ func callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) error {
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy) md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
if md != nil { if md != nil {
return contract.CallExInternal(ic, cs, manifest.MethodDeploy, return contract.CallExInternal(ic, cs, manifest.MethodDeploy,
[]stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty, nil) []stackitem.Item{stackitem.NewBool(isUpdate)}, smartcontract.All, vm.EnsureIsEmpty)
} }
return nil return nil
} }

View file

@ -440,7 +440,8 @@ func getTestContractState() (*state.Contract, *state.Contract) {
emit.Syscall(w.BinWriter, interopnames.SystemStorageGet) emit.Syscall(w.BinWriter, interopnames.SystemStorageGet)
emit.Opcodes(w.BinWriter, opcode.RET) emit.Opcodes(w.BinWriter, opcode.RET)
onPaymentOff := w.Len() onPaymentOff := w.Len()
emit.Int(w.BinWriter, 3) emit.Syscall(w.BinWriter, interopnames.SystemRuntimeGetCallingScriptHash)
emit.Int(w.BinWriter, 4)
emit.Opcodes(w.BinWriter, opcode.PACK) emit.Opcodes(w.BinWriter, opcode.PACK)
emit.String(w.BinWriter, "LastPayment") emit.String(w.BinWriter, "LastPayment")
emit.Syscall(w.BinWriter, interopnames.SystemRuntimeNotify) emit.Syscall(w.BinWriter, interopnames.SystemRuntimeNotify)
@ -972,7 +973,7 @@ func TestContractCreateDeploy(t *testing.T) {
cs.Hash = state.CreateContractHash(sender, cs.Script) cs.Hash = state.CreateContractHash(sender, cs.Script)
v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All) v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All)
err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil) err := contract.CallExInternal(ic, cs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, v.Run()) require.NoError(t, v.Run())
require.Equal(t, "create", v.Estack().Pop().String()) require.Equal(t, "create", v.Estack().Pop().String())
@ -993,7 +994,7 @@ func TestContractCreateDeploy(t *testing.T) {
require.NoError(t, v.Run()) require.NoError(t, v.Run())
v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All) v.LoadScriptWithHash(currCs.Script, cs.Hash, smartcontract.All)
err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty, nil) err = contract.CallExInternal(ic, newCs, "getValue", nil, smartcontract.All, vm.EnsureNotEmpty)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, v.Run()) require.NoError(t, v.Run())
require.Equal(t, "update", v.Estack().Pop().String()) require.Equal(t, "update", v.Estack().Pop().String())

View file

@ -159,7 +159,7 @@ func (c *nep17TokenNative) postTransfer(ic *interop.Context, from, to *util.Uint
stackitem.NewBigInteger(amount), stackitem.NewBigInteger(amount),
data, data,
} }
if err := contract.CallExInternal(ic, cs, manifest.MethodOnPayment, args, smartcontract.All, vm.EnsureIsEmpty, nil); err != nil { if err := contract.CallFromNative(ic, c.Hash, cs, manifest.MethodOnPayment, args, vm.EnsureIsEmpty); err != nil {
panic(err) panic(err)
} }
} }

View file

@ -247,15 +247,13 @@ func (n *Notary) withdraw(ic *interop.Context, args []stackitem.Item) stackitem.
panic(fmt.Errorf("failed to get GAS contract state: %w", err)) panic(fmt.Errorf("failed to get GAS contract state: %w", err))
} }
transferArgs := []stackitem.Item{stackitem.NewByteArray(n.Hash.BytesBE()), stackitem.NewByteArray(to.BytesBE()), stackitem.NewBigInteger(deposit.Amount), stackitem.Null{}} transferArgs := []stackitem.Item{stackitem.NewByteArray(n.Hash.BytesBE()), stackitem.NewByteArray(to.BytesBE()), stackitem.NewBigInteger(deposit.Amount), stackitem.Null{}}
err = contract.CallExInternal(ic, cs, "transfer", transferArgs, smartcontract.All, vm.EnsureIsEmpty, func(ctx *vm.Context) { // we need EnsureIsEmpty because there's a callback popping result from the stack err = contract.CallFromNative(ic, n.Hash, cs, "transfer", transferArgs, vm.EnsureNotEmpty)
isTransferOk := ic.VM.Estack().Pop().Bool()
if !isTransferOk {
panic("failed to transfer GAS from Notary account")
}
})
if err != nil { if err != nil {
panic(fmt.Errorf("failed to transfer GAS from Notary account: %w", err)) panic(fmt.Errorf("failed to transfer GAS from Notary account: %w", err))
} }
if !ic.VM.Estack().Pop().Bool() {
panic("failed to transfer GAS from Notary account: `transfer` returned false")
}
if err := n.removeDepositFor(ic.DAO, from); err != nil { if err := n.removeDepositFor(ic.DAO, from); err != nil {
panic(fmt.Errorf("failed to remove withdrawn deposit for %s from the storage: %w", from.StringBE(), err)) panic(fmt.Errorf("failed to remove withdrawn deposit for %s from the storage: %w", from.StringBE(), err))
} }

View file

@ -20,6 +20,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -248,16 +249,17 @@ func (o *Oracle) FinishInternal(ic *interop.Context) error {
r := io.NewBinReaderFromBuf(req.UserData) r := io.NewBinReaderFromBuf(req.UserData)
userData := stackitem.DecodeBinaryStackItem(r) userData := stackitem.DecodeBinaryStackItem(r)
args := stackitem.NewArray([]stackitem.Item{ args := []stackitem.Item{
stackitem.Make(req.URL), stackitem.Make(req.URL),
stackitem.Make(userData), stackitem.Make(userData),
stackitem.Make(resp.Code), stackitem.Make(resp.Code),
stackitem.Make(resp.Result), stackitem.Make(resp.Result),
}) }
ic.VM.Estack().PushVal(args) cs, err := ic.DAO.GetContractState(req.CallbackContract)
ic.VM.Estack().PushVal(req.CallbackMethod) if err != nil {
ic.VM.Estack().PushVal(req.CallbackContract.BytesBE()) return err
return contract.Call(ic) }
return contract.CallFromNative(ic, o.Hash, cs, req.CallbackMethod, args, vm.EnsureIsEmpty)
} }
func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.Item { func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.Item {

View file

@ -91,6 +91,21 @@ func newTestNative() *testNative {
RequiredFlags: smartcontract.NoneFlag} RequiredFlags: smartcontract.NoneFlag}
tn.meta.AddMethod(md, desc) tn.meta.AddMethod(md, desc)
desc = &manifest.Method{
Name: "callOtherContractWithReturn",
Parameters: []manifest.Parameter{
manifest.NewParameter("contractHash", smartcontract.Hash160Type),
manifest.NewParameter("method", smartcontract.StringType),
manifest.NewParameter("arg", smartcontract.ArrayType),
},
ReturnType: smartcontract.IntegerType,
}
md = &interop.MethodAndPrice{
Func: tn.callOtherContractWithReturn,
Price: testSumPrice,
RequiredFlags: smartcontract.NoneFlag}
tn.meta.AddMethod(md, desc)
desc = &manifest.Method{Name: "onPersist", ReturnType: smartcontract.BoolType} desc = &manifest.Method{Name: "onPersist", ReturnType: smartcontract.BoolType}
md = &interop.MethodAndPrice{Func: tn.OnPersist, RequiredFlags: smartcontract.AllowModifyStates} md = &interop.MethodAndPrice{Func: tn.OnPersist, RequiredFlags: smartcontract.AllowModifyStates}
tn.meta.AddMethod(md, desc) tn.meta.AddMethod(md, desc)
@ -122,16 +137,17 @@ func toUint160(item stackitem.Item) util.Uint160 {
return u return u
} }
func (tn *testNative) call(ic *interop.Context, args []stackitem.Item, retState vm.CheckReturnState) { func (tn *testNative) call(ic *interop.Context, args []stackitem.Item, checkReturn vm.CheckReturnState) {
cs, err := ic.DAO.GetContractState(toUint160(args[0])) h := toUint160(args[0])
if err != nil {
panic(err)
}
bs, err := args[1].TryBytes() bs, err := args[1].TryBytes()
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = contract.CallExInternal(ic, cs, string(bs), args[2].Value().([]stackitem.Item), smartcontract.All, retState, nil) cs, err := ic.DAO.GetContractState(h)
if err != nil {
panic(err)
}
err = contract.CallFromNative(ic, tn.meta.Hash, cs, string(bs), args[2].Value().([]stackitem.Item), checkReturn)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -142,6 +158,12 @@ func (tn *testNative) callOtherContractNoReturn(ic *interop.Context, args []stac
return stackitem.Null{} return stackitem.Null{}
} }
func (tn *testNative) callOtherContractWithReturn(ic *interop.Context, args []stackitem.Item) stackitem.Item {
tn.call(ic, args, vm.EnsureNotEmpty)
bi := ic.VM.Estack().Pop().BigInt()
return stackitem.Make(bi.Add(bi, big.NewInt(1)))
}
func TestNativeContract_Invoke(t *testing.T) { func TestNativeContract_Invoke(t *testing.T) {
chain := newTestChain(t) chain := newTestChain(t)
defer chain.Close() defer chain.Close()
@ -238,4 +260,10 @@ func TestNativeContract_InvokeOtherContract(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
checkResult(t, res, stackitem.Null{}) // simple call is done with EnsureNotEmpty checkResult(t, res, stackitem.Null{}) // simple call is done with EnsureNotEmpty
}) })
t.Run("non-native, with return", func(t *testing.T) {
res, err := invokeContractMethod(chain, testSumPrice*4+10000, tn.Metadata().Hash,
"callOtherContractWithReturn", cs.Hash, "ret7", []interface{}{})
require.NoError(t, err)
checkResult(t, res, stackitem.Make(8))
})
} }

View file

@ -317,11 +317,32 @@ func TestNEO_TransferOnPayment(t *testing.T) {
aer, err := bc.GetAppExecResults(tx.Hash(), trigger.Application) aer, err := bc.GetAppExecResults(tx.Hash(), trigger.Application)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, vm.HaltState, aer[0].VMState) require.Equal(t, vm.HaltState, aer[0].VMState)
require.Len(t, aer[0].Events, 3) // transfer + auto GAS claim + onPayment require.Len(t, aer[0].Events, 3) // transfer + GAS claim for sender + onPayment
e := aer[0].Events[2] e := aer[0].Events[2]
require.Equal(t, "LastPayment", e.Name) require.Equal(t, "LastPayment", e.Name)
arr := e.Item.Value().([]stackitem.Item) arr := e.Item.Value().([]stackitem.Item)
require.Equal(t, neoOwner.BytesBE(), arr[0].Value()) require.Equal(t, bc.contracts.NEO.Hash.BytesBE(), arr[0].Value())
require.Equal(t, big.NewInt(amount), arr[1].Value()) require.Equal(t, neoOwner.BytesBE(), arr[1].Value())
require.Equal(t, big.NewInt(amount), arr[2].Value())
tx = transferTokenFromMultisigAccount(t, bc, cs.Hash, bc.contracts.NEO.Hash, amount)
aer, err = bc.GetAppExecResults(tx.Hash(), trigger.Application)
require.NoError(t, err)
require.Equal(t, vm.HaltState, aer[0].VMState)
// Now we must also have GAS claim for contract and corresponding `onPayment`.
require.Len(t, aer[0].Events, 5)
e = aer[0].Events[2] // onPayment for GAS claim
require.Equal(t, "LastPayment", e.Name)
arr = e.Item.Value().([]stackitem.Item)
require.Equal(t, stackitem.Null{}, arr[1])
require.Equal(t, bc.contracts.GAS.Hash.BytesBE(), arr[0].Value())
e = aer[0].Events[4] // onPayment for NEO transfer
require.Equal(t, "LastPayment", e.Name)
arr = e.Item.Value().([]stackitem.Item)
require.Equal(t, bc.contracts.NEO.Hash.BytesBE(), arr[0].Value())
require.Equal(t, neoOwner.BytesBE(), arr[1].Value())
require.Equal(t, big.NewInt(amount), arr[2].Value())
} }

View file

@ -48,14 +48,6 @@ type Context struct {
// Call flags this context was created with. // Call flags this context was created with.
callFlag smartcontract.CallFlag callFlag smartcontract.CallFlag
// InvocationState contains expected return type and actions to be performed on context unload.
InvocationState
}
// InvocationState contains return convention and callback to be executed on context unload.
type InvocationState struct {
// Callback is executed on context unload.
Callback func(ctx *Context)
// CheckReturn specifies if amount of return values needs to be checked. // CheckReturn specifies if amount of return values needs to be checked.
CheckReturn CheckReturnState CheckReturn CheckReturnState
} }

View file

@ -284,6 +284,7 @@ func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) {
ctx.tryStack = NewStack("exception") ctx.tryStack = NewStack("exception")
ctx.callFlag = f ctx.callFlag = f
ctx.static = newSlot(v.refs) ctx.static = newSlot(v.refs)
ctx.callingScriptHash = v.GetCurrentScriptHash()
v.istack.PushVal(ctx) v.istack.PushVal(ctx)
} }
@ -295,11 +296,17 @@ func (v *VM) LoadScriptWithFlags(b []byte, f smartcontract.CallFlag) {
// each other. // each other.
func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) { func (v *VM) LoadScriptWithHash(b []byte, hash util.Uint160, f smartcontract.CallFlag) {
shash := v.GetCurrentScriptHash() shash := v.GetCurrentScriptHash()
v.LoadScriptWithCallingHash(shash, b, hash, f)
}
// LoadScriptWithCallingHash is similar to LoadScriptWithHash but sets calling hash explicitly.
// It should be used for calling from native contracts.
func (v *VM) LoadScriptWithCallingHash(caller util.Uint160, b []byte, hash util.Uint160, f smartcontract.CallFlag) {
v.LoadScriptWithFlags(b, f) v.LoadScriptWithFlags(b, f)
ctx := v.Context() ctx := v.Context()
ctx.isDeployed = true ctx.isDeployed = true
ctx.scriptHash = hash ctx.scriptHash = hash
ctx.callingScriptHash = shash ctx.callingScriptHash = caller
} }
// Context returns the current executed context. Nil if there is no context, // Context returns the current executed context. Nil if there is no context,
@ -1418,9 +1425,6 @@ func (v *VM) unloadContext(ctx *Context) {
if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static {
ctx.static.Clear() ctx.static.Clear()
} }
if ctx.Callback != nil {
ctx.Callback(ctx)
}
switch ctx.CheckReturn { switch ctx.CheckReturn {
case NoCheck: case NoCheck:
case EnsureIsEmpty: case EnsureIsEmpty: