diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 82b26fbd7..bf7f48cac 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -2450,8 +2450,8 @@ func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([ } // GetTestVM returns a VM and a Store setup for a test run of some sort of code. -func (bc *Blockchain) GetTestVM() *vm.VM { - systemInterop := bc.newInteropContext(trigger.Application, bc.dao, nil, nil) +func (bc *Blockchain) GetTestVM(tx *transaction.Transaction) *vm.VM { + systemInterop := bc.newInteropContext(trigger.Application, bc.dao, nil, tx) vm := systemInterop.SpawnVM() vm.SetPriceGetter(getPrice) return vm diff --git a/pkg/core/blockchainer.go b/pkg/core/blockchainer.go index 03f0bd4e3..8f8bc0729 100644 --- a/pkg/core/blockchainer.go +++ b/pkg/core/blockchainer.go @@ -45,7 +45,7 @@ type Blockchainer interface { GetStateRoot(height uint32) (*state.MPTRootState, error) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem GetStorageItems(hash util.Uint160) (map[string]*state.StorageItem, error) - GetTestVM() *vm.VM + GetTestVM(tx *transaction.Transaction) *vm.VM GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) GetUnspentCoinState(util.Uint256) *state.UnspentCoin References(t *transaction.Transaction) ([]transaction.InOut, error) diff --git a/pkg/core/interop_neo.go b/pkg/core/interop_neo.go index 5f4d56b7e..12d653d0c 100644 --- a/pkg/core/interop_neo.go +++ b/pkg/core/interop_neo.go @@ -602,7 +602,7 @@ func (ic *interopContext) contractMigrate(v *vm.VM) error { ic.dao.MigrateNEP5Balances(hash, contract.ScriptHash()) // save NEP5 metadata if any - v := ic.bc.GetTestVM() + v := ic.bc.GetTestVM(nil) w := io.NewBufBinWriter() emit.AppCallWithOperationAndArgs(w.BinWriter, hash, "decimals") v.SetGasLimit(ic.bc.GetConfig().FreeGasLimit) diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 93aa36ccf..d35b222e5 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -123,7 +123,7 @@ func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) func (chain testChain) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem { panic("TODO") } -func (chain testChain) GetTestVM() *vm.VM { +func (chain testChain) GetTestVM(tx *transaction.Transaction) *vm.VM { panic("TODO") } func (chain testChain) GetStorageItems(hash util.Uint160) (map[string]*state.StorageItem, error) { diff --git a/pkg/rpc/request/param.go b/pkg/rpc/request/param.go index a6a254dd0..cd78a1078 100644 --- a/pkg/rpc/request/param.go +++ b/pkg/rpc/request/param.go @@ -164,6 +164,27 @@ func (p *Param) GetUint160FromAddressOrHex() (util.Uint160, error) { return p.GetUint160FromAddress() } +// GetArrayUint160FromHex returns array of Uint160 values of the parameter that +// was supply as array of raw hex. +func (p *Param) GetArrayUint160FromHex() ([]util.Uint160, error) { + if p == nil { + return nil, nil + } + arr, err := p.GetArray() + if err != nil { + return nil, err + } + var result = make([]util.Uint160, len(arr)) + for i, parameter := range arr { + hash, err := parameter.GetUint160FromHex() + if err != nil { + return nil, err + } + result[i] = hash + } + return result, nil +} + // GetFuncParam returns current parameter as a function call parameter. func (p *Param) GetFuncParam() (FuncParam, error) { if p == nil { diff --git a/pkg/rpc/request/param_test.go b/pkg/rpc/request/param_test.go index 0f51bc1db..feaee36a7 100644 --- a/pkg/rpc/request/param_test.go +++ b/pkg/rpc/request/param_test.go @@ -185,6 +185,24 @@ func TestParam_GetUint160FromAddressOrHex(t *testing.T) { }) } +func TestParam_GetArrayUint160FromHex(t *testing.T) { + in1 := util.Uint160{1, 2, 3} + in2 := util.Uint160{4, 5, 6} + p := Param{Type: ArrayT, Value: []Param{ + { + Type: StringT, + Value: in1.StringLE(), + }, + { + Type: StringT, + Value: in2.StringLE(), + }, + }} + arr, err := p.GetArrayUint160FromHex() + require.NoError(t, err) + require.Equal(t, []util.Uint160{in1, in2}, arr) +} + func TestParamGetFuncParam(t *testing.T) { fp := FuncParam{ Type: smartcontract.StringType, diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 07318983c..1a633045f 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -1075,11 +1075,15 @@ func (s *Server) invoke(reqParams request.Params) (interface{}, *response.Error) if err != nil { return nil, response.ErrInvalidParams } + hashesForVerifying, err := reqParams.ValueWithType(2, request.ArrayT).GetArrayUint160FromHex() + if err != nil { + return nil, response.ErrInvalidParams + } script, err := request.CreateInvocationScript(scriptHash, slice) if err != nil { return nil, response.NewInternalServerError("can't create invocation script", err) } - return s.runScriptInVM(script), nil + return s.runScriptInVM(script, hashesForVerifying), nil } // invokescript implements the `invokescript` RPC call. @@ -1088,11 +1092,20 @@ func (s *Server) invokeFunction(reqParams request.Params) (interface{}, *respons if err != nil { return nil, response.ErrInvalidParams } - script, err := request.CreateFunctionInvocationScript(scriptHash, reqParams[1:]) + var hashesForVerifying []util.Uint160 + hashesForVerifyingIndex := len(reqParams) + if hashesForVerifyingIndex > 3 { + hashesForVerifying, err = reqParams.ValueWithType(3, request.ArrayT).GetArrayUint160FromHex() + if err != nil { + return nil, response.ErrInvalidParams + } + hashesForVerifyingIndex-- + } + script, err := request.CreateFunctionInvocationScript(scriptHash, reqParams[1:hashesForVerifyingIndex]) if err != nil { return nil, response.NewInternalServerError("can't create invocation script", err) } - return s.runScriptInVM(script), nil + return s.runScriptInVM(script, hashesForVerifying), nil } // invokescript implements the `invokescript` RPC call. @@ -1106,13 +1119,27 @@ func (s *Server) invokescript(reqParams request.Params) (interface{}, *response. return nil, response.ErrInvalidParams } - return s.runScriptInVM(script), nil + hashesForVerifying, err := reqParams.ValueWithType(1, request.ArrayT).GetArrayUint160FromHex() + if err != nil { + return nil, response.ErrInvalidParams + } + + return s.runScriptInVM(script, hashesForVerifying), nil } // runScriptInVM runs given script in a new test VM and returns the invocation // result. -func (s *Server) runScriptInVM(script []byte) *result.Invoke { - vm := s.chain.GetTestVM() +func (s *Server) runScriptInVM(script []byte, scriptHashesForVerifying []util.Uint160) *result.Invoke { + var tx *transaction.Transaction + if count := len(scriptHashesForVerifying); count != 0 { + tx := new(transaction.Transaction) + tx.Attributes = make([]transaction.Attribute, count) + for i, a := range tx.Attributes { + a.Data = scriptHashesForVerifying[i].BytesBE() + a.Usage = transaction.Script + } + } + vm := s.chain.GetTestVM(tx) vm.SetGasLimit(s.config.MaxGasInvoke) vm.LoadScript(script) _ = vm.Run()