diff --git a/pkg/compiler/interop_test.go b/pkg/compiler/interop_test.go index 120e29709..d737a3058 100644 --- a/pkg/compiler/interop_test.go +++ b/pkg/compiler/interop_test.go @@ -202,7 +202,7 @@ func TestAppCall(t *testing.T) { fc := fakechain.NewFakeChain() ic := interop.NewContext(trigger.Application, fc, dao.NewSimple(storage.NewMemoryStore(), false, false), - interop.DefaultBaseExecFee, native.DefaultStoragePrice, contractGetter, nil, nil, nil, zaptest.NewLogger(t)) + interop.DefaultBaseExecFee, native.DefaultStoragePrice, contractGetter, nil, nil, nil, nil, zaptest.NewLogger(t)) t.Run("valid script", func(t *testing.T) { src := getAppCallScript(fmt.Sprintf("%#v", ih.BytesBE())) diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 504bebd1e..94b55283a 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1110,7 +1110,6 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error v := systemInterop.SpawnVM() v.LoadScriptWithFlags(tx.Script, callflag.All) v.SetPriceGetter(systemInterop.GetPrice) - v.LoadToken = contract.LoadToken(systemInterop) v.GasLimit = tx.SystemFee err := systemInterop.Exec() @@ -2169,7 +2168,6 @@ func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b * systemInterop := bc.newInteropContext(t, bc.dao, b, tx) vm := systemInterop.SpawnVM() vm.SetPriceGetter(systemInterop.GetPrice) - vm.LoadToken = contract.LoadToken(systemInterop) return systemInterop } @@ -2204,7 +2202,6 @@ func (bc *Blockchain) GetTestHistoricVM(t trigger.Type, tx *transaction.Transact systemInterop := bc.newInteropContext(t, dTrie, b, tx) vm := systemInterop.SpawnVM() vm.SetPriceGetter(systemInterop.GetPrice) - vm.LoadToken = contract.LoadToken(systemInterop) return systemInterop, nil } @@ -2280,7 +2277,6 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa vm := interopCtx.SpawnVM() vm.SetPriceGetter(interopCtx.GetPrice) - vm.LoadToken = contract.LoadToken(interopCtx) vm.GasLimit = gas if err := bc.InitVerificationContext(interopCtx, hash, witness); err != nil { return 0, err @@ -2376,7 +2372,7 @@ func (bc *Blockchain) newInteropContext(trigger trigger.Type, d *dao.Simple, blo // changes that were not yet persisted to Blockchain's dao. baseStorageFee = bc.contracts.Policy.GetStoragePriceInternal(d) } - ic := interop.NewContext(trigger, bc, d, baseExecFee, baseStorageFee, bc.contracts.Management.GetContract, bc.contracts.Contracts, block, tx, bc.log) + ic := interop.NewContext(trigger, bc, d, baseExecFee, baseStorageFee, bc.contracts.Management.GetContract, bc.contracts.Contracts, contract.LoadToken, block, tx, bc.log) ic.Functions = systemInterops switch { case tx != nil: diff --git a/pkg/core/interop/context.go b/pkg/core/interop/context.go index 0c7419284..0d78ce900 100644 --- a/pkg/core/interop/context.go +++ b/pkg/core/interop/context.go @@ -64,6 +64,7 @@ type Context struct { getContract func(*dao.Simple, util.Uint160) (*state.Contract, error) baseExecFee int64 baseStorageFee int64 + loadToken func(ic *Context, id int32) error GetRandomCounter uint32 signers []transaction.Signer } @@ -71,6 +72,7 @@ type Context struct { // NewContext returns new interop context. func NewContext(trigger trigger.Type, bc Ledger, d *dao.Simple, baseExecFee, baseStorageFee int64, getContract func(*dao.Simple, util.Uint160) (*state.Contract, error), natives []Contract, + loadTokenFunc func(ic *Context, id int32) error, block *block.Block, tx *transaction.Transaction, log *zap.Logger) *Context { dao := d.GetPrivate() cfg := bc.GetConfig() @@ -88,6 +90,7 @@ func NewContext(trigger trigger.Type, bc Ledger, d *dao.Simple, baseExecFee, bas getContract: getContract, baseExecFee: baseExecFee, baseStorageFee: baseStorageFee, + loadToken: loadTokenFunc, } } @@ -298,6 +301,12 @@ func (ic *Context) BaseStorageFee() int64 { return ic.baseStorageFee } +// LoadToken wraps externally provided load-token loading function providing it with context, +// this function can then be easily used by VM. +func (ic *Context) LoadToken(id int32) error { + return ic.loadToken(ic, id) +} + // SyscallHandler handles syscall with id. func (ic *Context) SyscallHandler(_ *vm.VM, id uint32) error { f := ic.GetFunction(id) @@ -317,6 +326,7 @@ func (ic *Context) SyscallHandler(_ *vm.VM, id uint32) error { // SpawnVM spawns a new VM with the specified gas limit and set context.VM field. func (ic *Context) SpawnVM() *vm.VM { v := vm.NewWithTrigger(ic.Trigger) + v.LoadToken = ic.LoadToken v.GasLimit = -1 v.SyscallHandler = ic.SyscallHandler ic.VM = v diff --git a/pkg/core/interop/contract/call.go b/pkg/core/interop/contract/call.go index a322bcc7c..26a23aaa1 100644 --- a/pkg/core/interop/contract/call.go +++ b/pkg/core/interop/contract/call.go @@ -22,26 +22,24 @@ type policyChecker interface { } // LoadToken calls method specified by the token id. -func LoadToken(ic *interop.Context) func(id int32) error { - return func(id int32) error { - ctx := ic.VM.Context() - if !ctx.GetCallFlags().Has(callflag.ReadStates | callflag.AllowCall) { - return errors.New("invalid call flags") - } - tok := ctx.NEF.Tokens[id] - if int(tok.ParamCount) > ctx.Estack().Len() { - return errors.New("stack is too small") - } - args := make([]stackitem.Item, tok.ParamCount) - for i := range args { - args[i] = ic.VM.Estack().Pop().Item() - } - cs, err := ic.GetContract(tok.Hash) - if err != nil { - return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err) - } - return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, false) +func LoadToken(ic *interop.Context, id int32) error { + ctx := ic.VM.Context() + if !ctx.GetCallFlags().Has(callflag.ReadStates | callflag.AllowCall) { + return errors.New("invalid call flags") } + tok := ctx.NEF.Tokens[id] + if int(tok.ParamCount) > ctx.Estack().Len() { + return errors.New("stack is too small") + } + args := make([]stackitem.Item, tok.ParamCount) + for i := range args { + args[i] = ic.VM.Estack().Pop().Item() + } + cs, err := ic.GetContract(tok.Hash) + if err != nil { + return fmt.Errorf("token contract %s not found: %w", tok.Hash.StringLE(), err) + } + return callInternal(ic, cs, tok.Method, tok.CallFlag, tok.HasReturn, args, false) } // Call calls a contract with flags. diff --git a/pkg/core/interop/crypto/ecdsa_test.go b/pkg/core/interop/crypto/ecdsa_test.go index 9cfb370e1..fc6f70bfd 100644 --- a/pkg/core/interop/crypto/ecdsa_test.go +++ b/pkg/core/interop/crypto/ecdsa_test.go @@ -73,7 +73,7 @@ func initCheckMultisigVMNoArgs(container *transaction.Transaction) *vm.VM { trigger.Verification, fakechain.NewFakeChain(), dao.NewSimple(storage.NewMemoryStore(), false, false), - interop.DefaultBaseExecFee, native.DefaultStoragePrice, nil, nil, nil, + interop.DefaultBaseExecFee, native.DefaultStoragePrice, nil, nil, nil, nil, container, nil) ic.Container = container