native: allow NULL filter in oracle requests

Follow neo-project/neo#2067
This commit is contained in:
Roman Khimov 2020-11-18 23:59:13 +03:00
parent 21317c25cf
commit eef921b8e0
2 changed files with 49 additions and 12 deletions

View file

@ -265,9 +265,15 @@ func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.I
if err != nil { if err != nil {
panic(err) panic(err)
} }
filter, err := stackitem.ToString(args[1]) var filter *string
if err != nil { _, ok := args[1].(stackitem.Null)
panic(err) if !ok {
// Check UTF-8 validity.
str, err := stackitem.ToString(args[1])
if err != nil {
panic(err)
}
filter = &str
} }
cb, err := stackitem.ToString(args[2]) cb, err := stackitem.ToString(args[2])
if err != nil { if err != nil {
@ -285,8 +291,8 @@ func (o *Oracle) request(ic *interop.Context, args []stackitem.Item) stackitem.I
} }
// RequestInternal processes oracle request. // RequestInternal processes oracle request.
func (o *Oracle) RequestInternal(ic *interop.Context, url, filter, cb string, userData stackitem.Item, gas *big.Int) error { func (o *Oracle) RequestInternal(ic *interop.Context, url string, filter *string, cb string, userData stackitem.Item, gas *big.Int) error {
if len(url) > maxURLLength || len(filter) > maxFilterLength || len(cb) > maxCallbackLength || gas.Uint64() < 1000_0000 { if len(url) > maxURLLength || (filter != nil && len(*filter) > maxFilterLength) || len(cb) > maxCallbackLength || gas.Uint64() < 1000_0000 {
return ErrBigArgument return ErrBigArgument
} }
@ -317,6 +323,12 @@ func (o *Oracle) RequestInternal(ic *interop.Context, url, filter, cb string, us
return ErrBigArgument return ErrBigArgument
} }
var filterNotif stackitem.Item
if filter != nil {
filterNotif = stackitem.Make(*filter)
} else {
filterNotif = stackitem.Null{}
}
ic.Notifications = append(ic.Notifications, state.NotificationEvent{ ic.Notifications = append(ic.Notifications, state.NotificationEvent{
ScriptHash: o.Hash, ScriptHash: o.Hash,
Name: "OracleRequest", Name: "OracleRequest",
@ -324,14 +336,14 @@ func (o *Oracle) RequestInternal(ic *interop.Context, url, filter, cb string, us
stackitem.Make(id), stackitem.Make(id),
stackitem.Make(ic.VM.GetCallingScriptHash().BytesBE()), stackitem.Make(ic.VM.GetCallingScriptHash().BytesBE()),
stackitem.Make(url), stackitem.Make(url),
stackitem.Make(filter), filterNotif,
}), }),
}) })
req := &state.OracleRequest{ req := &state.OracleRequest{
OriginalTxID: o.getOriginalTxID(ic.DAO, ic.Tx), OriginalTxID: o.getOriginalTxID(ic.DAO, ic.Tx),
GasForResponse: gas.Uint64(), GasForResponse: gas.Uint64(),
URL: url, URL: url,
Filter: &filter, Filter: filter,
CallbackContract: ic.VM.GetCallingScriptHash(), CallbackContract: ic.VM.GetCallingScriptHash(),
CallbackMethod: cb, CallbackMethod: cb,
UserData: data, UserData: data,

View file

@ -18,6 +18,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -88,10 +89,14 @@ func getOracleContractState(h util.Uint160) *state.Contract {
} }
func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain, func putOracleRequest(t *testing.T, h util.Uint160, bc *Blockchain,
url, filter string, userData []byte, gas int64) util.Uint256 { url string, filter *string, userData []byte, gas int64) util.Uint256 {
w := io.NewBufBinWriter() w := io.NewBufBinWriter()
var filtItem interface{}
if filter != nil {
filtItem = *filter
}
emit.AppCallWithOperationAndArgs(w.BinWriter, h, "requestURL", emit.AppCallWithOperationAndArgs(w.BinWriter, h, "requestURL",
url, filter, "handle", userData, gas) url, filtItem, "handle", userData, gas)
require.NoError(t, w.Err) require.NoError(t, w.Err)
gas += 50_000_000 + 5_000_000 // request + contract call with args gas += 50_000_000 + 5_000_000 // request + contract call with args
@ -113,15 +118,16 @@ func TestOracle_Request(t *testing.T) {
require.NoError(t, bc.dao.PutContractState(cs)) require.NoError(t, bc.dao.PutContractState(cs))
gasForResponse := int64(2000_1234) gasForResponse := int64(2000_1234)
var filter = "flt"
userData := []byte("custom info") userData := []byte("custom info")
txHash := putOracleRequest(t, cs.ScriptHash(), bc, "url", "flt", userData, gasForResponse) txHash := putOracleRequest(t, cs.ScriptHash(), bc, "url", &filter, userData, gasForResponse)
req, err := orc.GetRequestInternal(bc.dao, 1) req, err := orc.GetRequestInternal(bc.dao, 1)
require.NotNil(t, req) require.NotNil(t, req)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, txHash, req.OriginalTxID) require.Equal(t, txHash, req.OriginalTxID)
require.Equal(t, "url", req.URL) require.Equal(t, "url", req.URL)
require.Equal(t, "flt", *req.Filter) require.Equal(t, filter, *req.Filter)
require.Equal(t, cs.ScriptHash(), req.CallbackContract) require.Equal(t, cs.ScriptHash(), req.CallbackContract)
require.Equal(t, "handle", req.CallbackMethod) require.Equal(t, "handle", req.CallbackMethod)
require.Equal(t, uint64(gasForResponse), req.GasForResponse) require.Equal(t, uint64(gasForResponse), req.GasForResponse)
@ -192,7 +198,7 @@ func TestOracle_Request(t *testing.T) {
t.Run("ErrorOnFinish", func(t *testing.T) { t.Run("ErrorOnFinish", func(t *testing.T) {
const reqID = 2 const reqID = 2
putOracleRequest(t, cs.ScriptHash(), bc, "url", "flt", []byte{1, 2}, gasForResponse) putOracleRequest(t, cs.ScriptHash(), bc, "url", nil, []byte{1, 2}, gasForResponse)
_, err := orc.GetRequestInternal(bc.dao, reqID) // ensure ID is 2 _, err := orc.GetRequestInternal(bc.dao, reqID) // ensure ID is 2
require.NoError(t, err) require.NoError(t, err)
@ -215,4 +221,23 @@ func TestOracle_Request(t *testing.T) {
_, err = orc.GetRequestInternal(ic.DAO, reqID) _, err = orc.GetRequestInternal(ic.DAO, reqID)
require.Error(t, err) require.Error(t, err)
}) })
t.Run("BadRequest", func(t *testing.T) {
var doBadRequest = func(t *testing.T, h util.Uint160, url string, filter *string, userData []byte, gas int64) {
txHash := putOracleRequest(t, h, bc, url, filter, userData, gas)
aer, err := bc.GetAppExecResults(txHash, trigger.Application)
require.NoError(t, err)
require.Equal(t, 1, len(aer))
require.Equal(t, vm.FaultState, aer[0].VMState)
}
t.Run("non-UTF8 url", func(t *testing.T) {
doBadRequest(t, cs.ScriptHash(), "\xff", nil, []byte{1, 2}, gasForResponse)
})
t.Run("non-UTF8 filter", func(t *testing.T) {
var f = "\xff"
doBadRequest(t, cs.ScriptHash(), "url", &f, []byte{1, 2}, gasForResponse)
})
t.Run("not enough gas", func(t *testing.T) {
doBadRequest(t, cs.ScriptHash(), "url", nil, nil, 1000)
})
})
} }