diff --git a/pkg/core/native/oracle.go b/pkg/core/native/oracle.go index dc3e57fed..313a78805 100644 --- a/pkg/core/native/oracle.go +++ b/pkg/core/native/oracle.go @@ -265,9 +265,15 @@ func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.I if err != nil { panic(err) } - filter, err := stackitem.ToString(args[1]) - if err != nil { - panic(err) + var filter *string + _, ok := args[1].(stackitem.Null) + if !ok { + // Check UTF-8 validity. + str, err := stackitem.ToString(args[1]) + if err != nil { + panic(err) + } + filter = &str } cb, err := stackitem.ToString(args[2]) if err != nil { @@ -285,8 +291,8 @@ func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.I } // RequestInternal processes oracle request. -func (o *Oracle) RequestInternal(ic *interop.Context, url, filter, cb string, userData stackitem.Item, gas *big.Int) error { - if len(url) > maxURLLength || len(filter) > maxFilterLength || len(cb) > maxCallbackLength || gas.Uint64() < 1000_0000 { +func (o *Oracle) RequestInternal(ic *interop.Context, url string, filter *string, cb string, userData stackitem.Item, gas *big.Int) error { + if len(url) > maxURLLength || (filter != nil && len(*filter) > maxFilterLength) || len(cb) > maxCallbackLength || gas.Uint64() < 1000_0000 { return ErrBigArgument } @@ -317,6 +323,12 @@ func (o *Oracle) RequestInternal(ic *interop.Context, url, filter, cb string, us return ErrBigArgument } + var filterNotif stackitem.Item + if filter != nil { + filterNotif = stackitem.Make(*filter) + } else { + filterNotif = stackitem.Null{} + } ic.Notifications = append(ic.Notifications, state.NotificationEvent{ ScriptHash: o.Hash, Name: "OracleRequest", @@ -324,14 +336,14 @@ func (o *Oracle) RequestInternal(ic *interop.Context, url, filter, cb string, us stackitem.Make(id), stackitem.Make(ic.VM.GetCallingScriptHash().BytesBE()), stackitem.Make(url), - stackitem.Make(filter), + filterNotif, }), }) req := &state.OracleRequest{ OriginalTxID: o.getOriginalTxID(ic.DAO, ic.Tx), GasForResponse: gas.Uint64(), URL: url, - Filter: &filter, + Filter: filter, CallbackContract: ic.VM.GetCallingScriptHash(), CallbackMethod: cb, UserData: data, diff --git a/pkg/core/native_oracle_test.go b/pkg/core/native_oracle_test.go index 0c2aca4a5..2d603e2b2 100644 --- a/pkg/core/native_oracle_test.go +++ b/pkg/core/native_oracle_test.go @@ -18,6 +18,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -88,10 +89,14 @@ func getOracleContractState(h util.Uint160) *state.Contract { } func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain, - url, filter string, userData []byte, gas int64) util.Uint256 { + url string, filter *string, userData []byte, gas int64) util.Uint256 { w := io.NewBufBinWriter() + var filtItem interface{} + if filter != nil { + filtItem = *filter + } emit.AppCallWithOperationAndArgs(w.BinWriter, h, "requestURL", - url, filter, "handle", userData, gas) + url, filtItem, "handle", userData, gas) require.NoError(t, w.Err) gas += 50_000_000 + 5_000_000 // request + contract call with args @@ -113,15 +118,16 @@ func TestOracle_Request(t *testing.T) { require.NoError(t, bc.dao.PutContractState(cs)) gasForResponse := int64(2000_1234) + var filter = "flt" userData := []byte("custom info") - txHash := putOracleRequest(t, cs.ScriptHash(), bc, "url", "flt", userData, gasForResponse) + txHash := putOracleRequest(t, cs.ScriptHash(), bc, "url", &filter, userData, gasForResponse) req, err := orc.GetRequestInternal(bc.dao, 1) require.NotNil(t, req) require.NoError(t, err) require.Equal(t, txHash, req.OriginalTxID) require.Equal(t, "url", req.URL) - require.Equal(t, "flt", *req.Filter) + require.Equal(t, filter, *req.Filter) require.Equal(t, cs.ScriptHash(), req.CallbackContract) require.Equal(t, "handle", req.CallbackMethod) require.Equal(t, uint64(gasForResponse), req.GasForResponse) @@ -192,7 +198,7 @@ func TestOracle_Request(t *testing.T) { t.Run("ErrorOnFinish", func(t *testing.T) { const reqID = 2 - putOracleRequest(t, cs.ScriptHash(), bc, "url", "flt", []byte{1, 2}, gasForResponse) + putOracleRequest(t, cs.ScriptHash(), bc, "url", nil, []byte{1, 2}, gasForResponse) _, err := orc.GetRequestInternal(bc.dao, reqID) // ensure ID is 2 require.NoError(t, err) @@ -215,4 +221,23 @@ func TestOracle_Request(t *testing.T) { _, err = orc.GetRequestInternal(ic.DAO, reqID) require.Error(t, err) }) + t.Run("BadRequest", func(t *testing.T) { + var doBadRequest = func(t *testing.T, h util.Uint160, url string, filter *string, userData []byte, gas int64) { + txHash := putOracleRequest(t, h, bc, url, filter, userData, gas) + aer, err := bc.GetAppExecResults(txHash, trigger.Application) + require.NoError(t, err) + require.Equal(t, 1, len(aer)) + require.Equal(t, vm.FaultState, aer[0].VMState) + } + t.Run("non-UTF8 url", func(t *testing.T) { + doBadRequest(t, cs.ScriptHash(), "\xff", nil, []byte{1, 2}, gasForResponse) + }) + t.Run("non-UTF8 filter", func(t *testing.T) { + var f = "\xff" + doBadRequest(t, cs.ScriptHash(), "url", &f, []byte{1, 2}, gasForResponse) + }) + t.Run("not enough gas", func(t *testing.T) { + doBadRequest(t, cs.ScriptHash(), "url", nil, nil, 1000) + }) + }) }